#!/usr/bin/env python
#
# Copyright (C) 2019 CS-SI. All Rights Reserved.
# Author: Yoann Vandoorselaere <yoannv@gmail.com>
#
# This file is part of the preludedb-admin program.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

from __future__ import absolute_import, division, print_function, unicode_literals

import argparse
import io
import multiprocessing
import operator
import os
import prelude
import preludedb

try:
    import queue
except ImportError:
    import Queue as queue

import signal
import sys
import time


SIGLIST = [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT, signal.SIGUSR1]


TOP_PARSER = argparse.ArgumentParser(prog="preludedb-admin")
SUBPARSERS = TOP_PARSER.add_subparsers(help='sub-command help')

DATABASE_HELP = "Database settings (example: \"type=mysql user=prelude name=mydb\")"
DEFAULT_LIMIT = 8192
PROCESS_EXITED = -1


running_worker = multiprocessing.Value('i', 0)
continue_processing = multiprocessing.Value('i', True)


def ask_confirmation(question):
    func = input if sys.version_info[0] > 2 else raw_input

    print("%s [y/n]" % question, file=sys.stderr)
    while True:
        try:
            return {"y": True, "n": False}[func().lower()]
        except KeyError:
            print("Please respond with \'y\' or \'n\'.\n", file=sys.stderr)



class DatabaseAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        if hasattr(namespace, "object"):
            v = TypedDB(namespace.object, values)
        else:
            try:
                v = preludedb.DB(preludedb.SQL(values))
            except RuntimeError as e:
                print(e)
                sys.exit(1)

        # Force reopen after fork()
        v.get_sql().close()

        setattr(namespace, self.dest, v)


class TypedDB(preludedb.DB):
    def __init__(self, type, *args):
        try:
            preludedb.DB.__init__(self, preludedb.SQL(*args))
        except RuntimeError as e:
            print(e)
            sys.exit(1)

        if type == "alert":
            self.getIdents = self.getAlertIdents
            self.get = self.getAlert
            self.delete = self.deleteAlert

        elif type == "heartbeat":
            self.getIdents = self.getHeartbeatIdents
            self.get = self.getHeartbeat
            self.delete = self.deleteHeartbeat


class Worker(multiprocessing.Process):
    def __init__(self, task_queue, results_queue, cmdobj):
        multiprocessing.Process.__init__(self)
        self.task_queue = task_queue
        self.results_queue = results_queue
        self._cmdobj = cmdobj

        self.processed = multiprocessing.Value('i', 0)

        # This is done before the actual fork to prevent a race condition where the parent
        # would call continue_processing() before the value has been incremented.
        with running_worker.get_lock():
            running_worker.value += 1

    def run(self):
        for i in SIGLIST:
            signal.signal(i, signal.SIG_IGN)

        self._cmdobj.init_worker()

        success = True
        while success:
            task = self.task_queue.get()
            if task is None:
                break

            try:
                ret = self._cmdobj.run_worker_transaction(task, worker=self)
                if ret:
                    self.results_queue.put(ret)

            except Exception as error:
                success = False
                print("[%s error]: %s" % (self.name, str(error)), file=sys.stderr)

        self.results_queue.put(PROCESS_EXITED)
        self._cmdobj.exit_worker(success)


class GenericCommand(object):
    processed = 0
    iteration = 0
    pushed_count = 0
    workers = []

    support_transaction = []
    support_events_stats = support_limit = support_offset = support_criteria = True

    def __init__(self):
        if self.support_transaction:
            self.parser.add_argument("-e", "--events-per-transaction", type=int, default=0, help="Maximum number of event to process per transaction (default no transaction)")

        if self.support_criteria:
            self.parser.add_argument("-c", "--criteria", type=prelude.IDMEFCriteria, default=None, help="IDMEF criteria")

        if self.support_offset:
            self.parser.add_argument("-o", "--offset", type=int, default=0, help="Skip processing until 'offset' events")

        if self.support_limit:
            self.parser.add_argument("-l", "--limit", type=int, default=-1, help="Process at most 'limit' events")

        self._options = TOP_PARSER.parse_args()
        self._offset = getattr(self._options, "offset", 0)

        self._limit = getattr(self._options, "limit", -1)
        if self._limit == -1:
            self._limit = sys.maxsize

    def init_parent(self):
        pass

    def run_parent(self):
        pass

    def exit_parent(self):
        return self.exit_worker()

    def _make_transaction(self, type, force=False, op=operator.eq):
        if not self.support_transaction:
            return

        eps = self._options.events_per_transaction

        if eps > 1 and (force or op(self.iteration % eps, 0)):
            for i in self.support_transaction:
                getattr(getattr(self._options, i), type)()

    def exit_worker(self, success=True):
        if success:
            self._make_transaction("transaction_end", op=operator.ne)

    def run_worker_transaction(self, args, worker=None):
        ret = None
        for i in args:
            if not continue_processing.value:
                break

            try:
                count = len(i)
            except:
                count = 1

            self._make_transaction("transaction_start")

            try:
                ret = self.run_worker(i)
            except Exception:
                self._make_transaction("transaction_abort", force=True)
                raise

            self.processed += count
            self.iteration += 1
            self._make_transaction("transaction_end")

            if worker:
                with worker.processed.get_lock():
                    worker.processed.value += count

        return ret

    @property
    def continue_processing(self):
        return continue_processing.value and (self._options.limit == -1 or self.pushed_count < self._options.limit)

    def get_idents_limited(self, db, force_offset_zero=False, limit=DEFAULT_LIMIT):
        ncpu = getattr(self._options, "multiprocess", 1)
        limit = min(limit * ncpu, self._limit)

        while self.continue_processing:
            remaining = self._limit - self.pushed_count
            count = self.pushed_count if force_offset_zero is False else 0

            res = db.getIdents(self._options.criteria, min(remaining, limit), self._offset + count)
            if not res:
                break

            yield tuple(res)

    def push_worker(self, args, batch=False):
        if batch:
            args = [args]

        self.run_worker_transaction(args)
        self.pushed_count += len(args)


class MultiprocessCommand(GenericCommand):
    def __init__(self):
        self.parser.add_argument("-m", "--multiprocess", type=int, default=multiprocessing.cpu_count(),
                                 help="number of processes to use. Default value matches the number of available CPUs (i.e. %d)" % multiprocessing.cpu_count())

        GenericCommand.__init__(self)

        if self._options.multiprocess == 1:
            return self.init_worker()

        self._results = multiprocessing.Queue()
        self._tasks = multiprocessing.Queue(maxsize=self._options.multiprocess)

        self.workers = [Worker(self._tasks, self._results, self) for i in range(self._options.multiprocess)]
        for i in self.workers:
            i.start()

    def exit_parent(self):
        if self._options.multiprocess <= 1:
            return GenericCommand.exit_worker(self)

        for i in range(len(multiprocessing.active_children())):
            self._tasks.put(None)

        # Wait for all the childs to exit
        for i in multiprocessing.active_children():
            i.join()

    def wait_results(self, n):
        if self._options.multiprocess < 2:
            return

        for i in range(n):
            if self._results.get() == PROCESS_EXITED:
                self._options.multiprocess -= 1

    def init_worker(self):
        pass

    def exit_worker(self, success):
        GenericCommand.exit_worker(self, success)

        with running_worker.get_lock():
            running_worker.value -= 1

            # If this is the last running worker, we make sure that the parent feeder is released
            # from a potential blocking put() (in case the queue was already full).
            if running_worker.value == 0:
                try:
                    self._tasks.get_nowait()
                except queue.Empty:
                    pass

    def _chunkify(self, lst, n):
        for i in range(n):
            yield lst[i::n]

    def push_worker(self, args, batch = False):
        if self._options.multiprocess == 1:
            return GenericCommand.push_worker(self, args, batch)

        if len(args) == 1:
            self._tasks.put(args)
        else:
            for i in self._chunkify(args, self._options.multiprocess):
                self._tasks.put([i] if batch else i)

        self.pushed_count += len(args)

    @property
    def continue_processing(self):
        ret = GenericCommand.continue_processing.fget(self)
        if self._options.multiprocess == 1:
            return ret

        with running_worker.get_lock():
            return ret and running_worker.value


class Load(MultiprocessCommand):
    support_transaction = ["database"]

    parser = SUBPARSERS.add_parser("load", help="Load a Prelude database from a file")
    parser.add_argument("database", action=DatabaseAction, help=DATABASE_HELP)
    parser.add_argument("infile", nargs="?", type=argparse.FileType('r'), default=sys.stdin, help="File to load the data from (default to stdin)")

    def run_parent(self):
        while self.continue_processing:
            idmef = prelude.IDMEF()
            try:
                idmef << self._options.infile
            except EOFError:
                return

            if self._options.criteria and not self._options.criteria.match(idmef):
                continue

            self.push_worker([idmef])

    def run_worker(self, idmef):
        self._options.database.insert(idmef)


class Copy(MultiprocessCommand):
    support_transaction = ["database2"]

    parser = SUBPARSERS.add_parser("copy", help="Copy content of a Prelude database to another database")
    parser.add_argument("object", type=str, choices=["alert", "heartbeat"], help="Type of object")
    parser.add_argument("database1", action=DatabaseAction, help=DATABASE_HELP)
    parser.add_argument("database2", action=DatabaseAction, help=DATABASE_HELP)

    def run_parent(self):
        for i in self.get_idents_limited(self._options.database1):
            self.push_worker(i)

    def run_worker(self, ident):
        self._options.database2.insert(self._options.database1.get(ident))


class Move(Copy):
    parser = SUBPARSERS.add_parser("move", help="Move content of a Prelude database to another database")
    parser.add_argument("object", type=str, choices=["alert", "heartbeat"], help="Type of object")
    parser.add_argument("database1", type=str, action=DatabaseAction, help=DATABASE_HELP)
    parser.add_argument("database2", type=str, action=DatabaseAction, help=DATABASE_HELP)

    def run_worker(self, ident):
        Copy.run_worker(self, ident)
        self._options.database1.delete(ident)


class Delete(MultiprocessCommand):
    support_offset = False
    support_transaction = ["database"]

    parser = SUBPARSERS.add_parser("delete", help="Delete content of a Prelude database")
    parser.add_argument("object", type=str, choices=["alert", "heartbeat"], help="Type of object")
    parser.add_argument("database", type=str, action=DatabaseAction, help=DATABASE_HELP)
    parser.add_argument("--optimize", action='store_true', help="Run OPTIMIZE after delete operation")

    def init_parent(self):
        if not self._options.criteria and not ask_confirmation("No criteria provided : this will delete the entire database. Proceed ?"):
            sys.exit(0)

    def run_parent(self):
        for i in self.get_idents_limited(self._options.database, force_offset_zero=True):
            self.push_worker(i, batch=True)

            # Retrieving ident will return incorrect results before all process are finish.
            self.wait_results(self._options.multiprocess)

        if self._options.optimize:
            self._options.database.optimize()
        else:
            print("WARNING: Orphan data may remain. Please run OPTIMIZE command to complete the process.", file=sys.stderr)

    def run_worker(self, ident):
        self._options.database.delete(ident)
        return True


class Print(GenericCommand):
    parser = SUBPARSERS.add_parser("print", help="Print message from a Prelude database")
    parser.add_argument("object", type=str, choices=["alert", "heartbeat"], help="Type of object")
    parser.add_argument("database", type=str, action=DatabaseAction, help=DATABASE_HELP)

    if sys.version_info[0] > 2:
        type, default = argparse.FileType(mode='w', errors='surrogateescape'), io.TextIOWrapper(sys.stdout.buffer, errors="surrogateescape")
    else:
        type, default = argparse.FileType(mode='w'), sys.stdout

    parser.add_argument("outfile", nargs="?", type=type, default=default, help="File to write the data to (default to stdout)")

    def run_parent(self):
        for i in self.get_idents_limited(self._options.database):
            self.push_worker(i)

    def run_worker(self, ident):
        self._options.outfile.write(str(self._options.database.get(ident)))


class Save(MultiprocessCommand):
    support_transaction = False

    parser = SUBPARSERS.add_parser("save", help="Save a Prelude database to a file")
    parser.add_argument("object", type=str, choices=["alert", "heartbeat"], help="Type of object")
    parser.add_argument("database", type=str, action=DatabaseAction, help=DATABASE_HELP)
    parser.add_argument("outfile", nargs="?", type=argparse.FileType('w'), default=sys.stdout, help="File to write the data to (default to stdout)")

    def __init__(self, *args, **kwargs):
        self._lock = multiprocessing.Lock()
        MultiprocessCommand.__init__(self, *args, **kwargs)

    def run_parent(self):
        for i in self.get_idents_limited(self._options.database):
            self.push_worker(i)

    def run_worker(self, i):
        idmef = self._options.database.get(i)

        with self._lock:
            idmef >> self._options.outfile
            self._options.outfile.flush()


class Count(GenericCommand):
    support_events_stats = False

    parser = SUBPARSERS.add_parser("count", help="Retrieve event count from a database")
    parser.add_argument("object", type=str, choices=["alert", "heartbeat"], help="Type of object")
    parser.add_argument("database", type=str, action=DatabaseAction, help=DATABASE_HELP)

    def run_parent(self):
        res = self._options.database.getValues(["count(%s.create_time)" % self._options.object], limit=self._options.limit, offset=self._options.offset, criteria=self._options.criteria)
        print("COUNT : %d" % res[0][0])


class Optimize(GenericCommand):
    support_criteria = support_offset = support_limit = support_events_stats = False

    parser = SUBPARSERS.add_parser("optimize", help="Optimize a Prelude database")
    parser.add_argument("database", type=str, action=DatabaseAction, help=DATABASE_HELP)

    def run_parent(self):
        self._options.database.optimize()


def update(val):
    p, v = val.split("=")
    return p, v


class Update(GenericCommand):
    support_events_stats = False

    name = "update"

    parser = SUBPARSERS.add_parser("update", help="Update value in a Prelude database")
    parser.add_argument("database", type=str, action=DatabaseAction, help=DATABASE_HELP)
    parser.add_argument("values", metavar="PATH=VALUE", type=update, nargs="+", help="Paths to update with their respective values")

    def run_parent(self):
        plist = []
        vlist = []

        for path, value in self._options.values:
            plist.append(path)
            vlist.append(prelude.IDMEFValue(value, path=prelude.IDMEFPath(path)))

        self._options.database.update(plist, vlist, self._options.criteria, limit=self._options.limit, offset=self._options.offset)


class Command(object):
    COMMAND_TBL = {
        "count": Count,
        "delete": Delete,
        "copy": Copy,
        "load": Load,
        "move": Move,
        "optimize": Optimize,
        "print": Print,
        "save": Save,
        "update": Update
    }

    def _pstat(self, count, elapsed, name=None, status=""):
        hdr = ""
        if name:
            if status:
                status = " " + status

            hdr = "[%s%s] " % (name, status)

        if self._cmd.support_events_stats:
            hdr += "%d '%s' events processed" % (count, self._cmd.__class__.__name__.lower())
        else:
            hdr += "'%s' operation executed" % (self._cmd.__class__.__name__.lower())

        evsec = ""
        if not status:
            evsec = " in %f seconds" % elapsed

        if self._cmd.support_events_stats:
            evsec += " (%f events/sec average)" % (count / elapsed)

        print("%s%s." % (hdr, evsec), file=sys.stderr)

    def _handle_stats_signal(self, signo=None, frame=None):
        elapsed = time.time() - self._time_start

        if signo:
            print("\n", file=sys.stderr)

        total = 0
        for i in self._cmd.workers:
            count = i.processed.value
            total += count
            self._pstat(count, elapsed, name=i.name, status="running" if i.is_alive() else "stopped")

        self._pstat(total or self._cmd.processed, elapsed)

    def _handle_signal(self, signo, frame):
        if continue_processing.value:
            print("Interrupted by signal. Will exit after current transaction (send one more time for immediate exit).", file=sys.stderr)
            continue_processing.value = False
            return

        print("Interrupted by signal (user forced)", file=sys.stderr)
        self._handle_stats_signal(signo, frame)

        for i in multiprocessing.active_children():
            os.kill(i.pid, signal.SIGKILL)

        # We don't want to use sys.exit() here, since the parent feeder might be locked in a put().
        os._exit(1)

    def __init__(self):
        try:
            v = self.COMMAND_TBL[sys.argv[1]]
        except (KeyError, IndexError):
            TOP_PARSER.print_help()
            sys.exit(1)

        self._cmd = v()

        try:
            self._cmd.init_parent()
        except:
            self._cmd.exit_parent()
            raise

        signal.signal(signal.SIGINT, self._handle_signal)
        signal.signal(signal.SIGTERM, self._handle_signal)
        signal.signal(signal.SIGQUIT, self._handle_stats_signal)
        signal.signal(signal.SIGUSR1, self._handle_stats_signal)

        self._time_start = time.time()
        try:
            self._cmd.run_parent()
        finally:
            self._cmd.exit_parent()

        self._handle_stats_signal()


if __name__ == "__main__":
    Command()
