/// inference/lambda.rs BAYES STAR, (c) coppola.ai 2024 use super::{ inference::{compute_each_combination, groups_from_backlinks, Inferencer}, table::{GenericNodeType, PropositionNode}, }; use crate::{ inference::inference::build_factor_context_for_assignment, model::{objects::EXISTENCE_FUNCTION, weights::CLASS_LABELS}, }; use std::error::Error; impl Inferencer { pub fn initialize_lambda(&mut self) -> Result<(), Box> { for node in &self.proposition_graph.all_nodes { for outcome in CLASS_LABELS { self.data.set_lambda_value(node, outcome, 1f64); } for parent in &self.proposition_graph.get_all_backward(node) { for outcome in CLASS_LABELS { self.data.set_lambda_message(node, parent, outcome, 1f64); } } } Ok(()) } pub fn do_lambda_traversal(&mut self) -> Result<(), Box> { let mut bfs_order = self.bfs_order.clone(); bfs_order.reverse(); for node in &bfs_order { self.lambda_visit_node(node)?; } Ok(()) } pub fn lambda_visit_node(&mut self, from_node: &PropositionNode) -> Result<(), Box> { self.lambda_send_messages(from_node)?; let is_observed = self.is_observed(from_node)?; if is_observed { self.lambda_set_from_evidence(from_node)?; } else { self.lambda_compute_value(&from_node)?; } Ok(()) } pub fn lambda_set_from_evidence( &mut self, node: &PropositionNode, ) -> Result<(), Box> { let as_single = node.extract_single(); let probability = self .fact_memory .get_proposition_probability(&as_single)? .unwrap(); self.data.set_lambda_value(node, 1, probability); self.data.set_lambda_value(node, 0, 1f64 - probability); Ok(()) } pub fn lambda_compute_value( &mut self, node: &PropositionNode, ) -> Result<(), Box> { let is_observed = self.is_observed(node)?; assert!(!is_observed); let children = self.proposition_graph.get_all_forward(node); for class_label in &CLASS_LABELS { let mut product = 1f64; for (_child_index, child_node) in children.iter().enumerate() { let child_lambda = self .data .get_lambda_message(&child_node, node,*class_label) .unwrap(); product *= child_lambda; } self.data .set_lambda_value(&node, *class_label, product); } Ok(()) } pub fn lambda_send_messages(&mut self, node: &PropositionNode) -> Result<(), Box> { let parent_nodes = self.proposition_graph.get_all_backward(node); let all_combinations = compute_each_combination(&parent_nodes); let lambda_true = self.data.get_lambda_value(node, 1).unwrap(); let lambda_false = self.data.get_lambda_value(node, 0).unwrap(); for (to_index, to_parent) in parent_nodes.iter().enumerate() { let mut sum_true = 0f64; let mut sum_false = 0f64; for combination in &all_combinations { let mut pi_product = 1f64; for (other_index, other_parent) in parent_nodes.iter().enumerate() { if other_index != to_index { let class_bool = combination.get(other_parent).unwrap(); let class_label = if *class_bool { 1 } else { 0 }; let this_pi = self.data.get_pi_message(&other_parent, node, class_label).unwrap(); pi_product *= this_pi; } } let probability_true = self.score_factor_assignment(&parent_nodes, combination, node)?; let probability_false = 1f64 - probability_true; let parent_assignment = combination.get(to_parent).unwrap(); let true_factor = probability_true * pi_product * lambda_true; let false_factor = probability_false * pi_product * lambda_false; if *parent_assignment { sum_true += true_factor + false_factor; } else { sum_false += true_factor + false_factor; } } self.data.set_lambda_message(node, to_parent, 1, sum_true); self.data.set_lambda_message(node, to_parent, 0, sum_false); } Ok(()) } }