from __future__ import annotations
from typing import Any, Annotated, Final, Literal, Optional, Type, Self, Protocol, runtime_checkable, overload, assert_never
from abc import abstractmethod
from collections.abc import Mapping, MutableMapping
from enum import IntEnum, StrEnum, auto
from dataclasses import dataclass, field
from subprocess import Popen, PIPE
from more_itertools import collapse, zip_broadcast
import threading
import ipaddress
import socks
SOCKS_VERSION: Literal[5] = 5
type PackingSequence = str | tuple[str, str]
type RecursiveMapping[K, V] = Mapping[K, V | RecursiveMapping[K, V]]
PACKABLE_DEFERRED_FORMAT: Final[str] = "&"
PACKABLE_VARIABLE_LENGTH_DECLARATION_FORMAT: Final[str] = "%*"
[docs]
@runtime_checkable
class Packable(Protocol):
[docs]
@runtime_checkable
class SupportsUnbytes(Protocol):
[docs]
@classmethod
@abstractmethod
def __unbytes__(cls, input: bytes) -> Self: ...
[docs]
@dataclass(frozen=True)
class SocketAddress:
address: Any = field()
port: Annotated[Optional[int], "!H"] = None
[docs]
def __str__(self):
if self.port is None:
return f"{self.address}"
return self.url_literal
@property
def sockaddr(self) -> tuple[str, int]:
return f"{self.address}", self.port or 0
[docs]
def with_port(self, port: Optional[int]) -> Self:
return type(self)(self.address, port)
[docs]
def with_default_port(self, port: int) -> Self:
return self.with_port(self.port or port)
@property
def url_literal(self) -> str:
return ":".join(map(str, filter(lambda x: x is not None, [self.address, self.port])))
[docs]
@dataclass(frozen=True)
class IPv4(SocketAddress):
address: IPv4.IPv4Address
[docs]
class IPv4Address(ipaddress.IPv4Address):
[docs]
def __bytes__(self) -> bytes:
return self.packed
[docs]
@classmethod
def __unbytes__(cls, input: bytes) -> Self:
return cls(input)
def __init__(self, address: str | IPv4.IPv4Address, *argv, **kwargs):
if isinstance(address, str):
address = IPv4.IPv4Address(address)
super().__init__(address, *argv, **kwargs)
[docs]
@dataclass(frozen=True)
class IPv6(SocketAddress):
address: IPv6.IPv6Address
[docs]
class IPv6Address(ipaddress.IPv6Address):
[docs]
def __bytes__(self) -> bytes:
return self.packed
[docs]
@classmethod
def __unbytes__(cls, input: bytes) -> Self:
return cls(input)
def __init__(self, address: str | IPv6.IPv6Address, *argv, **kwargs):
if isinstance(address, str):
address = IPv6.IPv6Address(address)
super().__init__(address, *argv, **kwargs)
@property
def url_literal(self) -> str:
"""SEE: https://www.ietf.org/rfc/rfc2732.txt"""
return ":".join(map(str, filter(lambda x: x is not None, [f"[{self.address}]", self.port])))
[docs]
@dataclass(frozen=True)
class Host(SocketAddress):
address: Annotated[str, "!B%*s"]
type Address = IPv4 | IPv6 | Host
[docs]
class Socks5Method(IntEnum):
"""SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-3"""
NO_AUTHENTICATION_REQUIRED = 0x00
GSSAPI = 0x01
USERNAME_PASSWORD = 0x02
# IANA_ASSIGNED = 0x03
# RESERVED_FOR_PRIVATE_METHODS = frozenset(range(0x80, 0xFF)) # 0x80..0xFE
NO_ACCEPTABLE_METHODS = 0xFF
[docs]
class Socks5Command(IntEnum):
"""SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4"""
CONNECT = 0x01
BIND = 0x02
UDP_ASSOCIATE = 0x03
[docs]
class Socks5AddressType(IntEnum):
"""SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4"""
IPv4 = 0x01
# A fully-qualified domain name. The first octet of the address field contains the number of octets of name that follow, there is no terminating NUL octet.
DOMAINNAME = 0x03
IPv6 = 0x04
[docs]
@dataclass
class Socks5MethodSelectionRequest:
"""Socks5 Header.
SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-3
Header
------
| version | method_count | methods |
| 1 byte | 1 byte | [method_count] bytes |
"""
version: Annotated[int, "!B"]
methods: Annotated[list[int], "!B%*B"]
[docs]
@dataclass
class Socks5MethodSelectionResponse:
"""Socks5 Method Selection Response
Method Selection Response
-------------------------
| version | method |
| 1 byte | 1 byte |
"""
version: Annotated[int, "!B"]
method: Socks5Method
[docs]
@dataclass(frozen=True)
class Socks5Address:
@classmethod
@overload
def address_type(cls, type: Literal[Socks5AddressType.IPv4]) -> Type[IPv4]: ...
@classmethod
@overload
def address_type(cls, type: Literal[Socks5AddressType.DOMAINNAME]) -> Type[Host]: ...
@classmethod
@overload
def address_type(cls, type: Literal[Socks5AddressType.IPv6]) -> Type[IPv6]: ...
[docs]
@classmethod
def address_type(cls, type: Socks5AddressType) -> Type[IPv4] | Type[Host] | Type[IPv6]:
match type:
case Socks5AddressType.IPv4:
return IPv4
case Socks5AddressType.DOMAINNAME:
return Host
case Socks5AddressType.IPv6:
return IPv6
case _ as unreachable:
assert_never(unreachable)
type: Socks5AddressType
sockaddr: Annotated[Address, "&", "type", "address_type"]
[docs]
@classmethod
def from_address(cls, address: Address) -> Self:
match address:
case IPv4():
return cls(Socks5AddressType.IPv4, address)
case Host():
return cls(Socks5AddressType.DOMAINNAME, address)
case IPv6():
return cls(Socks5AddressType.IPv6, address)
case _ as unreachable:
assert_never(unreachable)
[docs]
@dataclass(frozen=True)
class Socks5Request:
"""SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4
Request
-------
| version | cmd | rsv | atyp | dst.addr | dst.port |
| 1 byte | 1 byte | 0x00 | 1 byte | 4-255 bytes | 2 bytes |
"""
version: Annotated[int, "!B"]
command: Socks5Command
reserved: Annotated[int, "!B"]
destination: Socks5Address
[docs]
class Socks5ReplyType(IntEnum):
"""SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-6"""
SUCCEEDED = 0x00
GENERAL_SOCKS_SERVER_FAILURE = 0x01
CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02
NETWORK_UNREACHABLE = 0x03
HOST_UNREACHABLE = 0x04
CONNECTION_REFUSED = 0x05
TTL_EXPIRED = 0x06
COMMAND_NOT_SUPPORTED = 0x07
ADDRESS_TYPE_NOT_SUPPORTED = 0x08
# UNASSIGNED = frozenset(range(0x09, 0x100)) # 0x09..0xFF
@property
def message(self) -> str:
return socks.SOCKS5_ERRORS[self.value]
[docs]
@dataclass(frozen=True)
class Socks5Reply:
"""Socks5 Reply
Reply
-----
| version | reply | rsv | atyp | dst.addr | dst.port |
| 1 byte | 1 byte | 0x00 | 1 byte | 4-255 bytes | 2 bytes |
"""
version: Annotated[int, "!B"]
reply: Socks5ReplyType
reserved: Annotated[int, "!B"] = 0x00
server_bound_address: Socks5Address = Socks5Address.from_address(IPv4("0.0.0.0", 0))
[docs]
class Socks5State(StrEnum):
LISTEN = auto()
HANDSHAKE = auto()
REQUEST = auto()
ESTABLISHED = auto()
CLOSED = auto()
[docs]
@dataclass(frozen=True)
class Pattern:
address: str
is_positive_match: bool = True
[docs]
def __str__(self):
return ("" if self.is_positive_match else "!") + "%s" % self.address
[docs]
class UpstreamScheme(StrEnum):
SSH = auto()
SOCKS5 = auto()
SOCKS5H = auto()
@property
def default_port(self):
match self:
case UpstreamScheme.SSH:
return 22
case UpstreamScheme.SOCKS5 | UpstreamScheme.SOCKS5H:
return 1080
case _ as unreachable:
assert_never(unreachable)
[docs]
@dataclass(frozen=True)
class UpstreamAddress(object):
scheme: UpstreamScheme
address: Address
[docs]
def __str__(self):
return f"{self.scheme}://{self.address}"
[docs]
def with_default_port(self, port: Optional[int] = None) -> Self:
return type(self)(self.scheme, self.address.with_default_port(port or self.scheme.default_port))
type RoutingEntry = list[Pattern]
type RoutingTable = Mapping[UpstreamAddress, RoutingEntry]
Socks5Addresses: Mapping[Socks5AddressType, type[Address]] = {
Socks5AddressType.IPv4: IPv4,
Socks5AddressType.IPv6: IPv6,
Socks5AddressType.DOMAINNAME: Host,
}
Socks5AddressTypes = {
IPv4: Socks5AddressType.IPv4,
IPv6: Socks5AddressType.IPv6,
Host: Socks5AddressType.DOMAINNAME,
}
[docs]
@dataclass(frozen=True)
class SSHUpstream:
ssh_client: Popen
proxy_server: Address
[docs]
@classmethod
def create(cls, upstream: Address, proxy_server: Address, ssh_options: Optional[dict] = None) -> Self:
return cls(
Popen(
[
"ssh",
"-NT",
"-D",
f"{proxy_server.port}",
*collapse(zip_broadcast("-o", [f"{key}={value}" for key, value in (ssh_options or {}).items()])),
f"{upstream.address}",
]
+ ([] if upstream.port is None else ["-p", f"{upstream.port}"]),
stdout=PIPE,
stderr=PIPE,
),
proxy_server,
)
[docs]
@dataclass(frozen=True)
class ProxyUpstream:
proxy_server: Address
type Upstream = SSHUpstream | ProxyUpstream
[docs]
@dataclass(frozen=True)
class RetryOptions:
tries: int = -1
delay: float = 1
max_delay: Optional[float] = None
backoff: float = 1
jitter: float = 0
[docs]
@classmethod
def exponential_backoff(cls, *argv, backoff=2, **kwargs):
return cls(*argv, **dict(backoff=backoff, **kwargs))
[docs]
@dataclass
class ApplicationContext:
name: str = "socks-router"
routing_table: RoutingTable = field(default_factory=dict)
# seconds
ssh_connection_timeout: int = 10
# seconds
remote_socket_timeout: Optional[float] = 10
# seconds
proxy_poll_socket_timeout: float = 0.1
proxy_retry_options: RetryOptions = field(default_factory=RetryOptions.exponential_backoff)
mutex: threading.Lock = field(default_factory=threading.Lock)
upstreams: MutableMapping[UpstreamAddress, Upstream] = field(default_factory=dict)
is_terminating: bool = False