~opencrea/+junk/market

« back to all changes in this revision

Viewing changes to ann/models/ann_layer.py

  • Committer: joannes
  • Date: 2017-09-23 21:42:25 UTC
  • Revision ID: joannes@debian-20170923214225-0ez3b2juxo7lduii
convergent

Show diffs side-by-side

added added

removed removed

Lines of Context:
10
10
import os
11
11
import time
12
12
import datetime
 
13
import random
13
14
 
14
15
import logging
15
16
_logger = logging.getLogger(__name__)
16
17
 
 
18
LIMIT_PRECISION = 0.01
17
19
 
18
20
class AnnLayer(models.Model):
19
21
    """ANN Layer"""
67
69
            ],
68
70
        string='Status', required=True, default='draft')
69
71
 
 
72
    @api.multi
 
73
    def unlink(self):
 
74
        for layer in self:
 
75
            for neuron in layer.neuron_ids:
 
76
                neuron.input_connection_ids.unlink()
 
77
                neuron.output_connection_ids.unlink()
 
78
            layer.neuron_ids.unlink()
 
79
        return super(AnnLayer, self).unlink()
 
80
 
 
81
    @api.multi
 
82
    def button_backforward_error(self):
 
83
        "Backforward the error on the Network"
 
84
        limit_precision = 0.01
 
85
 
 
86
        #return importance correction weight
 
87
        def find_importance(value):
 
88
            "defined importance of connection value"
 
89
            fib1 = 0
 
90
            fib2 = 1
 
91
            i = 0
 
92
 
 
93
            while True:
 
94
                fib3 = fib1 + fib2
 
95
 
 
96
                if value <= i:
 
97
                    return fib3
 
98
                else:
 
99
                    i += 1
 
100
                    fib1 = fib2
 
101
                    fib2 = fib3
 
102
 
 
103
        for layer in self:
 
104
 
 
105
            #Get the error
 
106
            if layer.next_layer_id:
 
107
                for neuron in layer.neuron_ids:
 
108
 
 
109
                    real_value = 0.0
 
110
                    total_importance = 0.0
 
111
                    for connection in neuron.output_connection_ids:
 
112
                        importance = find_importance(connection.weight_nb)
 
113
                        add_real_value = connection.value_backforward * 1.0 / importance
 
114
 
 
115
                        if abs(add_real_value) > limit_precision:
 
116
                            real_value += add_real_value
 
117
                            total_importance += 1.0 / importance
 
118
                        elif add_real_value > 0.0:
 
119
                            real_value += limit_precision
 
120
                        else:
 
121
                            real_value -= limit_precision
 
122
 
 
123
 
 
124
                    if total_importance > 0.0:
 
125
                        neuron.real_value = real_value / total_importance
 
126
                    elif real_value >= 0.0:
 
127
                        neuron.real_value = limit_precision
 
128
                    else:
 
129
                        neuron.real_value = - limit_precision
 
130
 
 
131
    @api.multi
 
132
    def button_train(self):
 
133
        "Backforward the error on the Network"
 
134
        error_nb_max = 10
 
135
 
 
136
        def zero_limit(value):
 
137
            "limit zero value"
 
138
            if value >= 0.0 and value < LIMIT_PRECISION:
 
139
                return LIMIT_PRECISION
 
140
            elif value < 0.0 and value > - LIMIT_PRECISION:
 
141
                return - LIMIT_PRECISION
 
142
            else:
 
143
                return value
 
144
 
 
145
        #defined importance correction weight
 
146
        def find_importance(value, input_abs):
 
147
            "defined importance of connection value"
 
148
            fib1 = 0
 
149
            fib2 = 1
 
150
            i = 1
 
151
 
 
152
            if not value:
 
153
                return 0
 
154
            elif abs(value) >= 2 * input_abs:
 
155
                return 1
 
156
            else:
 
157
                while True:
 
158
                    fib3 = fib1 + fib2
 
159
                    i += 1
 
160
                    if abs(value) > input_abs / fib3:
 
161
                        return i
 
162
                    else:
 
163
                        fib1 = fib2
 
164
                        fib2 = fib3
 
165
 
 
166
        for layer in self:
 
167
 
 
168
            #Get the error
 
169
            layer.button_backforward_error()
 
170
 
 
171
            #compute next
 
172
            if layer.previous_layer_id:
 
173
 
 
174
                for neuron in layer.neuron_ids:
 
175
                    #linear
 
176
                    if neuron.activation_function == 'linear':
 
177
                        error_value = zero_limit(neuron.output_value - neuron.real_value)
 
178
 
 
179
                    #TODO: Save error value,  not used?
 
180
                    neuron.error_value = error_value
 
181
                    neuron.error_abs += (abs(error_value) + neuron.error_nb * neuron.error_abs) / (neuron.error_nb + 1)
 
182
                    neuron.error_avg = (error_value + neuron.error_nb * neuron.error_avg) / (neuron.error_nb + 1)
 
183
                    if not error_value:
 
184
                        neuron.error_nb += 1
 
185
 
 
186
                    if neuron.error_nb > error_nb_max:
 
187
                        neuron.error_nb = error_nb_max
 
188
                    elif neuron.error_nb < 1:
 
189
                        neuron.error_nb = 1
 
190
 
 
191
 
 
192
 
 
193
 
 
194
 
 
195
 
 
196
                    #TODO: neuron.input_abs = 0
 
197
 
 
198
                    #sort connection by value = weight * input
 
199
                    weighted_connection = {}
 
200
                    for input_connection in neuron.input_connection_ids:
 
201
                        input_connection_value = input_connection.from_neuron_id.output_value * input_connection.weight
 
202
                        i2 = find_importance(input_connection_value, neuron.input_abs)
 
203
                        if i2 not in list(weighted_connection.keys()):
 
204
                            weighted_connection[i2] = [input_connection]
 
205
                        else:
 
206
                            weighted_connection[i2].append(input_connection)
 
207
 
 
208
                    weighted_connection_keys = sorted(list(weighted_connection.keys()))
 
209
                    if len(weighted_connection_keys) and weighted_connection_keys[0] == 0:
 
210
                        weighted_connection[weighted_connection_keys[-1] + 1] = weighted_connection[0]
 
211
                        del weighted_connection[0]
 
212
                        weighted_connection_keys = sorted(list(weighted_connection.keys()))
 
213
 
 
214
                    len_input = 0
 
215
                    for key in weighted_connection_keys:
 
216
                        len_input += len(weighted_connection[key])
 
217
                        connection_error_value = zero_limit(error_value / len_input)
 
218
 
 
219
                        for input_connection in weighted_connection[key]:
 
220
                            input_connection_value = zero_limit(input_connection.from_neuron_id.output_value)
 
221
                            input_connection_weight = zero_limit(input_connection.weight)
 
222
 
 
223
                            #Define limit value for not dividing by zerog
 
224
                            if connection_error_value >= - LIMIT_PRECISION and connection_error_value <= LIMIT_PRECISION:
 
225
                                weight_error = 0.0
 
226
 
 
227
                            elif input_connection_weight >= - LIMIT_PRECISION and input_connection_weight <= LIMIT_PRECISION:
 
228
                                if connection_error_value * input_connection_value >= 0.0:
 
229
                                    weight_error = 2 * LIMIT_PRECISION
 
230
                                else:
 
231
                                    weight_error = - 2 * LIMIT_PRECISION
 
232
 
 
233
                            elif input_connection_value * input_connection_weight >= - LIMIT_PRECISION and \
 
234
                                 input_connection_value * input_connection_weight <= LIMIT_PRECISION:
 
235
                                if neuron.real_value * input_connection_value >= 0.0:
 
236
                                    weight_error = 2 * LIMIT_PRECISION
 
237
                                else:
 
238
                                    weight_error = - 2 * LIMIT_PRECISION
 
239
                            else:
 
240
                                #compute error weight
 
241
                                weight_error = zero_limit(- connection_error_value / input_connection_value)
 
242
                                if key + 1.0 > input_connection.weight_importance:
 
243
                                    weight_error = weight_error / (float(key) + 1.0)
 
244
                                else:
 
245
                                    weight_error = weight_error / (input_connection.weight_importance)
 
246
 
 
247
                            print "==compute error====", error_value, connection_error_value, input_connection_weight, weight_error
 
248
 
 
249
                            if input_connection_weight * (input_connection_weight + weight_error) <= 0.0:
 
250
                                #Limit variation of weight when sign change
 
251
                                if input_connection_weight + weight_error >= 0.0:
 
252
                                    weight_error = 2 * LIMIT_PRECISION
 
253
                                else:
 
254
                                    weight_error = - 2 * LIMIT_PRECISION
 
255
 
 
256
                            if abs(input_connection.weight_error_avg) > abs(weight_error):
 
257
                                if weight_error >= 0.0:
 
258
                                    weight_error = abs(input_connection.weight_error_avg)
 
259
                                else:
 
260
                                    weight_error = - abs(input_connection.weight_error_avg)
 
261
 
 
262
 
 
263
                            #Backforward value
 
264
                            if abs(input_connection_weight + weight_error) >= LIMIT_PRECISION:
 
265
                                input_connection.value_backforward = input_connection_value * input_connection_weight / (input_connection_weight + weight_error)
 
266
                            else:
 
267
                                print "==ha====", input_connection_weight,  weight_error
 
268
                                if input_connection_weight + weight_error >= 0.0:
 
269
                                    input_connection.value_backforward = input_connection_value / LIMIT_PRECISION
 
270
                                else:
 
271
                                    input_connection.value_backforward = - input_connection_value / LIMIT_PRECISION
 
272
 
 
273
                            input_connection.weight_error_avg = (9.0 * input_connection.weight_error_avg + weight_error) / 10.0
 
274
                            input_connection.weight_error_abs = (9.0 * input_connection.weight_error_abs + abs(weight_error)) / 10.0
 
275
                            input_connection.weight_importance = (9.0 * input_connection.weight_importance + float(key)) / 10.0
 
276
                            input_connection.weight_nb = key
 
277
 
 
278
 
 
279
 
 
280
 
 
281
                            #
 
282
                            input_connection.weight = zero_limit(input_connection_weight + weight_error)
 
283
 
 
284
                            input_connection.weight_error = weight_error
 
285
                            input_connection.weight_nb = key
70
286