~ntt-pf-lab/nova/monkey_patch_notification

« back to all changes in this revision

Viewing changes to vendor/tornado/tornado/database.py

  • Committer: Jesse Andrews
  • Date: 2010-05-28 06:05:26 UTC
  • Revision ID: git-v1:bf6e6e718cdc7488e2da87b21e258ccc065fe499
initial commit

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
#!/usr/bin/env python
 
2
#
 
3
# Copyright 2009 Facebook
 
4
#
 
5
# Licensed under the Apache License, Version 2.0 (the "License"); you may
 
6
# not use this file except in compliance with the License. You may obtain
 
7
# a copy of the License at
 
8
#
 
9
#     http://www.apache.org/licenses/LICENSE-2.0
 
10
#
 
11
# Unless required by applicable law or agreed to in writing, software
 
12
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 
13
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 
14
# License for the specific language governing permissions and limitations
 
15
# under the License.
 
16
 
 
17
"""A lightweight wrapper around MySQLdb."""
 
18
 
 
19
import copy
 
20
import MySQLdb
 
21
import MySQLdb.constants
 
22
import MySQLdb.converters
 
23
import MySQLdb.cursors
 
24
import itertools
 
25
import logging
 
26
 
 
27
_log = logging.getLogger('tornado.database')
 
28
 
 
29
class Connection(object):
 
30
    """A lightweight wrapper around MySQLdb DB-API connections.
 
31
 
 
32
    The main value we provide is wrapping rows in a dict/object so that
 
33
    columns can be accessed by name. Typical usage:
 
34
 
 
35
        db = database.Connection("localhost", "mydatabase")
 
36
        for article in db.query("SELECT * FROM articles"):
 
37
            print article.title
 
38
 
 
39
    Cursors are hidden by the implementation, but other than that, the methods
 
40
    are very similar to the DB-API.
 
41
 
 
42
    We explicitly set the timezone to UTC and the character encoding to
 
43
    UTF-8 on all connections to avoid time zone and encoding errors.
 
44
    """
 
45
    def __init__(self, host, database, user=None, password=None):
 
46
        self.host = host
 
47
        self.database = database
 
48
 
 
49
        args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8",
 
50
                    db=database, init_command='SET time_zone = "+0:00"',
 
51
                    sql_mode="TRADITIONAL")
 
52
        if user is not None:
 
53
            args["user"] = user
 
54
        if password is not None:
 
55
            args["passwd"] = password
 
56
 
 
57
        # We accept a path to a MySQL socket file or a host(:port) string
 
58
        if "/" in host:
 
59
            args["unix_socket"] = host
 
60
        else:
 
61
            self.socket = None
 
62
            pair = host.split(":")
 
63
            if len(pair) == 2:
 
64
                args["host"] = pair[0]
 
65
                args["port"] = int(pair[1])
 
66
            else:
 
67
                args["host"] = host
 
68
                args["port"] = 3306
 
69
 
 
70
        self._db = None
 
71
        self._db_args = args
 
72
        try:
 
73
            self.reconnect()
 
74
        except:
 
75
            _log.error("Cannot connect to MySQL on %s", self.host,
 
76
                          exc_info=True)
 
77
 
 
78
    def __del__(self):
 
79
        self.close()
 
80
 
 
81
    def close(self):
 
82
        """Closes this database connection."""
 
83
        if getattr(self, "_db", None) is not None:
 
84
            self._db.close()
 
85
            self._db = None
 
86
 
 
87
    def reconnect(self):
 
88
        """Closes the existing database connection and re-opens it."""
 
89
        self.close()
 
90
        self._db = MySQLdb.connect(**self._db_args)
 
91
        self._db.autocommit(True)
 
92
 
 
93
    def iter(self, query, *parameters):
 
94
        """Returns an iterator for the given query and parameters."""
 
95
        if self._db is None: self.reconnect()
 
96
        cursor = MySQLdb.cursors.SSCursor(self._db)
 
97
        try:
 
98
            self._execute(cursor, query, parameters)
 
99
            column_names = [d[0] for d in cursor.description]
 
100
            for row in cursor:
 
101
                yield Row(zip(column_names, row))
 
102
        finally:
 
103
            cursor.close()
 
104
 
 
105
    def query(self, query, *parameters):
 
106
        """Returns a row list for the given query and parameters."""
 
107
        cursor = self._cursor()
 
108
        try:
 
109
            self._execute(cursor, query, parameters)
 
110
            column_names = [d[0] for d in cursor.description]
 
111
            return [Row(itertools.izip(column_names, row)) for row in cursor]
 
112
        finally:
 
113
            cursor.close()
 
114
 
 
115
    def get(self, query, *parameters):
 
116
        """Returns the first row returned for the given query."""
 
117
        rows = self.query(query, *parameters)
 
118
        if not rows:
 
119
            return None
 
120
        elif len(rows) > 1:
 
121
            raise Exception("Multiple rows returned for Database.get() query")
 
122
        else:
 
123
            return rows[0]
 
124
 
 
125
    def execute(self, query, *parameters):
 
126
        """Executes the given query, returning the lastrowid from the query."""
 
127
        cursor = self._cursor()
 
128
        try:
 
129
            self._execute(cursor, query, parameters)
 
130
            return cursor.lastrowid
 
131
        finally:
 
132
            cursor.close()
 
133
 
 
134
    def executemany(self, query, parameters):
 
135
        """Executes the given query against all the given param sequences.
 
136
 
 
137
        We return the lastrowid from the query.
 
138
        """
 
139
        cursor = self._cursor()
 
140
        try:
 
141
            cursor.executemany(query, parameters)
 
142
            return cursor.lastrowid
 
143
        finally:
 
144
            cursor.close()
 
145
 
 
146
    def _cursor(self):
 
147
        if self._db is None: self.reconnect()
 
148
        return self._db.cursor()
 
149
 
 
150
    def _execute(self, cursor, query, parameters):
 
151
        try:
 
152
            return cursor.execute(query, parameters)
 
153
        except OperationalError:
 
154
            _log.error("Error connecting to MySQL on %s", self.host)
 
155
            self.close()
 
156
            raise
 
157
 
 
158
 
 
159
class Row(dict):
 
160
    """A dict that allows for object-like property access syntax."""
 
161
    def __getattr__(self, name):
 
162
        try:
 
163
            return self[name]
 
164
        except KeyError:
 
165
            raise AttributeError(name)
 
166
 
 
167
 
 
168
# Fix the access conversions to properly recognize unicode/binary
 
169
FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
 
170
FLAG = MySQLdb.constants.FLAG
 
171
CONVERSIONS = copy.deepcopy(MySQLdb.converters.conversions)
 
172
for field_type in \
 
173
        [FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING] + \
 
174
        ([FIELD_TYPE.VARCHAR] if 'VARCHAR' in vars(FIELD_TYPE) else []):
 
175
    CONVERSIONS[field_type].insert(0, (FLAG.BINARY, str))
 
176
 
 
177
 
 
178
# Alias some common MySQL exceptions
 
179
IntegrityError = MySQLdb.IntegrityError
 
180
OperationalError = MySQLdb.OperationalError