~ben-hutchings/ensoft-sextant/upload-perf

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
# -----------------------------------------
# Sextant
# Copyright 2014, Ensoft Ltd.
# Author: Patrick Stevens, using work from Patrick Stevens and James Harkin
# -----------------------------------------
# API to interact with a Neo4J server: upload, query and delete programs in a DB

__all__ = ("Validator", "AddToDatabase", "FunctionQueryResult", "Function",
           "SextantConnection")

import re  # for validation of function/program names
import logging
from datetime import datetime
import os
import getpass
from collections import namedtuple

from neo4jrestclient.client import GraphDatabase
import neo4jrestclient.client as client

COMMON_CUTOFF = 10
# a function is deemed 'common' if it has more than this
# many connections


class Validator():
    """ Sanitises/checks strings, to prevent Cypher injection attacks"""

    @staticmethod
    def validate(input_):
        """
        Checks whether we can allow a string to be passed into a Cypher query.
        :param input_: the string we wish to validate
        :return: bool(the string is allowed)
        """
        regex = re.compile(r'^[A-Za-z0-9\-:\.\$_@\*\(\)%\+,]+$')
        return bool(regex.match(input_))

    @staticmethod
    def sanitise(input_):
        """
        Strips harmful characters from the given string.
        :param input_: string to sanitise
        :return: the sanitised string
        """
        return re.sub(r'[^\.\-_a-zA-Z0-9]+', '', input_)


class AddToDatabase():
    """Updates the database, adding functions/calls to a given program"""

    def __init__(self, program_name='', sextant_connection=None,
                 uploader='', uploader_id='', date=None):
        """
        Object which can be used to add functions and calls to a new program
        :param program_name: the name of the new program to be created
          (must already be validated against Validator)
        :param sextant_connection: the SextantConnection to use for connections
        :param uploader: string identifier of user who is uploading
        :param uploader_id: string Unix user-id of logged-in user
        :param date: string date of today
        """
        # program_name must be alphanumeric, to avoid injection attacks easily
        if not Validator.validate(program_name):
            return

        self.program_name = program_name
        self.parent_database_connection = sextant_connection

        self._funcs_tx = None  # transaction for uploading functions
        self._calls_tx = None  # transaction for uploading relationships

        self._calldict = {}

        if self.parent_database_connection:
            # we'll locally use db for short
            db = self.parent_database_connection._db

            parent_function = db.nodes.create(name=program_name,
                                              type='program',
                                              uploader=uploader,
                                              uploader_id=uploader_id,
                                              date=date)
            self._parent_id = parent_function.id

            self._funcs_tx = db.transaction(using_globals=False, for_query=True)
            self._calls_tx = db.transaction(using_globals=False, for_query=True)

    @staticmethod
    def _get_display_name(function_name):
        """
        Gets the name we will display to the user for this function name.

        For instance, if function_name were __libc_start_main@plt, we would
        return ("__libc_start_main", "plt_stub"). The returned function type is
        currently one of "plt_stub", "function_pointer" or "normal".

        :param function_name: the name straight from objdump of a function
        :return: ("display name", "function type")

        """

        if function_name[-4:] == "@plt":
            display_name = function_name[:-4]
            function_group = "plt_stub"
        elif function_name[:20] == "_._function_pointer_":
            display_name = function_name
            function_group = "function_pointer"
        else:
            display_name = function_name
            function_group = "normal"

        return display_name, function_group

    def add_function(self, function_name):
        """
        Adds a function to the program, ready to be sent to the remote database.
        If the function name is already in use, this method effectively does
          nothing and returns True.

        :param function_name: a string which must be alphanumeric
        :return: True if the request succeeded, False otherwise
        """
        if not Validator.validate(function_name):
            return False
        if function_name in self._calldict:
            return True

        display_name, function_group = self._get_display_name(function_name)

        query = ('START n = node({}) '
                 'CREATE (n)-[:subject]->(m:func {{type: "{}", name: "{}"}}) '
                 'RETURN m.name, id(m)')
        query = query.format(self._parent_id, function_group, display_name)

        self._funcs_tx.append(query)

        self._calldict[function_name] = set()

        return True

    def add_function_call(self, fn_calling, fn_called):
        """
        Adds a function call to the program, ready to be sent to the database.
        Effectively does nothing if there is already a function call between
          these two functions.
        Function names must be alphanumeric for easy security purposes;
          returns False if they fail validation.  :param fn_calling: the name of the calling-function as a string.
          It should already exist in the AddToDatabase; if it does not,
          this method will create a stub for it.
        :param fn_called: name of the function called by fn_calling.
          If it does not exist, we create a stub representation for it.
        :return: True if successful, False otherwise
        """
        if not all((Validator.validate(fn_calling),
                    Validator.validate(fn_called))):
            return False
    
        if not fn_called in self._calldict:
            self.add_function(fn_called)
        if not fn_calling in self._calldict:
            self.add_function(fn_calling)


        self._calldict[fn_calling].add(fn_called)

        return True

    def commit(self):
        """
        Call this when you are finished with the object.
        Changes are not synced to the remote database until this is called.
        """
 
        functions = (result[0] for result in self._funcs_tx.commit())  # send off the function names
        id_funcs = dict(functions)
        
        logging.info('Functions uploaded.')

        # so id_funcs is a dict with id_funcs['name'] == id
        for (caller, called) in self._calldict.items():
            if not called:
                pass
            else:
                # add all the connections for this caller in one query
                caller_id = id_funcs[self._get_display_name(caller)[0]]
                called_ids = (id_funcs[self._get_display_name(fn)[0]] for fn in called)

                query = (' MATCH n WHERE id(n) = {}'
                         ' UNWIND[{}] as called_id'
                         ' MATCH m WHERE id(m) = called_id'
                         ' CREATE (n)-[:calls]->(m)')

                query = query.format(caller_id, ','.join(str(i) for i in called_ids))

                self._calls_tx.append(query)

        self._calls_tx.commit()
        logging.info('Calls uploaded')


class FunctionQueryResult:
    """A graph of function calls arising as the result of a Neo4J query."""

    def __init__(self, parent_db, program_name='', rest_output=None):
        self.program_name = program_name
        self._parent_db_connection = parent_db
        self.functions = self._rest_node_output_to_graph(rest_output)
        self._update_common_functions()

    def __eq__(self, other):
        # we make a dictionary so that we can perform easy comparison
        selfdict = {func.name: func for func in self.functions}
        otherdict = {func.name: func for func in other.functions}

        return self.program_name == other.program_name and selfdict == otherdict

    def _update_common_functions(self):
        """
        Loop over all functions: increment the called-by count of their callees.
        """
        for func in self.functions:
            for called in func.functions_i_call:
                called.number_calling_me += 1

    def _rest_node_output_to_graph(self, rest_output):
        """
        Convert the output of a REST API query into our internal representation.
        :param rest_output: output of the REST call as a Neo4j QuerySequence
        :return: iterable of <Function>s ready to initialise self.functions.
        """

        if rest_output is None or not rest_output.elements:
            return []

        # how we store this is: a dict
        #   with keys  'functionname'
        #   and values [the function object we will use,
        #               and a set of (function names this function calls),
        #               and numeric ID of this node in the Neo4J database]

        result = {}

        # initial pass for names of functions

        # if the following assertion failed, we've probably called db.query
        # to get it to not return client.Node objects, which is wrong.
        # we attempt to handle this a bit later; this should never arise, but
        # we can cope with it happening in some cases, like the test suite

        if type(rest_output.elements) is not list:
            logging.warning('Not a list: {}'.format(type(rest_output.elements)))

        for node_list in rest_output.elements:
            assert(isinstance(node_list, list))
            for node in node_list:
                if isinstance(node, client.Node):
                    name = node.properties['name']
                    node_id = node.id
                    node_type = node.properties['type']
                else:  # this is the handling we mentioned earlier;
                    # we are a dictionary instead of a list, as for some
                    # reason we've returned Raw rather than Node data.
                    # We should never reach this code, but just in case.
                    name = node['data']['name']
                    # hacky workaround to get the id
                    node_id = node['self'].split('/')[-1]
                    node_type = node['data']['type']

                result[name] = [Function(self.program_name,
                                         function_name=name,
                                         function_type=node_type),
                                set(),
                                node_id]

        # end initialisation of names-dictionary

        if self._parent_db_connection is not None:
            # This is the normal case, of extracting results from a server.
            # We leave the other case in because it is useful for unit testing.

            # We collect the name-name pairs of caller-callee, batched for speed
            new_tx = self._parent_db_connection.transaction(using_globals=False,
                                                            for_query=True)
            for index in result:
                q = ("START n=node({})"
                     "MATCH n-[calls:calls]->(m)"
                     "RETURN n.name, m.name").format(result[index][2])
                new_tx.append(q)

            logging.debug('exec')
            results = new_tx.execute()

            # results is a list of query results, each of those being a list of
            # calls.

            for call_list in results:
                if call_list:
                    # call_list has element 0 being an arbitrary call this
                    # function makes; element 0 of that call is the name of the
                    # function itself. Think {{'orig', 'b'}, {'orig', 'c'}}.
                    orig = call_list[0][0]
                    # result['orig'] is [<Function>, ('callee1','callee2')]
                    result[orig][1] |= set(list(zip(*call_list.elements))[1])
                    # recall: set union is denoted by |

        else:
            # we don't have a parent database connection.
            # This has probably arisen because we created this object from a
            # test suite, or something like that.
            for node in rest_output.elements:
                node_name = node[0].properties['name']
                result[node_name][1] |= {relationship.end.properties['name']
                                         for relationship in node[0].relationships.outgoing()}

        logging.debug('Relationships complete.')

        # named_function takes a function name and returns the Function object
        # with that name, or None if none exists.
        named_function = lambda name: result[name][0] if name in result else None

        for function, calls, node_id in result.values():
            what_i_call = [named_function(name)
                           for name in calls
                           if named_function(name) is not None]
            function.functions_i_call = what_i_call

        return [list_element[0]
                for list_element in result.values()
                if list_element[0]]

    def get_functions(self):
        """
        :return: a list of Function objects present in the query result
        """
        return self.functions

    def get_function(self, name):
        """
        Given a function name, returns the Function object which has that name.
        If no function with that name exists, returns None.
        """
        func_list = [func for func in self.functions if func.name == name]
        return None if len(func_list) == 0 else func_list[0]


def set_common_cutoff(common_def):
    """
    Sets the number of incoming connections at which we deem a function 'common'
    Default is 10 (which is used if this method is never called).
    :param common_def: number of incoming connections
    """
    global COMMON_CUTOFF
    COMMON_CUTOFF = common_def


class Function(object):
    """Represents a function which might appear in a FunctionQueryResult."""

    def __eq__(self, other):
        funcs_i_call_list = {func.name for func in self.functions_i_call}
        funcs_other_calls_list = {func.name for func in other.functions_i_call}

        return (self.parent_program == other.parent_program
                and self.name == other.name
                and funcs_i_call_list == funcs_other_calls_list
                and self.attributes == other.attributes)

    @property
    def number_calling_me(self):
        return self._number_calling_me

    @number_calling_me.setter
    def number_calling_me(self, value):
        self._number_calling_me = value
        self.is_common = (self._number_calling_me > COMMON_CUTOFF)

    def __init__(self, program_name='', function_name='', function_type=''):
        self.parent_program = program_name
        self.attributes = []
        self.type = function_type
        self.functions_i_call = []
        self.name = function_name
        self.is_common = False
        self._number_calling_me = 0
        # care: _number_calling_me is not automatically updated, except by
        # any invocation of FunctionQueryResult._update_common_functions.


class SextantConnection:
    """
    RESTful connection to a remote database.
    It can be used to create/delete/query programs.
    """

    ProgramWithMetadata = namedtuple('ProgramWithMetadata',
                                     ['uploader', 'uploader_id',
                                      'program_name', 'date', 
                                      'number_of_funcs'])

    def __init__(self, url):
        self.url = url
        self._db = GraphDatabase(url)

    def new_program(self, name_of_program):
        """
        Request that the remote database create a new program with the given name.
        This procedure will create a new program remotely; you can manipulate
          that program using the returned AddToDatabase object.
        The name can appear in the database already, but this is not recommended
          because then delete_program will not know which to delete. Check first
          using self.check_program_exists.
        The name specified must pass Validator.validate()ion; this is a measure
          to prevent Cypher injection attacks.
        :param name_of_program: string program name
        :return: AddToDatabase instance if successful
        """

        if not Validator.validate(name_of_program):
            raise ValueError(
                "{} is not a valid program name".format(name_of_program))
        
        uploader = getpass.getuser()
        uploader_id = os.getuid()

        return AddToDatabase(sextant_connection=self,
                             program_name=name_of_program,
                             uploader=uploader, uploader_id=uploader_id,
                             date=str(datetime.now()))

    def delete_program(self, name_of_program):
        """
        Request that the remote database delete a specified program.
        :param name_of_program: a string which must be alphanumeric only
        :return: bool(request succeeded)
        """
        if not Validator.validate(name_of_program):
            return False

        q = """MATCH (n) WHERE n.name= "{}" AND n.type="program"
        OPTIONAL MATCH (n)-[r]-(b) OPTIONAL MATCH (b)-[rel]-()
        DELETE  b,rel DELETE n, r""".format(name_of_program)

        self._db.query(q)

        return True

    def _execute_query(self, prog_name='', query=''):
        """
        Executes a Cypher query against the remote database.
        Note that this returns a FunctionQueryResult, so is unsuitable for any
          other expected outputs (such as lists of names). For those instances,
          it is better to run self._parent_database_connection_object.query
          explicitly.
        Intended only to be used for non-updating queries
          (such as "get functions" rather than "create").
        :param prog_name: name of the program the result object will reflect
        :param query: verbatim query we wish the server to execute
        :return: a FunctionQueryResult corresponding to the server's output
        """
        rest_output = self._db.query(query, returns=client.Node)

        return FunctionQueryResult(parent_db=self._db,
                                   program_name=prog_name,
                                   rest_output=rest_output)

    def get_program_names(self):
        """
        Execute query to retrieve a list of all programs in the database.
        Any name in this list can be used verbatim in any SextantConnection
          method which requires a program-name input.
        :return: a list of function-name strings.
        """
        q = """MATCH (n) WHERE n.type = "program" RETURN n.name"""
        program_names = self._db.query(q, returns=str).elements

        result = [el[0] for el in program_names]

        return set(result)

    def programs_with_metadata(self):
        """
        Returns a set of namedtuples which represent the current database.
        
        The namedtuples have .uploader, .uploader_id, .program_name, .date,
        .number_of_funcs.
        :return: set of namedtuples
       
        """
        
        q = ("MATCH (base) WHERE base.type = 'program' "
             "MATCH (base)-[:subject]->(n)"
             "RETURN base.uploader, base.uploader_id, base.name, base.date, count(n)")
        result = self._db.query(q)
        return {self.ProgramWithMetadata(*res) for res in result}

    def check_program_exists(self, program_name):
        """
        Execute query to check whether a program with the given name exists.
        Returns False if the program_name fails validation against Validator.
        :return: bool(the program exists in the database).
        """

        if not Validator.validate(program_name):
            return False

        q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program' "
             "RETURN count(base)").format(program_name)

        result = self._db.query(q, returns=int)
        return result.elements[0][0] > 0

    def check_function_exists(self, program_name, function_name):
        """
        Execute query to check whether a function with the given name exists.
        We only check for functions which are children of a program with the
          given program_name.
        :param program_name: string name of the program within which to check
        :param function_name: string name of the function to check for existence
        :return: bool(names validate correctly, and function exists in program)
        """
        if not self.check_program_exists(program_name):
            return False

        if not Validator.validate(program_name):
            return False

        q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program'"
             "MATCH (base)-[r:subject]->(m) WHERE m.name = '{}'"
             "RETURN count(m)").format(program_name, function_name)

        result = self._db.query(q, returns=int)
        return result.elements[0][0] > 0

    def get_function_names(self, program_name):
        """
        Execute query to retrieve a list of all functions in the program.
        Any of the output names can be used verbatim in any SextantConnection
          method which requires a function-name input.
        :param program_name: name of the program whose functions to retrieve
        :return: None if program_name doesn't exist in the remote database,
          a set of function-name strings otherwise.
        """

        if not self.check_program_exists(program_name):
            return None

        q = ("MATCH (base) WHERE base.name = '{}' AND base.type = 'program' "
             "MATCH (base)-[r:subject]->(m) "
             "RETURN  m.name").format(program_name)
        return {func[0] for func in self._db.query(q)}

    def get_all_functions_called(self, program_name, function_calling):
        """
        Execute query to find all functions called by a function (indirectly).
        If the given function is not present in the program, returns None;
          likewise if the program_name does not exist.
        :param program_name: a string name of the program we wish to query under
        :param function_calling: string name of a function whose children to find
        :return: FunctionQueryResult, maximal subgraph rooted at function_calling
        """

        if not self.check_program_exists(program_name):
            return None

        if not self.check_function_exists(program_name, function_calling):
            return None

        # @@@ type in query - does it matter?
        q = """MATCH (base) WHERE base.name = '{}' ANd base.type = 'program'
            MATCH (base)-[:subject]->(m) WHERE m.name='{}'
            MATCH (m)-[:calls*]->(n)
            RETURN distinct n, m""".format(program_name, function_calling)

        return self._execute_query(program_name, q)

    def get_all_functions_calling(self, program_name, function_called):
        """
        Execute query to find all functions which call a function (indirectly).
        If the given function is not present in the program, returns None;
          likewise if the program_name does not exist.
        :param program_name: a string name of the program we wish to query
        :param function_called: string name of a function whose parents to find
        :return: FunctionQueryResult, maximal connected subgraph with leaf function_called
        """

        if not self.check_program_exists(program_name):
            return None

        if not self.check_function_exists(program_name, function_called):
            return None

        q = """MATCH (base) WHERE base.name = '{}' AND base.type = 'program'
            MATCH (base)-[r:subject]->(m) WHERE m.name='{}'
            MATCH (n)-[:calls*]->(m) WHERE n.name <> '{}'
            RETURN distinct n , m"""
        q = q.format(program_name, function_called, program_name)

        return self._execute_query(program_name, q)

    def get_call_paths(self, program_name, function_calling, function_called):
        """
        Execute query to find all possible routes between two specific nodes.
        If the given functions are not present in the program, returns None;
          ditto if the program_name does not exist.
        :param program_name: string program name
        :param function_calling: string
        :param function_called: string
        :return: FunctionQueryResult, the union of all subgraphs reachable by
          adding a source at function_calling and a sink at function_called.
        """

        if not self.check_program_exists(program_name):
            return None

        if not self.check_function_exists(program_name, function_called):
            return None

        if not self.check_function_exists(program_name, function_calling):
            return None

        q = r"""MATCH (pr) WHERE pr.name = '{}' AND pr.type = 'program'
                MATCH p=(start {{name: "{}" }})-[:calls*]->(end {{name:"{}"}})
                  WHERE (pr)-[:subject]->(start)
                WITH DISTINCT nodes(p) AS result
                UNWIND result AS answer
                RETURN answer"""
        q = q.format(program_name, function_calling, function_called)

        return self._execute_query(program_name, q)

    def get_whole_program(self, program_name):
        """Execute query to find the entire program with a given name.
        If the program is not present in the remote database, returns None.
        :param: program_name: a string name of the program we wish to return.
        :return: a FunctionQueryResult consisting of the program graph.
        """

        if not self.check_program_exists(program_name):
            return None

        query = """MATCH (base) WHERE base.name = '{}' AND base.type = 'program'
                MATCH (base)-[subject:subject]->(m)
                RETURN DISTINCT (m)""".format(program_name)

        return self._execute_query(program_name, query)

    def get_shortest_path_between_functions(self, program_name, func1, func2):
        """
        Execute query to get a single, shortest, path between two functions.
        :param program_name: string name of the program we wish to search under
        :param func1: the name of the originating function of our shortest path
        :param func2: the name of the function at which to terminate the path
        :return: FunctionQueryResult shortest path between func1 and func2.
        """
        if not self.check_program_exists(program_name):
            return None

        if not self.check_function_exists(program_name, func1):
            return None

        if not self.check_function_exists(program_name, func2):
            return None

        q = """MATCH (func1 {{ name:"{}" }}),(func2 {{ name:"{}" }}),
            p = shortestPath((func1)-[:calls*]->(func2))
            UNWIND nodes(p) AS ans
            RETURN ans""".format(func1, func2)

        return self._execute_query(program_name, q)