-
Notifications
You must be signed in to change notification settings - Fork 23
scratch of structure learning for discrete bn #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Roman223
wants to merge
1
commit into
aimclub:2.0.0
Choose a base branch
from
Roman223:2.0.0
base: 2.0.0
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| from bamt.local_typing.node_types import NodeType | ||
| from abc import ABC | ||
|
|
||
| # todo: abstract methods | ||
| class Checker(ABC): | ||
| def __init__(self): | ||
| self.node_type = NodeType | ||
| self.is_mixture = False | ||
| self.is_logit = False | ||
|
|
||
| def validate_argument(self, arg): | ||
| enumerator = self.node_type.__class__ | ||
| if isinstance(arg, str): | ||
| arg = enumerator(arg) | ||
| if arg not in self.node_type.__class__: | ||
| assert TypeError("Wrong type of argument.") | ||
| return True | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| from bamt.checkers.base import Checker | ||
| from bamt.checkers.node_checkers import RawNodeChecker | ||
|
|
||
| from bamt.local_typing.network_types import NetworkType | ||
| from bamt.local_typing.node_types import NodeSign | ||
|
|
||
|
|
||
| class NetworkChecker(Checker): | ||
| def __init__(self, descriptor): | ||
| super().__init__() | ||
| self.RESTRICTIONS = [("cont", "disc"), ("cont", "disc_num")] | ||
| self.checker_descriptor = { | ||
| "types": { | ||
| node_name: RawNodeChecker(node_type) | ||
| for node_name, node_type in descriptor["types"].items() | ||
| }, | ||
| "signs": { | ||
| node_name: NodeSign(1 if v == "pos" else -1) | ||
| for node_name, v in descriptor["signs"].items() | ||
| }, | ||
| } | ||
| if all( | ||
| node_checker.is_cont | ||
| for node_checker in self.checker_descriptor["types"].values() | ||
| ): | ||
| self.network_type = NetworkType.continuous | ||
| elif all( | ||
| not node_checker.is_cont | ||
| for node_checker in self.checker_descriptor["types"].values() | ||
| ): | ||
| self.network_type = NetworkType.discrete | ||
| else: | ||
| self.network_type = NetworkType.hybrid | ||
|
|
||
| def __getitem__(self, node_name): | ||
| node_checker = self.checker_descriptor["types"][node_name] | ||
|
|
||
| if node_checker.is_cont: | ||
| signs = {"signs": self.checker_descriptor["signs"][node_name]} | ||
| else: | ||
| signs = {} | ||
|
|
||
| return { | ||
| "node_checker": node_checker, | ||
| } | signs | ||
|
|
||
| def is_restricted_pair(self, node1, node2): | ||
| node_type_checkers = self.checker_descriptor["types"] | ||
| if ( | ||
| node_type_checkers[node1].node_type.name, | ||
| node_type_checkers[node2].node_type.name, | ||
| ) in self.RESTRICTIONS: | ||
| return True | ||
| else: | ||
| return False | ||
|
|
||
| def get_checker_rules(self): | ||
| return { | ||
| "descriptor": self.checker_descriptor, | ||
| "restriction_rule": self.is_restricted_pair, | ||
| } | ||
|
|
||
| def has_mixture_nodes(self): | ||
| if not getattr( | ||
| next(iter(self.checker_descriptor["types"].values())), "is_mixture" | ||
| ): | ||
| return None | ||
|
|
||
| if any( | ||
| node_checker.is_mixture | ||
| for node_checker in self.checker_descriptor["types"].values() | ||
| ): | ||
| return True | ||
| else: | ||
| return False | ||
|
|
||
| def has_logit_nodes(self): | ||
| if not getattr( | ||
| next(iter(self.checker_descriptor["types"].values())), "is_logit" | ||
| ): | ||
| return None | ||
|
|
||
| if any( | ||
| node_checker.is_logit | ||
| for node_checker in self.checker_descriptor["types"].values() | ||
| ): | ||
| return True | ||
| else: | ||
| return False | ||
|
|
||
| def validate_load(self, input_dict, network): | ||
| # check compatibility with father network. | ||
| if not network.use_mixture: | ||
| for node_name, node_data in input_dict["parameters"].items(): | ||
| node_checker = self[node_name]["node_checker"] | ||
| if node_checker.is_disc: | ||
| continue | ||
| else: | ||
| # Since we don't have information about types of nodes, we | ||
| # should derive it from parameters. | ||
| if not node_checker.has_combinations: | ||
| if list(node_data.keys()) == ["covars", "mean", "coef"]: | ||
| return "use_mixture" | ||
| else: | ||
| if any( | ||
| list(node_keys.keys()) == ["covars", "mean", "coef"] | ||
| for node_keys in node_data["hybcprob"].values() | ||
| ): | ||
| return "use_mixture" | ||
|
|
||
| # check if edges before and after are the same.They can be different in | ||
| # the case when user sets forbidden edges. | ||
| if not network.has_logit: | ||
| if not all( | ||
| edges_before == [edges_after[0], edges_after[1]] | ||
| for edges_before, edges_after in zip(input_dict["edges"], network.edges) | ||
| ): | ||
| # logger_network.error( | ||
| # f"This crucial parameter is not the same as father's parameter: has_logit." | ||
| # ) | ||
| return False | ||
| return True | ||
|
|
||
| @property | ||
| def is_disc(self): | ||
| return True if self.network_type is NetworkType.discrete else False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,100 @@ | ||||||
| from bamt.checkers.base import Checker | ||||||
| from bamt.local_typing.node_types import ( | ||||||
| RawNodeType, | ||||||
| NodeType, | ||||||
| discrete_nodes, | ||||||
| continuous_nodes | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class RawNodeChecker(Checker): | ||||||
| def __init__(self, node_type): | ||||||
| # node_type can be only cont, disc or disc_num | ||||||
| super().__init__() | ||||||
| self.node_type = RawNodeType(node_type) | ||||||
|
|
||||||
| def __repr__(self): | ||||||
| return f"RawNodeChecker({self.node_type})" | ||||||
|
|
||||||
| @property | ||||||
| def is_cont(self): | ||||||
| return True if self.node_type is RawNodeType.cont else False | ||||||
|
|
||||||
| @property | ||||||
| def is_disc(self): | ||||||
| return True if self.node_type in [RawNodeType, RawNodeType.disc_num] else False | ||||||
|
|
||||||
|
Comment on lines
+25
to
+26
|
||||||
| return True if self.node_type in [RawNodeType, RawNodeType.disc_num] else False | |
| return True if self.node_type in [RawNodeType.disc, RawNodeType.disc_num] else False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,22 +1,77 @@ | ||
| from networkx import DiGraph | ||
| from pgmpy.base.DAG import DAG | ||
| from bamt.core.graph.graph import Graph | ||
|
|
||
| from .graph import Graph | ||
| from networkx import DiGraph, topological_sort | ||
| from bamt.core.nodes.node import Node | ||
|
|
||
| from typing import Type, Sequence | ||
| from bamt.loggers.logger import logger_graphs | ||
| from bamt.local_typing.node_types import continuous_nodes | ||
|
|
||
|
|
||
| class DirectedAcyclicGraph(Graph): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self._networkx_graph = DiGraph() | ||
| self.nodes: list[Type[Node]] = [] | ||
| self.edges: list[Sequence] = [] | ||
|
|
||
| def has_cycle(self): | ||
| pass | ||
|
|
||
| def from_container(self, container): | ||
| pass | ||
|
|
||
| def from_networkx(self, net): | ||
| pass | ||
|
|
||
| def get_family(self, descriptor): | ||
| """ | ||
| A function that updates each node accordingly structure; | ||
| """ | ||
| if not self.nodes: | ||
| logger_graphs.error("Vertex list is None") | ||
| return None | ||
| if not self.edges: | ||
| logger_graphs.error("Edges list is None") | ||
| return None | ||
|
|
||
| node_mapping = { | ||
| node_name: {"disc_parents": [], "cont_parents": [], "children": []} for node_name in self.nodes | ||
| } | ||
|
|
||
| for edge in self.edges: | ||
| parent, child = edge[0], edge[1] | ||
|
|
||
| if descriptor["types"][parent] in continuous_nodes: | ||
| node_mapping[child]["cont_parents"].append(parent) | ||
| else: | ||
| node_mapping[child]["disc_parents"].append(parent) | ||
| return node_mapping | ||
|
|
||
| def __getattr__(self, item): | ||
| return getattr(self._networkx_graph, item) | ||
| @staticmethod | ||
| def top_order(nodes: list[Type[Node]], | ||
| edges: list[Sequence]) -> list[str]: | ||
| """ | ||
| Function for topological sorting | ||
| """ | ||
| G = DiGraph() | ||
| G.add_nodes_from(nodes) | ||
| G.add_edges_from(edges) | ||
| return list(topological_sort(G)) | ||
|
|
||
| def __setattr__(self, key, value): | ||
| setattr(self._networkx_graph, key, value) | ||
| def from_pgmpy(self, pgmpy_dag: DAG): | ||
| self.nodes = pgmpy_dag.nodes | ||
| self.edges = pgmpy_dag.edges | ||
|
|
||
| def __delattr__(self, item): | ||
| delattr(self._networkx_graph, item) | ||
| # def __getattr__(self, item): | ||
| # return getattr(self._networkx_graph, item) | ||
| # | ||
| # def __setattr__(self, key, value): | ||
| # setattr(self._networkx_graph, key, value) | ||
| # | ||
| # def __delattr__(self, item): | ||
| # delattr(self._networkx_graph, item) | ||
|
|
||
| @property | ||
| def networkx_graph(self): | ||
| return self._networkx_graph | ||
| # @property | ||
| # def networkx_graph(self): | ||
| # return self._networkx_graph |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from classifier import Classifier | ||
| from continuous_distribution import ContinuousDistribution | ||
| from empirical_distribution import EmpiricalDistribution | ||
| from regressor import Regressor | ||
| from .classifier import Classifier | ||
| from .continuous_distribution import ContinuousDistribution | ||
| from .empirical_distribution import EmpiricalDistribution | ||
| from .regressor import Regressor |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using 'assert' with an exception instance is incorrect; replace it with 'raise TypeError("Wrong type of argument.")' to properly signal the error.