#!/usr/bin/python3

r'''Stereo processing

SYNOPSIS

  $ mrcal-stereo                        \
     --az-fov-deg               90      \
     --el-fov-deg               90      \
     --sgbm-block-size          5       \
     --sgbm-p1                  600     \
     --sgbm-p2                  2400    \
     --sgbm-uniqueness-ratio    5       \
     --sgbm-disp12-max-diff     1       \
     --sgbm-speckle-window-size 200     \
     --sgbm-speckle-range       2       \
     --outdir /tmp                      \
     left.cameramodel right.cameramodel \
     left.jpg         right.jpg

  Writing '/tmp/rectified0.cameramodel'
  Writing '/tmp/rectified1.cameramodel'
  ##### processing left.jpg and right.jpg
  Writing '/tmp/left-rectified.png'
  Writing '/tmp/right-rectified.png'
  Writing '/tmp/left-disparity.png'
  Writing '/tmp/left-range.png'

Given a pair of calibrated cameras and pairs of images captured by these
cameras, this tool runs the whole stereo processing sequence to produce
disparity and range images.

mrcal functions are used to construct the rectified system. Currently only the
OpenCV SGBM routine is available to perform stereo matching, but more options
will be made available with time.

The commandline arguments to configure the SGBM matcher (--sgbm-...) map to the
corresponding OpenCV APIs. Omitting an --sgbm-... argument will result in the
defaults being used in the cv2.StereoSGBM_create() call. Usually the
cv2.StereoSGBM_create() defaults are terrible, and produce a disparity map that
isn't great. The --sgbm-... arguments in the synopsis above are a good start to
get usable stereo.

The rectified system is constructed with the axes

- x: from the origin of the first camera to the origin of the second camera (the
  baseline direction)

- y: completes the system from x,z

- z: the mean "forward" direction of the two input cameras, with the component
  parallel to the baseline subtracted off

The active window in this system is specified using a few parameters. These
refer to

- the "azimuth" (or "az"): the direction along the baseline: rectified x axis

- the "elevation" (or "el"): the direction across the baseline: rectified y axis

The rectified field of view is given by the arguments --az-fov-deg and
--el-fov-deg. At this time there's no auto-detection logic, and these must be
given. Changing these is a "zoom" operation.

To pan the stereo system, pass --az0-deg and/or --el0-deg. These specify the
center of the rectified images, and are optional.

Finally, the resolution of the rectified images is given with --pixels-per-deg.
This is optional, and defaults to the resolution of the input images. If we want
to scale the input resolution, pass a value <0. For instance, to generate
rectified images at half the input resolution, pass --pixels-per-deg=-0.5. Note
that the Python argparse has a problem with negative numbers, so
"--pixels-per-deg -0.5" does not work.

The input images are specified by a pair of globs, so we can process many images
with a single call. Each glob is expanded, and the filenames are sorted. The
resulting lists of files are assumed to match up.

'''

import sys
import argparse
import re
import os

def parse_args():

    def positive_float(string):
        try:
            value = float(string)
        except:
            raise argparse.ArgumentTypeError("argument MUST be a positive floating-point number. Got '{}'".format(string))
        if value <= 0:
            raise argparse.ArgumentTypeError("argument MUST be a positive floating-point number. Got '{}'".format(string))
        return value
    def positive_int(string):
        try:
            value = int(string)
        except:
            raise argparse.ArgumentTypeError("argument MUST be a positive integer. Got '{}'".format(string))
        if value <= 0 or abs(value-float(string)) > 1e-6:
            raise argparse.ArgumentTypeError("argument MUST be a positive integer. Got '{}'".format(string))
        return value


    parser = \
        argparse.ArgumentParser(description = __doc__,
                                formatter_class=argparse.RawDescriptionHelpFormatter)


    ######## geometry and rectification system parameters
    parser.add_argument('--az-fov-deg',
                        type=float,
                        help='''The field of view in the azimuth direction, in
                        degrees. There's no auto-detection at this time, so this
                        argument is required (unless --already-rectified)''')
    parser.add_argument('--el-fov-deg',
                        type=float,
                        help='''The field of view in the elevation direction, in
                        degrees. There's no auto-detection at this time, so this
                        argument is required (unless --already-rectified)''')
    parser.add_argument('--az0-deg',
                        default = None,
                        type=float,
                        help='''The azimuth center of the rectified images. "0"
                        means "the horizontal center of the rectified system is
                        the mean forward direction of the two cameras projected
                        to lie perpendicular to the baseline". If omitted, we
                        align the center of the rectified system with the center
                        of the two cameras' views''')
    parser.add_argument('--el0-deg',
                        default = 0,
                        type=float,
                        help='''The elevation center of the rectified system.
                        "0" means "the vertical center of the rectified system
                        lies along the mean forward direction of the two
                        cameras" Defaults to 0.''')
    parser.add_argument('--pixels-per-deg',
                        help='''The resolution of the rectified images. This is
                        either a whitespace-less, comma-separated list of two
                        values (az,el) or a single value to be applied to both
                        axes. If a resolution of >0 is requested, the value is
                        used as is. If a resolution of <0 is requested, we use
                        this as a scale factor on the resolution of the input
                        image. For instance, to downsample by a factor of 2,
                        pass -0.5. By default, we use -1 for both axes: the
                        resolution of the input image at the center of the
                        rectified system.''')
    parser.add_argument('--rectification',
                        choices=('LENSMODEL_PINHOLE', 'LENSMODEL_LATLON'),
                        default = 'LENSMODEL_LATLON',
                        help='''The lens model to use for rectification.
                        Currently two models are supported: LENSMODEL_LATLON
                        (the default) and LENSMODEL_PINHOLE. Pinhole stereo
                        works badly for wide lenses and suffers from varying
                        angular resolution across the image. LENSMODEL_LATLON
                        rectification uses a transverse equirectangular
                        projection, and does not suffer from these effects. It
                        is thus the recommended model''')

    parser.add_argument('--already-rectified',
                        action='store_true',
                        help='''If given, assume the given models and images
                        already represent a rectified system. This will be
                        checked, and the models will be used as-is if the checks
                        pass''')

    ######## image pre-filtering
    parser.add_argument('--clahe',
                        action='store_true',
                        help='''If given, apply CLAHE equalization to the images
                        prior to the stereo matching. If --already-rectified, we
                        still apply this equalization, if requested. Requires
                        --force-grayscale''')

    parser.add_argument('--force-grayscale',
                        action='store_true',
                        help='''If given, convert the images to grayscale prior
                        to doing anything else with them. By default, read the
                        images in their default format, and pass those
                        posibly-color images to all the processing steps.
                        Required if --clahe''')

    ######## --show-geometry
    parser.add_argument('--show-geometry',
                        action='store_true',
                        help='''If given, we construct the rectified stereo
                        system, but instead of continuing with the stereo
                        processing, we render the geometry of the stereo world.
                        The images are ignored in this mode.''')
    parser.add_argument('--axis-scale',
                        type=float,
                        help='''Used if --show-geometry. Scale for the camera
                        axes. By default a reasonable default is chosen (see
                        mrcal.show_geometry() for the logic)''')
    parser.add_argument('--title',
                        type=str,
                        default = None,
                        help='''Used if --show-geometry. Title string for the plot''')
    parser.add_argument('--hardcopy',
                        type=str,
                        help='''Used if --show-geometry. Write the output to
                        disk, instead of making an interactive plot. The output
                        filename is given in the option''')
    parser.add_argument('--terminal',
                        type=str,
                        help=r'''Used if --show-geometry. The gnuplotlib
                        terminal. The default is almost always right, so most
                        people don't need this option''')
    parser.add_argument('--set',
                        type=str,
                        action='append',
                        help='''Used if --show-geometry. Extra 'set' directives
                        to pass to gnuplotlib. May be given multiple times''')
    parser.add_argument('--unset',
                        type=str,
                        action='append',
                        help='''Used if --show-geometry. Extra 'unset'
                        directives to pass to gnuplotlib. May be given multiple
                        times''')


    ######## stereo processing
    parser.add_argument('--force', '-f',
                        action='store_true',
                        default=False,
                        help='''By default existing files are not overwritten. Pass --force to overwrite them
                        without complaint''')
    parser.add_argument('--outdir',
                        default='.',
                        type=lambda d: d if os.path.isdir(d) else \
                        parser.error(f"--outdir requires an existing directory as the arg, but got '{d}'"),
                        help='''Directory to write the output into. If omitted,
                        we user the current directory''')
    parser.add_argument('--tag',
                        help='''String to use in the output filenames.
                        Non-specific output filenames if omitted ''')
    parser.add_argument('--disparity-range',
                        type=int,
                        nargs=2,
                        default=(0,100),
                        help='''The disparity limits to use in the search, in
                        pixels. Two integers are expected: MIN_DISPARITY
                        MAX_DISPARITY. Completely arbitrarily, we default to
                        MIN_DISPARITY=0 and MAX_DISPARITY=100''')
    parser.add_argument('--valid-intrinsics-region',
                        action='store_true',
                        help='''If given, annotate the image with its
                        valid-intrinsics region. This will end up in the
                        rectified images, and make it clear where successful
                        matching shouldn't be expected''')
    parser.add_argument('--range-image-limits',
                        type=positive_float,
                        nargs=2,
                        default=(1,1000),
                        help='''The nearest,furthest range to encode in the range image.
                        Defaults to 1,1000, arbitrarily''')

    parser.add_argument('--sgbm-block-size',
                        type=int,
                        default = 5,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, 5 is used''')
    parser.add_argument('--sgbm-p1',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-p2',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-disp12-max-diff',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-pre-filter-cap',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-uniqueness-ratio',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-speckle-window-size',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-speckle-range',
                        type=int,
                        help='''A parameter for the OpenCV SGBM matcher. If
                        omitted, the OpenCV default is used''')
    parser.add_argument('--sgbm-mode',
                        choices=('SGBM','HH','HH4','SGBM_3WAY'),
                        help='''A parameter for the OpenCV SGBM matcher. Must be
                        one of ('SGBM','HH','HH4','SGBM_3WAY'). If omitted, the
                        OpenCV default (SGBM) is used''')

    parser.add_argument('models',
                        type=str,
                        nargs = 2,
                        help='''Camera models representing cameras used to
                        capture the images. Both intrinsics and extrinsics are
                        used''')
    parser.add_argument('images',
                        type=str,
                        nargs='*',
                        help='''The image globs to use for the triangulation. If
                        not --show-geometry: exactly two image globs must be
                        given. If --show-geometry: the images are ignored''')

    args = parser.parse_args()

    if not args.show_geometry and \
       len(args.images) != 2:
            print("""Argument-parsing error:
  --show-geometry not given: exactly 2 images should have been given, since we're
  performing stereo processing""",
                  file=sys.stderr)
            sys.exit(1)

    if args.pixels_per_deg is None:
        args.pixels_per_deg = (-1, -1)
    else:
        try:
            l = [float(x) for x in args.pixels_per_deg.split(',')]
            if len(l) < 1 or len(l) > 2:
                raise
            for x in l:
                if x == 0:
                    raise
            args.pixels_per_deg = l
        except:
            print("""Argument-parsing error:
  --pixels_per_deg requires RESX,RESY or RESXY, where RES... is a value <0 or >0""",
                  file=sys.stderr)
            sys.exit(1)

    if not args.already_rectified and \
       (args.az_fov_deg is None or \
        args.el_fov_deg is None ):
        print("""Argument-parsing error:
  --az-fov-deg and --el-fov-deg are required if not --already-rectified""",
              file=sys.stderr)
        sys.exit(1)

    if args.clahe and not args.force_grayscale:
        print("--clahe requires --force-grayscale",
              file=sys.stderr)
        sys.exit(1)

    return args

args = parse_args()

# arg-parsing is done before the imports so that --help works without building
# stuff, so that I can generate the manpages and README






import numpy as np
import numpysane as nps
import cv2
import glob
import mrcal

if args.sgbm_mode is not None:
    if   args.sgbm_mode == 'SGBM':      args.sgbm_mode = cv2.StereoSGBM_MODE_SGBM
    elif args.sgbm_mode == 'HH':        args.sgbm_mode = cv2.StereoSGBM_MODE_HH
    elif args.sgbm_mode == 'HH4':       args.sgbm_mode = cv2.StereoSGBM_MODE_HH4
    elif args.sgbm_mode == 'SGBM_3WAY': args.sgbm_mode = cv2.StereoSGBM_MODE_SGBM_3WAY
    else:
        raise Exception("arg-parsing error. This is a bug. Please report")

if len(args.pixels_per_deg) == 2:
    pixels_per_deg_az,pixels_per_deg_el = args.pixels_per_deg
else:
    pixels_per_deg_az = pixels_per_deg_el = args.pixels_per_deg[0]

if args.tag is not None:
    args.tag = re.sub('[^a-zA-Z0-9_+-]+', '_', args.tag)


models = [mrcal.cameramodel(m) for m in args.models]

if not args.already_rectified:
    models_rectified = \
        mrcal.rectified_system(models,
                               az_fov_deg          = args.az_fov_deg,
                               el_fov_deg          = args.el_fov_deg,
                               el0_deg             = args.el0_deg,
                               az0_deg             = args.az0_deg,
                               pixels_per_deg_az   = pixels_per_deg_az,
                               pixels_per_deg_el   = pixels_per_deg_el,
                               rectification_model = args.rectification)
else:
    models_rectified = models
    mrcal.stereo._validate_models_rectified(models_rectified)

if args.show_geometry:
    # Display the geometry of the two cameras in the stereo pair, and of the
    # rectified system
    plot_options = dict(terminal = args.terminal,
                        hardcopy = args.hardcopy,
                        _set     = args.set,
                        unset    = args.unset,
                        wait     = args.hardcopy is None)
    if args.title is not None:
        plot_options['title'] = args.title

    data_tuples, plot_options = \
        mrcal.show_geometry( list(models) + list(models_rectified),
                             ( "camera0", "camera1", "camera0-rectified", "camera1-rectified" ),
                             show_calobjects  = False,
                             return_plot_args = True,
                             **plot_options)

    Nside = 8
    model = models_rectified[0]
    w,h = model.imagersize()
    linspace_w = np.linspace(0,w-1, Nside+1)
    linspace_h = np.linspace(0,h-1, Nside+1)
    # shape (Nloop, 2)
    qloop = \
       nps.transpose( nps.glue( nps.cat( linspace_w,
                                         0*linspace_w),
                                nps.cat( 0*linspace_h[1:] + w-1,
                                         linspace_h[1:]),
                                nps.cat( linspace_w[-2::-1],
                                         0*linspace_w[-2::-1] + h-1),
                                nps.cat( 0*linspace_h[-2::-1],
                                         linspace_h[-2::-1]),
                                axis=-1) )

    # shape (Nloop,3): unit vectors in cam-local coords
    v = mrcal.unproject(np.ascontiguousarray(qloop),
                        *model.intrinsics(),
                        normalize = True)

    # Scale the vectors to a nice-looking length. This is intended to work with
    # the defaults to mrcal.show_geometry(). That function sets the right,down
    # axis lengths to baseline/3. Here I default to a bit less: baseline/4
    baseline = \
        nps.mag(mrcal.compose_Rt( models_rectified[1].extrinsics_Rt_fromref(),
                                  models_rectified[0].extrinsics_Rt_toref())[3,:])
    v *= baseline/4

    # shape (Nloop,3): ref-coord system points
    p = mrcal.transform_point_Rt(model.extrinsics_Rt_toref(), v)

    # shape (Nloop,2,3). Nloop-length different lines with 2 points in each one
    p = nps.xchg(nps.cat(p,p), -2,-3)
    p[:,1,:] = model.extrinsics_Rt_toref()[3,:]

    data_tuples.append( (p, dict(tuplesize = -3,
                                 _with     = 'lines lt 1',)))

    import gnuplotlib as gp
    gp.plot(*data_tuples, **plot_options)

    sys.exit()




def write_output(func, filename):
    if args.tag is None:
        tag = ''
    else:
        tag = '-' + args.tag

    f,e = os.path.splitext(filename)
    filename = f"{args.outdir}/{f}{tag}{e}"

    if os.path.isfile(filename) and not args.force:
        print(f"WARNING: '{filename}' already exists. Not overwriting this file. Pass -f to overwrite")
        return

    print(f"Writing '{filename}'")
    func(filename)

if not args.already_rectified:
    write_output(lambda filename: models_rectified[0].write(filename),
                 'rectified0.cameramodel')
    write_output(lambda filename: models_rectified[1].write(filename),
                 'rectified1.cameramodel')

    rectification_maps = mrcal.rectification_maps(models, models_rectified)




if args.clahe:
    clahe = cv2.createCLAHE()
    clahe.setClipLimit(8)


image_filenames_all = [sorted(glob.glob(os.path.expanduser(g))) for g in args.images]

if len(image_filenames_all[0]) != len(image_filenames_all[1]):
    raise Exception(f"First glob matches {len(image_filenames_all[0])} files, but the second glob matches {len(image_filenames_all[1])} files. These must be identical")

Nimages = len(image_filenames_all[0])

for i in range(Nimages):

    print(f"##### processing {os.path.split(image_filenames_all[0][i])[1]} and {os.path.split(image_filenames_all[1][i])[1]}")

    image_filenames = (image_filenames_all[0][i],
                       image_filenames_all[1][i])

    flags = ()
    if args.force_grayscale:
        flags = (cv2.IMREAD_GRAYSCALE,)
    images = [cv2.imread(f, *flags) for f in image_filenames]

    # This doesn't really matter: I don't use the input imagersize. But a
    # mismatch suggests the user probably messed up, and it would save them time
    # to yell at them
    imagersize_image = np.array((images[0].shape[1], images[0].shape[0]))
    imagersize_model = models[0].imagersize()
    if np.any(imagersize_image - imagersize_model):
        raise Exception(f"Image '{image_filenames_all[0][i]}' dimensions {imagersize_image} don't match the model '{args.models[0]}' dimensions {imagersize_model}")
    imagersize_image = np.array((images[1].shape[1], images[1].shape[0]))
    imagersize_model = models[1].imagersize()
    if np.any(imagersize_image - imagersize_model):
        raise Exception(f"Image '{image_filenames_all[1][i]}' dimensions {imagersize_image} don't match the model '{args.models[1]}' dimensions {imagersize_model}")

    if args.clahe:
        images = [ clahe.apply(image) for image in images ]

    if args.valid_intrinsics_region:
        for i in range(2):
            mrcal.annotate_image__valid_intrinsics_region(images[i], models[i])

    image_filenames_base = \
        [os.path.splitext(os.path.split(f)[1])[0] for f in image_filenames]

    if not args.already_rectified:
        images_rectified = [mrcal.transform_image(images[i],
                                                  rectification_maps[i]) \
                            for i in range(2)]

        write_output(lambda filename: cv2.imwrite(filename, images_rectified[0]),
                     image_filenames_base[0] + '-rectified.png')
        write_output(lambda filename: cv2.imwrite(filename, images_rectified[1]),
                     image_filenames_base[1] + '-rectified.png')
    else:
        images_rectified = images


    # Done with all the preliminaries. Run the stereo matching
    disp_min,disp_max = args.disparity_range

    # This is a hard-coded property of the OpenCV StereoSGBM implementation
    disparity_scale = 16

    # round to nearest multiple of disparity_scale. The OpenCV StereoSGBM
    # implementation requires this
    disp_max = disparity_scale*round(disp_max/disparity_scale)

    # I only add non-default args. StereoSGBM_create() doesn't like being given
    # None args
    kwargs = dict()
    if args.sgbm_p1 is not None:
        kwargs['P1']                = args.sgbm_p1
    if args.sgbm_p2 is not None:
        kwargs['P2']                = args.sgbm_p2
    if args.sgbm_disp12_max_diff is not None:
        kwargs['disp12MaxDiff']     = args.sgbm_disp12_max_diff
    if args.sgbm_uniqueness_ratio is not None:
        kwargs['uniquenessRatio']   = args.sgbm_uniqueness_ratio
    if args.sgbm_speckle_window_size is not None:
        kwargs['speckleWindowSize'] = args.sgbm_speckle_window_size
    if args.sgbm_speckle_range is not None:
        kwargs['speckleRange']      = args.sgbm_speckle_range
    if args.sgbm_mode is not None:
        kwargs['mode']              = args.sgbm_mode
    stereo = \
        cv2.StereoSGBM_create(minDisparity      = disp_min,
                              numDisparities    = disp_max,
                              # blocksize is required, so I always pass it.
                              # There's a default set in the argument parser, no
                              # this is never None
                              blockSize         = args.sgbm_block_size,
                              **kwargs)
    disparity = stereo.compute(*images_rectified)
    write_output(lambda filename: \
                 cv2.imwrite(filename,
                             mrcal.apply_color_map(disparity,
                                                   0, disp_max*disparity_scale)),
                 image_filenames_base[0] + '-disparity.png')

    _range = mrcal.stereo_range( disparity.astype(np.float32) / disparity_scale,
                                 models_rectified )
    write_output(lambda filename: \
                 cv2.imwrite(filename, mrcal.apply_color_map(_range,
                                                             *args.range_image_limits)),
                 image_filenames_base[0] + '-range.png')
