# -*- coding: utf-8 -*-
import logging
import os
import nibabel as nib
from nibabel.streamlines.array_sequence import ArraySequence
from nibabel.streamlines.tractogram import Tractogram, TractogramItem
import numpy as np
try:
import dipy
except ImportError:
dipy_available = False
[docs]
def close_or_delete_mmap(obj):
"""Close the memory-mapped file if it exists, otherwise set the object to None.
Parameters
----------
obj : object
The object that potentially has a memory-mapped file to be closed.
"""
if hasattr(obj, "_mmap") and obj._mmap is not None:
obj._mmap.close()
elif isinstance(obj, ArraySequence):
close_or_delete_mmap(obj._data)
close_or_delete_mmap(obj._offsets)
close_or_delete_mmap(obj._lengths)
elif isinstance(obj, np.memmap):
del obj
else:
logging.debug("Object to be close or deleted must be np.memmap")
[docs]
def split_name_with_gz(filename):
"""
Returns the clean basename and extension of a file.
Means that this correctly manages the ".nii.gz" extensions.
Parameters
----------
filename: str
The filename to clean
Returns
-------
base, ext : tuple(str, str)
Clean basename and the full extension
"""
base, ext = os.path.splitext(filename)
if ext == ".gz":
# Test if we have a .nii additional extension
temp_base, add_ext = os.path.splitext(base)
if add_ext == ".nii" or add_ext == ".trk":
ext = add_ext + ext
base = temp_base
return base, ext
[docs]
def get_reference_info_wrapper(reference): # noqa: C901
"""Will compare the spatial attribute of 2 references.
Parameters
----------
reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or
trk.header (dict), TrxFile or trx.header (dict)
Reference that provides the spatial attribute.
Returns
-------
output : tuple
- affine ndarray (4,4), np.float32, transformation of VOX to RASMM
- dimensions ndarray (3,), int16, volume shape for each axis
- voxel_sizes ndarray (3,), float32, size of voxel for each axis
- voxel_order, string, Typically 'RAS' or 'LPS'
"""
from trx import trx_file_memmap
is_nifti = False
is_trk = False
is_sft = False
is_trx = False
if isinstance(reference, str):
_, ext = split_name_with_gz(reference)
if ext in [".nii", ".nii.gz"]:
header = nib.load(reference).header
is_nifti = True
elif ext == ".trk":
header = nib.streamlines.load(reference, lazy_load=True).header
is_trk = True
elif ext == ".trx":
header = trx_file_memmap.load(reference).header
is_trx = True
elif isinstance(reference, trx_file_memmap.TrxFile):
header = reference.header
is_trx = True
elif isinstance(reference, nib.nifti1.Nifti1Image):
header = reference.header
is_nifti = True
elif isinstance(reference, nib.streamlines.trk.TrkFile):
header = reference.header
is_trk = True
elif isinstance(reference, nib.nifti1.Nifti1Header):
header = reference
is_nifti = True
elif isinstance(reference, dict) and "magic_number" in reference:
header = reference
is_trk = True
elif isinstance(reference, dict) and "NB_VERTICES" in reference:
header = reference
is_trx = True
elif dipy_available and isinstance(
reference, dipy.io.stateful_tractogram.StatefulTractogram
):
is_sft = True
if is_nifti:
affine = header.get_best_affine()
dimensions = header["dim"][1:4]
voxel_sizes = header["pixdim"][1:4]
if not affine[0:3, 0:3].any():
raise ValueError(
"Invalid affine, contains only zeros."
"Cannot determine voxel order from transformation"
)
voxel_order = "".join(nib.aff2axcodes(affine))
elif is_trk:
affine = header["voxel_to_rasmm"]
dimensions = header["dimensions"]
voxel_sizes = header["voxel_sizes"]
voxel_order = header["voxel_order"]
elif is_sft:
affine, dimensions, voxel_sizes, voxel_order = reference.space_attributes
elif is_trx:
affine = header["VOXEL_TO_RASMM"]
dimensions = header["DIMENSIONS"]
voxel_sizes = nib.affines.voxel_sizes(affine)
voxel_order = "".join(nib.aff2axcodes(affine))
else:
raise TypeError("Input reference is not one of the supported format")
if isinstance(voxel_order, np.bytes_):
voxel_order = voxel_order.decode("utf-8")
if dipy_available:
from dipy.io.utils import is_reference_info_valid
is_reference_info_valid(affine, dimensions, voxel_sizes, voxel_order)
return affine, dimensions, voxel_sizes, voxel_order
[docs]
def get_axis_shift_vector(flip_axes):
"""
Parameters
----------
flip_axes : list of str
String containing the axis to flip.
Possible values are 'x', 'y', 'z'
Returns
-------
flip_vector : np.ndarray (3,)
Vector containing the axis to flip.
Possible values are -1, 1
"""
shift_vector = np.zeros(3)
if "x" in flip_axes:
shift_vector[0] = -1.0
if "y" in flip_axes:
shift_vector[1] = -1.0
if "z" in flip_axes:
shift_vector[2] = -1.0
return shift_vector
[docs]
def get_axis_flip_vector(flip_axes):
"""
Parameters
----------
flip_axes : list of str
String containing the axis to flip.
Possible values are 'x', 'y', 'z'
Returns
-------
flip_vector : np.ndarray (3,)
Vector containing the axis to flip.
Possible values are -1, 1
"""
flip_vector = np.ones(3)
if "x" in flip_axes:
flip_vector[0] = -1.0
if "y" in flip_axes:
flip_vector[1] = -1.0
if "z" in flip_axes:
flip_vector[2] = -1.0
return flip_vector
[docs]
def get_shift_vector(sft):
"""
When flipping a tractogram the shift vector is used to change the origin
of the grid from the corner to the center of the grid.
Parameters
----------
sft : StatefulTractogram
StatefulTractogram object
Returns
-------
shift_vector : ndarray
Shift vector to apply to the streamlines
"""
dims = sft.space_attributes[1]
shift_vector = -1.0 * (np.array(dims) / 2.0)
return shift_vector
[docs]
def flip_sft(sft, flip_axes):
"""Flip the streamlines in the StatefulTractogram according to the
flip_axes. Uses the spatial information to flip according to the center
of the grid.
Parameters
----------
sft : StatefulTractogram
StatefulTractogram to flip
flip_axes : list of str
Axes to flip.
Possible values are 'x', 'y', 'z'
Returns
-------
sft : StatefulTractogram
StatefulTractogram with flipped axes
"""
if not dipy_available:
logging.error(
"Dipy library is missing, cannot use functions related "
"to the StatefulTractogram."
)
return None
flip_vector = get_axis_flip_vector(flip_axes)
shift_vector = get_shift_vector(sft)
flipped_streamlines = []
for streamline in sft.streamlines:
mod_streamline = streamline + shift_vector
mod_streamline *= flip_vector
mod_streamline -= shift_vector
flipped_streamlines.append(mod_streamline)
from dipy.io.stateful_tractogram import StatefulTractogram
new_sft = StatefulTractogram.from_sft(
flipped_streamlines,
sft,
data_per_point=sft.data_per_point,
data_per_streamline=sft.data_per_streamline,
)
return new_sft
[docs]
def get_reverse_enum(space_str, origin_str):
"""Convert string representation to enums for the StatefulTractogram.
Parameters
----------
space_str : str
String representing the space.
origin_str : str
String representing the origin.
Returns
-------
output : str
Space and Origin as Enums.
"""
if not dipy_available:
logging.error(
"Dipy library is missing, cannot use functions related "
"to the StatefulTractogram."
)
return None
from dipy.io.stateful_tractogram import Origin, Space
origin = Origin.NIFTI if origin_str.lower() == "nifti" else Origin.TRACKVIS
if space_str.lower() == "rasmm":
space = Space.RASMM
elif space_str.lower() == "voxmm":
space = Space.VOXMM
else:
space = Space.VOX
return space, origin
[docs]
def convert_data_dict_to_tractogram(data):
"""Convert data from a lazy tractogram to a tractogram.
Parameters
----------
data : dict
The data dictionary to convert into a nibabel tractogram.
Returns
-------
Tractogram
A Tractogram object.
"""
streamlines = ArraySequence(data["strs"])
streamlines._data = streamlines._data
for key in data["dps"]:
shape = (len(streamlines), len(data["dps"][key]) // len(streamlines))
data["dps"][key] = np.array(data["dps"][key]).reshape(shape)
for key in data["dpv"]:
shape = (
len(streamlines._data),
len(data["dpv"][key]) // len(streamlines._data),
)
data["dpv"][key] = np.array(data["dpv"][key]).reshape(shape)
tmp_arr = ArraySequence()
tmp_arr._data = data["dpv"][key]
tmp_arr._offsets = streamlines._offsets
tmp_arr._lengths = streamlines._lengths
data["dpv"][key] = tmp_arr
obj = Tractogram(
streamlines, data_per_point=data["dpv"], data_per_streamline=data["dps"]
)
return obj
[docs]
def append_generator_to_dict(gen, data):
if isinstance(gen, TractogramItem):
data["strs"].append(gen.streamline.tolist())
for key in gen.data_for_points:
if key not in data["dpv"]:
data["dpv"][key] = np.array([])
data["dpv"][key] = np.append(data["dpv"][key], gen.data_for_points[key])
for key in gen.data_for_streamline:
if key not in data["dps"]:
data["dps"][key] = np.array([])
data["dps"][key] = np.append(data["dps"][key], gen.data_for_streamline[key])
else:
data["strs"].append(gen.tolist())
[docs]
def verify_trx_dtype(trx, dict_dtype): # noqa: C901
"""Verify if the dtype of the data in the trx is the same as the one in
the dict.
Parameters
----------
trx : Tractogram
Tractogram to verify.
dict_dtype : dict
Dictionary containing all elements dtype to verify.
Returns
-------
output : bool
True if the dtype is the same, False otherwise.
"""
identical = True
for key in dict_dtype:
if key == "positions":
if trx.streamlines._data.dtype != dict_dtype[key]:
logging.warning("Positions dtype is different")
identical = False
elif key == "offsets":
if trx.streamlines._offsets.dtype != dict_dtype[key]:
logging.warning("Offsets dtype is different")
identical = False
elif key == "dpv":
for key_dpv in dict_dtype[key]:
if trx.data_per_vertex[key_dpv]._data.dtype != dict_dtype[key][key_dpv]:
logging.warning(
"Data per vertex ({}) dtype is different".format(key_dpv)
)
identical = False
elif key == "dps":
for key_dps in dict_dtype[key]:
if trx.data_per_streamline[key_dps].dtype != dict_dtype[key][key_dps]:
logging.warning(
"Data per streamline ({}) dtype is different".format(key_dps)
)
identical = False
elif key == "dpg":
for key_group in dict_dtype[key]:
for key_dpg in dict_dtype[key][key_group]:
if (
trx.data_per_point[key_group][key_dpg].dtype
!= dict_dtype[key][key_group][key_dpg]
):
logging.warning(
"Data per group ({}) dtype is different".format(key_dpg)
)
identical = False
elif key == "groups":
for key_group in dict_dtype[key]:
if (
trx.data_per_point[key_group]._data.dtype
!= dict_dtype[key][key_group]
):
logging.warning(
"Data per group ({}) dtype is different".format(key_group)
)
identical = False
return identical