#!/usr/bin/env python

# Generates an unbiased group-average template via image registration of images to a midway space.

# Make the corresponding MRtrix3 Python libraries available
import inspect, os, sys
lib_folder = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))), os.pardir, 'lib'))
if not os.path.isdir(lib_folder):
  sys.stderr.write('Unable to locate MRtrix3 Python libraries')
  sys.exit(1)
sys.path.insert(0, lib_folder)


import shutil
move      = shutil.move
copy      = shutil.copy
copytree  = shutil.copytree
rmtree    = shutil.rmtree
remove    = os.remove

# Binds raw_input() to input() in Python2, so that input() can be used
#   and the code will work on both Python 2 and 3
try:
  input = raw_input #pylint: disable=redefined-builtin
except NameError:
  pass


import math
from mrtrix3 import app, file, image, path, run #pylint: disable=redefined-builtin


from mrtrix3.path import allindir

def abspath(*arg):
  return os.path.abspath(os.path.join(*arg))

def relpath(*arg):
  return os.path.relpath(os.path.join(*arg), app.workingDir)

try:
  from numpy import loadtxt, savetxt, dot
except ImportError:
  app.console("numpy not found; using replacement functions")
  def loadtxt(fname, delimiter=" ", dtype=float):
    with open(fname, "rb") as f:
      return [[dtype(a) for a in l.decode(errors='ignore').rstrip().split(delimiter)] for l in f.readlines()]

  def savetxt(fname, X, delimiter=" ", fmt="%.14e"):
    try:
      fh = open(fname, 'wb')
      if len(X) == 1:
        ncol = 1
      else:
        ncol = len(X[0])
      fmt = [fmt, ] * ncol
      row_format = delimiter.join(fmt)
      for row in X:
        fh.write(((row_format % tuple(row) + '\n').encode(errors='ignore')))
    finally:
      fh.close()

  def dot(a, b):
    # matrix dot product
    # does not check input dimensions
    return [[sum(x*y for x,y in zip(a_row,b_col)) for b_col in zip(*b)] for a_row in a]

def check_linear_transformation (transformation, cmd, max_scaling = 0.5, max_shear = 0.2, max_rot = 2 * math.pi, pause_on_warn = True):

  def load_key_value(file_path):
    res = {}
    with open (file_path, "r") as f:
      for line in f.readlines():
        if len(line)==1 or line.startswith("#"):
          continue
        name, var = line.rstrip().partition(":")[::2]
        if name in res.keys():
          res[name].append(var.split())
        else:
          res[name] = var.split()
    return res

  bGood = True
  run.command ('transformcalc ' + transformation + ' decompose ' + transformation + 'decomp')
  if not os.path.isfile(transformation + 'decomp'): # does not exist if run with -continue option
    app.console(transformation + 'decomp not found. skipping check')
    return True
  data = load_key_value(transformation + 'decomp')
  run.function(remove, transformation + 'decomp')
  scaling = [float(value) for value in data['scaling']]
  if any([a < 0 for a in scaling]) or any([a > (1 + max_scaling) for a in scaling]) or any([a < (1 - max_scaling) for a in scaling]):
    app.warn ("large scaling (" + str(scaling) + ") in " + transformation)
    bGood = False
  shear = [float(value) for value in data['shear']]
  if any([abs(a) > max_shear for a in shear]):
    app.warn ("large shear (" + str(shear) + ") in " + transformation)
    bGood = False
  rot_angle = float(data['angle_axis'][0])
  if abs(rot_angle) > max_rot:
    app.warn ("large rotation (" + str(rot_angle) + ") in " + transformation)
    bGood = False

  if not bGood:
    newcmd = []
    what = ''
    init_rotation_found = False
    skip = 0
    for e in cmd.split():
      if skip:
        skip -= 1
        continue
      if '_init_rotation' in e:
        init_rotation_found = True
      if '_init_matrix' in e:
        skip = 1
        continue
      if 'affine_scale' in e:
        assert what != 'rigid'
        what = 'affine'
      elif 'rigid_scale' in e:
        assert what != 'affine'
        what = 'rigid'
      newcmd.append(e)
    newcmd=" ".join(newcmd)
    if not init_rotation_found:
      app.console("replacing the transformation obtained with:")
      app.console(cmd)
      if what:
        newcmd += ' -'+what+'_init_translation mass -'+what+'_init_rotation search'
      app.console("by the one obtained with:")
      app.console(newcmd)
      run.command(newcmd)
      return check_linear_transformation (transformation, newcmd, max_scaling, max_shear, max_rot, pause_on_warn = pause_on_warn)
    if pause_on_warn:
      app.warn("you might want to manually repeat mrregister with different parameters and overwrite the transformation file: \n%s" % transformation)
      app.console('The command that failed the test was: \n' + cmd)
      app.console('Working directory: \n' + os.getcwd())
      input("press enter to continue population_template")
  return bGood

class Input(object):
  def __init__(self, filename, prefix, directory, mask_filename = '', mask_directory = ''):
    self.filename = filename
    self.prefix = prefix
    self.directory = directory
    self.mask_filename = mask_filename
    self.mask_directory = mask_directory

rigid_scales  = [0.3,0.4,0.6,0.8,1.0,1.0]
rigid_lmax    = [2,2,2,4,4,4]
affine_scales = [0.3,0.4,0.6,0.8,1.0,1.0]
affine_lmax   = [2,2,2,4,4,4]

nl_scales = [0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]
nl_niter =  [5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5  ,5  ,5  ,5  ,5  ]
nl_lmax =   [2,  2,  2,  2,  2,  2,  2,  2,  4,  4,  4,  4  ,4  ,4  ,4  ,4  ]

registration_modes = ["rigid","affine","nonlinear","rigid_affine","rigid_nonlinear","affine_nonlinear","rigid_affine_nonlinear"]


app.init('David Raffelt (david.raffelt@florey.edu.au) & Max Pietsch (maximilian.pietsch@kcl.ac.uk) & Thijs Dhollander (thijs.dhollander@gmail.com)',
         'Generates an unbiased group-average template from a series of images')
app.cmdline.addDescription('First a template is optimised with linear registration (rigid or affine, affine is default), then non-linear registration is used to optimise the template further.')
app.cmdline.add_argument('input_dir', help='The input directory containing all images used to build the template')
app.cmdline.add_argument('template', help='The output template image')

linoptions = app.cmdline.add_argument_group('Options for the linear registration')
linoptions.add_argument('-linear_no_pause', action='store_true', help='Do not pause the script if a linear registration seems implausible')
linoptions.add_argument('-linear_estimator', help='Choose estimator for intensity difference metric. Valid choices are: l1 (least absolute: |x|), l2 (ordinary least squares), lp (least powers: |x|^1.2), Default: l2')
linoptions.add_argument('-rigid_scale', help='Specifiy the multi-resolution pyramid used to build the rigid template, in the form of a list of scale factors (default: %s). This and affine_scale implicitly  define the number of template levels' % ','.join([str(x) for x in rigid_scales]))
linoptions.add_argument('-rigid_lmax', help='Specifiy the lmax used for rigid registration for each scale factor, in the form of a list of integers (default: %s). The list must be the same length as the linear_scale factor list' % ','.join([str(x) for x in rigid_lmax]))
linoptions.add_argument('-rigid_niter', help='Specifiy the number of registration iterations used within each level before updating the template, in the form of a list of integers (default:50 for each scale). This must be a single number or a list of same length as the linear_scale factor list')
linoptions.add_argument('-affine_scale', help='Specifiy the multi-resolution pyramid used to build the affine template, in the form of a list of scale factors (default: %s). This and rigid_scale implicitly define the number of template levels' % ','.join([str(x) for x in affine_scales]))
linoptions.add_argument('-affine_lmax', help='Specifiy the lmax used for affine registration for each scale factor, in the form of a list of integers (default: %s). The list must be the same length as the linear_scale factor list' % ','.join([str(x) for x in affine_lmax]))
linoptions.add_argument('-affine_niter', help='Specifiy the number of registration iterations used within each level before updating the template, in the form of a list of integers (default:500 for each scale). This must be a single number or a list of same length as the linear_scale factor list')

nloptions = app.cmdline.add_argument_group('Options for the non-linear registration')
nloptions.add_argument('-nl_scale', help='Specifiy the multi-resolution pyramid used to build the non-linear template, in the form of a list of scale factors (default: %s). This implicitly defines the number of template levels' % ','.join([str(x) for x in nl_scales]))
nloptions.add_argument('-nl_lmax', help='Specifiy the lmax used for non-linear registration for each scale factor, in the form of a list of integers (default: %s). The list must be the same length as the nl_scale factor list' % ','.join([str(x) for x in nl_lmax]))
nloptions.add_argument('-nl_niter', help='Specifiy the number of registration iterations used within each level before updating the template, in the form of a list of integers (default: %s). The list must be the same length as the nl_scale factor list' % ','.join([str(x) for x in nl_niter]))
nloptions.add_argument('-nl_update_smooth', default='2.0', help='Regularise the gradient update field with Gaussian smoothing (standard deviation in voxel units, Default 2.0 x voxel_size)')
nloptions.add_argument('-nl_disp_smooth', default='1.0', help='Regularise the displacement field with Gaussian smoothing (standard deviation in voxel units, Default 1.0 x voxel_size)')
nloptions.add_argument('-nl_grad_step', default='0.5', help='The gradient step size for non-linear registration (Default: 0.5)')

options = app.cmdline.add_argument_group('Input, output and general options')
options.add_argument('-type', help='Specifiy the types of registration stages to perform. Options are "rigid" (perform rigid registration only which might be useful for intra-subject registration in longitudinal analysis), "affine" (perform affine registration) and "nonlinear" as well as cominations of registration types: %s. Default: rigid_affine_nonlinear' % ', '.join('"'+x+'"' for x in registration_modes if "_" in x), default='rigid_affine_nonlinear')
options.add_argument('-voxel_size', help='Define the template voxel size in mm. Use either a single value for isotropic voxels or 3 comma separated values.')
options.add_argument('-initial_alignment', default='mass', help='Method of alignment to form the initial template. Options are "mass" (default), "geometric" and "none".')
options.add_argument('-mask_dir', help='Optionally input a set of masks inside a single directory, one per input image (with the same file name prefix). Using masks will speed up registration significantly')
options.add_argument('-warp_dir', help='Output a directory containing warps from each input to the template. If the folder does not exist it will be created')
options.add_argument('-transformed_dir', help='Output a directory containing the input images transformed to the template. If the folder does not exist it will be created')
options.add_argument('-linear_transformations_dir', help='Output a directory containing the linear transformations used to generate the template. If the folder does not exist it will be created')
options.add_argument('-template_mask', help='Output a template mask. Only works in -mask_dir has been input. The template mask is computed as the intersection of all subject masks in template space.')
options.add_argument('-noreorientation', action='store_true', help='Turn off FOD reorientation in mrregister. Reorientation is on by default if the number of volumes in the 4th dimension corresponds to the number of coefficients in an antipodally symmetric spherical harmonic series (i.e. 6, 15, 28, 45, 66 etc)')

app.parse()

if not app.args.type in registration_modes:
  app.error("registration type must be one of %s. provided: %s" % (str(["rigid", "affine", "rigid_affine"]), app.args.type))
dorigid     = "rigid"     in app.args.type
doaffine    = "affine"    in app.args.type
dolinear    = dorigid or doaffine
dononlinear = "nonlinear" in app.args.type
assert (dorigid + doaffine + dononlinear >= 1), "FIXME: registration type not valid"

app.args.input_dir = relpath(app.args.input_dir)
inputDir = app.args.input_dir
if not os.path.exists(inputDir):
  app.error('input directory not found')
inFiles = sorted(allindir(inputDir, dir_path=False))
if len(inFiles) <= 1:
  app.error('Not enough images found in input directory. More than one image is needed to generate a population template')
else:
  app.console('Generating a population-average template from ' + str(len(inFiles)) + ' input images')

voxel_size = None
if app.args.voxel_size:
  voxel_size = app.args.voxel_size.split(',')
  if len(voxel_size) == 1:
    voxel_size = voxel_size * 3
  try:
    assert len(voxel_size) == 3
    [float(v) for v in voxel_size] #pylint: disable=expression-not-assigned
  except:
    app.error('voxel size needs to be a single or three comma-separated floating point numbers; received: ' + str(app.args.voxel_size))

initial_alignment = app.args.initial_alignment
if initial_alignment not in ["mass", "geometric", "none"]:
  app.error('initial_alignment must be one of ' + " ".join(["mass", "geometric", "none"]))

linear_estimator = app.args.linear_estimator
if linear_estimator:
  if not dononlinear:
    app.error('linear_estimator specified when no linear registration is requested')
  if linear_estimator not in ["l1", "l2", "lp"]:
    app.error('linear_estimator must be one of ' + " ".join(["l1", "l2", "lp"]))

useMasks = False
if app.args.mask_dir:
  useMasks = True
  app.args.mask_dir = relpath(app.args.mask_dir)
  maskDir = app.args.mask_dir
  if not os.path.exists(maskDir):
    app.error('mask directory not found')
  maskFiles = allindir(maskDir, dir_path=False)
  if len(maskFiles) < len(inFiles):
    app.error('there are not enough mask images for the number of images in the input directory')
  maskCommonPostfix = path.commonPostfix(maskFiles)
  maskPrefixes = []
  for m in maskFiles:
    maskPrefixes.append(m.split(maskCommonPostfix)[0])

if not useMasks:
  app.warn('no masks input. Use input masks to reduce computation time and improve robustness')

if app.args.template_mask and not useMasks:
  app.error('you cannot output a template mask because no subject masks were input using -mask_dir')

commonPostfix = path.commonPostfix(inFiles)
inputs = []
for i in inFiles:
  image.check3DNonunity(os.path.join(path.fromUser(inputDir, False), i))
  subj_prefix = i.split(commonPostfix)[0]
  if useMasks:
    if subj_prefix not in maskPrefixes:
      app.error ('no matching mask image was found for input image ' + i)
    index = maskPrefixes.index(subj_prefix)
    inputs.append(Input(i, subj_prefix, path.fromUser(inputDir, False), maskFiles[index], path.fromUser(maskDir, False)))
  else:
    inputs.append(Input(i, subj_prefix, path.fromUser(inputDir, False)))


noreorientation = app.args.noreorientation

do_pause_on_warn = True
if app.args.linear_no_pause:
  do_pause_on_warn = False
  if not dolinear:
    app.error("linear option set when no linear registration is performed")

app.args.template = relpath(app.args.template)
app.checkOutputPath(app.args.template)

if app.args.warp_dir:
  app.args.warp_dir = relpath(app.args.warp_dir)
  app.checkOutputPath(app.args.warp_dir)

if app.args.transformed_dir:
  app.args.transformed_dir = relpath(app.args.transformed_dir)
  app.checkOutputPath(app.args.transformed_dir)

if app.args.linear_transformations_dir:
  if not dolinear:
    app.error("linear option set when no linear registration is performed")
  app.args.linear_transformations_dir = relpath(app.args.linear_transformations_dir)
  app.checkOutputPath(app.args.linear_transformations_dir)


# automatically detect SH series
do_fod_registration = False
image_size = image.Header(relpath(inputs[0].directory, inputs[0].filename)).size()
if len(image_size) < 3 or len(image_size) > 4:
  app.error('only 3 and 4 dimensional images can be used to build a template')
if len(image_size) == 4:
  val = (math.sqrt (1 + 8 * image_size[3]) - 3.0) / 4.0
  if not (val - int(val)) and not noreorientation:
    app.console("SH series detected, performing FOD registration")
    do_fod_registration = True

#rigid options
if app.args.rigid_scale:
  rigid_scales = [float(x) for x in app.args.rigid_scale.split(',')]
  if not dorigid:
    app.error("rigid_scales option set when no rigid registration is performed")
if app.args.rigid_lmax:
  if not dorigid:
    app.error("rigid_lmax option set when no rigid registration is performed")
  rigid_lmax = [int(x) for x in app.args.rigid_lmax.split(',')]
  if do_fod_registration and len(rigid_scales) != len(rigid_lmax):
    app.error('rigid_scales and rigid_lmax schedules are not equal in length')

rigid_niter = [100] * len(rigid_scales)
if app.args.rigid_niter:
  if not dorigid:
    app.error("rigid_niter specified when no rigid registration is performed")
  rigid_niter = [int(x) for x in app.args.rigid_niter.split(',')]
  if len(rigid_niter) == 1:
    rigid_niter = rigid_niter * len(rigid_scales)
  elif len(rigid_scales) != len(rigid_niter):
    app.error('rigid_scales and rigid_niter schedules are not equal in length')

# affine options
if app.args.affine_scale:
  affine_scales = [float(x) for x in app.args.affine_scale.split(',')]
  if not doaffine:
    app.error("affine_scale option set when no affine registration is performed")
if app.args.affine_lmax:
  if not doaffine:
    app.error("affine_lmax option set when no affine registration is performed")
  affine_lmax = [int(x) for x in app.args.affine_lmax.split(',')]
  if do_fod_registration and len(affine_scales) != len(affine_lmax):
    app.error('affine_scales and affine_lmax schedules are not equal in length')

affine_niter = [500] * len(affine_scales)
if app.args.affine_niter:
  if not doaffine:
    app.error("affine_niter specified when no affine registration is performed")
  affine_niter = [int(x) for x in app.args.affine_niter.split(',')]
  if len(affine_niter) == 1:
    affine_niter = affine_niter * len(affine_scales)
  elif len(affine_scales) != len(affine_niter):
    app.error('affine_scales and affine_niter schedules are not equal in length')

linear_scales = []
linear_lmax   = []
linear_niter  = []
linear_type   = []
if dorigid:
  linear_scales += rigid_scales
  linear_lmax   += rigid_lmax
  linear_niter  += rigid_niter
  linear_type   += ['rigid'] * len(rigid_scales)

if doaffine:
  linear_scales += affine_scales
  linear_lmax   += affine_lmax
  linear_niter  += affine_niter
  linear_type   += ['affine'] * len(affine_scales)

assert len(linear_type) == len(linear_scales)
assert len(linear_scales) == len(linear_niter)
if do_fod_registration:
  assert len(linear_lmax) == len(linear_niter)

app.console('initial alignment of images: %s' % initial_alignment)

if dolinear:
  app.console('linear registration stages:')
  if do_fod_registration:
    for istage, [tpe, scale, lmax, niter] in enumerate(zip (linear_type, linear_scales, linear_lmax, linear_niter)):
      app.console('(%02i) %s scale: %.4f, niter: %i, lmax: %i' %(istage, tpe.ljust(9), scale, niter, lmax))
  else:
    for istage, [tpe, scale, niter] in enumerate(zip (linear_type, linear_scales, linear_niter)):
      app.console('(%02i) %s scale: %.4f, niter: %i, no reorientation' %(istage, tpe.ljust(9), scale, niter))

datatype_option = ' -datatype float32'

if not dononlinear:
  nl_scales = []
  nl_lmax   = []
  nl_niter  = []
  if app.args.warp_dir:
    app.error('warp_dir specified when no nonlinear registration is performed')
else:
  if app.args.nl_scale:
    nl_scales = [float(x) for x in app.args.nl_scale.split(',')]
  if app.args.nl_niter:
    nl_niter = [int(x) for x in app.args.nl_niter.split(',')]
  if app.args.nl_lmax:
    nl_lmax = [int(x) for x in app.args.nl_lmax.split(',')]

  if len(nl_scales) != len(nl_niter):
    app.error('nl_scales and nl_niter schedules are not equal in length')

  app.console('nonlinear registration stages:')
  if do_fod_registration:
    if len(nl_scales) != len(nl_lmax):
      app.error('nl_scales and nl_lmax schedules are not equal in length')

  if do_fod_registration:
    for istage, [scale, lmax, niter] in enumerate(zip (nl_scales, nl_lmax, nl_niter)):
      app.console('(%02i) nonlinear scale: %.4f, niter: %i, lmax: %i' %(istage, scale, niter, lmax))
  else:
    for istage, [scale, niter] in enumerate(zip (nl_scales, nl_niter)):
      app.console('(%02i) nonlinear scale: %.4f, niter: %i, no reorientation' %(istage, scale, niter))

app.makeTempDir()
app.gotoTempDir()

file.makeDir('inputs_transformed')
file.makeDir('linear_transforms_initial')
file.makeDir('linear_transforms')
for level in range(0, len(linear_scales)):
  file.makeDir('linear_transforms_%i' % level)
for level in range(0, len(nl_scales)):
  file.makeDir('warps_%i' % level)

if useMasks:
  file.makeDir('masks_transformed')
write_log = (app.verbosity >= 2)
if write_log:
  file.makeDir('log')

# Make initial template in average space
app.console('Generating initial template')
input_filenames = []
for i in inputs:
  input_filenames.append(abspath(i.directory, i.filename))
if voxel_size is None:
  run.command('mraverageheader ' + ' '.join(input_filenames) + ' average_header.mif -fill')
else:
  run.command('mraverageheader -fill ' + ' '.join(input_filenames) + ' - | mrresize - -voxel '+','.join(voxel_size)+' average_header.mif')

# crop average space to extent defined by original masks
if useMasks:
  progress = app.progressBar('Importing input masks to average space for template cropping', len(inputs))
  mask_filenames = []
  for i in inputs:
    run.command('mrtransform ' + abspath(i.mask_directory, i.mask_filename) + \
                ' -interp nearest -template average_header.mif ' + os.path.join('masks_transformed', i.mask_filename))
    mask_filenames.append(os.path.join('masks_transformed', i.mask_filename))
    progress.increment()
  progress.done()
  run.command('mrmath ' + ' '.join(mask_filenames) + ' max mask_initial.mif' )
  run.command('mrcrop ' + 'average_header.mif -mask mask_initial.mif average_header_cropped.mif')
  run.function(remove,'mask_initial.mif')
  run.function(remove, 'average_header.mif')
  run.function(move, 'average_header_cropped.mif', 'average_header.mif')
  progress = app.progressBar('Erasing temporary mask images', len(mask_filenames))
  for mask in mask_filenames:
    run.function(remove, mask)
    progress.increment()
  progress.done()



if initial_alignment == 'none':
  progress = app.progressBar('Resampling input images to template space with no initial alignment', len(inputs))
  for i in inputs:
    run.command('mrtransform ' + abspath(i.directory, i.filename) + ' -interp linear -template average_header.mif ' + os.path.join('inputs_transformed', i.prefix + '.mif') + datatype_option)
    progress.increment()
  progress.done()
  if not dolinear:
    for i in inputs:
      with open(os.path.join('linear_transforms_initial','%s.txt' % (i.prefix)),'w') as fout:
        fout.write('1 0 0 0\n0 1 0 0\n0 0 1 0\n0 0 0 1\n')
else:
  progress = app.progressBar('Performing initial rigid registration to template', len(inputs))
  mask = ''
  for i in inputs:
    if useMasks:
      mask = ' -mask1 ' + abspath(i.mask_directory, i.mask_filename)
    output = ' -rigid ' + os.path.join('linear_transforms_initial', i.prefix + '.txt')
    run.command('mrregister ' + abspath(i.directory, i.filename) + ' average_header.mif' + \
                mask + \
                ' -rigid_scale 1 ' + \
                ' -rigid_niter 0 ' + \
                ' -type rigid ' + \
                ' -noreorientation ' + \
                ' -rigid_init_translation ' + initial_alignment + ' ' + \
                datatype_option + \
                output)
    # translate input images to centre of mass without interpolation
    run.command('mrtransform ' + abspath(i.directory, i.filename) + \
                ' -linear ' + os.path.join('linear_transforms_initial', i.prefix + '.txt') + \
                datatype_option + \
                ' ' + os.path.join('inputs_transformed', i.prefix + '_translated.mif'))
    if useMasks:
      run.command('mrtransform ' + abspath(i.mask_directory, i.mask_filename) + \
                  ' -linear ' + os.path.join('linear_transforms_initial', i.prefix + '.txt') + \
                  datatype_option + ' ' + \
                  os.path.join('masks_transformed', i.prefix + '_translated.mif'))
    progress.increment()
  progress.done()
  # update average space to new extent
  run.command('mraverageheader ' + ' '.join([os.path.join('inputs_transformed', i.prefix + '_translated.mif') for i in inputs]) + ' average_header_tight.mif')
  if voxel_size is None:
    run.command('mrpad -uniform 10 average_header_tight.mif average_header.mif -force')
  else:
    run.command('mrpad -uniform 10 average_header_tight.mif - | mrresize - -voxel '+','.join(voxel_size)+' average_header.mif -force')
  run.function(remove, 'average_header_tight.mif')
  if useMasks:
    # reslice masks
    progress = app.progressBar('Reslicing input masks to average header', len(inputs))
    for i in inputs:
      run.command('mrtransform ' + \
                  os.path.join('masks_transformed', i.prefix + '_translated.mif') + ' ' + \
                  os.path.join('masks_transformed', i.prefix + '.mif') + ' ' + \
                  '-interp nearest -template average_header.mif' + \
                  datatype_option)
      progress.increment()
    progress.done()
    # crop average space to extent defined by translated masks
    mask_filenames = []
    for i in inputs:
      mask_filenames.append(os.path.join('masks_transformed', i.prefix + '.mif'))
    run.command('mrmath ' + ' '.join(mask_filenames) + ' max mask_translated.mif' )
    run.command('mrcrop ' + 'average_header.mif -mask mask_translated.mif average_header_cropped.mif')
    # pad average space to allow for deviation from initial alignment
    run.command('mrpad -uniform 10 average_header_cropped.mif -force average_header.mif')
    run.function(remove, 'average_header_cropped.mif')
    progress = app.progressBar('Reslicing mask images to new padded average header', len(inputs))
    for i in inputs:
      run.command('mrtransform ' +
                  os.path.join('masks_transformed', i.prefix + '_translated.mif') + ' ' + \
                  os.path.join('masks_transformed', i.prefix + '.mif') + ' ' + \
                  '-interp nearest -template average_header.mif' + \
                  datatype_option + ' ' + \
                  '-force')
      run.function(remove, os.path.join('masks_transformed', i.prefix + '_translated.mif'))
      progress.increment()
    progress.done()
    run.function(remove, 'mask_translated.mif')
  # reslice input images
  progress = app.progressBar('Reslicing input images to average header', len(inputs))
  for i in inputs:
    run.command('mrtransform ' + \
                os.path.join('inputs_transformed', i.prefix + '_translated.mif') + ' ' + \
                os.path.join('inputs_transformed', i.prefix + '.mif') + ' ' + \
                '-interp linear -template average_header.mif' + \
                datatype_option)
    run.function(remove, os.path.join('inputs_transformed', i.prefix + '_translated.mif'))
    progress.increment()
  progress.done()


run.command('mrmath ' + ' '.join(allindir('inputs_transformed')) + ' mean initial_template.mif')
current_template = 'initial_template.mif'


# Optimise template with linear registration
if not dolinear:
  for i in inputs:
    run.function(copy, os.path.join('linear_transforms_initial','%s.txt' % (i.prefix)), os.path.join('linear_transforms','%s.txt' % (i.prefix)))
else:
  for level, (regtype, scale, niter, lmax) in enumerate(zip(linear_type, linear_scales, linear_niter, linear_lmax)):
    progress = app.progressBar('Optimising template with linear registration (stage {0} of {1})'.format(level+1, len(linear_scales)), len(inputs))
    for i in inputs:
      initialise_option = ''
      if useMasks:
        mask_option = ' -mask1 ' + abspath(i.mask_directory, i.mask_filename)
      else:
        mask_option = ''
      lmax_option = ''
      metric_option = ''
      mrregister_log_option = ''
      if regtype == 'rigid':
        scale_option = ' -rigid_scale ' + str(scale)
        niter_option = ' -rigid_niter ' + str(niter)
        regtype_option = ' -type rigid'
        output_option = ' -rigid ' + os.path.join('linear_transforms_%i' % level, '%s.txt' % i.prefix)
        if level > 0:
          initialise_option = ' -rigid_init_matrix ' + os.path.join('linear_transforms_%i' % (level - 1), '%s.txt' % i.prefix)
        if do_fod_registration:
          lmax_option = ' -rigid_lmax ' + str(lmax)
        else:
          lmax_option = ' -noreorientation'
        if linear_estimator:
          metric_option = ' -rigid_metric.diff.estimator ' + linear_estimator
        if app.verbosity >= 2:
          mrregister_log_option = ' -info -rigid_log ' + os.path.join('log', i.filename + "_" + str(level) + '.log')
      else:
        scale_option = ' -affine_scale ' + str(scale)
        niter_option = ' -affine_niter ' + str(niter)
        regtype_option = ' -type affine'
        output_option = ' -affine ' + os.path.join('linear_transforms_%i' % level, '%s.txt' % i.prefix)
        if level > 0:
          initialise_option = ' -affine_init_matrix ' + os.path.join('linear_transforms_%i' % (level - 1), '%s.txt' % i.prefix)
        if do_fod_registration:
          lmax_option = ' -affine_lmax ' + str(lmax)
        else:
          lmax_option = ' -noreorientation'
        if linear_estimator:
          metric_option = ' -affine_metric.diff.estimator ' + linear_estimator
        if write_log:
          mrregister_log_option = ' -info -affine_log ' + os.path.join('log', i.filename + '_' + str(level) + '.log')

      command = 'mrregister ' + abspath(i.directory, i.filename) + ' ' + current_template + \
                ' -force' + \
                initialise_option + \
                mask_option + \
                scale_option + \
                niter_option + \
                lmax_option + \
                regtype_option + \
                metric_option + \
                datatype_option + \
                output_option + \
                mrregister_log_option
      run.command(command)
      check_linear_transformation(os.path.join('linear_transforms_%i' % level, '%s.txt' % i.prefix), command, pause_on_warn=do_pause_on_warn)
      progress.increment()
    progress.done()

    # Here we ensure the template doesn't drift or scale
    run.command('transformcalc ' + ' '.join(allindir('linear_transforms_%i' % level)) + ' average linear_transform_average.txt -force -quiet')
    if linear_type[level] == 'rigid':
      run.command('transformcalc linear_transform_average.txt rigid linear_transform_average.txt -force -quiet')
    run.command('transformcalc linear_transform_average.txt invert linear_transform_average_inv.txt -force -quiet')

    average_inv = run.function(loadtxt, 'linear_transform_average_inv.txt')
    if average_inv is not None:
      for i in inputs:
        transform = dot(loadtxt(os.path.join('linear_transforms_%i' % level, '%s.txt' % i.prefix)), average_inv)
        savetxt(os.path.join('linear_transforms_%i' % level, '%s.txt' % i.prefix), transform)

    progress = app.progressBar('Transforming all subjects to revised template', len(inputs))
    for i in inputs:
      run.command('mrtransform ' + abspath(i.directory, i.filename) + ' ' + \
                  '-template ' + current_template + ' ' + \
                  '-linear ' + os.path.join('linear_transforms_%i' % level, '%s.txt' % i.prefix) + ' ' + \
                  os.path.join('inputs_transformed', '%s.mif' % i.prefix) + \
                  datatype_option + ' ' + \
                  '-force')
      progress.increment()
    progress.done()

    run.command('mrmath ' + ' '.join(allindir('inputs_transformed')) + ' mean linear_template' + str(level) + '.mif -force')
    current_template = 'linear_template' + str(level) + '.mif'

  for entry in os.listdir('linear_transforms_%i' % level):
    run.function(copy, os.path.join('linear_transforms_%i' % level, entry), os.path.join('linear_transforms', entry))

# Create a template mask for nl registration by taking the intersection of all transformed input masks and dilating
if useMasks and (dononlinear or app.args.template_mask):
  progress = app.progressBar('Generating template mask for non-linear registration', len(inputs))
  for i in inputs:
    run.command('mrtransform ' + abspath(i.mask_directory, i.mask_filename) + ' ' + \
                '-template ' + current_template + ' ' + \
                '-interp nearest ' + \
                '-linear ' + os.path.join('linear_transforms', '%s.txt' % i.prefix) + ' ' + \
                os.path.join('masks_transformed', '%s.mif' % i.prefix) + ' ' + \
                '-force')
    progress.increment()
  progress.done()
  run.command ('mrmath ' + ' '.join(allindir('masks_transformed')) + ' min - | maskfilter - median - | maskfilter - dilate -npass 5 init_nl_template_mask.mif -force')
  current_template_mask = 'init_nl_template_mask.mif'

if dononlinear:
  # Optimise the template with non-linear registration
  file.makeDir('warps')
  for level, (scale, niter, lmax) in enumerate(zip(nl_scales, nl_niter, nl_lmax)):
    progress = app.progressBar('Optimising template with non-linear registration (stage {0} of {1})'.format(level+1, len(nl_scales)), len(inputs))
    for i in inputs:
      if level > 0:
        initialise_option = ' -nl_init ' + os.path.join('warps_%i' % (level-1), '%s.mif' % i.prefix)
        scale_option = ''
      else:
        scale_option = ' -nl_scale ' + str(scale)
        if not doaffine: # rigid or no previous linear stage
          initialise_option = ' -rigid_init_matrix ' + os.path.join('linear_transforms', '%s.txt' % i.prefix)
        else:
          initialise_option = ' -affine_init_matrix ' + os.path.join('linear_transforms', '%s.txt' % i.prefix)

      if useMasks:
        mask_option = ' -mask1 ' + abspath(i.mask_directory, i.mask_filename) + ' -mask2 ' + current_template_mask
      else:
        mask_option = ''

      if do_fod_registration:
        lmax_option = ' -nl_lmax ' + str(lmax)
      else:
        lmax_option = ' -noreorientation'

      run.command('mrregister ' + abspath(i.directory, i.filename) + ' ' + current_template + ' ' + \
                  '-type nonlinear ' + \
                  '-nl_niter ' + str(niter) + ' ' + \
                  '-nl_warp_full ' + os.path.join('warps_%i' % level, '%s.mif' % i.prefix) + ' ' + \
                  '-transformed ' + os.path.join('inputs_transformed', '%s.mif' % i.prefix) + ' ' + \
                  '-nl_update_smooth ' +  app.args.nl_update_smooth + ' ' + \
                  '-nl_disp_smooth ' +  app.args.nl_disp_smooth + ' ' + \
                  '-nl_grad_step ' +  app.args.nl_grad_step + ' ' + \
                  '-force' + \
                  initialise_option + \
                  scale_option + \
                  mask_option + \
                  datatype_option + \
                  lmax_option)

      if level > 0:
        run.function(remove, os.path.join('warps_%i'%(level-1), '%s.mif' % i.prefix))
      if useMasks:
        run.command('mrtransform ' + abspath(i.mask_directory, i.mask_filename) + ' ' + \
                    '-template ' + current_template + ' ' + \
                    '-warp_full ' + os.path.join('warps_%i' % level, '%s.mif' % i.prefix) + ' ' + \
                    os.path.join('masks_transformed', i.prefix + '.mif') + ' ' + \
                    '-interp nearest ' +
                    '-force')
      progress.increment()
    progress.done()

    run.command('mrmath ' + ' '.join(allindir('inputs_transformed')) + ' mean nl_template' + str(level) + '.mif')
    current_template = 'nl_template' + str(level) + '.mif'

    if useMasks:
      run.command('mrmath ' + ' '.join(allindir('masks_transformed')) + ' min - | maskfilter - median - | maskfilter - dilate -npass 5 nl_template_mask' + str(level) + '.mif')
      current_template_mask = 'nl_template_mask' + str(level) + '.mif'

    if level < len(nl_scales) - 1:
      if scale < nl_scales[level + 1]:
        upsample_factor = nl_scales[level + 1] / scale
        for i in inputs:
          run.command('mrresize ' + os.path.join('warps_%i' % level, '%s.mif' % i.prefix) + ' -scale %f tmp.mif' % upsample_factor)
          run.function(move, 'tmp.mif', os.path.join('warps_' + str(level), '%s.mif' % i.prefix))
    else:
      for i in inputs:
        run.function(move, os.path.join('warps_' + str(level), '%s.mif' % i.prefix), 'warps')


run.command('mrconvert ' + current_template + ' ' + path.fromUser(app.args.template, True) + (' -force' if app.forceOverwrite else ''))

if app.args.warp_dir:
  warp_path = path.fromUser(app.args.warp_dir, False)
  if os.path.exists(warp_path):
    run.function(rmtree, warp_path)
  run.function(copytree, 'warps', warp_path)

if app.args.linear_transformations_dir:
  linear_transformations_path = path.fromUser(app.args.linear_transformations_dir, False)
  if os.path.exists(linear_transformations_path):
    run.function(rmtree, linear_transformations_path)
  run.function(copytree, 'linear_transforms', linear_transformations_path)

if app.args.transformed_dir:
  transformed_path = path.fromUser(app.args.transformed_dir, False)
  if os.path.exists(transformed_path):
    run.function(rmtree, transformed_path)
  run.function(copytree, 'inputs_transformed', transformed_path)

if app.args.template_mask:
  run.command('mrconvert ' + current_template_mask + ' ' + path.fromUser(app.args.template_mask, True) + (' -force' if app.forceOverwrite else ''))

app.complete()
