Source code for indra.preassembler

from __future__ import absolute_import, print_function, unicode_literals
from builtins import dict, str

import sys
import time
import logging
import itertools
import functools
import collections
import networkx as nx
import multiprocessing as mp
try:
    import pygraphviz as pgv
except ImportError:
    pass
from indra.util import fast_deepcopy
from indra.statements import *

logger = logging.getLogger(__name__)


[docs]class Preassembler(object): """De-duplicates statements and arranges them in a specificity hierarchy. Parameters ---------- hierarchies : dict[:py:class:`indra.preassembler.hierarchy_manager`] A dictionary of hierarchies with keys such as 'entity' (hierarchy of entities, primarily specifying relationships between genes and their families) and 'modification' pointing to HierarchyManagers stmts : list of :py:class:`indra.statements.Statement` or None A set of statements to perform pre-assembly on. If None, statements should be added using the :py:meth:`add_statements` method. Attributes ---------- stmts : list of :py:class:`indra.statements.Statement` Starting set of statements for preassembly. unique_stmts : list of :py:class:`indra.statements.Statement` Statements resulting from combining duplicates. related_stmts : list of :py:class:`indra.statements.Statement` Top-level statements after building the refinement hierarchy. hierarchies : dict[:py:class:`indra.preassembler.hierarchy_manager`] A dictionary of hierarchies with keys such as 'entity' and 'modification' pointing to HierarchyManagers """ def __init__(self, hierarchies, stmts=None): self.hierarchies = hierarchies if stmts: logger.debug("Deepcopying stmts in __init__") self.stmts = fast_deepcopy(stmts) else: self.stmts = [] self.unique_stmts = None self.related_stmts = None
[docs] def add_statements(self, stmts): """Add to the current list of statements. Parameters ---------- stmts : list of :py:class:`indra.statements.Statement` Statements to add to the current list. """ self.stmts += fast_deepcopy(stmts)
[docs] def combine_duplicates(self): """Combine duplicates among `stmts` and save result in `unique_stmts`. A wrapper around the static method :py:meth:`combine_duplicate_stmts`. """ if self.unique_stmts is None: self.unique_stmts = self.combine_duplicate_stmts(self.stmts) return self.unique_stmts
@staticmethod def _get_stmt_matching_groups(stmts): """Use the matches_key method to get sets of matching statements.""" def match_func(x): return x.matches_key() # Remove exact duplicates using a set() call, then make copies: st = list(set(stmts)) # Group statements according to whether they are matches (differing # only in their evidence). # Sort the statements in place by matches_key() st.sort(key=match_func) return itertools.groupby(st, key=match_func)
[docs] @staticmethod def combine_duplicate_stmts(stmts): """Combine evidence from duplicate Statements. Statements are deemed to be duplicates if they have the same key returned by the `matches_key()` method of the Statement class. This generally means that statements must be identical in terms of their arguments and can differ only in their associated `Evidence` objects. This function keeps the first instance of each set of duplicate statements and merges the lists of Evidence from all of the other statements. Parameters ---------- stmts : list of :py:class:`indra.statements.Statement` Set of statements to de-duplicate. Returns ------- list of :py:class:`indra.statements.Statement` Unique statements with accumulated evidence across duplicates. Examples -------- De-duplicate and combine evidence for two statements differing only in their evidence lists: >>> map2k1 = Agent('MAP2K1') >>> mapk1 = Agent('MAPK1') >>> stmt1 = Phosphorylation(map2k1, mapk1, 'T', '185', ... evidence=[Evidence(text='evidence 1')]) >>> stmt2 = Phosphorylation(map2k1, mapk1, 'T', '185', ... evidence=[Evidence(text='evidence 2')]) >>> uniq_stmts = Preassembler.combine_duplicate_stmts([stmt1, stmt2]) >>> uniq_stmts [Phosphorylation(MAP2K1(), MAPK1(), T, 185)] >>> sorted([e.text for e in uniq_stmts[0].evidence]) # doctest:+IGNORE_UNICODE ['evidence 1', 'evidence 2'] """ unique_stmts = [] for _, duplicates in Preassembler._get_stmt_matching_groups(stmts): ev_keys = set() # Get the first statement and add the evidence of all subsequent # Statements to it duplicates = list(duplicates) for stmt_ix, stmt in enumerate(duplicates): if stmt_ix is 0: new_stmt = stmt.make_generic_copy() if len(duplicates) == 1: new_stmt.uuid = stmt.uuid raw_text = [None if ag is None else ag.db_refs.get('TEXT') for ag in stmt.agent_list(deep_sorted=True)] raw_grounding = [None if ag is None else ag.db_refs for ag in stmt.agent_list(deep_sorted=True)] for ev in stmt.evidence: ev_key = ev.matches_key() if ev_key not in ev_keys: # In case there are already agents annotations, we # just add a new key for raw_text, otherwise create # a new key if 'agents' in ev.annotations: ev.annotations['agents']['raw_text'] = raw_text ev.annotations['agents']['raw_grounding'] = \ raw_grounding else: ev.annotations['agents'] = \ {'raw_text': raw_text, 'raw_grounding': raw_grounding} if 'prior_uuids' not in ev.annotations.keys(): ev.annotations['prior_uuids'] = [] ev.annotations['prior_uuids'].append(stmt.uuid) new_stmt.evidence.append(ev) ev_keys.add(ev_key) # This should never be None or anything else assert isinstance(new_stmt, Statement) unique_stmts.append(new_stmt) return unique_stmts
def _get_entities(self, stmt, stmt_type, eh): entities = [] for a in stmt.agent_list(): # Entity is None: add the None to the entities list if a is None and stmt_type != Complex: entities.append(a) continue # Entity is not None, but could be ungrounded or not # in a family else: a_ns, a_id = a.get_grounding() # No grounding available--in this case, use the # entity_matches_key if a_ns is None or a_id is None: entities.append(a.entity_matches_key()) continue # We have grounding, now check for a component ID uri = eh.get_uri(a_ns, a_id) # This is the component ID corresponding to the agent # in the entity hierarchy component = eh.components.get(uri) # If no component ID, use the entity_matches_key() if component is None: entities.append(a.entity_matches_key()) # Component ID, so this is in a family else: # We turn the component ID into a string so that # we can sort it along with entity_matches_keys # for Complexes entities.append(str(component)) return entities def _get_stmt_by_group(self, stmt_type, stmts_this_type, eh): """Group Statements of `stmt_type` by their hierarchical relations.""" # Dict of stmt group key tuples, indexed by their first Agent stmt_by_first = collections.defaultdict(lambda: []) # Dict of stmt group key tuples, indexed by their second Agent stmt_by_second = collections.defaultdict(lambda: []) # Dict of statements with None first, with second Agent as keys none_first = collections.defaultdict(lambda: []) # Dict of statements with None second, with first Agent as keys none_second = collections.defaultdict(lambda: []) # The dict of all statement groups, with tuples of components # or entity_matches_keys as keys stmt_by_group = collections.defaultdict(lambda: []) # Here we group Statements according to the hierarchy graph # components that their agents are part of for stmt_tuple in stmts_this_type: _, stmt = stmt_tuple entities = self._get_entities(stmt, stmt_type, eh) # At this point we have an entity list # If we're dealing with Complexes, sort the entities and use # as dict key if stmt_type == Complex: # There shouldn't be any statements of the type # e.g., Complex([Foo, None, Bar]) assert None not in entities assert len(entities) > 0 entities.sort() key = tuple(entities) if stmt_tuple not in stmt_by_group[key]: stmt_by_group[key].append(stmt_tuple) elif stmt_type == Conversion: assert len(entities) > 0 key = (entities[0], tuple(sorted(entities[1:len(stmt.obj_from)+1])), tuple(sorted(entities[-len(stmt.obj_to):]))) if stmt_tuple not in stmt_by_group[key]: stmt_by_group[key].append(stmt_tuple) # Now look at all other statement types # All other statements will have one or two entities elif len(entities) == 1: # If only one entity, we only need the one key # It should not be None! assert None not in entities key = tuple(entities) if stmt_tuple not in stmt_by_group[key]: stmt_by_group[key].append(stmt_tuple) else: # Make sure we only have two entities, and they are not both # None key = tuple(entities) assert len(key) == 2 assert key != (None, None) # First agent is None; add in the statements, indexed by # 2nd if key[0] is None and stmt_tuple not in none_first[key[1]]: none_first[key[1]].append(stmt_tuple) # Second agent is None; add in the statements, indexed by # 1st elif key[1] is None and stmt_tuple not in none_second[key[0]]: none_second[key[0]].append(stmt_tuple) # Neither entity is None! elif None not in key: if stmt_tuple not in stmt_by_group[key]: stmt_by_group[key].append(stmt_tuple) if key not in stmt_by_first[key[0]]: stmt_by_first[key[0]].append(key) if key not in stmt_by_second[key[1]]: stmt_by_second[key[1]].append(key) # When we've gotten here, we should have stmt_by_group entries, and # we may or may not have stmt_by_first/second dicts filled out # (depending on the statement type). if none_first: # Get the keys associated with stmts having a None first # argument for second_arg, stmts in none_first.items(): # Look for any statements with this second arg second_arg_keys = stmt_by_second[second_arg] # If there are no more specific statements matching this # set of statements with a None first arg, then the # statements with the None first arg deserve to be in # their own group. if not second_arg_keys: stmt_by_group[(None, second_arg)] = stmts # On the other hand, if there are statements with a matching # second arg component, we need to add the None first # statements to all groups with the matching second arg for second_arg_key in second_arg_keys: stmt_by_group[second_arg_key] += stmts # Now do the corresponding steps for the statements with None as the # second argument: if none_second: for first_arg, stmts in none_second.items(): # Look for any statements with this first arg first_arg_keys = stmt_by_first[first_arg] # If there are no more specific statements matching this # set of statements with a None second arg, then the # statements with the None second arg deserve to be in # their own group. if not first_arg_keys: stmt_by_group[(first_arg, None)] = stmts # On the other hand, if there are statements with a matching # first arg component, we need to add the None second # statements to all groups with the matching first arg for first_arg_key in first_arg_keys: stmt_by_group[first_arg_key] += stmts return stmt_by_group def _generate_id_maps(self, unique_stmts, poolsize=None, size_cutoff=100, split_idx=None): """Connect statements using their refinement relationships.""" # Check arguments relating to multiprocessing if poolsize is None: logger.debug('combine_related: poolsize not set, ' 'not using multiprocessing.') use_mp = False elif sys.version_info[0] >= 3 and sys.version_info[1] >= 4: use_mp = True logger.info('combine_related: Python >= 3.4 detected, ' 'using multiprocessing with poolsize %d, ' 'size_cutoff %d' % (poolsize, size_cutoff)) else: use_mp = False logger.info('combine_related: Python < 3.4 detected, ' 'not using multiprocessing.') eh = self.hierarchies['entity'] # Make a list of Statement types stmts_by_type = collections.defaultdict(lambda: []) for idx, stmt in enumerate(unique_stmts): stmts_by_type[type(stmt)].append((idx, stmt)) child_proc_groups = [] parent_proc_groups = [] skipped_groups = 0 # Each Statement type can be preassembled independently for stmt_type, stmts_this_type in stmts_by_type.items(): logger.info('Grouping %s (%s)' % (stmt_type.__name__, len(stmts_this_type))) stmt_by_group = self._get_stmt_by_group(stmt_type, stmts_this_type, eh) # Divide statements by group size # If we're not using multiprocessing, then all groups are local for g_name, g in stmt_by_group.items(): if len(g) < 2: skipped_groups += 1 continue if use_mp and len(g) >= size_cutoff: child_proc_groups.append(g) else: parent_proc_groups.append(g) # Now run preassembly! logger.debug("Groups: %d parent, %d worker, %d skipped." % (len(parent_proc_groups), len(child_proc_groups), skipped_groups)) supports_func = functools.partial(_set_supports_stmt_pairs, hierarchies=self.hierarchies, split_idx=split_idx, check_entities_match=False) # Check if we are running any groups in child processes; note that if # use_mp is False, child_proc_groups will be empty if child_proc_groups: # Get a multiprocessing context ctx = mp.get_context('spawn') pool = ctx.Pool(poolsize) # Run the large groups remotely logger.debug("Running %d groups in child processes" % len(child_proc_groups)) res = pool.map_async(supports_func, child_proc_groups) workers_ready = False else: workers_ready = True # Run the small groups locally logger.debug("Running %d groups in parent process" % len(parent_proc_groups)) stmt_ix_map = [supports_func(stmt_tuples) for stmt_tuples in parent_proc_groups] logger.debug("Done running parent process groups") while not workers_ready: logger.debug("Checking child processes") if res.ready(): workers_ready = True logger.debug('Child process group comparisons successful? %s' % res.successful()) if not res.successful(): raise Exception("Sorry, there was a problem with " "preassembly in the child processes.") else: stmt_ix_map += res.get() logger.debug("Closing pool...") pool.close() logger.debug("Joining pool...") pool.join() logger.debug("Pool closed and joined.") time.sleep(1) logger.debug("Done.") # Combine all redundant map edges stmt_ix_map_set = set([]) for group_ix_map in stmt_ix_map: for ix_pair in group_ix_map: stmt_ix_map_set.add(ix_pair) return stmt_ix_map_set
[docs] def find_contradicts(self): """Return pairs of contradicting Statements. Returns ------- contradicts : list(tuple(Statement, Statement)) A list of Statement pairs that are contradicting. """ eh = self.hierarchies['entity'] # Make a dict of Statement by type stmts_by_type = collections.defaultdict(lambda: []) for idx, stmt in enumerate(self.stmts): stmts_by_type[type(stmt)].append((idx, stmt)) # Handle Statements with polarity first pos_stmts = AddModification.__subclasses__() neg_stmts = [modclass_to_inverse[c] for c in pos_stmts] pos_stmts += [Activation, IncreaseAmount] neg_stmts += [Inhibition, DecreaseAmount] contradicts = [] for pst, nst in zip(pos_stmts, neg_stmts): poss = stmts_by_type.get(pst, []) negs = stmts_by_type.get(nst, []) pos_stmt_by_group = self._get_stmt_by_group(pst, poss, eh) neg_stmt_by_group = self._get_stmt_by_group(nst, negs, eh) for key, pg in pos_stmt_by_group.items(): ng = neg_stmt_by_group.get(key, []) for (_, st1), (_, st2) in itertools.product(pg, ng): if st1.contradicts(st2, self.hierarchies): contradicts.append((st1, st2)) # Handle neutral Statements next neu_stmts = [Influence, ActiveForm] for stt in neu_stmts: stmts = stmts_by_type.get(stt, []) for (_, st1), (_, st2) in itertools.combinations(stmts, 2): if st1.contradicts(st2, self.hierarchies): contradicts.append((st1, st2)) return contradicts
def _set_supports_stmt_pairs(stmt_tuples, split_idx=None, hierarchies=None, check_entities_match=False): # This is useful when deep-debugging, but even for normal debug is too much. # logger.debug("Getting support pairs for %d tuples with idx %s and stmts " # "%s split at %s." # % (len(stmt_tuples), [idx for idx, _ in stmt_tuples], # [(s.get_hash(shallow=True), s) for _, s in stmt_tuples], # split_idx)) # Make the iterator by one of two methods, depending on the case if split_idx is None: stmt_pair_iter = itertools.combinations(stmt_tuples, 2) else: stmt_group_a = [] stmt_group_b = [] for idx, stmt in stmt_tuples: if idx <= split_idx: stmt_group_a.append((idx, stmt)) else: stmt_group_b.append((idx, stmt)) stmt_pair_iter = itertools.product(stmt_group_a, stmt_group_b) # Actually create the index maps. ix_map = [] for stmt_tuple1, stmt_tuple2 in stmt_pair_iter: stmt_ix1, stmt1 = stmt_tuple1 stmt_ix2, stmt2 = stmt_tuple2 if check_entities_match and not stmt1.entities_match(stmt2): continue if stmt1.refinement_of(stmt2, hierarchies): ix_map.append((stmt_ix1, stmt_ix2)) elif stmt2.refinement_of(stmt1, hierarchies): ix_map.append((stmt_ix2, stmt_ix1)) return ix_map
[docs]def render_stmt_graph(statements, reduce=True, english=False, rankdir=None, agent_style=None): """Render the statement hierarchy as a pygraphviz graph. Parameters ---------- stmts : list of :py:class:`indra.statements.Statement` A list of top-level statements with associated supporting statements resulting from building a statement hierarchy with :py:meth:`combine_related`. reduce : bool Whether to perform a transitive reduction of the edges in the graph. Default is True. english : bool If True, the statements in the graph are represented by their English-assembled equivalent; otherwise they are represented as text-formatted Statements. rank_dir : str or None Argument to pass through to the pygraphviz `AGraph` constructor specifying graph layout direction. In particular, a value of 'LR' specifies a left-to-right direction. If None, the pygraphviz default is used. agent_style : dict or None Dict of attributes specifying the visual properties of nodes. If None, the following default attributes are used:: agent_style = {'color': 'lightgray', 'style': 'filled', 'fontname': 'arial'} Returns ------- pygraphviz.AGraph Pygraphviz graph with nodes representing statements and edges pointing from supported statements to supported_by statements. Examples -------- Pattern for getting statements and rendering as a Graphviz graph: >>> from indra.preassembler.hierarchy_manager import hierarchies >>> braf = Agent('BRAF') >>> map2k1 = Agent('MAP2K1') >>> st1 = Phosphorylation(braf, map2k1) >>> st2 = Phosphorylation(braf, map2k1, residue='S') >>> pa = Preassembler(hierarchies, [st1, st2]) >>> pa.combine_related() # doctest:+ELLIPSIS [Phosphorylation(BRAF(), MAP2K1(), S)] >>> graph = render_stmt_graph(pa.related_stmts) >>> graph.write('example_graph.dot') # To make the DOT file >>> graph.draw('example_graph.png', prog='dot') # To make an image Resulting graph: .. image:: /images/example_graph.png :align: center :alt: Example statement graph rendered by Graphviz """ from indra.assemblers.english import EnglishAssembler # Set the default agent formatting properties if agent_style is None: agent_style = {'color': 'lightgray', 'style': 'filled', 'fontname': 'arial'} # Sets to store all of the nodes and edges as we recursively process all # of the statements nodes = set([]) edges = set([]) stmt_dict = {} # Recursive function for processing all statements def process_stmt(stmt): nodes.add(str(stmt.matches_key())) stmt_dict[str(stmt.matches_key())] = stmt for sby_ix, sby_stmt in enumerate(stmt.supported_by): edges.add((str(stmt.matches_key()), str(sby_stmt.matches_key()))) process_stmt(sby_stmt) # Process all of the top-level statements, getting the supporting statements # recursively for stmt in statements: process_stmt(stmt) # Create a networkx graph from the nodes nx_graph = nx.DiGraph() nx_graph.add_edges_from(edges) # Perform transitive reduction if desired if reduce: nx_graph = nx.algorithms.dag.transitive_reduction(nx_graph) # Create a pygraphviz graph from the nx graph try: pgv_graph = pgv.AGraph(name='statements', directed=True, rankdir=rankdir) except NameError: logger.error('Cannot generate graph because ' 'pygraphviz could not be imported.') return None for node in nx_graph.nodes(): stmt = stmt_dict[node] if english: ea = EnglishAssembler([stmt]) stmt_str = ea.make_model() else: stmt_str = str(stmt) pgv_graph.add_node(node, label='%s (%d)' % (stmt_str, len(stmt.evidence)), **agent_style) pgv_graph.add_edges_from(nx_graph.edges()) return pgv_graph
[docs]def flatten_stmts(stmts): """Return the full set of unique stms in a pre-assembled stmt graph. The flattened list of statements returned by this function can be compared to the original set of unique statements to make sure no statements have been lost during the preassembly process. Parameters ---------- stmts : list of :py:class:`indra.statements.Statement` A list of top-level statements with associated supporting statements resulting from building a statement hierarchy with :py:meth:`combine_related`. Returns ------- stmts : list of :py:class:`indra.statements.Statement` List of all statements contained in the hierarchical statement graph. Examples -------- Calling :py:meth:`combine_related` on two statements results in one top-level statement; calling :py:func:`flatten_stmts` recovers both: >>> from indra.preassembler.hierarchy_manager import hierarchies >>> braf = Agent('BRAF') >>> map2k1 = Agent('MAP2K1') >>> st1 = Phosphorylation(braf, map2k1) >>> st2 = Phosphorylation(braf, map2k1, residue='S') >>> pa = Preassembler(hierarchies, [st1, st2]) >>> pa.combine_related() # doctest:+ELLIPSIS [Phosphorylation(BRAF(), MAP2K1(), S)] >>> flattened = flatten_stmts(pa.related_stmts) >>> flattened.sort(key=lambda x: x.matches_key()) >>> flattened [Phosphorylation(BRAF(), MAP2K1()), Phosphorylation(BRAF(), MAP2K1(), S)] """ total_stmts = set(stmts) for stmt in stmts: if stmt.supported_by: children = flatten_stmts(stmt.supported_by) total_stmts = total_stmts.union(children) return list(total_stmts)
[docs]def flatten_evidence(stmts, collect_from=None): """Add evidence from *supporting* stmts to evidence for *supported* stmts. Parameters ---------- stmts : list of :py:class:`indra.statements.Statement` A list of top-level statements with associated supporting statements resulting from building a statement hierarchy with :py:meth:`combine_related`. collect_from : str in ('supports', 'supported_by') String indicating whether to collect and flatten evidence from the `supports` attribute of each statement or the `supported_by` attribute. If not set, defaults to 'supported_by'. Returns ------- stmts : list of :py:class:`indra.statements.Statement` Statement hierarchy identical to the one passed, but with the evidence lists for each statement now containing all of the evidence associated with the statements they are supported by. Examples -------- Flattening evidence adds the two pieces of evidence from the supporting statement to the evidence list of the top-level statement: >>> from indra.preassembler.hierarchy_manager import hierarchies >>> braf = Agent('BRAF') >>> map2k1 = Agent('MAP2K1') >>> st1 = Phosphorylation(braf, map2k1, ... evidence=[Evidence(text='foo'), Evidence(text='bar')]) >>> st2 = Phosphorylation(braf, map2k1, residue='S', ... evidence=[Evidence(text='baz'), Evidence(text='bak')]) >>> pa = Preassembler(hierarchies, [st1, st2]) >>> pa.combine_related() # doctest:+ELLIPSIS [Phosphorylation(BRAF(), MAP2K1(), S)] >>> [e.text for e in pa.related_stmts[0].evidence] # doctest:+IGNORE_UNICODE ['baz', 'bak'] >>> flattened = flatten_evidence(pa.related_stmts) >>> sorted([e.text for e in flattened[0].evidence]) # doctest:+IGNORE_UNICODE ['bak', 'bar', 'baz', 'foo'] """ if collect_from is None: collect_from = 'supported_by' if collect_from not in ('supports', 'supported_by'): raise ValueError('collect_from must be one of "supports", ' '"supported_by"') # Copy all of the statements--these will be the ones where we update # the evidence lists stmts = fast_deepcopy(stmts) for stmt in stmts: total_evidence = _flatten_evidence_for_stmt(stmt, collect_from) stmt.evidence = total_evidence return stmts
def _flatten_evidence_for_stmt(stmt, collect_from): supp_stmts = (stmt.supports if collect_from == 'supports' else stmt.supported_by) total_evidence = set(stmt.evidence) for supp_stmt in supp_stmts: child_evidence = _flatten_evidence_for_stmt(supp_stmt, collect_from) total_evidence = total_evidence.union(child_evidence) return list(total_evidence)