Source code for trx.streamlines_ops
# -*- coding: utf-8 -*-
"""Set operations on streamlines with precision-based matching."""
from functools import reduce
import itertools
import numpy as np
[docs]
KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1)))
[docs]
def intersection(left, right):
"""Return the intersection of two streamline hash dictionaries.
Parameters
----------
left : dict
Hash dictionary returned by :func:`hash_streamlines`.
right : dict
Hash dictionary returned by :func:`hash_streamlines`.
Returns
-------
dict
Dictionary containing only keys present in both inputs.
"""
return {k: v for k, v in left.items() if k in right}
[docs]
def difference(left, right):
"""Return the difference of two streamline hash dictionaries.
Parameters
----------
left : dict
Hash dictionary returned by :func:`hash_streamlines`.
right : dict
Hash dictionary returned by :func:`hash_streamlines`.
Returns
-------
dict
Dictionary containing keys present in ``left`` but not in ``right``.
"""
return {k: v for k, v in left.items() if k not in right}
[docs]
def union(left, right):
"""Return the union of two streamline hash dictionaries.
Parameters
----------
left : dict
Hash dictionary returned by :func:`hash_streamlines`.
right : dict
Hash dictionary returned by :func:`hash_streamlines`.
Returns
-------
dict
Dictionary containing all keys from both inputs. Values from ``left``
overwrite those from ``right`` when keys overlap.
"""
result = right.copy()
result.update(left)
return result
[docs]
def get_streamline_key(streamline, precision=None):
"""Produce a hash key from a streamline using a few points.
Parameters
----------
streamline : ndarray
A single streamline (N, 3).
precision : int, optional
The number of decimals to keep when hashing the points of the
streamlines. Allows a soft comparison of streamlines. If None, no
rounding is performed.
Returns
-------
bytes
Hash of the first/last MIN_NB_POINTS points of the streamline.
"""
# Use just a few data points as hash key. I could use all the data of
# the streamlines, but then the complexity grows with the number of
# points.
if len(streamline) < MIN_NB_POINTS:
key = streamline.copy()
else:
key = streamline[KEY_INDEX].copy()
if precision is not None:
key = np.round(key, precision)
key.flags.writeable = False
return key.data.tobytes()
[docs]
def hash_streamlines(streamlines, start_index=0, precision=None):
"""Produce a dict from streamlines.
Produce a dict from streamlines by using the points as keys and the
indices of the streamlines as values.
Parameters
----------
streamlines : list of ndarray
The list of streamlines used to produce the dict.
start_index : int, optional
The index of the first streamline. 0 by default.
precision : int, optional
The number of decimals to keep when hashing the points of the
streamlines. Allows a soft comparison of streamlines. If None, no
rounding is performed.
Returns
-------
dict
A dict where the keys are streamline points and the values are
indices starting at start_index.
"""
keys = [get_streamline_key(s, precision) for s in streamlines]
return {k: i for i, k in enumerate(keys, start_index)}