Source code for

from __future__ import absolute_import, print_function, unicode_literals
from builtins import dict, str

__all__ = ['stmts_from_json', 'stmts_from_json_file', 'stmts_to_json',
           'stmts_to_json_file', 'draw_stmt_graph',
           'UnresolvedUuidError', 'InputError']

import json
import logging
from indra.statements.statements import Statement, Unresolved

logger = logging.getLogger(__name__)

[docs]def stmts_from_json(json_in, on_missing_support='handle'): """Get a list of Statements from Statement jsons. In the case of pre-assembled Statements which have `supports` and `supported_by` lists, the uuids will be replaced with references to Statement objects from the json, where possible. The method of handling missing support is controled by the `on_missing_support` key-word argument. Parameters ---------- json_in : iterable[dict] A json list containing json dict representations of INDRA Statements, as produced by the `to_json` methods of subclasses of Statement, or equivalently by `stmts_to_json`. on_missing_support : Optional[str] Handles the behavior when a uuid reference in `supports` or `supported_by` attribute cannot be resolved. This happens because uuids can only be linked to Statements contained in the `json_in` list, and some may be missing if only some of all the Statements from pre- assembly are contained in the list. Options: - *'handle'* : (default) convert unresolved uuids into `Unresolved` Statement objects. - *'ignore'* : Simply omit any uuids that cannot be linked to any Statements in the list. - *'error'* : Raise an error upon hitting an un-linkable uuid. Returns ------- stmts : list[:py:class:`Statement`] A list of INDRA Statements. """ stmts = [] uuid_dict = {} for json_stmt in json_in: try: st = Statement._from_json(json_stmt) except Exception as e: logger.warning("Error creating statement: %s" % e) continue stmts.append(st) uuid_dict[st.uuid] = st for st in stmts: _promote_support(st.supports, uuid_dict, on_missing_support) _promote_support(st.supported_by, uuid_dict, on_missing_support) return stmts
[docs]def stmts_from_json_file(fname, format='json'): """Return a list of statements loaded from a JSON file. Parameters ---------- fname : str Path to the JSON file to load statements from. format : Optional[str] One of 'json' to assume regular JSON formatting or 'jsonl' assuming each statement is on a new line. Returns ------- list[indra.statements.Statement] The list of INDRA Statements loaded from the JSOn file. """ with open(fname, 'r') as fh: if format == 'json': return stmts_from_json(json.load(fh)) else: return stmts_from_json([json.loads(line) for line in fh.readlines()])
[docs]def stmts_to_json_file(stmts, fname, format='json', **kwargs): """Serialize a list of INDRA Statements into a JSON file. Parameters ---------- stmts : list[indra.statement.Statements] The list of INDRA Statements to serialize into the JSON file. fname : str Path to the JSON file to serialize Statements into. format : Optional[str] One of 'json' to use regular JSON with indent=1 formatting or 'jsonl' to put each statement on a new line without indents. """ sj = stmts_to_json(stmts, **kwargs) with open(fname, 'w') as fh: if format == 'json': json.dump(sj, fh, indent=1) else: for json_stmt in sj: json.dump(json_stmt, fh) fh.write('\n')
[docs]def stmts_to_json(stmts_in, use_sbo=False, matches_fun=None): """Return the JSON-serialized form of one or more INDRA Statements. Parameters ---------- stmts_in : Statement or list[Statement] A Statement or list of Statement objects to serialize into JSON. use_sbo : Optional[bool] If True, SBO annotations are added to each applicable element of the JSON. Default: False matches_fun : Optional[function] A custom function which, if provided, is used to construct the matches key which is then hashed and put into the return value. Default: None Returns ------- json_dict : dict JSON-serialized INDRA Statements. """ if not isinstance(stmts_in, list): json_dict = stmts_in.to_json(use_sbo=use_sbo) return json_dict else: json_dict = [st.to_json(use_sbo=use_sbo, matches_fun=matches_fun) for st in stmts_in] return json_dict
def _promote_support(sup_list, uuid_dict, on_missing='handle'): """Promote the list of support-related uuids to Statements, if possible.""" valid_handling_choices = ['handle', 'error', 'ignore'] if on_missing not in valid_handling_choices: raise InputError('Invalid option for `on_missing_support`: \'%s\'\n' 'Choices are: %s.' % (on_missing, str(valid_handling_choices))) for idx, uuid in enumerate(sup_list): if uuid in uuid_dict.keys(): sup_list[idx] = uuid_dict[uuid] elif on_missing == 'handle': sup_list[idx] = Unresolved(uuid) elif on_missing == 'ignore': sup_list.remove(uuid) elif on_missing == 'error': raise UnresolvedUuidError("Uuid %s not found in stmt jsons." % uuid) return
[docs]def draw_stmt_graph(stmts): """Render the attributes of a list of Statements as directed graphs. The layout works well for a single Statement or a few Statements at a time. This function displays the plot of the graph using Parameters ---------- stmts : list[indra.statements.Statement] A list of one or more INDRA Statements whose attribute graph should be drawn. """ import networkx try: import matplotlib.pyplot as plt except Exception: logger.error('Could not import matplotlib, not drawing graph.') return try: # This checks whether networkx has this package to work with. import pygraphviz except Exception: logger.error('Could not import pygraphviz, not drawing graph.') return import numpy g = networkx.compose_all([stmt.to_graph() for stmt in stmts]) plt.figure() plt.ion() g.graph['graph'] = {'rankdir': 'LR'} pos = networkx.drawing.nx_agraph.graphviz_layout(g, prog='dot') g = g.to_undirected() # Draw nodes options = { 'marker': 'o', 's': 200, 'c': [0.85, 0.85, 1], 'facecolor': '0.5', 'lw': 0, } ax = plt.gca() nodelist = list(g) xy = numpy.asarray([pos[v] for v in nodelist]) node_collection = ax.scatter(xy[:, 0], xy[:, 1], **options) node_collection.set_zorder(2) # Draw edges networkx.draw_networkx_edges(g, pos, arrows=False, edge_color='0.5') # Draw labels edge_labels = {(e[0], e[1]): e[2].get('label') for e in g.edges(data=True)} networkx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels) node_labels = {n[0]: n[1].get('label') for n in g.nodes(data=True)} for key, label in node_labels.items(): if len(label) > 25: parts = label.split(' ') parts.insert(int(len(parts)/2), '\n') label = ' '.join(parts) node_labels[key] = label networkx.draw_networkx_labels(g, pos, labels=node_labels) ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False)
[docs]class UnresolvedUuidError(Exception): pass
[docs]class InputError(Exception): pass