#!/usr/bin/python3
"""
    Copyright (C) 2021  Michael Ablassmeier <abi@grinser.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sys
import shutil
import tempfile
import signal
from functools import partial
import json
import time
import argparse
from argparse import Namespace
import logging

from typing import List, Tuple, Dict, IO
from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import qemu
from libvirtnbdbackup.qemu.exceptions import ProcessError, QemuHelperError
from libvirtnbdbackup import output
from libvirtnbdbackup.output.exceptions import OutputException
from libvirtnbdbackup.common import common as lib
from libvirtnbdbackup.common.processinfo import processInfo
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup.exceptions import RestoreError
from libvirtnbdbackup.sparsestream import streamer
from libvirtnbdbackup.sparsestream import types
from libvirtnbdbackup.sparsestream import exceptions


def checkRequirements():
    """Check if required utils are installed"""
    for exe in ("nbdkit", "qemu-nbd"):
        if not shutil.which(exe):
            logging.error("Please install required [%s] utility.", exe)


def checkDevice(args: Namespace, device: str) -> None:
    """Check if /dev/nbdX exists, otherwise it is likely
    nbd module isnt loaded on the system"""
    if not device.startswith("/dev/nbd"):
        logging.error("Target device [%s] seems not to be a ndb device?", device)

    if not lib.exists(args, device):
        logging.error(
            "Target device [%s] does not exist, nbd module not loaded?", device
        )


def locatePlugin(args: Namespace) -> str:
    """Attempt to locate the nbdkit plugin that is passed to the
    nbdkit process"""
    pluginFileName = "virtnbd-nbdkit-plugin"
    installDir = os.path.dirname(sys.argv[0])
    nbdkitModule = f"{installDir}/{pluginFileName}"

    if not lib.exists(args, nbdkitModule):
        logging.error("Failed to locate nbdkit plugin: [%s]", pluginFileName)

    return nbdkitModule


def replayChanges(dataRanges: List, args: Namespace) -> None:
    """Replay the changes from an incremental or differential
    backup file to the NBD device"""
    logging.info("Replaying changes from incremental backups")
    blockListInc = list(
        filter(
            lambda x: x["inc"] is True,
            dataRanges,
        )
    )
    dataSize = sum(extent["length"] for extent in blockListInc)
    progressBar = lib.progressBar(dataSize, "replaying..", args)
    with output.openfile(args.device, "wb") as replayDevice:
        for extent in blockListInc:
            if args.noprogress:
                logging.info(
                    "Replaying offset %s from %s", extent["offset"], extent["file"]
                )
            with output.openfile(os.path.abspath(extent["file"]), "rb") as replaySrc:
                replaySrc.seek(extent["offset"])
                replayDevice.seek(extent["originalOffset"])
                replayDevice.write(replaySrc.read(extent["length"]))
            replayDevice.seek(0)
            replayDevice.flush()
            progressBar.update(extent["length"])

    progressBar.close()


def handleSignal(nbdkitProcess: processInfo, device: str, blockMap, log, signum, _):
    """Catch signal, attempt to stop processes."""
    log.info("Received signal: [%s]", signum)
    qemu.util("").disconnect(device)
    log.info("Removing temporary blockmap file: [%s]", blockMap.name)
    os.remove(blockMap.name)
    log.info("Removing nbdkit logfile: [%s]", nbdkitProcess.logFile)
    os.remove(nbdkitProcess.logFile)
    lib.killProc(nbdkitProcess.pid)
    sys.exit(0)


def getDataRanges(stream, sTypes, reader) -> Tuple[List, Dict]:
    """Read block offsets from backup stream image"""
    try:
        kind, start, length = stream.readFrame(reader)
        meta = stream.loadMetadata(reader.read(length))
    except exceptions.StreamFormatException as errmsg:
        logging.error("Unable to read metadata header: %s", errmsg)
        raise RestoreError from errmsg

    if lib.isCompressed(meta):
        logging.error("Mapping compressed images currently not supported.")
        raise RestoreError

    assert reader.read(len(sTypes.TERM)) == sTypes.TERM

    dataRanges: List = []
    while True:
        kind, start, length = stream.readFrame(reader)
        if kind == sTypes.STOP:
            dataRanges[-1]["nextBlockOffset"] = None
            break

        blockInfo = {}
        blockInfo["offset"] = reader.tell()
        blockInfo["originalOffset"] = start
        blockInfo["nextOriginalOffset"] = start + length
        blockInfo["length"] = length
        blockInfo["data"] = kind == sTypes.DATA
        blockInfo["file"] = os.path.abspath(reader.name)
        blockInfo["inc"] = meta["incremental"]

        if kind == sTypes.DATA:
            reader.seek(length, os.SEEK_CUR)
            assert reader.read(len(sTypes.TERM)) == sTypes.TERM

        nextBlockOffset = reader.tell() + sTypes.FRAME_LEN
        blockInfo["nextBlockOffset"] = nextBlockOffset
        dataRanges.append(blockInfo)

    return dataRanges, meta


def dumpBlockMap(tfile: IO, dataRanges: List) -> bool:
    """Dump block map to temporary file"""
    try:
        tfile.write(json.dumps(dataRanges, indent=4).encode())
        return True
    except OSError as e:
        logging.error("Unable to write blockmap file: %s", e)
        return False


def main() -> None:
    """Map full backup file to nbd device for single file or
    instant recovery"""
    parser = argparse.ArgumentParser(
        description="Map backup image(s) to block device",
        epilog=(
            "Examples:\n"
            "   # Map full backup to device /dev/nbd0:\n"
            "\t%(prog)s -f /backup/sda.full.data\n"
            "   # Map full backup to device /dev/nbd2:\n"
            "\t%(prog)s -f /backup/sda.full.data -d /dev/nbd2\n"
            "   # Map sequence of full and incremental to device /dev/nbd2:\n"
            "\t%(prog)s -f /backup/sda.full.data,/backup/sda.inc.1.data -d /dev/nbd2\n"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )
    opt = parser.add_argument_group("General options")
    opt.add_argument(
        "-f", "--file", required=True, type=str, help="List of Backup files to map"
    )
    opt.add_argument(
        "-b",
        "--blocksize",
        required=False,
        type=str,
        default="4096",
        help="Maximum blocksize passed to nbdkit. (default: %(default)s)",
    )
    opt.add_argument(
        "-d",
        "--device",
        default="/dev/nbd0",
        type=str,
        help="Target device. (default: %(default)s)",
    )
    opt.add_argument(
        "-e",
        "--export-name",
        default="sda",
        type=str,
        help="Export name passed to nbdkit. (default: %(default)s)",
    )
    opt.add_argument(
        "-t",
        "--threads",
        default=1,
        type=str,
        help="Amount of threads passed to nbdkit process. (default: %(default)s)",
    )
    opt.add_argument(
        "-l",
        "--listen-address",
        default="127.0.0.1",
        type=str,
        help="IP Address for nbdkit process to listen on. (default: %(default)s)",
    )
    opt.add_argument(
        "-p",
        "--listen-port",
        default="10809",
        type=str,
        help="Port for nbdkit process to listen on. (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        required=False,
        action="store_true",
        default=False,
        help="Disable progress bar",
    )
    logopt = parser.add_argument_group("Logging options")
    argopt.addLogArgs(logopt, parser.prog)
    debopt = parser.add_argument_group("Debug options")
    debopt.add_argument(
        "-r",
        "--readonly",
        required=False,
        action="store_true",
        help="Map image readonly (default: %(default)s)",
    )
    argopt.addDebugArgs(debopt)
    args = lib.argparse(parser)
    args.sshClient = None
    fileLog = lib.getLogFile(args.logfile) or sys.exit(1)

    counter = logCount()
    lib.configLogger(args, fileLog, counter)
    lib.printVersion(__version__)
    nbdkitModule = locatePlugin(args)
    logging.info("Logfile: [%s]", args.logfile)
    logging.info("Plugin location: [%s]", nbdkitModule)

    checkRequirements()
    checkDevice(args, args.device)
    dataFiles = args.file.split(",")

    if len(dataFiles) > 1 and not "full.data" in dataFiles[0]:
        logging.error("Sequence must start with a full backup")
    if len(dataFiles) > 1 and args.readonly:
        logging.error("Device mapping with incrementals doesn't work in readonly mode")

    if counter.count.errors > 0:
        sys.exit(1)

    fullImage = os.path.abspath(dataFiles[0])

    stream = streamer.SparseStream(types)
    sTypes = types.SparseStreamTypes()

    # pylint: disable=consider-using-with
    blockMap = tempfile.NamedTemporaryFile(delete=False, prefix="block.", suffix=".map")
    logging.info("Write blockmap to temporary file: [%s]", blockMap.name)

    dataRanges = []
    for dFile in dataFiles:
        try:
            reader = output.openfile(dFile, "rb")
        except OutputException as e:
            logging.error("[%s]: [%s]", dFile, e)
            sys.exit(1)

        ranges, meta = getDataRanges(stream, sTypes, reader)
        dataRanges.extend(ranges)
        if ranges is False or meta is False:
            logging.error("Unable to read meta header from backup file.")
            sys.exit(1)

        if args.verbose is True:
            logging.info(json.dumps(dataRanges, indent=4))
        else:
            logging.info(
                "Parsed [%s] block offsets from file [%s]", len(dataRanges), dFile
            )
        reader.close()

    dumpBlockMap(blockMap, dataRanges)
    blockMap.flush()
    blockMap.close()

    logging.info("Target device: %s", args.device)

    qFh = qemu.util(args.export_name)
    try:
        nbdkitProcess = qFh.startNbdkitProcess(
            args, nbdkitModule, blockMap.name, fullImage
        )
    except QemuHelperError as e:
        logging.error("Failed to start nbdkit process: [%s]", e)
        sys.exit(1)

    logging.info(
        "Started nbdkit process pid: [%s], Logfile: [%s]",
        nbdkitProcess.pid,
        nbdkitProcess.logFile,
    )
    signal.signal(
        signal.SIGINT,
        partial(handleSignal, nbdkitProcess, args.device, blockMap, logging),
    )

    maxRetry = 10
    retryCnt = 0
    nbdCmd = [
        "qemu-nbd",
        "-c",
        f"{args.device}",
        f"nbd://127.0.0.1:{args.listen_port}/{args.export_name}",
        "-f",
        "raw",
    ]
    if args.readonly:
        logging.warning("Device will be mapped readonly without cow.")
        logging.warning("Mounting will only work with '-o norecovery,ro'")
        nbdCmd.append("-r")
    logging.debug(nbdCmd)
    while True:
        try:
            qemu.runcmd(cmdLine=nbdCmd, toPipe=True)
            break
        except ProcessError as e:
            if retryCnt >= maxRetry:
                logging.info("Unable to connect device after service start: %s", e)
                lib.killProc(nbdkitProcess.pid)
                break
            if "Connection refused" in str(e):
                logging.info("NBD server refused connection, retry [%s]", retryCnt)
                time.sleep(1)
                retryCnt += 1
            else:
                logging.error("Failed to map device:")
                logging.error("Stderr: [%s]", str(e))
                lib.killProc(nbdkitProcess.pid)

    if len(dataFiles) > 1:
        try:
            replayChanges(dataRanges, args)
        except OutputException as e:
            logging.error("Failed to replay changes: %s", e)
            lib.killProc(nbdkitProcess.pid)
            sys.exit(1)

    logging.info("Done mapping backup image to [%s]", args.device)
    logging.info("Press CTRL+C to disconnect")
    while True:
        time.sleep(60)


if __name__ == "__main__":
    main()
