#!/usr/bin/env python

# Script for estimating response functions for spherical deconvolution
# A number of different approaches are available within this script for performing response function estimation.

# Make the corresponding MRtrix3 Python libraries available
import inspect, os, sys
lib_folder = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))), os.pardir, 'lib'))
if not os.path.isdir(lib_folder):
  sys.stderr.write('Unable to locate MRtrix3 Python libraries')
  sys.exit(1)
sys.path.insert(0, lib_folder)

from mrtrix3 import algorithm, app, image, path, run


app.init('Robert E. Smith (robert.smith@florey.edu.au) and Thijs Dhollander (thijs.dhollander@gmail.com)',
         'Estimate response function(s) for spherical deconvolution')
app.cmdline.addDescription('dwi2response acts as a \'master\' script for performing various types of response function estimation; a range of different algorithms are available for completing this task. When using this script, the name of the algorithm to be used must appear as the first argument on the command-line after \'dwi2response\'. The subsequent compulsory arguments and options available depend on the particular algorithm being invoked.')
app.cmdline.addDescription('Each algorithm available also has its own help page, including necessary references; e.g. to see the help page of the \'fa\' algorithm, type \'dwi2response fa\'.')

# General options
common_options = app.cmdline.add_argument_group('Options common to all dwi2response algorithms')
common_options.add_argument('-shells', help='The b-value shell(s) to use in response function estimation (single value for single-shell response, comma-separated list for multi-shell response)')
common_options.add_argument('-lmax', help='The maximum harmonic degree(s) of response function estimation (single value for single-shell response, comma-separated list for multi-shell response)')
common_options.add_argument('-mask', help='Provide an initial mask for response voxel selection')
common_options.add_argument('-voxels', help='Output an image showing the final voxel selection(s)')
common_options.add_argument('-grad', help='Pass the diffusion gradient table in MRtrix format')
common_options.add_argument('-fslgrad', nargs=2, metavar=('bvecs', 'bvals'), help='Pass the diffusion gradient table in FSL bvecs/bvals format')
app.cmdline.flagMutuallyExclusiveOptions( [ 'grad', 'fslgrad' ] )

# Import the command-line settings for all algorithms found in the relevant directory
algorithm.initialise()


app.parse()


# Find out which algorithm the user has requested
alg = algorithm.getModule(app.args.algorithm)


# Check for prior existence of output files, and grab any input files, used by the particular algorithm
if app.args.voxels:
  app.checkOutputPath(app.args.voxels)
alg.checkOutputPaths()


# Sanitise some inputs, and get ready for data import
if app.args.lmax:
  try:
    lmax = [ int(x) for x in app.args.lmax.split(',') ]
    if any([lmax_value%2 for lmax_value in lmax]):
      app.error('Value of lmax must be even')
  except:
    app.error('Parameter lmax must be a number')
  if alg.needsSingleShell() and not len(lmax) == 1:
    app.error('Can only specify a single lmax value for single-shell algorithms')
shells_option = ''
if app.args.shells:
  try:
    shells_values = [ int(round(float(x))) for x in app.args.shells.split(',') ]
  except:
    app.error('-shells option should provide a comma-separated list of b-values')
  if alg.needsSingleShell() and not len(shells_values) == 1:
    app.error('Can only specify a single b-value shell for single-shell algorithms')
  shells_option = ' -shells ' + app.args.shells
singleshell_option = ''
if alg.needsSingleShell():
  singleshell_option = ' -singleshell -no_bzero'

grad_import_option = ''
if app.args.grad:
  grad_import_option = ' -grad ' + path.fromUser(app.args.grad, True)
elif app.args.fslgrad:
  grad_import_option = ' -fslgrad ' + path.fromUser(app.args.fslgrad[0], True) + ' ' + path.fromUser(app.args.fslgrad[1], True)
elif 'dw_scheme' not in image.Header(path.fromUser(app.args.input, False)).keyval():
  app.error('Script requires diffusion gradient table: either in image header, or using -grad / -fslgrad option')

app.makeTempDir()

# Get standard input data into the temporary directory
if alg.needsSingleShell() or shells_option:
  run.command('mrconvert ' + path.fromUser(app.args.input, True) + ' - -strides 0,0,0,1' + grad_import_option + ' | dwiextract - ' + path.toTemp('dwi.mif', True) + shells_option + singleshell_option)
else: # Don't discard b=0 in multi-shell algorithms
  run.command('mrconvert ' + path.fromUser(app.args.input, True) + ' ' + path.toTemp('dwi.mif', True) + ' -strides 0,0,0,1' + grad_import_option)
if app.args.mask:
  run.command('mrconvert ' + path.fromUser(app.args.mask, True) + ' ' + path.toTemp('mask.mif', True) + ' -datatype bit')

alg.getInputs()

app.gotoTempDir()


if app.args.mask:
  # Check that the brain mask is appropriate
  if image.Header('mask.mif').size()[:3] != image.Header('dwi.mif').size()[:3]:
    app.error('Dimensions of provided mask image do not match DWI')
else:
  run.command('dwi2mask dwi.mif mask.mif')

if int(image.statistic('mask.mif', 'count', '-mask mask.mif')) == 0:
  app.error(('Provided' if app.args.mask else 'Generated') + ' mask image does not contain any voxels')


# From here, the script splits depending on what estimation algorithm is being used
alg.execute()


# Finalize for all algorithms
if app.args.voxels:
  run.command('mrconvert voxels.mif ' + path.fromUser(app.args.voxels, True) + (' -force' if app.forceOverwrite else ''))
app.complete()
