3
# Copyright 2009 Facebook
5
# Licensed under the Apache License, Version 2.0 (the "License"); you may
6
# not use this file except in compliance with the License. You may obtain
7
# a copy of the License at
9
# http://www.apache.org/licenses/LICENSE-2.0
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14
# License for the specific language governing permissions and limitations
17
"""A lightweight wrapper around MySQLdb."""
21
import MySQLdb.constants
22
import MySQLdb.converters
23
import MySQLdb.cursors
27
_log = logging.getLogger('tornado.database')
29
class Connection(object):
30
"""A lightweight wrapper around MySQLdb DB-API connections.
32
The main value we provide is wrapping rows in a dict/object so that
33
columns can be accessed by name. Typical usage:
35
db = database.Connection("localhost", "mydatabase")
36
for article in db.query("SELECT * FROM articles"):
39
Cursors are hidden by the implementation, but other than that, the methods
40
are very similar to the DB-API.
42
We explicitly set the timezone to UTC and the character encoding to
43
UTF-8 on all connections to avoid time zone and encoding errors.
45
def __init__(self, host, database, user=None, password=None):
47
self.database = database
49
args = dict(conv=CONVERSIONS, use_unicode=True, charset="utf8",
50
db=database, init_command='SET time_zone = "+0:00"',
51
sql_mode="TRADITIONAL")
54
if password is not None:
55
args["passwd"] = password
57
# We accept a path to a MySQL socket file or a host(:port) string
59
args["unix_socket"] = host
62
pair = host.split(":")
64
args["host"] = pair[0]
65
args["port"] = int(pair[1])
75
_log.error("Cannot connect to MySQL on %s", self.host,
82
"""Closes this database connection."""
83
if getattr(self, "_db", None) is not None:
88
"""Closes the existing database connection and re-opens it."""
90
self._db = MySQLdb.connect(**self._db_args)
91
self._db.autocommit(True)
93
def iter(self, query, *parameters):
94
"""Returns an iterator for the given query and parameters."""
95
if self._db is None: self.reconnect()
96
cursor = MySQLdb.cursors.SSCursor(self._db)
98
self._execute(cursor, query, parameters)
99
column_names = [d[0] for d in cursor.description]
101
yield Row(zip(column_names, row))
105
def query(self, query, *parameters):
106
"""Returns a row list for the given query and parameters."""
107
cursor = self._cursor()
109
self._execute(cursor, query, parameters)
110
column_names = [d[0] for d in cursor.description]
111
return [Row(itertools.izip(column_names, row)) for row in cursor]
115
def get(self, query, *parameters):
116
"""Returns the first row returned for the given query."""
117
rows = self.query(query, *parameters)
121
raise Exception("Multiple rows returned for Database.get() query")
125
def execute(self, query, *parameters):
126
"""Executes the given query, returning the lastrowid from the query."""
127
cursor = self._cursor()
129
self._execute(cursor, query, parameters)
130
return cursor.lastrowid
134
def executemany(self, query, parameters):
135
"""Executes the given query against all the given param sequences.
137
We return the lastrowid from the query.
139
cursor = self._cursor()
141
cursor.executemany(query, parameters)
142
return cursor.lastrowid
147
if self._db is None: self.reconnect()
148
return self._db.cursor()
150
def _execute(self, cursor, query, parameters):
152
return cursor.execute(query, parameters)
153
except OperationalError:
154
_log.error("Error connecting to MySQL on %s", self.host)
160
"""A dict that allows for object-like property access syntax."""
161
def __getattr__(self, name):
165
raise AttributeError(name)
168
# Fix the access conversions to properly recognize unicode/binary
169
FIELD_TYPE = MySQLdb.constants.FIELD_TYPE
170
FLAG = MySQLdb.constants.FLAG
171
CONVERSIONS = copy.deepcopy(MySQLdb.converters.conversions)
173
[FIELD_TYPE.BLOB, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING] + \
174
([FIELD_TYPE.VARCHAR] if 'VARCHAR' in vars(FIELD_TYPE) else []):
175
CONVERSIONS[field_type].insert(0, (FLAG.BINARY, str))
178
# Alias some common MySQL exceptions
179
IntegrityError = MySQLdb.IntegrityError
180
OperationalError = MySQLdb.OperationalError