~johannes.baiter/mnemosyne-proj/mnemodroid

« back to all changes in this revision

Viewing changes to src/com/mnemodroid/mnemosyne/openSM2sync/server.py

  • Committer: Johannes Baiter
  • Date: 2011-02-15 00:30:59 UTC
  • Revision ID: johannes.baiter@gmail.com-20110215003059-83fn5ebmjs89jl2d
Relocated python scripts, added README and shellscript for tarball-creation, some bugfixes

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
#
2
 
# server.py - Max Usachev <maxusachev@gmail.com>
3
 
#             Ed Bartosh <bartosh@gmail.com>
4
 
#             Peter Bienstman <Peter.Bienstman@UGent.be>
5
 
 
6
 
import os
7
 
import sys
8
 
import cgi
9
 
import time
10
 
import select
11
 
import socket
12
 
import tarfile
13
 
import httplib
14
 
import tempfile
15
 
from wsgiref.simple_server import WSGIServer, WSGIRequestHandler
16
 
 
17
 
from utils import traceback_string, rand_uuid
18
 
from text_formats.xml_format import XMLFormat
19
 
from partner import Partner, UnsizedLogEntryStreamReader, BUFFER_SIZE
20
 
 
21
 
# Avoid delays caused by Nagle's algorithm.
22
 
# http://www.cmlenz.net/archives/2008/03/python-httplib-performance-problems
23
 
 
24
 
realsocket = socket.socket
25
 
def socketwrap(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
26
 
    sockobj = realsocket(family, type, proto)
27
 
    sockobj.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
28
 
    return sockobj
29
 
socket.socket = socketwrap
30
 
 
31
 
# Work around http://bugs.python.org/issue6085.
32
 
 
33
 
def not_insane_address_string(self):
34
 
    host, port = self.client_address[:2]
35
 
    return "%s (no getfqdn)" % host
36
 
 
37
 
WSGIRequestHandler.address_string = not_insane_address_string
38
 
 
39
 
# Don't pollute our testsuite output.
40
 
 
41
 
def dont_log(*kwargs):
42
 
    pass
43
 
 
44
 
WSGIRequestHandler.log_message = dont_log
45
 
 
46
 
# Register binary formats.
47
 
 
48
 
from binary_formats.mnemosyne_format import MnemosyneFormat
49
 
BinaryFormats = [MnemosyneFormat]
50
 
 
51
 
 
52
 
class Session(object):
53
 
 
54
 
    """Very basic session support.
55
 
 
56
 
    Note that although the current code supports multiple open sessions at
57
 
    once, it does not yet support the locking mechanisms to make this
58
 
    thread-safe.
59
 
 
60
 
    """
61
 
 
62
 
    def __init__(self, client_info, database):
63
 
        self.token = rand_uuid()
64
 
        self.client_info = client_info
65
 
        self.database = database
66
 
        self.client_log = []
67
 
        self.apply_error = None
68
 
        self.expires = time.time() + 60*60
69
 
        self.backup_file = self.database.backup()
70
 
        self.database.set_sync_partner_info(client_info)
71
 
 
72
 
    def is_expired(self):
73
 
        return time.time() > self.expired
74
 
 
75
 
    def close(self):
76
 
        self.database.update_last_log_index_synced_for(\
77
 
            self.client_info["machine_id"])
78
 
        self.database.save()
79
 
 
80
 
    def terminate(self):
81
 
 
82
 
        """Restore from backup if the session failed to close normally."""
83
 
 
84
 
        self.database.restore(self.backup_file)
85
 
        
86
 
 
87
 
class Server(WSGIServer, Partner):
88
 
 
89
 
    program_name = "unknown-SRS-app"
90
 
    program_version = "unknown"
91
 
 
92
 
    def __init__(self, machine_id, port, ui):        
93
 
        self.machine_id = machine_id
94
 
        WSGIServer.__init__(self, ("", port), WSGIRequestHandler)
95
 
        self.set_app(self.wsgi_app)
96
 
        Partner.__init__(self, ui)
97
 
        self.text_format = XMLFormat()
98
 
        self.stopped = False
99
 
        self.sessions = {} # {session_token: session}
100
 
        self.session_token_for_user = {} # {user_name: session_token}
101
 
 
102
 
    def serve_until_stopped(self):
103
 
        while not self.stopped:
104
 
            # We time out every 0.25 seconds, so that changing
105
 
            # self.stopped can have an effect.
106
 
            if select.select([self.socket], [], [], 0.25)[0]:
107
 
                self.handle_request()
108
 
        self.socket.close()
109
 
 
110
 
    def wsgi_app(self, environ, start_response):
111
 
        # Catch badly formed requests.
112
 
        status, method, args  = self.get_method(environ)
113
 
        if status != "200 OK":
114
 
            response_headers = [("Content-type", "text/plain")]
115
 
            start_response(status, response_headers)
116
 
            return [status]
117
 
        # Note that it is no use to wrap the function call in a try/except
118
 
        # statement. The reponse could be an iterable, in which case more
119
 
        # calls to e.g. 'get_server_log_entries' could follow outside of this
120
 
        # function 'wsgi_app'. Any exceptions that occur then will no longer
121
 
        # be caught here. Therefore, we need to catch all of our exceptions
122
 
        # ourselves at the lowest level.
123
 
        response_headers = [("Content-type", self.text_format.mime_type)]
124
 
        start_response("200 OK", response_headers)
125
 
        return getattr(self, method)(environ, **args)
126
 
        
127
 
    def get_method(self, environ):
128
 
        # Convert e.g. GET /foo_bar into get_foo_bar.
129
 
        method = (environ["REQUEST_METHOD"] + \
130
 
                  environ["PATH_INFO"].replace("/", "_")).lower()
131
 
        args = cgi.parse_qs(environ["QUERY_STRING"])
132
 
        args = dict([(key, val[0]) for key, val in args.iteritems()])
133
 
        # Login method.
134
 
        if method == "put_login" or method == "get_status":
135
 
            if len(args) == 0:
136
 
                return "200 OK", method, args
137
 
            else:
138
 
                return "400 Bad Request", None, None             
139
 
        # See if the token matches.
140
 
        if not "session_token" in args or args["session_token"] \
141
 
            not in self.sessions:
142
 
            return "403 Forbidden", None, None
143
 
        # See if the method exists.
144
 
        if hasattr(self, method) and callable(getattr(self, method)):
145
 
            return "200 OK", method, args
146
 
        else:
147
 
            return "404 Not Found", None, None
148
 
 
149
 
    # The following functions are not yet thread safe.
150
 
 
151
 
    def create_session(self, client_info):
152
 
        database = self.load_database(client_info["database_name"])
153
 
        session = Session(client_info, database)
154
 
        self.sessions[session.token] = session
155
 
        self.session_token_for_user[client_info["username"]] = session.token
156
 
        return session
157
 
 
158
 
    def close_session_with_token(self, session_token):
159
 
        session = self.sessions[session_token]
160
 
        session.close()
161
 
        self.unload_database(session.database)        
162
 
        del self.session_token_for_user[session.client_info["username"]]
163
 
        del self.sessions[session_token]
164
 
        self.ui.close_progress()
165
 
        
166
 
    def cancel_session_with_token(self, session_token):
167
 
 
168
 
        """Cancel a session at the user's request, e.g. after detecting
169
 
        conflicts.
170
 
 
171
 
        """
172
 
        
173
 
        session = self.sessions[session_token]
174
 
        self.unload_database(session.database)
175
 
        del self.session_token_for_user[session.client_info["username"]]
176
 
        del self.sessions[session_token]
177
 
        self.ui.close_progress()
178
 
        
179
 
    def terminate_session_with_token(self, session_token):
180
 
 
181
 
        """Clean up a session which failed to close normally."""
182
 
 
183
 
        session = self.sessions[session_token]
184
 
        session.terminate()
185
 
        self.unload_database(session.database)      
186
 
        del self.session_token_for_user[session.client_info["username"]]
187
 
        del self.sessions[session_token]
188
 
        self.ui.close_progress()
189
 
        
190
 
    def terminate_all_sessions(self):
191
 
        for session_token in self.sessions.keys():
192
 
            self.terminate_session_with_token(session_token)
193
 
            
194
 
    def handle_error(self, session=None, traceback_string=None):
195
 
        if session:
196
 
            self.terminate_session_with_token(session.token)
197
 
        if traceback_string:
198
 
            self.ui.show_error(traceback_string)
199
 
            return self.text_format.repr_message("Internal server error",
200
 
                traceback_string)
201
 
    
202
 
    def stop(self):
203
 
        self.terminate_all_sessions()
204
 
        self.stopped = True
205
 
        self.ui.close_progress()
206
 
        
207
 
    def binary_format_for(self, session):
208
 
        for BinaryFormat in BinaryFormats:
209
 
            binary_format = BinaryFormat(session.database)
210
 
            if binary_format.supports(session.client_info["program_name"],
211
 
                session.client_info["program_version"],
212
 
                session.client_info["database_version"]):
213
 
                return binary_format
214
 
        return None
215
 
 
216
 
    def supports_binary_transfer(self, session):
217
 
 
218
 
        """For testability, can easily be overridden by testsuite. """
219
 
        
220
 
        return self.binary_format_for(session) is not None
221
 
    
222
 
    # The following functions are to be overridden by the actual server code,
223
 
    # to implement e.g. authorisation, storage, ... .
224
 
 
225
 
    def authorise(self, username, password):
226
 
 
227
 
        """Returns True if 'password' is correct for 'username'."""
228
 
        
229
 
        raise NotImplementedError
230
 
 
231
 
    def load_database(self, database_name):
232
 
 
233
 
        """Returns a database object for the database named 'database_name'.
234
 
        Should create the database if it does not exist yet.
235
 
 
236
 
        """
237
 
 
238
 
        raise NotImplementedError
239
 
 
240
 
    def unload_database(self, database):
241
 
 
242
 
        """Here, there is the possibility for a custom server to do some
243
 
        after sync cleanup.
244
 
 
245
 
        """
246
 
        
247
 
        pass
248
 
    
249
 
    # The following are methods that are supported by the server through GET
250
 
    # and PUT calls. 'get_foo_bar' gets executed after a 'GET /foo_bar'
251
 
    # request. Similarly, 'put_foo_bar' gets executed after a 'PUT /foo_bar'
252
 
    # request.
253
 
 
254
 
    def get_status(self, environ):
255
 
        return [self.text_format.repr_message("OK")]
256
 
 
257
 
    def put_login(self, environ):
258
 
        session = None
259
 
        try:
260
 
            self.ui.set_progress_text("Client logging in...")
261
 
            client_info_repr = environ["wsgi.input"].readline()
262
 
            client_info = self.text_format.parse_partner_info(\
263
 
                client_info_repr)
264
 
            if not self.authorise(client_info["username"],
265
 
                client_info["password"]):
266
 
                return [self.text_format.repr_message("Access denied")]
267
 
            # Close old session waiting in vain for client input.
268
 
            old_running_session_token = self.session_token_for_user.\
269
 
                get(client_info["username"])
270
 
            if old_running_session_token:
271
 
                self.terminate_session_with_token(old_running_session_token)
272
 
            session = self.create_session(client_info)
273
 
            # If the client database is empty, perhaps it was reset, and we
274
 
            # need to delete the partnership from our side too.
275
 
            if session.client_info["database_is_empty"] == True:
276
 
                session.database.remove_partnership_with(\
277
 
                    session.client_info["machine_id"])
278
 
            # Make sure there are no cycles in the sync graph.
279
 
            server_in_client_partners = self.machine_id in \
280
 
                session.client_info["partners"]
281
 
            client_in_server_partners = session.client_info["machine_id"] in \
282
 
                session.database.partners()
283
 
            if (server_in_client_partners and not client_in_server_partners)\
284
 
               or \
285
 
               (client_in_server_partners and not server_in_client_partners):
286
 
                self.terminate_session_with_token(session.token)                
287
 
                return [self.text_format.repr_message("Sync cycle detected")]
288
 
            session.database.create_if_needed_partnership_with(\
289
 
                client_info["machine_id"])
290
 
            session.database.merge_partners(client_info["partners"])
291
 
            # Note that we need to send 'user_id' to the client as well, so
292
 
            # that the client can make sure the 'user_id's (used to label the
293
 
            # anonymous uploaded logs) are consistent across machines.
294
 
            server_info = {"user_id": session.database.user_id(),
295
 
                "machine_id": self.machine_id,
296
 
                "program_name": self.program_name,
297
 
                "program_version": self.program_version,
298
 
                "database_version": session.database.version,
299
 
                "partners": session.database.partners(),
300
 
                "session_token": session.token,
301
 
                "supports_binary_transfer": \
302
 
                    self.supports_binary_transfer(session)}
303
 
            # Add optional program-specific information.
304
 
            server_info = \
305
 
                session.database.append_to_sync_partner_info(server_info)
306
 
            # We check if files were updated outside of the program. This can
307
 
            # generate MEDIA_EDITED log entries, so it should be done first.
308
 
            session.database.check_for_edited_media_files()
309
 
            return [self.text_format.repr_partner_info(server_info)\
310
 
                   .encode("utf-8")] 
311
 
        except:
312
 
            # We need to be really thorough in our exception handling, so as
313
 
            # to always revert the database to its last backup if an error
314
 
            # occurs. It is important that this happens as soon as possible,
315
 
            # especially if this server is being run as a built-in server in a
316
 
            # thread in an SRS desktop application.
317
 
            # As mentioned before, the error handling should happen here, at
318
 
            # the lowest level, and not in e.g. 'wsgi_app'.
319
 
            return [self.handle_error(session, traceback_string())]
320
 
 
321
 
    def put_client_log_entries(self, environ, session_token):
322
 
        try:
323
 
            session = self.sessions[session_token]
324
 
            self.ui.set_progress_text("Receiving log entries...")
325
 
            socket = environ["wsgi.input"]
326
 
            # In order to do conflict resolution easily, one of the sync
327
 
            # partners has to have both logs in memory. We do this at the
328
 
            # server side, as the client could be a resource-limited mobile
329
 
            # device.
330
 
            session.client_log = []
331
 
            client_o_ids = []
332
 
            def callback(context, log_entry):
333
 
                context["session_client_log"].append(log_entry)
334
 
                if log_entry["type"] > 5: # not STARTED_PROGRAM,
335
 
                    # STOPPED_PROGRAM, STARTED_SCHEDULER, LOADED_DATABASE,
336
 
                    # SAVED_DATABASE
337
 
                    if "fname" in log_entry:
338
 
                        log_entry["o_id"] = log_entry["fname"]
339
 
                    context["client_o_ids"].append(log_entry["o_id"])
340
 
            context = {"session_client_log": session.client_log,
341
 
                       "client_o_ids": client_o_ids}
342
 
            adapted_stream = UnsizedLogEntryStreamReader(socket,
343
 
                self.text_format.log_entries_footer())
344
 
            self.download_log_entries(adapted_stream, callback, context)
345
 
            # Now we can determine whether there are conflicts.
346
 
            for log_entry in session.database.log_entries_to_sync_for(\
347
 
                session.client_info["machine_id"]):
348
 
                if not log_entry:
349
 
                    continue  # Irrelevent entry for card-based clients.
350
 
                if "fname" in log_entry:
351
 
                    log_entry["o_id"] = log_entry["fname"]
352
 
                if log_entry["type"] > 5 and \
353
 
                    log_entry["o_id"] in client_o_ids:
354
 
                    return [self.text_format.repr_message("Conflict")]
355
 
            return [self.text_format.repr_message("OK")]
356
 
        except:
357
 
            return [self.handle_error(session, traceback_string())]
358
 
        
359
 
    def put_client_entire_database_binary(self, environ, session_token):
360
 
        try:
361
 
            session = self.sessions[session_token] 
362
 
            self.ui.set_progress_text("Getting entire binary database...")
363
 
            filename = session.database.path()
364
 
            session.database.abandon()
365
 
            self.download_binary_file(filename, environ["wsgi.input"])
366
 
            session.database.load(filename)
367
 
            session.database.create_if_needed_partnership_with(\
368
 
                session.client_info["machine_id"])
369
 
            session.database.remove_partnership_with(self.machine_id)
370
 
            return [self.text_format.repr_message("OK")]
371
 
        except:
372
 
            return [self.handle_error(session, traceback_string())]
373
 
 
374
 
    def get_server_log_entries(self, environ, session_token):
375
 
        try:
376
 
            session = self.sessions[session_token]
377
 
            self.ui.set_progress_text("Sending log entries...")
378
 
            log_entries = session.database.log_entries_to_sync_for(\
379
 
                session.client_info["machine_id"],
380
 
                session.client_info["interested_in_old_reps"])
381
 
            number_of_entries = session.database.\
382
 
                number_of_log_entries_to_sync_for(\
383
 
                session.client_info["machine_id"],
384
 
                session.client_info["interested_in_old_reps"])
385
 
            for buffer in self.stream_log_entries(log_entries,
386
 
                number_of_entries):
387
 
                yield buffer        
388
 
        except:
389
 
            yield self.handle_error(session, traceback_string())
390
 
        # Now that all the data is underway to the client, we can already
391
 
        # start applying the client log entries. If there are errors that
392
 
        # occur, we save them and communicate them to the client in
393
 
        # 'get_sync_finish'.
394
 
        try:    
395
 
            self.ui.set_progress_text("Applying log entries...")
396
 
            # First, dump to the science log, so that we can skip over the new
397
 
            # logs in case the client uploads them.
398
 
            session.database.dump_to_science_log()
399
 
            for log_entry in session.client_log:
400
 
                session.database.apply_log_entry(log_entry)
401
 
            # Skip over the logs that the client promised to upload.
402
 
            if session.client_info["upload_science_logs"]:
403
 
                session.database.skip_science_log()
404
 
        except:
405
 
            session.apply_error = traceback_string()
406
 
 
407
 
    def get_server_entire_database(self, environ, session_token):
408
 
        try:
409
 
            session = self.sessions[session_token]
410
 
            self.ui.set_progress_text("Sending entire database...")
411
 
            session.database.dump_to_science_log()
412
 
            log_entries = session.database.all_log_entries(\
413
 
                session.client_info["interested_in_old_reps"])
414
 
            number_of_entries = session.database.number_of_log_entries(\
415
 
                session.client_info["interested_in_old_reps"])
416
 
            for buffer in self.stream_log_entries(log_entries,
417
 
                number_of_entries):
418
 
                yield buffer
419
 
        except:
420
 
            yield self.handle_error(session, traceback_string())
421
 
 
422
 
    def get_server_entire_database_binary(self, environ, session_token):
423
 
        try:
424
 
            session = self.sessions[session_token]
425
 
            self.ui.set_progress_text("Sending entire binary database...")
426
 
            binary_format = self.binary_format_for(session)
427
 
            binary_file, file_size = binary_format.binary_file_and_size(\
428
 
                session.client_info["store_pregenerated_data"],
429
 
                session.client_info["interested_in_old_reps"])
430
 
            for buffer in self.stream_binary_file(binary_file, file_size):
431
 
                yield buffer
432
 
            binary_format.clean_up()
433
 
            # This is a full sync, we don't need to apply client log
434
 
            # entries here.
435
 
        except:
436
 
            yield self.handle_error(session, traceback_string())        
437
 
 
438
 
    def put_client_media_files(self, environ, session_token):
439
 
        try:
440
 
            session = self.sessions[session_token]
441
 
            self.ui.set_progress_text("Getting media files...")
442
 
            socket = environ["wsgi.input"]
443
 
            size = int(socket.readline())
444
 
            tar_pipe = tarfile.open(mode="r|", fileobj=socket)
445
 
            # Work around http://bugs.python.org/issue7693.
446
 
            tar_pipe.extractall(session.database.media_dir().encode("utf-8"))
447
 
            return [self.text_format.repr_message("OK")]
448
 
        except:
449
 
            return [self.handle_error(session, traceback_string())]        
450
 
 
451
 
    def get_server_media_files(self, environ, session_token,
452
 
                               redownload_all=False):
453
 
        try:
454
 
            session = self.sessions[session_token]
455
 
            # Note that for media files, we use tar stream directy for efficiency
456
 
            # reasons, and bypass the routines in Partner.
457
 
            self.ui.set_progress_text("Sending media files...")
458
 
            # Determine files to send across.
459
 
            if redownload_all in ["1", "True", "true"]:
460
 
                filenames = list(session.database.all_media_filenames())
461
 
            else:
462
 
                filenames = list(session.database.media_filenames_to_sync_for(\
463
 
                    session.client_info["machine_id"]))
464
 
            if len(filenames) == 0:
465
 
                yield "0\n"
466
 
                return
467
 
            # Create a temporary tar file with the files.
468
 
            tmp_file = tempfile.NamedTemporaryFile(delete=False)
469
 
            tmp_file_name = tmp_file.name
470
 
            saved_path = os.getcwdu()
471
 
            os.chdir(session.database.media_dir())
472
 
            tar_pipe = tarfile.open(mode="w|", fileobj=tmp_file,
473
 
                bufsize=BUFFER_SIZE, format=tarfile.PAX_FORMAT)
474
 
            for filename in filenames:
475
 
                tar_pipe.add(filename)
476
 
            tar_pipe.close()
477
 
            # Stream tar file across.
478
 
            tmp_file = file(tmp_file_name, "rb")
479
 
            file_size = os.path.getsize(tmp_file_name)
480
 
            for buffer in self.stream_binary_file(tmp_file, file_size):
481
 
                yield buffer            
482
 
            os.remove(tmp_file_name)
483
 
            os.chdir(saved_path)
484
 
        except:
485
 
            yield self.handle_error(session, traceback_string())
486
 
 
487
 
    def get_sync_cancel(self, environ, session_token):
488
 
        try:
489
 
            self.ui.set_progress_text("Sync cancelled!")
490
 
            self.cancel_session_with_token(session_token)
491
 
            return [self.text_format.repr_message("OK")]
492
 
        except:
493
 
            session = self.sessions[session_token]
494
 
            return [self.handle_error(session, traceback_string())]
495
 
        
496
 
    def get_sync_finish(self, environ, session_token):           
497
 
        try:
498
 
            session = self.sessions[session_token]
499
 
            if session.apply_error:
500
 
                return [self.handle_error(session, session.apply_error)]
501
 
            self.ui.set_progress_text("Sync finished!")
502
 
            self.close_session_with_token(session_token) 
503
 
            # Now is a good time to garbage-collect dangling sessions.
504
 
            # Only relevant for multi-user server.
505
 
            for session_token, session in self.sessions.iteritems():
506
 
                if session.is_expired():
507
 
                    self.terminate_session_with_token(session_token)
508
 
            return [self.text_format.repr_message("OK")]
509
 
        except:
510
 
            return [self.handle_error(session, traceback_string())]