~storm/storm/trunk

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
#
# Copyright (c) 2006, 2007 Canonical
#
# Written by Gustavo Niemeyer <gustavo@niemeyer.net>
#
# This file is part of Storm Object Relational Mapper.
#
# Storm is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation; either version 2.1 of
# the License, or (at your option) any later version.
#
# Storm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

"""Basic database interfacing mechanisms for Storm.

This is the common code for database support; specific databases are
supported in modules in L{storm.databases}.
"""

from collections.abc import Callable
from functools import wraps

from storm.expr import Expr, State, compile
# Circular import: imported at the end of the module.
# from storm.tracer import trace
from storm.variables import Variable
from storm.xid import Xid
from storm.exceptions import (
    ClosedError, ConnectionBlockedError, DatabaseError, DisconnectionError,
    Error, ProgrammingError, wrap_exceptions)
from storm.uri import URI
import storm


__all__ = ["Database", "Connection", "Result",
           "convert_param_marks", "create_database", "register_scheme"]


STATE_CONNECTED = 1
STATE_DISCONNECTED = 2
STATE_RECONNECT = 3


class Result(object):
    """A representation of the results from a single SQL statement."""

    _closed = False

    def __init__(self, connection, raw_cursor):
        self._connection = connection # Ensures deallocation order.
        self._raw_cursor = raw_cursor
        if raw_cursor.arraysize == 1:
            # Default of 1 is silly.
            self._raw_cursor.arraysize = 10

    def __del__(self):
        """Close the cursor."""
        try:
            self.close()
        except:
            pass

    def close(self):
        """Close the underlying raw cursor, if it hasn't already been closed.
        """
        if not self._closed:
            self._closed = True
            self._raw_cursor.close()
            self._raw_cursor = None

    def get_one(self):
        """Fetch one result from the cursor.

        The result will be converted to an appropriate format via
        L{from_database}.

        @raise DisconnectionError: Raised when the connection is lost.
            Reconnection happens automatically on rollback.

        @return: A converted row or None, if no data is left.
        """
        row = self._connection._check_disconnect(self._raw_cursor.fetchone)
        if row is not None:
            return tuple(self.from_database(row))
        return None

    def get_all(self):
        """Fetch all results from the cursor.

        The results will be converted to an appropriate format via
        L{from_database}.

        @raise DisconnectionError: Raised when the connection is lost.
            Reconnection happens automatically on rollback.
        """
        result = self._connection._check_disconnect(self._raw_cursor.fetchall)
        if result:
            return [tuple(self.from_database(row)) for row in result]
        return result

    def __iter__(self):
        """Yield all results, one at a time.

        The results will be converted to an appropriate format via
        L{from_database}.

        @raise DisconnectionError: Raised when the connection is lost.
            Reconnection happens automatically on rollback.
        """
        while True:
            results = self._connection._check_disconnect(
                self._raw_cursor.fetchmany)
            if not results:
                break
            for result in results:
                yield tuple(self.from_database(result))

    @property
    def rowcount(self):
        """
        See PEP 249 for further details on rowcount.

        @return: the number of affected rows, or None if the database
            backend does not provide this information. Return value
            is undefined if all results have not yet been retrieved.
        """
        if self._raw_cursor.rowcount == -1:
            return None
        return self._raw_cursor.rowcount

    def get_insert_identity(self, primary_columns, primary_variables):
        """Get a query which will return the row that was just inserted.

        This must be overridden in database-specific subclasses.

        @rtype: L{storm.expr.Expr}
        """
        raise NotImplementedError

    @staticmethod
    def set_variable(variable, value):
        """Set the given variable's value from the database."""
        variable.set(value, from_db=True)

    @staticmethod
    def from_database(row):
        """Convert a row fetched from the database to an agnostic format.

        This method is intended to be overridden in subclasses, but
        not called externally.

        If there are any peculiarities in the datatypes returned from
        a database backend, this method should be overridden in the
        backend subclass to convert them.
        """
        return row


class CursorWrapper(object):
    """A DB-API cursor, wrapping exceptions as StormError instances."""

    def __init__(self, cursor, database):
        super(CursorWrapper, self).__setattr__('_cursor', cursor)
        super(CursorWrapper, self).__setattr__('_database', database)

    def __getattr__(self, name):
        attr = getattr(self._cursor, name)
        if isinstance(attr, Callable):
            @wraps(attr)
            def wrapper(*args, **kwargs):
                with wrap_exceptions(self._database):
                    return attr(*args, **kwargs)

            return wrapper
        else:
            return attr

    def __setattr__(self, name, value):
        return setattr(self._cursor, name, value)

    def __iter__(self):
        with wrap_exceptions(self._database):
            for item in self._cursor:
                yield item

    def __enter__(self):
        return self

    def __exit__(self, type_, value, tb):
        with wrap_exceptions(self._database):
            self.close()


class ConnectionWrapper(object):
    """A DB-API connection, wrapping exceptions as StormError instances."""

    def __init__(self, connection, database):
        self.__dict__['_connection'] = connection
        self.__dict__['_database'] = database

    def __getattr__(self, name):
        attr = getattr(self._connection, name)
        if isinstance(attr, Callable):
            @wraps(attr)
            def wrapper(*args, **kwargs):
                with wrap_exceptions(self._database):
                    return attr(*args, **kwargs)

            return wrapper
        else:
            return attr

    def __setattr__(self, name, value):
        return setattr(self._connection, name, value)

    def __enter__(self):
        return self

    def __exit__(self, type_, value, tb):
        with wrap_exceptions(self._database):
            if type_ is None and value is None and tb is None:
                self.commit()
            else:
                self.rollback()

    def cursor(self):
        with wrap_exceptions(self._database):
            return CursorWrapper(self._connection.cursor(), self._database)


class Connection(object):
    """A connection to a database.

    @cvar result_factory: A callable which takes this L{Connection}
        and the backend cursor and returns an instance of L{Result}.
    @type param_mark: C{str}
    @cvar param_mark: The dbapi paramstyle that the database backend expects.
    @type compile: L{storm.expr.Compile}
    @cvar compile: The compiler to use for connections of this type.
    """

    result_factory = Result
    param_mark = "?"
    compile = compile

    _blocked = False
    _closed = False
    _two_phase_transaction = False  # If True, a two-phase transaction has
                                    # been started with begin()
    _state = STATE_CONNECTED

    def __init__(self, database, event=None):
        self._database = database # Ensures deallocation order.
        self._event = event
        self._raw_connection = self._database.raw_connect()

    def __del__(self):
        """Close the connection."""
        try:
            self.close()
        except:
            pass

    def block_access(self):
        """Block access to the connection.

        Attempts to execute statements or commit a transaction will
        result in a C{ConnectionBlockedError} exception.  Rollbacks
        are permitted as that operation is often used in case of
        failures.
        """
        self._blocked = True

    def unblock_access(self):
        """Unblock access to the connection."""
        self._blocked = False

    def execute(self, statement, params=None, noresult=False):
        """Execute a statement with the given parameters.

        @type statement: L{Expr} or C{str}
        @param statement: The statement to execute. It will be
            compiled if necessary.
        @param noresult: If True, no result will be returned.

        @raise ConnectionBlockedError: Raised if access to the connection
            has been blocked with L{block_access}.
        @raise DisconnectionError: Raised when the connection is lost.
            Reconnection happens automatically on rollback.

        @return: The result of C{self.result_factory}, or None if
            C{noresult} is True.
        """
        if self._closed:
            raise ClosedError("Connection is closed")
        if self._blocked:
            raise ConnectionBlockedError("Access to connection is blocked")
        if self._event:
            self._event.emit("register-transaction")
        self._ensure_connected()
        if isinstance(statement, Expr):
            if params is not None:
                raise ValueError("Can't pass parameters with expressions")
            state = State()
            statement = self.compile(statement, state)
            params = state.parameters
        statement = convert_param_marks(statement, "?", self.param_mark)
        raw_cursor = self.raw_execute(statement, params)
        if noresult:
            self._check_disconnect(raw_cursor.close)
            return None
        return self.result_factory(self, raw_cursor)

    def close(self):
        """Close the connection if it is not already closed."""
        if not self._closed:
            self._closed = True
            if self._raw_connection is not None:
                self._raw_connection.close()
                self._raw_connection = None

    def begin(self, xid):
        """Begin a two-phase transaction."""
        if self._two_phase_transaction:
            raise ProgrammingError("begin cannot be used inside a transaction")
        self._ensure_connected()
        raw_xid = self._raw_xid(xid)
        self._check_disconnect(self._raw_connection.tpc_begin, raw_xid)
        self._two_phase_transaction = True

    def prepare(self):
        """Run the prepare phase of a two-phase transaction."""
        if not self._two_phase_transaction:
            raise ProgrammingError("prepare must be called inside a two-phase "
                                   "transaction")
        self._check_disconnect(self._raw_connection.tpc_prepare)

    def commit(self, xid=None):
        """Commit the connection.

        @param xid: Optionally the L{Xid} of a previously prepared
             transaction to commit. This form should be called outside
             of a transaction, and is intended for use in recovery.

        @raise ConnectionBlockedError: Raised if access to the connection
            has been blocked with L{block_access}.
        @raise DisconnectionError: Raised when the connection is lost.
            Reconnection happens automatically on rollback.

        """
        try:
            self._ensure_connected()
            if xid:
                raw_xid = self._raw_xid(xid)
                self._check_disconnect(self._raw_connection.tpc_commit, raw_xid)
            elif self._two_phase_transaction:
                self._check_disconnect(self._raw_connection.tpc_commit)
                self._two_phase_transaction = False
            else:
                self._check_disconnect(self._raw_connection.commit)
        finally:
            self._check_disconnect(trace, "connection_commit", self, xid)

    def recover(self):
        """Return a list of L{Xid}\\ s representing pending transactions."""
        self._ensure_connected()
        raw_xids = self._check_disconnect(self._raw_connection.tpc_recover)
        return [Xid(raw_xid[0], raw_xid[1], raw_xid[2])
                for raw_xid in raw_xids]

    def rollback(self, xid=None):
        """Rollback the connection.

        @param xid: Optionally the L{Xid} of a previously prepared
             transaction to rollback. This form should be called outside
             of a transaction, and is intended for use in recovery.
        """
        try:
            if self._state == STATE_CONNECTED:
                try:
                    if xid:
                        raw_xid = self._raw_xid(xid)
                        self._raw_connection.tpc_rollback(raw_xid)
                    elif self._two_phase_transaction:
                        self._raw_connection.tpc_rollback()
                    else:
                        self._raw_connection.rollback()
                except Error as exc:
                    if self.is_disconnection_error(exc):
                        self._raw_connection = None
                        self._state = STATE_RECONNECT
                        self._two_phase_transaction = False
                    else:
                        raise
                else:
                    self._two_phase_transaction = False
            else:
                self._two_phase_transaction = False
                self._state = STATE_RECONNECT
        finally:
            self._check_disconnect(trace, "connection_rollback", self, xid)

    @staticmethod
    def to_database(params):
        """Convert some parameters into values acceptable to a database backend.

        It is acceptable to override this method in subclasses, but it
        is not intended to be used externally.

        This delegates conversion to any
        L{Variable <storm.variable.Variable>}\\ s in the parameter list, and
        passes through all other values untouched.
        """
        for param in params:
            if isinstance(param, Variable):
                yield param.get(to_db=True)
            else:
                yield param

    def build_raw_cursor(self):
        """Get a new dbapi cursor object.

        It is acceptable to override this method in subclasses, but it
        is not intended to be called externally.
        """
        return self._raw_connection.cursor()

    def raw_execute(self, statement, params=None):
        """Execute a raw statement with the given parameters.

        It's acceptable to override this method in subclasses, but it
        is not intended to be called externally.

        If the global C{DEBUG} is True, the statement will be printed
        to standard out.

        @return: The dbapi cursor object, as fetched from L{build_raw_cursor}.
        """
        raw_cursor = self._check_disconnect(self.build_raw_cursor)
        self._prepare_execution(raw_cursor, params, statement)
        args = self._execution_args(params, statement)
        self._run_execution(raw_cursor, args, params, statement)
        return raw_cursor

    def _execution_args(self, params, statement):
        """Get the appropriate statement execution arguments."""
        if params:
            args = (statement, tuple(self.to_database(params)))
        else:
            args = (statement,)
        return args

    def _run_execution(self, raw_cursor, args, params, statement):
        """Complete the statement execution, along with result reports."""
        try:
            self._check_disconnect(raw_cursor.execute, *args)
        except Exception as error:
            self._check_disconnect(
                trace, "connection_raw_execute_error", self, raw_cursor,
                statement, params or (), error)
            raise
        else:
            self._check_disconnect(
                trace, "connection_raw_execute_success", self, raw_cursor,
                statement, params or ())

    def _prepare_execution(self, raw_cursor, params, statement):
        """Prepare the statement execution to be run."""
        try:
            self._check_disconnect(
                trace, "connection_raw_execute", self, raw_cursor,
                statement, params or ())
        except Exception as error:
            self._check_disconnect(
                trace, "connection_raw_execute_error", self, raw_cursor,
                statement, params or (), error)
            raise

    def _ensure_connected(self):
        """Ensure that we are connected to the database.

        If the connection is marked as dead, or if we can't reconnect,
        then raise DisconnectionError.
        """
        if self._blocked:
            raise ConnectionBlockedError("Access to connection is blocked")
        if self._state == STATE_CONNECTED:
            return
        elif self._state == STATE_DISCONNECTED:
            raise DisconnectionError("Already disconnected")
        elif self._state == STATE_RECONNECT:
            try:
                self._raw_connection = self._database.raw_connect()
            except DatabaseError as exc:
                self._state = STATE_DISCONNECTED
                self._raw_connection = None
                raise DisconnectionError(str(exc))
            else:
                self._state = STATE_CONNECTED

    def is_disconnection_error(self, exc, extra_disconnection_errors=()):
        """Check whether an exception represents a database disconnection.

        This should be overridden by backends to detect whichever
        exception values are used to represent this condition.
        """
        return False

    def _raw_xid(self, xid):
        """Return a raw xid from the given high-level L{Xid} object."""
        return self._raw_connection.xid(xid.format_id,
                                        xid.global_transaction_id,
                                        xid.branch_qualifier)

    def _check_disconnect(self, function, *args, **kwargs):
        """Run the given function, checking for database disconnections."""
        # Allow the caller to specify additional exception types that
        # should be treated as possible disconnection errors.
        extra_disconnection_errors = kwargs.pop(
            'extra_disconnection_errors', ())
        try:
            return function(*args, **kwargs)
        except Exception as exc:
            if self.is_disconnection_error(exc, extra_disconnection_errors):
                self._state = STATE_DISCONNECTED
                self._raw_connection = None
                raise DisconnectionError(str(exc))
            else:
                raise

    def preset_primary_key(self, primary_columns, primary_variables):
        """Process primary variables before an insert happens.

        This method may be overwritten by backends to implement custom
        changes in primary variables before an insert happens.
        """


class Database(object):
    """A database that can be connected to.

    This should be subclassed for individual database backends.

    @cvar connection_factory: A callable which will take this database
        and should return an instance of L{Connection}.
    """

    connection_factory = Connection

    def __init__(self, uri=None):
        self._uri = uri
        self._exception_types = {}

    def get_uri(self):
        """Return the URI object this database was created with."""
        return self._uri

    def connect(self, event=None):
        """Create a connection to the database.

        It calls C{self.connection_factory} to allow for ease of
        customization.

        @param event: The event system to broadcast messages with. If
            not specified, then no events will be broadcast.

        @return: An instance of L{Connection}.
        """
        return self.connection_factory(self, event)

    def raw_connect(self):
        """Create a raw database connection.

        This is used by L{Connection} objects to connect to the
        database.  It should be overriden in subclasses to do any
        database-specific connection setup.

        @return: A DB-API connection object.
        """
        raise NotImplementedError

    @property
    def _exception_module(self):
        """The module where appropriate DB-API exception types are defined.

        Subclasses should set this if they support re-raising DB-API
        exceptions as StormError instances.
        """
        return None

    def _make_combined_exception_type(self, wrapper_type, dbapi_type):
        """Make a combined exception based on both DB-API and Storm.

        Storm historically defined its own exception types as ABCs and
        registered the DB-API exception types as virtual subclasses.
        However, this doesn't work properly in Python 3
        (https://bugs.python.org/issue12029).  Instead, we create and cache
        subclass-specific exception types that inherit from both StormError
        and the DB-API exception type, allowing code that catches either
        StormError (or subclasses) or the specific DB-API exceptions to keep
        working.

        @type wrapper_type: L{type}
        @param wrapper_type: The type of the wrapper exception to create; a
            subclass of L{StormError}.
        @type dbapi_type: L{type}
        @param dbapi_type: The type of the DB-API exception.

        @return: The combined exception type.
        """
        if dbapi_type.__name__ not in self._exception_types:
            self._exception_types[dbapi_type.__name__] = type(
                dbapi_type.__name__, (dbapi_type, wrapper_type), {})
        return self._exception_types[dbapi_type.__name__]

    def _wrap_exception(self, wrapper_type, exception):
        """Wrap a DB-API exception as a StormError instance.

        This constructs a wrapper exception with the same C{args} as the
        DB-API exception.  Subclasses may override this to set additional
        attributes on the wrapper exception.

        @type wrapper_type: L{type}
        @param wrapper_type: The type of the wrapper exception to create; a
            subclass of L{StormError}.
        @type exception: L{Exception}
        @param exception: The DB-API exception to wrap.

        @return: The wrapped exception; an instance of L{StormError}.
        """
        return self._make_combined_exception_type(
            wrapper_type, exception.__class__)(*exception.args)


def convert_param_marks(statement, from_param_mark, to_param_mark):
    # TODO: Add support for $foo$bar$foo$ literals.
    if from_param_mark == to_param_mark or from_param_mark not in statement:
        return statement
    tokens = statement.split("'")
    for i in range(0, len(tokens), 2):
        tokens[i] = tokens[i].replace(from_param_mark, to_param_mark)
    return "'".join(tokens)


_database_schemes = {}

def register_scheme(scheme, factory):
    """Register a handler for a new database URI scheme.

    @param scheme: the database URI scheme
    @param factory: a function taking a URI instance and returning a database.
    """
    _database_schemes[scheme] = factory


def create_database(uri):
    """Create a database instance.

    @param uri: An URI instance, or a string describing the URI. Some examples:

    "sqlite:"
        An in memory sqlite database.

    "sqlite:example.db"
        A SQLite database called example.db

    "postgres:test"
        The database 'test' from the local postgres server.

    "postgres://user:password@host/test"
        The database test on machine host with supplied user credentials,
        using postgres.

    "anything:..."
        Where 'anything' has previously been registered with
        L{register_scheme}.
    """
    if isinstance(uri, str):
        uri = URI(uri)
    if uri.scheme in _database_schemes:
        factory = _database_schemes[uri.scheme]
    else:
        module = __import__("%s.databases.%s" % (storm.__name__, uri.scheme),
                            None, None, [""])
        factory = module.create_from_uri
    return factory(uri)

# Deal with circular import.
from storm.tracer import trace