1
# This file is part of DEAP.
3
# DEAP is free software: you can redistribute it and/or modify
4
# it under the terms of the GNU Lesser General Public License as
5
# published by the Free Software Foundation, either version 3 of
6
# the License, or (at your option) any later version.
8
# DEAP is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU Lesser General Public License for more details.
13
# You should have received a copy of the GNU Lesser General Public
14
# License along with DEAP. If not, see <http://www.gnu.org/licenses/>.
17
from itertools import product
19
def product(*args, **kwds):
20
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
21
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
22
pools = map(tuple, args) * kwds.get('repeat', 1)
25
result = [x+[y] for x in result for y in pool]
29
class SortingNetwork(list):
30
"""Sorting network class.
32
From Wikipedia : A sorting network is an abstract mathematical model
33
of a network of wires and comparator modules that is used to sort a
34
sequence of numbers. Each comparator connects two wires and sort the
35
values by outputting the smaller value to one wire, and a larger
38
def __init__(self, dimension, connectors = []):
39
self.dimension = dimension
40
for wire1, wire2 in connectors:
41
self.addConnector(wire1, wire2)
43
def addConnector(self, wire1, wire2):
44
"""Add a connector between wire1 and wire2 in the network."""
49
wire1, wire2 = wire2, wire1
54
# Empty network, create new level and connector
55
self.append([(wire1, wire2)])
58
for wires in last_level:
59
if wires[1] >= wire1 and wires[0] <= wire2:
60
self.append([(wire1, wire2)])
63
last_level.append((wire1, wire2))
65
def sort(self, values):
66
"""Sort the values in-place based on the connectors in the network."""
68
for wire1, wire2 in level:
69
if values[wire1] > values[wire2]:
70
values[wire1], values[wire2] = values[wire2], values[wire1]
72
def assess(self, cases=None):
73
"""Try to sort the **cases** using the network, return the number of
74
misses. If **cases** is None, test all possible cases according to
75
the network dimensionality.
78
cases = product(range(2), repeat=self.dimension)
81
ordered = [[0]*(self.dimension-i) + [1]*i for i in range(self.dimension+1)]
82
for sequence in cases:
83
sequence = list(sequence)
85
misses += (sequence != ordered[sum(sequence)])
89
"""Return an ASCII representation of the network."""
90
str_wires = [["-"]*7 * self.depth]
92
str_wires[0][1] = " o"
95
for i in xrange(1, self.dimension):
96
str_wires.append(["-"]*7 * self.depth)
97
str_spaces.append([" "]*7 * self.depth)
98
str_wires[i][0] = str(i)
99
str_wires[i][1] = " o"
101
for index, level in enumerate(self):
102
for wire1, wire2 in level:
103
str_wires[wire1][(index+1)*6] = "x"
104
str_wires[wire2][(index+1)*6] = "x"
105
for i in xrange(wire1, wire2):
106
str_spaces[i][(index+1)*6+1] = "|"
107
for i in xrange(wire1+1, wire2):
108
str_wires[i][(index+1)*6] = "|"
110
network_draw = "".join(str_wires[0])
111
for line, space in zip(str_wires[1:], str_spaces):
113
network_draw += "".join(space)
115
network_draw += "".join(line)
120
"""Return the number of parallel steps that it takes to sort any input.
126
"""Return the number of comparison-swap used."""
127
return sum(len(level) for level in self)