Commit d8e2fd87 authored by Rick Smetsers's avatar Rick Smetsers
Browse files

DFA encoding

parent 9db3ffab
import collections
from z3gi.model.ra import Action
"""Test scenarios contain a list of traces together with the number of locations and registers of the SUT generating
these traces.
"""
DFATestCase = collections.namedtuple('DFATestCase', 'description traces nr_states')
# Define data
# 16 2
# 1 4 1 0 0 0
# 1 4 0 1 0 0
# 1 4 0 0 1 0
# 1 5 1 0 1 1 1
# 1 6 1 1 1 1 0 1
# 1 6 0 1 0 0 0 0
# 1 6 1 0 0 0 0 0
# 1 7 0 0 0 1 1 0 1
# 1 7 0 0 0 0 1 0 1
# 0 3 1 0 1
# 0 4 0 0 0 0
# 0 4 1 1 0 1
# 0 5 0 0 0 0 0
# 0 5 0 0 1 0 1
# 0 6 0 1 0 1 1 1
# 0 7 1 0 0 0 1 1 1
# store something and accept, as long as you give the stored value, accept, otherwise go back to start and reject
sut_m1 = DFATestCase("Abbadingo website example: This DFA accepts strings of 0's and 1's in which the number of 0's minus twice the number of 1's is either 1 or 3 more than a multiple of 5.",
[('1000', True),
('0100', True),
('0010', True),
('10111', True),
('111101', True),
('010000', True),
('100000', True),
('0001101', True),
('0000101', True),
('101', False),
('0000', False),
('1101', False),
('00000', False),
('00101', False),
('010111', False),
('1000111', False),
('', False)
], 5)
\ No newline at end of file
import unittest
from encode.fa import DFAEncoder
from tests.dfa_testscenario import *
from z3gi.learn.fa import FALearner
from z3gi.define.fa import DFA, MealyMachine
import model.fa
import z3
num_exp = 1
class DFALearnerTest(unittest.TestCase):
def setUp(self):
pass
def test_sut1(self):
self.check_scenario(sut_m1)
def check_scenario(self, test_scenario):
print("Scenario " + test_scenario.description)
for i in range(0, num_exp):
(succ, fa, model) = self.learn_model(test_scenario)
self.assertTrue(succ, msg="Register Automaton could not be inferred")
self.assertEqual(len(fa.states), test_scenario.nr_states,
"Wrong number of states.")
self.assertEqual(len(fa.states), test_scenario.nr_states,
"Wrong number of registers.")
exported = fa.export(model)
print("Learned model: \n", exported)
self.assertEqual(len(exported.states()), test_scenario.nr_states,
"Wrong number of states in exported model. ")
self.assertEqual(len(exported.registers()), test_scenario.nr_registers,
"Wrong number of registers in exported model. ")
self.check_ra_against_obs(exported, test_scenario)
def check_ra_against_obs(self, learned_fa, test_scenario):
"""Checks if the learned RA corresponds to the scenario observations"""
for trace, acc in test_scenario.traces:
self.assertEqual(learned_fa.accepts(trace), acc,
"Register automaton output doesn't correspond to observation {0}".format(str(trace)))
def learn_model(self, test_scenario):
labels = set()
for label, _ in test_scenario.traces:
labels.add(label)
learner = FALearner(list(labels), encoder=DFAEncoder(), verbose=True)
for trace in test_scenario.traces:
learner.add(trace)
min_states = test_scenario.nr_states - 1
max_states = test_scenario.nr_states + 1
result = learner._learn_model(min_states, max_states) #
return result
#
# """
# Visitor class for implementing procedures on inferred RAs.
# """
# class RaVisitor:
# def __init__(self):
# super().__init__()
# """
# Visits every location and transition in the register automaton.
# """
# def process(self, model, ra, labels, regs, states):
# to_visit = [ra.start]
# visited = []
# while (len(to_visit) > 0):
# loc = to_visit.pop(0)
# acc = model.eval(ra.output(loc))
# self._visit_location(loc, acc)
# visited.append(loc)
# next_trans = []
# for l in labels:
# for r in regs:
# guard_enabled = model.eval(ra.guard(loc, l, r))
# if guard_enabled:
# next_loc = model.eval(ra.transition(loc, l, r))
# next_asg = model.eval(ra.update(loc, l))
# next_trans.append((loc, l, r, next_asg, next_loc))
#
# for (start_loc, label, guard, asg, end_loc) in next_trans:
# self._visit_transition(start_loc, label, guard, asg, end_loc)
# if end_loc not in visited and end_loc not in to_visit:
# to_visit.append(end_loc)
# # we sort according to the location strings so we get them in order
# to_visit.sort(key=lambda loc: str(loc))
# """
# Visits states in the RA in lexographical order starting from the initial location.
# """
# def _visit_location(self, loc, acc):
# raise NotImplementedError()
# """
# Visits transitions in the RA.
# """
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# raise NotImplementedError()
# class RaPrinter(RaVisitor):
# def __init__(self):
# super().__init__()
# """
# Prints location.
# """
# def _visit_location(self, loc, acc):
# print("{0}({1})".format(str(loc), "+" if acc == True else "-") )
# """
# Prints transition.
# """
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# print("\t{0} -> {1}({2}) {3} {4}".format(str(start_loc), str(label), str(guard), str(asg), str(end_loc)))
# # TODO it should probably store states/regs as strings
# class SimpleRa():
# def __init__(self, states, loc_to_acc, loc_to_trans, registers):
# super().__init__()
# self.states = states
# self.loc_to_acc = loc_to_acc
# self.loc_to_trans = loc_to_trans
# self.register = registers
# def get_start_loc(self):
# return self.states[0]
# def get_states(self):
# return list(self.states)
# def get_transitions(self, loc, label=None):
# if label is None:
# return list(self.loc_to_trans[loc])
# else:
# return list([trans for trans in self.loc_to_trans[loc] if trans[1] == label])
# def get_registers(self):
# return list(self.registers)
# def get_acc(self, loc):
# return self.loc_to_acc[loc]
# class NoTransitionTriggeredException(Exception):
# pass
# class SimpleRaSimulator():
# def __init__(self, sra):
# super().__init__()
# self.ra = sra
# """
# Runs the given sequence of values on the RA.
# """
# def accepts(self, trace):
# init = -1
# reg_val = dict()
# for reg in self.ra.get_registers():
# reg_val[reg] = init
# loc = self.ra.get_start_loc()
# for act in trace:
# next_transitions = self.ra.get_transitions(loc, act.label)
# # to define a fresh guard we need to know which register guards are present
# active_regs = [trans[2] for trans in next_transitions]
# n_loc = None
# for (_, _, guard, asg, next_loc) in next_transitions:
# if (self._is_satisfied(act, guard, active_regs, reg_val)):
# if not is_fresh(asg):
# reg_val[asg] = act.value
# n_loc = next_loc
# break
# if n_loc is None:
# print("In location {0} with trans. {1}, \n reg vals {2} and crt val {3}".format(
# str(loc), str(next_transitions), str(reg_val), str(act.value)
# ))
# raise NoTransitionTriggeredException()
# else:
# loc = n_loc
# return self.ra.get_acc(loc)
# def _is_satisfied(self, act, guard, active_regs, reg_val):
# if is_fresh(guard):
# reg_vals = list([reg_val[reg] for reg in active_regs])
# return act.value not in reg_vals
# else:
# return act.value is reg_val[guard]
# """
# Builds a SRA from the inferred uninterpreted functions for the RA.
# """
# class SimpleRaBuilder(RaVisitor):
# def __init__(self):
# super().__init__()
# self.states = []
# self.loc_to_acc = dict()
# self.loc_to_trans = dict()
# self.registers = []
# def _visit_location(self, loc, acc):
# self.states.append(loc)
# self.loc_to_acc[loc] = acc
# if loc not in self.loc_to_trans:
# self.loc_to_trans[loc] = []
# def _visit_transition(self, start_loc, label, guard, asg, end_loc):
# self.loc_to_trans[start_loc].append((start_loc, label, guard, asg, end_loc))
# if not is_fresh(guard) and guard not in self.registers:
# self.registers.append(guard)
# if not is_fresh(asg) and asg not in self.registers:
# self.registers.append(asg)
# """
# Builds a SRA from the RA generated functions. It uses as states and registers the actual Z3 constants.
# """
# def build_ra(self):
# return SimpleRa(self.states, self.loc_to_acc, self.loc_to_trans, self.registers.sort(key=lambda reg: str(reg)))
#
#
# def is_fresh(reg):
# return str(reg) == str("fresh")
\ No newline at end of file
......@@ -11,10 +11,12 @@ class FSM(Automaton,metaclass=ABCMeta):
@abstractmethod
def __init__(self, num_states):
self.State, self.states = enum('State', ['state{0}'.format(i) for i in range(num_states)])
self.start = self.states[0]
class DFA(FSM):
def __init__(self, labels, num_states):
super.__init__(num_states)
super().__init__(num_states)
labels = list(labels)
self.Label, elements = enum('Label', labels)
self.labels = {labels[i]: elements[i] for i in range(len(labels))}
self.transition = z3.Function('transition', self.State, self.Label, self.State)
......@@ -22,13 +24,13 @@ class DFA(FSM):
def export(self, model : z3.ModelRef) -> model.fa.DFA:
builder = DFABuilder(self)
dfa = builder.build_dfa(self)
dfa = builder.build_dfa(model)
return dfa
class MealyMachine(FSM):
def __init__(self, input_labels, output_labels, num_states):
super.__init__(num_states)
super().__init__(num_states)
self.Input, elements = enum('Input', input_labels)
self.inputs = {input_labels[i]: elements[i] for i in range(len(input_labels))}
self.Output, elements = enum('Output', output_labels)
......@@ -45,7 +47,7 @@ class Mapper(object):
def __init__(self, fa):
self.Element = z3.DeclareSort('Element')
self.start = self.element(0)
self.map = z3.Function('map', self.Element, fa.Location)
self.map = z3.Function('map', self.Element, fa.State)
def element(self, name):
return z3.Const("n"+str(name), self.Element)
......@@ -79,18 +81,19 @@ class DFABuilder(object):
self.dfa = dfa
def build_dfa(self, m : z3.ModelRef) -> model.fa.DFA:
tr = FATranslator()
tr = FATranslator(self.dfa)
mut_dfa = model.fa.MutableDFA()
for state in self.dfa.states:
accepting = m.eval(self.dfa.output(state))
mut_dfa.add_state(tr.z3_to_state(state), tr.z3_to_bool(accepting))
for state in self.dfa.states:
for labels in self.dfa.labels:
for labels in self.dfa.labels.values():
to_state = m.eval(self.dfa.transition(state, labels))
trans = Transition(
tr.z3_to_state(state),
tr.z3_to_label(labels),
tr.z3_to_state(to_state))
mut_dfa.add_transition(state, trans)
return mut_dfa.to_immutable()
......
......@@ -8,7 +8,7 @@ import z3
class DFAEncoder(Encoder):
def __init__(self, labels):
def __init__(self):
self.tree = Tree(itertools.count(0))
self.cache = {}
self.labels = set()
......@@ -32,7 +32,8 @@ class DFAEncoder(Encoder):
def node_constraints(self, dfa, mapper):
constraints = []
for node, accept in self.cache:
for node in self.cache:
accept = self.cache[node]
n = mapper.element(node.id)
constraints.append(dfa.output(mapper.map(n)) == accept)
return constraints
......
......@@ -4,45 +4,43 @@ from encode.ra import RAEncoder
from learn import Learner
import model.fa
import z3
from encode.fa import DFAEncoder, MealyEncoder
from learn import Learner
import model.fa
class MealyLearner(Learner):
def __init__(self, labels, io=False, outputs=None, encoder=None, solver=None, verbose=False):
if not encoder:
encoder = RAEncoder()
class FALearner(Learner):
def __init__(self, labels, encoder, solver=None, verbose=False):
if not solver:
solver = z3.Solver()
if outputs:
self.outputs = outputs
self.labels = labels
self.encoder = encoder
self.solver = solver
self.encoder = encoder
self.verbose = verbose
self.io = io
def add(self, trace):
self.encoder.add(trace)
def model(self, min_states=1, max_states=20) -> model.fa.MealyMachine:
(succ, ra_def, m) = self._learn_model(min_states=min_states,
def model(self, min_states=1, max_states=20):
(succ, fa, m) = self._learn_model(min_states=min_states,
max_states=max_states)
if succ:
return ra_def.export(m)
return fa.export(m)
return None
def _learn_model(self, min_states=1, max_states=20):
"""generates the definition and model for an ra whose traces include the traces added so far"""
num_values = len(self.encoder.values)
for num_locations in range(min_states, max_states + 1):
ra, constraints = self.encoder.build(num_locations)
"""generates the definition and model for an fa whose traces include the traces added so far"""
for num_states in range(min_states, max_states + 1):
fa, constraints = self.encoder.build(num_states)
self.solver.add(constraints)
result = self.solver.check()
if self.verbose:
print("Learning with {0} states. Result: {1}"
.format(num_locations, result))
.format(num_states, result))
if result == z3.sat:
model = self.solver.model()
self.solver.reset()
return (True, ra, model)
return (True, fa, model)
else:
self.solver.reset()
# TODO: process the unsat core?
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment