"""Test the TCP implementation of UCSPI."""

import argparse
import dataclasses
import pathlib
import socket
import tempfile

from typing import Any

import utf8_locale

import ucspi_test


@dataclasses.dataclass(frozen=True)
class Config(ucspi_test.Config):
    """Runtime configuration for the TCP test runner."""

    tempd: pathlib.Path
    tempd_obj: tempfile.TemporaryDirectory[str]


class UnixRunner(ucspi_test.Runner):
    """Run ucspi-unix tests."""

    def find_listening_address(self) -> list[str]:
        """Find a local address/port combination."""
        print(f"{self.proto}.find_listening_address() starting")
        assert isinstance(self.cfg, Config), repr(self.cfg)

        sockpath = self.cfg.tempd / "listen.sock"
        assert not sockpath.exists() and not sockpath.is_symlink(), repr(sockpath)
        return [str(sockpath)]

    @property
    def supports_remote_info(self) -> bool:
        return False

    @property
    def logs_to_stdout(self) -> bool:
        return True

    def get_listening_socket(self, addr: list[str]) -> socket.socket:
        assert len(addr) == 1, repr(addr)
        sockpath = addr[0]

        try:
            spath = pathlib.Path(sockpath)
            if spath.exists():
                if not spath.is_socket():
                    raise ucspi_test.RunnerError(
                        f"Expected a Unix-domain socket at {sockpath!r}, got {spath.stat()!r}"
                    )
                print(f"- removing the {sockpath} socket")
                spath.unlink()
        except OSError as err:
            raise ucspi_test.RunnerError(f"Could not examine and clean {sockpath!r} up: {err}")

        try:
            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
        except OSError as err:
            raise ucspi_test.RunnerError(f"Could not create a Unix-domain socket: {err}")
        try:
            sock.bind(sockpath)
        except OSError as err:
            raise ucspi_test.RunnerError(
                f"Could not bind the {sock!r} socket to {sockpath!r}: {err}"
            )
        try:
            sock.listen(5)
        except OSError as err:
            raise ucspi_test.RunnerError(f"Could not listen on {sock!r}: {err}")

        return sock

    def get_connected_socket(self, addr: list[str]) -> socket.socket:
        if len(addr) != 1:
            raise ucspi_test.RunnerError(
                f"{self.proto}.get_connected_socket(): unexpected address length for {addr!r}"
            )
        sockpath = addr[0]

        try:
            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0)
        except OSError as err:
            raise ucspi_test.RunnerError(f"Could not create a Unix-domain socket: {err}") from err
        try:
            sock.connect(sockpath)
        except OSError as err:
            raise ucspi_test.RunnerError(
                f"Could not connect a Unix-domain socket to {sockpath!r}: {err}"
            ) from err

        return sock

    def format_local_addr(self, addr: list[str]) -> str:
        assert len(addr) == 1, repr(addr)
        return addr[0]

    def format_remote_addr(self, addr: Any) -> str:
        assert isinstance(addr, str), repr(addr)
        return addr


def parse_args() -> Config | None:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(prog="uctest")

    parser.add_argument(
        "-d", "--bindir", type=pathlib.Path, required=True, help="the path to the UCSPI utilities"
    )
    parser.add_argument(
        "-p", "--proto", type=str, required=True, help="the UCSPI protocol ('tcp', 'unix', etc)"
    )
    args = parser.parse_args()

    # pylint: disable-next=consider-using-with
    tempd_obj = tempfile.TemporaryDirectory(prefix="ucspi-unix-test.")
    tempd = pathlib.Path(tempd_obj.name)

    return Config(
        bindir=args.bindir.absolute(),
        proto=args.proto,
        tempd=tempd,
        tempd_obj=tempd_obj,
        utf8_env=utf8_locale.UTF8Detect().detect().env,
    )


def main() -> None:
    """Parse command-line arguments, run the tests."""
    cfg = parse_args()
    if cfg is None:
        print("No loopback interface addresses for the requested family")
        return

    ucspi_test.add_handler("unix", UnixRunner)
    ucspi_test.run_test_handler(cfg)


if __name__ == "__main__":
    main()
