membership.rs 15.1 KB
Newer Older
1
use super::equivalence::equivalence_trait::CounterExample;
2
use crate::{
3
4
5
6
    ads::{
        traits::{AdaptiveDistinguishingSequence, AdsStatus},
        tree::threaded::Ads as OtADS,
    },
7
    automatadefs::{
8
        mealy::{InputSymbol, Mealy, OutputSymbol, State},
9
        traits::{FiniteStateMachine, ObservationTree},
10
11
12
13
14
15
16
17
18
19
20
    },
    learner::{
        apartness::{compute_witness, states_are_apart},
        obs_tree::array_tree::TreeErr,
    },
    sul::{simulator::Simulator, system_under_learning::SystemUnderLearning},
    util::{
        learning_config::{Rule2, Rule3},
        toolbox,
    },
};
Bharat Garhewal's avatar
Bharat Garhewal committed
21
use itertools::Itertools;
22
use rand::prelude::SliceRandom;
23
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
24
use std::collections::VecDeque;
Bharat's avatar
Bharat committed
25

26
pub struct Oracle<'a, T> {
27
    sul: &'a mut Simulator,
28
    _input_alphabet: Vec<InputSymbol>,
29
    obs_tree: T,
30
31
32
33
    rule2: Rule2,
    rule3: Rule3,
}

34
35
36
37
impl<'a, T> Oracle<'a, T>
where
    T: ObservationTree + Sync + Send,
{
38
    pub fn new(sul: &'a mut Simulator, rule2: Rule2, rule3: Rule3) -> Self {
39
        Self {
Bharat Garhewal's avatar
Bharat Garhewal committed
40
            _input_alphabet: sul.get_alphabet(),
41
            obs_tree: T::new(sul.get_alphabet().len()),
42
43
44
            sul,
            rule2,
            rule3,
45
46
47
48
49
50
51
        }
    }

    pub fn get_counts(&mut self) -> (usize, usize) {
        self.sul.get_counts()
    }

Bharat's avatar
Bharat committed
52
53
54
55
    pub fn borrow_sul(&mut self) -> &mut Simulator {
        self.sul
    }

56
57
58
    /// Immutably borrow the underlying observation tree.
    pub fn borrow_tree(&self) -> &T {
        &self.obs_tree
59
60
    }

61
62
    /// Mutably borrow the underlying observation tree.
    pub fn borrow_mut_tree(&mut self) -> &mut T {
63
64
65
        &mut self.obs_tree
    }

66
    /// Identifies state `fs` amongst the basis candidates to just one basis candidate.
67
68
69
    pub fn identify_frontier(&mut self, fs: State, candidates: &mut Vec<State>) {
        candidates.retain(|&b| !states_are_apart(&self.obs_tree, fs, b));
        if candidates.len() < 2 {
70
71
72
            return;
        }
        let mut prefix = self.obs_tree.get_access_seq(fs);
73
        // let mut candidates = org_candidates;
74
75
        let (input_seq, output_seq) = match self.rule3 {
            Rule3::Ads => {
76
77
78
79
80
81
82
83
                if candidates.len() == 2 {
                    let mut wit = compute_witness(&self.obs_tree, candidates[0], candidates[1])
                        .expect("No witness found between basis states!");
                    let mut input_seq = prefix;
                    input_seq.append(&mut wit);
                    let output_seq = self.output_query(&input_seq);
                    (input_seq, output_seq)
                } else {
Bharat Garhewal's avatar
Bharat Garhewal committed
84
                    let suffix = &mut OtADS::new(&self.obs_tree, candidates);
85
86
                    self.adaptive_output_query(&mut prefix, None, suffix)
                }
87
88
89
90
            }
            Rule3::SepSeq => {
                let mut local_rng = rand::thread_rng();
                let mut basis_pair = candidates.choose_multiple(&mut local_rng, 2).copied();
91
92
93
94
95
96
97
98
                let (b1, b2) = (
                    basis_pair.next().expect("Safe"),
                    basis_pair.next().expect("Safe"),
                );
                let mut wit = compute_witness(&self.obs_tree, b1, b2)
                    .expect("No witness found between basis states!");
                let mut input_seq = prefix;
                input_seq.append(&mut wit);
99
                let output_seq = self.output_query(&input_seq);
100
                (input_seq, output_seq)
101
102
            }
        };
103
104
105
        let _ = self
            .obs_tree
            .insert_observation(None, &input_seq, &output_seq);
106
107
108
        candidates.retain(|&b| !states_are_apart(&self.obs_tree, fs, b));
        if candidates.len() > 1 {
            self.identify_frontier(fs, candidates);
109
        }
110
111
    }

112
113
    /// Explores the frontier for all basis states and returns a vector of the frontier states
    /// paired with their corresponding basis candidates.
114
    #[must_use]
115
    pub fn explore_frontier(&mut self, basis: &[State]) -> Vec<(State, Vec<State>)> {
116
        self.obs_tree
117
            .no_succ_defined(basis, true)
118
119
            .iter()
            .inspect(|(q, i)| log::debug!("{:?} has {:?} undefined", *q, *i))
120
            .map(|(q, i)| self._explore_frontier(*q, *i, basis))
121
            .inspect(|(q, cands)| log::debug!("Frontier {:?} has candidates {:?}", q, cands))
122
123
124
            .collect()
    }

125
    /// Explore the frontier for a single (state, input) pair.
126
    fn _explore_frontier(
127
128
129
130
        &mut self,
        q: State,
        i: InputSymbol,
        basis: &[State],
131
    ) -> (State, Vec<State>) {
132
133
        let mut access_q = self.obs_tree.get_access_seq(q);
        // let prefix = concat_slices(&[&access_q, &[i]]);
134
135
        let (input_seq, output_seq) = match self.rule2 {
            Rule2::Ads => {
136
                log::debug!("Constructing ADS for {:?} at {:?}", q, i);
137
                let suffix = &mut OtADS::new(&self.obs_tree, basis);
138
                self.adaptive_output_query(&mut access_q, Some(i), suffix)
139
140
            }
            Rule2::Nothing => {
141
142
                let mut prefix = access_q;
                prefix.push(i);
143
                let o_seq = self.output_query(&prefix);
144
                (prefix, o_seq)
145
146
147
            }
            Rule2::SepSeq => {
                let mut local_rng = rand::thread_rng();
148
149
                let mut wit = {
                    if basis.len() > 1 {
Bharat Garhewal's avatar
Bharat Garhewal committed
150
151
152
153
154
                        let (b1, b2) = basis
                            .choose_multiple(&mut local_rng, 2)
                            .copied()
                            .collect_tuple()
                            .expect("Safe");
155
156
157
158
159
160
161
162
                        compute_witness(&self.obs_tree, b1, b2).expect("Safe")
                    } else {
                        Vec::new()
                    }
                };
                let mut input_seq = access_q;
                input_seq.push(i);
                input_seq.append(&mut wit);
163
                let output_seq = self.output_query(&input_seq);
164
                (input_seq, output_seq)
165
166
            }
        };
167
168
169
170
171
        let _ = self
            .obs_tree
            .insert_observation(None, &input_seq, &output_seq);
        let fs = self.obs_tree.get_succ(q, &[i]).expect("Safe");
        log::debug!("Added frontier {:?}.", fs);
172
        let bs_not_sep = basis
173
            .par_iter()
174
            .copied()
175
            .filter(|&b| !states_are_apart(&self.obs_tree, fs, b))
176
177
178
179
            .collect();
        (fs, bs_not_sep)
    }

180
    /// Make a non-adaptive output query.
181
    pub fn output_query(&mut self, input_seq: &[InputSymbol]) -> Vec<OutputSymbol> {
Bharat Garhewal's avatar
Bharat Garhewal committed
182
183
        let out_seq = self
            .obs_tree
184
            .get_observation(None, input_seq)
Bharat Garhewal's avatar
Bharat Garhewal committed
185
186
187
            .unwrap_or_else(|| self.sul.trace(input_seq).to_vec());
        let _ = self.obs_tree.insert_observation(None, input_seq, &out_seq);
        out_seq.to_vec()
188
189
    }

190
    /// Make an adaptive output query of the form (prefix -> infix -> ADS).
191
    pub(crate) fn adaptive_output_query<ADS: AdaptiveDistinguishingSequence>(
192
        &mut self,
193
        prefix: &mut Vec<InputSymbol>,
194
        infix: Option<InputSymbol>,
195
196
197
198
199
200
        suffix: &mut ADS,
    ) -> (Vec<InputSymbol>, Vec<OutputSymbol>) {
        if let Some(i) = infix {
            prefix.push(i);
        }
        self._adaptive_output_query(prefix, suffix)
201
202
    }

203
204
    /// Make an adaptive query, where the infix has been moved into the prefix.
    fn _adaptive_output_query<ADS: AdaptiveDistinguishingSequence>(
205
        &mut self,
206
207
208
        prefix: &mut Vec<InputSymbol>,
        suffix: &mut ADS,
    ) -> (Vec<InputSymbol>, Vec<OutputSymbol>) {
209
210
        let tree_reply = self
            .obs_tree
211
212
            .get_succ(State::new(0), prefix)
            .and_then(|curr_state| self.answer_ads_from_tree(suffix, curr_state).ok());
213
        suffix.reset_to_root();
214
215
        if tree_reply.is_some() {
            unreachable!("ADS is not increasing the norm, we already knew this information.");
216
        }
217
218
219
220
221
222
223
224
        self.sul.reset();
        let prefix_out = self.sul.step(prefix);
        let (mut suffix_inputs, suffix_outputs) = self.sul_adaptive_query(suffix);
        let input_seq = prefix;
        input_seq.append(&mut suffix_inputs); //concat_slices(&[prefix, &suffix_inputs]);
        let output_seq = toolbox::concat_slices(&[&prefix_out, &suffix_outputs]);
        let _ = self.add_observation(input_seq, &output_seq);
        (input_seq.to_vec(), output_seq)
225
226
    }

227
228
    // Assuming the prefix has been sent to the SUL, perform the adaptive query.
    fn sul_adaptive_query<ADS: AdaptiveDistinguishingSequence>(
229
        &mut self,
230
        ads: &mut ADS,
231
    ) -> (Vec<InputSymbol>, Box<[OutputSymbol]>) {
232
233
234
235
        let mut inputs_sent = Vec::new();
        let mut outputs_received = Vec::new();
        let mut last_output = None;
        loop {
236
237
238
            let next_input = ads.next_input(last_output);
            if let Ok(next_input) = next_input {
                log::debug!("Next input: {:?}", next_input);
239
                inputs_sent.push(next_input);
240
                let o = self.sul.step(&[next_input])[0];
241
242
243
                last_output = Some(o);
                outputs_received.push(o);
            } else {
244
                log::debug!("Next input undefined.");
245
246
247
                break;
            }
        }
248
        (inputs_sent, outputs_received.into_boxed_slice())
249
250
    }

251
252
    pub fn ads_equiv_test(
        &mut self,
253
        ads: &mut super::equivalence::incomplete::impl_ads::AdsTree,
254
255
256
257
258
259
260
261
262
        prefix: &[InputSymbol],
        fsm: &Mealy,
    ) -> CounterExample {
        if let Some(tree_prefix_out) = self.obs_tree.get_observation(None, prefix) {
            let hyp_out = fsm.trace(prefix).1.to_vec();
            if tree_prefix_out != hyp_out {
                return Some((prefix.to_vec(), tree_prefix_out));
            }
            let ts = self.obs_tree.get_succ(State::new(0), prefix).expect("Safe");
263
            assert!(ads.next_input(None).is_ok());
264
265
266
            let ads_res = self.answer_ads_from_tree(ads, ts);
            match ads_res {
                Ok((ads_inputs, ads_outputs)) => {
267
                    println!("Tree had a reply, duplicate query!");
268
269
                    // assert!(self.check_consistency(fsm).is_none());
                    // println!("ADS: {:#?}", ads);
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                    let mut input_seq = Vec::from(prefix);
                    input_seq.extend(ads_inputs.iter());
                    let mut output_seq = tree_prefix_out;
                    output_seq.extend(ads_outputs.iter());
                    let hyp_out = fsm.trace(&input_seq).1;
                    if output_seq != hyp_out.to_vec() {
                        return Some((input_seq, output_seq));
                    }
                }
                Err(_) => ads.reset_to_root(),
            }
        }

        self.sul.reset();
        let prefix_out = self.sul.step(prefix);
        let (mut hyp_curr, hyp_out) = fsm.trace(prefix);
286
        let hyp_prefix_state = hyp_curr;
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        if prefix_out != hyp_out {
            log::info!("Found CE in the random prefix.");
            return Some((prefix.to_vec(), prefix_out.to_vec()));
        }
        let mut inputs_sent = Vec::from_iter(prefix.iter().copied());
        let mut outputs_received = Vec::from_iter(prefix_out.iter().copied());
        let mut hyp_out;
        let mut prev_output = None;
        loop {
            let next_input = ads.next_input(prev_output);
            match next_input {
                Ok(input) => {
                    log::debug!("Input sent: {:?}", input);
                    inputs_sent.push(input);
                    let o = self.sul.step(&[input])[0];
                    (hyp_curr, hyp_out) = fsm.step_from(hyp_curr, input);
                    log::debug!("Output Received: {:?}", o);
                    prev_output = Some(o);
                    outputs_received.push(o);
                    if o != hyp_out {
                        self.add_observation(&inputs_sent, &outputs_received);
                        return Some((inputs_sent, outputs_received));
                    }
                }
                Err(ads_err) => match ads_err {
                    AdsStatus::Done => {
313
314
315
316
317
318
319
320
                        // if !ads
                        //     .tree
                        //     .ref_at(ads.curr_idx)
                        //     .initial
                        //     .contains(&hyp_prefix_state)
                        // {
                        //     panic!("Probably not the state we wanted!");
                        // }
321
322
323
324
                        self.add_observation(&inputs_sent, &outputs_received);
                        return None;
                    }
                    AdsStatus::Unexpected => {
325
                        unreachable!("IADS root must always contain the hyp state to identify.");
326
327
328
329
330
                    }
                },
            }
        }
    }
Bharat's avatar
Bharat committed
331

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    pub fn check_consistency(&self, hypothesis: &Mealy) -> CounterExample {
        let mut queue = VecDeque::from([(State::new(0), State::new(0))]);
        while !queue.is_empty() {
            let (q, r) = queue.pop_front().expect("Always safe!");
            for i in (0..hypothesis.input_alphabet().len())
                .into_iter()
                .map(|x| x as u16)
                .map(InputSymbol::from)
            {
                if let Some(((out_tree, dest_tree), (dest_hyp, out_hyp))) = self
                    .obs_tree
                    .get_out_succ(q, i)
                    .zip(Some(hypothesis.step_from(r, i)))
                {
                    if out_hyp == out_tree {
                        queue.push_back((dest_tree, dest_hyp));
                    } else {
                        let mut inputs = self.obs_tree.get_access_seq(q);
                        inputs.push(i);
                        let outputs = self.obs_tree.get_observation(None, &inputs).expect("Safe");
                        return Some((inputs, outputs));
                    }
                }
            }
        }
        None
    }

    #[allow(clippy::type_complexity)]
361
    fn answer_ads_from_tree(
362
        &self,
363
364
        ads: &mut impl AdaptiveDistinguishingSequence,
        from_state: State,
365
    ) -> Result<(Box<[InputSymbol]>, Box<[OutputSymbol]>), TreeErr> {
366
367
368
369
370
        let mut prev_output = None;
        let mut outputs_received = vec![];
        let mut inputs_sent = vec![];
        let mut curr_state = from_state;
        loop {
371
            let next_input = ads.next_input(prev_output);
372
373
374
375
            if let Ok(next_input) = next_input {
                inputs_sent.push(next_input);
                let (output, dest) = self
                    .obs_tree
376
                    .get_out_succ(curr_state, next_input)
377
378
379
380
381
382
383
384
385
386
387
388
389
390
                    .ok_or(TreeErr::AbsentEntry)?;
                prev_output = Some(output);
                outputs_received.push(output);
                curr_state = dest;
            } else {
                ads.reset_to_root();
                return Ok((
                    inputs_sent.into_boxed_slice(),
                    outputs_received.into_boxed_slice(),
                ));
            }
        }
    }

391
392
393
394
395
    pub fn add_observation(
        &mut self,
        input_seq: &[InputSymbol],
        output_seq: &[OutputSymbol],
    ) -> State {
396
397
        self.obs_tree
            .insert_observation(None, input_seq, output_seq)
398
399
    }
}