2
# This file is part of the Connection-Set Algebra (CSA).
3
# Copyright (C) 2010 Mikael Djurfeldt
5
# CSA is free software; you can redistribute it and/or modify
6
# it under the terms of the GNU General Public License as published by
7
# the Free Software Foundation; either version 3 of the License, or
8
# (at your option) any later version.
10
# CSA is distributed in the hope that it will be useful,
11
# but WITHOUT ANY WARRANTY; without even the implied warranty of
12
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
# GNU General Public License for more details.
15
# You should have received a copy of the GNU General Public License
16
# along with this program. If not, see <http://www.gnu.org/licenses/>.
20
import random as _random
23
import valueset as _vs
27
class Random (_cs.Operator):
28
def __mul__ (self, valueSet):
29
return ValueSetRandomMask (valueSet)
31
def __call__ (self, p = None, N = None, fanIn = None, fanOut = None):
33
assert N == None and fanIn == None and fanOut == None, \
34
'inconsistent parameters'
35
return _elementary.ConstantRandomMask (p)
38
'inconsistent parameters'
39
return _elementary.SampleNRandomMask (N)
40
assert fanIn != None or fanOut != None, 'not implemenented'
41
assert False, 'inconsistent parameters'
44
class ValueSetRandomMask (_cs.Mask):
45
def __init__ (self, valueSet):
46
_cs.Mask.__init__ (self)
47
self.valueSet = valueSet
48
self.state = _random.getstate ()
50
def startIteration (self, state):
51
_random.setstate (self.state)
53
def iterator (self, low0, high0, low1, high1, state):
54
for j in xrange (low1, high1):
55
for i in xrange (low0, high0):
56
if _random.random () < self.valueSet (i, j):
60
class Disc (_cs.Operator):
61
def __init__ (self, r):
64
def __mul__ (self, metric):
65
return DiscMask (self.r, metric)
68
class DiscMask (_cs.Mask):
69
def __init__ (self, r, metric):
70
_cs.Mask.__init__ (self)
74
def iterator (self, low0, high0, low1, high1, state):
75
for j in xrange (low1, high1):
76
for i in xrange (low0, high0):
77
if self.metric (i, j) < self.r:
81
class Gaussian (_cs.Operator):
82
def __init__ (self, sigma, cutoff):
86
def __mul__ (self, metric):
87
return GaussianValueSet (self.sigma, self.cutoff, metric)
90
class GaussianValueSet (_vs.ValueSet):
91
def __init__ (self, sigma, cutoff, metric):
92
self.sigma22 = 2* sigma * sigma
96
def __call__ (self, i, j):
97
d = self.metric (i, j)
98
return _math.exp (- d * d / self.sigma22) if d < self.cutoff else 0.0
101
class Block (_cs.Operator):
102
def __init__ (self, M, N):
106
def __mul__ (self, other):
107
c = _cs.coerceCSet (other)
108
if isinstance (c, _cs.Mask):
109
return BlockMask (self.M, self.N, c)
111
return _cs.ConnectionSet (BlockCSet (self.M, self.N, c))
114
class BlockMask (_cs.Mask):
115
def __init__ (self, M, N, mask):
116
_cs.Mask.__init__ (self)
121
def iterator (self, low0, high0, low1, high1, state):
122
maskIter = self.m.iterator (low0 / self.M,
123
(high0 + self.M - 1) / self.M,
125
(high1 + self.N - 1) / self.N,
129
(i, j) = maskIter.next ()
131
# collect connections in one connection matrix column
135
(i, j) = maskIter.next ()
137
# generate blocks for the column
138
for jj in xrange (max (self.N * post, low1),
139
min (self.N * (post + 1), high1)):
141
for ii in xrange (max (self.M * k, low0),
142
min (self.M * (k + 1), high0)):
145
except StopIteration:
147
# generate blocks for the last column
148
for jj in xrange (max (self.N * post, low1),
149
min (self.N * (post + 1), high1)):
151
for ii in xrange (max (self.M * k, low0),
152
min (self.M * (k + 1), high0)):