omlish 0.0.0.dev217__py3-none-any.whl → 0.0.0.dev218__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
omlish/sockets/bind.py ADDED
@@ -0,0 +1,332 @@
1
+ # ruff: noqa: UP006 UP007
2
+ # @omlish-lite
3
+ """
4
+ TODO:
5
+ - DupSocketBinder
6
+ """
7
+ import abc
8
+ import dataclasses as dc
9
+ import errno
10
+ import os
11
+ import socket as socket_
12
+ import stat
13
+ import typing as ta
14
+
15
+ from omlish.lite.check import check
16
+ from omlish.lite.dataclasses import dataclass_maybe_post_init
17
+ from omlish.sockets.addresses import SocketAddress
18
+ from omlish.sockets.addresses import SocketAndAddress
19
+
20
+
21
+ SocketBinderT = ta.TypeVar('SocketBinderT', bound='SocketBinder')
22
+ SocketBinderConfigT = ta.TypeVar('SocketBinderConfigT', bound='SocketBinder.Config')
23
+
24
+
25
+ ##
26
+
27
+
28
+ class SocketBinder(abc.ABC, ta.Generic[SocketBinderConfigT]):
29
+ @dc.dataclass(frozen=True)
30
+ class Config:
31
+ listen_backlog: int = 5
32
+
33
+ allow_reuse_address: bool = True
34
+ allow_reuse_port: bool = True
35
+
36
+ set_inheritable: bool = False
37
+
38
+ #
39
+
40
+ @classmethod
41
+ def new(
42
+ cls,
43
+ target: ta.Union[
44
+ int,
45
+ ta.Tuple[str, int],
46
+ str,
47
+ ],
48
+ ) -> 'SocketBinder.Config':
49
+ if isinstance(target, int):
50
+ return TcpSocketBinder.Config(
51
+ port=target,
52
+ )
53
+
54
+ elif isinstance(target, tuple):
55
+ host, port = target
56
+ return TcpSocketBinder.Config(
57
+ host=host,
58
+ port=port,
59
+ )
60
+
61
+ elif isinstance(target, str):
62
+ return UnixSocketBinder.Config(
63
+ file=target,
64
+ )
65
+
66
+ else:
67
+ raise TypeError(target)
68
+
69
+ #
70
+
71
+ def __init__(self, config: SocketBinderConfigT) -> None:
72
+ super().__init__()
73
+
74
+ self._config = config
75
+
76
+ #
77
+
78
+ @classmethod
79
+ def new(cls, target: ta.Any) -> 'SocketBinder':
80
+ config: SocketBinder.Config
81
+ if isinstance(target, SocketBinder.Config):
82
+ config = target
83
+
84
+ else:
85
+ config = SocketBinder.Config.new(target)
86
+
87
+ if isinstance(config, TcpSocketBinder.Config):
88
+ return TcpSocketBinder(config)
89
+
90
+ elif isinstance(config, UnixSocketBinder.Config):
91
+ return UnixSocketBinder(config)
92
+
93
+ else:
94
+ raise TypeError(config)
95
+
96
+ #
97
+
98
+ class Error(RuntimeError):
99
+ pass
100
+
101
+ class NotBoundError(Error):
102
+ pass
103
+
104
+ class AlreadyBoundError(Error):
105
+ pass
106
+
107
+ #
108
+
109
+ @property
110
+ @abc.abstractmethod
111
+ def address_family(self) -> int:
112
+ raise NotImplementedError
113
+
114
+ @property
115
+ @abc.abstractmethod
116
+ def address(self) -> SocketAddress:
117
+ raise NotImplementedError
118
+
119
+ #
120
+
121
+ _socket: socket_.socket
122
+
123
+ @property
124
+ def is_bound(self) -> bool:
125
+ return hasattr(self, '_socket')
126
+
127
+ @property
128
+ def socket(self) -> socket_.socket:
129
+ try:
130
+ return self._socket
131
+ except AttributeError:
132
+ raise self.NotBoundError from None
133
+
134
+ _name: str
135
+
136
+ @property
137
+ def name(self) -> str:
138
+ try:
139
+ return self._name
140
+ except AttributeError:
141
+ raise self.NotBoundError from None
142
+
143
+ _port: ta.Optional[int]
144
+
145
+ @property
146
+ def port(self) -> ta.Optional[int]:
147
+ try:
148
+ return self._port
149
+ except AttributeError:
150
+ raise self.NotBoundError from None
151
+
152
+ #
153
+
154
+ def fileno(self) -> int:
155
+ return self.socket.fileno()
156
+
157
+ #
158
+
159
+ def __enter__(self: SocketBinderT) -> SocketBinderT:
160
+ self.bind()
161
+
162
+ return self
163
+
164
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
165
+ self.close()
166
+
167
+ #
168
+
169
+ def _init_socket(self) -> None:
170
+ if hasattr(self, '_socket'):
171
+ raise self.AlreadyBoundError
172
+
173
+ socket = socket_.socket(self.address_family, socket_.SOCK_STREAM)
174
+ self._socket = socket
175
+
176
+ if self._config.allow_reuse_address and hasattr(socket_, 'SO_REUSEADDR'):
177
+ socket.setsockopt(socket_.SOL_SOCKET, socket_.SO_REUSEADDR, 1)
178
+
179
+ # Since Linux 6.12.9, SO_REUSEPORT is not allowed on other address families than AF_INET/AF_INET6.
180
+ if (
181
+ self._config.allow_reuse_port and hasattr(socket_, 'SO_REUSEPORT') and
182
+ self.address_family in (socket_.AF_INET, socket_.AF_INET6)
183
+ ):
184
+ try:
185
+ socket.setsockopt(socket_.SOL_SOCKET, socket_.SO_REUSEPORT, 1)
186
+ except OSError as err:
187
+ if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL):
188
+ raise
189
+
190
+ if self._config.set_inheritable and hasattr(socket, 'set_inheritable'):
191
+ socket.set_inheritable(True)
192
+
193
+ def _pre_bind(self) -> None:
194
+ pass
195
+
196
+ def _post_bind(self) -> None:
197
+ pass
198
+
199
+ def bind(self) -> None:
200
+ self._init_socket()
201
+
202
+ self._pre_bind()
203
+
204
+ self.socket.bind(self.address)
205
+
206
+ self._post_bind()
207
+
208
+ check.state(all(hasattr(self, a) for a in ('_socket', '_name', '_port')))
209
+
210
+ #
211
+
212
+ def close(self) -> None:
213
+ if hasattr(self, '_socket'):
214
+ self._socket.close()
215
+
216
+ #
217
+
218
+ def listen(self) -> None:
219
+ self.socket.listen(self._config.listen_backlog)
220
+
221
+ @abc.abstractmethod
222
+ def accept(self, socket: ta.Optional[socket_.socket] = None) -> SocketAndAddress:
223
+ raise NotImplementedError
224
+
225
+
226
+ ##
227
+
228
+
229
+ class TcpSocketBinder(SocketBinder):
230
+ @dc.dataclass(frozen=True)
231
+ class Config(SocketBinder.Config):
232
+ DEFAULT_HOST: ta.ClassVar[str] = 'localhost'
233
+ host: str = DEFAULT_HOST
234
+
235
+ port: int = 0
236
+
237
+ def __post_init__(self) -> None:
238
+ dataclass_maybe_post_init(super())
239
+ check.non_empty_str(self.host)
240
+ check.isinstance(self.port, int)
241
+ check.arg(self.port > 0)
242
+
243
+ def __init__(self, config: Config) -> None:
244
+ super().__init__(check.isinstance(config, self.Config))
245
+
246
+ self._address = (config.host, config.port)
247
+
248
+ #
249
+
250
+ address_family = socket_.AF_INET
251
+
252
+ @property
253
+ def address(self) -> SocketAddress:
254
+ return self._address
255
+
256
+ #
257
+
258
+ def _post_bind(self) -> None:
259
+ super()._post_bind()
260
+
261
+ host, port, *_ = self.socket.getsockname()
262
+
263
+ self._name = socket_.getfqdn(host)
264
+ self._port = port
265
+
266
+ #
267
+
268
+ def accept(self, socket: ta.Optional[socket_.socket] = None) -> SocketAndAddress:
269
+ if socket is None:
270
+ socket = self.socket
271
+
272
+ conn, client_address = socket.accept()
273
+ return SocketAndAddress(conn, client_address)
274
+
275
+
276
+ ##
277
+
278
+
279
+ class UnixSocketBinder(SocketBinder):
280
+ @dc.dataclass(frozen=True)
281
+ class Config(SocketBinder.Config):
282
+ file: str = ''
283
+
284
+ unlink: bool = False
285
+
286
+ def __post_init__(self) -> None:
287
+ dataclass_maybe_post_init(super())
288
+ check.non_empty_str(self.file)
289
+
290
+ def __init__(self, config: Config) -> None:
291
+ super().__init__(check.isinstance(config, self.Config))
292
+
293
+ self._address = config.file
294
+
295
+ #
296
+
297
+ address_family = socket_.AF_UNIX
298
+
299
+ @property
300
+ def address(self) -> SocketAddress:
301
+ return self._address
302
+
303
+ #
304
+
305
+ def _pre_bind(self) -> None:
306
+ super()._pre_bind()
307
+
308
+ if self._config.unlink:
309
+ try:
310
+ os.unlink(self._config.file)
311
+ except FileNotFoundError:
312
+ pass
313
+
314
+ def _post_bind(self) -> None:
315
+ super()._post_bind()
316
+
317
+ name = self.socket.getsockname()
318
+
319
+ os.chmod(name, stat.S_IRWXU | stat.S_IRWXG) # noqa
320
+
321
+ self._name = name
322
+ self._port = None
323
+
324
+ #
325
+
326
+ def accept(self, sock: ta.Optional[socket_.socket] = None) -> SocketAndAddress:
327
+ if sock is None:
328
+ sock = self.socket
329
+
330
+ conn, _ = sock.accept()
331
+ client_address = ('', 0)
332
+ return SocketAndAddress(conn, client_address)
@@ -1,30 +1,12 @@
1
1
  # ruff: noqa: UP006 UP007
2
2
  # @omlish-lite
3
- import abc
4
3
  import typing as ta
5
4
 
6
5
  from .addresses import SocketAddress
6
+ from .io import SocketIoPair # noqa
7
7
 
8
8
 
9
- SocketHandlerFactory = ta.Callable[[SocketAddress, ta.BinaryIO, ta.BinaryIO], 'SocketHandler']
9
+ SocketHandler = ta.Callable[[SocketAddress, 'SocketIoPair'], None] # ta.TypeAlias
10
10
 
11
11
 
12
12
  ##
13
-
14
-
15
- class SocketHandler(abc.ABC):
16
- def __init__(
17
- self,
18
- client_address: SocketAddress,
19
- rfile: ta.BinaryIO,
20
- wfile: ta.BinaryIO,
21
- ) -> None:
22
- super().__init__()
23
-
24
- self._client_address = client_address
25
- self._rfile = rfile
26
- self._wfile = wfile
27
-
28
- @abc.abstractmethod
29
- def handle(self) -> None:
30
- raise NotImplementedError
omlish/sockets/io.py ADDED
@@ -0,0 +1,69 @@
1
+ # ruff: noqa: UP006 UP007
2
+ # @omlish-lite
3
+ import io
4
+ import socket
5
+ import typing as ta
6
+
7
+
8
+ ##
9
+
10
+
11
+ class SocketWriter(io.BufferedIOBase):
12
+ """
13
+ Simple writable BufferedIOBase implementation for a socket
14
+
15
+ Does not hold data in a buffer, avoiding any need to call flush().
16
+ """
17
+
18
+ def __init__(self, sock):
19
+ super().__init__()
20
+
21
+ self._sock = sock
22
+
23
+ def writable(self):
24
+ return True
25
+
26
+ def write(self, b):
27
+ self._sock.sendall(b)
28
+ with memoryview(b) as view:
29
+ return view.nbytes
30
+
31
+ def fileno(self):
32
+ return self._sock.fileno()
33
+
34
+
35
+ class SocketIoPair(ta.NamedTuple):
36
+ r: ta.BinaryIO
37
+ w: ta.BinaryIO
38
+
39
+ @classmethod
40
+ def from_socket(
41
+ cls,
42
+ sock: socket.socket,
43
+ *,
44
+ r_buf_size: int = -1,
45
+ w_buf_size: int = 0,
46
+ ) -> 'SocketIoPair':
47
+ rf: ta.Any = sock.makefile('rb', r_buf_size)
48
+
49
+ if w_buf_size:
50
+ wf: ta.Any = SocketWriter(sock)
51
+ else:
52
+ wf = sock.makefile('wb', w_buf_size)
53
+
54
+ return cls(rf, wf)
55
+
56
+
57
+ ##
58
+
59
+
60
+ def close_socket_immediately(sock: socket.socket) -> None:
61
+ try:
62
+ # Explicitly shutdown. socket.close() merely releases the socket and waits for GC to perform the actual close.
63
+ sock.shutdown(socket.SHUT_WR)
64
+
65
+ except OSError:
66
+ # Some platforms may raise ENOTCONN here
67
+ pass
68
+
69
+ sock.close()
File without changes
@@ -0,0 +1,99 @@
1
+ # @omlish-lite
2
+ # ruff: noqa: UP006 UP007
3
+ import dataclasses as dc
4
+ import socket
5
+ import typing as ta
6
+
7
+ from ..addresses import SocketAndAddress
8
+ from ..handlers import SocketHandler
9
+ from ..io import SocketIoPair
10
+ from ..io import close_socket_immediately
11
+
12
+
13
+ SocketServerHandler = ta.Callable[[SocketAndAddress], None] # ta.TypeAlias
14
+
15
+
16
+ ##
17
+
18
+
19
+ @dc.dataclass(frozen=True)
20
+ class StandardSocketServerHandler:
21
+ handler: SocketServerHandler
22
+
23
+ timeout: ta.Optional[float] = None
24
+
25
+ # http://bugs.python.org/issue6192
26
+ # TODO: https://eklitzke.org/the-caveats-of-tcp-nodelay
27
+ disable_nagle_algorithm: bool = False
28
+
29
+ no_close: bool = False
30
+
31
+ def __call__(self, conn: SocketAndAddress) -> None:
32
+ try:
33
+ if self.timeout is not None:
34
+ conn.socket.settimeout(self.timeout)
35
+
36
+ if self.disable_nagle_algorithm:
37
+ conn.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
38
+
39
+ self.handler(conn)
40
+
41
+ finally:
42
+ close_socket_immediately(conn.socket)
43
+
44
+
45
+ #
46
+
47
+
48
+ @dc.dataclass(frozen=True)
49
+ class CallbackWrappedSocketServerHandler:
50
+ handler: SocketServerHandler
51
+
52
+ before_handle: ta.Optional[SocketServerHandler] = None
53
+ after_handle: ta.Optional[SocketServerHandler] = None
54
+
55
+ # Return True if suppress like __exit__
56
+ on_error: ta.Optional[ta.Callable[[SocketAndAddress, Exception], bool]] = None
57
+
58
+ finally_: ta.Optional[SocketServerHandler] = None
59
+
60
+ def __call__(self, conn: SocketAndAddress) -> None:
61
+ try:
62
+ if (before_handle := self.before_handle) is not None:
63
+ before_handle(conn)
64
+
65
+ self.handler(conn)
66
+
67
+ except Exception as e:
68
+ if (on_error := self.on_error) is not None and on_error(conn, e):
69
+ pass
70
+ else:
71
+ raise
72
+
73
+ else:
74
+ if (after_handle := self.after_handle) is not None:
75
+ after_handle(conn)
76
+
77
+ finally:
78
+ if (finally_ := self.finally_) is not None:
79
+ finally_(conn)
80
+
81
+
82
+ #
83
+
84
+
85
+ @dc.dataclass(frozen=True)
86
+ class SocketHandlerSocketServerHandler:
87
+ handler: SocketHandler
88
+
89
+ r_buf_size: int = -1
90
+ w_buf_size: int = 0
91
+
92
+ def __call__(self, conn: SocketAndAddress) -> None:
93
+ fp = SocketIoPair.from_socket(
94
+ conn.socket,
95
+ r_buf_size=self.r_buf_size,
96
+ w_buf_size=self.w_buf_size,
97
+ )
98
+
99
+ self.handler(conn.address, fp)
@@ -0,0 +1,144 @@
1
+ # @omlish-lite
2
+ # ruff: noqa: UP006 UP007
3
+ import abc
4
+ import contextlib
5
+ import selectors
6
+ import threading
7
+ import typing as ta
8
+
9
+ from ..bind import SocketBinder
10
+ from .handlers import SocketServerHandler
11
+
12
+
13
+ ##
14
+
15
+
16
+ class SocketServer(abc.ABC):
17
+ def __init__(
18
+ self,
19
+ binder: SocketBinder,
20
+ handler: SocketServerHandler,
21
+ *,
22
+ on_error: ta.Optional[ta.Callable[[BaseException], None]] = None,
23
+ poll_interval: float = .5,
24
+ shutdown_timeout: ta.Optional[float] = None,
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self._binder = binder
29
+ self._handler = handler
30
+ self._on_error = on_error
31
+ self._poll_interval = poll_interval
32
+ self._shutdown_timeout = shutdown_timeout
33
+
34
+ self._lock = threading.RLock()
35
+ self._is_shutdown = threading.Event()
36
+ self._should_shutdown = False
37
+
38
+ @property
39
+ def binder(self) -> SocketBinder:
40
+ return self._binder
41
+
42
+ @property
43
+ def handler(self) -> SocketServerHandler:
44
+ return self._handler
45
+
46
+ #
47
+
48
+ class SelectorProtocol(ta.Protocol):
49
+ def register(self, *args, **kwargs) -> None:
50
+ raise NotImplementedError
51
+
52
+ def select(self, *args, **kwargs) -> bool:
53
+ raise NotImplementedError
54
+
55
+ Selector: ta.ClassVar[ta.Any]
56
+ if hasattr(selectors, 'PollSelector'):
57
+ Selector = selectors.PollSelector
58
+ else:
59
+ Selector = selectors.SelectSelector
60
+
61
+ #
62
+
63
+ @contextlib.contextmanager
64
+ def _listen_context(self) -> ta.Iterator[SelectorProtocol]:
65
+ with contextlib.ExitStack() as es:
66
+ es.enter_context(self._lock)
67
+ es.enter_context(self._binder)
68
+
69
+ self._binder.listen()
70
+
71
+ self._is_shutdown.clear()
72
+ try:
73
+ # XXX: Consider using another file descriptor or connecting to the socket to wake this up instead of
74
+ # polling. Polling reduces our responsiveness to a shutdown request and wastes cpu at all other times.
75
+ with self.Selector() as selector:
76
+ selector.register(self._binder.fileno(), selectors.EVENT_READ)
77
+
78
+ yield selector
79
+
80
+ finally:
81
+ self._is_shutdown.set()
82
+
83
+ @contextlib.contextmanager
84
+ def loop_context(self, poll_interval: ta.Optional[float] = None) -> ta.Iterator[ta.Iterator[bool]]:
85
+ if poll_interval is None:
86
+ poll_interval = self._poll_interval
87
+
88
+ with self._listen_context() as selector:
89
+ def loop():
90
+ while not self._should_shutdown:
91
+ ready = selector.select(poll_interval)
92
+
93
+ # bpo-35017: shutdown() called during select(), exit immediately.
94
+ if self._should_shutdown:
95
+ break # type: ignore[unreachable]
96
+
97
+ if ready:
98
+ try:
99
+ conn = self._binder.accept()
100
+
101
+ except OSError as exc:
102
+ if (on_error := self._on_error) is not None:
103
+ on_error(exc)
104
+
105
+ return
106
+
107
+ self._handler(conn)
108
+
109
+ yield bool(ready)
110
+
111
+ yield loop()
112
+
113
+ def run(self, poll_interval: ta.Optional[float] = None) -> None:
114
+ with self.loop_context(poll_interval=poll_interval) as loop:
115
+ for _ in loop:
116
+ pass
117
+
118
+ #
119
+
120
+ class _NOT_SET: # noqa
121
+ def __new__(cls, *args, **kwargs): # noqa
122
+ raise TypeError
123
+
124
+ def shutdown(
125
+ self,
126
+ block: bool = False,
127
+ timeout: ta.Union[float, None, ta.Type[_NOT_SET]] = _NOT_SET,
128
+ ) -> None:
129
+ self._should_shutdown = True
130
+
131
+ if block:
132
+ if timeout is self._NOT_SET:
133
+ timeout = self._shutdown_timeout
134
+
135
+ if not self._is_shutdown.wait(timeout=timeout): # type: ignore
136
+ raise TimeoutError
137
+
138
+ #
139
+
140
+ def __enter__(self) -> 'SocketServer':
141
+ return self
142
+
143
+ def __exit__(self, exc_type, exc_val, exc_tb):
144
+ self.shutdown()