#!/usr/bin/env python3

import os
import sys
import argparse
import base64
import binascii
import json

# For generation of the transient keypair
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.backends import default_backend
import cryptography

import htcondor

# For invoking the scitokens-cpp library
import cffi


# A global handle to the C library functions we use
g_C = None

# A global handle to the scitokens library we use
g_ST = None

# A global handle to FFI definitions
g_FFI = None


def generate_private_key():
    return ec.generate_private_key(
        ec.SECP256R1(),
                backend=default_backend()
    )


def generate_key_id(private_key):
    public_key = private_key.public_key()
    public_numbers = public_key.public_numbers()
    digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
    digest.update(base64.urlsafe_b64encode(cryptography.utils.int_to_bytes(public_numbers.x)))
    kid = binascii.hexlify(digest.finalize())
    return kid.decode('utf-8')[:6]


def generate_public_key_dict(private_key):
    public_key = private_key.public_key()
    public_numbers = public_key.public_numbers()

    return {   
        "alg": "ES256",
        "crv": "P-256",
        "x": base64.urlsafe_b64encode(cryptography.utils.int_to_bytes(public_numbers.x)).decode('ascii'),
        "y": base64.urlsafe_b64encode(cryptography.utils.int_to_bytes(public_numbers.y)).decode('ascii'),
        "kty": "EC",
        "use": "sig",
        "kid": generate_key_id(private_key)
    }


def generate_private_key_pem(private_key):
    return private_key.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.TraditionalOpenSSL,
        encryption_algorithm=serialization.NoEncryption()
    )


def init_ffi():
    global g_FFI
    g_FFI = cffi.FFI()
    g_FFI.cdef("""
    void free(void *);
    int keycache_refresh_jwks(const char *issuer, char **err_msg);
    int keycache_get_cached_jwks(const char *issuer, char **jwks, char **err_msg);
    int keycache_set_jwks(const char *issuer, const char *jwks, char **err_msg);
    void *scitoken_key_create(const char *key_id, const char *algorithm, const char *public_contents, const char *private_contents, char **err_msg);
    void scitoken_key_destroy(void *private_key);
    void *scitoken_create(void *private_key);
    void scitoken_destroy(void *token);
    int scitoken_set_claim_string(void *token, const char *key, const char *value, char **err_msg);
    void scitoken_set_lifetime(void *token, int lifetime);
    int scitoken_serialize(const void *token, char **value, char **err_msg);
    int scitoken_config_set_str(const char *key, const char *value, char **err_msg);
    """)
    global g_C
    g_C = g_FFI.dlopen(None)
    global g_ST
    lib_filename = "libSciTokens.so.0"
    if sys.platform == "darwin":
        lib_filename = "libSciTokens.0.dylib"
    g_ST = g_FFI.dlopen(lib_filename)


def configure_cache(cache):
    try:
        if cache == None:
            # Look up SEC_SCITOKENS_CACHE in condor config
            cache = htcondor.param["SEC_SCITOKENS_CACHE"]
        # TODO early-out if cache==None?
        err_msg = g_FFI.new("char**", g_FFI.NULL)
        rv = g_ST.scitoken_config_set_str(b"keycache.cache_home", cache.encode(), err_msg)
        if rv != 0:
            err_msg_str = "Unknown Error"
            if err_msg[0] != g_FFI.NULL:
                err_msg_str = g_FFI.string(err_msg[0]).decode()
                g_C.free(err_msg[0])
            raise RuntimeError(err_msg_str)
    except (AttributeError):
        # TODO ModuleNotFoundError: can't find htcondor module
        #      KeyError: can't find SEC_SCITOKENS_CACHE param
        pass
    return 0

def get_cached_jwks(issuer):
    jwks_result = g_FFI.new("char**", g_FFI.NULL)
    err_msg = g_FFI.new("char**", g_FFI.NULL)
    rv = g_ST.keycache_get_cached_jwks(issuer.encode(), jwks_result, err_msg)
    if rv == 0:
        if jwks_result == g_FFI.NULL:
            raise RuntimeError("Cached JWKS unexpectedly null.")
        jwks = json.loads(g_FFI.string(jwks_result[0]).decode())
        g_C.free(jwks_result[0])
        return jwks
    else:
        err_msg_str = "Unknown Error"
        if err_msg[0] != g_FFI.NULL:
            err_msg_str = g_FFI.string(err_msg[0]).decode() 
            g_C.free(err_msg[0])
        raise RuntimeError(err_msg_str)


def refresh_jwks(issuer):
    err_msg = g_FFI.new("char**", g_FFI.NULL)
    rv = g_ST.keycache_refresh_jwks(issuer.encode(), err_msg)
    if rv != 0:
       err_msg_str = "Unknown Error"
       if err_msg[0] != g_FFI.NULL:
           err_msg_str = g_FFI.string(err_msg[0]).decode() 
           g_C.free(err_msg[0])
       raise RuntimeError(err_msg_str)


def set_jwks(issuer, jwks_dict):
    err_msg = g_FFI.new("char**", g_FFI.NULL)
    rv = g_ST.keycache_set_jwks(issuer.encode(), json.dumps(jwks_dict).encode(),
        err_msg)
    if rv != 0:
       err_msg_str = "Unknown Error"
       if err_msg[0] != g_FFI.NULL:
           err_msg_str = g_FFI.string(err_msg[0]).decode()
           g_C.free(err_msg[0])
       raise RuntimeError(err_msg_str)


def inject_key(issuer, private_key):
    refresh_jwks(issuer)
    jwks = get_cached_jwks(issuer)
    pubkey = generate_public_key_dict(private_key)
    jwks['keys'].append(pubkey)
    set_jwks(issuer, jwks)


def configure_token(token, args):
    err_msg = g_FFI.new("char**", g_FFI.NULL)
    rv = g_ST.scitoken_set_claim_string(token, b"iss", args.issuer.encode(), err_msg)
    if rv != 0:
       err_msg_str = "Unknown Error"
       if err_msg[0] != g_FFI.NULL:
           err_msg_str = g_FFI.string(err_msg[0]).decode()
           g_C.free(err_msg[0])
       raise RuntimeError(err_msg_str)

    if args.subject is not None:
        err_msg = g_FFI.new("char**", g_FFI.NULL)
        rv = g_ST.scitoken_set_claim_string(token, b"sub", args.subject.encode(), err_msg)
        if rv != 0:
            err_msg_str = "Unknown Error"
            if err_msg[0] != g_FFI.NULL:
                err_msg_str = g_FFI.string(err_msg[0]).decode()
                g_C.free(err_msg[0])
            raise RuntimeError(err_msg_str)

    if args.scope is not None:
        err_msg = g_FFI.new("char**", g_FFI.NULL)
        rv = g_ST.scitoken_set_claim_string(token, b"scope", args.scope.encode(), err_msg)
        if rv != 0:
            err_msg_str = "Unknown Error"
            if err_msg[0] != g_FFI.NULL:
                err_msg_str = g_FFI.string(err_msg[0]).decode()
                g_C.free(err_msg[0])
            raise RuntimeError(err_msg_str)

    if args.audience is not None:
        err_msg = g_FFI.new("char**", g_FFI.NULL)
        rv = g_ST.scitoken_set_claim_string(token, b"aud", args.audience.encode(), err_msg)
        if rv != 0:
            err_msg_str = "Unknown Error"
            if err_msg[0] != g_FFI.NULL:
                err_msg_str = g_FFI.string(err_msg[0]).decode()
                g_C.free(err_msg[0])
            raise RuntimeError(err_msg_str)

    g_ST.scitoken_set_lifetime(token, args.lifetime)


def create_token(args, private_key):
    err_msg = g_FFI.new("char**", g_FFI.NULL)
    key = g_FFI.new("void**")
    key[0] = g_ST.scitoken_key_create(generate_key_id(private_key).encode(),
        b"ES256",
        b"", # Public PEM -- not needed
        generate_private_key_pem(private_key),
        err_msg)

    token = g_FFI.new("void**")
    token[0] = g_ST.scitoken_create(key[0])

    try:
        configure_token(token[0], args)
        value = g_FFI.new("char**", g_FFI.NULL)
        rv = g_ST.scitoken_serialize(token[0], value, err_msg)
        if rv == 0:
            if value[0] == g_FFI.NULL:
                raise RuntimeError("Empty token returned")
            return g_FFI.string(value[0]).decode()
        else:
            err_msg_str = "Unknown Error"
            if err_msg[0] != g_FFI.NULL:
                err_msg_str = g_FFI.string(err_msg[0]).decode()
                g_C.free(err_msg[0])
            raise RuntimeError(err_msg_str)
    finally:
        g_ST.scitoken_destroy(token[0])
        g_ST.scitoken_key_destroy(key[0])

def parse_args():
    usage = "%(prog)s [options]"
    parser = argparse.ArgumentParser(usage=usage)
    parser.add_argument("-i", "--issuer",
                        required=True,
                        help="Issuer URL to use for new token")
    parser.add_argument("-s", "--subject",
                        help="Subject to use for new token")
    parser.add_argument("-l", "--lifetime",
                        type=int, default=60*60,
                        help="Lifetime (in seconds) for new token")
    parser.add_argument("-S", "--scope",
                        required=True,
                        help="Scope to use for new token")
    parser.add_argument("-a", "--audience",
                        help="Audience to use for new token")
    parser.add_argument("-c", "--cache",
                        help="Location of cache directory")
    return parser.parse_args()


def main():
    args = parse_args()
    init_ffi()

    configure_cache(args.cache)
    private_key = generate_private_key()
    inject_key(args.issuer, private_key)

    private_pem = generate_private_key_pem(private_key)

    print(create_token(args, private_key))

if __name__ == '__main__':
    main()
