Source code for socks_router.models

from __future__ import annotations

from typing import Any, Annotated, Final, Literal, Optional, Type, Self, Protocol, runtime_checkable, overload, assert_never
from abc import abstractmethod
from collections.abc import Mapping, MutableMapping
from enum import IntEnum, StrEnum, auto
from dataclasses import dataclass, field
from subprocess import Popen, PIPE
from more_itertools import collapse, zip_broadcast

import threading
import ipaddress
import socks

SOCKS_VERSION: Literal[5] = 5

type PackingSequence = str | tuple[str, str]

type RecursiveMapping[K, V] = Mapping[K, V | RecursiveMapping[K, V]]

PACKABLE_DEFERRED_FORMAT: Final[str] = "&"
PACKABLE_VARIABLE_LENGTH_DECLARATION_FORMAT: Final[str] = "%*"


[docs] @runtime_checkable class Packable(Protocol):
[docs] @classmethod def __pack_format__(cls) -> str: ...
[docs] @runtime_checkable class SupportsUnbytes(Protocol):
[docs] @classmethod @abstractmethod def __unbytes__(cls, input: bytes) -> Self: ...
[docs] @dataclass(frozen=True) class SocketAddress: address: Any = field() port: Annotated[Optional[int], "!H"] = None
[docs] def __str__(self): if self.port is None: return f"{self.address}" return self.url_literal
@property def sockaddr(self) -> tuple[str, int]: return f"{self.address}", self.port or 0
[docs] def with_port(self, port: Optional[int]) -> Self: return type(self)(self.address, port)
[docs] def with_default_port(self, port: int) -> Self: return self.with_port(self.port or port)
@property def url_literal(self) -> str: return ":".join(map(str, filter(lambda x: x is not None, [self.address, self.port])))
[docs] @dataclass(frozen=True) class IPv4(SocketAddress): address: IPv4.IPv4Address
[docs] class IPv4Address(ipaddress.IPv4Address):
[docs] @classmethod def __pack_format__(cls) -> str: return "!4B"
[docs] def __bytes__(self) -> bytes: return self.packed
[docs] @classmethod def __unbytes__(cls, input: bytes) -> Self: return cls(input)
def __init__(self, address: str | IPv4.IPv4Address, *argv, **kwargs): if isinstance(address, str): address = IPv4.IPv4Address(address) super().__init__(address, *argv, **kwargs)
[docs] @dataclass(frozen=True) class IPv6(SocketAddress): address: IPv6.IPv6Address
[docs] class IPv6Address(ipaddress.IPv6Address):
[docs] @classmethod def __pack_format__(cls) -> str: return "!16B"
[docs] def __bytes__(self) -> bytes: return self.packed
[docs] @classmethod def __unbytes__(cls, input: bytes) -> Self: return cls(input)
def __init__(self, address: str | IPv6.IPv6Address, *argv, **kwargs): if isinstance(address, str): address = IPv6.IPv6Address(address) super().__init__(address, *argv, **kwargs) @property def url_literal(self) -> str: """SEE: https://www.ietf.org/rfc/rfc2732.txt""" return ":".join(map(str, filter(lambda x: x is not None, [f"[{self.address}]", self.port])))
[docs] @dataclass(frozen=True) class Host(SocketAddress): address: Annotated[str, "!B%*s"]
type Address = IPv4 | IPv6 | Host
[docs] class Socks5Method(IntEnum): """SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-3""" NO_AUTHENTICATION_REQUIRED = 0x00 GSSAPI = 0x01 USERNAME_PASSWORD = 0x02 # IANA_ASSIGNED = 0x03 # RESERVED_FOR_PRIVATE_METHODS = frozenset(range(0x80, 0xFF)) # 0x80..0xFE NO_ACCEPTABLE_METHODS = 0xFF
[docs] @classmethod def __pack_format__(cls) -> str: return "!B"
[docs] class Socks5Command(IntEnum): """SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4""" CONNECT = 0x01 BIND = 0x02 UDP_ASSOCIATE = 0x03
[docs] @classmethod def __pack_format__(cls) -> str: return "!B"
[docs] class Socks5AddressType(IntEnum): """SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4""" IPv4 = 0x01 # A fully-qualified domain name. The first octet of the address field contains the number of octets of name that follow, there is no terminating NUL octet. DOMAINNAME = 0x03 IPv6 = 0x04
[docs] @classmethod def __pack_format__(cls) -> str: return "!B"
[docs] @dataclass class Socks5MethodSelectionRequest: """Socks5 Header. SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-3 Header ------ | version | method_count | methods | | 1 byte | 1 byte | [method_count] bytes | """ version: Annotated[int, "!B"] methods: Annotated[list[int], "!B%*B"]
[docs] @dataclass class Socks5MethodSelectionResponse: """Socks5 Method Selection Response Method Selection Response ------------------------- | version | method | | 1 byte | 1 byte | """ version: Annotated[int, "!B"] method: Socks5Method
[docs] @dataclass(frozen=True) class Socks5Address: @classmethod @overload def address_type(cls, type: Literal[Socks5AddressType.IPv4]) -> Type[IPv4]: ... @classmethod @overload def address_type(cls, type: Literal[Socks5AddressType.DOMAINNAME]) -> Type[Host]: ... @classmethod @overload def address_type(cls, type: Literal[Socks5AddressType.IPv6]) -> Type[IPv6]: ...
[docs] @classmethod def address_type(cls, type: Socks5AddressType) -> Type[IPv4] | Type[Host] | Type[IPv6]: match type: case Socks5AddressType.IPv4: return IPv4 case Socks5AddressType.DOMAINNAME: return Host case Socks5AddressType.IPv6: return IPv6 case _ as unreachable: assert_never(unreachable)
type: Socks5AddressType sockaddr: Annotated[Address, "&", "type", "address_type"]
[docs] @classmethod def from_address(cls, address: Address) -> Self: match address: case IPv4(): return cls(Socks5AddressType.IPv4, address) case Host(): return cls(Socks5AddressType.DOMAINNAME, address) case IPv6(): return cls(Socks5AddressType.IPv6, address) case _ as unreachable: assert_never(unreachable)
[docs] @dataclass(frozen=True) class Socks5Request: """SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-4 Request ------- | version | cmd | rsv | atyp | dst.addr | dst.port | | 1 byte | 1 byte | 0x00 | 1 byte | 4-255 bytes | 2 bytes | """ version: Annotated[int, "!B"] command: Socks5Command reserved: Annotated[int, "!B"] destination: Socks5Address
[docs] class Socks5ReplyType(IntEnum): """SEE: https://datatracker.ietf.org/doc/html/rfc1928#section-6""" SUCCEEDED = 0x00 GENERAL_SOCKS_SERVER_FAILURE = 0x01 CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02 NETWORK_UNREACHABLE = 0x03 HOST_UNREACHABLE = 0x04 CONNECTION_REFUSED = 0x05 TTL_EXPIRED = 0x06 COMMAND_NOT_SUPPORTED = 0x07 ADDRESS_TYPE_NOT_SUPPORTED = 0x08 # UNASSIGNED = frozenset(range(0x09, 0x100)) # 0x09..0xFF
[docs] @classmethod def __pack_format__(cls) -> str: return "!B"
@property def message(self) -> str: return socks.SOCKS5_ERRORS[self.value]
[docs] @dataclass(frozen=True) class Socks5Reply: """Socks5 Reply Reply ----- | version | reply | rsv | atyp | dst.addr | dst.port | | 1 byte | 1 byte | 0x00 | 1 byte | 4-255 bytes | 2 bytes | """ version: Annotated[int, "!B"] reply: Socks5ReplyType reserved: Annotated[int, "!B"] = 0x00 server_bound_address: Socks5Address = Socks5Address.from_address(IPv4("0.0.0.0", 0))
[docs] class Socks5State(StrEnum): LISTEN = auto() HANDSHAKE = auto() REQUEST = auto() ESTABLISHED = auto() CLOSED = auto()
[docs] @dataclass(frozen=True) class Pattern: address: str is_positive_match: bool = True
[docs] def __str__(self): return ("" if self.is_positive_match else "!") + "%s" % self.address
[docs] class UpstreamScheme(StrEnum): SSH = auto() SOCKS5 = auto() SOCKS5H = auto() @property def default_port(self): match self: case UpstreamScheme.SSH: return 22 case UpstreamScheme.SOCKS5 | UpstreamScheme.SOCKS5H: return 1080 case _ as unreachable: assert_never(unreachable)
[docs] @dataclass(frozen=True) class UpstreamAddress(object): scheme: UpstreamScheme address: Address
[docs] def __str__(self): return f"{self.scheme}://{self.address}"
[docs] def with_default_port(self, port: Optional[int] = None) -> Self: return type(self)(self.scheme, self.address.with_default_port(port or self.scheme.default_port))
type RoutingEntry = list[Pattern] type RoutingTable = Mapping[UpstreamAddress, RoutingEntry] Socks5Addresses: Mapping[Socks5AddressType, type[Address]] = { Socks5AddressType.IPv4: IPv4, Socks5AddressType.IPv6: IPv6, Socks5AddressType.DOMAINNAME: Host, } Socks5AddressTypes = { IPv4: Socks5AddressType.IPv4, IPv6: Socks5AddressType.IPv6, Host: Socks5AddressType.DOMAINNAME, }
[docs] @dataclass(frozen=True) class SSHUpstream: ssh_client: Popen proxy_server: Address
[docs] @classmethod def create(cls, upstream: Address, proxy_server: Address, ssh_options: Optional[dict] = None) -> Self: return cls( Popen( [ "ssh", "-NT", "-D", f"{proxy_server.port}", *collapse(zip_broadcast("-o", [f"{key}={value}" for key, value in (ssh_options or {}).items()])), f"{upstream.address}", ] + ([] if upstream.port is None else ["-p", f"{upstream.port}"]), stdout=PIPE, stderr=PIPE, ), proxy_server, )
[docs] @dataclass(frozen=True) class ProxyUpstream: proxy_server: Address
type Upstream = SSHUpstream | ProxyUpstream
[docs] @dataclass(frozen=True) class RetryOptions: tries: int = -1 delay: float = 1 max_delay: Optional[float] = None backoff: float = 1 jitter: float = 0
[docs] @classmethod def exponential_backoff(cls, *argv, backoff=2, **kwargs): return cls(*argv, **dict(backoff=backoff, **kwargs))
[docs] @dataclass class ApplicationContext: name: str = "socks-router" routing_table: RoutingTable = field(default_factory=dict) # seconds ssh_connection_timeout: int = 10 # seconds remote_socket_timeout: Optional[float] = 10 # seconds proxy_poll_socket_timeout: float = 0.1 proxy_retry_options: RetryOptions = field(default_factory=RetryOptions.exponential_backoff) mutex: threading.Lock = field(default_factory=threading.Lock) upstreams: MutableMapping[UpstreamAddress, Upstream] = field(default_factory=dict) is_terminating: bool = False