brent.graph
module
The brent.graph
module contains the DAG object. This is the
main object that you'll talk to when constructing a casual graph.
Source code
"""
The `brent.graph` module contains the DAG object. This is the
main object that you'll talk to when constructing a casual graph.
"""
import logging
import numpy as np
import pandas as pd
import networkx as nx
from graphviz import Digraph
from brent.common import normalise, window, is_path_blocked
class DAG:
"""
A Directed-Acyclic-Graph (DAG) object describes a graphical model of a dataset.
This object is generated from a pandas dataframe. Every column in the
dataframe will result in a node/variable in the DAG object.
```
from brent import DAG
from brent.common import make_fake_df
# let's start with a new dataset
df = make_fake_df(4)
dag = DAG(df).add_edge("a", "b").add_edge("b", "c").add_edge("c","d")
```
"""
def __init__(self, dataframe: pd.DataFrame):
"""
Create a new DAG from a dataframe.
Inputs:
- **dataframe**: pandas object that contains all variables
Example:
```
from brent import DAG
from brent.common import make_fake_df
# let's start with a new dataset
df = make_fake_df(4)
dag = DAG(df).add_edge("a", "b").add_edge("b", "c").add_edge("c","d")
```
"""
self.df = dataframe
self.graph = nx.DiGraph()
for node in self.df.columns:
self.graph.add_node(node)
self.cached = False
self.prob_tables = {}
@property
def undirected_graph(self):
"""
Fetch the `undirected` variant of the NetworkX graph. This can be
useful when trying to determine all paths between two nodes.
"""
return self.graph.to_undirected()
@property
def origin_nodes(self):
"""These nodes are nodes that do not have any edges going in."""
return tuple(x for x in self.graph.nodes() if self.graph.in_degree(x) == 0)
@property
def marginal_table(self):
"""
The marginal table is a table with all possible values and associated probability.
"""
nodes = list(self.graph.nodes)
logging.debug(f"about to calculate marginal table with nodes {nodes}")
logging.debug(f"updating table for node {nodes[-1]}")
marginal = self.calc_node_table(nodes.pop())
logging.debug(f"current node table:\n{marginal}")
logging.debug(f"current marginal table:\n{marginal}")
for node in nodes:
logging.debug(f"updating table for node {node}")
logging.debug(f"current node table:\n{self.calc_node_table(node)}")
marginal = self.merge_probs(marginal, self.calc_node_table(node))
logging.debug(f"current marginal table:\n{marginal}")
return marginal
@property
def nodes(self):
"""The nodes of the graph."""
return list(self.graph.nodes)
@property
def edges(self):
"""The edges of the graph."""
return list(self.graph.edges)
def copy(self):
"""Returns a copy of the current DAG."""
new_dag = DAG(self.df)
new_dag.graph = self.graph.copy()
return new_dag
def edge_direction(self, node_a, node_b):
"""Determines the `<-` vs. `->` direction of an edge between two nodes."""
if node_b in self.parents(node_a):
return "<-"
if node_a in self.parents(node_b):
return "->"
raise ValueError(f"node '{node_a}' is not connected to '{node_b}'")
def undirected_paths(self, node_a, node_b):
"""
Returns a list of all the paths that are between `node_a` and `node_b`.
These paths do not take the direction into account and will turn the
directed graph into an undirected one.
"""
return list(nx.all_simple_paths(self.undirected_graph, node_a, node_b))
def directed_paths(self, node_a, node_b):
"""
Find all paths between node_a and node_b. These paths may be be
probalistically inactive. If you want the active paths call `DAG.active_paths` instead.
"""
output_paths = []
for path in self.undirected_paths(node_a, node_b):
windowed = list(window(path))
new_path = []
for n1, n2 in windowed:
new_path.append(n1)
new_path.append(self.edge_direction(n1, n2))
new_path.append(n2)
output_paths.append(new_path)
logging.debug(f"found directed path: {' '.join(new_path)}")
return output_paths
def active_paths(self, node_a, node_b, z=list()):
"""
Returns a list of all the paths that can influence probability between
`node_a` and `node_b`. You can also supply a collection of nodes `z` that
are given. These nodes might block (or enable) a path to be active. If there
are no active paths between two nodes then this means that the two nodes
are (conditionally if `z` is given) independant.
"""
undirected_paths = self.directed_paths(node_a, node_b)
active_paths = []
for path in undirected_paths:
z_path = [f"given({_})" if _ in z else _ for _ in path]
logging.debug(f"now checking {' '.join(z_path)}")
if not is_path_blocked(z_path):
active_paths.append(z_path)
return active_paths
def cache(self):
"""
When calling `.cache()` the API will realise that the graph will no longer change.
We will also cache all the probability tables that are associated with a node.
:return:
"""
if self.cached:
raise ValueError("cannot call `.cache()` on a DAG that is already cached!")
for node in self.nodes:
self.prob_tables[node] = self.calc_node_table(node)
self.cached = True
return self
def calc_node_table(self, name):
"""
Calculates probability table for a given node.
Suppose we have graph `A -> B -> C`: `.calc_node_table("b")`
call will calculate `P(B|A)` in the `probs` column of the result.
## Input
- **name**: Name of a node/variable in the graph
"""
if name not in self.nodes:
raise ValueError(f"node {name} not in available nodes: {self.nodes}")
if name in self.prob_tables:
return self.prob_tables[name]
parents = self.parents(name)
tbl = self.df.copy()
logging.debug(f"creating node table node={name} parents={parents}")
def calculate_parents_size(dataf, groups=[]):
return (dataf
.assign(count=1)
.groupby(groups)
.transform(np.sum)['count'])
if len(parents) == 0:
tbl = tbl.assign(parent_size=lambda d: d.shape[0])
else:
tbl = tbl.assign(parent_size=lambda d: calculate_parents_size(d, list(parents)))
return (tbl
.assign(node_size=lambda d: calculate_parents_size(d, list(parents) + [name]))
.assign(prob=lambda d: d['node_size'] / d['parent_size'])
[list(parents) + [name] + ["prob"]]
.drop_duplicates()
.reset_index(drop=True))
def merge_probs(self, this_df, that_df):
"""
Merges two probability dataframes while checking if nodes
are connected in the graph. If `this_df` denotes `p(C|A,B)`
in table form and `that_df` denotes `p(B|A)` in table form
then the output of this function will denote `p(C,B|A)` in
table form. Note that this method is mainly meant for internal usage.
## Input
- **this_df**: Name of a node/variable in the graph, say `p(C|A,B)`
- **that_df**: Name of a node/variable in the graph, say `p(B|A)`
## Output
A dataframe that merges the two former tables, say `p(C,B|A)`.
"""
common_cols = list(set(this_df.columns)
.intersection(set(that_df.columns))
.difference({"prob"}))
if len(common_cols) == 0:
columns = set(that_df.columns).difference({"prob"})
loose_tables = []
for c in columns:
for value in self.calc_node_table(c)[c].values:
loose_tables.append(this_df.assign(**{c: value}))
join_able = pd.concat(loose_tables)
return self.merge_probs(join_able, that_df)
return (this_df
.set_index(common_cols)
.join(that_df.set_index(common_cols), lsuffix="1", rsuffix="2")
.assign(prob=lambda x: normalise(x.prob1 * x.prob2))
.drop("prob1", axis=1)
.drop("prob2", axis=1)
.reset_index())
def add_edge(self, source, sink) -> 'DAG':
"""
Adds an edge to the graph.
## Input
- **source**: Name of a node in the graph
- **sink**: Name of a node in the graph
## Example
```
from brent import DAG
from brent.common import make_fake_df
(DAG(dataframe=make_fake_df(4))
.add_edge("a", "b")
.add_edge("b", "c")
.add_edge("c", "d"))
```
"""
if source not in self.df.columns:
raise ValueError(f"cause {source} not in dataframe")
if sink not in self.df.columns:
raise ValueError(f"effect {sink} not in dataframe")
if self.cached:
raise RuntimeError("Cannot change a graph when the dag is baked.")
new_graph = self.graph.copy()
new_graph.add_edge(source, sink)
if not nx.is_directed_acyclic_graph(new_graph):
raise ValueError(f"edge {source} -> {sink} causes DAG to get cycle")
self.graph.add_edge(source, sink)
logging.debug(f"created connection {source} -> {sink}")
return self
def children(self, node):
"""
Return the children of a node.
## Input
- **node**: Name of a node
## Example
```
from brent import DAG
from brent.common import make_fake_df
dag = (DAG(dataframe=make_fake_df(4))
.add_edge("a", "b")
.add_edge("b", "c")
.add_edge("c", "d"))
dag.children("b") #outputs "c"
dag.children("c") #outputs "d"
```
"""
return set(self.graph.successors(node))
def parents(self, node):
"""
Return the parents of a node.
## Input
- **node**: Name of a node
## Example
```
from brent import DAG
from brent.common import make_fake_df
dag = (DAG(dataframe=make_fake_df(4))
.add_edge("a", "b")
.add_edge("b", "c")
.add_edge("c", "d"))
dag.children("b") #outputs "a"
dag.children("c") #outputs "b"
```
"""
return set(self.graph.predecessors(node))
def connections(self, node):
"""
Return all nodes connected to the one passed in.
## Input
- **node**: Name of a node
## Example
```
from brent import DAG
from brent.common import make_fake_df
dag = (DAG(dataframe=make_fake_df(4))
.add_edge("a", "b")
.add_edge("b", "c")
.add_edge("c", "d"))
dag.children("b") #outputs ["a","c"]
dag.children("c") #outputs ["b","d"]
```
"""
return set(list(self.children(node)) + list(self.parents(node)))
def independences(self):
"""
**NOT IMPLEMENTED YET**
"""
# https://www.slideshare.net/duytungpham18/lecture-1-graphical-models
pass
def plot(self):
"""A pretty plotting function."""
d = Digraph()
d.attr(rankdir='LR')
d.attr('node', shape='circle')
for n in self.graph.nodes:
d.node(n)
for n1, n2 in self.graph.edges:
d.edge(n1, n2)
return d
def nx_plot(self, **kwargs):
"""
A customizable plotting function. The function comes from `networkx`.
It merely wraps around the `nx.draw` method, documentation of this project
can be found [here](https://networkx.github.io/documentation/stable/index.html).
"""
nx.draw(self.graph, node_size=500, with_labels=True, node_color="white", **kwargs)
Classes
class DAG
-
A Directed-Acyclic-Graph (DAG) object describes a graphical model of a dataset. This object is generated from a pandas dataframe. Every column in the dataframe will result in a node/variable in the DAG object.
from brent import DAG from brent.common import make_fake_df # let's start with a new dataset df = make_fake_df(4) dag = DAG(df).add_edge("a", "b").add_edge("b", "c").add_edge("c","d")
Source code
class DAG: """ A Directed-Acyclic-Graph (DAG) object describes a graphical model of a dataset. This object is generated from a pandas dataframe. Every column in the dataframe will result in a node/variable in the DAG object. ``` from brent import DAG from brent.common import make_fake_df # let's start with a new dataset df = make_fake_df(4) dag = DAG(df).add_edge("a", "b").add_edge("b", "c").add_edge("c","d") ``` """ def __init__(self, dataframe: pd.DataFrame): """ Create a new DAG from a dataframe. Inputs: - **dataframe**: pandas object that contains all variables Example: ``` from brent import DAG from brent.common import make_fake_df # let's start with a new dataset df = make_fake_df(4) dag = DAG(df).add_edge("a", "b").add_edge("b", "c").add_edge("c","d") ``` """ self.df = dataframe self.graph = nx.DiGraph() for node in self.df.columns: self.graph.add_node(node) self.cached = False self.prob_tables = {} @property def undirected_graph(self): """ Fetch the `undirected` variant of the NetworkX graph. This can be useful when trying to determine all paths between two nodes. """ return self.graph.to_undirected() @property def origin_nodes(self): """These nodes are nodes that do not have any edges going in.""" return tuple(x for x in self.graph.nodes() if self.graph.in_degree(x) == 0) @property def marginal_table(self): """ The marginal table is a table with all possible values and associated probability. """ nodes = list(self.graph.nodes) logging.debug(f"about to calculate marginal table with nodes {nodes}") logging.debug(f"updating table for node {nodes[-1]}") marginal = self.calc_node_table(nodes.pop()) logging.debug(f"current node table:\n{marginal}") logging.debug(f"current marginal table:\n{marginal}") for node in nodes: logging.debug(f"updating table for node {node}") logging.debug(f"current node table:\n{self.calc_node_table(node)}") marginal = self.merge_probs(marginal, self.calc_node_table(node)) logging.debug(f"current marginal table:\n{marginal}") return marginal @property def nodes(self): """The nodes of the graph.""" return list(self.graph.nodes) @property def edges(self): """The edges of the graph.""" return list(self.graph.edges) def copy(self): """Returns a copy of the current DAG.""" new_dag = DAG(self.df) new_dag.graph = self.graph.copy() return new_dag def edge_direction(self, node_a, node_b): """Determines the `<-` vs. `->` direction of an edge between two nodes.""" if node_b in self.parents(node_a): return "<-" if node_a in self.parents(node_b): return "->" raise ValueError(f"node '{node_a}' is not connected to '{node_b}'") def undirected_paths(self, node_a, node_b): """ Returns a list of all the paths that are between `node_a` and `node_b`. These paths do not take the direction into account and will turn the directed graph into an undirected one. """ return list(nx.all_simple_paths(self.undirected_graph, node_a, node_b)) def directed_paths(self, node_a, node_b): """ Find all paths between node_a and node_b. These paths may be be probalistically inactive. If you want the active paths call `DAG.active_paths` instead. """ output_paths = [] for path in self.undirected_paths(node_a, node_b): windowed = list(window(path)) new_path = [] for n1, n2 in windowed: new_path.append(n1) new_path.append(self.edge_direction(n1, n2)) new_path.append(n2) output_paths.append(new_path) logging.debug(f"found directed path: {' '.join(new_path)}") return output_paths def active_paths(self, node_a, node_b, z=list()): """ Returns a list of all the paths that can influence probability between `node_a` and `node_b`. You can also supply a collection of nodes `z` that are given. These nodes might block (or enable) a path to be active. If there are no active paths between two nodes then this means that the two nodes are (conditionally if `z` is given) independant. """ undirected_paths = self.directed_paths(node_a, node_b) active_paths = [] for path in undirected_paths: z_path = [f"given({_})" if _ in z else _ for _ in path] logging.debug(f"now checking {' '.join(z_path)}") if not is_path_blocked(z_path): active_paths.append(z_path) return active_paths def cache(self): """ When calling `.cache()` the API will realise that the graph will no longer change. We will also cache all the probability tables that are associated with a node. :return: """ if self.cached: raise ValueError("cannot call `.cache()` on a DAG that is already cached!") for node in self.nodes: self.prob_tables[node] = self.calc_node_table(node) self.cached = True return self def calc_node_table(self, name): """ Calculates probability table for a given node. Suppose we have graph `A -> B -> C`: `.calc_node_table("b")` call will calculate `P(B|A)` in the `probs` column of the result. ## Input - **name**: Name of a node/variable in the graph """ if name not in self.nodes: raise ValueError(f"node {name} not in available nodes: {self.nodes}") if name in self.prob_tables: return self.prob_tables[name] parents = self.parents(name) tbl = self.df.copy() logging.debug(f"creating node table node={name} parents={parents}") def calculate_parents_size(dataf, groups=[]): return (dataf .assign(count=1) .groupby(groups) .transform(np.sum)['count']) if len(parents) == 0: tbl = tbl.assign(parent_size=lambda d: d.shape[0]) else: tbl = tbl.assign(parent_size=lambda d: calculate_parents_size(d, list(parents))) return (tbl .assign(node_size=lambda d: calculate_parents_size(d, list(parents) + [name])) .assign(prob=lambda d: d['node_size'] / d['parent_size']) [list(parents) + [name] + ["prob"]] .drop_duplicates() .reset_index(drop=True)) def merge_probs(self, this_df, that_df): """ Merges two probability dataframes while checking if nodes are connected in the graph. If `this_df` denotes `p(C|A,B)` in table form and `that_df` denotes `p(B|A)` in table form then the output of this function will denote `p(C,B|A)` in table form. Note that this method is mainly meant for internal usage. ## Input - **this_df**: Name of a node/variable in the graph, say `p(C|A,B)` - **that_df**: Name of a node/variable in the graph, say `p(B|A)` ## Output A dataframe that merges the two former tables, say `p(C,B|A)`. """ common_cols = list(set(this_df.columns) .intersection(set(that_df.columns)) .difference({"prob"})) if len(common_cols) == 0: columns = set(that_df.columns).difference({"prob"}) loose_tables = [] for c in columns: for value in self.calc_node_table(c)[c].values: loose_tables.append(this_df.assign(**{c: value})) join_able = pd.concat(loose_tables) return self.merge_probs(join_able, that_df) return (this_df .set_index(common_cols) .join(that_df.set_index(common_cols), lsuffix="1", rsuffix="2") .assign(prob=lambda x: normalise(x.prob1 * x.prob2)) .drop("prob1", axis=1) .drop("prob2", axis=1) .reset_index()) def add_edge(self, source, sink) -> 'DAG': """ Adds an edge to the graph. ## Input - **source**: Name of a node in the graph - **sink**: Name of a node in the graph ## Example ``` from brent import DAG from brent.common import make_fake_df (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) ``` """ if source not in self.df.columns: raise ValueError(f"cause {source} not in dataframe") if sink not in self.df.columns: raise ValueError(f"effect {sink} not in dataframe") if self.cached: raise RuntimeError("Cannot change a graph when the dag is baked.") new_graph = self.graph.copy() new_graph.add_edge(source, sink) if not nx.is_directed_acyclic_graph(new_graph): raise ValueError(f"edge {source} -> {sink} causes DAG to get cycle") self.graph.add_edge(source, sink) logging.debug(f"created connection {source} -> {sink}") return self def children(self, node): """ Return the children of a node. ## Input - **node**: Name of a node ## Example ``` from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs "c" dag.children("c") #outputs "d" ``` """ return set(self.graph.successors(node)) def parents(self, node): """ Return the parents of a node. ## Input - **node**: Name of a node ## Example ``` from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs "a" dag.children("c") #outputs "b" ``` """ return set(self.graph.predecessors(node)) def connections(self, node): """ Return all nodes connected to the one passed in. ## Input - **node**: Name of a node ## Example ``` from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs ["a","c"] dag.children("c") #outputs ["b","d"] ``` """ return set(list(self.children(node)) + list(self.parents(node))) def independences(self): """ **NOT IMPLEMENTED YET** """ # https://www.slideshare.net/duytungpham18/lecture-1-graphical-models pass def plot(self): """A pretty plotting function.""" d = Digraph() d.attr(rankdir='LR') d.attr('node', shape='circle') for n in self.graph.nodes: d.node(n) for n1, n2 in self.graph.edges: d.edge(n1, n2) return d def nx_plot(self, **kwargs): """ A customizable plotting function. The function comes from `networkx`. It merely wraps around the `nx.draw` method, documentation of this project can be found [here](https://networkx.github.io/documentation/stable/index.html). """ nx.draw(self.graph, node_size=500, with_labels=True, node_color="white", **kwargs)
Instance variables
var edges
-
The edges of the graph.
Source code
@property def edges(self): """The edges of the graph.""" return list(self.graph.edges)
var marginal_table
-
The marginal table is a table with all possible values and associated probability.
Source code
@property def marginal_table(self): """ The marginal table is a table with all possible values and associated probability. """ nodes = list(self.graph.nodes) logging.debug(f"about to calculate marginal table with nodes {nodes}") logging.debug(f"updating table for node {nodes[-1]}") marginal = self.calc_node_table(nodes.pop()) logging.debug(f"current node table:\n{marginal}") logging.debug(f"current marginal table:\n{marginal}") for node in nodes: logging.debug(f"updating table for node {node}") logging.debug(f"current node table:\n{self.calc_node_table(node)}") marginal = self.merge_probs(marginal, self.calc_node_table(node)) logging.debug(f"current marginal table:\n{marginal}") return marginal
var nodes
-
The nodes of the graph.
Source code
@property def nodes(self): """The nodes of the graph.""" return list(self.graph.nodes)
var origin_nodes
-
These nodes are nodes that do not have any edges going in.
Source code
@property def origin_nodes(self): """These nodes are nodes that do not have any edges going in.""" return tuple(x for x in self.graph.nodes() if self.graph.in_degree(x) == 0)
var undirected_graph
-
Fetch the
undirected
variant of the NetworkX graph. This can be useful when trying to determine all paths between two nodes.Source code
@property def undirected_graph(self): """ Fetch the `undirected` variant of the NetworkX graph. This can be useful when trying to determine all paths between two nodes. """ return self.graph.to_undirected()
Methods
def __init__(self, dataframe)
-
Create a new DAG from a dataframe.
Inputs:
- dataframe: pandas object that contains all variables
Example:
from brent import DAG from brent.common import make_fake_df # let's start with a new dataset df = make_fake_df(4) dag = DAG(df).add_edge("a", "b").add_edge("b", "c").add_edge("c","d")
Source code
def __init__(self, dataframe: pd.DataFrame): """ Create a new DAG from a dataframe. Inputs: - **dataframe**: pandas object that contains all variables Example: ``` from brent import DAG from brent.common import make_fake_df # let's start with a new dataset df = make_fake_df(4) dag = DAG(df).add_edge("a", "b").add_edge("b", "c").add_edge("c","d") ``` """ self.df = dataframe self.graph = nx.DiGraph() for node in self.df.columns: self.graph.add_node(node) self.cached = False self.prob_tables = {}
def active_paths(self, node_a, node_b, z=[])
-
Returns a list of all the paths that can influence probability between
node_a
andnode_b
. You can also supply a collection of nodesz
that are given. These nodes might block (or enable) a path to be active. If there are no active paths between two nodes then this means that the two nodes are (conditionally ifz
is given) independant.Source code
def active_paths(self, node_a, node_b, z=list()): """ Returns a list of all the paths that can influence probability between `node_a` and `node_b`. You can also supply a collection of nodes `z` that are given. These nodes might block (or enable) a path to be active. If there are no active paths between two nodes then this means that the two nodes are (conditionally if `z` is given) independant. """ undirected_paths = self.directed_paths(node_a, node_b) active_paths = [] for path in undirected_paths: z_path = [f"given({_})" if _ in z else _ for _ in path] logging.debug(f"now checking {' '.join(z_path)}") if not is_path_blocked(z_path): active_paths.append(z_path) return active_paths
def add_edge(self, source, sink)
-
Adds an edge to the graph.
Input
- source: Name of a node in the graph
- sink: Name of a node in the graph
Example
from brent import DAG from brent.common import make_fake_df (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d"))
Source code
def add_edge(self, source, sink) -> 'DAG': """ Adds an edge to the graph. ## Input - **source**: Name of a node in the graph - **sink**: Name of a node in the graph ## Example ``` from brent import DAG from brent.common import make_fake_df (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) ``` """ if source not in self.df.columns: raise ValueError(f"cause {source} not in dataframe") if sink not in self.df.columns: raise ValueError(f"effect {sink} not in dataframe") if self.cached: raise RuntimeError("Cannot change a graph when the dag is baked.") new_graph = self.graph.copy() new_graph.add_edge(source, sink) if not nx.is_directed_acyclic_graph(new_graph): raise ValueError(f"edge {source} -> {sink} causes DAG to get cycle") self.graph.add_edge(source, sink) logging.debug(f"created connection {source} -> {sink}") return self
def cache(self)
-
When calling
.cache()
the API will realise that the graph will no longer change. We will also cache all the probability tables that are associated with a node. :return:Source code
def cache(self): """ When calling `.cache()` the API will realise that the graph will no longer change. We will also cache all the probability tables that are associated with a node. :return: """ if self.cached: raise ValueError("cannot call `.cache()` on a DAG that is already cached!") for node in self.nodes: self.prob_tables[node] = self.calc_node_table(node) self.cached = True return self
def calc_node_table(self, name)
-
Calculates probability table for a given node.
Suppose we have graph
A -> B -> C
:.calc_node_table("b")
call will calculateP(B|A)
in theprobs
column of the result.Input
- name: Name of a node/variable in the graph
Source code
def calc_node_table(self, name): """ Calculates probability table for a given node. Suppose we have graph `A -> B -> C`: `.calc_node_table("b")` call will calculate `P(B|A)` in the `probs` column of the result. ## Input - **name**: Name of a node/variable in the graph """ if name not in self.nodes: raise ValueError(f"node {name} not in available nodes: {self.nodes}") if name in self.prob_tables: return self.prob_tables[name] parents = self.parents(name) tbl = self.df.copy() logging.debug(f"creating node table node={name} parents={parents}") def calculate_parents_size(dataf, groups=[]): return (dataf .assign(count=1) .groupby(groups) .transform(np.sum)['count']) if len(parents) == 0: tbl = tbl.assign(parent_size=lambda d: d.shape[0]) else: tbl = tbl.assign(parent_size=lambda d: calculate_parents_size(d, list(parents))) return (tbl .assign(node_size=lambda d: calculate_parents_size(d, list(parents) + [name])) .assign(prob=lambda d: d['node_size'] / d['parent_size']) [list(parents) + [name] + ["prob"]] .drop_duplicates() .reset_index(drop=True))
def children(self, node)
-
Return the children of a node.
Input
- node: Name of a node
Example
from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs "c" dag.children("c") #outputs "d"
Source code
def children(self, node): """ Return the children of a node. ## Input - **node**: Name of a node ## Example ``` from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs "c" dag.children("c") #outputs "d" ``` """ return set(self.graph.successors(node))
def connections(self, node)
-
Return all nodes connected to the one passed in.
Input
- node: Name of a node
Example
from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs ["a","c"] dag.children("c") #outputs ["b","d"]
Source code
def connections(self, node): """ Return all nodes connected to the one passed in. ## Input - **node**: Name of a node ## Example ``` from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs ["a","c"] dag.children("c") #outputs ["b","d"] ``` """ return set(list(self.children(node)) + list(self.parents(node)))
def copy(self)
-
Returns a copy of the current DAG.
Source code
def copy(self): """Returns a copy of the current DAG.""" new_dag = DAG(self.df) new_dag.graph = self.graph.copy() return new_dag
def directed_paths(self, node_a, node_b)
-
Find all paths between node_a and node_b. These paths may be be probalistically inactive. If you want the active paths call
DAG.active_paths()
instead.Source code
def directed_paths(self, node_a, node_b): """ Find all paths between node_a and node_b. These paths may be be probalistically inactive. If you want the active paths call `DAG.active_paths` instead. """ output_paths = [] for path in self.undirected_paths(node_a, node_b): windowed = list(window(path)) new_path = [] for n1, n2 in windowed: new_path.append(n1) new_path.append(self.edge_direction(n1, n2)) new_path.append(n2) output_paths.append(new_path) logging.debug(f"found directed path: {' '.join(new_path)}") return output_paths
def edge_direction(self, node_a, node_b)
-
Determines the
<-
vs.->
direction of an edge between two nodes.Source code
def edge_direction(self, node_a, node_b): """Determines the `<-` vs. `->` direction of an edge between two nodes.""" if node_b in self.parents(node_a): return "<-" if node_a in self.parents(node_b): return "->" raise ValueError(f"node '{node_a}' is not connected to '{node_b}'")
def independences(self)
-
NOT IMPLEMENTED YET
Source code
def independences(self): """ **NOT IMPLEMENTED YET** """ # https://www.slideshare.net/duytungpham18/lecture-1-graphical-models pass
def merge_probs(self, this_df, that_df)
-
Merges two probability dataframes while checking if nodes are connected in the graph. If
this_df
denotesp(C|A,B)
in table form andthat_df
denotesp(B|A)
in table form then the output of this function will denotep(C,B|A)
in table form. Note that this method is mainly meant for internal usage.Input
- this_df: Name of a node/variable in the graph, say
p(C|A,B)
- that_df: Name of a node/variable in the graph, say
p(B|A)
Output
A dataframe that merges the two former tables, say
p(C,B|A)
.Source code
def merge_probs(self, this_df, that_df): """ Merges two probability dataframes while checking if nodes are connected in the graph. If `this_df` denotes `p(C|A,B)` in table form and `that_df` denotes `p(B|A)` in table form then the output of this function will denote `p(C,B|A)` in table form. Note that this method is mainly meant for internal usage. ## Input - **this_df**: Name of a node/variable in the graph, say `p(C|A,B)` - **that_df**: Name of a node/variable in the graph, say `p(B|A)` ## Output A dataframe that merges the two former tables, say `p(C,B|A)`. """ common_cols = list(set(this_df.columns) .intersection(set(that_df.columns)) .difference({"prob"})) if len(common_cols) == 0: columns = set(that_df.columns).difference({"prob"}) loose_tables = [] for c in columns: for value in self.calc_node_table(c)[c].values: loose_tables.append(this_df.assign(**{c: value})) join_able = pd.concat(loose_tables) return self.merge_probs(join_able, that_df) return (this_df .set_index(common_cols) .join(that_df.set_index(common_cols), lsuffix="1", rsuffix="2") .assign(prob=lambda x: normalise(x.prob1 * x.prob2)) .drop("prob1", axis=1) .drop("prob2", axis=1) .reset_index())
- this_df: Name of a node/variable in the graph, say
def nx_plot(self, **kwargs)
-
A customizable plotting function. The function comes from
networkx
. It merely wraps around thenx.draw
method, documentation of this project can be found here.Source code
def nx_plot(self, **kwargs): """ A customizable plotting function. The function comes from `networkx`. It merely wraps around the `nx.draw` method, documentation of this project can be found [here](https://networkx.github.io/documentation/stable/index.html). """ nx.draw(self.graph, node_size=500, with_labels=True, node_color="white", **kwargs)
def parents(self, node)
-
Return the parents of a node.
Input
- node: Name of a node
Example
from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs "a" dag.children("c") #outputs "b"
Source code
def parents(self, node): """ Return the parents of a node. ## Input - **node**: Name of a node ## Example ``` from brent import DAG from brent.common import make_fake_df dag = (DAG(dataframe=make_fake_df(4)) .add_edge("a", "b") .add_edge("b", "c") .add_edge("c", "d")) dag.children("b") #outputs "a" dag.children("c") #outputs "b" ``` """ return set(self.graph.predecessors(node))
def plot(self)
-
A pretty plotting function.
Source code
def plot(self): """A pretty plotting function.""" d = Digraph() d.attr(rankdir='LR') d.attr('node', shape='circle') for n in self.graph.nodes: d.node(n) for n1, n2 in self.graph.edges: d.edge(n1, n2) return d
def undirected_paths(self, node_a, node_b)
-
Returns a list of all the paths that are between
node_a
andnode_b
. These paths do not take the direction into account and will turn the directed graph into an undirected one.Source code
def undirected_paths(self, node_a, node_b): """ Returns a list of all the paths that are between `node_a` and `node_b`. These paths do not take the direction into account and will turn the directed graph into an undirected one. """ return list(nx.all_simple_paths(self.undirected_graph, node_a, node_b))