Commit d8e2fd87 authored by Rick Smetsers's avatar Rick Smetsers
Browse files

DFA encoding

parent 9db3ffab
import collections
from z3gi.model.ra import Action
"""Test scenarios contain a list of traces together with the number of locations and registers of the SUT generating
these traces.
"""
DFATestCase = collections.namedtuple('DFATestCase', 'description traces nr_states')
# Define data
# 16 2
# 1 4 1 0 0 0
# 1 4 0 1 0 0
# 1 4 0 0 1 0
# 1 5 1 0 1 1 1
# 1 6 1 1 1 1 0 1
# 1 6 0 1 0 0 0 0
# 1 6 1 0 0 0 0 0
# 1 7 0 0 0 1 1 0 1
# 1 7 0 0 0 0 1 0 1
# 0 3 1 0 1
# 0 4 0 0 0 0
# 0 4 1 1 0 1
# 0 5 0 0 0 0 0
# 0 5 0 0 1 0 1
# 0 6 0 1 0 1 1 1
# 0 7 1 0 0 0 1 1 1
# store something and accept, as long as you give the stored value, accept, otherwise go back to start and reject
sut_m1 = DFATestCase("Abbadingo website example: This DFA accepts strings of 0's and 1's in which the number of 0's minus twice the number of 1's is either 1 or 3 more than a multiple of 5.",
[('1000', True),
('0100', True),
('0010', True),
('10111', True),
('111101', True),
('010000', True),
('100000', True),
('0001101', True),
('0000101', True),
('101', False),
('0000', False),
('1101', False),
('00000', False),
('00101', False),
('010111', False),
('1000111', False),
('', False)
], 5)
\ No newline at end of file
import unittest
from encode.fa import DFAEncoder
from tests.dfa_testscenario import *
from z3gi.learn.fa import FALearner
from z3gi.define.fa import DFA, MealyMachine
import model.fa
import z3
num_exp = 1
class DFALearnerTest(unittest.TestCase):
def setUp(self):
pass
def test_sut1(self):
self.check_scenario(sut_m1)
def check_scenario(self, test_scenario):
print("Scenario " + test_scenario.description)
for i in range(0, num_exp):
(succ, fa, model) = self.learn_model(test_scenario)
self.assertTrue(succ, msg="Register Automaton could not be inferred")
self.assertEqual(len(fa.states), test_scenario.nr_states,
"Wrong number of states.")
self.assertEqual(len(fa.states), test_scenario.nr_states,
"Wrong number of registers.")
exported = fa.export(model)
print("Learned model: \n", exported)
self.assertEqual(len(exported.states()), test_scenario.nr_states,
"Wrong number of states in exported model. ")
self.assertEqual(len(exported.registers()), test_scenario.nr_registers,
"Wrong number of registers in exported model. ")
self.check_ra_against_obs(exported, test_scenario)
def check_ra_against_obs(self, learned_fa, test_scenario):
"""Checks if the learned RA corresponds to the scenario observations"""
for trace, acc in test_scenario.traces:
self.assertEqual(learned_fa.accepts(trace), acc,
"Register automaton output doesn't correspond to observation {0}".format(str(trace)))
def learn_model(self, test_scenario):
labels = set()
for label, _ in test_scenario.traces:
labels.add(label)
learner = FALearner(list(labels), encoder=DFAEncoder(), verbose=True)
for trace in test_scenario.traces:
learner.add(trace)
min_states = test_scenario.nr_states - 1
max_states = test_scenario.nr_states + 1
result = learner._learn_model(min_states, max_states) #
return result
#
# """
# Visitor class for implementing procedures on inferred RAs.
# """
# class RaVisitor:
# def __init__(self):
# super().__init__()
# """
# Visits every location and transition in the register automaton.
# """
# def process(self, model, ra, labels, regs, states):
# to_visit = [ra.start]
# visited = []
# while (len(to_visit) > 0):
# loc = to_visit.pop(0)
# acc = model.eval(ra.output(loc))
# self._visit_location(loc, acc)
# visited.append(loc)
# next_trans = []
# for l in labels:
# for r in regs:
# guard_enabled = model.eval(ra.guard(loc, l, r))
# if guard_enabled:
# next_loc = model.eval(ra.transition(loc, l, r))
# next_asg = model.eval(ra.update(loc, l))
# next_trans.append((loc, l, r, next_asg, next_loc))
#
# for (start_loc, label, guard, asg, end_loc) in next_trans:
# self._visit_transition(start_loc, label, guard, asg, end_loc)
# if end_loc not in visited and end_loc not in to_visit:
# to_visit.append(end_loc)
# # we sort according to the location strings so we get them in order
# to_visit.sort(key=lambda loc: str(loc))
# """
# Visits states in the RA in lexographical order starting from the initial location.
# """
# def _visit_location(self, loc, acc):
# raise NotImplementedError()
# """
# Visits transitions in the RA.
# """
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# raise NotImplementedError()
# class RaPrinter(RaVisitor):
# def __init__(self):
# super().__init__()
# """
# Prints location.
# """
# def _visit_location(self, loc, acc):
# print("{0}({1})".format(str(loc), "+" if acc == True else "-") )
# """
# Prints transition.
# """
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# print("\t{0} -> {1}({2}) {3} {4}".format(str(start_loc), str(label), str(guard), str(asg), str(end_loc)))
# # TODO it should probably store states/regs as strings
# class SimpleRa():
# def __init__(self, states, loc_to_acc, loc_to_trans, registers):
# super().__init__()
# self.states = states
# self.loc_to_acc = loc_to_acc
# self.loc_to_trans = loc_to_trans
# self.register = registers
# def get_start_loc(self):
# return self.states[0]
# def get_states(self):
# return list(self.states)
# def get_transitions(self, loc, label=None):
# if label is None:
# return list(self.loc_to_trans[loc])
# else:
# return list([trans for trans in self.loc_to_trans[loc] if trans[1] == label])
# def get_registers(self):
# return list(self.registers)
# def get_acc(self, loc):
# return self.loc_to_acc[loc]
# class NoTransitionTriggeredException(Exception):
# pass
# class SimpleRaSimulator():
# def __init__(self, sra):
# super().__init__()
# self.ra = sra
# """
# Runs the given sequence of values on the RA.
# """
# def accepts(self, trace):
# init = -1
# reg_val = dict()
# for reg in self.ra.get_registers():
# reg_val[reg] = init
# loc = self.ra.get_start_loc()
# for act in trace:
# next_transitions = self.ra.get_transitions(loc, act.label)
# # to define a fresh guard we need to know which register guards are present
# active_regs = [trans[2] for trans in next_transitions]
# n_loc = None
# for (_, _, guard, asg, next_loc) in next_transitions:
# if (self._is_satisfied(act, guard, active_regs, reg_val)):
# if not is_fresh(asg):
# reg_val[asg] = act.value
# n_loc = next_loc
# break
# if n_loc is None:
# print("In location {0} with trans. {1}, \n reg vals {2} and crt val {3}".format(
# str(loc), str(next_transitions), str(reg_val), str(act.value)
# ))
# raise NoTransitionTriggeredException()
# else:
# loc = n_loc
# return self.ra.get_acc(loc)
# def _is_satisfied(self, act, guard, active_regs, reg_val):
# if is_fresh(guard):
# reg_vals = list([reg_val[reg] for reg in active_regs])
# return act.value not in reg_vals
# else:
# return act.value is reg_val[guard]
# """
# Builds a SRA from the inferred uninterpreted functions for the RA.
# """
# class SimpleRaBuilder(RaVisitor):
# def __init__(self):
# super().__init__()
# self.states = []
# self.loc_to_acc = dict()
# self.loc_to_trans = dict()
# self.registers = []
# def _visit_location(self, loc, acc):
# self.states.append(loc)
# self.loc_to_acc[loc] = acc
# if loc not in self.loc_to_trans:
# self.loc_to_trans[loc] = []
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# self.loc_to_trans[start_loc].append((start_loc, label, guard, asg, end_loc))
# if not is_fresh(guard) and guard not in self.registers:
# self.registers.append(guard)
# if not is_fresh(asg) and asg not in self.registers:
# self.registers.append(asg)
# """
# Builds a SRA from the RA generated functions. It uses as states and registers the actual Z3 constants.
# """
# def build_ra(self):
# return SimpleRa(self.states, self.loc_to_acc, self.loc_to_trans, self.registers.sort(key=lambda reg: str(reg)))
#
#
# def is_fresh(reg):
# return str(reg) == str("fresh")
\ No newline at end of file
...@@ -11,10 +11,12 @@ class FSM(Automaton,metaclass=ABCMeta): ...@@ -11,10 +11,12 @@ class FSM(Automaton,metaclass=ABCMeta):
@abstractmethod @abstractmethod
def __init__(self, num_states): def __init__(self, num_states):
self.State, self.states = enum('State', ['state{0}'.format(i) for i in range(num_states)]) self.State, self.states = enum('State', ['state{0}'.format(i) for i in range(num_states)])
self.start = self.states[0]
class DFA(FSM): class DFA(FSM):
def __init__(self, labels, num_states): def __init__(self, labels, num_states):
super.__init__(num_states) super().__init__(num_states)
labels = list(labels)
self.Label, elements = enum('Label', labels) self.Label, elements = enum('Label', labels)
self.labels = {labels[i]: elements[i] for i in range(len(labels))} self.labels = {labels[i]: elements[i] for i in range(len(labels))}
self.transition = z3.Function('transition', self.State, self.Label, self.State) self.transition = z3.Function('transition', self.State, self.Label, self.State)
...@@ -22,13 +24,13 @@ class DFA(FSM): ...@@ -22,13 +24,13 @@ class DFA(FSM):
def export(self, model : z3.ModelRef) -> model.fa.DFA: def export(self, model : z3.ModelRef) -> model.fa.DFA:
builder = DFABuilder(self) builder = DFABuilder(self)
dfa = builder.build_dfa(self) dfa = builder.build_dfa(model)
return dfa return dfa
class MealyMachine(FSM): class MealyMachine(FSM):
def __init__(self, input_labels, output_labels, num_states): def __init__(self, input_labels, output_labels, num_states):
super.__init__(num_states) super().__init__(num_states)
self.Input, elements = enum('Input', input_labels) self.Input, elements = enum('Input', input_labels)
self.inputs = {input_labels[i]: elements[i] for i in range(len(input_labels))} self.inputs = {input_labels[i]: elements[i] for i in range(len(input_labels))}
self.Output, elements = enum('Output', output_labels) self.Output, elements = enum('Output', output_labels)
...@@ -45,7 +47,7 @@ class Mapper(object): ...@@ -45,7 +47,7 @@ class Mapper(object):
def __init__(self, fa): def __init__(self, fa):
self.Element = z3.DeclareSort('Element') self.Element = z3.DeclareSort('Element')
self.start = self.element(0) self.start = self.element(0)
self.map = z3.Function('map', self.Element, fa.Location) self.map = z3.Function('map', self.Element, fa.State)
def element(self, name): def element(self, name):
return z3.Const("n"+str(name), self.Element) return z3.Const("n"+str(name), self.Element)
...@@ -79,18 +81,19 @@ class DFABuilder(object): ...@@ -79,18 +81,19 @@ class DFABuilder(object):
self.dfa = dfa self.dfa = dfa
def build_dfa(self, m : z3.ModelRef) -> model.fa.DFA: def build_dfa(self, m : z3.ModelRef) -> model.fa.DFA:
tr = FATranslator() tr = FATranslator(self.dfa)
mut_dfa = model.fa.MutableDFA() mut_dfa = model.fa.MutableDFA()
for state in self.dfa.states: for state in self.dfa.states:
accepting = m.eval(self.dfa.output(state)) accepting = m.eval(self.dfa.output(state))
mut_dfa.add_state(tr.z3_to_state(state), tr.z3_to_bool(accepting)) mut_dfa.add_state(tr.z3_to_state(state), tr.z3_to_bool(accepting))
for state in self.dfa.states: for state in self.dfa.states:
for labels in self.dfa.labels: for labels in self.dfa.labels.values():
to_state = m.eval(self.dfa.transition(state, labels)) to_state = m.eval(self.dfa.transition(state, labels))
trans = Transition( trans = Transition(
tr.z3_to_state(state), tr.z3_to_state(state),
tr.z3_to_label(labels), tr.z3_to_label(labels),
tr.z3_to_state(to_state)) tr.z3_to_state(to_state))
mut_dfa.add_transition(state, trans)
return mut_dfa.to_immutable() return mut_dfa.to_immutable()
......
...@@ -8,7 +8,7 @@ import z3 ...@@ -8,7 +8,7 @@ import z3
class DFAEncoder(Encoder): class DFAEncoder(Encoder):
def __init__(self, labels): def __init__(self):
self.tree = Tree(itertools.count(0)) self.tree = Tree(itertools.count(0))
self.cache = {} self.cache = {}
self.labels = set() self.labels = set()
...@@ -32,7 +32,8 @@ class DFAEncoder(Encoder): ...@@ -32,7 +32,8 @@ class DFAEncoder(Encoder):
def node_constraints(self, dfa, mapper): def node_constraints(self, dfa, mapper):
constraints = [] constraints = []
for node, accept in self.cache: for node in self.cache:
accept = self.cache[node]
n = mapper.element(node.id) n = mapper.element(node.id)
constraints.append(dfa.output(mapper.map(n)) == accept) constraints.append(dfa.output(mapper.map(n)) == accept)
return constraints return constraints
......
...@@ -4,45 +4,43 @@ from encode.ra import RAEncoder ...@@ -4,45 +4,43 @@ from encode.ra import RAEncoder
from learn import Learner from learn import Learner
import model.fa import model.fa
import z3
from encode.fa import DFAEncoder, MealyEncoder
from learn import Learner
import model.fa
class MealyLearner(Learner): class FALearner(Learner):
def __init__(self, labels, io=False, outputs=None, encoder=None, solver=None, verbose=False): def __init__(self, labels, encoder, solver=None, verbose=False):
if not encoder:
encoder = RAEncoder()
if not solver: if not solver:
solver = z3.Solver() solver = z3.Solver()
if outputs:
self.outputs = outputs
self.labels = labels
self.encoder = encoder
self.solver = solver self.solver = solver
self.encoder = encoder
self.verbose = verbose self.verbose = verbose
self.io = io
def add(self, trace): def add(self, trace):
self.encoder.add(trace) self.encoder.add(trace)
def model(self, min_states=1, max_states=20) -> model.fa.MealyMachine: def model(self, min_states=1, max_states=20):
(succ, ra_def, m) = self._learn_model(min_states=min_states, (succ, fa, m) = self._learn_model(min_states=min_states,
max_states=max_states) max_states=max_states)
if succ: if succ:
return ra_def.export(m) return fa.export(m)
return None return None
def _learn_model(self, min_states=1, max_states=20): def _learn_model(self, min_states=1, max_states=20):
"""generates the definition and model for an ra whose traces include the traces added so far""" """generates the definition and model for an fa whose traces include the traces added so far"""
num_values = len(self.encoder.values) for num_states in range(min_states, max_states + 1):
for num_locations in range(min_states, max_states + 1): fa, constraints = self.encoder.build(num_states)
ra, constraints = self.encoder.build(num_locations)
self.solver.add(constraints) self.solver.add(constraints)
result = self.solver.check() result = self.solver.check()
if self.verbose: if self.verbose:
print("Learning with {0} states. Result: {1}" print("Learning with {0} states. Result: {1}"
.format(num_locations, result)) .format(num_states, result))
if result == z3.sat: if result == z3.sat:
model = self.solver.model() model = self.solver.model()
self.solver.reset() self.solver.reset()
return (True, ra, model) return (True, fa, model)
else: else:
self.solver.reset() self.solver.reset()
# TODO: process the unsat core? # TODO: process the unsat core?
......
...@@ -15,7 +15,7 @@ class IOTransition(Transition): ...@@ -15,7 +15,7 @@ class IOTransition(Transition):
class DFA(Acceptor): class DFA(Acceptor):
def __init__(self, states, state_to_trans, state_to_acc): def __init__(self, states, state_to_trans, state_to_acc):
super().__init__(states, state_to_trans, state_to_acc) super().__init__(states, state_to_trans, state_to_acc)
def transitions(self, state: State, label:Label = None) -> List[Transition]: def transitions(self, state: State, label:Label = None) -> List[Transition]:
return super().transitions(state, label) return super().transitions(state, label)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment