"""Toblerone utility functions"""
import os.path as op
import glob
import warnings
import multiprocessing
import copy
import numpy as np
import regtricks as rt
from scipy.sparse.linalg import eigs
NP_FLOAT = np.float32
STRUCTURES = ['L_Accu', 'L_Amyg', 'L_Caud', 'L_Hipp', 'L_Pall', 'L_Puta',
'L_Thal', 'R_Accu', 'R_Amyg', 'R_Caud', 'R_Hipp', 'R_Pall', 'R_Puta',
'R_Thal', 'BrStem']
[docs]def cascade_attributes(decorator):
"""
Overrride default decorator behaviour to preserve docstrings etc
of decorated functions - functools.wraps didn't seem to work.
See https://stackoverflow.com/questions/6394511/python-functools-wraps-equivalent-for-classes
"""
def new_decorator(original):
wrapped = decorator(original)
wrapped.__name__ = original.__name__
wrapped.__doc__ = original.__doc__
wrapped.__module__ = original.__module__
return wrapped
return new_decorator
[docs]def check_anat_dir(dir):
"""Check that dir contains output from FIRST and FAST"""
return all([
op.isdir(op.join(dir, 'first_results')),
op.isfile(op.join(dir, 'T1_fast_pve_0.nii.gz')),
])
def _loadFIRSTdir(dir_path):
"""Load surface paths from a FIRST directory into a dict, accessed by the
standard keys used by FIRST (eg 'BrStem'). The function will attempt to
load every available surface found in the directory using the standard
list as reference (see FIRST documentation) but no errors will be raised
if a particular surface is not found
"""
files = glob.glob(op.join(dir_path, '*.vtk'))
if not files:
raise RuntimeError("FIRST directory %s is empty" % dir_path)
surfs = {}
for f in files:
fname = op.split(f)[1]
for s in STRUCTURES:
if s in fname:
surfs[s] = f
return surfs
def _loadFASTdir(dir):
"""Load the PV image paths for WM,GM,CSF from a FAST directory into a
dict, accessed by the keys FAST_GM for GM etc
"""
if not op.isdir(dir):
raise RuntimeError("FAST directory does not exist")
paths = {}
files = glob.glob(op.join(dir, '*.nii.gz'))
channels = { '_pve_{}'.format(c): t for (c,t)
in enumerate(['CSF', 'GM', 'WM']) }
for f in files:
fname = op.split(f)[1]
for c, t in channels.items():
if c in fname:
paths['FAST_' + t] = f
if len(paths) != 3:
raise RuntimeError("Could not load 3 PV maps from FAST directory")
return paths
def _loadSurfsToDict(fsdir):
"""Load the left/right white/pial surface paths from a FS directory into
a dictionary, accessed by the keys LWS/LPS/RPS/RWS
"""
sdir = op.realpath(op.join(fsdir, 'surf'))
if not op.isdir(sdir):
raise RuntimeError("Subject's surf directory does not exist")
surfs = {}
snames = {'L': 'lh', 'R': 'rh'}
exts = {'W': 'white', 'P': 'pial'}
for s in ['LWS', 'LPS', 'RWS', 'RPS']:
path = op.join(fsdir, 'surf', '%s.%s' % (snames[s[0]], exts[s[1]]))
if not op.exists(path):
raise RuntimeError("Could not find a file for %s in %s" % (s,fsdir))
surfs[s] = path
return surfs
def _splitExts(fname):
"""Split all extensions off a filename, eg:
'file.ext1.ext2' -> ('file', '.ext1.ext2')
"""
fname = op.split(fname)[1]
ext = ''
while '.' in fname:
fname, e = op.splitext(fname)
ext = e + ext
return fname, ext
def _distributeObjects(objs, ngroups):
"""Distribute a set of objects into n groups.
For preparing chunks before multiprocessing.pool.map
Returns a set of ranges, each of which are index numbers for
the original set of objs
"""
chunkSize = np.floor(len(objs) / ngroups).astype(np.int32)
chunks = []
for n in range(ngroups):
if n != ngroups - 1:
chunks.append(range(n * chunkSize, (n+1) * chunkSize))
else:
chunks.append(range(n * chunkSize, len(objs)))
assert sum(map(len, chunks)) == len(objs), \
"Distribute objects error: not all objects distributed"
return chunks
[docs]@cascade_attributes
def enforce_and_load_common_arguments(func):
"""
Decorator to enforce and pre-processes common arguments in a
kwargs dict that are used across multiple functions. Note
some function-specific checking is still required. This intercepts the
kwargs dict passed to the caller, does some checking and modification
in place, and then returns to the caller. The following args are handled:
Required args:
ref: path to a reference image in which to operate.
struct2ref: path/np.array/Registration representing transformation
between structural space (that of the surfaces) and reference.
If given as 'I', identity matrix will be used.
Optional args:
fslanat: a fslanat directory.
flirt: bool denoting that the struct2ref is a FLIRT transform.
This means it requires special treatment. If set, then it will be
pre-processed in place by this function, and then the flag will
be set back to false when the kwargs dict is returned to the caller
struct: if FLIRT given, then the path to the structural image used
for surface generation is required for said special treatment.
cores: maximum number of cores to parallelise tasks across
(default is N-1).
Returns:
a modified copy of kwargs dictionary passed to the caller.
"""
def enforcer(ref, struct2ref, **kwargs):
# Reference image path
if (not isinstance(ref, rt.ImageSpace)):
ref = rt.ImageSpace(ref)
# If given a fslanat dir we can load the structural image in
if kwargs.get('fslanat'):
if not check_anat_dir(kwargs['fslanat']):
raise RuntimeError("fslanat is not complete: it must contain "
"FAST output and a first_results subdirectory")
kwargs['fastdir'] = kwargs['fslanat']
kwargs['firstdir'] = op.join(kwargs['fslanat'], 'first_results')
# If no struct image given, try and pull it out from the anat dir
# But, if it has been cropped relative to original T1, then give
# warning (as we will not be able to convert FLIRT to world-world)
if not kwargs.get('struct'):
if kwargs.get('flirt'):
matpath = glob.glob(op.join(kwargs['fslanat'], '*nonroi2roi.mat'))[0]
nonroi2roi = np.loadtxt(matpath)
if np.any(np.abs(nonroi2roi[0:3,3])):
print("Warning: T1 was cropped relative to T1_orig within" +
" fslanat dir.\n Please ensure the struct2ref FLIRT" +
" matrix is referenced to T1, not T1_orig.")
s = op.join(kwargs['fslanat'], 'T1.nii.gz')
kwargs['struct'] = s
if not op.isfile(s):
raise RuntimeError("Could not find T1.nii.gz in the fslanat dir")
# Structural to reference transformation. Either as array, path
# to file containing matrix, or regtricks Registration object
if not any([type(struct2ref) is str, type(struct2ref) is np.ndarray,
type(struct2ref) is rt.Registration ]):
raise RuntimeError("struct2ref transform must be given (either path,",
" np.array or regtricks Registration object)")
else:
s2r = struct2ref
if (type(s2r) is str):
if s2r == 'I':
matrix = np.identity(4)
else:
_, matExt = op.splitext(s2r)
try:
if matExt in ['.txt', '.mat']:
matrix = np.loadtxt(s2r,
dtype=NP_FLOAT)
elif matExt in ['.npy', 'npz', '.pkl']:
matrix = np.load(s2r)
else:
matrix = np.fromfile(s2r,
dtype=NP_FLOAT)
except Exception as e:
warnings.warn("""Could not load struct2ref matrix.
File should be any type valid with numpy.load().""")
raise e
struct2ref = matrix
# If FLIRT transform we need to do some clever preprocessing
# We then set the flirt flag to false again (otherwise later steps will
# repeat the tricks and end up reverting to the original - those steps don't
# need to know what we did here, simply that it is now world-world again)
if kwargs.get('flirt'):
if not kwargs.get('struct'):
raise RuntimeError("If using a FLIRT transform, the path to the"
" structural image must also be given")
struct2ref = rt.Registration.from_flirt(struct2ref,
kwargs['struct'], ref).src2ref
kwargs['flirt'] = False
elif isinstance(struct2ref, rt.Registration):
struct2ref = struct2ref.src2ref
assert isinstance(struct2ref, np.ndarray), 'should have cast struc2ref to np.array'
# Processor cores
if not kwargs.get('cores'):
kwargs['cores'] = multiprocessing.cpu_count()
# Supersampling factor
sup = kwargs.get('super')
if sup is not None:
try:
if (type(sup) is list) and (len(sup) == 3):
sup = np.array([int(s) for s in sup])
else:
sup = int(sup[0])
sup = np.array([sup for _ in range(3)])
if type(sup) is not np.ndarray:
raise RuntimeError()
except:
raise RuntimeError("-super must be a value or list of 3" +
" values of int type")
return ref, struct2ref, kwargs
def common_args_enforced(ref, struct2ref, *args, **kwargs):
ref, struct2ref, enforced = enforcer(ref, struct2ref, **kwargs)
return func(ref, struct2ref, *args, **enforced)
return common_args_enforced
[docs]def sparse_normalise(mat, axis, threshold=1e-6):
"""
Normalise a sparse matrix so that all rows (axis=1) or columns (axis=0)
sum to either 1 or zero. NB any rows or columns that sum to less than
threshold will be rounded to zeros.
Args:
mat: sparse matrix to normalise
axis: dimension for which sum should equal 1 (1 for row, 0 for col)
threshold: any row/col with sum < threshold will be set to zero
Returns:
sparse matrix, same format as input.
"""
# Make local copy - otherwise this function will modify the caller's copy
constructor = type(mat)
mat = copy.deepcopy(mat)
if axis == 0:
matrix = mat.tocsr()
elif axis == 1:
matrix = mat.tocsc()
else:
raise ValueError("Axis must be 0 or 1")
# Set threshold. Round any row/col below this to zeros
norm = mat.sum(axis).A.flatten()
fltr = (norm > threshold)
normalise = np.zeros(norm.size)
normalise[fltr] = 1 / norm[fltr]
matrix.data *= np.take(normalise, matrix.indices)
# Sanity check
sums = matrix.sum(axis).A.flatten()
assert np.allclose(sums[sums>0], 1, atol=1e-4), 'Did not normalise to 1'
return constructor(matrix)
[docs]def is_symmetric(a, tol=1e-9):
if not a.nnz:
return True
return not (np.abs(a - a.T) > tol).max()
[docs]def is_nsd(a):
if not a.nnz:
return True
return not (eigs(a)[0] > 0).any()
[docs]def calculateXprods(points, tris):
"""
Normal vectors for points,triangles array.
For triangle vertices ABC, this is calculated as (C - A) x (B - A).
"""
return np.cross(
points[tris[:,2],:] - points[tris[:,0],:],
points[tris[:,1],:] - points[tris[:,0],:],
axis=1)
[docs]def slice_sparse(mat, slice0, slice1):
"""
Slice a block out of a sparse matrix, ie mat[slice0,slice1].
Scipy sparse matrices do not support slicing in this manner (unlike numpy)
Args:
mat (sparse): of any form
slice0 (bool,array): mask to apply on axis 0 (rows)
slice1 (bool,array): mask to apply on axis 1 (columns)
Returns:
CSR matrix
"""
out = mat.tocsc()[:,slice1]
return out.tocsr()[slice0,:]
[docs]def mask_laplacian(lap, mask):
lap = lap.tocsr()
lap[np.diag_indices_from(lap)] = 0
lap = slice_sparse(lap, mask, mask)
dia = lap.sum(1).A.flatten()
lap[np.diag_indices_from(lap)] = -dia
assert (np.abs(lap.sum(1))< 1e-4).all(), 'Unweighted laplacian'
assert is_nsd(lap), 'Not negative semi-definite'
assert is_symmetric(lap), 'Not symmetric'
return lap
[docs]def mask_projection_matrix(matrix, row_mask, col_mask):
"""
Mask a sparse projection matrix, whilst preserving total signal intensity.
For example, if the mask implies discarding voxels from a vol2surf matrix,
upweight the weights of the remaining voxels to account for the discarded
voxels. Rows represent the target domain and columns the source.
Args:
matrix (sparse): of any form, shape (N,M)
row_mask (np.array,bool): flat vector of size N, rows to retain in matrix
col_mask (np.array,bool): flat vector of size M, cols to retain in matrix
Returns:
CSR matrix, shape (sum(row_mask), sum(col_mask))
"""
if (row_mask.dtype.kind != 'b') or (col_mask.dtype.kind != 'b'):
raise ValueError("Row and column masks must be boolean arrays")
if not (row_mask.shape[0] == row_mask.size == matrix.shape[0]):
raise ValueError('Row mask size does not match matrix shape')
if not (col_mask.shape[0] == col_mask.size == matrix.shape[1]):
raise ValueError('Column mask size does not match matrix shape')
# Masking by rows is easy because they represent the output domain - ie,
# we can just drop the rows and ignore.
matrix = matrix[row_mask,:]
# Masking by cols is harder because we seek to preserve the overall signal
# intensity from source to output. So for every bit of source signal that
# goes missing, we want to upscale the remainder so that the overall
# magnitude of output remains the same.
matrix_masked = matrix[:,col_mask]
orig_weights = matrix.sum(1).A.flatten()
new_weights = matrix_masked.sum(1).A.flatten()
valid = (new_weights > 0)
sf = np.zeros_like(orig_weights)
sf[valid] = orig_weights[valid] / new_weights[valid]
matrix_csc = matrix_masked.tocsc()
matrix_csc.data = matrix_csc.data * np.take(sf, matrix_csc.indices)
matrix_new = matrix_csc.tocsr()
final_weights = matrix_new.sum(1).A.flatten()
if not np.allclose(final_weights[valid], orig_weights[valid]):
raise RuntimeError('Masked matrix did not re-scale correctly')
return matrix_new
[docs]def rebase_triangles(points, tris, tri_inds):
"""
Re-express a patch of a larger surface as a new points and triangle
matrix pair, indexed from 0. Useful for reducing computational
complexity when working with a small patch of a surface where only
a few nodes in the points array are required by the triangles matrix.
Args:
points (np.array): surface vertices, P x 3
tris (np.array): surface triangles, T x 3
tri_inds (np.array): row indices into triangles array, to rebase
Returns:
(points, tris) tuple of re-indexed points/tris.
"""
ps = np.empty((0, 3), dtype=NP_FLOAT)
ts = np.empty((len(tri_inds), 3), dtype=np.int32)
pointsLUT = []
for t in range(len(tri_inds)):
for v in range(3):
# For each vertex of each tri, check if we
# have already processed it in the LUT
vtx = tris[tri_inds[t],v]
idx = np.argwhere(pointsLUT == vtx)
# If not in the LUT, then add it and record that
# as the new position. Write the missing vertex
# into the local points array
if not idx.size:
pointsLUT.append(vtx)
idx = len(pointsLUT) - 1
ps = np.vstack([ps, points[vtx,:]])
# Update the local triangle
ts[t,v] = idx
return (ps, ts)
[docs]def space_encloses_surface(space, points_vox):
if np.round(np.min(points_vox)) < 0:
return False
if (np.round(np.max(points_vox, axis=0)) >= space.size).any():
return False
return True
[docs]def load_surfs_to_hemispheres(**kwargs):
from .classes import Hemisphere
# If subdir given, then get all the surfaces out of the surf dir
# If individual surface paths were given they will already be in scope
if kwargs.get('fsdir'):
surfdict = _loadSurfsToDict(kwargs['fsdir'])
kwargs.update(surfdict)
# What hemispheres are we working with?
sides = []
if np.all([ (kwargs.get(s) is not None) for s in ['LPS', 'LWS'] ]):
sides.append('L')
if np.all([ kwargs.get(s) is not None for s in ['RPS', 'RWS'] ]):
sides.append('R')
if not sides:
raise RuntimeError("At least one hemisphere (eg LWS/LPS) required")
hemispheres = [ Hemisphere(kwargs[s+'WS'], kwargs[s+'PS'], s)
for s in sides ]
return hemispheres