#!/usr/bin/env python
#
# This is a save process which also buffers outgoing I/O between
# rounds, so that external viewers never see anything that hasn't
# been committed at the backup
#
# TODO: fencing.

import optparse, os, re, select, signal, sys, time

from xen.remus import save, util, vm
from xen.remus.device import ReplicatedDisk, ReplicatedDiskException
from xen.remus.device import BufferedNIC, BufferedNICException
from xen.xend import XendOptions

class CfgException(Exception): pass

class Cfg(object):

    REMUS_FLAGS_COMPRESSION = 1

    def __init__(self):
        # must be set
        self.domid = 0

        self.host = 'localhost'
        self.nullremus = False
        self.port = XendOptions.instance().get_xend_relocation_port()
        self.interval = 200
        self.netbuffer = True
        self.flags = self.REMUS_FLAGS_COMPRESSION
        self.timer = False

        parser = optparse.OptionParser()
        parser.usage = '%prog [options] domain [destination]'
        parser.add_option('-i', '--interval', dest='interval', type='int',
                          metavar='MS',
                          help='checkpoint every MS milliseconds')
        parser.add_option('-p', '--port', dest='port', type='int',
                          help='send stream to port PORT', metavar='PORT')
        parser.add_option('', '--blackhole', dest='nullremus', action='store_true',
                          help='replicate to /dev/null (no disk checkpoints, only memory & net buffering)')
        parser.add_option('', '--no-net', dest='nonet', action='store_true',
                          help='run without net buffering (benchmark option)')
        parser.add_option('', '--no-compression', dest='nocompress', action='store_true',
                          help='run without checkpoint compression')
        parser.add_option('', '--timer', dest='timer', action='store_true',
                          help='force pause at checkpoint interval (experimental)')
        self.parser = parser

    def usage(self):
        self.parser.print_help()

    def getargs(self):
        opts, args = self.parser.parse_args()

        if opts.interval:
            self.interval = opts.interval
        if opts.port:
            self.port = opts.port
        if opts.nullremus:
            self.nullremus = True
        if opts.nonet:
            self.netbuffer = False
        if opts.nocompress:
            self.flags &= ~self.REMUS_FLAGS_COMPRESSION
        if opts.timer:
            self.timer = True

        if not args:
            raise CfgException('Missing domain')
        self.domid = args[0]
        if (len(args) > 1):
            self.host = args[1]

class SignalException(Exception): pass

def run(cfg):
    closure = lambda: None
    closure.cmd = None

    def sigexception(signo, frame):
        raise SignalException(signo)

    def die():
        # I am not sure what the best way to die is. xm destroy is another option,
        # or we could attempt to trigger some instant reboot.
        print "dying..."
        print util.runcmd(['sudo', 'ifdown', 'eth2'])
        # dangling imq0 handle on vif locks up the system
        for buf in bufs:
            buf.uninstall()
        print util.runcmd(['sudo', 'xm', 'destroy', cfg.domid])
        print util.runcmd(['sudo', 'ifup', 'eth2'])

    def getcommand():
        """Get a command to execute while running.
        Commands include:
          s: die prior to postsuspend hook
          s2: die after postsuspend hook
          r: die prior to preresume hook
          r2: die after preresume hook
          c: die prior to commit hook
          c2: die after commit hook
          """
        r, w, x = select.select([sys.stdin], [], [], 0)
        if sys.stdin not in r:
            return

        cmd = sys.stdin.readline().strip()
        if cmd not in ('s', 's2', 'r', 'r2', 'c', 'c2'):
            print "unknown command: %s" % cmd
        closure.cmd = cmd

    signal.signal(signal.SIGTERM, sigexception)

    dom = vm.VM(cfg.domid)

    # set up I/O buffers
    bufs = []

    # disks must commit before network can be released
    if not cfg.nullremus:
        for disk in dom.disks:
            try:
                bufs.append(ReplicatedDisk(disk))
            except ReplicatedDiskException, e:
                print e
                continue

    if cfg.netbuffer:
        for vif in dom.vifs:
            bufs.append(BufferedNIC(vif))

    if cfg.nullremus:
        fd = save.NullSocket((cfg.host, cfg.port))
    else:
        fd = save.MigrationSocket((cfg.host, cfg.port))

    def postsuspend():
        'Begin external checkpointing after domain has paused'
        if not cfg.timer:
            # when not using a timer thread, sleep until now + interval
            closure.starttime = time.time()

        if closure.cmd == 's':
            die()

        for buf in bufs:
            buf.postsuspend()

        if closure.cmd == 's2':
            die()

    def preresume():
        'Complete external checkpointing before domain resumes'
        if closure.cmd == 'r':
            die()

        for buf in bufs:
            buf.preresume()

        if closure.cmd == 'r2':
            die()

    def commit():
        'commit network buffer'
        if closure.cmd == 'c':
            die()

        print >> sys.stderr, "PROF: flushed memory at %0.6f" % (time.time())

        for buf in bufs:
            buf.commit()

        if closure.cmd == 'c2':
            die()

        # Since the domain is running at this point, it's a good time to
        # check for control channel commands
        getcommand()

        if not cfg.timer:
            endtime = time.time()
            elapsed = (endtime - closure.starttime) * 1000

            if elapsed < cfg.interval:
                time.sleep((cfg.interval - elapsed) / 1000.0)

        # False ends checkpointing
        return True

    if cfg.timer:
        interval = cfg.interval
    else:
        interval = 0

    rc = 0

    checkpointer = save.Saver(cfg.domid, fd, postsuspend, preresume, commit,
                              interval, cfg.flags)

    try:
        checkpointer.start()
    except save.CheckpointError, e:
        print e
        rc = 1
    except KeyboardInterrupt:
        pass
    except SignalException:
        print '*** signalled ***'

    for buf in bufs:
        buf.uninstall()

    sys.exit(rc)

cfg = Cfg()
try:
    cfg.getargs()
except CfgException, inst:
    print str(inst)
    cfg.usage()
    sys.exit(1)

try:
    run(cfg)
except vm.VMException, inst:
    print str(inst)
    sys.exit(1)
