~tribaal/txaws/xss-hardening

« back to all changes in this revision

Viewing changes to txaws/server/schema.py

  • Committer: Duncan McGreggor
  • Date: 2009-11-22 02:20:42 UTC
  • mto: (44.3.2 484858-s3-scripts)
  • mto: This revision was merged to the branch mainline in revision 52.
  • Revision ID: duncan@canonical.com-20091122022042-4zi231hxni1z53xd
* Updated the LICENSE file with copyright information.
* Updated the README with license information.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
from datetime import datetime
2
 
from operator import itemgetter
3
 
 
4
 
from dateutil.tz import tzutc
5
 
from dateutil.parser import parse
6
 
 
7
 
from txaws.server.exception import APIError
8
 
 
9
 
 
10
 
class SchemaError(APIError):
11
 
    """Raised when failing to extract or bundle L{Parameter}s."""
12
 
 
13
 
    def __init__(self, message):
14
 
        code = self.__class__.__name__[:-len("Error")]
15
 
        super(SchemaError, self).__init__(400, code=code, message=message)
16
 
 
17
 
 
18
 
class MissingParameterError(SchemaError):
19
 
    """Raised when a parameter is missing.
20
 
 
21
 
    @param name: The name of the missing parameter.
22
 
    """
23
 
 
24
 
    def __init__(self, name, kind=None):
25
 
        message = "The request must contain the parameter %s" % name
26
 
        if kind is not None:
27
 
            message += " (%s)" % (kind,)
28
 
        super(MissingParameterError, self).__init__(message)
29
 
 
30
 
 
31
 
class InconsistentParameterError(SchemaError):
32
 
    def __init__(self, name):
33
 
        message = "Parameter %s is used inconsistently" % (name,)
34
 
        super(InconsistentParameterError, self).__init__(message)
35
 
 
36
 
 
37
 
class InvalidParameterValueError(SchemaError):
38
 
    """Raised when the value of a parameter is invalid."""
39
 
 
40
 
 
41
 
class InvalidParameterCombinationError(SchemaError):
42
 
    """
43
 
    Raised when there is more than one parameter with the same name,
44
 
    when this isn't explicitly allowed for.
45
 
 
46
 
    @param name: The name of the missing parameter.
47
 
    """
48
 
 
49
 
    def __init__(self, name):
50
 
        message = "The parameter '%s' may only be specified once." % name
51
 
        super(InvalidParameterCombinationError, self).__init__(message)
52
 
 
53
 
 
54
 
class UnknownParameterError(SchemaError):
55
 
    """Raised when a parameter to extract is unknown."""
56
 
 
57
 
    def __init__(self, name):
58
 
        message = "The parameter %s is not recognized" % name
59
 
        super(UnknownParameterError, self).__init__(message)
60
 
 
61
 
 
62
 
class UnknownParametersError(Exception):
63
 
    """
64
 
    Raised when extra unknown fields are passed to L{Structure.parse}.
65
 
 
66
 
    @ivar result: The already coerced result representing the known parameters.
67
 
    @ivar unknown: The unknown parameters.
68
 
    """
69
 
    def __init__(self, result, unknown):
70
 
        self.result = result
71
 
        self.unknown = unknown
72
 
        message = "The parameters %s are not recognized" % (unknown,)
73
 
        super(UnknownParametersError, self).__init__(message)
74
 
 
75
 
 
76
 
class Parameter(object):
77
 
    """A single parameter in an HTTP request.
78
 
 
79
 
    @param name: A name for the key of the parameter, as specified
80
 
        in a request. For example, a single parameter would be specified
81
 
        simply as 'GroupName'.  If more than one group name was accepted,
82
 
        it would be specified as 'GroupName.n'.  A more complex example
83
 
        is 'IpPermissions.n.Groups.m.GroupName'.
84
 
    @param optional: If C{True} the parameter may not be present.
85
 
    @param default: A default value for the parameter, if not present.
86
 
    @param min: Minimum value for a parameter.
87
 
    @param max: Maximum value for a parameter.
88
 
    @param allow_none: Whether the parameter may be C{None}.
89
 
    @param validator: A callable to validate the parameter, returning a bool.
90
 
    """
91
 
 
92
 
    supports_multiple = False
93
 
    kind = None
94
 
 
95
 
    def __init__(self, name=None, optional=False, default=None,
96
 
                 min=None, max=None, allow_none=False, validator=None,
97
 
                 doc=None):
98
 
        self.name = name
99
 
        self.optional = optional
100
 
        self.default = default
101
 
        self.min = min
102
 
        self.max = max
103
 
        self.allow_none = allow_none
104
 
        self.validator = validator
105
 
        self.doc = doc
106
 
 
107
 
    def coerce(self, value):
108
 
        """Coerce a single value according to this parameter's settings.
109
 
 
110
 
        @param value: A L{str}, or L{None}. If L{None} is passed - meaning no
111
 
            value is avalable at all, not even the empty string - and this
112
 
            parameter is optional, L{self.default} will be returned.
113
 
        """
114
 
        if value is None:
115
 
            if self.optional:
116
 
                return self.default
117
 
            else:
118
 
                value = ""
119
 
        if value == "":
120
 
            if not self.allow_none:
121
 
                raise MissingParameterError(self.name, kind=self.kind)
122
 
            return self.default
123
 
        try:
124
 
            self._check_range(value)
125
 
            parsed = self.parse(value)
126
 
            if self.validator and not self.validator(parsed):
127
 
                raise ValueError(value)
128
 
            return parsed
129
 
        except ValueError:
130
 
            try:
131
 
                value = value.decode("utf-8")
132
 
                message = "Invalid %s value %s" % (self.kind, value)
133
 
            except UnicodeDecodeError:
134
 
                message = "Invalid %s value" % self.kind
135
 
            raise InvalidParameterValueError(message)
136
 
 
137
 
    def _check_range(self, value):
138
 
        """Check that the given C{value} is in the expected range."""
139
 
        if self.min is None and self.max is None:
140
 
            return
141
 
 
142
 
        measure = self.measure(value)
143
 
        prefix = "Value (%s) for parameter %s is invalid.  %s"
144
 
 
145
 
        if self.min is not None and measure < self.min:
146
 
            message = prefix % (value, self.name,
147
 
                                self.lower_than_min_template % self.min)
148
 
            raise InvalidParameterValueError(message)
149
 
 
150
 
        if self.max is not None and measure > self.max:
151
 
            message = prefix % (value, self.name,
152
 
                                self.greater_than_max_template % self.max)
153
 
            raise InvalidParameterValueError(message)
154
 
 
155
 
    def parse(self, value):
156
 
        """
157
 
        Parse a single parameter value coverting it to the appropriate type.
158
 
        """
159
 
        raise NotImplementedError()
160
 
 
161
 
    def format(self, value):
162
 
        """
163
 
        Format a single parameter value in a way suitable for an HTTP request.
164
 
        """
165
 
        raise NotImplementedError()
166
 
 
167
 
    def measure(self, value):
168
 
        """
169
 
        Return an C{int} providing a measure for C{value}, used for C{range}.
170
 
        """
171
 
        raise NotImplementedError()
172
 
 
173
 
 
174
 
class Unicode(Parameter):
175
 
    """A parameter that must be a C{unicode}."""
176
 
 
177
 
    kind = "unicode"
178
 
 
179
 
    lower_than_min_template = "Length must be at least %s."
180
 
    greater_than_max_template = "Length exceeds maximum of %s."
181
 
 
182
 
    def parse(self, value):
183
 
        return value.decode("utf-8")
184
 
 
185
 
    def format(self, value):
186
 
        return value.encode("utf-8")
187
 
 
188
 
    def measure(self, value):
189
 
        return len(value)
190
 
 
191
 
 
192
 
class UnicodeLine(Unicode):
193
 
    """A parameter that must be a C{unicode} string without newlines."""
194
 
 
195
 
    def coerce(self, value):
196
 
        super(UnicodeLine, self).coerce(value)
197
 
        if "\n" in value:
198
 
            raise InvalidParameterValueError("Can't contain newlines.")
199
 
 
200
 
 
201
 
class RawStr(Parameter):
202
 
    """A parameter that must be a C{str}."""
203
 
 
204
 
    kind = "raw string"
205
 
 
206
 
    def parse(self, value):
207
 
        return str(value)
208
 
 
209
 
    def format(self, value):
210
 
        return value
211
 
 
212
 
 
213
 
class Integer(Parameter):
214
 
    """A parameter that must be a positive C{int}."""
215
 
 
216
 
    kind = "integer"
217
 
 
218
 
    lower_than_min_template = "Value must be at least %s."
219
 
    greater_than_max_template = "Value exceeds maximum of %s."
220
 
 
221
 
    def __init__(self, name=None, optional=False, default=None,
222
 
                 min=0, max=None, allow_none=False, validator=None,
223
 
                 doc=None):
224
 
        super(Integer, self).__init__(name, optional, default, min, max,
225
 
                                      allow_none, validator, doc=doc)
226
 
 
227
 
    def parse(self, value):
228
 
        return int(value)
229
 
 
230
 
    def format(self, value):
231
 
        return str(value)
232
 
 
233
 
    def measure(self, value):
234
 
        return int(value)
235
 
 
236
 
 
237
 
class Bool(Parameter):
238
 
    """A parameter that must be a C{bool}."""
239
 
 
240
 
    kind = "boolean"
241
 
 
242
 
    def parse(self, value):
243
 
        if value == "true":
244
 
            return True
245
 
        if value == "false":
246
 
            return False
247
 
        raise ValueError()
248
 
 
249
 
    def format(self, value):
250
 
        if value:
251
 
            return "true"
252
 
        else:
253
 
            return "false"
254
 
 
255
 
 
256
 
class Enum(Parameter):
257
 
    """A parameter with enumerated values.
258
 
 
259
 
    @param name: The name of the parameter, as specified in a request.
260
 
    @param optional: If C{True} the parameter may not be present.
261
 
    @param default: A default value for the parameter, if not present.
262
 
    @param mapping: A mapping of accepted values to the values that
263
 
        will be returned by C{parse}.
264
 
    """
265
 
 
266
 
    kind = "enum"
267
 
 
268
 
    def __init__(self, name=None, mapping=None, optional=False, default=None,
269
 
                 doc=None):
270
 
        super(Enum, self).__init__(name, optional=optional, default=default,
271
 
                                   doc=doc)
272
 
        if mapping is None:
273
 
            raise TypeError("Must provide mapping")
274
 
        self.mapping = mapping
275
 
        self.reverse = dict((value, key) for key, value in mapping.iteritems())
276
 
 
277
 
    def parse(self, value):
278
 
        try:
279
 
            return self.mapping[value]
280
 
        except KeyError:
281
 
            raise ValueError()
282
 
 
283
 
    def format(self, value):
284
 
        return self.reverse[value]
285
 
 
286
 
 
287
 
class Date(Parameter):
288
 
    """A parameter that must be a valid ISO 8601 formatted date."""
289
 
 
290
 
    kind = "date"
291
 
 
292
 
    def parse(self, value):
293
 
        return parse(value).replace(tzinfo=tzutc())
294
 
 
295
 
    def format(self, value):
296
 
        # Convert value to UTC.
297
 
        tt = value.utctimetuple()
298
 
        utc_value = datetime(
299
 
            tt.tm_year, tt.tm_mon, tt.tm_mday, tt.tm_hour, tt.tm_min,
300
 
            tt.tm_sec)
301
 
        return datetime.strftime(utc_value, "%Y-%m-%dT%H:%M:%SZ")
302
 
 
303
 
 
304
 
class List(Parameter):
305
 
    """
306
 
    A homogenous list of instances of a parameterized type.
307
 
 
308
 
    There is a strange behavior that lists can have any starting index and any
309
 
    gaps are ignored.  Conventionally they are 1-based, and so indexes proceed
310
 
    like 1, 2, 3...  However, any non-negative index can be used and the
311
 
    ordering will be used to determine the true index. So::
312
 
 
313
 
        {5: 'a', 7: 'b', 9: 'c'}
314
 
 
315
 
    becomes::
316
 
 
317
 
        ['a', 'b', 'c']
318
 
    """
319
 
 
320
 
    kind = "list"
321
 
    supports_multiple = True
322
 
 
323
 
    def __init__(self, name=None, item=None, optional=False, default=None,
324
 
                 doc=None):
325
 
        """
326
 
        @param item: A L{Parameter} instance which will be used to parse and
327
 
            format the values in the list.
328
 
        """
329
 
        if item is None:
330
 
            raise TypeError("Must provide item")
331
 
        super(List, self).__init__(name, optional=optional, default=default,
332
 
                                   doc=doc)
333
 
        if item.name is None:
334
 
            item.name = name
335
 
        self.item = item
336
 
        if default is None:
337
 
            self.default = []
338
 
 
339
 
    def parse(self, value):
340
 
        """
341
 
        Convert a dictionary of {relative index: value} to a list of parsed
342
 
        C{value}s.
343
 
        """
344
 
        indices = []
345
 
        if not isinstance(value, dict):
346
 
            # We interpret non-list inputs as a list of one element, for
347
 
            # compatibility with certain EC2 APIs.
348
 
            return [self.item.coerce(value)]
349
 
        for index in value.keys():
350
 
            try:
351
 
                indices.append(int(index))
352
 
            except ValueError:
353
 
                raise UnknownParameterError(index)
354
 
        result = [None] * len(value)
355
 
        for index_index, index in enumerate(sorted(indices)):
356
 
            v = value[str(index)]
357
 
            if index < 0:
358
 
                raise UnknownParameterError(index)
359
 
            result[index_index] = self.item.coerce(v)
360
 
        return result
361
 
 
362
 
    def format(self, value):
363
 
        """
364
 
        Convert a list like::
365
 
 
366
 
            ["a", "b", "c"]
367
 
 
368
 
        to:
369
 
 
370
 
            {"1": "a", "2": "b", "3": "c"}
371
 
 
372
 
        C{value} may also be an L{Arguments} instance, mapping indices to
373
 
        values. Who knows why.
374
 
        """
375
 
        if isinstance(value, Arguments):
376
 
            return dict((str(i), self.item.format(v)) for i, v in value)
377
 
        return dict((str(i + 1), self.item.format(v))
378
 
                    for i, v in enumerate(value))
379
 
 
380
 
 
381
 
class Structure(Parameter):
382
 
    """
383
 
    A structure with named fields of parameterized types.
384
 
    """
385
 
 
386
 
    kind = "structure"
387
 
    supports_multiple = True
388
 
 
389
 
    def __init__(self, name=None, fields=None, optional=False, default=None,
390
 
                 doc=None):
391
 
        """
392
 
        @param fields: A mapping of field name to field L{Parameter} instance.
393
 
        """
394
 
        if fields is None:
395
 
            raise TypeError("Must provide fields")
396
 
        super(Structure, self).__init__(name, optional=optional,
397
 
                                        default=default, doc=doc)
398
 
        _namify_arguments(fields)
399
 
        self.fields = fields
400
 
 
401
 
    def parse(self, value):
402
 
        """
403
 
        Convert a dictionary of raw values to a dictionary of processed values.
404
 
        """
405
 
        result = {}
406
 
        rest = {}
407
 
        for k, v in value.iteritems():
408
 
            if k in self.fields:
409
 
                if (isinstance(v, dict)
410
 
                        and not self.fields[k].supports_multiple):
411
 
                    if len(v) == 1:
412
 
                        # We support "foo.1" as "foo" as long as there is only
413
 
                        # one "foo.#" parameter provided.... -_-
414
 
                        v = v.values()[0]
415
 
                    else:
416
 
                        raise InvalidParameterCombinationError(k)
417
 
                result[k] = self.fields[k].coerce(v)
418
 
            else:
419
 
                rest[k] = v
420
 
        for k, v in self.fields.iteritems():
421
 
            if k not in result:
422
 
                result[k] = v.coerce(None)
423
 
        if rest:
424
 
            raise UnknownParametersError(result, rest)
425
 
        return result
426
 
 
427
 
    def format(self, value):
428
 
        """
429
 
        Convert a dictionary of processed values to a dictionary of raw values.
430
 
        """
431
 
        if not isinstance(value, Arguments):
432
 
            value = value.iteritems()
433
 
        return dict((k, self.fields[k].format(v)) for k, v in value)
434
 
 
435
 
 
436
 
class Arguments(object):
437
 
    """Arguments parsed from a request."""
438
 
 
439
 
    def __init__(self, tree):
440
 
        """Initialize a new L{Arguments} instance.
441
 
 
442
 
        @param tree: The C{dict}-based structure of the L{Argument} instance
443
 
            to create.
444
 
        """
445
 
        for key, value in tree.iteritems():
446
 
            self.__dict__[key] = self._wrap(value)
447
 
 
448
 
    def __str__(self):
449
 
        return "Arguments(%s)" % (self.__dict__,)
450
 
 
451
 
    __repr__ = __str__
452
 
 
453
 
    def __iter__(self):
454
 
        """Returns an iterator yielding C{(name, value)} tuples."""
455
 
        return self.__dict__.iteritems()
456
 
 
457
 
    def __getitem__(self, index):
458
 
        """Return the argument value with the given L{index}."""
459
 
        return self.__dict__[index]
460
 
 
461
 
    def __len__(self):
462
 
        """Return the number of arguments."""
463
 
        return len(self.__dict__)
464
 
 
465
 
    def __contains__(self, key):
466
 
        """Return whether an argument with the given name is present."""
467
 
        return key in self.__dict__
468
 
 
469
 
    def _wrap(self, value):
470
 
        """Wrap the given L{tree} with L{Arguments} as necessary.
471
 
 
472
 
        @param tree: A {dict}, containing L{dict}s and/or leaf values, nested
473
 
            arbitrarily deep.
474
 
        """
475
 
        if isinstance(value, dict):
476
 
            if any(isinstance(name, int) for name in value.keys()):
477
 
                if not all(isinstance(name, int) for name in value.keys()):
478
 
                    raise RuntimeError("Integer and non-integer keys: %r"
479
 
                                       % value.keys())
480
 
                items = sorted(value.iteritems(), key=itemgetter(0))
481
 
                return [self._wrap(val) for _, val in items]
482
 
            else:
483
 
                return Arguments(value)
484
 
        elif isinstance(value, list):
485
 
            return [self._wrap(x) for x in value]
486
 
        else:
487
 
            return value
488
 
 
489
 
 
490
 
def _namify_arguments(mapping):
491
 
    """
492
 
    Ensure that a mapping of names to parameters has the parameters set to the
493
 
    correct name.
494
 
    """
495
 
    result = []
496
 
    for name, parameter in mapping.iteritems():
497
 
        parameter.name = name
498
 
        result.append(parameter)
499
 
    return result
500
 
 
501
 
 
502
 
class Schema(object):
503
 
    """
504
 
    The schema that the arguments of an HTTP request must be compliant with.
505
 
    """
506
 
 
507
 
    def __init__(self, *_parameters, **kwargs):
508
 
        """Initialize a new L{Schema} instance.
509
 
 
510
 
        Any number of L{Parameter} instances can be passed. The parameter names
511
 
        are used in L{Schema.extract} and L{Schema.bundle}. For example::
512
 
 
513
 
          schema = Schema(name="SetName", parameters=[Unicode("Name")])
514
 
 
515
 
        means that the result of L{Schema.extract} would have a C{Name}
516
 
        attribute. Similarly, L{Schema.bundle} would look for a C{Name}
517
 
        attribute.
518
 
 
519
 
        A more complex example::
520
 
 
521
 
          schema = Schema(
522
 
              name="SetNames",
523
 
              parameters=[List("Names", Unicode())])
524
 
 
525
 
        means that the result of L{Schema.extract} would have a C{Names}
526
 
        attribute, which would itself contain a list of names. Similarly,
527
 
        L{Schema.bundle} would look for a C{Names} attribute.
528
 
 
529
 
        Currently all parameters other than C{parameters} have no effect; they
530
 
        are merely exposed as attributes of instances of Schema, and are able
531
 
        to be overridden in L{extend}.
532
 
 
533
 
        @param name: (keyword) The name of the API call that this schema
534
 
            represents. Accessible via the C{name} attribute.
535
 
        @param parameters: (keyword) The parameters of the API, as a list of
536
 
            named L{Parameter} instances.
537
 
        @param doc: (keyword) The documentation of this API Call. Accessible
538
 
            via the C{doc} attribute.
539
 
        @param result: (keyword) A description of the result of this API
540
 
            call. Accessible via the C{result} attribute.
541
 
        @param errors: (keyword) A sequence of exception classes that the API
542
 
            can potentially raise. Accessible as a L{set} via the C{errors}
543
 
            attribute.
544
 
        """
545
 
        self.name = kwargs.pop('name', None)
546
 
        self.doc = kwargs.pop('doc', None)
547
 
        self.result = kwargs.pop('result', None)
548
 
        self.errors = set(kwargs.pop('errors', []))
549
 
        if 'parameters' in kwargs:
550
 
            if len(_parameters) > 0:
551
 
                raise TypeError("parameters= must only be passed "
552
 
                                "without positional arguments")
553
 
            self._parameters = kwargs['parameters']
554
 
        else:
555
 
            self._parameters = self._convert_old_schema(_parameters)
556
 
 
557
 
    def get_parameters(self):
558
 
        """
559
 
        Get the list of parameters this schema supports.
560
 
        """
561
 
        return self._parameters[:]
562
 
 
563
 
    def extract(self, params):
564
 
        """Extract parameters from a raw C{dict} according to this schema.
565
 
 
566
 
        @param params: The raw parameters to parse.
567
 
        @return: A tuple of an L{Arguments} object holding the extracted
568
 
            arguments and any unparsed arguments.
569
 
        """
570
 
        structure = Structure(fields=dict([(p.name, p)
571
 
                                           for p in self._parameters]))
572
 
        try:
573
 
            tree = structure.coerce(self._convert_flat_to_nest(params))
574
 
            rest = {}
575
 
        except UnknownParametersError, error:
576
 
            tree = error.result
577
 
            rest = self._convert_nest_to_flat(error.unknown)
578
 
        return Arguments(tree), rest
579
 
 
580
 
    def bundle(self, *arguments, **extra):
581
 
        """Bundle the given arguments in a C{dict} with EC2-style format.
582
 
 
583
 
        @param arguments: L{Arguments} instances to bundle. Keys in
584
 
            later objects will override those in earlier objects.
585
 
        @param extra: Any number of additional parameters. These will override
586
 
            similarly named arguments in L{arguments}.
587
 
        """
588
 
        params = {}
589
 
 
590
 
        for argument in arguments:
591
 
            params.update(argument)
592
 
 
593
 
        params.update(extra)
594
 
        result = {}
595
 
        for name, value in params.iteritems():
596
 
            if value is None:
597
 
                continue
598
 
            segments = name.split('.')
599
 
            first = segments[0]
600
 
            parameter = self.get_parameter(first)
601
 
            if parameter is None:
602
 
                raise RuntimeError("Parameter '%s' not in schema" % name)
603
 
            else:
604
 
                if value is None:
605
 
                    result[name] = ""
606
 
                else:
607
 
                    result[name] = parameter.format(value)
608
 
 
609
 
        return self._convert_nest_to_flat(result)
610
 
 
611
 
    def get_parameter(self, name):
612
 
        """
613
 
        Get the parameter on this schema with the given C{name}.
614
 
        """
615
 
        for parameter in self._parameters:
616
 
            if parameter.name == name:
617
 
                return parameter
618
 
 
619
 
    def _convert_flat_to_nest(self, params):
620
 
        """
621
 
        Convert a structure in the form of::
622
 
 
623
 
            {'foo.1.bar': 'value',
624
 
             'foo.2.baz': 'value'}
625
 
 
626
 
        to::
627
 
 
628
 
            {'foo': {'1': {'bar': 'value'},
629
 
                     '2': {'baz': 'value'}}}
630
 
 
631
 
        This is intended for use both during parsing of HTTP arguments like
632
 
        'foo.1.bar=value' and when dealing with schema declarations that look
633
 
        like 'foo.n.bar'.
634
 
 
635
 
        This is the inverse of L{_convert_nest_to_flat}.
636
 
        """
637
 
        result = {}
638
 
        for k, v in params.iteritems():
639
 
            last = result
640
 
            segments = k.split('.')
641
 
            for index, item in enumerate(segments):
642
 
                if index == len(segments) - 1:
643
 
                    newd = v
644
 
                else:
645
 
                    newd = {}
646
 
                if not isinstance(last, dict):
647
 
                    raise InconsistentParameterError(k)
648
 
                if type(last.get(item)) is dict and type(newd) is not dict:
649
 
                    raise InconsistentParameterError(k)
650
 
                last = last.setdefault(item, newd)
651
 
        return result
652
 
 
653
 
    def _convert_nest_to_flat(self, params, _result=None, _prefix=None):
654
 
        """
655
 
        Convert a data structure that looks like::
656
 
 
657
 
            {"foo": {"bar": "baz", "shimmy": "sham"}}
658
 
 
659
 
        to::
660
 
 
661
 
            {"foo.bar": "baz",
662
 
             "foo.shimmy": "sham"}
663
 
 
664
 
        This is the inverse of L{_convert_flat_to_nest}.
665
 
        """
666
 
        if _result is None:
667
 
            _result = {}
668
 
        for k, v in params.iteritems():
669
 
            if _prefix is None:
670
 
                path = k
671
 
            else:
672
 
                path = _prefix + '.' + k
673
 
            if isinstance(v, dict):
674
 
                self._convert_nest_to_flat(v, _result=_result, _prefix=path)
675
 
            else:
676
 
                _result[path] = v
677
 
        return _result
678
 
 
679
 
    def extend(self, *schema_items, **kwargs):
680
 
        """
681
 
        Add any number of schema items to a new schema.
682
 
 
683
 
        Takes the same arguments as the constructor, and returns a new
684
 
        L{Schema} instance.
685
 
 
686
 
        If parameters, result, or errors is specified, they will be merged with
687
 
        the existing parameters, result, or errors.
688
 
        """
689
 
        new_kwargs = {
690
 
            'name': self.name,
691
 
            'doc': self.doc,
692
 
            'parameters': self._parameters[:],
693
 
            'result': self.result.copy() if self.result else {},
694
 
            'errors': self.errors.copy() if self.errors else set()}
695
 
        if 'parameters' in kwargs:
696
 
            new_params = kwargs.pop('parameters')
697
 
            new_kwargs['parameters'].extend(new_params)
698
 
        new_kwargs['result'].update(kwargs.pop('result', {}))
699
 
        new_kwargs['errors'].update(kwargs.pop('errors', set()))
700
 
        new_kwargs.update(kwargs)
701
 
 
702
 
        if schema_items:
703
 
            parameters = self._convert_old_schema(schema_items)
704
 
            new_kwargs['parameters'].extend(parameters)
705
 
        return Schema(**new_kwargs)
706
 
 
707
 
    def _convert_old_schema(self, parameters):
708
 
        """
709
 
        Convert an ugly old schema, using dotted names, to the hot new schema,
710
 
        using List and Structure.
711
 
 
712
 
        The old schema assumes that every other dot implies an array. So a list
713
 
        of two parameters,
714
 
 
715
 
            [Integer("foo.bar.baz.quux"), Integer("foo.bar.shimmy")]
716
 
 
717
 
        becomes::
718
 
 
719
 
            [List(
720
 
                "foo",
721
 
                item=Structure(
722
 
                    fields={"baz": List(item=Integer()),
723
 
                            "shimmy": Integer()}))]
724
 
 
725
 
        By design, the old schema syntax ignored the names "bar" and "quux".
726
 
        """
727
 
        # 'merged' here is an associative list that maps parameter names to
728
 
        # Parameter instances, OR sub-associative lists which represent nested
729
 
        # lists and structures.
730
 
        # e.g.,
731
 
        #    [Integer("foo")]
732
 
        # becomes
733
 
        #    [("foo", Integer("foo"))]
734
 
        # and
735
 
        #    [Integer("foo.bar")]
736
 
        # (which represents a list of integers called "foo" with a meaningless
737
 
        # index name of "bar") becomes
738
 
        #     [("foo", [("bar", Integer("foo.bar"))])].
739
 
        merged = []
740
 
        for parameter in parameters:
741
 
            segments = parameter.name.split('.')
742
 
            _merge_associative_list(merged, segments, parameter)
743
 
        result = [self._inner_convert_old_schema(node, 1) for node in merged]
744
 
        return result
745
 
 
746
 
    def _inner_convert_old_schema(self, node, depth):
747
 
        """
748
 
        Internal recursion helper for L{_convert_old_schema}.
749
 
 
750
 
        @param node: A node in the associative list tree as described in
751
 
            _convert_old_schema. A two tuple of (name, parameter).
752
 
        @param depth: The depth that the node is at. This is important to know
753
 
            if we're currently processing a list or a structure. ("foo.N" is a
754
 
            list called "foo", "foo.N.fieldname" describes a field in a list of
755
 
            structs).
756
 
        """
757
 
        name, parameter_description = node
758
 
        if not isinstance(parameter_description, list):
759
 
            # This is a leaf, i.e., an actual L{Parameter} instance.
760
 
            return parameter_description
761
 
        if depth % 2 == 0:
762
 
            # we're processing a structure.
763
 
            fields = {}
764
 
            for node in parameter_description:
765
 
                fields[node[0]] = self._inner_convert_old_schema(
766
 
                    node, depth + 1)
767
 
            return Structure(name, fields=fields)
768
 
        else:
769
 
            # we're processing a list.
770
 
            if not isinstance(parameter_description, list):
771
 
                raise TypeError("node %r must be an associative list"
772
 
                                % (parameter_description,))
773
 
            if not len(parameter_description) == 1:
774
 
                raise ValueError(
775
 
                    "Multiple different index names specified: %r"
776
 
                    % ([item[0] for item in parameter_description],))
777
 
            subnode = parameter_description[0]
778
 
            item = self._inner_convert_old_schema(subnode, depth + 1)
779
 
            return List(name=name, item=item, optional=item.optional)
780
 
 
781
 
 
782
 
def _merge_associative_list(alist, path, value):
783
 
    """
784
 
    Merge a value into an associative list at the given path, maintaining
785
 
    insertion order. Examples will explain it::
786
 
 
787
 
        >>> alist = []
788
 
        >>> _merge_associative_list(alist, ["foo", "bar"], "barvalue")
789
 
        >>> _merge_associative_list(alist, ["foo", "baz"], "bazvalue")
790
 
        >>> alist == [("foo", [("bar", "barvalue"), ("baz", "bazvalue")])]
791
 
 
792
 
    @param alist: An associative list of names to values.
793
 
    @param path: A path through sub-alists which we ultimately want to point to
794
 
    C{value}.
795
 
    @param value: The value to set.
796
 
    @return: None. This operation mutates the associative list in place.
797
 
    """
798
 
    for key in path[:-1]:
799
 
        for item in alist:
800
 
            if item[0] == key:
801
 
                alist = item[1]
802
 
                break
803
 
        else:
804
 
            subalist = []
805
 
            alist.append((key, subalist))
806
 
            alist = subalist
807
 
    alist.append((path[-1], value))