learning.py 25.9 KB
Newer Older
1
from .observationtable import Table
2
import random
3
from systems.implementations import SuspensionAutomaton
4
5
import os, inspect
import helpers.graphhelper as gh
Michele's avatar
Michele committed
6
import logging
7
import helpers.traces as th
8
from systems.iopurpose import InputPurpose, OutputPurpose
9
10
11

class LearningAlgorithm:

12
    def __init__(self, teacher, oracle, tester, tablePreciseness = 1000,
13
14
15
             modelPreciseness = 0.1, closeStrategy = None,
             printPath = None, maxLoops=10, logger=None, outputPurpose=None,
             inputPurpose=None):
16

Michele's avatar
Michele committed
17
        self._logger = logger or logging.getLogger(__name__)
18

19
20
21
22
23
24
25
26
        self._inputPurpose = inputPurpose
        self._outputPurpose = outputPurpose

        if self._inputPurpose == None:
            self._inputPurpose = InputPurpose(teacher.getInputAlphabet().copy())
        if self._outputPurpose == None:
            self._outputPurpose = OutputPurpose(teacher.getOutputAlphabet().union(set((teacher.getQuiescence(),))))

27
28
        self._teacher = teacher
        self._oracle = oracle
29
        self.tester = tester
30
31
        self._tablePreciseness = tablePreciseness
        self._modelPreciseness = modelPreciseness
32

33
34
35
36
37
38
        self._table = Table(teacher.getInputAlphabet().copy(),
                            teacher.getOutputAlphabet().copy(),
                            teacher.getQuiescence(),
                            closeStrategy, logger=logger,
                            outputPurpose=self._outputPurpose,
                            inputPurpose=self._inputPurpose)
39

Michele's avatar
Michele committed
40
        # Maximum number of loops with no effect on hPlus model
41
        self._noEffectLimit = maxLoops
42

43
44
        # Current number of loops
        self._currentLoop = 0
45
46
47
48
49
50
51
52
        outputs = self._teacher.getOutputAlphabet()
        self._hMinus = SuspensionAutomaton(1,
                                    self._teacher.getInputAlphabet().copy(),
                                    self._teacher.getOutputAlphabet().copy(),
                                    self._teacher.getQuiescence())
        self._hPlus = SuspensionAutomaton(1,
                                    self._teacher.getInputAlphabet().copy(),
                                    self._teacher.getOutputAlphabet().copy(),
53
54
55
56
57
                                    self._teacher.getQuiescence(),
                                    chaos = True)

        self._printPath = printPath

58
59
    # this update uses a realistic teacher. If I need an output to happen I
    # cannot force it to happen.
60
    def updateTable(self):
61
62
63
64
65
66
        # First, try to avoid impossible traces: ask observation query
        for trace in self._table.getObservableTraces():
            observedOutputs = self._table.getOutputs(trace)
            observation = self._oracle.observation(trace, observedOutputs)
            if observation:
                self._table.updateEntry(trace, observation=observation)
67

68
        # For all traces for which we did not observed all possible outputs
69
        oTraces = self._table.getObservableTraces()
70

71
        trie = th.make_trie(oTraces)
72

73
        # Until we tried K times with no results
74
        K = len(oTraces) * 150 # TODO: should not be hardcoded
75
76
77
        found = 0
        tries = 0
        while tries < K:
78

79
80
            tries += 1
            oTraces = self._table.getObservableTraces()
81

82
83
84
85
86
87
            trie = th.make_trie(oTraces)
            subtrie = trie

            # if no trace is observable (best scenario)
            if len(oTraces) == 0:
                break
88

89
90
91
92
93
94
95
96
97
            observations = {}       # Dictionary with obtained outputs
            consecutiveInputs = ()  # keep track of last inputs sequence
            currentTrace = ()       # keep track of current trace

            i = 0
            # We build a trace until we either
            #         1 - observe an output that makes the trace not a prefix
            #         2 - there is no continuation of that trace in prefixes
            # We stop when we observed at least an output for each observable
98
            while len(oTraces) > len(observations.keys()): #and i < K:
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
                i += 1

                # check if trie contains no traces (but still has a child)
                children = trie.keys()
                hasTrace = False
                for child in children:
                    if trie[child] != {}:
                        hasTrace = True
                if not hasTrace:
                    break

                # if currentTrace is observable and we did not process it
                # already, we ask an output and we add the result to
                # observations[currentTrace]
                if (currentTrace in oTraces and
                    currentTrace not in observations.keys()):
                    # there might be some inputs waiting to be processed
                    if consecutiveInputs != ():
                        output = self._processInputs(consecutiveInputs)
                        # reset the inputs, because we processed them
                        consecutiveInputs = ()
                        if output == None:
                            # SUT not input enabled: reset
                            currentTrace = ()
                            subtrie = trie
                            self._teacher.reset()
                            continue
                    else:
                        # no input to process, ask an output
                        output = self._teacher.output()
129

130
131
132
                    # we have an output for currentTrace, add it to observations
                    # this is the first output we observe for currentTrace
                    observations[currentTrace] = set([output])
133

134
135
                    # remove currentTrace from trie
                    th.remove_from_trie(trie, currentTrace)
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                    # if that output is not a valid continuation
                    if output not in subtrie.keys():
                        # reset the process
                        currentTrace = ()
                        subtrie = trie
                        self._teacher.reset()
                        continue

                    # navigate trie
                    subtrie = subtrie[output]
                    currentTrace = currentTrace + (output,)

                else:
                    # currentTrace not observable, or already observed
                    # get an input from subtries
                    children = subtrie.keys()
                    inputChildren = [x for x in children \
                                    if x in self._teacher.getInputAlphabet()]
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
                    if len(inputChildren) > 0:
                        # process this input, add it to consecutiveInputs
                        # and navigate subtrie
                        input = random.sample(inputChildren,1)[0]
                        consecutiveInputs = consecutiveInputs + (input,)
                        subtrie = subtrie[input]
                        currentTrace = currentTrace + (input,)
                        continue
                    else:
                        # no inputs available, wait for output
                        # there might be some inputs waiting to be processed
                        if consecutiveInputs != ():
                            output = self._processInputs(consecutiveInputs)
                            # reset the inputs, because we processed them
                            consecutiveInputs = ()
                            if output == None:
                                # SUT not input enabled: reset
                                currentTrace = ()
                                subtrie = trie
                                self._teacher.reset()
                                continue
                        else:
                            # no input to process, ask an output
                            output = self._teacher.output()
180

181
182
183
184
185
186
187
188
189
190
191
                        # we have an output for currentTrace,
                        # if currentTrace is in otraces add it to observations
                        if currentTrace in oTraces:
                            observations[currentTrace].add(output)

                        # remove currentTrace from trie
                        if th.in_trie(trie, currentTrace):
                            th.remove_from_trie(trie, currentTrace)

                        # if that output is not a valid continuation
                        if output not in subtrie.keys():
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
                            # reset the process
                            currentTrace = ()
                            subtrie = trie
                            self._teacher.reset()
                            continue

                        # navigate trie
                        subtrie = subtrie[output]
                        currentTrace = currentTrace + (output,)

            # end while loop
            # observations contains observed outputs

            found += len(observations.keys())

            for trace in observations.keys():
                # Only if trace is a prefix in S, then
                # add trace + output to row (S cdot L_delta)
                if self._table.isInS(trace):
                    for output in observations[trace]:
                        self._table.addOneLetterExtension(trace, output)
214

215
216
217
                # Update set of outputs for traces where deltas are removed
                for deltaTrace in self._table.getDeltaTraces(trace):
                    for output in observations[trace]:
218
                        self._table.updateEntry(deltaTrace, output=output)
219

220
                for output in observations[trace]:
221
                    self._table.updateEntry(trace, output=output)
222

223
224
225
226
227
            # Observation query
            # ask observation query for all entries because I could have added
            # some 'impossible' traces
            for trace in self._table.getObservableTraces():
                observedOutputs = self._table.getOutputs(trace)
228
                observation = self._oracle.observation(trace, observedOutputs)
229
230
231
                if observation:
                    self._table.updateEntry(trace, observation=observation)

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    # # this update function uses teacher.process(trace)
    # # in case a InputOutputTeacher is used, outputs in trace are forced to happen
    # # this is not realistic, but still useful at the moment.
    # def oldUpdateTable(self):
    #     temp = 0
    #     tot = 0
    #     for c in range(200):
    #         for trace in self._table.getObservableTraces():
    #             observedOutputs = self._table.getOutputs(trace)
    #             output = self._teacher.process(trace)
    #             for i in range(10):
    #                 # try again if retrieving output is unsuccesful
    #                 if output != None:
    #                     break
    #                 output = self._teacher.process(trace)
    #             tot += 1
    #             if output != None:
    #                 # Only if trace is a prefix in S, then
    #                 # add trace + output to row (S cdot L_delta)
    #                 if self._table.isInS(trace):
    #                     self._table.addOneLetterExtension(trace, output)
    #
    #                 # Update set of outputs for traces where deltas are removed
    #                 for deltaTrace in self._table.getDeltaTraces(trace):
    #                     self._table.updateEntry(deltaTrace, output)
    #
    #                 # Add this output to the set of outputs observed after trace
    #                 observedOutputs.add(output)
    #             else:
    #                 temp += 1
    #
    #             observation = self._oracle.observation(trace, observedOutputs)
    #             self._table.updateEntry(trace, output, observation)
265

266
267
268
269
270
271
272


    def _processInputs(self, consecutiveInputs):
        if consecutiveInputs != ():
            output = self._teacher.oneOutput(consecutiveInputs)
            if output == None:
                # SUT did not accept an input.
273
                self._logger.warning("SUT did not accept input in " + str(consecutiveInputs))
274
275
276
                return None
            return output
        return self._teacher.output()
277

278
279

    def stabilizeTable(self):
280
        # While nothing changes, keep closing and consistent the table
281
282
283
        closingRows = self._table.isNotGloballyClosed()
        consistentCheck = self._table.isNotGloballyConsistent()
        while closingRows or consistentCheck:
Michele's avatar
Michele committed
284
            while closingRows:
Michele's avatar
Michele committed
285
286
                self._logger.info("Closing table")
                self._logger.debug(closingRows)
287
                self._table.promote(closingRows)
Michele's avatar
Michele committed
288
289
                # After promoting one should check if some one letter
                # extensions should also be added
Michele's avatar
Michele committed
290
                if self._table.addOneLetterExtensions(closingRows):
Michele's avatar
Michele committed
291
292
                    self._logger.info("something changed")
                if self._logger.isEnabledFor(logging.DEBUG):
293
294
                    self._table.printTable(prefix="_c_")
                self.updateTable()
295
296
                closingRows = self._table.isNotGloballyClosed()
                consistentCheck = self._table.isNotGloballyConsistent()
Michele's avatar
Michele committed
297
            # Table is closed, check for consistency
298
            if consistentCheck:
Michele's avatar
Michele committed
299
300
                self._logger.info("Consistency check")
                self._logger.debug(consistentCheck)
301
                if self._table.addColumn(consistentCheck, force=True):
Michele's avatar
Michele committed
302
303
304
                    self._logger.info("something changed")
                if self._logger.isEnabledFor(logging.DEBUG):
                    self._table.printTable(prefix="_i_")
305
306
307
308
309
                # TODO: is an update needed here? in theory, when I encounter
                # an inconsistency, by adding a column, the interesting row
                # will immediately make the table not closed, no need of
                # update, right?
                #self.updateTable()
310
311
                closingRows = self._table.isNotGloballyClosed()
                consistentCheck = self._table.isNotGloballyConsistent()
312

Michele's avatar
Michele committed
313
    def getHypothesis(self, chaos=False):
314
        # If table is not closed, ERROR
315
        if self._table.isNotGloballyClosed():
Michele's avatar
Michele committed
316
317
            self._logger.error("Tried to get hipotheses with table not \
                                closed or not consistent")
318
            return None, None
Michele's avatar
Michele committed
319
320
        # Get equivalence classes
        rows = self._table.getEquivalenceClasses(chaos)
321
322
323
        hyp = SuspensionAutomaton(len(rows),
                                    self._teacher.getInputAlphabet().copy(),
                                    self._teacher.getOutputAlphabet().copy(),
Michele's avatar
Michele committed
324
325
                                    self._teacher.getQuiescence(),
                                    chaos)
326
327

        # assign to each equivalence class a state number
328
        # start with equivalence class of empty trace to 0
329
330
331
332
333
334
335
336
337
        assignments = {():0}
        count = 1
        for row in rows:
            if row != ():
                assignments[row] = count
                count = count + 1

        # add transitions
        for row in rows:
338
339
            allLabels = self._getAllLabels(row)

340
            for label in allLabels:
341
342
343
344
                # create row and search it in the table
                extension = row + (label,)
                if self._table.isInRows(extension):
                    for target in rows:
345
                        found = False
346
347
348
349
350
                        # TODO: the method of table called at next line is
                        # private. Change to public, or add a public version
                        if self._table._rowEquality(extension, target, chaos):
                            hyp.addTransition(assignments[row], label,
                                              assignments[target])
351
352
353
                            found = True
                            break
                    if not found:
354
                        self._logger.warning("Chaotic behaviour")
355
356
357
                        # Either the table is not closed, or
                        # First column of extension has an empty set.
                        # If label is an input, then send it to ChaosDelta
358
                        #   - unless it is not enabled.
359
360
361
362
363
                        # If it is an output, then send it to Chaos
                        # It cannot be quiescence
                        if label in  self._teacher.getInputAlphabet():
                            hyp.addTransition(assignments[row], label,
                                              hyp.getChaosDelta())
364
                        elif label in self._table.getPossibleOutputs(row):
365
366
                            hyp.addTransition(assignments[row], label,
                                              hyp.getChaos())
367
368
369
370
371
372
373
                elif (len(row) > 0 and
                      row[-1] == hyp.getQuiescence() and
                      label == hyp.getQuiescence()):
                    # fixing issue #2: non existing row because it ends with
                    # a sequence of quiescence
                    hyp.addTransition(assignments[row], label,
                                      assignments[row])
374
375
                elif (chaos and label in self._table.getPossibleOutputs(row)):
                    self._logger.warning("Chaotic behaviour 2")
376
                    # Add transitions to chaotic state if necessary
Michele's avatar
Michele committed
377
378
379
380
381
382
383
                    if row in self._table.getObservableTraces():
                        if label != hyp.getQuiescence():
                            hyp.addTransition(assignments[row], label,
                                              hyp.getChaos())
                        else:
                            hyp.addTransition(assignments[row], label,
                                              hyp.getChaosDelta())
Michele's avatar
Michele committed
384
385
        return hyp

386
    # Get all labels enabled after row
387
    # check input enabledness
388
    def _getAllLabels(self, row):
389
390
391
392
393
394
395
        enabledInputs = self._teacher.getInputAlphabet()
        enabledOutputs = self._teacher.getOutputAlphabet()
        if self._inputPurpose != None:
            enabledInputs = self._inputPurpose.getEnabled(row)
        if self._outputPurpose != None:
            enabledOutputs = self._outputPurpose.getEnabled(row)
        allLabels = enabledInputs.union(enabledOutputs,
396
397
398
                                        set((self._teacher.getQuiescence(),)))
        return allLabels

399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    # Generate DOT files for hypotheses. hyp = hMinus|hPlus|both
    def generateDOT(self, path=None, hyp="both", format="pdf"):
        if path == None:
            path = os.path.dirname(os.path.abspath(
                            inspect.getfile(inspect.currentframe())))
        pathHminus = os.path.join(path, "hypotheses", "loop_" +
                                  str(self._currentLoop), "hMinus")
        pathHplus = os.path.join(path, "hypotheses", "loop_" +
                                 str(self._currentLoop), "hPlus")
        if hyp != None:
            if hyp == "hMinus" or hyp == "both":
                gh.createDOTFile(self._hMinus, pathHminus, format)
            if hyp == "hPlus" or hyp == "both":
                gh.createDOTFile(self._hPlus, pathHplus, format)

414
415
    # Start the learning process
    def run(self):
416
417
418
419
420
421
422
423
        noEffectLoops = 0

        oldTablePreciseness = self._table.preciseness()
        oldModelPreciseness = 0

        # While we have not reached our limit for lopps with no effect on
        # the table
        while noEffectLoops < self._noEffectLimit:
424
            self._currentLoop = self._currentLoop + 1
425
            self._logger.info("Learning loop number " + str(self._currentLoop))
426
            # Fill the table and make it closed and consistent
427
428
            self.updateTable()
            self.stabilizeTable()
429
430
            # Is the table quiescence reducible? If not make it so and
            # then fill it again, make it closed and consitent
431
432
433
            newSuffixes = self._table.isNotQuiescenceReducible()
            while (newSuffixes or not self._table.isStable()):
                if newSuffixes:
Michele's avatar
Michele committed
434
                    self._logger.info("Quiescence reducible")
435
                    self._logger.debug(newSuffixes)
Michele's avatar
Michele committed
436
437
                    if self._logger.isEnabledFor(logging.DEBUG):
                        self._table.printTable(prefix="_Q_")
438
                    if self._table.addColumn(newSuffixes, force=True):
Michele's avatar
Michele committed
439
440
441
442
                        self._logger.info("somethig changed")
                        if self._logger.isEnabledFor(logging.DEBUG):
                            self._table.printTable(prefix="_Qafter_")
                        self.updateTable()
443
                self.stabilizeTable()
444
                newSuffixes = self._table.isNotQuiescenceReducible()
445

446
447
448
449
450
            self._hMinus = self.getHypothesis()
            self._hPlus = self.getHypothesis(chaos=True)

            if self._printPath != None:
                self.generateDOT(path=self._printPath)
Michele's avatar
Michele committed
451
                self._table.printTable(self._printPath,
452
                                       prefix="loop_"+str(self._currentLoop))
Michele's avatar
Michele committed
453

454
455
456
457
458
459
460
            # Table preciseness will increase easily, it does not depend
            # on nondeterministic choices. Adding rows and columns to the
            # table is enough to increase it.
            # On the contrary, model preciseness will increase only if
            # something really changed in the behaviour of the system.
            # If model preciseness remain the same for too long, then we
            # might want to stop learning.
461
            currentModelPreciseness = self._hPlus.preciseness()
462
463
464
465
            if oldModelPreciseness == currentModelPreciseness:
                # TODO: problem here. These two values could be the same even
                # if the behaviour of the system changed. This might happen in
                # rare cases. For the moment I consider those cases so rare
Michele's avatar
Michele committed
466
                # that I do not handle them.
467
468
469
470
471
472
473
                noEffectLoops = noEffectLoops + 1
            else:
                oldModelPreciseness = currentModelPreciseness
                noEffectLoops = 0

            if (self._table.preciseness() < self._tablePreciseness or
                currentModelPreciseness < self._modelPreciseness):
474
                self._logger.info("Requested table preciseness: " +
475
476
                                str(self._tablePreciseness) + ". Actual: "+
                                str(self._table.preciseness()))
477
478
479
                self._logger.info("Requested model preciseness: " +
                                str(self._modelPreciseness) + ". Actual: "+
                                str(currentModelPreciseness))
480
481
482
483
484
485
486
487

                samples = set(['extendTable', 'testing'])

                # if model preciseness is 1 (maximum), update does not make sense
                if (currentModelPreciseness) < 1:
                    samples.add('update')

                choice = random.sample(samples, 1)[0]
488
                # TODO: forcing testing: remove!!
489
                choice = 'testing'
490
                if choice == 'extendTable':
491
                    self._logger.info("Improve preciseness method: extend the table")
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
                    # create a set of suffix closed suffixes and add them
                    # to the table
                    # TODO: are there clever ways to create it?
                    columns = set()
                    # Get maximum length in order to be sure to modify the Table
                    length = self._table.getLengthColumns() + 1
                    # we want to avoid quiescence for the moment TODO
                    inAndOut = self._teacher.getInputAlphabet().union(
                                        self._teacher.getOutputAlphabet())
                    newColumn = ()
                    for i in range(length):
                        label = random.sample(inAndOut,1)[0]
                        newColumn = (label,) + newColumn
                        columns.add(newColumn)
                    self._logger.debug(newColumn)

                    self._table.addColumn(columns,True)

                    # TODO: I do not like the next part very much, commented
                    # for the moment. Maybe just increasing columns (as we
                    # would have found a counterexample) is enough

                    # promote 10% of random rows
                    # rows = [x for x in self._table._rows if \
                    #         not self._table.isInS(x) and self._table.isDefined(x)]
                    # randomRows = random.sample(rows, int(len(rows)*0.1))
                    # self._logger.debug(randomRows)
                    # self._table.promote(randomRows)
                    # self._table.addOneLetterExtensions(randomRows)


523
                elif choice == 'update':
524
                    self._logger.info("Improve preciseness method: update the table")
525
526
527
528
                    # TODO: consider to remove this function call and just
                    # add a continue statement, because it is the first
                    # thing to do in the main while loop. Keep for now
                    # for clarity.
529
                    self.updateTable()
Michele's avatar
Michele committed
530
                elif choice == 'testing':
531
                    self._logger.info("Improve preciseness method: testing")
532
533
534
535
536
537
538
539
540
                    ce, output = self.tester.findCounterexample(self._hMinus)
                    # Handle counterexample
                    if ce:
                        if not self._table.handleCounterexample(ce, output):
                            self._logger.info("Handling counterexample: failed")
                        else:
                            self._logger.info("Handling counterexample: succeded")
                    else:
                        self._logger.info("No counterexample found, proceed")
541
542
                continue
            else:
543
                self._logger.info("Requested table preciseness: " +
544
545
                                str(self._tablePreciseness) + ". Actual: "+
                                str(self._table.preciseness()))
546
547
548
                self._logger.info("Requested model preciseness: " +
                                str(self._modelPreciseness) + ". Actual: "+
                                str(currentModelPreciseness))
549
                self._logger.info("Stop learning")
550
                # Exit while loop, return hMinus and hPlus
551
                break
552
553
        if self._logger.isEnabledFor(logging.DEBUG):
            self._table.printTable(prefix="_final_")
554
        return (self._hMinus, self._hPlus)