"""
This module groups and sorts Statements for presentation in downstream tools
while aggregating the statements' statistics/metrics into the groupings. While
most usage of this module will be via the top-level function
`group_and_sort_statements`, alternative usages (including custom statement
data, multiple statement grouping levels, and multiple strategies for
aggregating statement-level metrics for higher-level groupings) are supported
through the various classes (see Class Overview below).
Vocabulary
----------
An "agent-pair" is, as the name suggests, a pair of agents from a statement,
usually defined by their canonical names.
A "relation" is the basic information of a statement, with all details (such as
sites, residues, mutations, and bound conditions) stripped away. Usually this
means it is just the statement type (or verb), subject name, and object name,
though in some corner cases it is different.
Simple Example
--------------
The principal function in the module is `group_and_sort_statements`, and if you
want statements grouped into agent-pairs, then by relations, sorted by evidence
count, simply use the function with its defaults, e.g.:
.. code-block:: python
for _, ag_key, rels, ag_metrics in group_and_sort_statements(stmts):
print(ag_key)
for _, rel_key, stmt_data, rel_metrics in rels:
print('\t', rel_key)
for _, stmt_hash, stmt_obj, stmt_metrics in stmt_data:
print('\t\t', stmt_obj)
Advanced Example
----------------
Custom data and aggregation methods are supported, respectively, by using
instances of the `StmtStat` class and subclassing the BasicAggregator (or more
generally, the AggregatorMeta) API. Custom sorting is implemented by defining
and passing a `sort_by` function to `group_and_sort_statements`.
For example, if you have custom statement metrics (e.g., a value obtained by
experiment such as differential expression of subject or object genes), want
the statements grouped only to the level of relations, and want to sort the
statements and relations independently. Suppose also that your measurement
applies equally at the statement and relation level and hence you don't want
any changes applied during aggregation (e.g. averaging). This is illustrated in
the example below:
.. code-block:: python
# Define a new aggregator that doesn't apply any aggregation function to
# the data, simply taking the last metric (effectively a noop):
class NoopAggregator(BasicAggregator):
def _merge(self, metric_array):
self.values = metric_array
# Create your StmtStat using custom data dict `my_data`, a dict of values
# keyed by statement hash:
my_stat = StmtStat('my_stat', my_data, int, NoopAggregator)
# Define a custom sort function using my stat and the default available
# ev_count. In effect this will sort relations by the custom stat, and then
# secondarily sort the statements within that relation (for which my_stat
# is by design the same) using their evidence counts.
def my_sort(metrics):
return metrics['my_stat'], metrics['ev_count']
# Iterate over the results.
groups = group_and_sort_statements(stmts, sort_by=my_sort,
custom_stats=[my_stat],
grouping_level='relation')
for _, rel_key, rel_stmts, rel_metrics in groups:
print(rel_key, rel_metrics['my_stat'])
for _, stmt_hash, stmt, metrics in rel_stmts:
print('\t', stmt, metrics['ev_count'])
Class Overview
--------------
Statements can have multiple metrics associated with them, most commonly
belief, evidence counts, and source counts, although other metrics may also be
applied. Such metrics imply an order on the set of Statements, and a user
should be able to apply that order to them for sorting or filtering. them.
These types of metric, or "stat", are represented by `StmtStat` classes.
Statements can be grouped based on the information they represent: by their
agents (e.g. subject is MEK and object is ERK), and by their type (e.g.
Phosphorylation). These groups are represented by `StmtGroup` objects, which on
their surface behave much like `defaultdict(list)` would, though more is going
on behind the scenes. The StmtGroup class is used internally by
`group_and_sort_statements` and would only need to be used directly if defining
an alternative statement-level grouping approach (e.g., grouping statements by
subject).
Like Statements, higher-level statement groups are subject to sorting and
filtering. That requires that the `StmtStat`s be aggregated over the statements
in a group. The Aggregator classes serve this purpose, using numpy to do sums
over arrays of metrics as Statements are "included" in the `StmtGroup`. Each
`StmtStat` must declare how its data should be aggregated, as different kinds
of data aggregate differently. Custom aggregation methods can be implemented by
subclassing the `BasicAggregator` class and using an instance of the custom
class to define a `StmtStat`.
"""
import logging
from collections import defaultdict
from itertools import permutations
from numpy import array, zeros, maximum, concatenate, append
from typing import List, Set, Dict, Optional, Union, Tuple
from indra.assemblers.english import EnglishAssembler
from indra.statements import Agent, Influence, Event, get_statement_by_name, \
Statement, Evidence
logger = logging.getLogger(__name__)
db_sources = ['psp', 'cbn', 'pc', 'bel_lc', 'signor', 'biogrid',
'tas', 'hprd', 'trrust', 'ctd', 'vhn', 'pe', 'drugbank',
'omnipath', 'conib', 'crog', 'dgi', 'minerva', 'creeds',
'ubibrowser', 'acsn']
"""Database source names as they appear in the DB"""
reader_sources = ['geneways', 'tees', 'gnbr', 'semrep', 'isi', 'trips',
'rlimsp', 'medscan', 'eidos', 'sparser', 'reach']
"""Reader source names as they appear in the DB"""
# These are mappings where the actual INDRA source, as it appears
# in the evidence source_api is inconsistent with the colors here and
# with what comes out of the INDRA DB
internal_source_mappings = {
'bel': 'bel_lc',
'phosphoelm': 'pe',
'biopax': 'pc',
'virhostnet': 'vhn',
'phosphosite': 'psp',
}
"""Maps from source_info.json names to DB names"""
reverse_source_mappings = {v: k for k, v in internal_source_mappings.items()}
"""Maps from db names to source_info.json names"""
all_sources = db_sources + reader_sources
"""Source names as they appear in the DB"""
# Dervied types
SourceColors = List[Tuple[str, Dict[str, Union[str, Dict[str, str]]]]]
def _get_relation_keyed_stmts(stmt_list, expand_nary=True):
"""Low level generator over a list of statements, intended for internal use.
Non-unique grouping keys are generated for each statement. Each row
generated will contain a tuple, beginning with a relation key,
generally of the from (verb, ...agents...), though sometimes different in
the case of some kinds of Statement, (e.g. ActiveForm), an agent key, which
is always (...agents...), and the Statement object itself.
If expand n-ary is set to True (the default), Complexes and
Conversions will have their many agents grouped into appropriate pairs,
with each pair yielded as a separate entry IN ADDITION to an entry for the
full set of agents. So Complex(A(), B(), C()) will yield entries for:
- Complex(A(), B()),
- Complex(B(), C()),
- Complex(A(), C()), and
- Complex(A(), B(), C()).
If False, only Complex(A(), B(), C()) will be generated.
"""
def name(agent):
return 'None' if agent is None else agent.name
for s in stmt_list:
# Create a key.
verb = s.__class__.__name__
ags = s.agent_list()
rel_key = None
if verb == 'Complex':
ag_ns = {name(ag) for ag in ags}
if expand_nary:
if 1 < len(ag_ns) < 6:
for pair in permutations(ag_ns, 2):
yield (verb,) + tuple(pair), tuple(pair), s
if len(ag_ns) == 2:
continue
ag_key = tuple(sorted(ag_ns))
elif verb == 'Conversion':
subj = name(s.subj)
objs_from = tuple(sorted({name(ag) for ag in s.obj_from}))
objs_to = tuple(sorted({name(ag) for ag in s.obj_to}))
if expand_nary:
for obj in objs_from:
yield (verb, subj, objs_from, objs_to), (subj, obj), s
for obj in objs_to:
yield (verb, subj, objs_from, objs_to), (subj, obj), s
ag_key = (subj, objs_from, objs_to)
elif verb in ['ActiveForm', 'HasActivity']:
ag_name = name(ags[0])
ag_key = (ag_name,)
rel_key = (verb, ag_name, s.activity,
s.is_active if verb == 'ActiveForm' else s.has_activity)
elif verb == 'Influence':
sns, sid = s.subj.concept.get_grounding()
ons, oid = s.obj.concept.get_grounding()
skey = s.subj.concept.name if not sid \
else sid.split('/')[-1].replace('_', ' ')
okey = s.obj.concept.name if not oid \
else oid.split('/')[-1].replace('_', ' ')
ag_key = (skey, okey)
else:
ag_key = tuple([name(ag) for ag in ags])
# Set the default relation key.
if rel_key is None:
rel_key = (verb,) + ag_key
# Yield the next (default) element.
yield rel_key, ag_key, s
[docs]class StmtStat:
"""Abstraction of a metric applied to a set of statements.
Can be instantiated either via the constructor or two factory class methods:
- s = StmtStat(name, {hash: value, ...}, data_type, AggClass)
- [s1, ...] = \
StmtStat.from_dicts({hash: {label: value, ...}, ...}, data_type, AggClass)
- [s_ev_count, s_belief] = \
StmtStat.from_stmts([Statement(), ...], ('ev_count', 'belief'))
Note that each stat will have only one metric associated with it, so dicts
ingested by `from_dicts` will have their values broken up into separate
StmtStat instances.
Parameters
----------
name : str
The label for this data (e.g. "ev_count" or "belief")
data : dict{int: Number}
The relevant statistics as a dict keyed by hash.
data_type : type
The type of the data (e.g. `int` or `float`).
agg_class : type
A subclass of BasicAggregator which defines how these statistics will be
merged.
"""
def __init__(self, name, data, data_type, agg_class):
self.name = name
self.data = data
self.data_type = data_type
self.agg_class = agg_class
[docs] @classmethod
def from_dicts(cls, dict_data, data_type, agg_class):
"""Generate a list of StmtStat's from a dict of dicts.
Example Usage:
>> source_counts = {9623812756876: {'reach': 1, 'sparser': 2},
>> -39877587165298: {'reach': 3, 'sparser': 0}}
>> stmt_stats = StmtStat.from_dicts(source_counts, int, SumAggregator)
Parameters
----------
dict_data : dict{int: dict{str: Number}}
A dictionary keyed by hash with dictionary elements, where each
element gives a set of measurements for the statement labels as
keys. A common example is `source_counts`.
data_type : type
The type of the data being given (e.g. `int` or `float`).
agg_class : type
A subclass of BasicAggregator which defines how these statistics
will be merged (e.g. `SumAggregator`).
"""
data_groups = defaultdict(dict)
for h, data_dict in dict_data.items():
for name, value in data_dict.items():
data_groups[name][h] = value
data_groups = dict(data_groups)
classes = []
for class_name, class_data in data_groups.items():
classes.append(cls(class_name, class_data, data_type, agg_class))
return classes
[docs] @classmethod
def from_stmts(cls, stmt_list, values=None):
"""Generate a list of StmtStat's from a list of stmts.
The stats will include "ev_count", "belief", and "ag_count" by default,
but a more limited selection may be specified using `values`.
Example usage:
>> stmt_stats = StmtStat.from_stmts(stmt_list, ('ag_count', 'belief'))
Parameters
----------
stmt_list : list[Statement]
A list of INDRA statements, from which basic stats will be derived.
values : Optional[tuple(str)]
A tuple of the names of the values to gather from the list of
statements. For example, if you already have evidence counts, you
might only want to gather belief and agent counts.
"""
type_dict = {'ev_count': {'type': int, 'agg': SumAggregator},
'belief': {'type': float, 'agg': MaxAggregator},
'ag_count': {'type': int, 'agg': SumAggregator}}
if values is None:
values = tuple(type_dict.keys())
# Iterate over statements, filling in values that may have been
# missing.
data = {k: {} for k in values}
for stmt in stmt_list:
sh = stmt.get_hash()
if 'ev_count' in values:
data['ev_count'][sh] = len(stmt.evidence)
if 'belief' in values:
data['belief'][sh] = stmt.belief
if 'ag_count' in values:
data['ag_count'][sh] = len(stmt.agent_list())
# Create the objects.
return [cls(k, d, type_dict[k]['type'], type_dict[k]['agg'])
for k, d in data.items()]
[docs]def make_standard_stats(ev_counts=None, beliefs=None, source_counts=None):
"""Generate the standard ev_counts, beliefs, and source count stats."""
stats = []
if ev_counts:
stats.append(StmtStat('ev_count', ev_counts, int, SumAggregator))
if beliefs:
stats.append(StmtStat('belief', beliefs, float, MaxAggregator))
if source_counts:
stats.extend(StmtStat.from_dicts(source_counts, int, SumAggregator))
return stats
[docs]class StmtGroup:
"""Creates higher-level stmt groupings and aggregates metrics accordingly.
Used internally by `group_and_sort_statements`.
This class manages the accumulation of statistics for statement groupings,
such as by relation or agent pair. It calculates metrics for these
higher-level groupings using metric-specific aggregators implementing the
AggregatorMeta API (e.g., MultiAggregator and any children of
BasicAggregator).
For example, evidence counts for a relation can be calculated as the sum of
the statement-level evidence counts, while the belief for the relation can
be calculated as the average or maximum of the statement-level beliefs.
The primary methods for instantiating this class are the two factory
class methods:
- from_stmt_stats
- from_dicts
See the methods for more details on their purpose and usage.
Once instantiated, the StmtGroup behaves like a defaultdict of lists, where
the keys are group-level keys, and the lists contain statements.
Statements can be iteratively added to the group via the dict-like syntax
`stmt_group[group_key].include(stmt)`. This allows the caller to generate
keys and trigger metric aggregation in a single iteration over statements.
Example usage:
.. code-block:: python
# Get ev_count, belief, and ag_count from a list of statements.
stmt_stats = StmtStat.from_stmts(stmt_list)
# Add another stat for a measure of relevance
stmt_stats.append(
StmtStat('relevance', relevance_dict, float, AveAggregator)
)
# Create the Group
sg = StmtGroup.from_stmt_stats(*stmt_stats)
# Load it full of Statements, grouped by agents.
sg.fill_from_stmt_stats()
sg.start()
for s in stmt_list:
key = (ag.get_grounding() for ag in s.agent_list())
sg[key].include(s)
sg.finish()
# Now the stats for each group are aggregated and available for use.
metrics = sg[(('FPLX', 'MEK'), ('FPLX', 'ERK'))].get_dict()
"""
[docs] @classmethod
def from_stmt_stats(cls, *stmt_stats):
"""Create a stmt group from StmtStat objects.
Return a StmtGroup constructed existing StmtStat objects. This
method offers the user the most control and customizability.
"""
# Organize the data into groups by aggregation class.
stat_groups = defaultdict(lambda: {'stats': defaultdict(list),
'keys': [], 'types': []})
for stat in stmt_stats:
if not isinstance(stat, StmtStat):
raise ValueError("All arguments must be `StmtStat` object.")
stat_groups[stat.agg_class]['keys'].append(stat.name)
stat_groups[stat.agg_class]['types'].append(stat.data_type)
for h, v in stat.data.items():
stat_groups[stat.agg_class]['stats'][h].append(v)
return cls(stat_groups)
[docs] @classmethod
def from_dicts(cls, ev_counts=None, beliefs=None, source_counts=None):
"""Init a stmt group from dicts keyed by hash.
Return a StmtGroup constructed from the given keyword arguments.
The dict keys of `source_counts` will be broken out into their own
StmtStat objects, so that the resulting data model is in effect a flat
list of measurement parameters. There is some risk of name collision, so
take care not to name any sources "ev_counts" or "belief".
"""
stats = make_standard_stats(ev_counts=ev_counts, beliefs=beliefs,
source_counts=source_counts)
return cls.from_stmt_stats(*stats)
def __init__(self, stat_groups):
"""In this case, init is primarily intended for internal use."""
self.__stats = {}
self.__started = False
self.__finished = False
# Check the groups and solidify them in more immutable types.
hash_set = None
self.__stmt_stats = {}
rows = []
for agg_class, info_dict in stat_groups.items():
if hash_set is None:
hash_set = set(info_dict['stats'].keys())
else:
if hash_set != set(info_dict['stats'].keys()):
raise ValueError(f"Stats from {info_dict['keys']} do "
f"not cover the same corpora of hashes.")
self.__stmt_stats[agg_class] = {
'stats': {h: array(l) for h, l in info_dict['stats'].items()},
'keys': tuple(info_dict['keys']),
'types': tuple(info_dict['types'])
}
rows.extend(info_dict['keys'])
self.__rows = tuple(rows)
[docs] def add_stats(self, *stmt_stats):
"""Add more stats to the object.
If you have started accumulating data from statements and doing
aggregation, (e.g. if you have "started"), or if you are "finished",
this request will lead to an error.
"""
new_stats = [s for s in stmt_stats if s.name not in self.row_set()]
if not new_stats:
return
if self.__started or self.__finished:
raise RuntimeError("Cannot add stats after accumulation has "
"started or after it has finished.")
for stat in new_stats:
if not isinstance(stat, StmtStat):
raise ValueError("All arguments must be StmtStat objects.")
if stat.agg_class in self.__stmt_stats:
self.__stmt_stats[stat.agg_class]['keys'] += (stat.name,)
self.__stmt_stats[stat.agg_class]['types'] += (stat.data_type,)
for h, v in stat.data.items():
old_arr = self.__stmt_stats[stat.agg_class]['stats'][h]
self.__stmt_stats[stat.agg_class]['stats'][h] = \
append(old_arr, v)
else:
self.__stmt_stats[stat.agg_class] = {
'stats': {h: array([v]) for h, v in stat.data.items()},
'keys': (stat.name,),
'types': (stat.data_type,)
}
self.__rows += (stat.name,)
return
[docs] def row_set(self):
"""Get a set of the rows (data labels) of the stats in this instance."""
return set(self.__rows)
def __getitem__(self, key):
if key not in self.__stats:
if not self.__started:
raise KeyError(f"Could not add key {key} before "
"accumulation started.")
if not self.__finished:
# Remember, this is passing REFERENCES to the stats dict.
self.__stats[key] = MultiAggregator(
agg_class(d['keys'], d['stats'], d['types'])
for agg_class, d in self.__stmt_stats.items())
else:
raise KeyError(f"Key \"{key}\" not found! "
f"{self.__class__.__name__} is finished.")
return self.__stats[key]
[docs] def start(self):
"""Mark the start of Statement aggregation.
This will freeze the addition of StmtStats and will enable new keyed
entries to be added and aggregated.
"""
self.__started = True
[docs] def finish(self):
"""Finish adding entries, new keys will be rejected."""
self.__finished = True
for stat_grp in self.__stats.values():
stat_grp.finish()
return
def is_finished(self):
return self.__finished
def is_started(self):
return self.__started
[docs] def get_new_instance(self):
"""Create an instance to gather another level of data."""
return self.__class__(self.__stmt_stats)
[docs] def fill_from_stmt_stats(self):
"""Use the statements stats as stats and hashes as keys.
This is used if you decide you just want to represent statements.
"""
if self.__started or self.__finished:
raise RuntimeError("Cannot fill from stats if accumulation has"
"already started or after it has finished.")
# Gather stat rows from the stmt_stats.
stat_rows = defaultdict(lambda: {'keys': tuple(), 'arr': array([]),
'types': tuple()})
for info_dict in self.__stmt_stats.values():
for h, arr in info_dict['stats'].items():
stat_rows[h]['keys'] += info_dict['keys']
stat_rows[h]['arr'] = concatenate([stat_rows[h]['arr'], arr])
stat_rows[h]['types'] += info_dict['types']
stat_rows = dict(stat_rows)
# Fill up the stats.
for h, data in stat_rows.items():
self.__stats[h] = BasicAggregator.from_array(data['keys'],
data['arr'],
data['types'])
# Mark as finished.
self.finish()
return
[docs]class MultiAggregator(AggregatorMeta):
"""Implement the AggregatorMeta API for multiple BasicAggregator children.
Takes an iterable of BasicAggregator children.
"""
def __init__(self, basic_aggs):
self.__basic_aggs = tuple(basic_aggs)
self.__keymap = {k: stat for stat in self.__basic_aggs
for k in stat.keys()}
return
[docs] def include(self, stmt):
for basic_agg in self.__basic_aggs:
basic_agg.include(stmt)
[docs] def get_dict(self):
return {k: v for basic_agg in self.__basic_aggs
for k, v in basic_agg.get_dict().items()}
def finish(self):
for basic_agg in self.__basic_aggs:
basic_agg.finish()
def __getitem__(self, key):
return self.__keymap[key][key]
[docs]class BasicAggregator(AggregatorMeta):
"""Gathers measurements for a statement or similar entity.
By defining a child of BasicAggregator, specifically defining the
operations that gather new data and finalize that data once all the
statements are collected, one can use arbitrary statistical methods to
aggregate metrics for high-level groupings of Statements for subsequent
sorting or filtering purposes.
Parameters
----------
keys : list[str]
A dict keyed by aggregation method of lists of the names for the
elements of data.
stmt_metrics : dict{int: np.ndarray}
A dictionary keyed by hash with each element a dict of arrays keyed
by aggregation type.
original_types : tuple(type)
The type classes of each numerical value stored in the base_group
dict, e.g. `(int, float, int)`.
"""
def __init__(self, keys, stmt_metrics, original_types):
self._keys = keys
self._stmt_metrics = stmt_metrics
self._original_types = original_types
self._values = zeros(len(keys))
self._count = 0
self.__stmt_hashes = set()
self.__frozen = False
self.__dict = None
@classmethod
def from_array(cls, keys, arr, original_types, stmt_metrics=None):
new_cls = cls(keys, stmt_metrics, original_types)
new_cls._values = arr
return new_cls
def _finalize(self):
return
def finish(self):
self._finalize()
self.__frozen = True
[docs] def include(self, stmt):
"""Include a statement and its statistics in the group."""
if self.__frozen:
raise RuntimeError(f"No longer adding more stmt data to "
f"{self.__class__.__name__}.")
if not isinstance(stmt, Statement):
raise ValueError(f"Invalid type for addition to BasicAggregator: "
f"{type(stmt)}. Must be a Statement.")
h = stmt.get_hash()
if h in self.__stmt_hashes:
return
assert self._stmt_metrics and h in self._stmt_metrics
self._merge(self._stmt_metrics[h])
self._count += 1
self.__stmt_hashes.add(h)
def _merge(self, metric_array):
raise NotImplemented
def __getitem__(self, item):
if item not in self._keys:
raise KeyError(f"Key '{item}' not found!")
idx = self._keys.index(item)
return self._values[idx].astype(self._original_types[idx])
def keys(self):
return self._keys[:]
[docs] def get_dict(self):
if not self.__frozen:
raise RuntimeError("Cannot load source dict until frozen.")
if self.__dict is None:
self.__dict = {key: value.astype(original_type)
for key, value, original_type
in zip(self._keys, self._values,
self._original_types)}
return self.__dict
[docs]class SumAggregator(BasicAggregator):
"""A stats aggregator that executes a sum."""
def _merge(self, metric_array):
self._values += metric_array
[docs]class AveAggregator(BasicAggregator):
"""A stats aggregator averages the included statement metrics."""
def _merge(self, metric_array):
self._values += metric_array
def _finalize(self):
self._values = self._values / self._count
[docs]class MaxAggregator(BasicAggregator):
"""A stats aggregator that takes the max of statement metrics."""
def _merge(self, metric_array):
self._values = maximum(self._values, metric_array)
def _get_ag_name_set_len(stmt):
return len(set(a.name if a else 'None' for a in stmt.agent_list()))
[docs]def group_and_sort_statements(stmt_list, sort_by='default', custom_stats=None,
grouping_level='agent-pair'):
"""Group statements by type and arguments, and sort by prevalence.
Parameters
----------
stmt_list : list[Statement]
A list of INDRA statements.
sort_by : str or function or None
If str, it indicates which parameter to sort by, such as 'belief' or
'ev_count', or 'ag_count'. Those are the default options because they
can be derived from a list of statements, however if you give a custom
`stmt_metrics`, you may use any of the parameters used to build it.
The default, 'default', is mostly a sort by ev_count but also favors
statements with fewer agents. Alternatively, you may give a function
that takes a dict as its single argument, a dictionary of metrics. These
metrics are determined by the contents of the `stmt_metrics` passed
as an argument (see StmtGroup for details), or else will contain
the default metrics that can be derived from the statements themselves,
namely `ev_count`, `belief`, and `ag_count`. The value may also
be None, in which case the sort function will return the
same value for all elements, and thus the original order of elements
will be preserved. This could have strange effects when statements are
grouped (i.e. when `grouping_level` is not 'statement'); such
functionality is untested and we make no guarantee that it will work.
custom_stats : list[StmtStat]
A list of custom statement statistics to be used in addition to, or upon
name conflict in place of, the default statement statistics derived from
the list of statements.
grouping_level : str
The options are 'agent-pair', 'relation', and 'statement'. These
correspond to grouping by agent pairs, agent and type relationships, and
a flat list of statements. The default is 'agent-pair'.
Returns
-------
sorted_groups : list[tuple]
A list of tuples of the form (sort_param, key, contents, metrics), where
the sort param is whatever value was calculated to sort the results,
the key is the unique key for the agent pair, relation, or statements,
and the contents are either relations, statements, or statement JSON,
depending on the level. This structure is recursive, so the each list
of relations will also follow this structure, all the way down to
the lowest level (statement JSON). The metrics a dict of the aggregated
metrics for the entry (e.g. source counts, evidence counts, etc).
"""
# Validate the grouping level parameter.
if grouping_level not in ['agent-pair', 'relation', 'statement']:
raise ValueError(f"Invalid grouping level: \"{grouping_level}\".")
# Get any missing default metrics.
if custom_stats is not None:
stats = custom_stats[:]
stat_rows = {stat.name for stat in custom_stats}
else:
stats = []
stat_rows = set()
missing_rows = {'ev_count', 'belief', 'ag_count'} - stat_rows
if missing_rows:
stats += StmtStat.from_stmts(stmt_list, missing_rows)
# Init the base group.
base_group = StmtGroup.from_stmt_stats(*stats)
base_group.fill_from_stmt_stats()
# Define the sort function.
if isinstance(sort_by, str):
def _sort_func(metric):
assert isinstance(sort_by, str)
if sort_by == 'default':
return metric['ev_count'] + 1/(1 + metric['ag_count'])
return metric[sort_by]
elif sort_by is None:
def _sort_func(metric):
return 0
else:
# Check that the sort function is a valid function.
sample_dict = dict.fromkeys(base_group.row_set(), 0)
try:
n = sort_by(sample_dict)
# If the return value is not sortable, this will raise a TypeError.
n < n
except Exception as e:
raise ValueError(f"Invalid sort function: {e}")
# Assign the function.
_sort_func = sort_by
# Write a recursive method to group statement content.
def iter_rows(rows, *metric_dicts):
assert metric_dicts
for key, contents in rows:
metrics = metric_dicts[0][key].get_dict()
if len(metric_dicts) > 1:
if isinstance(contents, dict):
contents = contents.items()
contents = sorted_rows(contents, *metric_dicts[1:])
yield (_sort_func(metrics), str(key)) if sort_by else 0, \
key, contents, metrics
def sorted_rows(rows, *metric_dicts):
return sorted(iter_rows(rows, *metric_dicts), key=lambda t: t[0],
reverse=True)
# Return the sorted statements, if that's all you want.
if grouping_level == 'statement':
stmt_rows = ((s.get_hash(), s) for s in stmt_list)
return sorted_rows(stmt_rows, base_group)
# Create gathering metrics from the statement data.
relation_metrics = base_group.get_new_instance()
relation_metrics.start()
if grouping_level == 'agent-pair':
agent_pair_metrics = base_group.get_new_instance()
agent_pair_metrics.start()
# Add up the grouped statements from the metrics.
if grouping_level == 'relation':
grouped_stmts = defaultdict(list)
else:
grouped_stmts = defaultdict(lambda: defaultdict(list))
expand = (grouping_level == 'agent-pair')
for rel_key, ag_key, stmt in _get_relation_keyed_stmts(stmt_list, expand):
relation_metrics[rel_key].include(stmt)
if grouping_level == 'agent-pair':
grouped_stmts[ag_key][rel_key].append((stmt.get_hash(), stmt))
agent_pair_metrics[ag_key].include(stmt)
else:
grouped_stmts[rel_key].append((stmt.get_hash(), stmt))
# Stop filling these stmt groups. No more "new" keys.
relation_metrics.finish()
if grouping_level == 'agent-pair':
agent_pair_metrics.finish()
# Sort the rows by count and agent names.
if grouping_level == 'relation':
return sorted_rows(grouped_stmts.items(), relation_metrics,
base_group)
return sorted_rows(grouped_stmts.items(), agent_pair_metrics,
relation_metrics, base_group)
[docs]def make_stmt_from_relation_key(relation_key, agents=None):
"""Make a Statement from the relation key.
Specifically, make a Statement object from the sort key used by
`group_and_sort_statements`.
"""
def make_agent(name):
if name == 'None' or name is None:
return None
return Agent(name)
verb = relation_key[0]
inps = relation_key[1:]
StmtClass = get_statement_by_name(verb)
if agents is None:
agents = []
if verb == 'Complex':
agents.extend([make_agent(name) for name in inps])
stmt = StmtClass(agents[:])
elif verb == 'Conversion':
names_from = [make_agent(name) for name in inps[1]]
names_to = [make_agent(name) for name in inps[2]]
agents.extend(names_from + names_to)
stmt = StmtClass(make_agent(inps[0]), names_from, names_to)
elif verb == 'ActiveForm' or verb == 'HasActivity':
agents.extend([make_agent(inps[0])])
stmt = StmtClass(agents[0], inps[1], inps[2])
elif verb == 'Influence':
agents.extend([make_agent(inp) for inp in inps[:2]])
stmt = Influence(*[Event(ag) for ag in agents])
elif verb == 'Association':
agents.extend([make_agent(inp) for inp in inps])
stmt = StmtClass([Event(ag) for ag in agents])
else:
agents.extend([make_agent(name) for name in inps])
stmt = StmtClass(*agents)
return stmt
[docs]def stmt_to_english(stmt):
"""Return an English assembled Statement as a sentence."""
ea = EnglishAssembler([stmt])
return ea.make_model()[:-1]
[docs]def make_string_from_relation_key(rel_key):
"""Make a Statement string via EnglishAssembler from the relation key.
Specifically, make a string from the key used by `group_and_sort_statements`
for contents grouped at the relation level.
"""
stmt = make_stmt_from_relation_key(rel_key)
return stmt_to_english(stmt)
def get_simplified_stmts(stmts):
simple_stmts = []
for rel_key, _, _ in _get_relation_keyed_stmts(stmts, expand_nary=False):
simple_stmts.append(make_stmt_from_relation_key(rel_key))
return simple_stmts
def _str_conversion_bits(tpl):
bolds = ['<b>%s</b>' % el for el in tpl]
return ', '.join(bolds[:-1]) + ', and ' + bolds[-1]
[docs]def make_top_level_label_from_names_key(names):
"""Make an english string from the tuple names."""
try:
if len(names) == 3 and isinstance(names[1], tuple): # Conversions
el_from = _str_conversion_bits(names[1])
el_to = _str_conversion_bits(names[2])
tl_label = ("<b>%s</b> converts %s to %s"
% (names[0], el_from, el_to))
else:
b_names = ['<b>%s</b>' % name for name in names]
if len(names) == 1:
tl_label = b_names[0]
elif len(names) == 2: # Singleton Modifications
if names[0] is None or names[0] == 'None':
tl_label = b_names[1] + " is modified"
else:
tl_label = b_names[0] + " affects " + b_names[1]
elif names[1] == "activity": # ActiveForms
if names[2] or names[2] == "True":
tl_label = b_names[0] + " is active"
else:
tl_label = b_names[0] + " is not active"
else: # Large Complexes
tl_label = b_names[0] + " affects "
tl_label += ", ".join(b_names[1:-1]) + ', and ' + b_names[-1]
return tl_label
except Exception as e:
logger.error("Could not handle: %s" % str(names))
raise e
[docs]def standardize_counts(counts):
"""Standardize hash-based counts dicts to be int-keyed."""
standardized_counts = {}
for k, v in counts.items():
try:
int_k = int(k)
standardized_counts[int_k] = v
except ValueError:
logger.warning('Could not convert statement hash %s to int' % k)
return standardized_counts
def get_available_ev_counts(stmts):
return {stmt.get_hash(): len(stmt.evidence) for stmt in stmts}
def get_available_beliefs(stmts):
return {stmt.get_hash(): stmt.belief for stmt in stmts}
def get_available_source_counts(
stmts: List[Statement],
custom_sources: Optional[List[str]] = None
) -> Dict[int, Dict[str, int]]:
return {stmt.get_hash(): _get_available_ev_source_counts(
stmt.evidence, custom_sources) for stmt in stmts}
[docs]def available_sources_stmts(
stmts: List[Statement],
custom_sources: Optional[List[str]] = None
) -> Set[str]:
"""Returns the set of sources available in a list of statements"""
sources = set()
for stmt in stmts:
source_counts = _get_available_ev_source_counts(stmt.evidence,
custom_sources)
for src, count in source_counts.items():
if count > 0:
sources.add(src)
# Break if all sources are present
if len(sources) == len(all_sources):
break
return sources
[docs]def available_sources_src_counts(
source_counts: Dict[int, Dict[str, int]],
custom_sources: Optional[List[str]] = None
) -> Set[str]:
"""Returns the set of sources available from a source counts dict"""
every_source = all_sources if custom_sources is None else custom_sources
sources = set()
for _, src_count in source_counts.items():
for src, count in src_count.items():
if count > 0:
sources.add(src)
if len(sources) == len(every_source):
break
return sources
def _get_available_ev_source_counts(
evidences: List[Evidence],
custom_sources: Optional[List[str]] = None
) -> Dict[str, int]:
counts = _get_initial_source_counts(custom_sources)
for ev in evidences:
sa = internal_source_mappings.get(ev.source_api, ev.source_api)
try:
counts[sa] += 1
except KeyError:
continue
return counts
def _get_initial_source_counts(custom_sources: Optional[List[str]] = None):
if custom_sources is None:
return {s: 0 for s in all_sources}
else:
return {s: 0 for s in custom_sources}