Commit 1ef7d507 authored by Cas Visser's avatar Cas Visser
Browse files

added logic for using LY ADS based on threshold as well as framework for...

added logic for using LY ADS based on threshold as well as framework for recording identification power stats
parent 92dd12d2
#!/bin/bash
models_location="./experiment_models"
#models="model1.dot model4.dot TCP_Linux_Server.dot"
models="model3.dot BitVise.dot"
thresholds="never"
for threshold in $thresholds; do
for model in $models; do
for i in {1..20}; do
echo "Running iteration ${i} of ${model} with threshold ${threshold}."
cargo run -- -e hads -m "${models_location}/${model}" --ly_ads_threshold $threshold
sed -n '2p' ./results.csv | cat >> ./ip_results/total_results.csv
mv ./identification_power_log.csv "./ip_results/${model}-${threshold}-${i}.csv"
done
done
done
\ No newline at end of file
round,ot_ads_ip,ly_ads_ip,using_ly_ads
1,0,0,false
2,1,1,false
3,2,2,false
4,3.6,4,false
5,4.571429,6,true
6,6.25,7,false
7,6.6,9,true
8,8.568182,10,false
9,9.884615,12,true
10,11.92,14,false
11,11.973215,15,true
12,13.9578,16.88889,true
13,15.698246,18,false
14,15.826666,19,true
15,17.831932,19.904762,false
16,17.872461,21,false
17,17.989038,22,true
18,20.375,23,false
19,20.48,23.92,false
20,20.576921,25,true
21,22.79227,26,false
22,22.885092,27,false
23,22.971512,28,true
24,25.547619,29,false
25,25.636175,30,false
26,25.719072,31,true
27,28.206898,32,false
28,28.283978,33,false
This diff is collapsed.
name,learned,rounds,num_states,num_inputs,learn_inputs,learn_resets,test_inputs,test_resets,ads_score,ly_ads_threshold,times_ly_ads_used,ly_ads_inputs,ly_ads_resets,learning_algorithm
model4.dot,true,22,34,14,27988,3965,32686,1545,NaN,1.2,10,347,44,lsharp
......@@ -229,7 +229,7 @@ impl AdsTree {
u_i as f32,
u_i_o as f32,
u_i as f32,
0.99 * child_score,
child_score,
);
}
......
......@@ -13,6 +13,8 @@ use rayon::prelude::*;
use std::collections::VecDeque;
#[allow(unused_imports)]
use std::time::Instant;
use std::io::Write;
use std::fs::OpenOptions;
type Witness = Option<Box<InputWord>>;
pub struct LSharp<'a> {
......@@ -36,6 +38,10 @@ pub struct LSharp<'a> {
ads_id_power_rule3: Vec<f32>,
saved_ads: Option<LyADS>,
use_ly_ads: usize,
ly_ads_threshold: f32,
times_ly_ads_used: usize,
ly_ads_inputs: usize,
ly_ads_resets: usize,
}
impl<'a> LSharp<'a> {
......@@ -46,6 +52,7 @@ impl<'a> LSharp<'a> {
rule2_mode: Rule2,
rule3_mode: Rule3,
use_ly_ads: usize,
ly_ads_threshold: f32,
) -> Self {
let observation_tree = ObsTree::new(sul.get_alphabet());
Self {
......@@ -62,6 +69,10 @@ impl<'a> LSharp<'a> {
round: 0,
saved_ads: None,
use_ly_ads,
ly_ads_threshold,
times_ly_ads_used: 0,
ly_ads_inputs: 0,
ly_ads_resets: 0,
}
}
......@@ -587,10 +598,10 @@ impl<'a> LSharp<'a> {
ads: &mut (impl AdaptiveDistinguishingSequence + std::fmt::Debug),
from_state: Option<State>,
) -> (Box<InputWord>, Box<OutputWord>) {
log::info!(
"ADS : \n {}",
String::from_utf8(ads.get_print_tree().to_vec()).unwrap()
);
// log::info!(
// "ADS : \n {}",
// String::from_utf8(ads.get_print_tree().to_vec()).unwrap()
// );
let start_state = from_state.unwrap_or_else(|| State::from(0));
let access_seq = self.observation_tree.get_access_seq(start_state);
let _curr_state_sul = self.sul.step(&access_seq);
......@@ -833,8 +844,42 @@ impl<'a> LSharp<'a> {
self.make_obs_tree_adequate();
let mut ret = self.construct_hypothesis();
// Print statements for comparing identification power of LY ADS and OT ADS
let ly_ads = LyADS::new(&ret);
let ot_ads = OtADS::new(&self.observation_tree, &self.basis);
let use_ly_ads_now = ly_ads.identification_power() >= self.ly_ads_threshold * ot_ads.identification_power()
&& ly_ads.identification_power() > 0.0
&& self.ly_ads_threshold >= 0.0 ;
println!("LY ADS IP: {:.3}", ly_ads.identification_power());
println!("OT ADS IP: {:.3}", ot_ads.identification_power());
let mut ip_vec = Vec::new();
writeln!(
&mut ip_vec,
"{},{},{},{}",
self.round,
ot_ads.identification_power(),
ly_ads.identification_power(),
use_ly_ads_now
).expect("Could not write to the ip log vector!");
let file = "identification_power_log.csv";
let mut f = OpenOptions::new()
.write(true)
.append(true)
.open(file)
.unwrap_or_else(|file| panic!("Could not open file {}", file));
f.write_all(&ip_vec)
.unwrap_or_else(|file| panic!("Write exception when writing to file {}", file));
ip_vec.clear();
#[allow(clippy::collapsible_if)]
if self.use_ly_ads > 0 && self.round % self.use_ly_ads == 0 {
if use_ly_ads_now {
println!("Using LY ADS");
self.times_ly_ads_used += 1;
let (i_before, r_before) = self.sul.get_counts();
loop {
let (inconsistent_hyp, ly_ads) = self.check_hypothesis_with_ads(&ret);
if inconsistent_hyp {
......@@ -845,6 +890,10 @@ impl<'a> LSharp<'a> {
break;
}
}
let (i_after, r_after) = self.sul.get_counts();
self.ly_ads_inputs += i_after - i_before;
self.ly_ads_resets += r_after - r_before;
}
log::info!(
"ADS identification powers for rule 2 {:?}",
......@@ -855,13 +904,6 @@ impl<'a> LSharp<'a> {
self.ads_id_power_rule3
);
// Print statements for comparing identification power of LY ADS and OT ADS
// let ly_ads = LyADS::new(&ret);
// let ot_ads = OtADS::new(&self.observation_tree, &self.basis);
// println!("LY ADS: {:?}", ly_ads);
// println!("ADS IP: {:.3}", ly_ads.identification_power());
// println!("OT IP: {:.3}", ot_ads.identification_power());
ret
}
......@@ -957,6 +999,18 @@ impl<'a> LSharp<'a> {
}
(inconsistent, Some(ly_ads))
}
pub fn get_times_ly_ads_used(&self) -> usize {
self.times_ly_ads_used
}
pub fn get_ly_ads_inputs(&self) -> usize {
self.ly_ads_inputs
}
pub fn get_ly_ads_resets(&self) -> usize {
self.ly_ads_resets
}
}
fn concat_slices<T: Copy>(slices: &[&[T]]) -> Vec<T> {
......
......@@ -18,6 +18,7 @@ pub fn learn_fsm(
seed: i32,
logs: Option<Vec<(Box<[InputSymbol]>, Box<[OutputSymbol]>)>>,
use_ly_ads: usize,
ly_ads_threshold: f32,
) -> LearnResult {
let mealy_machine = Arc::new(sul.clone());
let mut sut = Simulator::new(Arc::clone(&mealy_machine));
......@@ -44,6 +45,7 @@ pub fn learn_fsm(
rule2_mode,
rule3_mode,
use_ly_ads,
ly_ads_threshold,
);
let mut idx = 1;
learner.init_obs_tree(logs);
......@@ -128,6 +130,9 @@ pub fn learn_fsm(
}
let ads_score = learner.get_ads_score();
let times_ly_ads_used = learner.get_times_ly_ads_used();
let ly_ads_inputs = learner.get_ly_ads_inputs();
let ly_ads_resets = learner.get_ly_ads_resets();
let (i, r) = sut.get_counts();
LearnResult {
learn_inputs: i,
......@@ -140,5 +145,9 @@ pub fn learn_fsm(
num_inputs: rev_input_map.len(),
num_states: hyp_states,
ads_score,
ly_ads_threshold,
times_ly_ads_used,
ly_ads_inputs,
ly_ads_resets,
}
}
......@@ -59,10 +59,24 @@ fn main() {
EqOracle::Hads
};
let _seed = matches.value_of("seed");
let ly_ads_threshold = matches.value_of("ly_ads_threshold").unwrap().parse().unwrap_or(-1.0);
let mut ip_vec = Vec::new();
writeln!(
&mut ip_vec,
"round,ot_ads_ip,ly_ads_ip,using_ly_ads"
).expect("Could not write to the ip log vector!");
let file = "identification_power_log.csv";
let mut f =
std::fs::File::create(file).unwrap_or_else(|file| panic!("Could not create file {}", file));
f.write_all(&ip_vec)
.unwrap_or_else(|file| panic!("Write exception when writing to file {}", file));
ip_vec.clear();
let mut result_vec = Vec::new();
writeln!(
&mut result_vec,
"name,learned,rounds,num_states,num_inputs,learn_inputs,learn_resets,test_inputs,test_resets,ads_score,learning_algorithm"
"name,learned,rounds,num_states,num_inputs,learn_inputs,learn_resets,test_inputs,test_resets,ads_score,ly_ads_threshold,times_ly_ads_used,ly_ads_inputs,ly_ads_resets,learning_algorithm"
).expect("Could not write to the results vector!");
let file = "results.csv";
let mut f =
......@@ -70,10 +84,12 @@ fn main() {
f.write_all(&result_vec)
.unwrap_or_else(|file| panic!("Write exception when writing to file {}", file));
result_vec.clear();
let num_repeat = match matches.value_of("repeat") {
Some(n) => n.parse::<usize>().unwrap(),
None => 1,
};
let log_path = matches.value_of("log");
let mut rng = rand::thread_rng();
if Path::new(path_name).is_file() {
......@@ -90,6 +106,7 @@ fn main() {
seed,
logs,
use_ly_ads,
ly_ads_threshold,
);
println!("Learning finished!\n{}", learn_results);
writeln!(
......@@ -126,6 +143,7 @@ fn main() {
seed,
None,
use_ly_ads,
ly_ads_threshold,
);
println!("Learning finished!\n{}", learn_results);
writeln!(
......
......@@ -50,19 +50,16 @@ impl AdaptiveDistinguishingSequence for AdsTree {
)
}
// Identification power is the expected number of states that can be
// excluded, normalized from 0 to 1, or more mathematically: the
// probability of ending up in a certain leaf, times the number of states
// that that leaf can distinguish.
// excluded, or more mathematically: the probability of ending up in a
// certain leaf, times the number of states that that leaf can distinguish.
fn identification_power(&self) -> f32 {
(leaves.map(|leaf| leaf.len() * (n_states - leaf.len())).sum()) / (n_states * (n_states - 1))
let n_states = self.get_leaves_indices().iter().map(|leaf_idx|
self.ads.arena[*leaf_idx].val.initial_states.len()).sum::<usize>();
self.get_leaves_indices().iter().map(|leaf_idx| {
let leaf_size = self.ads.arena[*leaf_idx].val.initial_states.len();
leaf_size * (n_states - leaf_size)
}).sum::<usize>() as f32 / (n_states * (n_states - 1)) as f32
}).sum::<usize>() as f32 / n_states as f32
}
fn reset_to_root(&mut self) {
......
......@@ -114,6 +114,14 @@ pub fn parse() -> Result<clap::ArgMatches<'static>, Box<dyn std::error::Error>>
.takes_value(true)
.multiple(true),
)
.arg(
Arg::with_name("ly_ads_threshold")
.long("ly_ads_threshold")
.value_name("ly_ads_threshold")
.help("Run the LY_ADS from every state in the basis on the hypothesis if it is at least ly_ads_threshold times better. 0 means use always. Pass 'never' to use never.")
.default_value("never")
.takes_value(true)
)
.get_matches();
Ok(matches)
}
......@@ -31,12 +31,16 @@ pub struct LearnResult {
pub num_states: usize,
pub num_inputs: usize,
pub ads_score: f32,
pub ly_ads_threshold: f32,
pub times_ly_ads_used: usize,
pub ly_ads_inputs: usize,
pub ly_ads_resets: usize,
}
impl LearnResult {
pub fn to_csv_entry(&self) -> String {
format!(
"{},{},{},{},{},{},{},{},{}",
"{},{},{},{},{},{},{},{},{},{},{},{},{}",
self.success,
self.rounds,
self.num_states,
......@@ -45,7 +49,11 @@ impl LearnResult {
self.learn_resets,
self.test_inputs,
self.test_resets,
self.ads_score
self.ads_score,
self.ly_ads_threshold,
self.times_ly_ads_used,
self.ly_ads_inputs,
self.ly_ads_resets
)
}
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment