import numbers
import logging
from copy import deepcopy
from collections import Counter
import scipy.stats
import kappy
import itertools
import numpy as np
import networkx as nx
from pysb import WILD, export, Observable, ComponentSet, Annotation
from pysb.core import as_complex_pattern, ComponentDuplicateNameError
from pysb.pattern import RulePatternMatcher
from indra.explanation.reporting import stmt_from_rule, agent_from_obs
from indra.statements import *
from indra.assemblers.pysb import assembler as pa
from indra.assemblers.pysb.kappa_util import im_json_to_graph
from indra.statements.agent import default_ns_order
from indra.ontology.bio import bio_ontology
from . import ModelChecker, PathResult, NodesContainer
from .model_checker import signed_edges_to_signed_nodes
logger = logging.getLogger(__name__)
try:
import paths_graph as pg
has_pg = True
except ImportError:
pg = None
has_pg = False
logger.warning('PathsGraph is not available')
[docs]class PysbModelChecker(ModelChecker):
"""Check a PySB model against a set of INDRA statements.
Parameters
----------
model : pysb.Model
A PySB model to check.
statements : Optional[list[indra.statements.Statement]]
A list of INDRA Statements to check the model against.
agent_obs: Optional[list[indra.statements.Agent]]
A list of INDRA Agents in a given state to be observed.
do_sampling : bool
Whether to use breadth-first search or weighted sampling to
generate paths. Default is False (breadth-first search).
seed : int
Random seed for sampling (optional, default is None).
model_stmts : list[indra.statements.Statement]
A list of INDRA statements used to assemble PySB model.
nodes_to_agents : dict
A dictionary mapping nodes of intermediate signed edges graph to INDRA
agents.
Attributes
----------
graph : nx.Digraph
A DiGraph with signed nodes to find paths in.
"""
def __init__(self, model, statements=None, agent_obs=None,
do_sampling=False, seed=None, model_stmts=None,
nodes_to_agents=None):
super().__init__(model, statements, do_sampling, seed, nodes_to_agents)
if agent_obs:
self.agent_obs = agent_obs
else:
self.agent_obs = []
mps_to_agents, rules_to_mps = pa.get_grounded_agents(model)
self.mps_to_agents = mps_to_agents
self.rules_to_mps = rules_to_mps
self.model_agents = self.get_model_agents()
self.model_stmts = model_stmts if model_stmts else []
# Influence map
self._im = None
# Map from statements to associated observables
self.stmt_to_obs = {}
# Map from agents to associated observables
self.agent_to_obs = {}
# Map between rules and downstream observables
self.rule_obs_dict = {}
# Map from observables to agents
self.obs_to_agents = {}
[docs] def generate_im(self, model):
"""Return a graph representing the influence map generated by Kappa
Parameters
----------
model : pysb.Model
The PySB model whose influence map is to be generated
Returns
-------
graph : networkx.MultiDiGraph
A MultiDiGraph representing the influence map
"""
kappa = kappy.KappaStd()
model_str = export.export(model, 'kappa')
kappa.add_model_string(model_str)
kappa.project_parse()
imap = kappa.analyses_influence_map(accuracy='medium')
graph = im_json_to_graph(imap)
return graph
[docs] def draw_im(self, fname):
"""Draw and save the influence map in a file.
Parameters
----------
fname : str
The name of the file to save the influence map in.
The extension of the file will determine the file format,
typically png or pdf.
"""
im = self.get_im()
im_agraph = nx.nx_agraph.to_agraph(im)
im_agraph.draw(fname, prog='dot')
[docs] def get_im(self, force_update=False):
"""Get the influence map for the model, generating it if necessary.
Parameters
----------
force_update : bool
Whether to generate the influence map when the function is called.
If False, returns the previously generated influence map if
available. Defaults to True.
Returns
-------
networkx MultiDiGraph object containing the influence map.
The influence map can be rendered as a pdf using the dot layout
program as follows::
im_agraph = nx.nx_agraph.to_agraph(influence_map)
im_agraph.draw('influence_map.pdf', prog='dot')
"""
if self._im and not force_update:
return self._im
if not self.model:
raise Exception("Cannot get influence map if there is no model.")
def add_obs_for_agents(main_agent, ref_agents=None):
if ref_agents:
all_agents = [main_agent] + ref_agents
else:
all_agents = [main_agent]
ag_to_obj_mps = self.get_all_mps(all_agents, mapping=True)
if all([not v for v in ag_to_obj_mps.values()]):
logger.debug('No monomer patterns found in model for agents %s'
', skipping' % all_agents)
return
obs_nodes = NodesContainer(main_agent, ref_agents)
main_obs_set = set()
ref_obs_set = set()
for agent in ag_to_obj_mps:
for obj_mp in ag_to_obj_mps[agent]:
obs_name = _monomer_pattern_label(obj_mp) + '_obs'
self.obs_to_agents[obs_name] = agent
# Add the observable
obj_obs = Observable(obs_name, obj_mp, _export=False)
if agent.matches(main_agent):
main_obs_set.add(obs_name)
else:
ref_obs_set.add(obs_name)
try:
self.model.add_component(obj_obs)
self.model.add_annotation(
Annotation(obs_name, agent.name,
'from_indra_agent'))
except ComponentDuplicateNameError as e:
pass
obs_nodes.main_interm = main_obs_set
obs_nodes.ref_interm = ref_obs_set
return obs_nodes
# Create observables for all statements to check, and add to model
# Remove any existing observables in the model
self.model.observables = ComponentSet([])
for stmt in self.statements:
# Generate observables for Modification statements
if isinstance(stmt, Modification) or \
isinstance(stmt, SelfModification):
# If the statement is a regular Mod, the target is stmt.sub
if isinstance(stmt, Modification):
sub = stmt.sub
# If it's a SelfMod, the target is stmt.enz
elif isinstance(stmt, SelfModification):
sub = stmt.enz
# Add the mod for the agent
if sub is None:
self.stmt_to_obs[stmt] = NodesContainer(None)
else:
mod_condition_name = modclass_to_modtype[stmt.__class__]
if isinstance(stmt, RemoveModification):
mod_condition_name = modtype_to_inverse[
mod_condition_name]
# Add modification to substrate agent
modified_sub = _add_modification_to_agent(
sub, mod_condition_name, stmt.residue,
stmt.position)
# Get all refinements of substrate agent
ref_subs = self.get_refinements(modified_sub)
obs_nodes = add_obs_for_agents(modified_sub, ref_subs)
# Associate this statement with this observable
self.stmt_to_obs[stmt] = obs_nodes
# Generate observables for Activation/Inhibition statements
elif isinstance(stmt, RegulateActivity):
if stmt.obj is None:
self.stmt_to_obs[stmt] = NodesContainer(None)
else:
# Add activity to object agent
regulated_obj = _add_activity_to_agent(
stmt.obj, stmt.obj_activity, stmt.is_activation)
# Get all refinements of object agent
ref_objs = self.get_refinements(stmt.obj)
obs_nodes = add_obs_for_agents(regulated_obj, ref_objs)
# Associate this statement with this observable
self.stmt_to_obs[stmt] = obs_nodes
elif isinstance(stmt, RegulateAmount):
if stmt.obj is None:
self.stmt_to_obs[stmt] = NodesContainer(None)
else:
# Get all refinements of object agent
ref_objs = self.get_refinements(stmt.obj)
obs_nodes = add_obs_for_agents(stmt.obj, ref_objs)
self.stmt_to_obs[stmt] = obs_nodes
elif isinstance(stmt, Influence):
if stmt.obj is None:
self.stmt_to_obs[stmt] = NodesContainer(None)
else:
# Get all refinements of object agent
ref_objs = self.get_refinements(stmt.obj)
concepts = [obj.concept for obj in ref_objs]
obs_nodes = add_obs_for_agents(stmt.obj.concept, concepts)
self.stmt_to_obs[stmt] = obs_nodes
# Add observables for each agent
for ag in self.agent_obs:
obs_nodes = add_obs_for_agents(ag)
self.agent_to_obs[ag] = obs_nodes
logger.info("Generating influence map")
self._im = self.generate_im(self.model)
# self._im.is_multigraph = lambda: False
# Now, for every rule in the model, check if there are any observables
# downstream; alternatively, for every observable in the model, get a
# list of rules.
# We'll need the dictionary to check if nodes are observables
node_attributes = nx.get_node_attributes(self._im, 'node_type')
for rule in self.model.rules:
obs_list = []
# Get successors of the rule node
for neighb in self._im.neighbors(rule.name):
# Check if the node is an observable
if node_attributes[neighb] != 'variable':
continue
# Get the edge and check the polarity
edge_sign = _get_edge_sign(self._im, (rule.name, neighb))
obs_list.append((neighb, edge_sign))
self.rule_obs_dict[rule.name] = obs_list
return self._im
[docs] def get_graph(self, prune_im=True, prune_im_degrade=True,
prune_im_subj_obj=False, add_namespaces=False,
edge_filter_func=None):
"""Get influence map and convert it to a graph with signed nodes."""
if self.graph:
return self.graph
# NOTE edge_filter_func is not currently used in PySB
im = self.get_im(force_update=True)
if prune_im:
self.prune_influence_map()
if prune_im_degrade:
self.prune_influence_map_degrade_bind_positive(self.model_stmts)
if prune_im_subj_obj:
self.prune_influence_map_subj_obj()
self.get_nodes_to_agents(add_namespaces=add_namespaces)
self.graph = signed_edges_to_signed_nodes(
im, prune_nodes=False, edge_signs={'pos': 1, 'neg': -1})
return self.graph
[docs] def get_nodes_to_agents(self, add_namespaces=False):
"""Return a dictionary mapping influence map nodes to INDRA agents.
Parameters
----------
add_namespaces : bool
Whether to propagate namespaces to node data. Default: False.
Returns
-------
nodes_to_agents : dict
A dictionary mapping influence map nodes to INDRA agents.
"""
if self.nodes_to_agents:
return self.nodes_to_agents
logger.info('Mapping nodes to agents')
im = self.get_im()
nodes_to_agents = {}
# First map rules to their subject agents
for rule, mps in self.rules_to_mps.items():
for mp in mps:
for ann in self.model.annotations:
if ann.subject == rule and ann.object == mp.monomer.name:
# We usually want to map rule to subject agent
if ann.predicate == 'rule_has_subject':
nodes_to_agents[rule] = self.mps_to_agents[mp]
# Add observables to agents stored earlier
nodes_to_agents.update(self.obs_to_agents)
# Optionally propagate namespaces to node data
if add_namespaces:
logger.info('Adding namespaces to influence map nodes')
for n, data in im.nodes(data=True):
ag = nodes_to_agents.get(n)
if ag:
ns, gr = ag.get_grounding()
data['ns'] = ns
self.nodes_to_agents = nodes_to_agents
[docs] def process_statement(self, stmt):
self.get_im()
# Check if this is one of the statement types that we can check
if not isinstance(stmt, (Modification, RegulateAmount,
RegulateActivity, Influence)):
logger.info('Statement type %s not handled' %
stmt.__class__.__name__)
return (None, None, 'STATEMENT_TYPE_NOT_HANDLED')
# Get the polarity for the statement
if isinstance(stmt, Modification):
target_polarity = 1 if isinstance(stmt, RemoveModification) else 0
elif isinstance(stmt, RegulateActivity):
target_polarity = 0 if stmt.is_activation else 1
elif isinstance(stmt, RegulateAmount):
target_polarity = 1 if isinstance(stmt, DecreaseAmount) else 0
elif isinstance(stmt, Influence):
target_polarity = 1 if stmt.overall_polarity() == -1 else 0
# Get the subject and object (works also for Modifications)
subj, obj = stmt.agent_list()
# Get a list of monomer patterns matching the subject FIXME Currently
# this will match rules with the corresponding monomer pattern on it.
# In future, this statement should (possibly) also match rules in which
# 1) the agent is in its active form, or 2) the agent is tagged as the
# enzyme in a rule of the appropriate activity (e.g., a phosphorylation
# rule) FIXME
if subj is not None:
ref_agents = self.get_refinements(subj)
subj_mps = self.get_all_mps([subj], ignore_activities=True)
subj_ref_mps = self.get_all_mps(ref_agents, ignore_activities=True)
if not subj_mps and not subj_ref_mps:
return (None, None, 'SUBJECT_MONOMERS_NOT_FOUND')
subj_nodes = NodesContainer(subj, ref_agents)
meaningful_res_code = None
# Each subject might produce a different input set and we need to
# combine them
for subj_mp in subj_mps:
inp, res_code = self.process_subject(subj_mp)
if res_code:
meaningful_res_code = res_code
continue
subj_nodes.main_nodes += inp
for subj_mp in subj_ref_mps:
inp, res_code = self.process_subject(subj_mp)
if res_code:
meaningful_res_code = res_code
continue
subj_nodes.ref_nodes += inp
subj_nodes.get_all_nodes()
if not subj_nodes.all_nodes and meaningful_res_code:
return (None, None, meaningful_res_code)
else:
subj_nodes = NodesContainer(None)
subj_nodes.all_nodes = None
# Observables may not be found for an activation since there may be no
# rule in the model activating the object, and the object may not have
# an "active" site of the appropriate type
obs_nodes = self.stmt_to_obs[stmt]
if obs_nodes is None:
logger.info("No observables for stmt %s, returning False" % stmt)
return (None, None, 'OBSERVABLES_NOT_FOUND')
# Statement object is None
if obs_nodes.main_agent is None:
# Cannot check modifications in this case
if isinstance(stmt, Modification):
return (None, None, 'STATEMENT_TYPE_NOT_HANDLED')
obs_nodes.all_nodes = None
else:
obs_nodes.main_nodes = [
(obs, target_polarity) for obs in obs_nodes.main_interm]
obs_nodes.ref_nodes = [
(obs, target_polarity) for obs in obs_nodes.ref_interm]
obs_nodes.get_all_nodes()
result_code = None
return subj_nodes, obs_nodes, result_code
[docs] def process_subject(self, subj_mp):
if subj_mp is None:
input_set_signed = None
else:
input_rule_set = self._get_input_rules(subj_mp)
if not input_rule_set:
logger.info('Input rules not found for %s' % subj_mp)
return (None, 'INPUT_RULES_NOT_FOUND')
input_set_signed = {(rule, 0) for rule in input_rule_set}
return input_set_signed, None
def get_model_agents(self):
return set(self.mps_to_agents.values())
[docs] def get_refinements(self, agent):
"""Return a list of refinement agents that are part of the model."""
agents = set()
for ag in self.model_agents:
if not ag.matches(agent) and ag.refinement_of(agent, bio_ontology):
agents.add(ag)
return list(agents)
[docs] def get_all_mps(self, agents, ignore_activities=False, mapping=False):
"""Get a list of all monomer patterns for a list of agents."""
ag_to_mps = {}
mps = []
for ag in agents:
ag_mps = list(pa.grounded_monomer_patterns(
self.model, ag, ignore_activities=ignore_activities))
if ag_mps:
ag_to_mps[ag] = ag_mps
mps += ag_mps
if mapping:
return ag_to_mps
return set(mps)
def _get_input_rules(self, subj_mp):
if subj_mp is None:
raise ValueError("Cannot take None as an argument for subj_mp.")
input_rules = _match_lhs(subj_mp, self.model.rules)
logger.debug('Found %s input rules matching %s' %
(len(input_rules), str(subj_mp)))
# Filter to include only rules where the subj_mp is actually the
# subject (i.e., don't pick up upstream rules where the subject
# is itself a substrate/object)
# FIXME: Note that this will eliminate rules where the subject
# being checked is included on the left hand side as
# a bound condition rather than as an enzyme.
subj_rules = pa.rules_with_annotation(self.model,
subj_mp.monomer.name,
'rule_has_subject')
logger.debug('%d rules with %s as subject' %
(len(subj_rules), subj_mp.monomer.name))
input_rule_set = set([r.name for r in input_rules]).intersection(
set([r.name for r in subj_rules]))
logger.debug('Final input rule set contains %d rules' %
len(input_rule_set))
return input_rule_set
def _sample_paths(self, input_rule_set, obs_name, target_polarity,
max_paths=1, max_path_length=5):
if max_paths == 0:
raise ValueError("max_paths cannot be 0 for path sampling.")
if not has_pg:
raise ImportError("Paths Graph is not imported")
# Convert path polarity representation from 0/1 to 1/-1
def convert_polarities(path_list):
return [tuple((n[0], 0 if n[1] > 0 else 1) for n in path)
for path in path_list]
pg_polarity = 0 if target_polarity > 0 else 1
nx_graph = self._im_to_signed_digraph(self.get_im())
# Add edges from dummy node to input rules
source_node = 'SOURCE_NODE'
for rule in input_rule_set:
nx_graph.add_edge(source_node, rule, sign=0)
# -------------------------------------------------
# Create combined paths_graph
f_level, b_level = pg.get_reachable_sets(nx_graph, source_node,
obs_name, max_path_length,
signed=True)
pg_list = []
for path_length in range(1, max_path_length+1):
cfpg = pg.CFPG.from_graph(
nx_graph, source_node, obs_name, path_length, f_level,
b_level, signed=True, target_polarity=pg_polarity)
pg_list.append(cfpg)
combined_pg = pg.CombinedCFPG(pg_list)
# Make sure the combined paths graph is not empty
if not combined_pg.graph:
pr = PathResult(
False, 'NO_PATHS_FOUND', max_paths, max_path_length)
pr.path_metrics = None
pr.paths = []
return pr
# Get a dict of rule objects
rule_obj_dict = {}
for ann in self.model.annotations:
if ann.predicate == 'rule_has_object':
rule_obj_dict[ann.subject] = ann.object
# Get monomer initial conditions
ic_dict = {}
for mon in self.model.monomers:
# FIXME: A hack that depends on the _0 convention
ic_name = '%s_0' % mon.name
# TODO: Wrap this in try/except?
ic_param = self.model.parameters[ic_name]
ic_value = ic_param.value
ic_dict[mon.name] = ic_value
# Set weights in PG based on model initial conditions
for cur_node in combined_pg.graph.nodes():
edge_weights = {}
rule_obj_list = []
edge_weights_by_gene = {}
for u, v in combined_pg.graph.out_edges(cur_node):
v_rule = v[1][0]
# Get the object of the rule (a monomer name)
rule_obj = rule_obj_dict.get(v_rule)
if rule_obj:
# Add to list so we can count instances by gene
rule_obj_list.append(rule_obj)
# Get the abundance of rule object from the initial
# conditions
# TODO: Wrap in try/except?
ic_value = ic_dict[rule_obj]
else:
ic_value = 1.0
edge_weights[(u, v)] = ic_value
edge_weights_by_gene[rule_obj] = ic_value
# Get frequency of different rule objects
rule_obj_ctr = Counter(rule_obj_list)
# Normalize results by weight sum and gene frequency at this level
edge_weight_sum = sum(edge_weights_by_gene.values())
edge_weights_norm = {}
for e, v in edge_weights.items():
v_rule = e[1][1][0]
rule_obj = rule_obj_dict.get(v_rule)
if rule_obj:
rule_obj_count = rule_obj_ctr[rule_obj]
else:
rule_obj_count = 1
edge_weights_norm[e] = ((v / float(edge_weight_sum)) /
float(rule_obj_count))
# Add edge weights to paths graph
nx.set_edge_attributes(combined_pg.graph, name='weight',
values=edge_weights_norm)
# Sample from the combined CFPG
paths = combined_pg.sample_paths(max_paths)
# -------------------------------------------------
if paths:
pr = PathResult(True, 'PATHS_FOUND', max_paths, max_path_length)
pr.path_metrics = None
# Convert path polarity representation from 0/1 to 1/-1
pr.paths = convert_polarities(paths)
# Strip off the SOURCE_NODE prefix
pr.paths = [p[1:] for p in pr.paths]
else:
assert False
pr = PathResult(
False, 'NO_PATHS_FOUND', max_paths, max_path_length)
pr.path_metrics = None
pr.paths = []
return pr
[docs] def score_paths(self, paths, agents_values, loss_of_function=False,
sigma=0.15, include_final_node=False):
"""Return scores associated with a given set of paths.
Parameters
----------
paths : list[list[tuple[str, int]]]
A list of paths obtained from path finding. Each path is a list
of tuples (which are edges in the path), with the first element
of the tuple the name of a rule, and the second element its
polarity in the path.
agents_values : dict[indra.statements.Agent, float]
A dictionary of INDRA Agents and their corresponding measured
value in a given experimental condition.
loss_of_function : Optional[boolean]
If True, flip the polarity of the path. For instance, if the effect
of an inhibitory drug is explained, set this to True.
Default: False
sigma : Optional[float]
The estimated standard deviation for the normally distributed
measurement error in the observation model used to score paths
with respect to data. Default: 0.15
include_final_node : Optional[boolean]
Determines whether the final node of the path is included in the
score. Default: False
"""
obs_model = lambda x: scipy.stats.norm(x, sigma)
# Build up dict mapping observables to values
obs_dict = {}
for ag, val in agents_values.items():
obs_list = self.agent_to_obs[ag]
if obs_list is not None:
for obs in obs_list:
obs_dict[obs] = val
# For every path...
path_scores = []
for path in paths:
logger.info('------')
logger.info("Scoring path:")
logger.info(path)
# Look at every node in the path, excluding the final
# observable...
path_score = 0
last_path_node_index = -1 if include_final_node else -2
for node, sign in path[:last_path_node_index]:
# ...and for each node check the sign to see if it matches the
# data. So the first thing is to look at what's downstream
# of the rule
# affected_obs is a list of observable names alogn
for affected_obs, rule_obs_sign in self.rule_obs_dict[node]:
flip_polarity = -1 if loss_of_function else 1
pred_sign = sign * rule_obs_sign * flip_polarity
# Check to see if this observable is in the data
logger.info('%s %s: effect %s %s' %
(node, sign, affected_obs, pred_sign))
measured_val = obs_dict.get(affected_obs)
if measured_val:
# For negative predictions use CDF (prob that given
# measured value, true value lies below 0)
if pred_sign <= 0:
prob_correct = obs_model(measured_val).logcdf(0)
# For positive predictions, use log survival function
# (SF = 1 - CDF, i.e., prob that true value is
# above 0)
else:
prob_correct = obs_model(measured_val).logsf(0)
logger.info('Actual: %s, Log Probability: %s' %
(measured_val, prob_correct))
path_score += prob_correct
if not self.rule_obs_dict[node]:
logger.info('%s %s' % (node, sign))
prob_correct = obs_model(0).logcdf(0)
logger.info('Unmeasured node, Log Probability: %s' %
(prob_correct))
path_score += prob_correct
# Normalized path
# path_score = path_score / len(path)
logger.info("Path score: %s" % path_score)
path_scores.append(path_score)
path_tuples = list(zip(paths, path_scores))
# Sort first by path length
sorted_by_length = sorted(path_tuples, key=lambda x: len(x[0]))
# Sort by probability; sort in reverse order to large values
# (higher probabilities) are ranked higher
scored_paths = sorted(sorted_by_length, key=lambda x: x[1],
reverse=True)
return scored_paths
[docs] def prune_influence_map(self):
"""Remove edges between rules causing problematic non-transitivity.
First, all self-loops are removed. After this initial step, edges are
removed between rules when they share *all* child nodes except for each
other; that is, they have a mutual relationship with each other and
share all of the same children.
Note that edges must be removed in batch at the end to prevent edge
removal from affecting the lists of rule children during the comparison
process.
"""
im = self.get_im()
# First, remove all self-loops
logger.info('Removing self loops')
edges_to_remove = []
for e in im.edges():
if e[0] == e[1]:
logger.info('Removing self loop: %s', e)
edges_to_remove.append((e[0], e[1]))
# Now remove all the edges to be removed with a single call
im.remove_edges_from(edges_to_remove)
# Remove parameter nodes from influence map
remove_im_params(self.model, im)
# Now compare nodes pairwise and look for overlap between child nodes
logger.info('Get successors of each node')
succ_dict = {}
for node in im.nodes():
succ_dict[node] = set(im.successors(node))
# Sort and then group nodes by number of successors
logger.info('Compare combinations of successors')
group_key_fun = lambda x: len(succ_dict[x])
nodes_sorted = sorted(im.nodes(), key=group_key_fun)
groups = itertools.groupby(nodes_sorted, key=group_key_fun)
# Now iterate over each group and then construct combinations
# within the group to check for shared sucessors
edges_to_remove = []
for gix, group in groups:
combos = itertools.combinations(group, 2)
for ix, (p1, p2) in enumerate(combos):
# Children are identical except for mutual relationship
if succ_dict[p1].difference(succ_dict[p2]) == set([p2]) and \
succ_dict[p2].difference(succ_dict[p1]) == set([p1]):
for u, v in ((p1, p2), (p2, p1)):
edges_to_remove.append((u, v))
logger.debug('Will remove edge (%s, %s)', u, v)
logger.info('Removing %d edges from influence map' %
len(edges_to_remove))
# Now remove all the edges to be removed with a single call
im.remove_edges_from(edges_to_remove)
[docs] def prune_influence_map_subj_obj(self):
"""Prune influence map to include only edges where the object of the
upstream rule matches the subject of the downstream rule."""
def get_rule_info(r):
result = {}
for ann in self.model.annotations:
if ann.subject == r:
if ann.predicate == 'rule_has_subject':
result['subject'] = ann.object
elif ann.predicate == 'rule_has_object':
result['object'] = ann.object
return result
im = self.get_im()
rules = im.nodes()
edges_to_prune = []
for r1, r2 in itertools.permutations(rules, 2):
if (r1, r2) not in im.edges():
continue
r1_info = get_rule_info(r1)
r2_info = get_rule_info(r2)
if 'object' not in r1_info or 'subject' not in r2_info:
continue
if r1_info['object'] != r2_info['subject']:
logger.info("Removing edge %s --> %s" % (r1, r2))
edges_to_prune.append((r1, r2))
logger.info('Removing %d edges from influence map' %
len(edges_to_prune))
im.remove_edges_from(edges_to_prune)
[docs] def prune_influence_map_degrade_bind_positive(self, model_stmts):
"""Prune positive edges between X degrading and X forming a
complex with Y."""
im = self.get_im()
edges_to_prune = []
for r1, r2, data in im.edges(data=True):
s1 = stmt_from_rule(r1, self.model, model_stmts)
s2 = stmt_from_rule(r2, self.model, model_stmts)
# Make sure this is a degradation/binding combo
s1_is_degrad = (s1 and isinstance(s1, DecreaseAmount))
s2_is_bind = (s2 and isinstance(s2, Complex) and 'bind' in r2)
if not s1_is_degrad or not s2_is_bind:
continue
# Make sure what is degraded is part of the complex
if s1.obj.name not in [m.name for m in s2.members]:
continue
# Make sure we're dealing with a positive influence
if data['sign'] == 1:
edges_to_prune.append((r1, r2))
logger.info('Removing %d edges from influence map' %
len(edges_to_prune))
im.remove_edges_from(edges_to_prune)
def _im_to_signed_digraph(self, im):
edges = []
for e in im.edges():
edge_sign = _get_edge_sign(im, e)
polarity = 0 if edge_sign > 0 else 1
edges.append((e[0], e[1], {'sign': polarity}))
dg = nx.DiGraph()
dg.add_edges_from(edges)
return dg
def _find_sources_sample(im, target, sources, polarity, rule_obs_dict,
agent_to_obs, agents_values):
# Build up dict mapping observables to values
obs_dict = {}
for ag, val in agents_values.items():
obs_list = agent_to_obs[ag]
for obs in obs_list:
obs_dict[obs] = val
sigma = 0.2
def obs_model(x):
return scipy.stats.norm(x, sigma)
def _sample_pred(im, target, rule_obs_dict, obs_model):
preds = list(_get_signed_predecessors(im, target, 1))
if not preds:
return None
pred_scores = []
for pred, sign in preds:
pred_score = 0
for affected_obs, rule_obs_sign in rule_obs_dict[pred]:
pred_sign = sign * rule_obs_sign
# Check to see if this observable is in the data
logger.info('%s %s: effect %s %s' %
(pred, sign, affected_obs, pred_sign))
measured_val = obs_dict.get(affected_obs)
if measured_val:
logger.info('Actual: %s' % measured_val)
# The tail probability of the real value being above 1
tail_prob = obs_model(measured_val).cdf(1)
pred_score += (tail_prob if pred_sign == 1 else
1-tail_prob)
pred_scores.append(pred_score)
# Normalize scores
pred_scores = np.array(pred_scores) / np.sum(pred_scores)
pred_idx = np.random.choice(range(len(preds)), p=pred_scores)
pred = preds[pred_idx]
return pred
preds = []
for i in range(100):
pred = _sample_pred(im, target, rule_obs_dict, obs_model)
preds.append(pred[0])
[docs]def remove_im_params(model, im):
"""Remove parameter nodes from the influence map.
Parameters
----------
model : pysb.core.Model
PySB model.
im : networkx.MultiDiGraph
Influence map.
Returns
-------
networkx.MultiDiGraph
Influence map with the parameter nodes removed.
"""
for param in model.parameters:
# If the node doesn't exist e.g., it may have already been removed),
# skip over the parameter without error
try:
im.remove_node(param.name)
except:
pass
def _get_signed_predecessors(im, node, polarity):
"""Get upstream nodes in the influence map.
Return the upstream nodes along with the overall polarity of the path
to that node by account for the polarity of the path to the given node
and the polarity of the edge between the given node and its immediate
predecessors.
Parameters
----------
im : networkx.MultiDiGraph
Graph containing the influence map.
node : str
The node (rule name) in the influence map to get predecessors (upstream
nodes) for.
polarity : int
Polarity of the overall path to the given node.
Returns
-------
generator of tuples, (node, polarity)
Each tuple returned contains two elements, a node (string) and the
polarity of the overall path (int) to that node.
"""
signed_pred_list = []
for pred in im.predecessors(node):
pred_edge = (pred, node)
yield (pred, _get_edge_sign(im, pred_edge) * polarity)
def _get_edge_sign(im, edge):
"""Get the polarity of the influence by examining the edge sign."""
edge_data = im[edge[0]][edge[1]]
# Handle possible multiple edges between nodes
signs = list(set([v['sign'] for v in edge_data.values()
if v.get('sign')]))
if len(signs) > 1:
logger.warning("Edge %s has conflicting polarities; choosing "
"positive polarity by default" % str(edge))
sign = 1
else:
sign = signs[0]
if sign is None:
raise Exception('No sign attribute for edge.')
elif abs(sign) == 1:
return sign
else:
raise Exception('Unexpected edge sign: %s' % edge.attr['sign'])
def _add_modification_to_agent(agent, mod_type, residue, position):
"""Add a modification condition to an Agent."""
new_mod = ModCondition(mod_type, residue, position)
# Check if this modification already exists
for old_mod in agent.mods:
if old_mod.equals(new_mod):
return agent
new_agent = deepcopy(agent)
new_agent.mods.append(new_mod)
return new_agent
def _add_activity_to_agent(agent, act_type, is_active):
# Default to active, and return polarity if it's an inhibition
new_act = ActivityCondition(act_type, True)
# Check if this state already exists
if agent.activity is not None and agent.activity.equals(new_act):
return agent
new_agent = deepcopy(agent)
new_agent.activity = new_act
polarity = 1 if is_active else -1
return new_agent
def _match_lhs(cp, rules):
"""Get rules with a left-hand side matching the given ComplexPattern."""
rule_matches = []
for rule in rules:
reactant_pattern = rule.rule_expression.reactant_pattern
for rule_cp in reactant_pattern.complex_patterns:
if _cp_embeds_into(rule_cp, cp):
rule_matches.append(rule)
break
return rule_matches
def _cp_embeds_into(cp1, cp2):
"""Check that any state in ComplexPattern2 is matched in ComplexPattern1.
"""
# Check that any state in cp2 is matched in cp1
# If the thing we're matching to is just a monomer pattern, that makes
# things easier--we just need to find the corresponding monomer pattern
# in cp1
if cp1 is None or cp2 is None:
return False
cp1 = as_complex_pattern(cp1)
cp2 = as_complex_pattern(cp2)
if len(cp2.monomer_patterns) == 1:
mp2 = cp2.monomer_patterns[0]
# Iterate over the monomer patterns in cp1 and see if there is one
# that has the same name
for mp1 in cp1.monomer_patterns:
if _mp_embeds_into(mp1, mp2):
return True
return False
def _mp_embeds_into(mp1, mp2):
"""Check that conditions in MonomerPattern2 are met in MonomerPattern1."""
sc_matches = []
if mp1.monomer.name != mp2.monomer.name:
return False
# Check that all conditions in mp2 are met in mp1
for site_name, site_state in mp2.site_conditions.items():
if site_name not in mp1.site_conditions or \
site_state != mp1.site_conditions[site_name]:
return False
return True
def _monomer_pattern_label(mp):
"""Return a string label for a MonomerPattern."""
site_strs = []
for site, cond in mp.site_conditions.items():
if isinstance(cond, tuple) or isinstance(cond, list):
assert len(cond) == 2
if cond[1] == WILD:
site_str = '%s_%s' % (site, cond[0])
else:
site_str = '%s_%s%s' % (site, cond[0], cond[1])
elif isinstance(cond, numbers.Real):
continue
else:
site_str = '%s_%s' % (site, cond)
site_strs.append(site_str)
return '%s_%s' % (mp.monomer.name, '_'.join(site_strs))