Commit f0a3bf82 authored by SirBlueRabbit's avatar SirBlueRabbit
Browse files

refactored mh_prob calculation

parent 256e90b8
......@@ -2,7 +2,7 @@ mod triangulation;
use clap::Parser;
use serde_derive::Deserialize;
use std::{error::Error, fs, io::Write, path::Path, time};
use triangulation::{observable::Observable, Model, Precomputed, StepCount, Triangulation};
use triangulation::{observable::Observable, MHProbs, Model, StepCount, Triangulation};
fn main() -> Result<(), Box<dyn Error>> {
std::env::set_var("RUST_BACKTRACE", "1");
......@@ -14,12 +14,12 @@ fn main() -> Result<(), Box<dyn Error>> {
// load config
let config = load_config(&args)?;
// precompute selection probabilities
let precomputed = Precomputed::new(&config.model);
// precompute Metropolis-Hastings probabilities
let mh_probs = MHProbs::new(&config.model);
// equilibration phase
let mut step_count_therm = StepCount::default();
let mut triangulation = thermalise(&config, &precomputed, &mut step_count_therm);
let mut triangulation = thermalise(&config, &mh_probs, &mut step_count_therm);
println!("=== Thermalisation steps ===");
println!("{}", step_count_therm);
......@@ -27,7 +27,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let mut step_count_meas = StepCount::default();
measure(
&config,
&precomputed,
&mh_probs,
args.index,
&mut triangulation,
&mut step_count_meas,
......@@ -45,21 +45,17 @@ fn load_config(args: &Args) -> Result<Config, Box<dyn Error>> {
Ok(toml::from_str(&config_toml)?)
}
fn thermalise(
config: &Config,
precomputed: &Precomputed,
step_count: &mut StepCount,
) -> Triangulation {
let mut triangulation = Triangulation::with_capacity(2 * config.model.sweep() + 12);
for _ in 0..(config.model.sweep() * config.markov_chain.thermalisation) {
triangulation.mc_step(config.model, precomputed, step_count);
fn thermalise(config: &Config, mh_probs: &MHProbs, step_count: &mut StepCount) -> Triangulation {
let mut triangulation = Triangulation::with_capacity(12 * (2 * config.model.target_volume + 1));
for _ in 0..(config.model.target_volume * config.markov_chain.thermalisation) {
triangulation.mc_step(config.model, mh_probs, step_count);
}
triangulation
}
fn measure(
config: &Config,
precomputed: &Precomputed,
mh_probs: &MHProbs,
index: Option<usize>,
triangulation: &mut Triangulation,
step_count: &mut StepCount,
......@@ -89,8 +85,8 @@ fn measure(
// execute measurement phase and write results to files
for _ in 0..(config.markov_chain.amount) {
for _ in 0..(config.model.sweep() * config.markov_chain.wait) {
triangulation.mc_step(config.model, precomputed, step_count);
for _ in 0..(config.model.target_volume * config.markov_chain.wait) {
triangulation.mc_step(config.model, mh_probs, step_count);
}
for (observable, file) in config
.measurement
......
use super::PrecomputedMove;
use super::MHProbs;
use super::{collections::Label, HalfEdge, Step, Triangulation, Weights};
impl Triangulation {
pub fn choose_step(
&self,
weights: Weights,
shard: &PrecomputedMove,
stars: &PrecomputedMove,
) -> Step {
pub fn choose_step(&self, weights: Weights, mh_probs: &MHProbs) -> Step {
// sample step
let step = self.sample_step(weights, &shard.select_grow, &stars.select_grow);
let step = self.sample_step(weights, mh_probs);
// validate step
if self.validate_step(step) {
// accept or reject step
if self.accept_step(
step,
&shard.accept_grow,
&shard.accept_shrink,
&stars.accept_grow,
&stars.accept_shrink,
) {
if self.accept_step(step, mh_probs) {
step
} else {
Step::Rejected
......@@ -30,14 +19,14 @@ impl Triangulation {
}
}
fn sample_step(&self, weights: Weights, shard_probs: &[f32], stars_probs: &[f32]) -> Step {
fn sample_step(&self, weights: Weights, mh_probs: &MHProbs) -> Step {
let cum_weights = weights.cumulative();
let seed = fastrand::usize(0..*cum_weights.last().unwrap());
if seed < cum_weights[0] {
self.sample_shard(shard_probs)
self.sample_shard(&mh_probs.shard_select_grow)
} else if seed < cum_weights[1] {
self.sample_stars(stars_probs)
self.sample_stars(&mh_probs.stars_select_grow)
} else {
unreachable!("Step type seed outside valid range");
}
......@@ -80,20 +69,13 @@ impl Triangulation {
}
}
fn accept_step(
&self,
step: Step,
shard_grow: &[f32],
shard_shrink: &[f32],
stars_grow: &[f32],
stars_shrink: &[f32],
) -> bool {
fn accept_step(&self, step: Step, mh_probs: &MHProbs) -> bool {
let n3 = self.volume();
let p_accept = match step {
Step::Move02(_) => shard_grow[n3],
Step::Move20(_) => shard_shrink[n3],
Step::Move23(_) => stars_grow[n3],
Step::Move32(_) => stars_shrink[n3],
Step::Move02(_) => mh_probs.shard_accept_grow[n3],
Step::Move20(_) => mh_probs.shard_accept_shrink[n3],
Step::Move23(_) => mh_probs.stars_accept_grow[n3],
Step::Move32(_) => mh_probs.stars_accept_shrink[n3],
Step::Invalid => return true,
Step::Rejected => unreachable!(),
};
......
......@@ -4,32 +4,12 @@ use std::{
ops::{Index, IndexMut},
};
#[derive(Debug, Clone)]
pub struct Pool<T: Copy> {
elements: Box<[Element<T>]>,
current_hole: usize,
size: usize,
}
#[derive(Debug, Clone)]
pub struct Bag<T> {
indices: Box<[Option<usize>]>,
labels: Box<[Option<Label<T>>]>,
size: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Label<T> {
value: usize,
pub value: usize,
object_type: std::marker::PhantomData<T>,
}
#[derive(Debug, Clone, Copy)]
enum Element<T: Copy> {
Object(T),
Hole(usize),
}
impl<T: Copy> From<usize> for Label<T> {
fn from(value: usize) -> Self {
Label {
......@@ -39,30 +19,23 @@ impl<T: Copy> From<usize> for Label<T> {
}
}
impl<T: Copy> Label<T> {
pub fn value(&self) -> usize {
self.value
impl<T: Copy> fmt::Display for Label<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl<T: Copy> Index<Label<T>> for Pool<T> {
type Output = T;
fn index(&self, label: Label<T>) -> &Self::Output {
match &self.elements[label.value] {
Element::Object(object) => object,
Element::Hole(_) => panic!("Label {} is not in use!", label.value),
}
}
#[derive(Debug, Clone, Copy)]
enum Element<T: Copy> {
Object(T),
Hole(usize),
}
impl<T: Copy> IndexMut<Label<T>> for Pool<T> {
fn index_mut(&mut self, label: Label<T>) -> &mut Self::Output {
match &mut self.elements[label.value] {
Element::Object(object) => object,
Element::Hole(_) => panic!("Label {} is not in use!", label.value),
}
}
#[derive(Debug, Clone)]
pub struct Pool<T: Copy> {
elements: Box<[Element<T>]>,
current_hole: usize,
size: usize,
}
impl<T: Copy> Pool<T> {
......@@ -118,6 +91,48 @@ impl<T: Copy> Pool<T> {
}
}
impl<T: Copy> Index<Label<T>> for Pool<T> {
type Output = T;
fn index(&self, label: Label<T>) -> &Self::Output {
match &self.elements[label.value] {
Element::Object(object) => object,
Element::Hole(_) => panic!("Label {} is not in use!", label.value),
}
}
}
impl<T: Copy> IndexMut<Label<T>> for Pool<T> {
fn index_mut(&mut self, label: Label<T>) -> &mut Self::Output {
match &mut self.elements[label.value] {
Element::Object(object) => object,
Element::Hole(_) => panic!("Label {} is not in use!", label.value),
}
}
}
impl<T: Copy + fmt::Display> fmt::Display for Pool<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.elements
.iter()
.enumerate()
.try_for_each(|(i, element)| match element {
Element::Hole(_) => Ok(()),
Element::Object(obj) => {
writeln!(f, "\t[{}]: \t{}", i, obj)
}
})?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Bag<T> {
indices: Box<[Option<usize>]>,
labels: Box<[Option<Label<T>>]>,
size: usize,
}
impl<T: Copy> Bag<T> {
pub fn with_capacity(capacity: usize) -> Bag<T> {
Bag {
......@@ -158,21 +173,6 @@ impl<T: Copy> Bag<T> {
}
}
impl<T: Copy + fmt::Display> fmt::Display for Pool<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.elements
.iter()
.enumerate()
.try_for_each(|(i, element)| match element {
Element::Hole(_) => Ok(()),
Element::Object(obj) => {
writeln!(f, "\t[{}]: \t{}", i, obj)
}
})?;
Ok(())
}
}
impl<T: Copy + fmt::Display> fmt::Display for Bag<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.labels.iter().try_for_each(|element| match element {
......@@ -185,12 +185,6 @@ impl<T: Copy + fmt::Display> fmt::Display for Bag<T> {
}
}
impl<T: Copy> fmt::Display for Label<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
#[cfg(test)]
mod tests {
use super::*;
......
......@@ -30,7 +30,7 @@ const VERTICES: [[usize; 2]; 12] = [
pub struct Model {
parameters: Parameters,
weights: Weights,
volume: usize,
pub target_volume: usize,
eps: f32,
}
......@@ -47,16 +47,13 @@ pub struct Weights {
}
#[derive(Clone, Debug)]
pub struct PrecomputedMove {
select_grow: Box<[f32]>,
accept_grow: Box<[f32]>,
accept_shrink: Box<[f32]>,
}
#[derive(Clone, Debug)]
pub struct Precomputed {
shard: PrecomputedMove,
stars: PrecomputedMove,
pub struct MHProbs {
shard_select_grow: Box<[f32]>,
shard_accept_grow: Box<[f32]>,
shard_accept_shrink: Box<[f32]>,
stars_select_grow: Box<[f32]>,
stars_accept_grow: Box<[f32]>,
stars_accept_shrink: Box<[f32]>,
}
#[derive(Debug, Clone)]
......@@ -115,18 +112,25 @@ impl StepCount {
}
}
impl Precomputed {
pub fn new(model: &Model) -> Precomputed {
let shard = model.precompute(2, 1);
let stars = model.precompute(1, 0);
Precomputed { shard, stars }
impl MHProbs {
pub fn new(model: &Model) -> MHProbs {
let shard = model.mh_probs(2, 1);
let stars = model.mh_probs(1, 0);
MHProbs {
shard_select_grow: shard.0,
shard_accept_grow: shard.1,
shard_accept_shrink: shard.2,
stars_select_grow: stars.0,
stars_accept_grow: stars.1,
stars_accept_shrink: stars.2,
}
}
}
impl Model {
#[cfg(test)]
pub fn new(
volume: usize,
target_volume: usize,
eps: f32,
kappa_0: f32,
kappa_3: f32,
......@@ -138,29 +142,22 @@ impl Model {
Model {
parameters,
weights,
volume,
target_volume,
eps,
}
}
pub fn sweep(&self) -> usize {
self.volume
}
fn precompute(&self, dn3: usize, dn0: usize) -> PrecomputedMove {
fn mh_probs(&self, dn3: usize, dn0: usize) -> (Box<[f32]>, Box<[f32]>, Box<[f32]>) {
let ratio = self.ratio(dn3, dn0);
let norm = self.norm(&ratio, dn3);
let select_grow = self.select_grow(&ratio, &norm, dn3);
let (accept_grow, accept_shrink) = self.accept(&norm, dn3);
PrecomputedMove {
select_grow,
accept_grow,
accept_shrink,
}
(select_grow, accept_grow, accept_shrink)
}
fn ratio(&self, dn3: usize, dn0: usize) -> Box<[f32]> {
let n = self.volume;
let n = self.target_volume;
let eps = self.eps;
let k0 = self.parameters.kappa_0;
let k3 = self.parameters.kappa_3;
......@@ -176,7 +173,7 @@ impl Model {
}
fn norm(&self, ratio: &[f32], dn3: usize) -> Box<[f32]> {
let n = self.volume;
let n = self.target_volume;
(0..=(2 * n))
.into_iter()
.map(|n3| {
......@@ -190,7 +187,7 @@ impl Model {
}
fn select_grow(&self, ratio: &[f32], norm: &[f32], dn3: usize) -> Box<[f32]> {
let n = self.volume;
let n = self.target_volume;
(0..=(2 * n))
.into_iter()
.map(|n3| {
......@@ -206,7 +203,7 @@ impl Model {
}
fn accept(&self, norm: &[f32], dn3: usize) -> (Box<[f32]>, Box<[f32]>) {
let n = self.volume;
let n = self.target_volume;
let accept_grow = (0..=(2 * n))
.into_iter()
.map(|n3| {
......@@ -274,7 +271,7 @@ impl Triangulation {
}
pub fn remove_tet(&mut self, label: Label<HalfEdge>) {
let index = (label.value() / 12) * 12;
let index = (label.value / 12) * 12;
(0..12).rev().for_each(|i| {
let label = Label::<HalfEdge>::from(index + i);
self.half_edges.remove(label);
......@@ -282,8 +279,8 @@ impl Triangulation {
});
}
pub fn mc_step(&mut self, model: Model, precomputed: &Precomputed, step_count: &mut StepCount) {
let step = self.choose_step(model.weights, &precomputed.shard, &precomputed.stars);
pub fn mc_step(&mut self, model: Model, mh_probs: &MHProbs, step_count: &mut StepCount) {
let step = self.choose_step(model.weights, mh_probs);
match step {
Step::Move02(_) => step_count.move02 += 1,
Step::Move20(_) => step_count.move20 += 1,
......@@ -298,12 +295,12 @@ impl Triangulation {
impl Label<HalfEdge> {
pub fn next(&self) -> Label<HalfEdge> {
let val = self.value();
let val = self.value;
Label::<HalfEdge>::from((val / 12) * 12 + NEXT[val % 12])
}
pub fn adj_int(&self) -> Label<HalfEdge> {
let val = self.value();
let val = self.value;
Label::<HalfEdge>::from((val / 12) * 12 + ADJ_INT[val % 12])
}
......@@ -313,7 +310,7 @@ impl Label<HalfEdge> {
#[cfg(test)]
pub fn vertices(&self, half_edges: &Pool<HalfEdge>) -> [Label<Vertex>; 4] {
let val = (self.value() / 12) * 12;
let val = (self.value / 12) * 12;
let label01 = Label::<HalfEdge>::from(val);
let label23 = Label::<HalfEdge>::from(val + 6);
let v0 = half_edges[label01].vertex_tail;
......@@ -324,7 +321,7 @@ impl Label<HalfEdge> {
}
pub fn same_tetrahedron(&self, other: Label<HalfEdge>) -> bool {
self.value() / 12 == other.value() / 12
self.value / 12 == other.value / 12
}
}
......@@ -462,14 +459,13 @@ mod tests {
// initialise
let model = Model::new(2, 0.01, 0.0, 3.0, 1, 1);
let precomputed = Precomputed::new(&model);
let mut triangulation = Triangulation::with_capacity(2 * model.sweep() + 12);
let mh_probs = MHProbs::new(&model);
let mut triangulation = Triangulation::with_capacity(12 * (2 * model.target_volume + 1));
print!("{}", triangulation);
// perform a large number of steps and check each iteration
(0..(100_000 * model.sweep())).for_each(|_| {
let step =
triangulation.choose_step(model.weights, &precomputed.shard, &precomputed.stars);
(0..(100_000 * model.target_volume)).for_each(|_| {
let step = triangulation.choose_step(model.weights, &mh_probs);
dbg!(step);
triangulation.do_step(step);
print!("{}", triangulation);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment