Source code for trx.utils

# -*- coding: utf-8 -*-
"""Utility functions for reference handling, coordinate flips, and file operations."""

import logging
import os

import nibabel as nib
from nibabel.streamlines.array_sequence import ArraySequence
from nibabel.streamlines.tractogram import Tractogram, TractogramItem
import numpy as np

try:
    import dipy

[docs] dipy_available = True
except ImportError: dipy_available = False
[docs] def close_or_delete_mmap(obj): """Close the memory-mapped file if it exists, otherwise set the object to None. Parameters ---------- obj : object The object that potentially has a memory-mapped file to be closed. """ if hasattr(obj, "_mmap") and obj._mmap is not None: obj._mmap.close() elif isinstance(obj, ArraySequence): close_or_delete_mmap(obj._data) close_or_delete_mmap(obj._offsets) close_or_delete_mmap(obj._lengths) elif isinstance(obj, np.memmap): del obj else: logging.debug("Object to be close or deleted must be np.memmap")
[docs] def split_name_with_gz(filename): """Return the clean basename and extension of a file. Correctly manages the ".nii.gz" extensions. Parameters ---------- filename : str The filename to clean. Returns ------- base : str Clean basename. ext : str The full extension. """ base, ext = os.path.splitext(filename) if ext == ".gz": # Test if we have a .nii additional extension temp_base, add_ext = os.path.splitext(base) if add_ext == ".nii" or add_ext == ".trk": ext = add_ext + ext base = temp_base return base, ext
[docs] def get_reference_info_wrapper(reference): # noqa: C901 """Extract spatial attributes from a reference object. Parameters ---------- reference : str or dict or Nifti1Image or TrkFile or Nifti1Header or TrxFile Reference that provides the spatial attribute. Returns ------- affine : ndarray (4, 4) Transformation of VOX to RASMM, np.float32. dimensions : ndarray (3,) Volume shape for each axis, int16. voxel_sizes : ndarray (3,) Size of voxel for each axis, float32. voxel_order : str Typically 'RAS' or 'LPS'. """ from trx import trx_file_memmap is_nifti = False is_trk = False is_sft = False is_trx = False if isinstance(reference, str): _, ext = split_name_with_gz(reference) if ext in [".nii", ".nii.gz"]: header = nib.load(reference).header is_nifti = True elif ext == ".trk": header = nib.streamlines.load(reference, lazy_load=True).header is_trk = True elif ext == ".trx": header = trx_file_memmap.load(reference).header is_trx = True elif isinstance(reference, trx_file_memmap.TrxFile): header = reference.header is_trx = True elif isinstance(reference, nib.nifti1.Nifti1Image): header = reference.header is_nifti = True elif isinstance(reference, nib.streamlines.trk.TrkFile): header = reference.header is_trk = True elif isinstance(reference, nib.nifti1.Nifti1Header): header = reference is_nifti = True elif isinstance(reference, dict) and "magic_number" in reference: header = reference is_trk = True elif isinstance(reference, dict) and "NB_VERTICES" in reference: header = reference is_trx = True elif dipy_available and isinstance( reference, dipy.io.stateful_tractogram.StatefulTractogram ): is_sft = True if is_nifti: affine = header.get_best_affine() dimensions = header["dim"][1:4] voxel_sizes = header["pixdim"][1:4] if not affine[0:3, 0:3].any(): raise ValueError( "Invalid affine, contains only zeros." "Cannot determine voxel order from transformation" ) voxel_order = "".join(nib.aff2axcodes(affine)) elif is_trk: affine = header["voxel_to_rasmm"] dimensions = header["dimensions"] voxel_sizes = header["voxel_sizes"] voxel_order = header["voxel_order"] elif is_sft: affine, dimensions, voxel_sizes, voxel_order = reference.space_attributes elif is_trx: affine = header["VOXEL_TO_RASMM"] dimensions = header["DIMENSIONS"] voxel_sizes = nib.affines.voxel_sizes(affine) voxel_order = "".join(nib.aff2axcodes(affine)) else: raise TypeError("Input reference is not one of the supported format") if isinstance(voxel_order, np.bytes_): voxel_order = voxel_order.decode("utf-8") if dipy_available: from dipy.io.utils import is_reference_info_valid is_reference_info_valid(affine, dimensions, voxel_sizes, voxel_order) return affine, dimensions, voxel_sizes, voxel_order
[docs] def is_header_compatible(reference_1, reference_2): """Compare the spatial attributes of 2 references. Parameters ---------- reference_1 : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict) Reference that provides the spatial attribute. reference_2 : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict) Reference that provides the spatial attribute. Returns ------- bool Whether all the spatial attributes match. """ affine_1, dimensions_1, voxel_sizes_1, voxel_order_1 = get_reference_info_wrapper( reference_1 ) affine_2, dimensions_2, voxel_sizes_2, voxel_order_2 = get_reference_info_wrapper( reference_2 ) identical_header = True if not np.allclose(affine_1, affine_2, rtol=1e-03, atol=1e-03): logging.error("Affine not equal") identical_header = False if not np.array_equal(dimensions_1, dimensions_2): logging.error("Dimensions not equal") identical_header = False if not np.allclose(voxel_sizes_1, voxel_sizes_2, rtol=1e-03, atol=1e-03): logging.error("Voxel_size not equal") identical_header = False if voxel_order_1 != voxel_order_2: logging.error("Voxel_order not equal") identical_header = False return identical_header
[docs] def get_axis_shift_vector(flip_axes): """Return a shift vector for the given axes. Parameters ---------- flip_axes : list of str String containing the axis to flip. Possible values are 'x', 'y', 'z'. Returns ------- shift_vector : np.ndarray (3,) Vector containing the axis to shift. Possible values are -1, 0. """ shift_vector = np.zeros(3) if "x" in flip_axes: shift_vector[0] = -1.0 if "y" in flip_axes: shift_vector[1] = -1.0 if "z" in flip_axes: shift_vector[2] = -1.0 return shift_vector
[docs] def get_axis_flip_vector(flip_axes): """Return a flip vector for the given axes. Parameters ---------- flip_axes : list of str String containing the axis to flip. Possible values are 'x', 'y', 'z'. Returns ------- flip_vector : np.ndarray (3,) Vector containing the axis to flip. Possible values are -1, 1. """ flip_vector = np.ones(3) if "x" in flip_axes: flip_vector[0] = -1.0 if "y" in flip_axes: flip_vector[1] = -1.0 if "z" in flip_axes: flip_vector[2] = -1.0 return flip_vector
[docs] def get_shift_vector(sft): """Return the shift vector for flipping a tractogram. When flipping a tractogram the shift vector is used to change the origin of the grid from the corner to the center of the grid. Parameters ---------- sft : StatefulTractogram StatefulTractogram object. Returns ------- shift_vector : ndarray Shift vector to apply to the streamlines. """ dims = sft.space_attributes[1] shift_vector = -1.0 * (np.array(dims) / 2.0) return shift_vector
[docs] def flip_sft(sft, flip_axes): """Flip the streamlines in a StatefulTractogram. Use the spatial information to flip according to the center of the grid. Parameters ---------- sft : StatefulTractogram StatefulTractogram to flip. flip_axes : list of str Axes to flip. Possible values are 'x', 'y', 'z'. Returns ------- sft : StatefulTractogram StatefulTractogram with flipped axes. """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " "to the StatefulTractogram." ) return None flip_vector = get_axis_flip_vector(flip_axes) shift_vector = get_shift_vector(sft) flipped_streamlines = [] for streamline in sft.streamlines: mod_streamline = streamline + shift_vector mod_streamline *= flip_vector mod_streamline -= shift_vector flipped_streamlines.append(mod_streamline) from dipy.io.stateful_tractogram import StatefulTractogram new_sft = StatefulTractogram.from_sft( flipped_streamlines, sft, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline, ) return new_sft
[docs] def load_matrix_in_any_format(filepath): """Load a matrix from a txt file OR a npy file. Parameters ---------- filepath : str Path to the matrix file. Returns ------- matrix : numpy.ndarray The matrix. """ _, ext = os.path.splitext(filepath) if ext == ".txt": data = np.loadtxt(filepath) elif ext == ".npy": data = np.load(filepath) else: raise ValueError("Extension {} is not supported".format(ext)) return data
[docs] def get_reverse_enum(space_str, origin_str): """Convert string representation to enums for the StatefulTractogram. Parameters ---------- space_str : str String representing the space. origin_str : str String representing the origin. Returns ------- space : Space Space enum value. origin : Origin Origin enum value. """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " "to the StatefulTractogram." ) return None from dipy.io.stateful_tractogram import Origin, Space origin = Origin.NIFTI if origin_str.lower() == "nifti" else Origin.TRACKVIS if space_str.lower() == "rasmm": space = Space.RASMM elif space_str.lower() == "voxmm": space = Space.VOXMM else: space = Space.VOX return space, origin
[docs] def convert_data_dict_to_tractogram(data): """Convert data from a lazy tractogram to a tractogram. Parameters ---------- data : dict The data dictionary to convert into a nibabel tractogram. Returns ------- Tractogram A Tractogram object. """ streamlines = ArraySequence(data["strs"]) streamlines._data = streamlines._data for key in data["dps"]: shape = (len(streamlines), len(data["dps"][key]) // len(streamlines)) data["dps"][key] = np.array(data["dps"][key]).reshape(shape) for key in data["dpv"]: shape = ( len(streamlines._data), len(data["dpv"][key]) // len(streamlines._data), ) data["dpv"][key] = np.array(data["dpv"][key]).reshape(shape) tmp_arr = ArraySequence() tmp_arr._data = data["dpv"][key] tmp_arr._offsets = streamlines._offsets tmp_arr._lengths = streamlines._lengths data["dpv"][key] = tmp_arr obj = Tractogram( streamlines, data_per_point=data["dpv"], data_per_streamline=data["dps"] ) return obj
[docs] def append_generator_to_dict(gen, data): """Append items yielded by a tractogram generator into data dict. Parameters ---------- gen : TractogramItem class instance or np.ndarray Item produced by a tractogram generator. Structured entries include per-point and per-streamline metadata. data : dict Accumulator containing ``strs`` (positions), ``dpv`` and ``dps`` dictionaries that will be extended in-place. Returns ------- None The function mutates ``data`` and returns ``None``. """ if isinstance(gen, TractogramItem): data["strs"].append(gen.streamline.tolist()) for key in gen.data_for_points: if key not in data["dpv"]: data["dpv"][key] = np.array([]) data["dpv"][key] = np.append(data["dpv"][key], gen.data_for_points[key]) for key in gen.data_for_streamline: if key not in data["dps"]: data["dps"][key] = np.array([]) data["dps"][key] = np.append(data["dps"][key], gen.data_for_streamline[key]) else: data["strs"].append(gen.tolist())
[docs] def verify_trx_dtype(trx, dict_dtype): # noqa: C901 """Verify that data dtypes in the trx match the given dict. Parameters ---------- trx : Tractogram Tractogram to verify. dict_dtype : dict Dictionary containing all elements dtype to verify. Returns ------- bool True if the dtype is the same, False otherwise. """ identical = True for key in dict_dtype: if key == "positions": if trx.streamlines._data.dtype != dict_dtype[key]: logging.warning("Positions dtype is different") identical = False elif key == "offsets": if trx.streamlines._offsets.dtype != dict_dtype[key]: logging.warning("Offsets dtype is different") identical = False elif key == "dpv": for key_dpv in dict_dtype[key]: if trx.data_per_vertex[key_dpv]._data.dtype != dict_dtype[key][key_dpv]: logging.warning( "Data per vertex ({}) dtype is different".format(key_dpv) ) identical = False elif key == "dps": for key_dps in dict_dtype[key]: if trx.data_per_streamline[key_dps].dtype != dict_dtype[key][key_dps]: logging.warning( "Data per streamline ({}) dtype is different".format(key_dps) ) identical = False elif key == "dpg": for key_group in dict_dtype[key]: for key_dpg in dict_dtype[key][key_group]: if ( trx.data_per_point[key_group][key_dpg].dtype != dict_dtype[key][key_group][key_dpg] ): logging.warning( "Data per group ({}) dtype is different".format(key_dpg) ) identical = False elif key == "groups": for key_group in dict_dtype[key]: if ( trx.data_per_point[key_group]._data.dtype != dict_dtype[key][key_group] ): logging.warning( "Data per group ({}) dtype is different".format(key_group) ) identical = False return identical