omlish 0.0.0.dev123__py3-none-any.whl → 0.0.0.dev125__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,376 @@
1
+ # ruff: noqa: UP006 UP007
2
+ import abc
3
+ import http.client
4
+ import http.server
5
+ import io
6
+ import typing as ta
7
+
8
+ from .versions import HttpProtocolVersion
9
+ from .versions import HttpProtocolVersions
10
+
11
+
12
+ T = ta.TypeVar('T')
13
+
14
+
15
+ HttpHeaders = http.client.HTTPMessage # ta.TypeAlias
16
+
17
+
18
+ ##
19
+
20
+
21
+ class ParseHttpRequestResult(abc.ABC): # noqa
22
+ __slots__ = (
23
+ 'server_version',
24
+ 'request_line',
25
+ 'request_version',
26
+ 'version',
27
+ 'headers',
28
+ 'close_connection',
29
+ )
30
+
31
+ def __init__(
32
+ self,
33
+ *,
34
+ server_version: HttpProtocolVersion,
35
+ request_line: str,
36
+ request_version: HttpProtocolVersion,
37
+ version: HttpProtocolVersion,
38
+ headers: ta.Optional[HttpHeaders],
39
+ close_connection: bool,
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ self.server_version = server_version
44
+ self.request_line = request_line
45
+ self.request_version = request_version
46
+ self.version = version
47
+ self.headers = headers
48
+ self.close_connection = close_connection
49
+
50
+ def __repr__(self) -> str:
51
+ return f'{self.__class__.__name__}({", ".join(f"{a}={getattr(self, a)!r}" for a in self.__slots__)})'
52
+
53
+
54
+ class EmptyParsedHttpResult(ParseHttpRequestResult):
55
+ pass
56
+
57
+
58
+ class ParseHttpRequestError(ParseHttpRequestResult):
59
+ __slots__ = (
60
+ 'code',
61
+ 'message',
62
+ *ParseHttpRequestResult.__slots__,
63
+ )
64
+
65
+ def __init__(
66
+ self,
67
+ *,
68
+ code: http.HTTPStatus,
69
+ message: ta.Union[str, ta.Tuple[str, str]],
70
+
71
+ **kwargs: ta.Any,
72
+ ) -> None:
73
+ super().__init__(**kwargs)
74
+
75
+ self.code = code
76
+ self.message = message
77
+
78
+
79
+ class ParsedHttpRequest(ParseHttpRequestResult):
80
+ __slots__ = (
81
+ 'method',
82
+ 'path',
83
+ 'headers',
84
+ 'expects_continue',
85
+ *[a for a in ParseHttpRequestResult.__slots__ if a != 'headers'],
86
+ )
87
+
88
+ def __init__(
89
+ self,
90
+ *,
91
+ method: str,
92
+ path: str,
93
+ headers: HttpHeaders,
94
+ expects_continue: bool,
95
+
96
+ **kwargs: ta.Any,
97
+ ) -> None:
98
+ super().__init__(
99
+ headers=headers,
100
+ **kwargs,
101
+ )
102
+
103
+ self.method = method
104
+ self.path = path
105
+ self.expects_continue = expects_continue
106
+
107
+ headers: HttpHeaders
108
+
109
+
110
+ #
111
+
112
+
113
+ class HttpRequestParser:
114
+ DEFAULT_SERVER_VERSION = HttpProtocolVersions.HTTP_1_0
115
+
116
+ # The default request version. This only affects responses up until the point where the request line is parsed, so
117
+ # it mainly decides what the client gets back when sending a malformed request line.
118
+ # Most web servers default to HTTP 0.9, i.e. don't send a status line.
119
+ DEFAULT_REQUEST_VERSION = HttpProtocolVersions.HTTP_0_9
120
+
121
+ #
122
+
123
+ DEFAULT_MAX_LINE: int = 0x10000
124
+ DEFAULT_MAX_HEADERS: int = 100
125
+
126
+ #
127
+
128
+ def __init__(
129
+ self,
130
+ *,
131
+ server_version: HttpProtocolVersion = DEFAULT_SERVER_VERSION,
132
+
133
+ max_line: int = DEFAULT_MAX_LINE,
134
+ max_headers: int = DEFAULT_MAX_HEADERS,
135
+ ) -> None:
136
+ super().__init__()
137
+
138
+ if server_version >= HttpProtocolVersions.HTTP_2_0:
139
+ raise ValueError(f'Unsupported protocol version: {server_version}')
140
+ self._server_version = server_version
141
+
142
+ self._max_line = max_line
143
+ self._max_headers = max_headers
144
+
145
+ #
146
+
147
+ @property
148
+ def server_version(self) -> HttpProtocolVersion:
149
+ return self._server_version
150
+
151
+ #
152
+
153
+ def _run_read_line_coro(
154
+ self,
155
+ gen: ta.Generator[int, bytes, T],
156
+ read_line: ta.Callable[[int], bytes],
157
+ ) -> T:
158
+ sz = next(gen)
159
+ while True:
160
+ try:
161
+ sz = gen.send(read_line(sz))
162
+ except StopIteration as e:
163
+ return e.value
164
+
165
+ #
166
+
167
+ def parse_request_version(self, version_str: str) -> HttpProtocolVersion:
168
+ if not version_str.startswith('HTTP/'):
169
+ raise ValueError(version_str) # noqa
170
+
171
+ base_version_number = version_str.split('/', 1)[1]
172
+ version_number_parts = base_version_number.split('.')
173
+
174
+ # RFC 2145 section 3.1 says there can be only one "." and
175
+ # - major and minor numbers MUST be treated as separate integers;
176
+ # - HTTP/2.4 is a lower version than HTTP/2.13, which in turn is lower than HTTP/12.3;
177
+ # - Leading zeros MUST be ignored by recipients.
178
+ if len(version_number_parts) != 2:
179
+ raise ValueError(version_number_parts) # noqa
180
+ if any(not component.isdigit() for component in version_number_parts):
181
+ raise ValueError('non digit in http version') # noqa
182
+ if any(len(component) > 10 for component in version_number_parts):
183
+ raise ValueError('unreasonable length http version') # noqa
184
+
185
+ return HttpProtocolVersion(
186
+ int(version_number_parts[0]),
187
+ int(version_number_parts[1]),
188
+ )
189
+
190
+ #
191
+
192
+ def coro_read_raw_headers(self) -> ta.Generator[int, bytes, ta.List[bytes]]:
193
+ raw_headers: ta.List[bytes] = []
194
+ while True:
195
+ line = yield self._max_line + 1
196
+ if len(line) > self._max_line:
197
+ raise http.client.LineTooLong('header line')
198
+ raw_headers.append(line)
199
+ if len(raw_headers) > self._max_headers:
200
+ raise http.client.HTTPException(f'got more than {self._max_headers} headers')
201
+ if line in (b'\r\n', b'\n', b''):
202
+ break
203
+ return raw_headers
204
+
205
+ def read_raw_headers(self, read_line: ta.Callable[[int], bytes]) -> ta.List[bytes]:
206
+ return self._run_read_line_coro(self.coro_read_raw_headers(), read_line)
207
+
208
+ def parse_raw_headers(self, raw_headers: ta.Sequence[bytes]) -> HttpHeaders:
209
+ return http.client.parse_headers(io.BytesIO(b''.join(raw_headers)))
210
+
211
+ #
212
+
213
+ def coro_parse(self) -> ta.Generator[int, bytes, ParseHttpRequestResult]:
214
+ raw_request_line = yield self._max_line + 1
215
+
216
+ # Common result kwargs
217
+
218
+ request_line = '-'
219
+ request_version = self.DEFAULT_REQUEST_VERSION
220
+
221
+ # Set to min(server, request) when it gets that far, but if it fails before that the server authoritatively
222
+ # responds with its own version.
223
+ version = self._server_version
224
+
225
+ headers: HttpHeaders | None = None
226
+
227
+ close_connection = True
228
+
229
+ def result_kwargs():
230
+ return dict(
231
+ server_version=self._server_version,
232
+ request_line=request_line,
233
+ request_version=request_version,
234
+ version=version,
235
+ headers=headers,
236
+ close_connection=close_connection,
237
+ )
238
+
239
+ # Decode line
240
+
241
+ if len(raw_request_line) > self._max_line:
242
+ return ParseHttpRequestError(
243
+ code=http.HTTPStatus.REQUEST_URI_TOO_LONG,
244
+ message='Request line too long',
245
+ **result_kwargs(),
246
+ )
247
+
248
+ if not raw_request_line:
249
+ return EmptyParsedHttpResult(**result_kwargs())
250
+
251
+ request_line = raw_request_line.decode('iso-8859-1').rstrip('\r\n')
252
+
253
+ # Split words
254
+
255
+ words = request_line.split()
256
+ if len(words) == 0:
257
+ return EmptyParsedHttpResult(**result_kwargs())
258
+
259
+ # Parse and set version
260
+
261
+ if len(words) >= 3: # Enough to determine protocol version
262
+ version_str = words[-1]
263
+ try:
264
+ request_version = self.parse_request_version(version_str)
265
+
266
+ except (ValueError, IndexError):
267
+ return ParseHttpRequestError(
268
+ code=http.HTTPStatus.BAD_REQUEST,
269
+ message=f'Bad request version ({version_str!r})',
270
+ **result_kwargs(),
271
+ )
272
+
273
+ if (
274
+ request_version < HttpProtocolVersions.HTTP_0_9 or
275
+ request_version >= HttpProtocolVersions.HTTP_2_0
276
+ ):
277
+ return ParseHttpRequestError(
278
+ code=http.HTTPStatus.HTTP_VERSION_NOT_SUPPORTED,
279
+ message=f'Invalid HTTP version ({version_str})',
280
+ **result_kwargs(),
281
+ )
282
+
283
+ version = min([self._server_version, request_version])
284
+
285
+ if version >= HttpProtocolVersions.HTTP_1_1:
286
+ close_connection = False
287
+
288
+ # Verify word count
289
+
290
+ if not 2 <= len(words) <= 3:
291
+ return ParseHttpRequestError(
292
+ code=http.HTTPStatus.BAD_REQUEST,
293
+ message=f'Bad request syntax ({request_line!r})',
294
+ **result_kwargs(),
295
+ )
296
+
297
+ # Parse method and path
298
+
299
+ method, path = words[:2]
300
+ if len(words) == 2:
301
+ close_connection = True
302
+ if method != 'GET':
303
+ return ParseHttpRequestError(
304
+ code=http.HTTPStatus.BAD_REQUEST,
305
+ message=f'Bad HTTP/0.9 request type ({method!r})',
306
+ **result_kwargs(),
307
+ )
308
+
309
+ # gh-87389: The purpose of replacing '//' with '/' is to protect against open redirect attacks possibly
310
+ # triggered if the path starts with '//' because http clients treat //path as an absolute URI without scheme
311
+ # (similar to http://path) rather than a path.
312
+ if path.startswith('//'):
313
+ path = '/' + path.lstrip('/') # Reduce to a single /
314
+
315
+ # Parse headers
316
+
317
+ try:
318
+ raw_gen = self.coro_read_raw_headers()
319
+ raw_sz = next(raw_gen)
320
+ while True:
321
+ buf = yield raw_sz
322
+ try:
323
+ raw_sz = raw_gen.send(buf)
324
+ except StopIteration as e:
325
+ raw_headers = e.value
326
+ break
327
+
328
+ headers = self.parse_raw_headers(raw_headers)
329
+
330
+ except http.client.LineTooLong as err:
331
+ return ParseHttpRequestError(
332
+ code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
333
+ message=('Line too long', str(err)),
334
+ **result_kwargs(),
335
+ )
336
+
337
+ except http.client.HTTPException as err:
338
+ return ParseHttpRequestError(
339
+ code=http.HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
340
+ message=('Too many headers', str(err)),
341
+ **result_kwargs(),
342
+ )
343
+
344
+ # Check for connection directive
345
+
346
+ conn_type = headers.get('Connection', '')
347
+ if conn_type.lower() == 'close':
348
+ close_connection = True
349
+ elif (
350
+ conn_type.lower() == 'keep-alive' and
351
+ version >= HttpProtocolVersions.HTTP_1_1
352
+ ):
353
+ close_connection = False
354
+
355
+ # Check for expect directive
356
+
357
+ expect = headers.get('Expect', '')
358
+ if (
359
+ expect.lower() == '100-continue' and
360
+ version >= HttpProtocolVersions.HTTP_1_1
361
+ ):
362
+ expects_continue = True
363
+ else:
364
+ expects_continue = False
365
+
366
+ # Return
367
+
368
+ return ParsedHttpRequest(
369
+ method=method,
370
+ path=path,
371
+ expects_continue=expects_continue,
372
+ **result_kwargs(),
373
+ )
374
+
375
+ def parse(self, read_line: ta.Callable[[int], bytes]) -> ParseHttpRequestResult:
376
+ return self._run_read_line_coro(self.coro_parse(), read_line)
@@ -0,0 +1,17 @@
1
+ # ruff: noqa: UP006 UP007
2
+ import typing as ta
3
+
4
+
5
+ class HttpProtocolVersion(ta.NamedTuple):
6
+ major: int
7
+ minor: int
8
+
9
+ def __str__(self) -> str:
10
+ return f'HTTP/{self.major}.{self.minor}'
11
+
12
+
13
+ class HttpProtocolVersions:
14
+ HTTP_0_9 = HttpProtocolVersion(0, 9)
15
+ HTTP_1_0 = HttpProtocolVersion(1, 0)
16
+ HTTP_1_1 = HttpProtocolVersion(1, 1)
17
+ HTTP_2_0 = HttpProtocolVersion(2, 0)