# -*- coding: utf-8 -*-
from nibabel.streamlines.tractogram import TractogramItem
from nibabel.streamlines.tractogram import Tractogram
from nibabel.streamlines.array_sequence import ArraySequence
import os
import logging
import nibabel as nib
import numpy as np
try:
import dipy
except ImportError:
dipy_available = False
[docs]
def close_or_delete_mmap(obj):
"""
Close the memory-mapped file if it exists, otherwise set the object to None.
Parameters:
-----------
obj : object
The object that potentially has a memory-mapped file to be closed.
"""
if hasattr(obj, '_mmap') and obj._mmap is not None:
obj._mmap.close()
elif isinstance(obj, ArraySequence):
close_or_delete_mmap(obj._data)
close_or_delete_mmap(obj._offsets)
close_or_delete_mmap(obj._lengths)
elif isinstance(obj, np.memmap):
del obj
else:
logging.debug('Object to be close or deleted must be np.memmap')
[docs]
def split_name_with_gz(filename):
"""
Returns the clean basename and extension of a file.
Means that this correctly manages the ".nii.gz" extensions.
Parameters
----------
filename: str
The filename to clean
Returns
-------
base, ext : tuple(str, str)
Clean basename and the full extension
"""
base, ext = os.path.splitext(filename)
if ext == ".gz":
# Test if we have a .nii additional extension
temp_base, add_ext = os.path.splitext(base)
if add_ext == ".nii" or add_ext == ".trk":
ext = add_ext + ext
base = temp_base
return base, ext
[docs]
def get_reference_info_wrapper(reference):
""" Will compare the spatial attribute of 2 references.
Parameters
----------
reference : Nifti or Trk filename, Nifti1Image or TrkFile, Nifti1Header or
trk.header (dict), TrxFile or trx.header (dict)
Reference that provides the spatial attribute.
Returns
-------
output : tuple
- affine ndarray (4,4), np.float32, tranformation of VOX to RASMM
- dimensions ndarray (3,), int16, volume shape for each axis
- voxel_sizes ndarray (3,), float32, size of voxel for each axis
- voxel_order, string, Typically 'RAS' or 'LPS'
"""
from trx import trx_file_memmap
is_nifti = False
is_trk = False
is_sft = False
is_trx = False
if isinstance(reference, str):
_, ext = split_name_with_gz(reference)
if ext in ['.nii', '.nii.gz']:
header = nib.load(reference).header
is_nifti = True
elif ext == '.trk':
header = nib.streamlines.load(reference, lazy_load=True).header
is_trk = True
elif ext == '.trx':
header = trx_file_memmap.load(reference).header
is_trx = True
elif isinstance(reference, trx_file_memmap.TrxFile):
header = reference.header
is_trx = True
elif isinstance(reference, nib.nifti1.Nifti1Image):
header = reference.header
is_nifti = True
elif isinstance(reference, nib.streamlines.trk.TrkFile):
header = reference.header
is_trk = True
elif isinstance(reference, nib.nifti1.Nifti1Header):
header = reference
is_nifti = True
elif isinstance(reference, dict) and 'magic_number' in reference:
header = reference
is_trk = True
elif isinstance(reference, dict) and 'NB_VERTICES' in reference:
header = reference
is_trx = True
elif dipy_available and \
isinstance(reference, dipy.io.stateful_tractogram.StatefulTractogram):
is_sft = True
if is_nifti:
affine = header.get_best_affine()
dimensions = header['dim'][1:4]
voxel_sizes = header['pixdim'][1:4]
if not affine[0:3, 0:3].any():
raise ValueError(
'Invalid affine, contains only zeros.'
'Cannot determine voxel order from transformation')
voxel_order = ''.join(nib.aff2axcodes(affine))
elif is_trk:
affine = header['voxel_to_rasmm']
dimensions = header['dimensions']
voxel_sizes = header['voxel_sizes']
voxel_order = header['voxel_order']
elif is_sft:
affine, dimensions, voxel_sizes, voxel_order = reference.space_attributes
elif is_trx:
affine = header['VOXEL_TO_RASMM']
dimensions = header['DIMENSIONS']
voxel_sizes = nib.affines.voxel_sizes(affine)
voxel_order = ''.join(nib.aff2axcodes(affine))
else:
raise TypeError('Input reference is not one of the supported format')
if isinstance(voxel_order, np.bytes_):
voxel_order = voxel_order.decode('utf-8')
if dipy_available:
from dipy.io.utils import is_reference_info_valid
is_reference_info_valid(affine, dimensions, voxel_sizes, voxel_order)
return affine, dimensions, voxel_sizes, voxel_order
[docs]
def get_axis_shift_vector(flip_axes):
"""
Parameters
----------
flip_axes : list of str
String containing the axis to flip.
Possible values are 'x', 'y', 'z'
Returns
-------
flip_vector : np.ndarray (3,)
Vector containing the axis to flip.
Possible values are -1, 1
"""
shift_vector = np.zeros(3)
if 'x' in flip_axes:
shift_vector[0] = -1.0
if 'y' in flip_axes:
shift_vector[1] = -1.0
if 'z' in flip_axes:
shift_vector[2] = -1.0
return shift_vector
[docs]
def get_axis_flip_vector(flip_axes):
"""
Parameters
----------
flip_axes : list of str
String containing the axis to flip.
Possible values are 'x', 'y', 'z'
Returns
-------
flip_vector : np.ndarray (3,)
Vector containing the axis to flip.
Possible values are -1, 1
"""
flip_vector = np.ones(3)
if 'x' in flip_axes:
flip_vector[0] = -1.0
if 'y' in flip_axes:
flip_vector[1] = -1.0
if 'z' in flip_axes:
flip_vector[2] = -1.0
return flip_vector
[docs]
def get_shift_vector(sft):
"""
When flipping a tractogram the shift vector is used to change the origin
of the grid from the corner to the center of the grid.
Parameters
----------
sft : StatefulTractogram
StatefulTractogram object
Returns
-------
shift_vector : ndarray
Shift vector to apply to the streamlines
"""
dims = sft.space_attributes[1]
shift_vector = -1.0 * (np.array(dims) / 2.0)
return shift_vector
[docs]
def flip_sft(sft, flip_axes):
""" Flip the streamlines in the StatefulTractogram according to the
flip_axes. Uses the spatial information to flip according to the center
of the grid.
Parameters
----------
sft : StatefulTractogram
StatefulTractogram to flip
flip_axes : list of str
Axes to flip.
Possible values are 'x', 'y', 'z'
Returns
-------
sft : StatefulTractogram
StatefulTractogram with flipped axes
"""
if not dipy_available:
logging.error('Dipy library is missing, cannot use functions related '
'to the StatefulTractogram.')
return None
flip_vector = get_axis_flip_vector(flip_axes)
shift_vector = get_shift_vector(sft)
flipped_streamlines = []
for streamline in sft.streamlines:
mod_streamline = streamline + shift_vector
mod_streamline *= flip_vector
mod_streamline -= shift_vector
flipped_streamlines.append(mod_streamline)
from dipy.io.stateful_tractogram import StatefulTractogram
new_sft = StatefulTractogram.from_sft(flipped_streamlines, sft,
data_per_point=sft.data_per_point,
data_per_streamline=sft.data_per_streamline)
return new_sft
[docs]
def get_reverse_enum(space_str, origin_str):
""" Convert string representation to enums for the StatefulTractogram.
Parameters
----------
space_str : str
String representing the space.
origin_str : str
String representing the origin.
Returns
-------
output : str
Space and Origin as Enums.
"""
if not dipy_available:
logging.error('Dipy library is missing, cannot use functions related '
'to the StatefulTractogram.')
return None
from dipy.io.stateful_tractogram import Space, Origin
origin = Origin.NIFTI if origin_str.lower() == 'nifti' else Origin.TRACKVIS
if space_str.lower() == 'rasmm':
space = Space.RASMM
elif space_str.lower() == 'voxmm':
space = Space.VOXMM
else:
space = Space.VOX
return space, origin
[docs]
def convert_data_dict_to_tractogram(data):
""" Convert a data from a lazy tractogram to a tractogram
Keyword arguments:
data -- The data dictionary to convert into a nibabel tractogram
Returns:
A Tractogram object
"""
streamlines = ArraySequence(data['strs'])
streamlines._data = streamlines._data
for key in data['dps']:
shape = (len(streamlines), len(data['dps'][key]) // len(streamlines))
data['dps'][key] = np.array(data['dps'][key]).reshape(shape)
for key in data['dpv']:
shape = (len(streamlines._data), len(
data['dpv'][key]) // len(streamlines._data))
data['dpv'][key] = np.array(data['dpv'][key]).reshape(shape)
tmp_arr = ArraySequence()
tmp_arr._data = data['dpv'][key]
tmp_arr._offsets = streamlines._offsets
tmp_arr._lengths = streamlines._lengths
data['dpv'][key] = tmp_arr
obj = Tractogram(streamlines, data_per_point=data['dpv'],
data_per_streamline=data['dps'])
return obj
[docs]
def append_generator_to_dict(gen, data):
if isinstance(gen, TractogramItem):
data['strs'].append(gen.streamline.tolist())
for key in gen.data_for_points:
if key not in data['dpv']:
data['dpv'][key] = np.array([])
data['dpv'][key] = np.append(
data['dpv'][key], gen.data_for_points[key])
for key in gen.data_for_streamline:
if key not in data['dps']:
data['dps'][key] = np.array([])
data['dps'][key] = np.append(
data['dps'][key], gen.data_for_streamline[key])
else:
data['strs'].append(gen.tolist())
[docs]
def verify_trx_dtype(trx, dict_dtype):
""" Verify if the dtype of the data in the trx is the same as the one in
the dict.
Parameters
----------
trx : Tractogram
Tractogram to verify.
dict_dtype : dict
Dictionary containing all elements dtype to verify.
Returns
-------
output : bool
True if the dtype is the same, False otherwise.
"""
identical = True
for key in dict_dtype:
if key == 'positions':
if trx.streamlines._data.dtype != dict_dtype[key]:
logging.warning('Positions dtype is different')
identical = False
elif key == 'offsets':
if trx.streamlines._offsets.dtype != dict_dtype[key]:
logging.warning('Offsets dtype is different')
identical = False
elif key == 'dpv':
for key_dpv in dict_dtype[key]:
if trx.data_per_vertex[key_dpv]._data.dtype != dict_dtype[key][key_dpv]:
logging.warning(
'Data per vertex ({}) dtype is different'.format(key_dpv))
identical = False
elif key == 'dps':
for key_dps in dict_dtype[key]:
if trx.data_per_streamline[key_dps].dtype != dict_dtype[key][key_dps]:
logging.warning(
'Data per streamline ({}) dtype is different'.format(key_dps))
identical = False
elif key == 'dpg':
for key_group in dict_dtype[key]:
for key_dpg in dict_dtype[key][key_group]:
if trx.data_per_point[key_group][key_dpg].dtype != dict_dtype[key][key_group][key_dpg]:
logging.warning(
'Data per group ({}) dtype is different'.format(key_dpg))
identical = False
elif key == 'groups':
for key_group in dict_dtype[key]:
if trx.data_per_point[key_group]._data.dtype != dict_dtype[key][key_group]:
logging.warning(
'Data per group ({}) dtype is different'.format(key_group))
identical = False
return identical