toms-fast 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toms_fast-0.2.1.dist-info/METADATA +467 -0
- toms_fast-0.2.1.dist-info/RECORD +60 -0
- toms_fast-0.2.1.dist-info/WHEEL +4 -0
- toms_fast-0.2.1.dist-info/entry_points.txt +2 -0
- tomskit/__init__.py +0 -0
- tomskit/celery/README.md +693 -0
- tomskit/celery/__init__.py +4 -0
- tomskit/celery/celery.py +306 -0
- tomskit/celery/config.py +377 -0
- tomskit/cli/__init__.py +207 -0
- tomskit/cli/__main__.py +8 -0
- tomskit/cli/scaffold.py +123 -0
- tomskit/cli/templates/__init__.py +42 -0
- tomskit/cli/templates/base.py +348 -0
- tomskit/cli/templates/celery.py +101 -0
- tomskit/cli/templates/extensions.py +213 -0
- tomskit/cli/templates/fastapi.py +400 -0
- tomskit/cli/templates/migrations.py +281 -0
- tomskit/cli/templates_config.py +122 -0
- tomskit/logger/README.md +466 -0
- tomskit/logger/__init__.py +4 -0
- tomskit/logger/config.py +106 -0
- tomskit/logger/logger.py +290 -0
- tomskit/py.typed +0 -0
- tomskit/redis/README.md +462 -0
- tomskit/redis/__init__.py +6 -0
- tomskit/redis/config.py +85 -0
- tomskit/redis/redis_pool.py +87 -0
- tomskit/redis/redis_sync.py +66 -0
- tomskit/server/__init__.py +47 -0
- tomskit/server/config.py +117 -0
- tomskit/server/exceptions.py +412 -0
- tomskit/server/middleware.py +371 -0
- tomskit/server/parser.py +312 -0
- tomskit/server/resource.py +464 -0
- tomskit/server/server.py +276 -0
- tomskit/server/type.py +263 -0
- tomskit/sqlalchemy/README.md +590 -0
- tomskit/sqlalchemy/__init__.py +20 -0
- tomskit/sqlalchemy/config.py +125 -0
- tomskit/sqlalchemy/database.py +125 -0
- tomskit/sqlalchemy/pagination.py +359 -0
- tomskit/sqlalchemy/property.py +19 -0
- tomskit/sqlalchemy/sqlalchemy.py +131 -0
- tomskit/sqlalchemy/types.py +32 -0
- tomskit/task/README.md +67 -0
- tomskit/task/__init__.py +4 -0
- tomskit/task/task_manager.py +124 -0
- tomskit/tools/README.md +63 -0
- tomskit/tools/__init__.py +18 -0
- tomskit/tools/config.py +70 -0
- tomskit/tools/warnings.py +37 -0
- tomskit/tools/woker.py +81 -0
- tomskit/utils/README.md +666 -0
- tomskit/utils/README_SERIALIZER.md +644 -0
- tomskit/utils/__init__.py +35 -0
- tomskit/utils/fields.py +434 -0
- tomskit/utils/marshal_utils.py +137 -0
- tomskit/utils/response_utils.py +13 -0
- tomskit/utils/serializers.py +447 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FastAPI middleware for request ID tracking and resource cleanup.
|
|
3
|
+
|
|
4
|
+
This module provides middleware to:
|
|
5
|
+
1. Automatically handle X-Request-ID headers
|
|
6
|
+
2. Clean up resources (database sessions, Redis connections, etc.) after request completion
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import logging
|
|
11
|
+
import uuid
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from typing import Any, Optional
|
|
14
|
+
|
|
15
|
+
from fastapi import Request
|
|
16
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
17
|
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
18
|
+
|
|
19
|
+
from tomskit.logger import set_app_trace_id
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RequestIDMiddleware(BaseHTTPMiddleware):
|
|
25
|
+
"""
|
|
26
|
+
Middleware to handle X-Request-ID header for request tracking.
|
|
27
|
+
|
|
28
|
+
This middleware:
|
|
29
|
+
1. Reads X-Request-ID from request headers (if present)
|
|
30
|
+
2. Generates a new UUID if not present
|
|
31
|
+
3. Sets the request ID to the logging context via set_app_trace_id()
|
|
32
|
+
4. Adds X-Request-ID to response headers
|
|
33
|
+
|
|
34
|
+
This enables distributed tracing and request correlation across services.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
from tomskit.server import FastApp, RequestIDMiddleware
|
|
38
|
+
|
|
39
|
+
app = FastApp()
|
|
40
|
+
app.add_middleware(RequestIDMiddleware)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
app: ASGIApp,
|
|
46
|
+
header_name: str = "X-Request-ID",
|
|
47
|
+
include_in_response: bool = True,
|
|
48
|
+
generate_on_missing: bool = True,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Initialize RequestIDMiddleware.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
app: The ASGI application.
|
|
55
|
+
header_name: The header name to read/write request ID. Default: "X-Request-ID".
|
|
56
|
+
include_in_response: Whether to include request ID in response headers. Default: True.
|
|
57
|
+
generate_on_missing: Whether to generate a new ID if not present in request. Default: True.
|
|
58
|
+
"""
|
|
59
|
+
super().__init__(app)
|
|
60
|
+
self.header_name = header_name
|
|
61
|
+
self.include_in_response = include_in_response
|
|
62
|
+
self.generate_on_missing = generate_on_missing
|
|
63
|
+
|
|
64
|
+
async def dispatch(self, request: Request, call_next):
|
|
65
|
+
"""
|
|
66
|
+
Process the request and set request ID in context.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
request: The incoming request.
|
|
70
|
+
call_next: The next middleware or route handler.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Response with X-Request-ID header (if enabled).
|
|
74
|
+
"""
|
|
75
|
+
# 1. Try to get request ID from request headers
|
|
76
|
+
request_id = request.headers.get(self.header_name)
|
|
77
|
+
|
|
78
|
+
# 2. Generate new ID if not present and generation is enabled
|
|
79
|
+
if not request_id and self.generate_on_missing:
|
|
80
|
+
request_id = str(uuid.uuid4())
|
|
81
|
+
|
|
82
|
+
# 3. Set request ID to logging context (if we have one)
|
|
83
|
+
if request_id:
|
|
84
|
+
set_app_trace_id(request_id)
|
|
85
|
+
|
|
86
|
+
# 4. Process the request
|
|
87
|
+
response = await call_next(request)
|
|
88
|
+
|
|
89
|
+
# 5. Add request ID to response headers (if enabled and we have one)
|
|
90
|
+
if self.include_in_response and request_id:
|
|
91
|
+
response.headers[self.header_name] = request_id
|
|
92
|
+
|
|
93
|
+
return response
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# ============================================================================
|
|
97
|
+
# Resource Cleanup Middleware (Strategy Pattern)
|
|
98
|
+
# ============================================================================
|
|
99
|
+
|
|
100
|
+
class CleanupStrategy(ABC):
|
|
101
|
+
"""
|
|
102
|
+
Abstract base class for resource cleanup strategies.
|
|
103
|
+
|
|
104
|
+
Defines a unified interface for resource cleanup, supporting multiple resource types
|
|
105
|
+
(database, Redis, etc.). Each strategy is responsible for cleaning up a specific type of resource.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
@abstractmethod
|
|
109
|
+
async def cleanup(self, state: dict[str, Any]) -> None:
|
|
110
|
+
"""
|
|
111
|
+
Execute resource cleanup.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
state: ASGI scope state dictionary containing resource identifiers created during the request
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
Exception: Any exceptions during cleanup should be caught and logged, not raised
|
|
118
|
+
"""
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
@abstractmethod
|
|
123
|
+
def resource_name(self) -> str:
|
|
124
|
+
"""
|
|
125
|
+
Return resource name for logging purposes.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Resource name string, e.g., "database_session" or "redis_client"
|
|
129
|
+
"""
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class ResourceCleanupMiddleware:
|
|
134
|
+
"""
|
|
135
|
+
Resource cleanup middleware.
|
|
136
|
+
|
|
137
|
+
Uses strategy pattern to manage cleanup of multiple resources. Automatically cleans up
|
|
138
|
+
resources created during the request (such as database sessions, Redis connections, etc.)
|
|
139
|
+
after HTTP response completion to prevent resource leaks.
|
|
140
|
+
|
|
141
|
+
How it works:
|
|
142
|
+
1. Intercepts ASGI send calls using a message queue
|
|
143
|
+
2. Listens for response completion event (http.response.body with more_body=False)
|
|
144
|
+
3. Executes all registered cleanup strategies after response completion
|
|
145
|
+
4. Ensures cleanup is executed even if application raises exceptions
|
|
146
|
+
|
|
147
|
+
Features:
|
|
148
|
+
- Supports cleanup of multiple resource types (via strategy pattern)
|
|
149
|
+
- Automatic exception handling ensures cleanup logic doesn't affect responses
|
|
150
|
+
- Supports streaming responses (Server-Sent Events, large file downloads, etc.)
|
|
151
|
+
- Complete error handling and logging
|
|
152
|
+
|
|
153
|
+
Example:
|
|
154
|
+
from tomskit.server import FastApp, ResourceCleanupMiddleware
|
|
155
|
+
from tomskit.sqlalchemy.database import DatabaseCleanupStrategy
|
|
156
|
+
from tomskit.redis.redis_pool import RedisCleanupStrategy
|
|
157
|
+
|
|
158
|
+
app = FastApp()
|
|
159
|
+
|
|
160
|
+
# 使用数据库清理策略
|
|
161
|
+
app.add_middleware(
|
|
162
|
+
ResourceCleanupMiddleware,
|
|
163
|
+
strategies=[DatabaseCleanupStrategy()]
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# 或使用多个策略
|
|
167
|
+
app.add_middleware(
|
|
168
|
+
ResourceCleanupMiddleware,
|
|
169
|
+
strategies=[
|
|
170
|
+
DatabaseCleanupStrategy(),
|
|
171
|
+
RedisCleanupStrategy(),
|
|
172
|
+
]
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# If no strategies provided, no cleanup will be performed
|
|
176
|
+
app.add_middleware(ResourceCleanupMiddleware)
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
app: ASGIApp,
|
|
182
|
+
strategies: Optional[list[CleanupStrategy]] = None,
|
|
183
|
+
cleanup_timeout: float = 5.0,
|
|
184
|
+
):
|
|
185
|
+
"""
|
|
186
|
+
Initialize resource cleanup middleware.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
app: ASGI application instance
|
|
190
|
+
strategies: List of cleanup strategies. If None or empty, no cleanup will be performed
|
|
191
|
+
cleanup_timeout: Timeout for cleanup operations in seconds, default 5.0
|
|
192
|
+
"""
|
|
193
|
+
self.app = app
|
|
194
|
+
self.cleanup_timeout = cleanup_timeout
|
|
195
|
+
|
|
196
|
+
# If no strategies provided, use empty list (no cleanup will be performed)
|
|
197
|
+
if strategies is None:
|
|
198
|
+
self.strategies: list[CleanupStrategy] = []
|
|
199
|
+
else:
|
|
200
|
+
self.strategies = strategies
|
|
201
|
+
|
|
202
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
203
|
+
"""
|
|
204
|
+
Handle ASGI request.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
scope: ASGI scope dictionary
|
|
208
|
+
receive: ASGI receive callable
|
|
209
|
+
send: ASGI send callable
|
|
210
|
+
"""
|
|
211
|
+
# Only handle HTTP requests
|
|
212
|
+
if scope['type'] != 'http':
|
|
213
|
+
await self.app(scope, receive, send)
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
# If no cleanup strategies, skip the queue mechanism for efficiency
|
|
217
|
+
if not self.strategies:
|
|
218
|
+
await self.app(scope, receive, send)
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
# Get state (shared with request.state)
|
|
222
|
+
# In Starlette, scope['state'] is a State object, not a dict
|
|
223
|
+
# We should use it directly if it exists, otherwise create a dict
|
|
224
|
+
if 'state' not in scope:
|
|
225
|
+
scope['state'] = {}
|
|
226
|
+
state = scope['state']
|
|
227
|
+
|
|
228
|
+
# Create message queue to intercept send calls
|
|
229
|
+
send_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
230
|
+
|
|
231
|
+
# Track if cleanup has been executed to avoid duplicate execution
|
|
232
|
+
cleanup_executed = asyncio.Event()
|
|
233
|
+
|
|
234
|
+
# Wrap send function to put messages into queue
|
|
235
|
+
async def send_wrapper(message: dict[str, Any]) -> None: # type: ignore
|
|
236
|
+
await send_queue.put(message)
|
|
237
|
+
|
|
238
|
+
# Sender coroutine: get messages from queue and send
|
|
239
|
+
async def sender() -> None:
|
|
240
|
+
"""
|
|
241
|
+
Sender coroutine responsible for getting messages from queue and sending to client.
|
|
242
|
+
Triggers resource cleanup when response completion is detected.
|
|
243
|
+
"""
|
|
244
|
+
response_completed = False
|
|
245
|
+
try:
|
|
246
|
+
while True:
|
|
247
|
+
message = await send_queue.get()
|
|
248
|
+
await send(message)
|
|
249
|
+
|
|
250
|
+
# Detect response completion: http.response.body with more_body=False
|
|
251
|
+
if (
|
|
252
|
+
message['type'] == 'http.response.body'
|
|
253
|
+
and not message.get('more_body', False)
|
|
254
|
+
):
|
|
255
|
+
response_completed = True
|
|
256
|
+
break
|
|
257
|
+
|
|
258
|
+
except asyncio.CancelledError:
|
|
259
|
+
# If cancelled, still attempt to cleanup resources
|
|
260
|
+
response_completed = True
|
|
261
|
+
except Exception as e:
|
|
262
|
+
logger.error(f"Sender exception: {e}", exc_info=True)
|
|
263
|
+
response_completed = True
|
|
264
|
+
finally:
|
|
265
|
+
# Execute resource cleanup regardless of normal completion
|
|
266
|
+
if response_completed and not cleanup_executed.is_set():
|
|
267
|
+
cleanup_executed.set()
|
|
268
|
+
await self._execute_cleanup(state)
|
|
269
|
+
|
|
270
|
+
# Application handler coroutine
|
|
271
|
+
async def app_handler() -> None:
|
|
272
|
+
"""
|
|
273
|
+
Application handler coroutine that executes actual request processing.
|
|
274
|
+
"""
|
|
275
|
+
try:
|
|
276
|
+
await self.app(scope, receive, send_wrapper) # type: ignore[arg-type]
|
|
277
|
+
except Exception as e:
|
|
278
|
+
logger.error(f"Application handler exception: {e}", exc_info=True)
|
|
279
|
+
# Even if application raises exception, ensure response completion signal is sent
|
|
280
|
+
# so that sender can execute cleanup
|
|
281
|
+
if send_queue.empty():
|
|
282
|
+
# If queue is empty, response has not been sent yet
|
|
283
|
+
# Send an error response to ensure cleanup execution
|
|
284
|
+
try:
|
|
285
|
+
await send_wrapper({
|
|
286
|
+
'type': 'http.response.start',
|
|
287
|
+
'status': 500,
|
|
288
|
+
'headers': [],
|
|
289
|
+
})
|
|
290
|
+
await send_wrapper({
|
|
291
|
+
'type': 'http.response.body',
|
|
292
|
+
'body': b'',
|
|
293
|
+
'more_body': False,
|
|
294
|
+
})
|
|
295
|
+
except Exception:
|
|
296
|
+
pass
|
|
297
|
+
raise
|
|
298
|
+
|
|
299
|
+
# Run application handler and sender concurrently
|
|
300
|
+
# Use return_exceptions=True to prevent exceptions from stopping gather
|
|
301
|
+
results = await asyncio.gather(
|
|
302
|
+
app_handler(),
|
|
303
|
+
sender(),
|
|
304
|
+
return_exceptions=True,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# If cleanup hasn't been executed yet (e.g., app_handler failed before sending anything),
|
|
308
|
+
# execute it now
|
|
309
|
+
if not cleanup_executed.is_set():
|
|
310
|
+
cleanup_executed.set()
|
|
311
|
+
await self._execute_cleanup(state)
|
|
312
|
+
|
|
313
|
+
# Re-raise exceptions from app_handler if any
|
|
314
|
+
app_result = results[0]
|
|
315
|
+
if isinstance(app_result, Exception):
|
|
316
|
+
raise app_result
|
|
317
|
+
|
|
318
|
+
async def _execute_cleanup(self, state: dict[str, Any]) -> None:
|
|
319
|
+
"""
|
|
320
|
+
Execute all registered cleanup strategies.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
state: ASGI scope state dictionary
|
|
324
|
+
"""
|
|
325
|
+
if not self.strategies:
|
|
326
|
+
return
|
|
327
|
+
|
|
328
|
+
# Execute all cleanup strategies concurrently with timeout
|
|
329
|
+
cleanup_tasks = [
|
|
330
|
+
self._cleanup_with_timeout(strategy, state)
|
|
331
|
+
for strategy in self.strategies
|
|
332
|
+
]
|
|
333
|
+
|
|
334
|
+
results = await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
|
335
|
+
|
|
336
|
+
# Check results and log
|
|
337
|
+
for strategy, result in zip(self.strategies, results):
|
|
338
|
+
if isinstance(result, Exception):
|
|
339
|
+
logger.error(
|
|
340
|
+
f"Cleanup strategy '{strategy.resource_name}' failed: {result}",
|
|
341
|
+
exc_info=result if isinstance(result, BaseException) else None,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
async def _cleanup_with_timeout(
|
|
345
|
+
self,
|
|
346
|
+
strategy: CleanupStrategy,
|
|
347
|
+
state: dict[str, Any],
|
|
348
|
+
) -> None:
|
|
349
|
+
"""
|
|
350
|
+
Execute a single cleanup strategy with timeout protection.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
strategy: Cleanup strategy instance
|
|
354
|
+
state: ASGI scope state dictionary
|
|
355
|
+
"""
|
|
356
|
+
try:
|
|
357
|
+
await asyncio.wait_for(
|
|
358
|
+
strategy.cleanup(state),
|
|
359
|
+
timeout=self.cleanup_timeout,
|
|
360
|
+
)
|
|
361
|
+
except asyncio.TimeoutError:
|
|
362
|
+
logger.warning(
|
|
363
|
+
f"Cleanup strategy '{strategy.resource_name}' timed out "
|
|
364
|
+
f"(timeout: {self.cleanup_timeout}s)"
|
|
365
|
+
)
|
|
366
|
+
except Exception as e:
|
|
367
|
+
# Exceptions from cleanup strategies should not be raised
|
|
368
|
+
logger.error(
|
|
369
|
+
f"Cleanup strategy '{strategy.resource_name}' exception: {e}",
|
|
370
|
+
exc_info=True,
|
|
371
|
+
)
|
tomskit/server/parser.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
Any, Dict, List, Literal, Mapping, Optional, Type, get_args, get_origin
|
|
3
|
+
)
|
|
4
|
+
|
|
5
|
+
from fastapi import HTTPException, Request, UploadFile
|
|
6
|
+
from starlette.datastructures import FormData
|
|
7
|
+
from pydantic import (
|
|
8
|
+
BaseModel, ConfigDict, Field, ValidationError, create_model, TypeAdapter
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ParserModel(BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Base model that allows extra fields and dictionary‐style access.
|
|
15
|
+
"""
|
|
16
|
+
model_config = ConfigDict(
|
|
17
|
+
extra="allow",
|
|
18
|
+
populate_by_name=True,
|
|
19
|
+
arbitrary_types_allowed=True # ← 允许任意类作为字段类型
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def __getitem__(self, key: str) -> Any:
|
|
23
|
+
if key == "model_config":
|
|
24
|
+
key = "model_config_field"
|
|
25
|
+
return getattr(self, key, None)
|
|
26
|
+
|
|
27
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
28
|
+
if key == "model_config":
|
|
29
|
+
key = "model_config_field"
|
|
30
|
+
setattr(self, key, value)
|
|
31
|
+
|
|
32
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
33
|
+
if key == "model_config":
|
|
34
|
+
key = "model_config_field"
|
|
35
|
+
return getattr(self, key, default)
|
|
36
|
+
|
|
37
|
+
def __contains__(self, key: str) -> bool:
|
|
38
|
+
if key == "model_config":
|
|
39
|
+
key = "model_config_field"
|
|
40
|
+
return hasattr(self, key)
|
|
41
|
+
|
|
42
|
+
def to_mapping(self) -> Mapping[str, Any]:
|
|
43
|
+
return self.model_dump(by_alias=True)
|
|
44
|
+
|
|
45
|
+
class RequestParser:
|
|
46
|
+
"""
|
|
47
|
+
A Flask-RESTful-like request parser built on FastAPI + Pydantic v2.
|
|
48
|
+
Supports:
|
|
49
|
+
- JSON, query, form, header, cookie, path parameters
|
|
50
|
+
- store/append/count actions
|
|
51
|
+
- required/default/nullable/choices semantics
|
|
52
|
+
- automatic TypeAdapter validation for built-ins, Annotated, custom types
|
|
53
|
+
- UploadFile and List[UploadFile]
|
|
54
|
+
"""
|
|
55
|
+
# 类级缓存:parser_id -> 模型类
|
|
56
|
+
_model_cache: Dict[str, Type[ParserModel]] = {}
|
|
57
|
+
|
|
58
|
+
# 允许的 action
|
|
59
|
+
ALLOWED_ACTIONS = {"store", "append", "count"}
|
|
60
|
+
|
|
61
|
+
# 允许的 location
|
|
62
|
+
ALLOWED_LOCATIONS = {"args", "json", "form", "header", "cookie", "path"}
|
|
63
|
+
|
|
64
|
+
def __init__(self, parser_id: Optional[str] = None) -> None:
|
|
65
|
+
self._parser_id = parser_id
|
|
66
|
+
self._arg_defs: List[Dict[str, Any]] = []
|
|
67
|
+
self._model_cls: Optional[Type[ParserModel]] = None
|
|
68
|
+
|
|
69
|
+
def add_argument(
|
|
70
|
+
self,
|
|
71
|
+
name: str,
|
|
72
|
+
arg_type: Any,
|
|
73
|
+
required: bool = False,
|
|
74
|
+
default: Any = None,
|
|
75
|
+
choices: Optional[List[Any]] = None,
|
|
76
|
+
nullable: bool = False,
|
|
77
|
+
location: Literal["args", "json", "form", "header", "cookie", "path"] = "json",
|
|
78
|
+
action: Literal["store", "append", "count"] = "store",
|
|
79
|
+
help: Optional[str] = None,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Register a new argument.
|
|
83
|
+
Raises ValueError on invalid configuration.
|
|
84
|
+
"""
|
|
85
|
+
if not name or not isinstance(name, str):
|
|
86
|
+
raise ValueError("RequestParser: `name` must be a non-empty string")
|
|
87
|
+
if action not in self.ALLOWED_ACTIONS:
|
|
88
|
+
raise ValueError(f"RequestParser: `action` must be one of {self.ALLOWED_ACTIONS}")
|
|
89
|
+
if location not in self.ALLOWED_LOCATIONS:
|
|
90
|
+
raise ValueError(f"RequestParser: `location` must be one of {self.ALLOWED_LOCATIONS}")
|
|
91
|
+
if choices is not None and not isinstance(choices, list):
|
|
92
|
+
raise ValueError("RequestParser: `choices` must be a list if provided")
|
|
93
|
+
if required and default is not None:
|
|
94
|
+
raise ValueError("RequestParser: `default` must be None when `required=True`")
|
|
95
|
+
if not required and default is None and not nullable:
|
|
96
|
+
raise ValueError("RequestParser: `default=None` with `required=False` requires `nullable=True`")
|
|
97
|
+
if choices and default is not None and default not in choices:
|
|
98
|
+
raise ValueError("RequestParser: `default` must be one of `choices`")
|
|
99
|
+
|
|
100
|
+
# Validate that arg_type is supported by Pydantic
|
|
101
|
+
try:
|
|
102
|
+
adapter = TypeAdapter(arg_type)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise ValueError(f"RequestParser: Invalid arg_type for '{name}': {e}")
|
|
105
|
+
|
|
106
|
+
# Validate default value with the same adapter
|
|
107
|
+
if default is not None:
|
|
108
|
+
try:
|
|
109
|
+
adapter.validate_python(default)
|
|
110
|
+
except ValidationError as e:
|
|
111
|
+
raise ValueError(f"RequestParser: Invalid default for '{name}': {e}")
|
|
112
|
+
|
|
113
|
+
# Validate each choice
|
|
114
|
+
for choice in (choices or []):
|
|
115
|
+
try:
|
|
116
|
+
adapter.validate_python(choice)
|
|
117
|
+
except ValidationError as e:
|
|
118
|
+
raise ValueError(f"RequestParser: Invalid choice {choice!r} for '{name}': {e}")
|
|
119
|
+
|
|
120
|
+
self._arg_defs.append({
|
|
121
|
+
"name": name,
|
|
122
|
+
"type": arg_type,
|
|
123
|
+
"required": required,
|
|
124
|
+
"default": default,
|
|
125
|
+
"choices": choices or [],
|
|
126
|
+
"nullable": nullable,
|
|
127
|
+
"location": location,
|
|
128
|
+
"action": action,
|
|
129
|
+
"help": help,
|
|
130
|
+
})
|
|
131
|
+
|
|
132
|
+
def _build_model(self) -> Type[ParserModel]:
|
|
133
|
+
"""
|
|
134
|
+
Dynamically create a Pydantic model class based on registered arguments.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
if self._parser_id is not None:
|
|
138
|
+
cached = RequestParser._model_cache.get(self._parser_id)
|
|
139
|
+
if cached is not None:
|
|
140
|
+
return cached
|
|
141
|
+
|
|
142
|
+
fields: Dict[str, Any] = {}
|
|
143
|
+
|
|
144
|
+
for d in self._arg_defs:
|
|
145
|
+
name = d["name"]
|
|
146
|
+
base_type = d["type"]
|
|
147
|
+
action = d["action"]
|
|
148
|
+
|
|
149
|
+
# Determine annotation and default
|
|
150
|
+
ann: Any = None
|
|
151
|
+
if action == "append":
|
|
152
|
+
if get_origin(base_type) is list:
|
|
153
|
+
item_type = get_args(base_type)[0]
|
|
154
|
+
else:
|
|
155
|
+
item_type = base_type
|
|
156
|
+
ann = List[item_type] # type: ignore
|
|
157
|
+
default = d["default"] if d["default"] is not None else []
|
|
158
|
+
elif action == "count":
|
|
159
|
+
ann = int
|
|
160
|
+
default = 0
|
|
161
|
+
else: # store
|
|
162
|
+
ann = base_type
|
|
163
|
+
default = ... if d["required"] and d["default"] is None else d["default"]
|
|
164
|
+
|
|
165
|
+
# nullable → Optional
|
|
166
|
+
if d["nullable"]:
|
|
167
|
+
ann = Optional[ann] # type: ignore
|
|
168
|
+
|
|
169
|
+
# choices → Literal
|
|
170
|
+
if d["choices"]:
|
|
171
|
+
# ann = Literal[tuple(d["choices"])] # type: ignore
|
|
172
|
+
ann = Literal.__getitem__(tuple[Any, ...](d["choices"]))
|
|
173
|
+
|
|
174
|
+
internal_name = name
|
|
175
|
+
field_kwargs = {}
|
|
176
|
+
if name == "model_config":
|
|
177
|
+
internal_name = "model_config_field"
|
|
178
|
+
field_kwargs["alias"] = "model_config"
|
|
179
|
+
|
|
180
|
+
metadata: Dict[str, Any] = {}
|
|
181
|
+
if d["help"]:
|
|
182
|
+
metadata["description"] = d["help"]
|
|
183
|
+
|
|
184
|
+
pydantic_field_kwargs = {**metadata, **field_kwargs}
|
|
185
|
+
|
|
186
|
+
fields[internal_name] = (ann, Field(default, **pydantic_field_kwargs))
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
cached_model = create_model("ParsedModel", __base__=ParserModel, **fields)
|
|
190
|
+
except Exception as e:
|
|
191
|
+
raise RuntimeError(f"RequestParser: Failed to create parser model: {e}")
|
|
192
|
+
if self._parser_id is not None:
|
|
193
|
+
RequestParser._model_cache[self._parser_id] = cached_model
|
|
194
|
+
return cached_model # type: ignore
|
|
195
|
+
|
|
196
|
+
async def parse_args(self, request: Request) -> ParserModel:
|
|
197
|
+
"""
|
|
198
|
+
Extract raw values from Request by location, merge per action,
|
|
199
|
+
then validate & coerce via the generated Pydantic model.
|
|
200
|
+
Raises HTTPException(422) on validation errors.
|
|
201
|
+
"""
|
|
202
|
+
if self._model_cls is None:
|
|
203
|
+
self._model_cls = self._build_model()
|
|
204
|
+
|
|
205
|
+
# 1) 根据 Content-Type 决定是否解析 form
|
|
206
|
+
content_type = request.headers.get("content-type", "")
|
|
207
|
+
if content_type.startswith("multipart/form-data") or \
|
|
208
|
+
content_type.startswith("application/x-www-form-urlencoded"):
|
|
209
|
+
try:
|
|
210
|
+
form = await request.form()
|
|
211
|
+
except Exception:
|
|
212
|
+
form = FormData()
|
|
213
|
+
else:
|
|
214
|
+
form = FormData()
|
|
215
|
+
|
|
216
|
+
# 2) JSON 解析(仅在不是表单或 multipart 时才尝试)
|
|
217
|
+
try:
|
|
218
|
+
json_body = await request.json()
|
|
219
|
+
json_body = json_body if isinstance(json_body, dict) else {}
|
|
220
|
+
except Exception:
|
|
221
|
+
json_body = {}
|
|
222
|
+
|
|
223
|
+
# 3) 其余数据源
|
|
224
|
+
query = dict(request.query_params)
|
|
225
|
+
headers = dict(request.headers)
|
|
226
|
+
cookies = request.cookies
|
|
227
|
+
path_params = request.path_params
|
|
228
|
+
|
|
229
|
+
raw: Dict[str, Any] = {}
|
|
230
|
+
|
|
231
|
+
# 2) extract & merge
|
|
232
|
+
for d in self._arg_defs:
|
|
233
|
+
key = d["name"]
|
|
234
|
+
loc = d["location"]
|
|
235
|
+
action = d["action"]
|
|
236
|
+
value: Any = None
|
|
237
|
+
value = [] if action == "append" else 0 if action == "count" else None
|
|
238
|
+
present = False
|
|
239
|
+
|
|
240
|
+
if loc == "args" and key in query:
|
|
241
|
+
present = True
|
|
242
|
+
value = self._merge(value, query[key], action)
|
|
243
|
+
|
|
244
|
+
elif loc == "json" and key in json_body:
|
|
245
|
+
present = True
|
|
246
|
+
value = self._merge(value, json_body[key], action)
|
|
247
|
+
|
|
248
|
+
elif loc == "form":
|
|
249
|
+
# 多文件
|
|
250
|
+
if get_origin(d["type"]) is list and get_args(d["type"])[0] is UploadFile:
|
|
251
|
+
fl = form.getlist(key) # type: ignore[attr-defined]
|
|
252
|
+
if fl:
|
|
253
|
+
present = True
|
|
254
|
+
for f in fl:
|
|
255
|
+
value = self._merge(value, f, action)
|
|
256
|
+
# 单文件
|
|
257
|
+
elif d["type"] is UploadFile:
|
|
258
|
+
fl = form.getlist(key) # type: ignore[attr-defined]
|
|
259
|
+
if fl:
|
|
260
|
+
present = True
|
|
261
|
+
# 如果想严格单文件,可检查 len(fl)>1
|
|
262
|
+
value = self._merge(value, fl[0], action)
|
|
263
|
+
# 普通表单字段
|
|
264
|
+
elif key in form:
|
|
265
|
+
present = True
|
|
266
|
+
value = self._merge(value, form[key], action)
|
|
267
|
+
|
|
268
|
+
elif loc == "header" and key in headers:
|
|
269
|
+
present = True
|
|
270
|
+
value = self._merge(value, headers[key], action)
|
|
271
|
+
|
|
272
|
+
elif loc == "cookie" and key in cookies:
|
|
273
|
+
present = True
|
|
274
|
+
value = self._merge(value, cookies[key], action)
|
|
275
|
+
|
|
276
|
+
elif loc == "path" and key in path_params:
|
|
277
|
+
present = True
|
|
278
|
+
value = self._merge(value, path_params[key], action)
|
|
279
|
+
|
|
280
|
+
# 仅当请求里确实提供过这个字段时,才把它塞进 raw
|
|
281
|
+
if present:
|
|
282
|
+
raw[key] = value
|
|
283
|
+
|
|
284
|
+
# 3. 交给 Pydantic 去补默认 & 验证
|
|
285
|
+
try:
|
|
286
|
+
# Pydantic v2 用 model_validate
|
|
287
|
+
return self._model_cls.model_validate(raw) # type: ignore
|
|
288
|
+
except ValidationError as e:
|
|
289
|
+
# 格式化错误,使运维快速定位字段与原因
|
|
290
|
+
error_messages = []
|
|
291
|
+
for err in e.errors():
|
|
292
|
+
loc = ".".join(str(x) for x in err.get("loc", []))
|
|
293
|
+
msg = err.get("msg", "")
|
|
294
|
+
err_type = err.get("type", "")
|
|
295
|
+
inp = err.get("input", None)
|
|
296
|
+
error_messages.append(
|
|
297
|
+
f"RequestParser: Field '{loc}': {msg} (type={err_type}, input={inp!r})"
|
|
298
|
+
)
|
|
299
|
+
# 最终把格式化后的信息放到 detail 里
|
|
300
|
+
raise HTTPException(status_code=422, detail={"errors": error_messages})
|
|
301
|
+
|
|
302
|
+
@staticmethod
|
|
303
|
+
def _merge(prev: Any, new: Any, action: str) -> Any:
|
|
304
|
+
"""Merge raw values according to action."""
|
|
305
|
+
if action == "store":
|
|
306
|
+
return new
|
|
307
|
+
if action == "append":
|
|
308
|
+
lst = prev if isinstance(prev, list) else []
|
|
309
|
+
lst.append(new)
|
|
310
|
+
return lst
|
|
311
|
+
# count
|
|
312
|
+
return (prev or 0) + 1
|