plain 0.74.0__py3-none-any.whl → 0.76.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,150 @@
1
+ from __future__ import annotations
2
+
3
+ #
4
+ #
5
+ # This file is part of gunicorn released under the MIT license.
6
+ # See the LICENSE for more information.
7
+ #
8
+ # Vendored and modified for Plain.
9
+
10
+ # We don't need to call super() in __init__ methods of our
11
+ # BaseException and Exception classes because we also define
12
+ # our own __str__ methods so there is no need to pass 'message'
13
+ # to the base class to get a meaningful output from 'str(exc)'.
14
+ # pylint: disable=super-init-not-called
15
+
16
+
17
+ class ParseException(Exception):
18
+ pass
19
+
20
+
21
+ class NoMoreData(IOError):
22
+ def __init__(self, buf: bytes | None = None):
23
+ self.buf = buf
24
+
25
+ def __str__(self) -> str:
26
+ return f"No more data after: {self.buf!r}"
27
+
28
+
29
+ class ConfigurationProblem(ParseException):
30
+ def __init__(self, info: str):
31
+ self.info = info
32
+ self.code = 500
33
+
34
+ def __str__(self) -> str:
35
+ return f"Configuration problem: {self.info}"
36
+
37
+
38
+ class InvalidRequestLine(ParseException):
39
+ def __init__(self, req: str):
40
+ self.req = req
41
+ self.code = 400
42
+
43
+ def __str__(self) -> str:
44
+ return f"Invalid HTTP request line: {self.req!r}"
45
+
46
+
47
+ class InvalidRequestMethod(ParseException):
48
+ def __init__(self, method: str):
49
+ self.method = method
50
+
51
+ def __str__(self) -> str:
52
+ return f"Invalid HTTP method: {self.method!r}"
53
+
54
+
55
+ class InvalidHTTPVersion(ParseException):
56
+ def __init__(self, version: str):
57
+ self.version = version
58
+
59
+ def __str__(self) -> str:
60
+ return f"Invalid HTTP Version: {self.version!r}"
61
+
62
+
63
+ class InvalidHeader(ParseException):
64
+ def __init__(self, hdr: str, req: str | None = None):
65
+ self.hdr = hdr
66
+ self.req = req
67
+
68
+ def __str__(self) -> str:
69
+ return f"Invalid HTTP Header: {self.hdr!r}"
70
+
71
+
72
+ class ObsoleteFolding(ParseException):
73
+ def __init__(self, hdr: str):
74
+ self.hdr = hdr
75
+
76
+ def __str__(self) -> str:
77
+ return f"Obsolete line folding is unacceptable: {self.hdr!r}"
78
+
79
+
80
+ class InvalidHeaderName(ParseException):
81
+ def __init__(self, hdr: str):
82
+ self.hdr = hdr
83
+
84
+ def __str__(self) -> str:
85
+ return f"Invalid HTTP header name: {self.hdr!r}"
86
+
87
+
88
+ class UnsupportedTransferCoding(ParseException):
89
+ def __init__(self, hdr: str):
90
+ self.hdr = hdr
91
+ self.code = 501
92
+
93
+ def __str__(self) -> str:
94
+ return f"Unsupported transfer coding: {self.hdr!r}"
95
+
96
+
97
+ class InvalidChunkSize(IOError):
98
+ def __init__(self, data: bytes):
99
+ self.data = data
100
+
101
+ def __str__(self) -> str:
102
+ return f"Invalid chunk size: {self.data!r}"
103
+
104
+
105
+ class ChunkMissingTerminator(IOError):
106
+ def __init__(self, term: bytes):
107
+ self.term = term
108
+
109
+ def __str__(self) -> str:
110
+ return f"Invalid chunk terminator is not '\\r\\n': {self.term!r}"
111
+
112
+
113
+ class LimitRequestLine(ParseException):
114
+ def __init__(self, size: int, max_size: int):
115
+ self.size = size
116
+ self.max_size = max_size
117
+
118
+ def __str__(self) -> str:
119
+ return f"Request Line is too large ({self.size} > {self.max_size})"
120
+
121
+
122
+ class LimitRequestHeaders(ParseException):
123
+ def __init__(self, msg: str):
124
+ self.msg = msg
125
+
126
+ def __str__(self) -> str:
127
+ return self.msg
128
+
129
+
130
+ class InvalidProxyLine(ParseException):
131
+ def __init__(self, line: str):
132
+ self.line = line
133
+ self.code = 400
134
+
135
+ def __str__(self) -> str:
136
+ return f"Invalid PROXY line: {self.line!r}"
137
+
138
+
139
+ class ForbiddenProxyRequest(ParseException):
140
+ def __init__(self, host: str):
141
+ self.host = host
142
+ self.code = 403
143
+
144
+ def __str__(self) -> str:
145
+ return f"Proxy request from {self.host!r} not allowed"
146
+
147
+
148
+ class InvalidSchemeHeaders(ParseException):
149
+ def __str__(self) -> str:
150
+ return "Contradictory scheme headers"
@@ -0,0 +1,399 @@
1
+ from __future__ import annotations
2
+
3
+ #
4
+ #
5
+ # This file is part of gunicorn released under the MIT license.
6
+ # See the LICENSE for more information.
7
+ #
8
+ # Vendored and modified for Plain.
9
+ import io
10
+ import re
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ from ..util import bytes_to_str, split_request_uri
14
+ from .body import Body, ChunkedReader, EOFReader, LengthReader
15
+ from .errors import (
16
+ InvalidHeader,
17
+ InvalidHeaderName,
18
+ InvalidHTTPVersion,
19
+ InvalidRequestLine,
20
+ InvalidRequestMethod,
21
+ InvalidSchemeHeaders,
22
+ LimitRequestHeaders,
23
+ LimitRequestLine,
24
+ NoMoreData,
25
+ ObsoleteFolding,
26
+ UnsupportedTransferCoding,
27
+ )
28
+
29
+ if TYPE_CHECKING:
30
+ from ..config import Config
31
+
32
+ MAX_REQUEST_LINE = 8190
33
+ MAX_HEADERS = 32768
34
+ DEFAULT_MAX_HEADERFIELD_SIZE = 8190
35
+
36
+ # Request size limits for DDoS protection
37
+ LIMIT_REQUEST_LINE = 4094 # Maximum HTTP request line size in bytes
38
+ LIMIT_REQUEST_FIELDS = 100 # Maximum number of HTTP header fields
39
+ LIMIT_REQUEST_FIELD_SIZE = 8190 # Maximum size of an HTTP header field in bytes
40
+
41
+ # verbosely on purpose, avoid backslash ambiguity
42
+ RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~"
43
+ TOKEN_RE = re.compile(rf"[{re.escape(RFC9110_5_6_2_TOKEN_SPECIALS)}0-9a-zA-Z]+")
44
+ METHOD_BADCHAR_RE = re.compile("[a-z#]")
45
+ # usually 1.0 or 1.1 - RFC9112 permits restricting to single-digit versions
46
+ VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
47
+ RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
48
+
49
+
50
+ class Message:
51
+ def __init__(self, cfg: Config, unreader: Any, peer_addr: tuple[str, int] | Any):
52
+ self.cfg = cfg
53
+ self.unreader = unreader
54
+ self.peer_addr = peer_addr
55
+ self.remote_addr = peer_addr
56
+ self.version: tuple[int, int] | None = None
57
+ self.headers: list[tuple[str, str]] = []
58
+ self.trailers: list[tuple[str, str]] = []
59
+ self.body: Body | None = None
60
+ self.scheme = "https" if cfg.is_ssl else "http"
61
+ self.must_close = False
62
+
63
+ # set headers limits
64
+ self.limit_request_fields = LIMIT_REQUEST_FIELDS
65
+ if self.limit_request_fields <= 0 or self.limit_request_fields > MAX_HEADERS:
66
+ self.limit_request_fields = MAX_HEADERS
67
+ self.limit_request_field_size = LIMIT_REQUEST_FIELD_SIZE
68
+ if self.limit_request_field_size < 0:
69
+ self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
70
+
71
+ # set max header buffer size
72
+ max_header_field_size = (
73
+ self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
74
+ )
75
+ self.max_buffer_headers = (
76
+ self.limit_request_fields * (max_header_field_size + 2) + 4
77
+ )
78
+
79
+ unused = self.parse(self.unreader)
80
+ self.unreader.unread(unused)
81
+ self.set_body_reader()
82
+
83
+ def force_close(self) -> None:
84
+ self.must_close = True
85
+
86
+ def parse(self, unreader: Any) -> bytes:
87
+ raise NotImplementedError()
88
+
89
+ def parse_headers(
90
+ self, data: bytes, from_trailer: bool = False
91
+ ) -> list[tuple[str, str]]:
92
+ cfg = self.cfg
93
+ headers = []
94
+
95
+ # Split lines on \r\n
96
+ lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
97
+
98
+ # handle scheme headers
99
+ scheme_header = False
100
+ secure_scheme_headers = {}
101
+ forwarder_headers = []
102
+ if from_trailer:
103
+ # nonsense. either a request is https from the beginning
104
+ # .. or we are just behind a proxy who does not remove conflicting trailers
105
+ pass
106
+ elif (
107
+ "*" in cfg.forwarded_allow_ips
108
+ or not isinstance(self.peer_addr, tuple)
109
+ or self.peer_addr[0] in cfg.forwarded_allow_ips
110
+ ):
111
+ secure_scheme_headers = cfg.secure_scheme_headers
112
+ forwarder_headers = cfg.forwarder_headers
113
+
114
+ # Parse headers into key/value pairs paying attention
115
+ # to continuation lines.
116
+ while lines:
117
+ if len(headers) >= self.limit_request_fields:
118
+ raise LimitRequestHeaders("limit request headers fields")
119
+
120
+ # Parse initial header name: value pair.
121
+ curr = lines.pop(0)
122
+ header_length = len(curr) + len("\r\n")
123
+ if curr.find(":") <= 0:
124
+ raise InvalidHeader(curr)
125
+ name, value = curr.split(":", 1)
126
+ if not TOKEN_RE.fullmatch(name):
127
+ raise InvalidHeaderName(name)
128
+
129
+ # this is still a dangerous place to do this
130
+ # but it is more correct than doing it before the pattern match:
131
+ # after we entered Unicode wonderland, 8bits could case-shift into ASCII:
132
+ # b"\xDF".decode("latin-1").upper().encode("ascii") == b"SS"
133
+ name = name.upper()
134
+
135
+ value = [value.strip(" \t")]
136
+
137
+ # Consume value continuation lines..
138
+ while lines and lines[0].startswith((" ", "\t")):
139
+ # Obsolete folding is not permitted (RFC 7230)
140
+ raise ObsoleteFolding(name)
141
+ value = " ".join(value)
142
+
143
+ if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value):
144
+ raise InvalidHeader(name)
145
+
146
+ if header_length > self.limit_request_field_size > 0:
147
+ raise LimitRequestHeaders("limit request headers fields size")
148
+
149
+ if name in secure_scheme_headers:
150
+ secure = value == secure_scheme_headers[name]
151
+ scheme = "https" if secure else "http"
152
+ if scheme_header:
153
+ if scheme != self.scheme:
154
+ raise InvalidSchemeHeaders()
155
+ else:
156
+ scheme_header = True
157
+ self.scheme = scheme
158
+
159
+ # ambiguous mapping allows fooling downstream, e.g. merging non-identical headers:
160
+ # X-Forwarded-For: 2001:db8::ha:cc:ed
161
+ # X_Forwarded_For: 127.0.0.1,::1
162
+ # HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1
163
+ # Only modify after fixing *ALL* header transformations; network to wsgi env
164
+ if "_" in name:
165
+ if name in forwarder_headers or "*" in forwarder_headers:
166
+ # This forwarder may override our environment
167
+ pass
168
+ elif self.cfg.header_map == "dangerous":
169
+ # as if we did not know we cannot safely map this
170
+ pass
171
+ elif self.cfg.header_map == "drop":
172
+ # almost as if it never had been there
173
+ # but still counts against resource limits
174
+ continue
175
+ else:
176
+ # fail-safe fallthrough: refuse
177
+ raise InvalidHeaderName(name)
178
+
179
+ headers.append((name, value))
180
+
181
+ return headers
182
+
183
+ def set_body_reader(self) -> None:
184
+ chunked = False
185
+ content_length_str: str | None = None
186
+
187
+ for name, value in self.headers:
188
+ if name == "CONTENT-LENGTH":
189
+ if content_length_str is not None:
190
+ raise InvalidHeader("CONTENT-LENGTH", req=self)
191
+ content_length_str = value
192
+ elif name == "TRANSFER-ENCODING":
193
+ # T-E can be a list
194
+ # https://datatracker.ietf.org/doc/html/rfc9112#name-transfer-encoding
195
+ vals = [v.strip() for v in value.split(",")]
196
+ for val in vals:
197
+ if val.lower() == "chunked":
198
+ # DANGER: transfer codings stack, and stacked chunking is never intended
199
+ if chunked:
200
+ raise InvalidHeader("TRANSFER-ENCODING", req=self)
201
+ chunked = True
202
+ elif val.lower() == "identity":
203
+ # does not do much, could still plausibly desync from what the proxy does
204
+ # safe option: nuke it, its never needed
205
+ if chunked:
206
+ raise InvalidHeader("TRANSFER-ENCODING", req=self)
207
+ elif val.lower() in ("compress", "deflate", "gzip"):
208
+ # chunked should be the last one
209
+ if chunked:
210
+ raise InvalidHeader("TRANSFER-ENCODING", req=self)
211
+ self.force_close()
212
+ else:
213
+ raise UnsupportedTransferCoding(value)
214
+
215
+ if chunked:
216
+ # two potentially dangerous cases:
217
+ # a) CL + TE (TE overrides CL.. only safe if the recipient sees it that way too)
218
+ # b) chunked HTTP/1.0 (always faulty)
219
+ if self.version < (1, 1):
220
+ # framing wonky, see RFC 9112 Section 6.1
221
+ raise InvalidHeader("TRANSFER-ENCODING", req=self)
222
+ if content_length_str is not None:
223
+ # we cannot be certain the message framing we understood matches proxy intent
224
+ # -> whatever happens next, remaining input must not be trusted
225
+ raise InvalidHeader("CONTENT-LENGTH", req=self)
226
+ self.body = Body(ChunkedReader(self, self.unreader))
227
+ elif content_length_str is not None:
228
+ content_length: int
229
+ try:
230
+ if str(content_length_str).isnumeric():
231
+ content_length = int(content_length_str)
232
+ else:
233
+ raise InvalidHeader("CONTENT-LENGTH", req=self)
234
+ except ValueError:
235
+ raise InvalidHeader("CONTENT-LENGTH", req=self)
236
+
237
+ if content_length < 0:
238
+ raise InvalidHeader("CONTENT-LENGTH", req=self)
239
+
240
+ self.body = Body(LengthReader(self.unreader, content_length))
241
+ else:
242
+ self.body = Body(EOFReader(self.unreader))
243
+
244
+ def should_close(self) -> bool:
245
+ if self.must_close:
246
+ return True
247
+ for h, v in self.headers:
248
+ if h == "CONNECTION":
249
+ v = v.lower().strip(" \t")
250
+ if v == "close":
251
+ return True
252
+ elif v == "keep-alive":
253
+ return False
254
+ break
255
+ return self.version <= (1, 0) # type: ignore[operator]
256
+
257
+
258
+ class Request(Message):
259
+ def __init__(
260
+ self,
261
+ cfg: Config,
262
+ unreader: Any,
263
+ peer_addr: tuple[str, int] | Any,
264
+ req_number: int = 1,
265
+ ):
266
+ self.method: str | None = None
267
+ self.uri: str | None = None
268
+ self.path: str | None = None
269
+ self.query: str | None = None
270
+ self.fragment: str | None = None
271
+
272
+ # get max request line size
273
+ self.limit_request_line = LIMIT_REQUEST_LINE
274
+ if self.limit_request_line < 0 or self.limit_request_line >= MAX_REQUEST_LINE:
275
+ self.limit_request_line = MAX_REQUEST_LINE
276
+
277
+ self.req_number = req_number
278
+ super().__init__(cfg, unreader, peer_addr)
279
+
280
+ def get_data(self, unreader: Any, buf: io.BytesIO, stop: bool = False) -> None:
281
+ data = unreader.read()
282
+ if not data:
283
+ if stop:
284
+ raise StopIteration()
285
+ raise NoMoreData(buf.getvalue())
286
+ buf.write(data)
287
+
288
+ def parse(self, unreader: Any) -> bytes:
289
+ buf = io.BytesIO()
290
+ self.get_data(unreader, buf, stop=True)
291
+
292
+ # get request line
293
+ line, rbuf = self.read_line(unreader, buf, self.limit_request_line)
294
+
295
+ self.parse_request_line(line)
296
+ buf = io.BytesIO()
297
+ buf.write(rbuf)
298
+
299
+ # Headers
300
+ data = buf.getvalue()
301
+ idx = data.find(b"\r\n\r\n")
302
+
303
+ done = data[:2] == b"\r\n"
304
+ while True:
305
+ idx = data.find(b"\r\n\r\n")
306
+ done = data[:2] == b"\r\n"
307
+
308
+ if idx < 0 and not done:
309
+ self.get_data(unreader, buf)
310
+ data = buf.getvalue()
311
+ if len(data) > self.max_buffer_headers:
312
+ raise LimitRequestHeaders("max buffer headers")
313
+ else:
314
+ break
315
+
316
+ if done:
317
+ self.unreader.unread(data[2:])
318
+ return b""
319
+
320
+ self.headers = self.parse_headers(data[:idx], from_trailer=False)
321
+
322
+ ret = data[idx + 4 :]
323
+ buf = None
324
+ return ret
325
+
326
+ def read_line(
327
+ self, unreader: Any, buf: io.BytesIO, limit: int = 0
328
+ ) -> tuple[bytes, bytes]:
329
+ data = buf.getvalue()
330
+
331
+ while True:
332
+ idx = data.find(b"\r\n")
333
+ if idx >= 0:
334
+ # check if the request line is too large
335
+ if idx > limit > 0:
336
+ raise LimitRequestLine(idx, limit)
337
+ break
338
+ if len(data) - 2 > limit > 0:
339
+ raise LimitRequestLine(len(data), limit)
340
+ self.get_data(unreader, buf)
341
+ data = buf.getvalue()
342
+
343
+ return (
344
+ data[:idx], # request line,
345
+ data[idx + 2 :],
346
+ ) # residue in the buffer, skip \r\n
347
+
348
+ def parse_request_line(self, line_bytes: bytes) -> None:
349
+ bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
350
+ if len(bits) != 3:
351
+ raise InvalidRequestLine(bytes_to_str(line_bytes))
352
+
353
+ # Method: RFC9110 Section 9
354
+ self.method = bits[0]
355
+
356
+ # Enforce IANA-style method restrictions
357
+ if METHOD_BADCHAR_RE.search(self.method):
358
+ raise InvalidRequestMethod(self.method)
359
+ if not 3 <= len(bits[0]) <= 20:
360
+ raise InvalidRequestMethod(self.method)
361
+ # Standard restriction: RFC9110 token
362
+ if not TOKEN_RE.fullmatch(self.method):
363
+ raise InvalidRequestMethod(self.method)
364
+
365
+ # URI
366
+ self.uri = bits[1]
367
+
368
+ # Python stdlib explicitly tells us it will not perform validation.
369
+ # https://docs.python.org/3/library/urllib.parse.html#url-parsing-security
370
+ # There are *four* `request-target` forms in rfc9112, none of them can be empty:
371
+ # 1. origin-form, which starts with a slash
372
+ # 2. absolute-form, which starts with a non-empty scheme
373
+ # 3. authority-form, (for CONNECT) which contains a colon after the host
374
+ # 4. asterisk-form, which is an asterisk (`\x2A`)
375
+ # => manually reject one always invalid URI: empty
376
+ if len(self.uri) == 0:
377
+ raise InvalidRequestLine(bytes_to_str(line_bytes))
378
+
379
+ try:
380
+ parts = split_request_uri(self.uri)
381
+ except ValueError:
382
+ raise InvalidRequestLine(bytes_to_str(line_bytes))
383
+ self.path = parts.path or ""
384
+ self.query = parts.query or ""
385
+ self.fragment = parts.fragment or ""
386
+
387
+ # Version
388
+ match = VERSION_RE.fullmatch(bits[2])
389
+ if match is None:
390
+ raise InvalidHTTPVersion(bits[2])
391
+ self.version = (int(match.group(1)), int(match.group(2)))
392
+ if not (1, 0) <= self.version < (2, 0):
393
+ # Only HTTP/1.0 and HTTP/1.1 are supported
394
+ raise InvalidHTTPVersion(self.version)
395
+
396
+ def set_body_reader(self) -> None:
397
+ super().set_body_reader()
398
+ if isinstance(self.body.reader, EOFReader): # type: ignore[union-attr]
399
+ self.body = Body(LengthReader(self.unreader, 0))
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Iterator
4
+
5
+ #
6
+ #
7
+ # This file is part of gunicorn released under the MIT license.
8
+ # See the LICENSE for more information.
9
+ #
10
+ # Vendored and modified for Plain.
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ from .message import Request
14
+ from .unreader import IterUnreader, SocketUnreader
15
+
16
+ if TYPE_CHECKING:
17
+ import socket
18
+
19
+ from ..config import Config
20
+
21
+
22
+ class Parser:
23
+ mesg_class = None
24
+
25
+ def __init__(
26
+ self,
27
+ cfg: Config,
28
+ source: socket.socket | Any,
29
+ source_addr: tuple[str, int] | Any,
30
+ ) -> None:
31
+ self.cfg = cfg
32
+ if hasattr(source, "recv"):
33
+ self.unreader = SocketUnreader(source)
34
+ else:
35
+ self.unreader = IterUnreader(source)
36
+ self.mesg = None
37
+ self.source_addr = source_addr
38
+
39
+ # request counter (for keepalive connetions)
40
+ self.req_count = 0
41
+
42
+ def __iter__(self) -> Iterator[Request]:
43
+ return self
44
+
45
+ def __next__(self) -> Request:
46
+ # Stop if HTTP dictates a stop.
47
+ if self.mesg and self.mesg.should_close():
48
+ raise StopIteration()
49
+
50
+ # Discard any unread body of the previous message
51
+ if self.mesg:
52
+ data = self.mesg.body.read(8192)
53
+ while data:
54
+ data = self.mesg.body.read(8192)
55
+
56
+ # Parse the next request
57
+ self.req_count += 1
58
+ self.mesg = self.mesg_class(
59
+ self.cfg, self.unreader, self.source_addr, self.req_count
60
+ )
61
+ if not self.mesg:
62
+ raise StopIteration()
63
+ return self.mesg
64
+
65
+ next = __next__
66
+
67
+
68
+ class RequestParser(Parser):
69
+ mesg_class = Request
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ #
4
+ #
5
+ # This file is part of gunicorn released under the MIT license.
6
+ # See the LICENSE for more information.
7
+ #
8
+ # Vendored and modified for Plain.
9
+ import io
10
+ import os
11
+ import socket
12
+ from collections.abc import Iterable, Iterator
13
+ from typing import TYPE_CHECKING
14
+
15
+ if TYPE_CHECKING:
16
+ pass
17
+
18
+ # Classes that can undo reading data from
19
+ # a given type of data source.
20
+
21
+
22
+ class Unreader:
23
+ def __init__(self):
24
+ self.buf = io.BytesIO()
25
+
26
+ def chunk(self) -> bytes:
27
+ raise NotImplementedError()
28
+
29
+ def read(self, size: int | None = None) -> bytes:
30
+ if size is not None and not isinstance(size, int):
31
+ raise TypeError("size parameter must be an int or long.")
32
+
33
+ if size is not None:
34
+ if size == 0:
35
+ return b""
36
+ if size < 0:
37
+ size = None
38
+
39
+ self.buf.seek(0, os.SEEK_END)
40
+
41
+ if size is None and self.buf.tell():
42
+ ret = self.buf.getvalue()
43
+ self.buf = io.BytesIO()
44
+ return ret
45
+ if size is None:
46
+ d = self.chunk()
47
+ return d
48
+
49
+ while self.buf.tell() < size:
50
+ chunk = self.chunk()
51
+ if not chunk:
52
+ ret = self.buf.getvalue()
53
+ self.buf = io.BytesIO()
54
+ return ret
55
+ self.buf.write(chunk)
56
+ data = self.buf.getvalue()
57
+ self.buf = io.BytesIO()
58
+ self.buf.write(data[size:])
59
+ return data[:size]
60
+
61
+ def unread(self, data: bytes) -> None:
62
+ self.buf.seek(0, os.SEEK_END)
63
+ self.buf.write(data)
64
+
65
+
66
+ class SocketUnreader(Unreader):
67
+ def __init__(self, sock: socket.socket, max_chunk: int = 8192):
68
+ super().__init__()
69
+ self.sock = sock
70
+ self.mxchunk = max_chunk
71
+
72
+ def chunk(self) -> bytes:
73
+ return self.sock.recv(self.mxchunk)
74
+
75
+
76
+ class IterUnreader(Unreader):
77
+ def __init__(self, iterable: Iterable[bytes]):
78
+ super().__init__()
79
+ self.iter: Iterator[bytes] | None = iter(iterable)
80
+
81
+ def chunk(self) -> bytes:
82
+ if not self.iter:
83
+ return b""
84
+ try:
85
+ return next(self.iter)
86
+ except StopIteration:
87
+ self.iter = None
88
+ return b""