Commit 6f7c6bac authored by Paul Fiterau Brostean's avatar Paul Fiterau Brostean
Browse files

Somewhat working exporter with correct fresh functions

parent 08a23a73
......@@ -24,7 +24,6 @@ class RaLearnerTest(unittest.TestCase):
def check_scenario(self, test_scenario : RaTestScenario):
print("Scenario " + test_scenario.description)
#result = self.learn_model(test_scenario)
(succ, ra, model) = self.learn_model(test_scenario)
self.assertTrue(succ, msg="Register Automaton could not be inferred")
......@@ -35,9 +34,9 @@ class RaLearnerTest(unittest.TestCase):
exported = ra.export(model)
print("Learned model: \n",exported)
self.assertEqual(len(exported.states()), test_scenario.nr_locations,
"Wrong number of locations in exported model. ")
self.assertEqual(len(exported.registers()), test_scenario.nr_locations,
"Wrong number of registers in exported model. ")
"Wrong number of locations in exported model. ")
self.assertEqual(len(exported.registers()), test_scenario.nr_registers,
"Wrong number of registers in exported model. ")
self.check_ra_against_obs(exported, test_scenario)
......@@ -45,7 +44,7 @@ class RaLearnerTest(unittest.TestCase):
"""Checks if the learned RA corresponds to the scenario observations"""
for trace, acc in test_scenario.traces:
self.assertEqual(learned_ra.accepts(trace), acc,
"Register automaton output doesn't correspond to observation {1}".format(str(trace)))
"Register automaton output doesn't correspond to observation {0}".format(str(trace)))
def learn_model(self, test_scenario : RaTestScenario) -> \
(bool, RegisterAutomaton, z3.ModelRef):
......
......@@ -42,7 +42,7 @@ class RALearner(Learner):
self.num_registers = num_registers
num_values = len(self.encoder.values)
for num_locations in range(max(self.num_locations, min_locations), max_locations + 1):
for num_registers in range(self.num_registers, max(self.num_registers, min(num_values, max_locations))):
for num_registers in range(self.num_registers, max(self.num_registers, min(num_values, num_locations))):
if self.io:
ra = IORegisterAutomaton(inputs=self.labels, outputs=self.outputs, num_locations=num_locations, num_registers=num_registers)
else:
......
......@@ -108,7 +108,7 @@ class Guard(metaclass=ABCMeta):
@abstractmethod
def get_registers(self):
def registers(self):
"""Returns the registers/constants over which the guard is formed"""
pass
......@@ -121,16 +121,16 @@ class EqualityGuard(Guard):
"""An equality guard holds iff. the parameter value is equal to the value assigned to its register."""
def __init__(self, register):
super().__init__()
self.register = register
self._register = register
def is_satisfied(self, valuation, value):
return valuation[self.register] == value
return valuation[self._register] == value
def get_registers(self):
return [self.register]
def registers(self):
return [self._register]
def __str__(self):
return "p=={0}".format(str(self.register))
return "p=={0}".format(str(self._register))
class OrGuard(Guard):
def __init__(self, guards):
......@@ -142,7 +142,7 @@ class OrGuard(Guard):
return True
return False
def get_registers(self):
def registers(self):
regs_of_guards = [guard.registers() for guard in self.guards]
regs = itertools.chain.from_iterable(regs_of_guards)
seen = set()
......@@ -151,30 +151,30 @@ class OrGuard(Guard):
def __str__(self):
all_guards = [str(guard) for guard in self.guards]
return "\\/".join(all_guards)
return " \\/ ".join(all_guards)
class FreshGuard(Guard):
"""An fresh guard holds if the parameter value is different from the value assigned to any of its registers."""
def __init__(self, guarded_registers = []):
super().__init__()
self.registers = guarded_registers
self._registers = guarded_registers
def is_satisfied(self, valuation, value):
for register in self.registers:
for register in self._registers:
if valuation[register] == value:
return False
return True
def get_registers(self):
return self.registers
def registers(self):
return self._registers
def __str__(self):
if len(self.registers) == 0:
if len(self._registers) == 0:
return "True"
else:
all_deq = ["p!={0}".format(str(reg)) for reg in self.registers]
return "/\\".join(all_deq)
all_deq = ["p!={0}".format(str(reg)) for reg in self._registers]
return " /\\ ".join(all_deq)
class Assignment(metaclass=ABCMeta):
"""An assignment updates the valuation of registers using the old valuation and the current parameter value"""
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment