Commit 11813c0d authored by Paul Fiterau Brostean's avatar Paul Fiterau Brostean
Browse files

Generation kinda works now.

parent e13f2dd3
......@@ -19,15 +19,25 @@ class SUT(metaclass=ABCMeta):
"""Runs the list of inputs or input signatures comprising the input interface"""
ActionSignature = collections.namedtuple("ActionSignature", ('label', 'num_params'))
class RASUT(metaclass=ABCMeta):
def input_interface(self) -> List[ActionSignature]:
class ObjectSUT(SUT):
"""Wraps a"""
def run(self, seq:List[Action]):
"""Runs a sequence of inputs on the SUT and returns an observation"""
class ObjectSUT(RASUT):
"""Wraps around an object and calls methods on it corresponding to the Actions"""
def __init__(self, act_sigs, obj_gen):
self.obj_gen = obj_gen
self.acts = {act_sig.label:act_sig for act_sig in act_sigs}
def run(self, seq:List[object]):
def run(self, seq:List[Action]):
obj = self.obj_gen()
values = set()
out_seq = []
......@@ -40,8 +50,9 @@ class ObjectSUT(SUT):
outp = meth(val)
outp_action = self.parse_out(outp)
out_seq[:-1] = outp_action
return list(zip(seq, out_seq))
obs = list(zip(seq, out_seq))
return obs
def parse_out(self, outp) -> Action:
......@@ -51,11 +62,11 @@ class ObjectSUT(SUT):
if isinstance(outp, str):
return Action(outp, fresh)
if isinstance(outp, int):
return ("int", outp)
return Action("int", outp)
if isinstance(outp, tuple) and len(outp) == 2:
(lbl, val) = outp
if isinstance(val, int) and isinstance(lbl, str):
return outp
return Action(lbl, val)
raise Exception("Cannot process output")
def input_interface(self) -> List[ActionSignature]:
......@@ -10,7 +10,7 @@ class Stack():
def get(self):
if len(self.list) == 0:
return SUT.OK
return SUT.NOK
return ("OGET", self.list.pop())
......@@ -22,5 +22,5 @@ class Stack():
return SUT.NOK
def new_stack_sul(size):
return ObjectSUT(lambda : Stack(size), Stack.INTERFACE)
def new_stack_sut(size):
return ObjectSUT(Stack.INTERFACE, lambda : Stack(size))
from abc import ABCMeta, abstractmethod
from typing import List
from sut import SUT, ActionSignature
from typing import List, Tuple
from encode.iora import IORAEncoder
from learn.algorithm import learn
from learn.ra import RALearner
from model.ra import Action
from sut import SUT, ActionSignature, RASUT
# class RAObservation():
......@@ -9,6 +14,9 @@ from sut import SUT, ActionSignature
# def values(self):
# for
from sut.stack import new_stack_sut
from test import IORATest
class ObservationGeneration(metaclass=ABCMeta):
......@@ -17,15 +25,19 @@ class ObservationGeneration(metaclass=ABCMeta):
class ExhaustiveRAGenerator(ObservationGeneration):
def __init__(self, sut:SUT, act_sigs:List[ActionSignature]):
def __init__(self, sut:RASUT):
self.sut = sut
self.act_sigs = act_sigs
for sig in act_sigs:
self.act_sigs = sut.input_interface()
for sig in self.act_sigs:
if sig.num_params > 1:
raise Exception("This generator assumes at most one parameter per action")
def generate_observations(self, max_depth, max_registers=3) -> List[list]:
return self._generate_observations([0,[]], 0, max_depth, max_registers)
def generate_observations(self, max_depth, max_registers=3) -> List[Tuple[Action, Action]]:
val_obs = self._generate_observations([(0,[])], 0, max_depth, max_registers)
obs = [obs for (_, obs) in val_obs]
return obs
def _generate_observations(self, prev_obs, crt_depth, max_depth, max_values):
if crt_depth > max_depth:
......@@ -36,16 +48,24 @@ class ExhaustiveRAGenerator(ObservationGeneration):
for act_sig in self.act_sigs:
label = act_sig.label
if act_sig.num_params == 1:
for i in range(0, max(num_val, max_values)):
for i in range(0, min(num_val+1, max_values)):
seq = [inp for (inp, _) in obs]
seq.append((label, i))
new_obs[:-1] = (max(num_val, i),
seq.append(Action(label, i))
new_obs.append((max(num_val+1, i),
seq = [inp for (inp, _) in obs]
seq = seq.append((label, None))
new_obs[:-1] = (num_val,
seq.append(Action(label, None))
if crt_depth < max_depth:
new_obs.extend(self._generate_observations(new_obs, crt_depth+1, max_depth, max_values))
extended_obs = self._generate_observations(new_obs, crt_depth + 1, max_depth, max_values)
return new_obs
stack_sut = new_stack_sut(1)
gen = ExhaustiveRAGenerator(stack_sut)
obs = gen.generate_observations(2)
print("\n".join( [str(obs) for obs in obs]))
learner = RALearner(IORAEncoder())
learn(learner, IORATest, obs)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment