ran.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import itertools
import z3

from encode import Encoder
from define.ra import SimpleRegisterAutomaton


class RANEncoder(Encoder):
    def __init__(self):
        self.trie = RANEncoder.Trie(itertools.count(0))
        self.cache = {}
        self.values = set()

    def add(self, trace):
        seq, accept = trace
        node = self.trie[determinize(seq)]
        self.cache[node] = accept
        self.values.update([action.value for action in seq])

    def build(self, ra, initialized=True):
        mapper = RANEncoder.Mapper(ra)
        return self.axioms(ra, mapper, initialized) + \
               self.output_constraints(ra, mapper) + \
               self.transition_constraints(ra, mapper)

    def print_tree(self):
        print(self.trie)

    @staticmethod
    def axioms(ra : SimpleRegisterAutomaton, mapper, initialized):
        l = z3.Const('l', ra.Label)
        q, qp = z3.Consts('q qp', ra.Location)
        r, rp = z3.Consts('r rp', ra.Register)
        axioms = [
            # In the start state of the mapper,
            # all registers contain an uninitialized value.
            z3.ForAll(
                [r],
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
39
                mapper.valuation(mapper.start, r) == mapper.init
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
            ),

            # If two locations are connected with both register and fresh transitions,
            # then you have to do an update on a different register (otherwise you should merge the two transitions)
            z3.ForAll(
                [q, l, r],
                z3.Implies(
                    z3.And(
                        r != ra.fresh,
                        ra.transition(q, l, ra.fresh) == ra.transition(q, l, r),
                    ),
                    z3.And(
                        ra.update(q, l) != ra.fresh,
                        ra.update(q, l) != r
                    )
                )
            ),

            # The fresh register is never used
            z3.ForAll(
                [q],
                ra.used(q, ra.fresh) == False
            ),

            # If a variable is used after a transition, it means it was either used before, or it was assigned
            z3.ForAll(
                [q, l, r, rp],
                z3.Implies(
                    z3.And(
                        ra.used(ra.transition(q, l, rp), r) == True
                    ),
                    z3.Or(
                        ra.used(q, r) == True,
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
73
74
75
76
                        z3.And(
                            ra.update(q, l) == r,
                            rp == ra.fresh
                        ),
77
78
79
80
81
82
83
84
85
86
87
88
89
90
                    )
                )
            ),

            # If a variable is updated, then it should have been used.
            z3.ForAll(
                [q, l, r],
                z3.Implies(
                    z3.And(
                        r != ra.fresh,
                        ra.update(q, l) == r
                    ),
                    ra.used(ra.transition(q, l, ra.fresh), r) == True
                )
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
91
            ),
92
93

            # Registers are not used in the start state
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
94
95
96
97
98
            z3.ForAll(
                [r],
                ra.used(ra.start, r) == False
            )
        ]
99
100
101
102

        return axioms

    def output_constraints(self, ra, mapper):
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        constraints = []
        r = z3.Const('r', ra.Register)
        rp = z3.Const('rp', ra.Register)
        for node, accept in self.cache.items():
            n = mapper.element(node.id)
            constraints.extend(
                [ra.output(mapper.map(n)) == accept,
                 # z3.ForAll([r,rp],
                 #           z3.Implies(
                 #               z3.And(
                 #               r != rp,
                 #               r != ra.fresh,
                 #               rp != ra.fresh
                 #               ),
                 #               mapper.valuation(n, r) != mapper.valuation(n, rp)
                 #           )
                 #        )
                ]
            )

        return constraints
124
125
126
127
128
129
130
131
132
133

    def transition_constraints(self, ra, mapper):
        constraints = [ra.start == mapper.map(mapper.start)]
        values = {mapper.init}
        for node, label, value, child in self.trie.transitions():
            n = mapper.element(node.id)
            l = ra.labels[label]
            v = mapper.value(value)
            c = mapper.element(child.id)
            r = z3.Const('r', ra.Register)
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
134
            rp = z3.Const('rp', ra.Register)
135
136
137
138
139
140
141
142

            constraints.extend([
                # If the transition is over a register, then the register is in use.
                z3.ForAll(
                    [r],
                    z3.Implies(
                        z3.And(
                            r!= ra.fresh,
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
143
144
                            ra.transition(mapper.map(n), l, r) == mapper.map(c)
                        ),
145
146
147
148
149
                        ra.used(mapper.map(n), r) == True
                    )
                ),

                # If a non-fresh register has changed, it must have been updated
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
150
                # and the transition
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
                # what if not used?
                z3.ForAll(
                    [r],
                    z3.Implies(
                        z3.And(
                            r != ra.fresh,
                            mapper.valuation(c, r) != mapper.valuation(n, r)),
                        ra.update(mapper.map(n), l) == r
                    )
                ),

                z3.ForAll(
                    [r],
                    z3.Implies(
                        z3.And(
                            r != ra.fresh,
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
167
168
169
170
                            mapper.valuation(c, r) == mapper.valuation(n, r),
                            ra.used(mapper.map(n), r) == True
                        ),
                        ra.update(mapper.map(n), l) != r
171
172
173
                    )
                ),

Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

                z3.ForAll(
                    [r],
                    z3.If(
                        z3.And(
                            r != ra.fresh,
                            ra.update(mapper.map(n), l) == r,
                            ra.transition(mapper.map(n), l, ra.fresh) == mapper.map(c)
                        ),
                        z3.Exists(
                            [rp],
                            z3.And(
                                rp != ra.fresh,
                                mapper.valuation(c, rp) == v
                            )
                        ),
                        mapper.valuation(c, r) == mapper.valuation(n, r)
                    )
                ),

                # If a register is updated, then the node reached with a fresh transition must contain the new value
                # or
                # z3.ForAll(
                #     [r],
                #     z3.If(
                #         z3.And(
                #             r != ra.fresh,
                #             ra.update(mapper.map(n), l) == r,
                #             ra.transition(mapper.map(n), l, ra.fresh) == mapper.map(c)
                #         ),
                #         mapper.valuation(c, r) == v,
                #         mapper.valuation(c, r) == mapper(n, r)
                #     )
                # ),

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
                # Map to the right transition
                z3.If(
                    z3.Exists(
                        [r],
                        z3.And(
                            r != ra.fresh,
                            mapper.valuation(n, r) == v
                        )
                    ),
                    z3.ForAll(
                        [r],
                        z3.Implies(
                            z3.And(
                                r != ra.fresh,
                                mapper.valuation(n, r) == v
                            ),
                            z3.If(
                                ra.used(mapper.map(n), r) == True,
Paul Fiterau Brostean's avatar
Paul Fiterau Brostean committed
227
                                # it might not keep the valuation
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
                                ra.transition(mapper.map(n), l, r) == mapper.map(c),
                                ra.transition(mapper.map(n), l, ra.fresh) == mapper.map(c),
                            )
                        )
                    ),
                    ra.transition(mapper.map(n), l, ra.fresh) == mapper.map(c)),
            ])
            values.add(v)

        constraints.append(z3.Distinct(list(values)))
        return constraints

    @staticmethod
    class Trie(object):
        def __init__(self, counter):
            self.id = next(counter)
            self.counter = counter
            self.children = {}

        def __getitem__(self, seq):
            trie = self
            for label, value in seq:
                if (label, value) not in trie.children:
                    trie.children[(label, value)] = RANEncoder.Trie(self.counter)
                trie = trie.children[(label, value)]
            return trie

        def __iter__(self):
            yield self
            for node in itertools.chain(*map(iter, self.children.values())):
                yield node

        def transitions(self):
            for node in self:
                for label, value in node.children:
                    yield node, label, value, node.children[(label, value)]

        def __str__(self, tabs=0):
            space = "".join(["\t" for _ in range(0,tabs)])
           # print(space, "n", self.id)
            tree = "(n{0}".format(self.id)
            for label, value in self.children:
                tree+= "\n" + space + " {0} -> {1}".format(value, self.children[(label, value)].__str__(tabs=tabs+1))
            tree += ")"
            return tree


    @staticmethod
    class Mapper(object):
        def __init__(self, ra):
            self.Value = z3.DeclareSort('Value')
            self.Element = z3.DeclareSort('Element')
            self.start = self.element(0)
            self.init = self.value("_")
            self.map = z3.Function('map', self.Element, ra.Location)
            self.valuation = z3.Function('valuation', self.Element, ra.Register, self.Value)

        def value(self, name):
            return z3.Const("v"+str(name), self.Value)

        def element(self, name):
            return z3.Const("n"+str(name), self.Element)


def determinize(seq):
    neat = {}
    i = 0
    for (label, value) in seq:
        if value not in neat:
            neat[value] = i
            i = i + 1
    return [(label, neat[value]) for label, value in seq]