~awuerl/blitzortung-python/master

« back to all changes in this revision

Viewing changes to blitzortung/db.py

  • Committer: Andreas Würl
  • Date: 2012-01-29 15:34:23 UTC
  • Revision ID: git-v1:cdca5487c8322e426d0859349fa52643e0a47019
added python code from blitzortung-tracker-tools

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# -*- coding: utf8 -*-
 
2
 
 
3
'''
 
4
 
 
5
@author: Andreas Würl
 
6
 
 
7
'''
 
8
 
 
9
import os
 
10
 
 
11
import math
 
12
 
 
13
import pytz
 
14
import shapely.wkb
 
15
 
 
16
import psycopg2
 
17
import psycopg2.extras
 
18
import psycopg2.extensions
 
19
 
 
20
import GeoTypes
 
21
 
 
22
import data
 
23
import geom
 
24
 
 
25
GeoTypes.initialisePsycopgTypes(psycopg_module=psycopg2, psycopg_extensions_module=psycopg2.extensions)
 
26
 
 
27
from abc import ABCMeta, abstractmethod
 
28
 
 
29
class TimeInterval:
 
30
 
 
31
  def __init__(self, start = None, end = None):
 
32
    self.start = start
 
33
    self.end = end
 
34
 
 
35
  def get_start(self):
 
36
    return self.start
 
37
 
 
38
  def get_end(self):
 
39
    return self.end
 
40
 
 
41
  def __str__(self):
 
42
    return '[' + str(self.start) + ' - ' + str(self.end) + ']'
 
43
 
 
44
 
 
45
class Query:
 
46
  '''
 
47
  simple class for building of complex queries
 
48
  '''
 
49
 
 
50
  def __init__(self):
 
51
    self.sql = ''
 
52
    self.conditions = []
 
53
    self.parameters = {}
 
54
    self.table_name = None
 
55
    self.columns = None
 
56
    self.limit = None
 
57
    self.order = []
 
58
 
 
59
  def set_table_name(self, table_name):
 
60
    self.table_name = table_name
 
61
 
 
62
  def set_columns(self, columns):
 
63
    self.columns = columns
 
64
 
 
65
  def add_order(self, order):
 
66
    self.order.append(order)
 
67
 
 
68
  def set_limit(self, limit):
 
69
    if self.limit != None:
 
70
      raise Error("overriding Query.limit")
 
71
    self.limit = limit
 
72
 
 
73
  def add_condition(self, condition, parameters = None):
 
74
    self.conditions.append(condition)
 
75
    if parameters != None:
 
76
      self.parameters.update(parameters)
 
77
 
 
78
  def add_parameters(self, parameters):
 
79
    self.parameters.update(parameters)
 
80
 
 
81
  def __str__(self):
 
82
    sql = 'SELECT '
 
83
 
 
84
    if self.columns:
 
85
      for index, column in enumerate(self.columns):
 
86
        if index != 0:
 
87
          sql += ', '
 
88
        sql += column
 
89
      sql += ' '
 
90
 
 
91
    sql += 'FROM ' + self.table_name + ' '
 
92
 
 
93
    for index, condition in enumerate(self.conditions):
 
94
      if index == 0:
 
95
        sql += 'WHERE '
 
96
      else:
 
97
        sql += 'AND '
 
98
      sql += condition + ' '
 
99
 
 
100
    if len(self.order) > 0:
 
101
      sql += 'ORDER BY '
 
102
      for index, order in enumerate(self.order):
 
103
        if index != 0:
 
104
          sql += ', '
 
105
        sql += order.get_column() + ' '
 
106
        if order.is_desc():
 
107
          sql += 'DESC '
 
108
 
 
109
    if self.limit:
 
110
      sql += 'LIMIT ' + str(self.limit.get_number()) + ' '
 
111
 
 
112
    return sql
 
113
 
 
114
  def get_parameters(self):
 
115
      return self.parameters
 
116
 
 
117
  def parse_args(self, args):
 
118
    for arg in args:
 
119
      if arg:
 
120
        if isinstance(arg, TimeInterval):
 
121
 
 
122
          if arg.get_start() != None:
 
123
            self.add_condition('timestamp >= %(starttime)s', {'starttime': arg.get_start()})
 
124
 
 
125
          if arg.get_end() != None:
 
126
            self.add_condition('timestamp < %(endtime)s', {'endtime': arg.get_end()})
 
127
 
 
128
        elif isinstance(arg, shapely.geometry.base.BaseGeometry):
 
129
 
 
130
          if arg.is_valid:
 
131
 
 
132
            self.add_condition('SetSRID(CAST(%(envelope)s AS geometry), %(srid)s) && st_transform(the_geom, %(srid)s)', {'envelope': shapely.wkb.dumps(arg.envelope).encode('hex')})
 
133
 
 
134
            if not arg.equals(arg.envelope):
 
135
              self.add_condition('Intersects(SetSRID(CAST(%(geometry)s AS geometry), %(srid)s), st_transform(the_geom, %(srid)s))', {'geometry': shapely.wkb.dumps(arg).encode('hex')})
 
136
 
 
137
          else:
 
138
              raise Error("invalid geometry in db.Stroke.select()")
 
139
 
 
140
        elif isinstance(arg, Order):
 
141
            self.add_order(arg)
 
142
 
 
143
        elif isinstance(arg, Limit):
 
144
            self.setLimit(arg)
 
145
 
 
146
        else:
 
147
            print 'WARNING: ' + __name__ + ' unhandled object ' + str(type(arg))
 
148
 
 
149
  def get_results(self, db):
 
150
 
 
151
    resulting_strokes = []
 
152
    if db.cur.rowcount > 0:
 
153
      for result in db.cur.fetchall():
 
154
        resulting_strokes.append(db.create(result))
 
155
 
 
156
    return resulting_strokes
 
157
 
 
158
class RasterQuery(Query):
 
159
 
 
160
  def __init__(self, raster):
 
161
    Query.__init__(self)
 
162
 
 
163
    self.raster = raster
 
164
 
 
165
    env = self.raster.getEnv()
 
166
 
 
167
    if env.is_valid:
 
168
      self.add_condition('SetSRID(CAST(%(envelope)s AS geometry), %(srid)s) && st_transform(the_geom, %(srid)s)', {'envelope': shapely.wkb.dumps(env).encode('hex')})
 
169
    else:
 
170
      raise Error("invalid Raster geometry in db.Stroke.select()")
 
171
 
 
172
  def __str__(self):
 
173
    sql = 'SELECT '
 
174
 
 
175
    sql += 'TRUNC((ST_X(ST_TRANSFORM(the_geom, %(srid)s)) - ' + str(self.raster.getXMin()) + ') /' + str(self.raster.getXDiv()) + ') AS rx, '
 
176
    sql += 'TRUNC((ST_Y(ST_TRANSFORM(the_geom, %(srid)s)) - ' + str(self.raster.getYMin()) + ') /' + str(self.raster.getYDiv()) + ') AS ry, '
 
177
    sql += 'count(*) AS count FROM ('
 
178
 
 
179
    sql += Query.__str__(self)
 
180
 
 
181
    sql += ') AS ' + self.table_name + ' GROUP BY rx, ry'
 
182
 
 
183
    return sql
 
184
 
 
185
  def get_results(self, db):
 
186
 
 
187
    if db.cur.rowcount > 0:
 
188
      for result in db.cur.fetchall():
 
189
        self.raster.set(result['rx'], result['ry'], result['count'])
 
190
    return self.raster
 
191
 
 
192
class Order(object):
 
193
    '''
 
194
    definition for query search order
 
195
    '''
 
196
 
 
197
    def __init__(self, column, desc = False):
 
198
        self.column = column
 
199
        self.desc = desc
 
200
 
 
201
    def get_column(self):
 
202
        return self.column
 
203
 
 
204
    def is_desc(self):
 
205
        return self.desc
 
206
 
 
207
 
 
208
class Limit(object):
 
209
    '''
 
210
    definition of query result limit
 
211
    '''
 
212
 
 
213
    def __init__(self, limit):
 
214
        self.limit = limit
 
215
 
 
216
    def get_number(self):
 
217
        return self.limit
 
218
 
 
219
 
 
220
class Center(object):
 
221
    '''
 
222
    definition of query center point
 
223
    '''
 
224
 
 
225
    def __init__(self, center):
 
226
        self.center = center
 
227
 
 
228
    def get_point(self):
 
229
        return self.center
 
230
 
 
231
 
 
232
class Base(object):
 
233
    '''
 
234
    abstract base class for database access objects
 
235
 
 
236
    creation of database 
 
237
 
 
238
    psql as user postgres:
 
239
 
 
240
    CREATE USER blitzortung PASSWORD 'blitzortung' INHERIT;
 
241
 
 
242
    createdb -T postgistemplate -E utf8 -O blitzortung blitzortung
 
243
 
 
244
    psql blitzortung
 
245
 
 
246
    GRANT SELECT ON spatial_ref_sys TO blitzortung;
 
247
    GRANT SELECT ON geometry_columns TO blitzortung;
 
248
    GRANT INSERT, DELETE ON geometry_columns TO blitzortung;
 
249
 
 
250
    '''
 
251
    __metaclass__ = ABCMeta
 
252
 
 
253
    DefaultTimezone = pytz.UTC
 
254
 
 
255
    def __init__(self):
 
256
        '''
 
257
        create PostgreSQL db access object
 
258
        '''
 
259
 
 
260
        connection = "host='localhost' dbname='blitzortung' user='blitzortung' password='blitzortung'"
 
261
        self.schema_name = None
 
262
        self.cur = None
 
263
        self.conn = None
 
264
 
 
265
        self.srid = geom.Geometry.DefaultSrid
 
266
        self.tz = Base.DefaultTimezone
 
267
 
 
268
        try:
 
269
            self.conn = psycopg2.connect(connection)
 
270
            self.cur = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
 
271
        except Exception, e:
 
272
            print e
 
273
 
 
274
            if self.cur != None:
 
275
                try:
 
276
                    self.cur.close()
 
277
                except NameError:
 
278
                    pass
 
279
 
 
280
            if self.conn != None:
 
281
                try:
 
282
                    self.conn.close()
 
283
                except NameError:
 
284
                    pass
 
285
 
 
286
    def is_connected(self):
 
287
        if self.conn != None:
 
288
            return not self.conn.closed
 
289
        else:
 
290
            return False
 
291
 
 
292
    def set_table_name(self, table_name):
 
293
        self.table_name = table_name
 
294
 
 
295
    def get_table_name(self):
 
296
        return self.table_name
 
297
 
 
298
    def get_full_table_name(self):
 
299
        if self.get_schema_name() != None:
 
300
            return '"' + self.get_schema_name() + '"."' + self.get_table_name() + '"'
 
301
        else:
 
302
            return self.get_table_name()
 
303
 
 
304
    def set_schema_name(self, schema_name):
 
305
        self.schema_name = schema_name
 
306
 
 
307
    def get_schema_name(self):
 
308
        return self.schema_name
 
309
 
 
310
    def get_srid(self):
 
311
        return self.srid
 
312
 
 
313
    def set_srid(self, srid):
 
314
        self.srid = srid
 
315
 
 
316
    def get_timezone(self):
 
317
        return self.tz
 
318
 
 
319
    def set_timezone(self, tz):
 
320
        self.tz = tz
 
321
 
 
322
    def commit(self):
 
323
        ''' commit pending database transaction '''
 
324
        self.conn.commit()
 
325
 
 
326
    def rollback(self):
 
327
        ''' rollback pending database transaction '''
 
328
        self.conn.rollback()
 
329
 
 
330
    @abstractmethod
 
331
    def insert(self, object):
 
332
        pass
 
333
 
 
334
    @abstractmethod
 
335
    def select(self, args):
 
336
        pass
 
337
 
 
338
 
 
339
class Stroke(Base):
 
340
    '''
 
341
    stroke db access class
 
342
 
 
343
    database table creation (as db user blitzortung, database blitzortung): 
 
344
 
 
345
    CREATE TABLE strokes (id bigserial, timestamp timestamptz, nanoseconds SMALLINT, PRIMARY KEY(id));
 
346
    SELECT AddGeometryColumn('public','strokes','the_geom','4326','POINT',2);
 
347
    GRANT SELECT ON TABLE strokes TO bogroup_ro;
 
348
 
 
349
    ALTER TABLE strokes ADD COLUMN amplitude REAL;
 
350
    ALTER TABLE strokes ADD COLUMN error2d SMALLINT;
 
351
    ALTER TABLE strokes ADD COLUMN type SMALLINT;
 
352
    ALTER TABLE strokes ADD COLUMN stationcount SMALLINT;
 
353
 
 
354
    CREATE INDEX strokes_timestamp ON strokes USING btree("timestamp");
 
355
    CREATE INDEX strokes_geom ON strokes USING gist(the_geom);
 
356
    CREATE INDEX strokes_timestamp_geom ON strokes USING gist("timestamp", the_geom);
 
357
    CREATE INDEX strokes_id_timestamp_geom ON strokes USING gist(id, "timestamp", the_geom);
 
358
 
 
359
    empty the table with the following commands:
 
360
 
 
361
    DELETE FROM strokes;
 
362
    ALTER SEQUENCE strokes_id_seq RESTART 1;
 
363
 
 
364
    '''
 
365
 
 
366
    def __init__(self):
 
367
        Base.__init__(self)
 
368
        self.set_table_name('strokes')
 
369
 
 
370
    def insert(self, stroke):
 
371
        sql = 'INSERT INTO ' + self.get_full_table_name() + \
 
372
            ' ("timestamp", nanoseconds, the_geom, amplitude, error2d, type, stationcount) ' + \
 
373
            'VALUES (\'%s\', %d, st_setsrid(makepoint(%f, %f), 4326), %f, %d, %d, %d)' \
 
374
            %(stroke.get_time(), stroke.get_nanoseconds(), stroke.get_location().x, stroke.get_location().y, stroke.get_amplitude(), stroke.get_lateral_error(), stroke.get_type(), stroke.get_station_count())
 
375
        self.cur.execute(sql)
 
376
 
 
377
    def get_latest_time(self):
 
378
        sql = 'SELECT timestamp FROM ' + self.get_full_table_name() + \
 
379
            ' ORDER BY timestamp DESC LIMIT 1'
 
380
        self.cur.execute(sql)
 
381
        if self.cur.rowcount == 1:
 
382
            result = self.cur.fetchone()
 
383
            return result['timestamp']
 
384
        else:
 
385
            return None
 
386
 
 
387
    def create(self, result):
 
388
        stroke = data.Stroke()
 
389
 
 
390
        stroke.set_time(result['timestamp'])
 
391
        stroke.set_nanoseconds(result['nanoseconds'])
 
392
        stroke.set_location(shapely.wkb.loads(result['the_geom'].decode('hex')))
 
393
        stroke.set_amplitude(result['amplitude'])
 
394
        stroke.set_type(result['type'])
 
395
        stroke.set_station_count(result['stationcount'])
 
396
        stroke.set_lateral_error(result['error2d'])
 
397
 
 
398
        return stroke
 
399
 
 
400
    def select_query(self, args, query = Query()):
 
401
        ' build up query object for select statement '
 
402
        query.set_table_name(self.get_full_table_name())
 
403
        query.set_columns(['"timestamp"', 'nanoseconds', 'st_transform(the_geom, %i) AS the_geom' % self.srid, 'amplitude', 'type', 'error2d', 'stationcount'])
 
404
        query.add_parameters({'srid': self.srid})
 
405
 
 
406
        query.add_condition('the_geom IS NOT NULL')
 
407
        query.parse_args(args)
 
408
        return query
 
409
 
 
410
    def select(self, *args):
 
411
 
 
412
        ' build up query '
 
413
        query = self.select_query(args)
 
414
 
 
415
        return self.select_execute(query)
 
416
 
 
417
    def select_raster(self, raster, *args):
 
418
 
 
419
        ' build up query '
 
420
        query = self.select_query(args, RasterQuery(raster))
 
421
 
 
422
        return self.select_execute(query)
 
423
 
 
424
    def select_execute(self, query):
 
425
        ' set timezone for query '
 
426
        self.cur.execute('SET TIME ZONE \'%s\'' %(str(self.tz)))
 
427
 
 
428
        ' perform query '
 
429
        self.cur.execute(str(query), query.get_parameters())
 
430
 
 
431
        ' collect and return data '   
 
432
        return query.get_results(self)
 
433
 
 
434
class Location(Base):
 
435
  '''
 
436
  geonames db access class
 
437
 
 
438
  CREATE SCHEMA geo;
 
439
 
 
440
  CREATE TABLE geo.geonames (id bigserial, "name" character varying, PRIMARY KEY(id));
 
441
  SELECT AddGeometryColumn('geo','geonames','the_geom','4326','POINT',2);
 
442
 
 
443
  ALTER TABLE geo.geonames ADD COLUMN "class" INTEGER;
 
444
  ALTER TABLE geo.geonames ADD COLUMN feature_class CHARACTER(1);
 
445
  ALTER TABLE geo.geonames ADD COLUMN feature_code VARCHAR;
 
446
  ALTER TABLE geo.geonames ADD COLUMN country_code VARCHAR;
 
447
  ALTER TABLE geo.geonames ADD COLUMN admin_code_1 VARCHAR;
 
448
  ALTER TABLE geo.geonames ADD COLUMN admin_code_2 VARCHAR;
 
449
  ALTER TABLE geo.geonames ADD COLUMN population INTEGER;
 
450
  ALTER TABLE geo.geonames ADD COLUMN elevation SMALLINT;
 
451
 
 
452
  CREATE INDEX geonames_geom ON geo.geonames USING gist(the_geom);
 
453
 
 
454
  '''
 
455
 
 
456
  def __init__(self):
 
457
    Base.__init__(self)
 
458
    self.set_schema_name('geo')
 
459
    self.set_table_name('geonames')
 
460
 
 
461
  def delete_all(self):
 
462
    self.cur.execute('DELETE FROM ' + self.get_full_table_name())
 
463
 
 
464
  def insert(self, line):
 
465
    fields = line.strip().split('\t')
 
466
    name = fields[1]
 
467
    latitude = float(fields[4])
 
468
    longitude = float(fields[5])
 
469
    feature_class = fields[6]
 
470
    feature_code = fields[7]
 
471
    country_code = fields[8]
 
472
    admin_code_1 = fields[10]
 
473
    admin_code_2 = fields[11]
 
474
    admin_code_3 = fields[12]
 
475
    admin_code_4 = fields[13]
 
476
    population = int(fields[14])
 
477
    if fields[15] != '':
 
478
      elevation = int(fields[15])
 
479
    else:
 
480
      elevation = -1
 
481
 
 
482
    name = name.replace("'", "''")
 
483
 
 
484
    classification = self.size_class(population)
 
485
 
 
486
    if classification is not None:
 
487
      self.cur.execute('INSERT INTO ' + self.get_full_table_name() + '''
 
488
        (the_geom, name, class, feature_class, feature_code, country_code, admin_code_1, admin_code_2, population, elevation)
 
489
      VALUES(
 
490
        GeomFromText('POINT(%f %f)', 4326), '%s', %d, '%s', '%s', '%s', '%s', '%s', %d, %d)'''
 
491
                       % (longitude, latitude, name, classification, feature_class, feature_code, country_code, admin_code_1, admin_code_2, population, elevation))
 
492
 
 
493
  def size_class(self, n):
 
494
    if n < 1:
 
495
      return None
 
496
    base = math.floor(math.log(n)/math.log(10)) - 1
 
497
    relative = n / math.pow(10, base)
 
498
    order = min(2, math.floor(relative/25))
 
499
    if base < 0:
 
500
      base = 0
 
501
    return min(15, base * 3 + order)
 
502
 
 
503
  def select(self, *args):
 
504
    self.center = None
 
505
    self.min_population = 1000
 
506
    self.max_distance = 10000
 
507
    self.limit = 10
 
508
 
 
509
    for arg in args:
 
510
      if arg != None:
 
511
        if isinstance(arg, Center):
 
512
          ' center point information given '
 
513
          self.center = arg
 
514
        elif isinstance(arg, Limit):
 
515
          ' limit information given '
 
516
          self.limit = arg
 
517
 
 
518
    if self.is_connected():
 
519
      queryString = '''SELECT
 
520
          name,
 
521
          country_code,
 
522
          admin_code_1,
 
523
          admin_code_2,
 
524
          feature_class,
 
525
          feature_code,
 
526
          elevation,
 
527
          st_transform(the_geom, %(srid)s) AS the_geom,
 
528
          population,
 
529
          distance_sphere(the_geom, c.center) AS distance,
 
530
          st_azimuth(the_geom, c.center) AS azimuth
 
531
        FROM
 
532
          (SELECT SetSRID(MakePoint(%(center_x)s, %(center_y)s), %(srid)s) as center ) as c,
 
533
          %(table_name)s
 
534
        WHERE
 
535
          feature_class='P'
 
536
          AND population >= %(min_population)s
 
537
          AND st_transform(the_geom, %(srid)s) && st_expand(c.center, %(max_distance)s) order by distance limit %(limit)s''';
 
538
 
 
539
      params = {}
 
540
      params['srid'] = self.get_srid()
 
541
      params['table_name'] = self.get_full_table_name()
 
542
      params['center_x'] = self.center.get_point().x
 
543
      params['center_y'] = self.center.get_point().y
 
544
      params['min_population'] = self.min_population
 
545
      params['max_distance'] = self.max_distance
 
546
      params['limit'] = self.limit
 
547
 
 
548
      self.cur.execute(queryString % params)
 
549
 
 
550
      locations = []
 
551
      if self.cur.rowcount > 0:
 
552
        for result in self.cur.fetchall():
 
553
          location = {}
 
554
          location['name'] = result['name']
 
555
          location['distance'] = result['distance']
 
556
          location['azimuth'] = result['azimuth']
 
557
          locations.append(location)
 
558
 
 
559
      return locations
 
560