learning.py 27.1 KB
Newer Older
Michele's avatar
Michele committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2015 Michele Volpato
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

21
from .observationtable import Table
22
import random
23
from systems.implementations import SuspensionAutomaton
24 25
import os, inspect
import helpers.graphhelper as gh
Michele's avatar
Michele committed
26
import logging
27
import helpers.traces as th
28
from systems.iopurpose import InputPurpose, OutputPurpose
29 30 31

class LearningAlgorithm:

32
    def __init__(self, teacher, oracle, tester, tablePreciseness = 1000,
33 34 35
             modelPreciseness = 0.1, closeStrategy = None,
             printPath = None, maxLoops=10, logger=None, outputPurpose=None,
             inputPurpose=None):
36

Michele's avatar
Michele committed
37
        self._logger = logger or logging.getLogger(__name__)
38

39 40 41
        self._inputPurpose = inputPurpose
        self._outputPurpose = outputPurpose

42 43
        # If input (output) purpose is not defined, then add one that
        # returns always all inputs (outputs)
44 45 46 47 48
        if self._inputPurpose == None:
            self._inputPurpose = InputPurpose(teacher.getInputAlphabet().copy())
        if self._outputPurpose == None:
            self._outputPurpose = OutputPurpose(teacher.getOutputAlphabet().union(set((teacher.getQuiescence(),))))

49 50
        self._teacher = teacher
        self._oracle = oracle
51
        self.tester = tester
52 53
        self._tablePreciseness = tablePreciseness
        self._modelPreciseness = modelPreciseness
54

55 56 57 58 59 60
        self._table = Table(teacher.getInputAlphabet().copy(),
                            teacher.getOutputAlphabet().copy(),
                            teacher.getQuiescence(),
                            closeStrategy, logger=logger,
                            outputPurpose=self._outputPurpose,
                            inputPurpose=self._inputPurpose)
61

Michele's avatar
Michele committed
62
        # Maximum number of loops with no effect on hPlus model
63
        self._noEffectLimit = maxLoops
64

65 66
        # Current number of loops
        self._currentLoop = 0
67 68 69 70 71 72 73 74
        outputs = self._teacher.getOutputAlphabet()
        self._hMinus = SuspensionAutomaton(1,
                                    self._teacher.getInputAlphabet().copy(),
                                    self._teacher.getOutputAlphabet().copy(),
                                    self._teacher.getQuiescence())
        self._hPlus = SuspensionAutomaton(1,
                                    self._teacher.getInputAlphabet().copy(),
                                    self._teacher.getOutputAlphabet().copy(),
75 76 77 78 79
                                    self._teacher.getQuiescence(),
                                    chaos = True)

        self._printPath = printPath

80 81
    # this update uses a realistic teacher. If I need an output to happen I
    # cannot force it to happen.
82
    def updateTable(self):
83 84 85 86 87 88
        # First, try to avoid impossible traces: ask observation query
        for trace in self._table.getObservableTraces():
            observedOutputs = self._table.getOutputs(trace)
            observation = self._oracle.observation(trace, observedOutputs)
            if observation:
                self._table.updateEntry(trace, observation=observation)
89

90
        # For all traces for which we did not observed all possible outputs
91
        oTraces = self._table.getObservableTraces()
92

93
        trie = th.make_trie(oTraces)
94

95 96
        # Until we tried K times with no results, where K is the number of
        # observable traces times the number of outputs (including quiescence)
Michele's avatar
Michele committed
97
        K = len(oTraces) * 35 # (len(self._teacher.getOutputAlphabet()) + 1) # TODO comment from *
98 99 100
        found = 0
        tries = 0
        while tries < K:
101

102 103
            tries += 1
            oTraces = self._table.getObservableTraces()
104

105 106 107 108 109 110
            trie = th.make_trie(oTraces)
            subtrie = trie

            # if no trace is observable (best scenario)
            if len(oTraces) == 0:
                break
111

112 113 114 115 116 117 118 119
            observations = {}       # Dictionary with obtained outputs
            consecutiveInputs = ()  # keep track of last inputs sequence
            currentTrace = ()       # keep track of current trace

            # We build a trace until we either
            #         1 - observe an output that makes the trace not a prefix
            #         2 - there is no continuation of that trace in prefixes
            # We stop when we observed at least an output for each observable
120
            while len(oTraces) > len(observations.keys()):
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

                # check if trie contains no traces (but still has a child)
                children = trie.keys()
                hasTrace = False
                for child in children:
                    if trie[child] != {}:
                        hasTrace = True
                if not hasTrace:
                    break

                # if currentTrace is observable and we did not process it
                # already, we ask an output and we add the result to
                # observations[currentTrace]
                if (currentTrace in oTraces and
                    currentTrace not in observations.keys()):
                    # there might be some inputs waiting to be processed
                    if consecutiveInputs != ():
                        output = self._processInputs(consecutiveInputs)
                        # reset the inputs, because we processed them
                        consecutiveInputs = ()
                        if output == None:
                            # SUT not input enabled: reset
                            currentTrace = ()
                            subtrie = trie
                            self._teacher.reset()
                            continue
                    else:
                        # no input to process, ask an output
                        output = self._teacher.output()
150

151 152 153
                    # we have an output for currentTrace, add it to observations
                    # this is the first output we observe for currentTrace
                    observations[currentTrace] = set([output])
154

155 156
                    # remove currentTrace from trie
                    th.remove_from_trie(trie, currentTrace)
157

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
                    # if that output is not a valid continuation
                    if output not in subtrie.keys():
                        # reset the process
                        currentTrace = ()
                        subtrie = trie
                        self._teacher.reset()
                        continue

                    # navigate trie
                    subtrie = subtrie[output]
                    currentTrace = currentTrace + (output,)

                else:
                    # currentTrace not observable, or already observed
                    # get an input from subtries
                    children = subtrie.keys()
                    inputChildren = [x for x in children \
                                    if x in self._teacher.getInputAlphabet()]
176

177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
                    if len(inputChildren) > 0:
                        # process this input, add it to consecutiveInputs
                        # and navigate subtrie
                        input = random.sample(inputChildren,1)[0]
                        consecutiveInputs = consecutiveInputs + (input,)
                        subtrie = subtrie[input]
                        currentTrace = currentTrace + (input,)
                        continue
                    else:
                        # no inputs available, wait for output
                        # there might be some inputs waiting to be processed
                        if consecutiveInputs != ():
                            output = self._processInputs(consecutiveInputs)
                            # reset the inputs, because we processed them
                            consecutiveInputs = ()
                            if output == None:
                                # SUT not input enabled: reset
                                currentTrace = ()
                                subtrie = trie
                                self._teacher.reset()
                                continue
                        else:
                            # no input to process, ask an output
                            output = self._teacher.output()
201

202 203 204 205 206 207 208 209 210 211 212
                        # we have an output for currentTrace,
                        # if currentTrace is in otraces add it to observations
                        if currentTrace in oTraces:
                            observations[currentTrace].add(output)

                        # remove currentTrace from trie
                        if th.in_trie(trie, currentTrace):
                            th.remove_from_trie(trie, currentTrace)

                        # if that output is not a valid continuation
                        if output not in subtrie.keys():
213

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
                            # reset the process
                            currentTrace = ()
                            subtrie = trie
                            self._teacher.reset()
                            continue

                        # navigate trie
                        subtrie = subtrie[output]
                        currentTrace = currentTrace + (output,)

            # end while loop
            # observations contains observed outputs

            found += len(observations.keys())

            for trace in observations.keys():
                # Only if trace is a prefix in S, then
                # add trace + output to row (S cdot L_delta)
                if self._table.isInS(trace):
                    for output in observations[trace]:
                        self._table.addOneLetterExtension(trace, output)
235

236 237 238
                # Update set of outputs for traces where deltas are removed
                for deltaTrace in self._table.getDeltaTraces(trace):
                    for output in observations[trace]:
239
                        self._table.updateEntry(deltaTrace, output=output)
240

241
                for output in observations[trace]:
242
                    self._table.updateEntry(trace, output=output)
243

244 245 246 247 248
            # Observation query
            # ask observation query for all entries because I could have added
            # some 'impossible' traces
            for trace in self._table.getObservableTraces():
                observedOutputs = self._table.getOutputs(trace)
249
                observation = self._oracle.observation(trace, observedOutputs)
250 251 252
                if observation:
                    self._table.updateEntry(trace, observation=observation)

253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    # # this update function uses teacher.process(trace)
    # # in case a InputOutputTeacher is used, outputs in trace are forced to happen
    # # this is not realistic, but still useful at the moment.
    # def oldUpdateTable(self):
    #     temp = 0
    #     tot = 0
    #     for c in range(200):
    #         for trace in self._table.getObservableTraces():
    #             observedOutputs = self._table.getOutputs(trace)
    #             output = self._teacher.process(trace)
    #             for i in range(10):
    #                 # try again if retrieving output is unsuccesful
    #                 if output != None:
    #                     break
    #                 output = self._teacher.process(trace)
    #             tot += 1
    #             if output != None:
    #                 # Only if trace is a prefix in S, then
    #                 # add trace + output to row (S cdot L_delta)
    #                 if self._table.isInS(trace):
    #                     self._table.addOneLetterExtension(trace, output)
    #
    #                 # Update set of outputs for traces where deltas are removed
    #                 for deltaTrace in self._table.getDeltaTraces(trace):
    #                     self._table.updateEntry(deltaTrace, output)
    #
    #                 # Add this output to the set of outputs observed after trace
    #                 observedOutputs.add(output)
    #             else:
    #                 temp += 1
    #
    #             observation = self._oracle.observation(trace, observedOutputs)
    #             self._table.updateEntry(trace, output, observation)
286

287 288 289 290 291 292 293


    def _processInputs(self, consecutiveInputs):
        if consecutiveInputs != ():
            output = self._teacher.oneOutput(consecutiveInputs)
            if output == None:
                # SUT did not accept an input.
294
                self._logger.warning("SUT did not accept input in " + str(consecutiveInputs))
295 296 297
                return None
            return output
        return self._teacher.output()
298

299 300

    def stabilizeTable(self):
301
        # While nothing changes, keep closing and consistent the table
302 303 304
        closingRows = self._table.isNotGloballyClosed()
        consistentCheck = self._table.isNotGloballyConsistent()
        while closingRows or consistentCheck:
Michele's avatar
Michele committed
305
            while closingRows:
Michele's avatar
Michele committed
306
                self._logger.debug("Table is not closed")
Michele's avatar
Michele committed
307
                self._logger.debug(closingRows)
308
                self._table.promote(closingRows)
Michele's avatar
Michele committed
309 310
                # After promoting one should check if some one letter
                # extensions should also be added
311
                self._table.addOneLetterExtensions(closingRows)
Michele's avatar
Michele committed
312
                if self._logger.isEnabledFor(logging.DEBUG):
313 314
                    self._table.printTable(prefix="_c_")
                self.updateTable()
315 316
                closingRows = self._table.isNotGloballyClosed()
                consistentCheck = self._table.isNotGloballyConsistent()
Michele's avatar
Michele committed
317
            # Table is closed, check for consistency
318
            if consistentCheck:
Michele's avatar
Michele committed
319
                self._logger.debug("Table is not consistent")
Michele's avatar
Michele committed
320
                self._logger.debug(consistentCheck)
321
                self._table.addColumn(consistentCheck, force=True)
Michele's avatar
Michele committed
322 323
                if self._logger.isEnabledFor(logging.DEBUG):
                    self._table.printTable(prefix="_i_")
324 325 326 327 328
                # TODO: is an update needed here? in theory, when I encounter
                # an inconsistency, by adding a column, the interesting row
                # will immediately make the table not closed, no need of
                # update, right?
                #self.updateTable()
329 330
                closingRows = self._table.isNotGloballyClosed()
                consistentCheck = self._table.isNotGloballyConsistent()
331

Michele's avatar
Michele committed
332
    def getHypothesis(self, chaos=False):
333
        # If table is not closed, ERROR
334
        if self._table.isNotGloballyClosed():
335
            self._logger.error("Tried to get hypotheses with table not \
Michele's avatar
Michele committed
336
                                closed or not consistent")
337
            return None, None
Michele's avatar
Michele committed
338 339
        # Get equivalence classes
        rows = self._table.getEquivalenceClasses(chaos)
340 341 342
        hyp = SuspensionAutomaton(len(rows),
                                    self._teacher.getInputAlphabet().copy(),
                                    self._teacher.getOutputAlphabet().copy(),
Michele's avatar
Michele committed
343 344
                                    self._teacher.getQuiescence(),
                                    chaos)
345
        # assign to each equivalence class a state number
346
        # start with equivalence class of empty trace to 0
347 348 349 350 351 352 353 354 355
        assignments = {():0}
        count = 1
        for row in rows:
            if row != ():
                assignments[row] = count
                count = count + 1

        # add transitions
        for row in rows:
356 357 358
            # TODO reduce allLabels set, use outputExpert
            enabledOuptuts = self._table.getOutputs(row)
            allLabels = self._getAllLabels(row, enabledOuptuts)
359

360
            for label in allLabels:
361 362 363 364
                # create row and search it in the table
                extension = row + (label,)
                if self._table.isInRows(extension):
                    for target in rows:
365
                        found = False
366 367
                        # TODO: the method of table called at next line is
                        # private. Change to public, or add a public version
368
                        if self._table._moreSpecificRow(extension, target, chaos):
369 370
                            hyp.addTransition(assignments[row], label,
                                              assignments[target])
371

372 373 374
                            found = True
                            break
                    if not found:
375
                        self._logger.warning("Chaotic behaviour")
376 377 378
                        # Either the table is not closed, or
                        # First column of extension has an empty set.
                        # If label is an input, then send it to ChaosDelta
379
                        #   - unless it is not enabled.
380 381 382 383 384
                        # If it is an output, then send it to Chaos
                        # It cannot be quiescence
                        if label in  self._teacher.getInputAlphabet():
                            hyp.addTransition(assignments[row], label,
                                              hyp.getChaosDelta())
385
                        elif label in self._table.getPossibleOutputs(row):
386 387
                            hyp.addTransition(assignments[row], label,
                                              hyp.getChaos())
388 389 390 391 392 393 394
                elif (len(row) > 0 and
                      row[-1] == hyp.getQuiescence() and
                      label == hyp.getQuiescence()):
                    # fixing issue #2: non existing row because it ends with
                    # a sequence of quiescence
                    hyp.addTransition(assignments[row], label,
                                      assignments[row])
395 396
                elif (chaos and label in self._table.getPossibleOutputs(row)):
                    self._logger.warning("Chaotic behaviour 2")
397
                    # Add transitions to chaotic state if necessary
Michele's avatar
Michele committed
398 399 400 401 402 403 404
                    if row in self._table.getObservableTraces():
                        if label != hyp.getQuiescence():
                            hyp.addTransition(assignments[row], label,
                                              hyp.getChaos())
                        else:
                            hyp.addTransition(assignments[row], label,
                                              hyp.getChaosDelta())
Michele's avatar
Michele committed
405 406
        return hyp

407
    # Get all labels enabled after row
408
    # check input enabledness
409
    def _getAllLabels(self, row, outputs):
410 411 412
        enabledInputs = self._teacher.getInputAlphabet()
        enabledOutputs = self._teacher.getOutputAlphabet()
        if self._inputPurpose != None:
413
            enabledInputs = self._inputPurpose.getEnabled(row, outputs)
414
        if self._outputPurpose != None:
415
            enabledOutputs = self._outputPurpose.getEnabled(row, outputs)
416
        allLabels = enabledInputs.union(enabledOutputs,
417 418 419
                                        set((self._teacher.getQuiescence(),)))
        return allLabels

420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
    # Generate DOT files for hypotheses. hyp = hMinus|hPlus|both
    def generateDOT(self, path=None, hyp="both", format="pdf"):
        if path == None:
            path = os.path.dirname(os.path.abspath(
                            inspect.getfile(inspect.currentframe())))
        pathHminus = os.path.join(path, "hypotheses", "loop_" +
                                  str(self._currentLoop), "hMinus")
        pathHplus = os.path.join(path, "hypotheses", "loop_" +
                                 str(self._currentLoop), "hPlus")
        if hyp != None:
            if hyp == "hMinus" or hyp == "both":
                gh.createDOTFile(self._hMinus, pathHminus, format)
            if hyp == "hPlus" or hyp == "both":
                gh.createDOTFile(self._hPlus, pathHplus, format)

435 436
    # Start the learning process
    def run(self):
437 438 439 440 441 442 443 444
        noEffectLoops = 0

        oldTablePreciseness = self._table.preciseness()
        oldModelPreciseness = 0

        # While we have not reached our limit for lopps with no effect on
        # the table
        while noEffectLoops < self._noEffectLimit:
445
            self._currentLoop = self._currentLoop + 1
446
            self._logger.info("Learning loop number " + str(self._currentLoop))
447
            # Fill the table and make it closed and consistent
448
            self.updateTable() # TODO learning sometimes stops here!
449
            self.stabilizeTable()
450 451
            # Is the table quiescence reducible? If not make it so and
            # then fill it again, make it closed and consitent
452 453 454
            newSuffixes = self._table.isNotQuiescenceReducible()
            while (newSuffixes or not self._table.isStable()):
                if newSuffixes:
Michele's avatar
Michele committed
455
                    self._logger.debug("Table is not quiescence reducible")
456 457
                    self._logger.debug(newSuffixes)
                    if self._table.addColumn(newSuffixes, force=True):
Michele's avatar
Michele committed
458
                        if self._logger.isEnabledFor(logging.DEBUG):
459
                            self._table.printTable(prefix="_Q_")
Michele's avatar
Michele committed
460
                        self.updateTable()
461
                self.stabilizeTable()
462
                newSuffixes = self._table.isNotQuiescenceReducible()
463

464 465 466 467 468
            self._hMinus = self.getHypothesis()
            self._hPlus = self.getHypothesis(chaos=True)

            if self._printPath != None:
                self.generateDOT(path=self._printPath)
Michele's avatar
Michele committed
469
                self._table.printTable(self._printPath,
470
                                       prefix="loop_"+str(self._currentLoop))
Michele's avatar
Michele committed
471

472 473 474 475 476 477 478
            # Table preciseness will increase easily, it does not depend
            # on nondeterministic choices. Adding rows and columns to the
            # table is enough to increase it.
            # On the contrary, model preciseness will increase only if
            # something really changed in the behaviour of the system.
            # If model preciseness remain the same for too long, then we
            # might want to stop learning.
479
            currentModelPreciseness = self._hPlus.preciseness()
480 481 482 483
            if oldModelPreciseness == currentModelPreciseness:
                # TODO: problem here. These two values could be the same even
                # if the behaviour of the system changed. This might happen in
                # rare cases. For the moment I consider those cases so rare
Michele's avatar
Michele committed
484
                # that I do not handle them.
485 486 487 488 489 490 491
                noEffectLoops = noEffectLoops + 1
            else:
                oldModelPreciseness = currentModelPreciseness
                noEffectLoops = 0

            if (self._table.preciseness() < self._tablePreciseness or
                currentModelPreciseness < self._modelPreciseness):
492
                self._logger.info("Requested table preciseness: " +
493 494
                                str(self._tablePreciseness) + ". Actual: "+
                                str(self._table.preciseness()))
495 496 497
                self._logger.info("Requested model preciseness: " +
                                str(self._modelPreciseness) + ". Actual: "+
                                str(currentModelPreciseness))
498 499 500 501 502 503 504 505

                samples = set(['extendTable', 'testing'])

                # if model preciseness is 1 (maximum), update does not make sense
                if (currentModelPreciseness) < 1:
                    samples.add('update')

                choice = random.sample(samples, 1)[0]
506
                # TODO: forcing testing: remove!!
507
                choice = 'testing'
508
                if choice == 'extendTable':
509
                    self._logger.info("Improve preciseness method: extend the table")
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540
                    # create a set of suffix closed suffixes and add them
                    # to the table
                    # TODO: are there clever ways to create it?
                    columns = set()
                    # Get maximum length in order to be sure to modify the Table
                    length = self._table.getLengthColumns() + 1
                    # we want to avoid quiescence for the moment TODO
                    inAndOut = self._teacher.getInputAlphabet().union(
                                        self._teacher.getOutputAlphabet())
                    newColumn = ()
                    for i in range(length):
                        label = random.sample(inAndOut,1)[0]
                        newColumn = (label,) + newColumn
                        columns.add(newColumn)
                    self._logger.debug(newColumn)

                    self._table.addColumn(columns,True)

                    # TODO: I do not like the next part very much, commented
                    # for the moment. Maybe just increasing columns (as we
                    # would have found a counterexample) is enough

                    # promote 10% of random rows
                    # rows = [x for x in self._table._rows if \
                    #         not self._table.isInS(x) and self._table.isDefined(x)]
                    # randomRows = random.sample(rows, int(len(rows)*0.1))
                    # self._logger.debug(randomRows)
                    # self._table.promote(randomRows)
                    # self._table.addOneLetterExtensions(randomRows)


541
                elif choice == 'update':
542
                    self._logger.info("Improve preciseness method: update the table")
543 544 545 546
                    # TODO: consider to remove this function call and just
                    # add a continue statement, because it is the first
                    # thing to do in the main while loop. Keep for now
                    # for clarity.
547
                    self.updateTable()
Michele's avatar
Michele committed
548
                elif choice == 'testing':
549
                    self._logger.info("Improve preciseness method: testing")
550 551 552 553 554 555 556 557 558
                    ce, output = self.tester.findCounterexample(self._hMinus)
                    # Handle counterexample
                    if ce:
                        if not self._table.handleCounterexample(ce, output):
                            self._logger.info("Handling counterexample: failed")
                        else:
                            self._logger.info("Handling counterexample: succeded")
                    else:
                        self._logger.info("No counterexample found, proceed")
559 560
                continue
            else:
561
                self._logger.info("Requested table preciseness: " +
562 563
                                str(self._tablePreciseness) + ". Actual: "+
                                str(self._table.preciseness()))
564 565 566
                self._logger.info("Requested model preciseness: " +
                                str(self._modelPreciseness) + ". Actual: "+
                                str(currentModelPreciseness))
567
                self._logger.info("Stop learning")
568
                # Exit while loop, return hMinus and hPlus
569
                break
570 571
        if self._logger.isEnabledFor(logging.DEBUG):
            self._table.printTable(prefix="_final_")
572
        return (self._hMinus, self._hPlus)