3
# Copyright 2014 Hewlett-Packard Development Company, L.P.
5
# Licensed under the Apache License, Version 2.0 (the "License"); you may
6
# not use this file except in compliance with the License. You may obtain
7
# a copy of the License at
9
# http://www.apache.org/licenses/LICENSE-2.0
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14
# License for the specific language governing permissions and limitations
18
from collections import OrderedDict
22
from inspect import getmembers
23
from inspect import isfunction
28
class BanditTestSet():
32
def __init__(self, logger, config, profile=None):
35
filter_list = self._filter_list_from_config(profile=profile)
36
self.load_tests(filter=filter_list)
38
def _filter_list_from_config(self, profile=None):
39
# will create an (include,exclude) list tuple from a specified name
42
# if a profile isn't set, there is nothing to do here
44
return_tuple = ([], [])
47
# an empty include list means that all are included
49
# profile needs to be a dict, include needs to be an element in
50
# profile, include needs to be a list, and 'all' is not in include
51
if(isinstance(profile, dict) and 'include' in profile and
52
isinstance(profile['include'], list) and
53
'all' not in profile['include']):
54
# there is a list of specific includes, add to the include list
55
for inc in profile['include']:
56
include_list.append(inc)
58
# an empty exclude list means none are excluded, an exclude list with
59
# 'all' means that all are excluded. Specifically named excludes are
60
# subtracted from the include list.
62
if(isinstance(profile, dict) and 'exclude' in profile and
63
isinstance(profile['exclude'], list)):
64
# it's a list, exclude specific tests
65
for exc in profile['exclude']:
66
exclude_list.append(exc)
69
"_filter_list_from_config completed - include: %s, exclude %s",
70
include_list, exclude_list
72
return_tuple = (include_list, exclude_list)
75
def _filter_tests(self, filter):
76
'''Filters the test set according to the filter tuple
78
Filters the test set according to the filter tuple which contains
79
include and exclude lists.
80
:param filter: Include, exclude lists tuple
83
include_list = filter[0]
84
exclude_list = filter[1]
86
# copy of tests dictionary for removing tests from
87
temp_dict = copy.deepcopy(self.tests)
89
# if the include list is empty, we don't have to do anything, if it
90
# isn't, we need to remove all tests except the ones in the list
92
for check_type in self.tests:
93
for test_name in self.tests[check_type]:
94
if test_name not in include_list:
95
del temp_dict[check_type][test_name]
97
# remove the items specified in exclude list
99
for check_type in self.tests:
100
for test_name in self.tests[check_type]:
101
if test_name in exclude_list:
102
del temp_dict[check_type][test_name]
104
# copy tests back over from temp copy
105
self.tests = copy.deepcopy(temp_dict)
106
self.logger.debug('obtained filtered set of tests:')
108
self.logger.debug('\t%s : %s', k, self.tests[k])
110
def _get_decorators_list(self):
111
'''Returns a list of decorator function names
113
Returns a list of decorator function names so that they can be
114
ignored when discovering test function names.
117
# we need to know the name of the decorators so we can automatically
118
# ignore them when discovering functions
119
decorator_source_file = "bandit.core.test_properties"
120
module = importlib.import_module(decorator_source_file)
123
decorators = [o for o in getmembers(module) if isfunction(o[1])]
125
return_list.append(d[0])
126
self.logger.debug('_get_decorators_list returning: %s', return_list)
129
def load_tests(self, filter=None):
130
'''Loads all tests in the plugins directory into testsdictionary.'''
132
# tests are a dictionary of functions, grouped by check type
133
# where the key is the function name, and the value is the
135
# eg. tests[check_type][fn_name] = function
138
directory = self.config.get_setting('plugins_dir')
139
plugin_name_pattern = self.config.get_setting('plugin_name_pattern')
141
decorators = self._get_decorators_list()
142
# try to import each python file in the plugins directory
143
sys.path.append(os.path.dirname(directory))
144
for file in glob.glob1(directory, plugin_name_pattern):
145
module_name = os.path.basename(file).split('.')[0]
147
# try to import the module by name
149
outer = os.path.basename(os.path.normpath(directory))
150
self.logger.debug("importing plugin module: %s",
151
outer + '.' + module_name)
152
module = importlib.import_module(outer + '.' + module_name)
155
except ImportError as e:
156
self.logger.error("could not import plugin module '%s.%s'",
157
directory, module_name)
158
self.logger.error("\tdetail: '%s'", str(e))
161
# otherwise we want to obtain a list of all functions in the module
162
# and add them to our dictionary of tests
165
o for o in getmembers(module) if isfunction(o[1])
167
for cur_func in functions_list:
168
# for every function in the module, add to the dictionary
169
# unless it's one of our decorators, then ignore it
170
fn_name = cur_func[0]
171
if fn_name not in decorators:
173
function = getattr(module, fn_name)
174
except AttributeError as e:
176
"could not locate test function '%s' in "
178
fn_name, directory, module_name
182
if hasattr(function, '_checks'):
183
for check in function._checks:
184
# if check type hasn't been encountered
185
# yet, initialize to empty dictionary
186
if check not in self.tests:
187
self.tests[check] = {}
188
# if there is a test name collision, bail
189
if fn_name in self.tests[check]:
191
"Duplicate function definition "
192
"%s in %s", fn_name, file
196
self.tests[check][fn_name] = function
198
'added function %s targetting %s',
201
self._filter_tests(filter)
203
def get_tests(self, checktype):
204
'''Returns all tests that are of type checktype
206
:param checktype: The type of test to filter on
207
:return: A dictionary of tests which are of the specified type
210
self.logger.debug('get_tests called with check type: %s', checktype)
211
if checktype in self.tests:
212
scoped_tests = self.tests[checktype]
213
self.logger.debug('get_tests returning scoped_tests : %s',