~yrke/opaal/timedart-vars+discrete

« back to all changes in this revision

Viewing changes to opaal/model_parsers/pyuppaal/successor_generator.py

  • Committer: Kenneth Yrke Jørgensen
  • Date: 2011-05-17 08:07:38 UTC
  • mfrom: (37.1.19 timedart-vars)
  • Revision ID: mail@yrke.dk-20110517080738-y09sdczr3npky2af
Merged with vars

Show diffs side-by-side

added added

removed removed

Lines of Context:
44
44
        }
45
45
 
46
46
class SuccessorGenerator:
47
 
    def __init__(self, model, constant_overrides=dict()):
 
47
    def __init__(self, model, pwlist, constant_overrides=dict()):
48
48
        model = SimplifyModel(model, constant_overrides).simplify()
49
49
 
 
50
        self.pwlist = pwlist
50
51
        self.model = model
51
52
        self.transitions = {}
52
53
        self.sync_transitions = {}
54
55
        self.num_templates = len(self.model.templates)
55
56
        self.invariants = {}
56
57
        self.externs = {}
 
58
        self.maxconstant = None
57
59
 
58
60
        #mapping from template => location id => nice name
59
61
        self.location_labels = defaultdict(dict)
65
67
        self.clocks = declvisitor.clocks
66
68
 
67
69
        self.constants = {}
 
70
        clockList = []
68
71
        #constant overrides
69
72
        for (cident, cval) in constant_overrides.iteritems():
70
73
            #XXX assumes expression is already python
112
115
        logger.info("Constants: %s", self.constants)
113
116
        logger.info("Variables: %s", self.variables)
114
117
        logger.info("Clocks: %s", self.clocks)
 
118
        for c, _ in self.clocks:
 
119
            clockList.append(c)
115
120
        logger.info("Externs: %s", self.externs)
116
121
 
117
122
        #Calculate invariants
146
151
                    target_idx = t.locations.index(tr.target)
147
152
                    debug_info = {}
148
153
 
 
154
                    lowerVal = {}
 
155
                    higherVal = {}
 
156
                    resets = set()
149
157
                    guard_code = None
150
158
                    debug_info['guard_code'] = ''
151
 
                    if tr.guard.value != '':
152
 
                        logger.debug("Guard: %s ==> %s", tr.guard.value, util.expression_to_python_DEPR(tr.guard.value))
153
 
                        guard_code = compile(util.expression_to_python_DEPR(tr.guard.value), "<guard>", 'eval')
 
159
 
 
160
                    clockGuard = []
 
161
                    varGuard = []
 
162
 
 
163
                    if tr.guard.value != '' and tr.guard.value != None:
 
164
                        elements = tr.guard.value.split("and")
 
165
                        for s in elements:
 
166
                            splitGuard = s.strip().split(" ")
 
167
                            if splitGuard[0].strip() in clockList:
 
168
                                clockGuard.append(s.strip())
 
169
                            else:
 
170
                                varGuard.append(s.strip())
 
171
 
 
172
                    #Find guard
 
173
                        for s in clockGuard:
 
174
                            mccandidate = 0
 
175
                            splitGuard = s.strip().split(" ")
 
176
 
 
177
                            if splitGuard[1] == ">=":
 
178
                                lowerVal[splitGuard[0]] = int(splitGuard[2])
 
179
                                mccandidate =  int(splitGuard[2])
 
180
 
 
181
                            if splitGuard[1] == ">":
 
182
                                lowerVal[splitGuard[0]] = int(splitGuard[2])+1
 
183
                                mccandidate =  int(splitGuard[2]) +1
 
184
 
 
185
                            if splitGuard[1] == "<=":
 
186
                                higherVal[splitGuard[0]] = int(splitGuard[2])
 
187
                                mccandidate =  int(splitGuard[2])
 
188
                            
 
189
                            if splitGuard[1] == "<":
 
190
                                tmp = int(splitGuard[2])-1
 
191
                                if tmp < 0:
 
192
                                    print "Error, we can have number smaller than 0"
 
193
                                    exit()
 
194
                                higherVal[splitGuard[0]] = tmp
 
195
                                mccandidate =  tmp
 
196
 
 
197
                            if splitGuard[1] == "==":
 
198
                                lowerVal[splitGuard[0]] = int(splitGuard[2])
 
199
                                higherVal[splitGuard[0]] = int(splitGuard[2])
 
200
                                mccandidate =  int(splitGuard[2])
 
201
 
 
202
                            #Set the max constant
 
203
                            if mccandidate > self.maxconstant:
 
204
                                self.maxconstant = mccandidate
 
205
 
 
206
                        #logger.debug("Guard: %s ==> %s", tr.guard.value, util.expression_to_python_DEPR(tr.guard.value))
 
207
                        
 
208
                        discreteUpdateStr = " and ".join(varGuard)
 
209
                        if discreteUpdateStr:
 
210
                            guard_code = compile(util.expression_to_python_DEPR(discreteUpdateStr), "<guard>", 'eval')
154
211
                        debug_info['guard_code'] = tr.guard.value
155
212
 
156
213
                    update_code = None
160
217
                        debug_info['update_code'] = tr.assignment.value
161
218
 
162
219
                    list_curtrans = [(target_idx, guard_code, update_code, 
163
 
                        tr.synchronisation.value, debug_info)]
 
220
                        tr.synchronisation.value, debug_info, lowerVal, higherVal, resets)]
164
221
 
165
222
                    self.transitions[t_idx][l_idx] += list_curtrans
166
223
 
167
224
                    if tr.synchronisation.value != "":
168
225
                        self.sync_transitions[t_idx][l_idx][tr.synchronisation.value] += list_curtrans
169
226
 
 
227
        self.infconstant = self.maxconstant + 2 #bad name, it is used for inf + 1
170
228
        #print self.location_labels
171
229
        #print "Transitions:"
172
230
        #print self.transitions
173
231
 
174
 
    def checkInvariant(self, state):
175
 
        for t_idx in self.invariants.keys():
176
 
            cur_inv = self.invariants[t_idx][state.locs[t_idx]]
177
 
            if cur_inv != '' and not eval(cur_inv, self.constants, state):
178
 
                return False
179
 
        return True
180
 
 
181
 
 
182
 
    def trans_successors(self, state, trans_info=False):
 
232
#    def checkInvariant(self, a):
 
233
#        state,_ = a
 
234
#        for t_idx in self.invariants.keys():
 
235
#            cur_inv = self.invariants[t_idx][state.locs[t_idx]]
 
236
#            if cur_inv != '' and not eval(cur_inv, self.constants, state):
 
237
#                return False
 
238
#        return True
 
239
 
 
240
 
 
241
    def trans_successors(self, state, w, p, trans_info=False):
183
242
        #Take an active transition
184
243
        for t_idx in xrange(self.num_templates):
185
 
            for (target_idx, guard, update, sync, debug_info) in self.transitions[t_idx][state.locs[t_idx]]:
 
244
            for (target_idx, guard, update, sync, debug_info, lowerVal, higherVal, resets) in self.transitions[t_idx][state.locs[t_idx]]:
 
245
 
 
246
                #logger.debug("")
 
247
                #logger.debug("")
 
248
                #logger.debug("New state")
 
249
                #logger.debug("%s - %s - %s", state, w, p)
186
250
                #Evaluate guard
187
 
                if guard:
188
 
                    logger.debug("Evaluating guard: %s on %s: ", debug_info['guard_code'], state)
189
 
                if guard == None or eval(guard, self.constants, state):
190
 
                    #If synchronisation, find trans to sync with
191
 
                    if sync != '':
192
 
                        if sync[-1] == '?':
193
 
                            #Only look for matching pair to ! syncs
194
 
                            continue
195
 
                        brother_sync = sync[:-1] + '?'
196
 
                        for t2_idx in xrange(self.num_templates):
197
 
                            for (target2_idx, guard2, update2, sync2, debug_info2) in self.sync_transitions[t2_idx][state.locs[t2_idx]][brother_sync]:
198
 
                                if guard2 == None or eval(guard2, self.constants, state):
199
 
                                    #sync found
200
 
                                    a = state.copy()
201
 
                                    a.locs[t_idx] = target_idx
202
 
                                    a.locs[t2_idx] = target2_idx
203
 
                                    #XXX, handle if update and update2 are overlapping
204
 
                                    if update != None:
205
 
                                        try:
206
 
                                            exec update in self.constants, a
207
 
                                        except Exception, e:
208
 
                                            raise VirtualMachineException('Executing "' + debug_info['update_code'] + '": ' + str(e))
209
 
                                    if update2 != None:
210
 
                                        try:
211
 
                                            exec update2 in self.constants, a
212
 
                                        except Exception, e:
213
 
                                            raise VirtualMachineException('Executing "' + debug_info2['update_code'] + '": ' + str(e))
214
 
                                    yield trans_info and (a, sync[:-1]) or a
215
 
 
216
 
                    #Not synchronising
217
 
                    else:
218
 
                        a = state.copy()
219
 
                        a.locs[t_idx] = target_idx
220
 
                        #Execute update
221
 
                        logger.debug("Executing update: %s on %s: ", debug_info['update_code'], a)
222
 
                        if update != None:
223
 
                            try:
224
 
                                exec update in self.constants, a
225
 
                            except Exception, e:
226
 
                                raise VirtualMachineException('Executing "' + debug_info['update_code'] + '": ' + str(e))
227
 
                        #print "Result:", a.vars
228
 
                        yield trans_info and (a, "") or a
229
 
    
230
 
    def delay_successors_1step(self, state):
231
 
        """Generate the 1 time unit delay successor"""
232
 
        a = state.copy()
233
 
        delay_has_effect = False
234
 
        for (c, cmax) in self.clocks:
235
 
            if a.clocks[c] <= cmax:
236
 
                a.clocks[c] = a.clocks[c]+1
237
 
                delay_has_effect = True
238
 
        if delay_has_effect:
 
251
 
 
252
                #Evalulate discrete Guard
 
253
                if guard != '' and guard != None and not eval(guard, self.constants, state):
 
254
                    continue
 
255
 
 
256
               
 
257
                start = 0
 
258
                for c in lowerVal.keys():
 
259
                    guard = lowerVal[c] - state[c]
 
260
                    if guard > start:
 
261
                        start = guard
 
262
 
 
263
                if start < 0 or p <= start: #Terminate early if can be satisfyed or if already passed
 
264
                    return
 
265
 
 
266
                end = self.maxconstant + 1
 
267
                for c in higherVal.keys():
 
268
                    guard = higherVal[c] - state[c]
 
269
                    if guard < end:
 
270
                        end = guard
 
271
                 
 
272
 
 
273
                #logger.debug("lowerVals %s, higerVals %s", lowerVal, higherVal)
 
274
                #logger.debug("The state has start %s and end %s ", start, end)
 
275
        
 
276
                #Check if the guard can be satisfied
 
277
                if  start > end or w > end:
 
278
                    return
 
279
        
 
280
                #Calculate discrete update
 
281
                newState = state.copy()
 
282
                newState.locs[t_idx] = target_idx
 
283
                #Calculate successor states
 
284
                #logger.debug("Successors:")
 
285
                if update == None:
 
286
                    for c, _ in self.clocks:
 
287
                        newVal = newState[c] + w
 
288
                        if newVal > self.maxconstant + 1: #Normalize if larger than max constant 
 
289
                            newVal = self.maxconstant + 1
 
290
                        newState[c] = newVal - w
 
291
                        
 
292
                    yield (newState, w, self.infconstant) # We only need to normalize if we had a open higher guard
 
293
                else: #Is allready passed?
 
294
                    if p < end:
 
295
                        end = p - 1 # Minus one as value p is passed
 
296
                    #yield many states
 
297
                    for i in range(start, end+1): #dont need infin as max is mc+1
 
298
                        a = newState.copy()
 
299
                        for c,_ in self.clocks:
 
300
                            tmp = a[c] + i
 
301
                            if tmp > self.maxconstant+1:
 
302
                                 tmp = self.maxconstant+1
 
303
                            a[c] = tmp
 
304
                        #for r in resets: #This is wrong collected wrong data
 
305
                        #    a[r] = 0
 
306
                        exec update in self.constants, a 
 
307
                        #if normalize:
 
308
                        #    s1, w1, p1 = self.normalize(a,0,self.infconstant)
 
309
                        #else:
 
310
                        s1, w1, p1 = a,0,self.infconstant
 
311
                        #logger.debug("%s - %s - %s", s1, w1, p1)
 
312
                        yield s1, w1, p1 
 
313
 
 
314
    def successors(self, state, w, p):
 
315
        for a in self.trans_successors(state, w, p):
 
316
            #if self.checkInvariant(a):
239
317
            yield a
240
 
 
241
 
    def delay_successors_allsteps(self, state):
242
 
        """Generate all delay successors"""
243
 
        a = state.copy()
244
 
        delay_has_effect = True
245
 
        #Using invariants we might shortcut this loop when 
246
 
        # we see at state that violates the invariant as all guards are on the
247
 
        # form < and <= 
248
 
        while delay_has_effect:
249
 
            delay_has_effect = False
250
 
            for (c, cmax) in self.clocks:
251
 
                if a.clocks[c] <= cmax:
252
 
                    a.clocks[c] = a.clocks[c]+1
253
 
                    delay_has_effect = True
254
 
            if delay_has_effect:
255
 
                yield a
256
 
                a = a.copy()
257
 
 
258
 
    def delay_successors_interesting(self, state):
259
 
        """Generate the next "interesting" delay successor, that is the next
260
 
        state that has an enabled transition."""
261
 
        a = state.copy()
262
 
        delay_has_effect = True
263
 
        while delay_has_effect:
264
 
            delay_has_effect = False
265
 
            for (c, cmax) in self.clocks:
266
 
                if a.clocks[c] <= cmax:
267
 
                    a.clocks[c] = a.clocks[c]+1
268
 
                    delay_has_effect = True
269
 
            if delay_has_effect:
270
 
                #Check if delay is only option in new state
271
 
                for b in self.trans_successors(a):
272
 
                    yield a
273
 
                    #There was an option that was not delay,
274
 
                    #Stop generating delay successors
275
 
                    return
276
 
 
277
 
    def successors(self, state):
278
 
        for a in self.trans_successors(state):
279
 
            if self.checkInvariant(a):
280
 
                yield a
281
 
        for a in self.delay_successors_1step(state):
282
 
            if self.checkInvariant(a):
283
 
                yield a
284
 
 
285
 
    def successors_transinfo(self, state):
286
 
        for (a, ti) in self.trans_successors(state, trans_info=True):
287
 
            if self.checkInvariant(a):
288
 
                yield (a, ti)
289
 
        for a in self.delay_successors_1step(state):
290
 
            if self.checkInvariant(a):
291
 
                yield (a, "")
 
318
        #for a in self.delay_successors_1step(state):
 
319
        #    if self.checkInvariant(a):
 
320
        #        yield a
292
321
 
293
322
    def get_initialstate(self):
294
323
        #calculate initial state
339
368
                initstate.clocks[c] = 0
340
369
 
341
370
        logger.info("Initial state: locs: %s vars: %s", initstate.locs, initstate.vars)
342
 
        return initstate
 
371
        return initstate, 0, self.maxconstant
343
372
 
344
373
# vim:ts=4:sw=4:expandtab