Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added bamt/checkers/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions bamt/checkers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from bamt.local_typing.node_types import NodeType
from abc import ABC

# todo: abstract methods
class Checker(ABC):
def __init__(self):
self.node_type = NodeType
self.is_mixture = False
self.is_logit = False

def validate_argument(self, arg):
enumerator = self.node_type.__class__
if isinstance(arg, str):
arg = enumerator(arg)
if arg not in self.node_type.__class__:
assert TypeError("Wrong type of argument.")
Copy link

Copilot AI Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using 'assert' with an exception instance is incorrect; replace it with 'raise TypeError("Wrong type of argument.")' to properly signal the error.

Suggested change
assert TypeError("Wrong type of argument.")
raise TypeError("Wrong type of argument.")

Copilot uses AI. Check for mistakes.
return True
126 changes: 126 additions & 0 deletions bamt/checkers/network_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from bamt.checkers.base import Checker
from bamt.checkers.node_checkers import RawNodeChecker

from bamt.local_typing.network_types import NetworkType
from bamt.local_typing.node_types import NodeSign


class NetworkChecker(Checker):
def __init__(self, descriptor):
super().__init__()
self.RESTRICTIONS = [("cont", "disc"), ("cont", "disc_num")]
self.checker_descriptor = {
"types": {
node_name: RawNodeChecker(node_type)
for node_name, node_type in descriptor["types"].items()
},
"signs": {
node_name: NodeSign(1 if v == "pos" else -1)
for node_name, v in descriptor["signs"].items()
},
}
if all(
node_checker.is_cont
for node_checker in self.checker_descriptor["types"].values()
):
self.network_type = NetworkType.continuous
elif all(
not node_checker.is_cont
for node_checker in self.checker_descriptor["types"].values()
):
self.network_type = NetworkType.discrete
else:
self.network_type = NetworkType.hybrid

def __getitem__(self, node_name):
node_checker = self.checker_descriptor["types"][node_name]

if node_checker.is_cont:
signs = {"signs": self.checker_descriptor["signs"][node_name]}
else:
signs = {}

return {
"node_checker": node_checker,
} | signs

def is_restricted_pair(self, node1, node2):
node_type_checkers = self.checker_descriptor["types"]
if (
node_type_checkers[node1].node_type.name,
node_type_checkers[node2].node_type.name,
) in self.RESTRICTIONS:
return True
else:
return False

def get_checker_rules(self):
return {
"descriptor": self.checker_descriptor,
"restriction_rule": self.is_restricted_pair,
}

def has_mixture_nodes(self):
if not getattr(
next(iter(self.checker_descriptor["types"].values())), "is_mixture"
):
return None

if any(
node_checker.is_mixture
for node_checker in self.checker_descriptor["types"].values()
):
return True
else:
return False

def has_logit_nodes(self):
if not getattr(
next(iter(self.checker_descriptor["types"].values())), "is_logit"
):
return None

if any(
node_checker.is_logit
for node_checker in self.checker_descriptor["types"].values()
):
return True
else:
return False

def validate_load(self, input_dict, network):
# check compatibility with father network.
if not network.use_mixture:
for node_name, node_data in input_dict["parameters"].items():
node_checker = self[node_name]["node_checker"]
if node_checker.is_disc:
continue
else:
# Since we don't have information about types of nodes, we
# should derive it from parameters.
if not node_checker.has_combinations:
if list(node_data.keys()) == ["covars", "mean", "coef"]:
return "use_mixture"
else:
if any(
list(node_keys.keys()) == ["covars", "mean", "coef"]
for node_keys in node_data["hybcprob"].values()
):
return "use_mixture"

# check if edges before and after are the same.They can be different in
# the case when user sets forbidden edges.
if not network.has_logit:
if not all(
edges_before == [edges_after[0], edges_after[1]]
for edges_before, edges_after in zip(input_dict["edges"], network.edges)
):
# logger_network.error(
# f"This crucial parameter is not the same as father's parameter: has_logit."
# )
return False
return True

@property
def is_disc(self):
return True if self.network_type is NetworkType.discrete else False
100 changes: 100 additions & 0 deletions bamt/checkers/node_checkers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from bamt.checkers.base import Checker
from bamt.local_typing.node_types import (
RawNodeType,
NodeType,
discrete_nodes,
continuous_nodes
)


class RawNodeChecker(Checker):
def __init__(self, node_type):
# node_type can be only cont, disc or disc_num
super().__init__()
self.node_type = RawNodeType(node_type)

def __repr__(self):
return f"RawNodeChecker({self.node_type})"

@property
def is_cont(self):
return True if self.node_type is RawNodeType.cont else False

@property
def is_disc(self):
return True if self.node_type in [RawNodeType, RawNodeType.disc_num] else False

Comment on lines +25 to +26
Copy link

Copilot AI Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for discrete node appears incorrect by including the entire RawNodeType enum; consider using a specific member (e.g., RawNodeType.disc) instead of RawNodeType.

Suggested change
return True if self.node_type in [RawNodeType, RawNodeType.disc_num] else False
return True if self.node_type in [RawNodeType.disc, RawNodeType.disc_num] else False

Copilot uses AI. Check for mistakes.
@classmethod
def evolve(cls, node_type_evolved, cont_parents, disc_parents):
"""Method to create final Node Checker linked to node after stage 2"""
return NodeChecker(node_type_evolved, cont_parents, disc_parents)


class NodeChecker(Checker):
def __init__(self, node_type, cont_parents, disc_parents):
super().__init__()
self.node_type = NodeType(node_type)

if self.node_type in discrete_nodes:
self.discrete = True

self.has_disc_parents = True if disc_parents else False
self.has_cont_parents = True if cont_parents else False

self.root = (
True if not self.has_cont_parents and not self.has_disc_parents else False
)

def __repr__(self):
return f"NodeChecker({self.node_type})"

def node_validation(self):
if self.has_disc_parents:
if self.node_type in (
NodeType.mixture_gaussian,
NodeType.gaussian,
NodeType.logit,
):
return False

if self.has_cont_parents:
if self.node_type is NodeType.discrete:
return False

if not (self.has_cont_parents or self.has_disc_parents):
if self.node_type not in (
NodeType.discrete,
NodeType.gaussian,
NodeType.mixture_gaussian,
):
return False

return True

@property
def does_require_regressor(self):
return True if self.node_type in nodes_with_regressors else False

@property
def does_require_classifier(self):
return True if self.node_type in nodes_with_classifiers else False

@property
def has_combinations(self):
return True if self.node_type in nodes_with_combinations else False

@property
def is_mixture(self):
return True if self.node_type in mixture_nodes else False

@property
def is_logit(self):
return True if self.node_type in logit_nodes else False

@property
def is_disc(self):
return True if self.node_type in discrete_nodes else False

@property
def is_cont(self):
return True if self.node_type in continuous_nodes else False
79 changes: 67 additions & 12 deletions bamt/core/graph/dag.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,77 @@
from networkx import DiGraph
from pgmpy.base.DAG import DAG
from bamt.core.graph.graph import Graph

from .graph import Graph
from networkx import DiGraph, topological_sort
from bamt.core.nodes.node import Node

from typing import Type, Sequence
from bamt.loggers.logger import logger_graphs
from bamt.local_typing.node_types import continuous_nodes


class DirectedAcyclicGraph(Graph):
def __init__(self):
super().__init__()
self._networkx_graph = DiGraph()
self.nodes: list[Type[Node]] = []
self.edges: list[Sequence] = []

def has_cycle(self):
pass

def from_container(self, container):
pass

def from_networkx(self, net):
pass

def get_family(self, descriptor):
"""
A function that updates each node accordingly structure;
"""
if not self.nodes:
logger_graphs.error("Vertex list is None")
return None
if not self.edges:
logger_graphs.error("Edges list is None")
return None

node_mapping = {
node_name: {"disc_parents": [], "cont_parents": [], "children": []} for node_name in self.nodes
}

for edge in self.edges:
parent, child = edge[0], edge[1]

if descriptor["types"][parent] in continuous_nodes:
node_mapping[child]["cont_parents"].append(parent)
else:
node_mapping[child]["disc_parents"].append(parent)
return node_mapping

def __getattr__(self, item):
return getattr(self._networkx_graph, item)
@staticmethod
def top_order(nodes: list[Type[Node]],
edges: list[Sequence]) -> list[str]:
"""
Function for topological sorting
"""
G = DiGraph()
G.add_nodes_from(nodes)
G.add_edges_from(edges)
return list(topological_sort(G))

def __setattr__(self, key, value):
setattr(self._networkx_graph, key, value)
def from_pgmpy(self, pgmpy_dag: DAG):
self.nodes = pgmpy_dag.nodes
self.edges = pgmpy_dag.edges

def __delattr__(self, item):
delattr(self._networkx_graph, item)
# def __getattr__(self, item):
# return getattr(self._networkx_graph, item)
#
# def __setattr__(self, key, value):
# setattr(self._networkx_graph, key, value)
#
# def __delattr__(self, item):
# delattr(self._networkx_graph, item)

@property
def networkx_graph(self):
return self._networkx_graph
# @property
# def networkx_graph(self):
# return self._networkx_graph
6 changes: 5 additions & 1 deletion bamt/core/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@

class Graph(ABC):
def __init__(self):
pass
self.nodes = []
self.edges = []

def __repr__(self):
return f"{self.__class__.__name__} with \nNodes:{self.nodes}\nEdges:{self.edges}"
8 changes: 4 additions & 4 deletions bamt/core/node_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from classifier import Classifier
from continuous_distribution import ContinuousDistribution
from empirical_distribution import EmpiricalDistribution
from regressor import Regressor
from .classifier import Classifier
from .continuous_distribution import ContinuousDistribution
from .empirical_distribution import EmpiricalDistribution
from .regressor import Regressor
4 changes: 4 additions & 0 deletions bamt/core/node_models/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class Distribution(ABC):
def fit(self, X: np.ndarray) -> None:
pass

@abstractmethod
def __repr__(self) -> str:
pass

@abstractmethod
def sample(self, num_samples: int) -> np.ndarray:
pass
Loading
Loading