Commit 9c09fa4c authored by Paul Fiterau Brostean's avatar Paul Fiterau Brostean
Browse files

Updates. Tests pass.

parent d8e2fd87
......@@ -3,9 +3,6 @@ import unittest
from encode.fa import DFAEncoder
from tests.dfa_testscenario import *
from z3gi.learn.fa import FALearner
from z3gi.define.fa import DFA, MealyMachine
import model.fa
import z3
num_exp = 1
......@@ -32,8 +29,6 @@ class DFALearnerTest(unittest.TestCase):
print("Learned model: \n", exported)
self.assertEqual(len(exported.states()), test_scenario.nr_states,
"Wrong number of states in exported model. ")
self.assertEqual(len(exported.registers()), test_scenario.nr_registers,
"Wrong number of registers in exported model. ")
self.check_ra_against_obs(exported, test_scenario)
def check_ra_against_obs(self, learned_fa, test_scenario):
......@@ -54,153 +49,4 @@ class DFALearnerTest(unittest.TestCase):
max_states = test_scenario.nr_states + 1
result = learner._learn_model(min_states, max_states) #
return result
#
# """
# Visitor class for implementing procedures on inferred RAs.
# """
# class RaVisitor:
# def __init__(self):
# super().__init__()
# """
# Visits every location and transition in the register automaton.
# """
# def process(self, model, ra, labels, regs, states):
# to_visit = [ra.start]
# visited = []
# while (len(to_visit) > 0):
# loc = to_visit.pop(0)
# acc = model.eval(ra.output(loc))
# self._visit_location(loc, acc)
# visited.append(loc)
# next_trans = []
# for l in labels:
# for r in regs:
# guard_enabled = model.eval(ra.guard(loc, l, r))
# if guard_enabled:
# next_loc = model.eval(ra.transition(loc, l, r))
# next_asg = model.eval(ra.update(loc, l))
# next_trans.append((loc, l, r, next_asg, next_loc))
#
# for (start_loc, label, guard, asg, end_loc) in next_trans:
# self._visit_transition(start_loc, label, guard, asg, end_loc)
# if end_loc not in visited and end_loc not in to_visit:
# to_visit.append(end_loc)
# # we sort according to the location strings so we get them in order
# to_visit.sort(key=lambda loc: str(loc))
# """
# Visits states in the RA in lexographical order starting from the initial location.
# """
# def _visit_location(self, loc, acc):
# raise NotImplementedError()
# """
# Visits transitions in the RA.
# """
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# raise NotImplementedError()
# class RaPrinter(RaVisitor):
# def __init__(self):
# super().__init__()
# """
# Prints location.
# """
# def _visit_location(self, loc, acc):
# print("{0}({1})".format(str(loc), "+" if acc == True else "-") )
# """
# Prints transition.
# """
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# print("\t{0} -> {1}({2}) {3} {4}".format(str(start_loc), str(label), str(guard), str(asg), str(end_loc)))
# # TODO it should probably store states/regs as strings
# class SimpleRa():
# def __init__(self, states, loc_to_acc, loc_to_trans, registers):
# super().__init__()
# self.states = states
# self.loc_to_acc = loc_to_acc
# self.loc_to_trans = loc_to_trans
# self.register = registers
# def get_start_loc(self):
# return self.states[0]
# def get_states(self):
# return list(self.states)
# def get_transitions(self, loc, label=None):
# if label is None:
# return list(self.loc_to_trans[loc])
# else:
# return list([trans for trans in self.loc_to_trans[loc] if trans[1] == label])
# def get_registers(self):
# return list(self.registers)
# def get_acc(self, loc):
# return self.loc_to_acc[loc]
# class NoTransitionTriggeredException(Exception):
# pass
# class SimpleRaSimulator():
# def __init__(self, sra):
# super().__init__()
# self.ra = sra
# """
# Runs the given sequence of values on the RA.
# """
# def accepts(self, trace):
# init = -1
# reg_val = dict()
# for reg in self.ra.get_registers():
# reg_val[reg] = init
# loc = self.ra.get_start_loc()
# for act in trace:
# next_transitions = self.ra.get_transitions(loc, act.label)
# # to define a fresh guard we need to know which register guards are present
# active_regs = [trans[2] for trans in next_transitions]
# n_loc = None
# for (_, _, guard, asg, next_loc) in next_transitions:
# if (self._is_satisfied(act, guard, active_regs, reg_val)):
# if not is_fresh(asg):
# reg_val[asg] = act.value
# n_loc = next_loc
# break
# if n_loc is None:
# print("In location {0} with trans. {1}, \n reg vals {2} and crt val {3}".format(
# str(loc), str(next_transitions), str(reg_val), str(act.value)
# ))
# raise NoTransitionTriggeredException()
# else:
# loc = n_loc
# return self.ra.get_acc(loc)
# def _is_satisfied(self, act, guard, active_regs, reg_val):
# if is_fresh(guard):
# reg_vals = list([reg_val[reg] for reg in active_regs])
# return act.value not in reg_vals
# else:
# return act.value is reg_val[guard]
# """
# Builds a SRA from the inferred uninterpreted functions for the RA.
# """
# class SimpleRaBuilder(RaVisitor):
# def __init__(self):
# super().__init__()
# self.states = []
# self.loc_to_acc = dict()
# self.loc_to_trans = dict()
# self.registers = []
# def _visit_location(self, loc, acc):
# self.states.append(loc)
# self.loc_to_acc[loc] = acc
# if loc not in self.loc_to_trans:
# self.loc_to_trans[loc] = []
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# self.loc_to_trans[start_loc].append((start_loc, label, guard, asg, end_loc))
# if not is_fresh(guard) and guard not in self.registers:
# self.registers.append(guard)
# if not is_fresh(asg) and asg not in self.registers:
# self.registers.append(asg)
# """
# Builds a SRA from the RA generated functions. It uses as states and registers the actual Z3 constants.
# """
# def build_ra(self):
# return SimpleRa(self.states, self.loc_to_acc, self.loc_to_trans, self.registers.sort(key=lambda reg: str(reg)))
#
#
# def is_fresh(reg):
# return str(reg) == str("fresh")
\ No newline at end of file
return result
\ No newline at end of file
......@@ -93,7 +93,7 @@ class DFABuilder(object):
tr.z3_to_state(state),
tr.z3_to_label(labels),
tr.z3_to_state(to_state))
mut_dfa.add_transition(state, trans)
mut_dfa.add_transition(tr.z3_to_state(state), trans)
return mut_dfa.to_immutable()
......
......@@ -9,7 +9,7 @@ class Learner(metaclass=ABCMeta):
pass
@abstractmethod
def model(self) -> Automaton:
def model(self, old_model=None) -> Automaton:
""""Infers a minimal model whose behavior corresponds to the traces added so far.
Returns None if no model could be obtained."""
pass
......
......@@ -15,7 +15,7 @@ class ActiveLearner():
self.tester = model_tester
def learn(self, inputs:List[str]):
model = None
model = self._epsilon(inputs)
while True:
model = self._learn_new(self.learner, model)
trace = self.tester.find_ce(model)
......
......@@ -8,7 +8,7 @@ import z3
from encode.fa import DFAEncoder, MealyEncoder
from learn import Learner
import model.fa
from model import Automaton
class FALearner(Learner):
def __init__(self, labels, encoder, solver=None, verbose=False):
......@@ -21,7 +21,9 @@ class FALearner(Learner):
def add(self, trace):
self.encoder.add(trace)
def model(self, min_states=1, max_states=20):
def model(self, min_states=1, max_states=20, old_model:Automaton=None) -> Automaton:
if old_model is not None:
min_states = len(old_model.states())
(succ, fa, m) = self._learn_model(min_states=min_states,
max_states=max_states)
if succ:
......
......@@ -22,9 +22,15 @@ class RALearner(Learner):
def add(self, trace):
self.encoder.add(trace)
def model(self, min_locations=1, max_locations=20, num_registers=0) -> model.ra.RegisterAutomaton:
def model(self, min_locations=1, max_locations=20, min_registers=0, max_registers=10,
old_model:model.ra.RegisterAutomaton = None) -> model.ra.RegisterAutomaton:
if old_model is not None:
min_locations = len(old_model.states())
min_registers = len(old_model.registers())
(succ, ra_def, m) = self._learn_model(min_locations=min_locations,
max_locations=max_locations, num_registers=num_registers)
max_locations=max_locations, min_registers=min_registers,
max_registers=max_registers)
if succ:
return ra_def.export(m)
return None
......
from abc import ABCMeta, abstractmethod
from typing import List
"""The most basic transition class available"""
"""The most basic transition class available"""
class Transition():
def __init__(self, start_state, start_label, end_state):
self.start_state = start_state
......@@ -11,71 +11,81 @@ class Transition():
self.end_state = end_state
def __str__(self, shortened=False):
short = "{0} -> {1}".format(self.start_label, self.end_state)
if not shortened:
return "{0} {1} -> {2}".format(self.start_state, self.start_label, self.end_state)
return "{0} {1}".format(self.start_state, short)
else:
"{1} -> {2}".format(self.start_label, self.end_state)
return short
"""Exception raised when no transition can be fired"""
class NoTransitionFired(Exception):
pass
pass
"""Exception raised when several transitions can be fired in a deterministic machine"""
class MultipleTransitionsFired(Exception):
pass
"""A basic abstract automaton model"""
class Automaton(metaclass=ABCMeta):
def __init__(self, states, state_to_trans):
super().__init__()
self._states = states
self._state_to_trans = state_to_trans
def __init__(self, states, state_to_trans):
super().__init__()
self._states = states
self._state_to_trans = state_to_trans
def start_state(self):
return self._states[0]
def start_state(self):
return self._states[0]
def states(self):
return list(self._states)
def states(self):
return list(self._states)
def transitions(self, state, label=None) -> List[Transition]:
if label is None:
return list(self._state_to_trans[state])
else:
return list([trans for trans in self._state_to_trans[state] if trans.start_label == label])
def transitions(self, state, label=None) -> List[Transition]:
if label is None:
return list(self._state_to_trans[state])
else:
return list([trans for trans in self._state_to_trans[state] if trans.start_label == label])
def state(self, trace):
"""state function which also provides a basic implementation"""
crt_state = self.start_state()
for symbol in trace:
transitions = self.transitions(crt_state, symbol)
fired_transitions = [trans for trans in transitions if trans.start_label == symbol]
def state(self, trace):
"""state function which also provides a basic implementation"""
crt_state = self.start_state()
for symbol in trace:
transitions = self.transitions(crt_state, symbol)
fired_transitions = [trans for trans in transitions if trans.start_label == symbol]
# the number of fired transitions can be more than one since we could have multiple equalities
if len(fired_transitions) == 0:
raise NoTransitionFired
# the number of fired transitions can be more than one since we could have multiple equalities
if len(fired_transitions) == 0:
raise NoTransitionFired
if len(fired_transitions) > 1:
raise MultipleTransitionsFired
if len(fired_transitions) > 1:
raise MultipleTransitionsFired
fired_transition = fired_transitions[0]
crt_state = fired_transition.end_state
fired_transition = fired_transitions[0]
crt_state = fired_transition.end_state
return crt_state
return crt_state
@abstractmethod
def output(self, trace):
pass
@abstractmethod
def output(self, trace):
pass
# Basic __str__ function which works for most FSMs.
def __str__(self):
str_rep = ""
for state in self.states():
str_rep += str(state) + "\n"
for tran in self.transitions(state):
str_rep += "\t" + tran.__str__(shortened=True) + "\n"
# Basic __str__ function which works for most FSMs.
def __str__(self):
str_rep = ""
for state in self.states():
str_rep += str(state) + "\n"
for tran in self.transitions(state):
str_rep += "\t"+tran.__str__(shortened=True) + "\n"
return str_rep
return str_rep
class MutableAutomatonMixin(metaclass=ABCMeta):
def add_state(self, state):
......@@ -91,7 +101,10 @@ class MutableAutomatonMixin(metaclass=ABCMeta):
def to_immutable(self) -> Automaton:
pass
"""An automaton model that generates output"""
class Transducer(Automaton, metaclass=ABCMeta):
def __init__(self, states, state_to_trans):
super().__init__(states, state_to_trans)
......@@ -100,7 +113,10 @@ class Transducer(Automaton, metaclass=ABCMeta):
def output(self, trace):
pass
"""An automaton model whose states are accepting/rejecting"""
class Acceptor(Automaton, metaclass=ABCMeta):
def __init__(self, states, state_to_trans, state_to_acc):
super().__init__(states, state_to_trans)
......@@ -120,8 +136,8 @@ class Acceptor(Automaton, metaclass=ABCMeta):
def __str__(self):
return str(self._state_to_acc) + "\n" + super().__str__()
class MutableAcceptorMixin(MutableAutomatonMixin, metaclass=ABCMeta):
def add_state(self, state, accepts):
super().add_state(state)
self._state_to_acc[state] = accepts
......@@ -13,6 +13,13 @@ class IOTransition(Transition):
super().__init__(start_state, start_label, end_state)
self.output = output
def __str__(self, shortened=False):
short = "{0}/{1} -> {2}".format(self.start_label, self.output, self.end_state)
if not shortened:
return "{0} {1}".format(self.start_state, short)
else:
return short
class DFA(Acceptor):
def __init__(self, states, state_to_trans, state_to_acc):
super().__init__(states, state_to_trans, state_to_acc)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment