#!/usr/bin/python3

# Authors: Daniel Kahn Gillmor <dkg@fifthhorseman.net>,
#          Gunnar Wolf <gwolf@debian.org>,
#          Jonathan McDowell <noodles@earth.li>
# License: Parts from dkg are GPLv3+

import os
from os import path
import socket
from subprocess import run, Popen, PIPE
from shutil import chown, copy, copyfile, rmtree
from distutils.dir_util import copy_tree
import sys
import tempfile
import hashlib
import codecs
from multiprocessing.pool import ThreadPool
from typing import List, Tuple, Optional
import datetime


def check_environ(should_run_on: str = 'kaufmann.debian.org') -> None:
    '''Make sure that we are running where we expect to run

    The expectation is to run on kaufmann.debian.org, but this can be
    bypassed for testing with the RUNANYWAY environment variable.
    '''
    if not (os.environ.get('RUNANYWAY', False) or
            socket.getfqdn(socket.gethostname()) == should_run_on):
        raise Exception('''
This script is meant to be run in %s
You can still run it if you are sure by setting
$RUNANYWAY to a nonempty value.
        ''' % (should_run_on))


def wkd_localpart(incoming: bytes) -> str:
    '''Z-base32 the localpart of an e-mail address

    https://tools.ietf.org/html/draft-koch-openpgp-webkey-service-08#section-3.1
    describes why this is needed.

    See https://tools.ietf.org/html/rfc6189#section-5.1.6 for a
    description of the z-base32 scheme.
    '''
    zb32 = "ybndrfg8ejkmcpqxot1uwisza345h769"

    b = hashlib.sha1(incoming).digest()
    ret = ""
    assert(len(b) * 8 == 160)
    for i in range(0, 160, 5):
        byte = i // 8
        offset = i - byte * 8
        # offset | bits remaining in k+1 | right-shift k+1
        # 3 | 0 | x
        # 4 | 1 | 7
        # 5 | 2 | 6
        # 6 | 3 | 5
        # 7 | 4 | 4
        if offset < 4:
            n = (b[byte] >> (3 - offset))
        else:
            n = (b[byte] << (offset - 3)) + (b[byte + 1] >> (11 - offset))

        ret += zb32[n & 0b11111]
    return ret


def getdomainlocalpart(line: bytes, domain: bytes) -> Optional[bytes]:
    'Get the localpart of the e-mail address of a GnuPG user ID line matching DOMAIN'
    if line.startswith(b'uid:'):
        uid = line.split(b':')[9]
        if uid.endswith(b'@' + domain + b'>'):
            broken = uid.split(b'<')
            if len(broken) != 2:
                raise ValueError("unexpected User ID %s" % (uid))
            return broken[1][:-len(b'@' + domain + b'>')].lower()
    return None


def gpgbase(keyrings: List[str]) -> List[str]:
    'Return the standard set of options to invoke gpg in an automated way'
    return ['gpg', '--batch', '--no-options', '--with-colons',
            '--no-default-keyring',
            '--homedir=/dev/null', '--trust-model=always',
            '--fixed-list-mode'] + list(map(lambda k: '--keyring=' + k, keyrings))


def emit_wkd_and_return_dane(localpart: bytes, domain: str, keyrings: List[str]) -> bytes:
    '''For a given address, emit the WKD file, and return the DANE OPENPGKEY record

    These are handled differently because we want to generate a
    single, reproducible zonefile for the DNS, while we are generating
    a tree of files for WKD.

    The caller will assemble all of the OPENPGPKEY records into a
    single zonefile.
    '''
    wkdstr = wkd_localpart(localpart)
    # what do we do if this local part is not a proper encoding?
    addr = codecs.decode(localpart) + '@' + domain
    cmd = gpgbase(keyrings) + ['--output',
                               path.join('openpgpkey', domain, 'hu', wkdstr),
                               '--export-options',
                               'export-clean',
                               '--export-filter',
                               'keep-uid=mbox=' + addr,
                               '--export',
                               '<' + addr + '>']
    run(cmd, check=True)
    cmd = gpgbase(keyrings) + ['--export-options', 'export-dane,export-clean',
                               '--export-filter', 'keep-uid=mbox=' + addr,
                               '--export', '<' + addr + '>']
    dane = run(cmd, stdout=PIPE, check=True)
    return dane.stdout


def build_wkd_and_dane(domain: str, keyrings: List[str]) -> None:
    'Publish WKD and DANE OPENPGPKEY for all domain-relevant OpenPGP certificates'
    if not path.isdir('openpgpkey'):
        os.mkdir('openpgpkey')
    os.mkdir(path.join('openpgpkey', domain))
    os.mkdir(path.join('openpgpkey', domain, 'hu'))

    # FIXME: deal with IDN:
    bytedomain = codecs.encode(domain)

    lister = Popen(gpgbase(keyrings) +
                   ['--list-keys', '@' + domain], stdout=PIPE)

    localparts = set(
        map(lambda x: getdomainlocalpart(x, bytedomain), lister.stdout))
    localparts.discard(None)

    dane_map = {}

    def runner(x: bytes) -> Tuple[bytes, bytes]:
        return (x, emit_wkd_and_return_dane(x, domain, keyrings))

    def add_to_zone(res: Tuple[bytes, bytes]) -> None:
        dane_map[res[0]] = res[1]

    pool = ThreadPool(None)
    for localpart in localparts:
        pool.apply_async(runner, (localpart,), {}, add_to_zone)

    pool.close()
    pool.join()
    # make the policy file:
    policyfile = open(path.join('openpgpkey', domain, 'policy'), 'wb')
    del policyfile
    # write out the zonefile all at once, ordered by the localpart
    with open(path.join('_openpgpkey.' + domain + '.zone'), 'wb') as zonefile:
        when = datetime.datetime.now()
        # FIXME: inspect serial number from existing zonefile --
        # update serial number if it was from the same day
        serial = 0
        zonefile.write(openpgpkey_zonefile_header(when, serial))
        for local in sorted(dane_map.keys()):
            zonefile.write(dane_map[local])


def fix_perms(path: str) -> None:
    '''Fix the permissions of a given directory

    Ensures all files/directories are owned and writeable by the keyring group.
    Additionally they must be readable by all and directories executable.
    '''
    try:
        chown(path, group="keyring")
        os.chmod(path, 0o775)
    except:
        pass
    for root, dirs, files in os.walk(path):
        for cur in dirs:
            try:
                chown(os.path.join(root, cur), group="keyring")
            except:
                pass
            try:
                os.chmod(os.path.join(root, cur), 0o775)
            except:
                pass
        for cur in files:
            try:
                chown(os.path.join(root, cur), group="keyring")
            except:
                pass
            try:
                os.chmod(os.path.join(root, cur), 0o664)
            except:
                pass


def publish(srcdir: str,
            where: str = None) -> None:
    '''Verify new keyrings in srcdir; if ok, then publish at where.

    Verification ensures that the new keyrings are signed-off by a
    member of debian's keyring-maint team.

    Publication consists of verifying the keyrings, placing them where
    onak can find them, produce a zonefile for OPENPGPKEY DANE
    records, and a tree of files for WKD.

    '''
    if where is None:
        prefix = os.environ.get('PREFIX', '/srv/keyring.debian.org')
    else:
        prefix = where
    pendingdir = path.join(prefix, 'pending-updates')
    hkpdir = path.join(prefix, 'keyrings-new')
    outputdir = path.join(prefix, 'pub')
    for direc in [srcdir, pendingdir, hkpdir, outputdir]:
        if not path.isdir(direc):
            raise Exception("%s is not a directory" % (direc))
    srcdir = path.realpath(srcdir)
    sha512fname = path.join(srcdir, 'sha512sums.txt')
    if not path.exists(sha512fname):
        raise Exception('sha512sums.txt not found in %s' % (srcdir))
    placeholder = path.join(srcdir, 'keyrings', '.placeholder')
    if path.exists(placeholder):
        os.unlink(placeholder)
    # gpgv needs the keyring in the filesystem, not just a file
    # descriptor (https://dev.gnupg.org/T4608)
    with tempfile.NamedTemporaryFile() as maint_keyring:
        maint_keyring.write(keyring_maint_keys())
        gpgvcall = [
            'gpgv',
            '--enable-special-filenames',
            '--keyring',
            maint_keyring.name,
            '--output',
            '-',
            sha512fname]
        gpgvout = run(gpgvcall, stderr=PIPE, stdout=PIPE)
        if gpgvout.returncode != 0:
            raise Exception("gpg verification failed:\n%s" %
                            (codecs.decode(gpgvout.stderr)))
    os.chdir(srcdir)
    files_to_check = set(
        path.join('keyrings', x + '.gpg') for x in [
            'debian-keyring',
            'debian-maintainers',
            'debian-nonupload',
            'debian-role-keys',
            'emeritus-keyring'])
    unexpected_files = set()
    for line in filter(lambda x: x, codecs.decode(gpgvout.stdout).split('\n')):
        (indigest, fname) = line.split()
        with open(fname, 'rb') as f:
            data = f.read()
        digest = hashlib.new('sha512', data=data).hexdigest()
        if digest != indigest:
            raise Exception(
                'mismatched digest for %s.\nWanted: %s\nGot: %s' %
                (fname, indigest, digest))
        if fname in files_to_check:
            files_to_check.remove(fname)
        else:
            unexpected_files.add(fname)
    if files_to_check:
        raise Exception('No sha512 digest found for: %s' % (files_to_check))
    if unexpected_files:
        print(
            'unexpected files (maybe add them to files_to_check):',
            unexpected_files)

    keyrings = ['keyring', 'maintainers', 'nonupload']
    for kname in keyrings:
        kfile = path.join(pendingdir, 'debian-%s.gpg' % (kname))
        if path.exists(kfile):
            raise Exception(
                'Unhandled pending updates.\nKeyrings in %s should be dealt with and removed' %
                (pendingdir))

    for kname in keyrings:
        kfile = path.join(hkpdir, 'debian-%s.gpg' % (kname))
        copy(kfile, pendingdir)

    print('Updating active keyrings.')
    copy_tree(srcdir, outputdir)
    fix_perms(outputdir)
    print('Updating HKP keyrings.')
    for kname in keyrings:
        kfile = path.join(srcdir, 'keyrings', 'debian-%s.gpg' % (kname))
        dst = os.path.join(hkpdir, os.path.basename(kfile))
        copyfile(kfile, dst)
    print('Publishing WKD and DANE data (may take a few minutes).')
    with tempfile.TemporaryDirectory(prefix='pub_staging_', dir=prefix) as wkd_staging:
        os.chdir(wkd_staging)

        def dkeyring(name: str) -> str:
            return path.join(srcdir, 'keyrings', 'debian-' + name + '.gpg')
        build_wkd_and_dane('debian.org',
                           [dkeyring(x) for x in [
                               'nonupload',
                               'keyring',
                               'role-keys']])
        wkd_deploy_path = path.join(prefix, 'openpgpkey')
        # not quite an atomic move:
        if path.isdir(wkd_deploy_path):
            os.rename(wkd_deploy_path, 'openpgpkey.old')
        os.rename('openpgpkey', wkd_deploy_path)
        fix_perms(wkd_deploy_path)
        os.rename(
            '_openpgpkey.debian.org.zone',
            path.join(prefix, '_openpgpkey.debian.org.zone'))
        fix_perms(path.join(prefix, '_openpgpkey.debian.org.zone'))
        os.chdir(srcdir)
    run(['static-update-component', 'openpgpkey.debian.org'], check=True)
    run(['sudo', 'service', 'bind9', 'reload'], check=True)


def openpgpkey_zonefile_header(timestamp: datetime.datetime = None, sequence: int = 0) -> bytes:
    '''Return static DNS RRs for _openpgpkey.debian.org

    These records were suggested by the Debian System Administration
    (DSA) team.
    '''
    if timestamp is None:
        timestamp = datetime.datetime.now()
    return b'''; zonefile for OPENPGPKEY records (RFC 7929) for debian.org
_openpgpkey.debian.org.  3600 IN SOA kaufmann.debian.org. hostmaster.debian.org. (
                                 %d%02d ; serial
                                 1800       ; refresh (30 minutes)
                                 600        ; retry (10 minutes)
                                 1814400    ; expire (3 weeks)
                                 600        ; minimum (10 minutes)
                                 )
_openpgpkey.debian.org. 28800 IN NS sec2.rcode0.net.
_openpgpkey.debian.org. 28800 IN NS nsp.dnsnode.net.
_openpgpkey.debian.org. 28800 IN NS sec1.rcode0.net.

''' % (int(timestamp.strftime('%Y%m%d')), sequence)


def keyring_maint_keys() -> bytes:
    '''Extract keyring-maint keys from the local system keyrings.

On DSA-managed hosts, /srv/keyring.debian.org/keyrings is more recent
and up-to-date so we prefer it.  On other hosts that have the
debian-keyring package installed, we can fall back to it.
    '''
    keyring_locations = [
        '/srv/keyring.debian.org/keyrings',
        '/usr/share/keyrings']
    keyrings = ['debian-keyring.gpg', 'debian-nonupload.gpg']
    keyring_maint_uids = ['Jonathan McDowell <noodles@earth.li>',
                          'William John Sullivan <johns@debian.org>',
                          'Gunnar Eyal Wolf Iszaevich <gwolf@debian.org>',
                          'Daniel Kahn Gillmor <dkg@debian.org>']
    keyring_files = None
    for loc in keyring_locations:
        possible_keyrings = [path.join(loc, k) for k in keyrings]
        if path.isdir(loc) and all(
                map(lambda k: path.exists(k), possible_keyrings)):
            keyring_files = possible_keyrings
            break

    if keyring_files is None:
        raise Exception(
            "Could not find keyrings to extract keyring-maint keys")

    gpgcmd = ['gpg',
              '--batch',
              '--homedir',
              '/dev/null',
              '--no-options',
              '--no-default-keyring',
              '--export-options',
              'export-minimal']
    for k in keyring_files:
        gpgcmd += ['--keyring', k]
    gpgcmd += ['--export']
    gpgcmd += ['=' + u for u in keyring_maint_uids]

    return run(gpgcmd, stdout=PIPE, check=True).stdout


if __name__ == '__main__':
    if len(sys.argv) < 2:
        raise Exception('Must provide directory containing new keyrings.')
    elif len(sys.argv) > 2:
        sys.argv.pop(0)
        subcommand = sys.argv.pop(0)
        if subcommand != 'build-wkd':
            raise Exception("do not know this subcommand: %s" % (subcommand))
        if len(sys.argv):
            domain = sys.argv.pop(0)
        else:
            domain = 'debian.org'
        if len(sys.argv):
            keys = sys.argv
        else:
            keys = ['/usr/share/keyrings/debian-nonupload.gpg',
                    '/usr/share/keyrings/debian-keyring.gpg',
                    '/usr/share/keyrings/debian-role-keys.gpg']
        build_wkd_and_dane(domain, keys)
    else:
        # standard update-keyrings
        check_environ()
        publish(sys.argv[1])
