#!/usr/bin/env python3
# 20180114
# Jan Mojzis
# Public domain.

import logging
import socket
import sys
import os
import time
import struct
import subprocess
import atexit

class SSH:
        """
        """

        SSH_MSG_KEXINIT = b'\x14'
        SSH_MSG_KEXDH_INIT = b'\x1e'

        def __init__(self, host = "127.0.0.1", port = 2222):
                """
                """

                self.timeout = 60
                try:
                        socket.setdefaulttimeout(self.timeout)
                        self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                        self.s.connect((host, port))
                except Exception as e:
                        logging.fatal("unable to connect to host='%s', port=%d: %s", host, port, e)
                        raise
                else:
                        logging.debug("connected to host='%s', port=%d", host, port)

        def close(self):
                """
                """

                self.s.close()

        def recv(self, l):
                """
                """

                r = self.s.recv(l)
                if len(r) == 0:
                        raise Exception("no data received")
                return r


        def packet_hello_receive(self):
                """
                """

                ret = b''
                try:
                        while True:
                                ch = self.recv(1)
                                ret += ch
                                if ch == b'\n':
                                        break
                except Exception as e:
                        logging.fatal("unable to receive server hello message: %s", e)
                        raise
                else:
                        logging.debug("server hello message received: %s", [ret])

                return ret


        def packet_hello_send(self, msg = b'SSH-2.0-autopkgtest\r\n'):
                """
                """

                try:
                        self.s.send(msg)
                except Exception as e:
                        logging.fatal("unable to send client hello message: %s", e)
                        raise
                else:
                        logging.debug("client hello message send: %s", [msg])

        def _parseuint32(self, x = b'', pos = 0):
                """
                """

                return struct.unpack_from('>I', x, pos)[0]

        def _sshstring(self, x = b'', pos = 0):
                """
                """

                ret = struct.pack('>I', len(x))
                return ret + x


        def packet_send(self, packet = b''):

                #padding
                paddinglen = 2 * 8 - ((len(packet) + 5) % 8)
                packet = struct.pack('B', paddinglen) + packet + paddinglen * b'\x00'


                #add length
                packet = self._sshstring(packet)

                try:
                        self.s.send(packet)
                except Exception as e:
                        logging.fatal("unable to send SSH packet: %s", e)
                        raise

        def packet_kex_receive(self):
                """
                """

                pos = 0
                ret = []

                try:
                        packet = self.recv(65536)
                except Exception as e:
                        logging.fatal("unable to receive server kex message: %s", e)
                        raise
                else:
                        pass

                try:
                        #packet length
                        l = self._parseuint32(packet, pos)
                        logging.debug("server kex: length = %d", l)
                        pos += 4

                        #padding length (skipping)
                        pos += 1

                        #packet type
                        if packet[pos] != self.SSH_MSG_KEXINIT[0]:
                                raise Exception("bad bad server kex message: not SSH_MSG_KEXINIT")
                        pos += 1

                        #cookie
                        cookie = packet[pos:pos+16]
                        logging.debug("server kex: cookie = %s", [cookie])
                        pos += 16

                        for text in [   'kex algorithms',
                                        'server host key algorithms',
                                        'encryption algorithms client to server',
                                        'encryption algorithms server to client',
                                        'mac algorithms client to server',
                                        'mac algorithms server to client',
                                        'compress algorithms client to server',
                                        'compress algorithms server to client',
                                        'languages client to server',
                                        'languages server to client' ]:

                                l = self._parseuint32(packet, pos)
                                pos += 4
                                data = packet[pos:pos+l]
                                pos += l
                                logging.debug("server kex: %s = %s", text, [data])
                                ret.append(data)

                        #XXX skipping the rest

                except Exception as e:
                        logging.fatal("unable to parse server kex message: %s", e)
                        raise

                return ret


        def packet_kex_send(self):
                """
                """

                #packet type
                packet = self.SSH_MSG_KEXINIT

                #cookie
                cookie = os.urandom(16)
                logging.debug("client kex: cookie = %s", [cookie])
                packet += cookie

                for text,data in [   ('kex algorithms',b'curve25519-sha256@libssh.org'),
                                ('server host key algorithms',b'ssh-ed25519',),
                                ('encryption algorithms client to server',b'chacha20-poly1305@openssh.com'),
                                ('encryption algorithms server to client',b'chacha20-poly1305@openssh.com'),
                                ('mac algorithms client to server',b'hmac-sha2-256'),
                                ('mac algorithms server to client',b'hmac-sha2-256'),
                                ('compress algorithms client to server',b'none'),
                                ('compress algorithms server to client',b'none'),
                                ('languages client to server',b''),
                                ('languages server to client',b'') ]:

                        packet += self._sshstring(data)
                        logging.debug("client kex: %s = %s", text, [data])

                #kex first packet follows
                packet += b'\x00'

                #reserved
                packet += b'\x00\x00\x00\x00'

                self.packet_send(packet)

        def packet_kexdh_send(self, pk):
                """
                """

                #packet type
                packet = self.SSH_MSG_KEXDH_INIT

                #pk
                logging.debug("client kexdh: pk = %s", [pk])
                packet += self._sshstring(pk)

                self.packet_send(packet)

        def packet_kexdh_receive(self):
                """
                """

                pos = 0

                try:
                        packet = self.recv(65536)
                except Exception as e:
                        #logging.fatal("unable to receive server kexdh message: %s", e)
                        raise Exception("unable to receive server kexdh message: %s" % e)
                else:
                        #logging.info("server kexdh message received: %s", [packet])
                        pass


def test_chacha20(ip = "127.0.0.1", port = 2222):
        """
        Tests if TinySSH has chacha20-poly1305@openssh.com.
        chacha20-poly1305@openssh.com is state-off-the-art ciphersuite,
        encrypts and authenticates data using 256bit secret-key.
        This cipher is safe against quantum computers.
        TinySSH must support chacha20-poly1305@openssh.com
        """

        logging.info("test_chacha20: %s", test_chacha20.__doc__)
        s = SSH(ip, port)
        s.packet_hello_receive()
        s.packet_hello_send(b'SSH-2.0-autopkgtest-test-chacha20\r\n')
        data = s.packet_kex_receive()
        if data[2].find(b'chacha20-poly1305@openssh.com') == -1:
                raise Exception("TinySSH doesn't have chacha20-poly1305@openssh.com")
        if data[3].find(b'chacha20-poly1305@openssh.com') == -1:
                raise Exception("TinySSH doesn't have chacha20-poly1305@openssh.com")
        s.close()
        logging.info("test_chacha20: OK\n")

def test_invalid_pk(ip = "127.0.0.1", port = 2222):
        """
        Tests if TinySSH accepts bad public-key.
        Client tries to send zero public-key.
        It's dangerous, because computed shared secret-key is also zero.
        TinySSH must immediately drop the connection.
        """

        logging.info("test_invalid_pk: %s", test_invalid_pk.__doc__)
        s = SSH(ip, port)
        s.packet_hello_receive()
        s.packet_hello_send(b'SSH-2.0-autopkgtest-invalid-pk\r\n')
        s.packet_kex_receive()
        s.packet_kex_send()
        # pk modulo p is zero
        pk = b'\xed' + 30*b'\xff' + b'\x7f'

        s.packet_kexdh_send(pk)
        try:
                s.packet_kexdh_receive()
        except Exception as e:
                logging.debug(e)
                pass
        else:
                raise Exception("TinySSH accepts bad public-key")
        s.close()
        logging.info("test_invalid_pk: OK\n")

def is_docker():
        """
        Detect docker
        """

        if os.path.exists('/.dockerenv'):
                return True
        else:
                return False
                        
def run_systemd_2222():
        """
        Run TinySSH on port 2222 using systemd.
        Default port 22 is used by OpenSSH server.
        """

        r = os.system("sed -i 's/^ListenStream=.*/ListenStream=2222/' /lib/systemd/system/tinysshd.socket")
        os.system("systemctl daemon-reload && systemctl restart tinysshd.socket")

server = None

def cleanup():
        if server != None:
                server.kill()

def run_tcpserver_2222(ip = "127.0.0.1", port = 2222):
        """
        Run TinySSH on port 2222 using tcpserver.
        Default port 22 is used by OpenSSH server.
        """
        os.system('tinysshd-makekey sshkeydir')
        server = subprocess.Popen(["tcpserver", "-HRDl0", ip, str(port), "tinysshd", "--", "sshkeydir"])
        time.sleep(1)

if __name__ == '__main__':

        try:
                ip = sys.argv[1]
        except IndexError:
                ip = "127.0.0.1"

        try:
                port = int(sys.argv[2])
        except IndexError:
                port = 2222

        logging.basicConfig(level=logging.DEBUG)

        os.environ["PATH"] += os.pathsep + "/usr/sbin"
        tmp = os.getenv("AUTOPKGTEST_TMP")
        if not tmp:
                tmp = '/tmp'
        os.chdir(tmp)

        #XXX
        if port == 2222:
                if is_docker():
                        run_tcpserver_2222()
                else:
                        run_systemd_2222()

        atexit.register(cleanup)

        test_chacha20(ip, port)
        test_invalid_pk(ip, port)
