Source code for trx.utils

# -*- coding: utf-8 -*-

import logging
import os

import nibabel as nib
from nibabel.streamlines.array_sequence import ArraySequence
from nibabel.streamlines.tractogram import Tractogram, TractogramItem
import numpy as np

try:
    import dipy

[docs] dipy_available = True
except ImportError: dipy_available = False
[docs] def close_or_delete_mmap(obj): """Close the memory-mapped file if it exists, otherwise set the object to None. Parameters ---------- obj : object The object that potentially has a memory-mapped file to be closed. """ if hasattr(obj, "_mmap") and obj._mmap is not None: obj._mmap.close() elif isinstance(obj, ArraySequence): close_or_delete_mmap(obj._data) close_or_delete_mmap(obj._offsets) close_or_delete_mmap(obj._lengths) elif isinstance(obj, np.memmap): del obj else: logging.debug("Object to be close or deleted must be np.memmap")
[docs] def split_name_with_gz(filename): """ Returns the clean basename and extension of a file. Means that this correctly manages the ".nii.gz" extensions. Parameters ---------- filename: str The filename to clean Returns ------- base, ext : tuple(str, str) Clean basename and the full extension """ base, ext = os.path.splitext(filename) if ext == ".gz": # Test if we have a .nii additional extension temp_base, add_ext = os.path.splitext(base) if add_ext == ".nii" or add_ext == ".trk": ext = add_ext + ext base = temp_base return base, ext
[docs] def get_reference_info_wrapper(reference): # noqa: C901 """Will compare the spatial attribute of 2 references. Parameters ---------- reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict), TrxFile or trx.header (dict) Reference that provides the spatial attribute. Returns ------- output : tuple - affine ndarray (4,4), np.float32, transformation of VOX to RASMM - dimensions ndarray (3,), int16, volume shape for each axis - voxel_sizes ndarray (3,), float32, size of voxel for each axis - voxel_order, string, Typically 'RAS' or 'LPS' """ from trx import trx_file_memmap is_nifti = False is_trk = False is_sft = False is_trx = False if isinstance(reference, str): _, ext = split_name_with_gz(reference) if ext in [".nii", ".nii.gz"]: header = nib.load(reference).header is_nifti = True elif ext == ".trk": header = nib.streamlines.load(reference, lazy_load=True).header is_trk = True elif ext == ".trx": header = trx_file_memmap.load(reference).header is_trx = True elif isinstance(reference, trx_file_memmap.TrxFile): header = reference.header is_trx = True elif isinstance(reference, nib.nifti1.Nifti1Image): header = reference.header is_nifti = True elif isinstance(reference, nib.streamlines.trk.TrkFile): header = reference.header is_trk = True elif isinstance(reference, nib.nifti1.Nifti1Header): header = reference is_nifti = True elif isinstance(reference, dict) and "magic_number" in reference: header = reference is_trk = True elif isinstance(reference, dict) and "NB_VERTICES" in reference: header = reference is_trx = True elif dipy_available and isinstance( reference, dipy.io.stateful_tractogram.StatefulTractogram ): is_sft = True if is_nifti: affine = header.get_best_affine() dimensions = header["dim"][1:4] voxel_sizes = header["pixdim"][1:4] if not affine[0:3, 0:3].any(): raise ValueError( "Invalid affine, contains only zeros." "Cannot determine voxel order from transformation" ) voxel_order = "".join(nib.aff2axcodes(affine)) elif is_trk: affine = header["voxel_to_rasmm"] dimensions = header["dimensions"] voxel_sizes = header["voxel_sizes"] voxel_order = header["voxel_order"] elif is_sft: affine, dimensions, voxel_sizes, voxel_order = reference.space_attributes elif is_trx: affine = header["VOXEL_TO_RASMM"] dimensions = header["DIMENSIONS"] voxel_sizes = nib.affines.voxel_sizes(affine) voxel_order = "".join(nib.aff2axcodes(affine)) else: raise TypeError("Input reference is not one of the supported format") if isinstance(voxel_order, np.bytes_): voxel_order = voxel_order.decode("utf-8") if dipy_available: from dipy.io.utils import is_reference_info_valid is_reference_info_valid(affine, dimensions, voxel_sizes, voxel_order) return affine, dimensions, voxel_sizes, voxel_order
[docs] def is_header_compatible(reference_1, reference_2): """Will compare the spatial attribute of 2 references. Parameters ---------- reference_1 : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict) Reference that provides the spatial attribute. reference_2 : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or trk.header (dict) Reference that provides the spatial attribute. Returns ------- output : bool Does all the spatial attribute match """ affine_1, dimensions_1, voxel_sizes_1, voxel_order_1 = get_reference_info_wrapper( reference_1 ) affine_2, dimensions_2, voxel_sizes_2, voxel_order_2 = get_reference_info_wrapper( reference_2 ) identical_header = True if not np.allclose(affine_1, affine_2, rtol=1e-03, atol=1e-03): logging.error("Affine not equal") identical_header = False if not np.array_equal(dimensions_1, dimensions_2): logging.error("Dimensions not equal") identical_header = False if not np.allclose(voxel_sizes_1, voxel_sizes_2, rtol=1e-03, atol=1e-03): logging.error("Voxel_size not equal") identical_header = False if voxel_order_1 != voxel_order_2: logging.error("Voxel_order not equal") identical_header = False return identical_header
[docs] def get_axis_shift_vector(flip_axes): """ Parameters ---------- flip_axes : list of str String containing the axis to flip. Possible values are 'x', 'y', 'z' Returns ------- flip_vector : np.ndarray (3,) Vector containing the axis to flip. Possible values are -1, 1 """ shift_vector = np.zeros(3) if "x" in flip_axes: shift_vector[0] = -1.0 if "y" in flip_axes: shift_vector[1] = -1.0 if "z" in flip_axes: shift_vector[2] = -1.0 return shift_vector
[docs] def get_axis_flip_vector(flip_axes): """ Parameters ---------- flip_axes : list of str String containing the axis to flip. Possible values are 'x', 'y', 'z' Returns ------- flip_vector : np.ndarray (3,) Vector containing the axis to flip. Possible values are -1, 1 """ flip_vector = np.ones(3) if "x" in flip_axes: flip_vector[0] = -1.0 if "y" in flip_axes: flip_vector[1] = -1.0 if "z" in flip_axes: flip_vector[2] = -1.0 return flip_vector
[docs] def get_shift_vector(sft): """ When flipping a tractogram the shift vector is used to change the origin of the grid from the corner to the center of the grid. Parameters ---------- sft : StatefulTractogram StatefulTractogram object Returns ------- shift_vector : ndarray Shift vector to apply to the streamlines """ dims = sft.space_attributes[1] shift_vector = -1.0 * (np.array(dims) / 2.0) return shift_vector
[docs] def flip_sft(sft, flip_axes): """Flip the streamlines in the StatefulTractogram according to the flip_axes. Uses the spatial information to flip according to the center of the grid. Parameters ---------- sft : StatefulTractogram StatefulTractogram to flip flip_axes : list of str Axes to flip. Possible values are 'x', 'y', 'z' Returns ------- sft : StatefulTractogram StatefulTractogram with flipped axes """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " "to the StatefulTractogram." ) return None flip_vector = get_axis_flip_vector(flip_axes) shift_vector = get_shift_vector(sft) flipped_streamlines = [] for streamline in sft.streamlines: mod_streamline = streamline + shift_vector mod_streamline *= flip_vector mod_streamline -= shift_vector flipped_streamlines.append(mod_streamline) from dipy.io.stateful_tractogram import StatefulTractogram new_sft = StatefulTractogram.from_sft( flipped_streamlines, sft, data_per_point=sft.data_per_point, data_per_streamline=sft.data_per_streamline, ) return new_sft
[docs] def load_matrix_in_any_format(filepath): """Load a matrix from a txt file OR a npy file. Parameters ---------- filepath : str Path to the matrix file. Returns ------- matrix : numpy.ndarray The matrix. """ _, ext = os.path.splitext(filepath) if ext == ".txt": data = np.loadtxt(filepath) elif ext == ".npy": data = np.load(filepath) else: raise ValueError("Extension {} is not supported".format(ext)) return data
[docs] def get_reverse_enum(space_str, origin_str): """Convert string representation to enums for the StatefulTractogram. Parameters ---------- space_str : str String representing the space. origin_str : str String representing the origin. Returns ------- output : str Space and Origin as Enums. """ if not dipy_available: logging.error( "Dipy library is missing, cannot use functions related " "to the StatefulTractogram." ) return None from dipy.io.stateful_tractogram import Origin, Space origin = Origin.NIFTI if origin_str.lower() == "nifti" else Origin.TRACKVIS if space_str.lower() == "rasmm": space = Space.RASMM elif space_str.lower() == "voxmm": space = Space.VOXMM else: space = Space.VOX return space, origin
[docs] def convert_data_dict_to_tractogram(data): """Convert data from a lazy tractogram to a tractogram. Parameters ---------- data : dict The data dictionary to convert into a nibabel tractogram. Returns ------- Tractogram A Tractogram object. """ streamlines = ArraySequence(data["strs"]) streamlines._data = streamlines._data for key in data["dps"]: shape = (len(streamlines), len(data["dps"][key]) // len(streamlines)) data["dps"][key] = np.array(data["dps"][key]).reshape(shape) for key in data["dpv"]: shape = ( len(streamlines._data), len(data["dpv"][key]) // len(streamlines._data), ) data["dpv"][key] = np.array(data["dpv"][key]).reshape(shape) tmp_arr = ArraySequence() tmp_arr._data = data["dpv"][key] tmp_arr._offsets = streamlines._offsets tmp_arr._lengths = streamlines._lengths data["dpv"][key] = tmp_arr obj = Tractogram( streamlines, data_per_point=data["dpv"], data_per_streamline=data["dps"] ) return obj
[docs] def append_generator_to_dict(gen, data): if isinstance(gen, TractogramItem): data["strs"].append(gen.streamline.tolist()) for key in gen.data_for_points: if key not in data["dpv"]: data["dpv"][key] = np.array([]) data["dpv"][key] = np.append(data["dpv"][key], gen.data_for_points[key]) for key in gen.data_for_streamline: if key not in data["dps"]: data["dps"][key] = np.array([]) data["dps"][key] = np.append(data["dps"][key], gen.data_for_streamline[key]) else: data["strs"].append(gen.tolist())
[docs] def verify_trx_dtype(trx, dict_dtype): # noqa: C901 """Verify if the dtype of the data in the trx is the same as the one in the dict. Parameters ---------- trx : Tractogram Tractogram to verify. dict_dtype : dict Dictionary containing all elements dtype to verify. Returns ------- output : bool True if the dtype is the same, False otherwise. """ identical = True for key in dict_dtype: if key == "positions": if trx.streamlines._data.dtype != dict_dtype[key]: logging.warning("Positions dtype is different") identical = False elif key == "offsets": if trx.streamlines._offsets.dtype != dict_dtype[key]: logging.warning("Offsets dtype is different") identical = False elif key == "dpv": for key_dpv in dict_dtype[key]: if trx.data_per_vertex[key_dpv]._data.dtype != dict_dtype[key][key_dpv]: logging.warning( "Data per vertex ({}) dtype is different".format(key_dpv) ) identical = False elif key == "dps": for key_dps in dict_dtype[key]: if trx.data_per_streamline[key_dps].dtype != dict_dtype[key][key_dps]: logging.warning( "Data per streamline ({}) dtype is different".format(key_dps) ) identical = False elif key == "dpg": for key_group in dict_dtype[key]: for key_dpg in dict_dtype[key][key_group]: if ( trx.data_per_point[key_group][key_dpg].dtype != dict_dtype[key][key_group][key_dpg] ): logging.warning( "Data per group ({}) dtype is different".format(key_dpg) ) identical = False elif key == "groups": for key_group in dict_dtype[key]: if ( trx.data_per_point[key_group]._data.dtype != dict_dtype[key][key_group] ): logging.warning( "Data per group ({}) dtype is different".format(key_group) ) identical = False return identical