fastapi-extra 0.3.5__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/PKG-INFO +1 -1
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/__init__.py +1 -1
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/database/service.py +17 -2
- fastapi_extra-0.4.0/fastapi_extra/database/sqlmap.py +154 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/native/cursor.pyx +20 -8
- fastapi_extra-0.4.0/fastapi_extra/native/routing.pyx +196 -0
- fastapi_extra-0.4.0/fastapi_extra/native/urlparse.pyx +88 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/response.py +1 -3
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/types.py +3 -1
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/PKG-INFO +1 -1
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/SOURCES.txt +1 -0
- fastapi_extra-0.3.5/fastapi_extra/native/routing.pyx +0 -191
- fastapi_extra-0.3.5/fastapi_extra/native/urlparse.pyx +0 -50
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/LICENSE +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/README.rst +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/_patch.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/cache/__init__.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/cache/redis.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/cursor.pyi +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/database/__init__.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/database/model.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/database/session.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/dependency.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/form.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/py.typed +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/routing.pyi +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/settings.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/urlparse.pyi +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/utils.py +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/dependency_links.txt +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/requires.txt +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/top_level.txt +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/pyproject.toml +0 -0
- {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/setup.cfg +0 -0
|
@@ -2,13 +2,18 @@ __author__ = "ziyan.yin"
|
|
|
2
2
|
__date__ = "2025-01-12"
|
|
3
3
|
|
|
4
4
|
from contextvars import ContextVar
|
|
5
|
-
from typing import Any, Generic, TypeVar
|
|
5
|
+
from typing import Any, Generic, Sequence, TypeVar
|
|
6
|
+
|
|
7
|
+
from sqlalchemy import ColumnExpressionArgument
|
|
8
|
+
from sqlmodel import select
|
|
6
9
|
|
|
7
10
|
from fastapi_extra.database.model import SQLModel
|
|
8
11
|
from fastapi_extra.database.session import AsyncSession, DefaultSession
|
|
9
12
|
from fastapi_extra.dependency import AbstractService
|
|
10
13
|
|
|
11
14
|
Model = TypeVar("Model", bound=SQLModel)
|
|
15
|
+
ID = int | str
|
|
16
|
+
PK = ID | tuple[ID] | dict[str, ID]
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
class ModelService(AbstractService, Generic[Model], abstract=True):
|
|
@@ -40,14 +45,24 @@ class ModelService(AbstractService, Generic[Model], abstract=True):
|
|
|
40
45
|
assert _session is not None, "Session is not initialized"
|
|
41
46
|
return _session
|
|
42
47
|
|
|
43
|
-
async def get(self, ident:
|
|
48
|
+
async def get(self, ident: PK, **kwargs: Any) -> Model | None:
|
|
44
49
|
return await self.session.get(self.__model__, ident, **kwargs)
|
|
50
|
+
|
|
51
|
+
async def get_list(self, *clause: ColumnExpressionArgument[bool] | bool) -> Sequence[Model]:
|
|
52
|
+
return (await self.session.exec(select(self.__model__).where(*clause))).all()
|
|
45
53
|
|
|
46
54
|
async def create_model(self, **kwargs: Any) -> Model:
|
|
47
55
|
model = self.__model__.model_validate(kwargs)
|
|
48
56
|
self.session.add(model)
|
|
49
57
|
await self.session.flush()
|
|
50
58
|
return model
|
|
59
|
+
|
|
60
|
+
async def update_model(self, model: Model, **kwargs: Any) -> Model:
|
|
61
|
+
for key, value in kwargs.items():
|
|
62
|
+
if key in model.__fields__:
|
|
63
|
+
setattr(model, key, value)
|
|
64
|
+
await self.session.flush()
|
|
65
|
+
return model
|
|
51
66
|
|
|
52
67
|
async def delete(self, model: Model) -> None:
|
|
53
68
|
return await self.session.delete(model)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
__author__ = "ziyan.yin"
|
|
2
|
+
__date__ = "2026-05-14"
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import inspect
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import (Any, Callable, Protocol, get_args, get_origin,
|
|
9
|
+
get_type_hints)
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
from sqlmodel import text
|
|
13
|
+
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
14
|
+
|
|
15
|
+
from fastapi_extra.settings import Settings
|
|
16
|
+
from fastapi_extra.types import NoneType, P, T
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SQLTemplateConfig(BaseModel):
|
|
20
|
+
path: str = Field(default="./template/sql")
|
|
21
|
+
suffix: str = Field(default=".sql")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SQLTemplateSettings(Settings):
|
|
25
|
+
sqlmap: SQLTemplateConfig = Field(alias="sqlmap")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_settings = SQLTemplateSettings() # type: ignore
|
|
29
|
+
_templates = {}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def install() -> None:
|
|
33
|
+
path = Path(_settings.sqlmap.path)
|
|
34
|
+
if not path.exists():
|
|
35
|
+
return
|
|
36
|
+
for file in path.glob(f"*{_settings.sqlmap.suffix}"):
|
|
37
|
+
_templates[file.stem] = text(file.read_text(encoding="utf-8"))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _is_base_model(cls: Any) -> bool:
|
|
41
|
+
"""安全判断是否为 Pydantic BaseModel"""
|
|
42
|
+
try:
|
|
43
|
+
return isinstance(cls, type) and issubclass(cls, BaseModel)
|
|
44
|
+
except TypeError:
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SessionService(Protocol):
|
|
49
|
+
session: AsyncSession
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def Mapped(func: Callable[P, T]) -> Callable[P, T]:
|
|
53
|
+
func_name = func.__name__
|
|
54
|
+
|
|
55
|
+
sig = inspect.signature(func)
|
|
56
|
+
hints = get_type_hints(func)
|
|
57
|
+
return_hint = hints.get("return", NoneType)
|
|
58
|
+
|
|
59
|
+
origin = get_origin(return_hint)
|
|
60
|
+
args = get_args(return_hint)
|
|
61
|
+
|
|
62
|
+
# 1. 预校验 SQL 模板是否存在
|
|
63
|
+
if func_name not in _templates:
|
|
64
|
+
# 这里可以选择抛异常或者在运行时加载
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
# 处理 NoneType / 返回为空
|
|
68
|
+
if return_hint is NoneType:
|
|
69
|
+
async def execute_only(self: SessionService, *args: Any, **kwargs: Any) -> None:
|
|
70
|
+
bound_args = sig.bind(self, *args, **kwargs)
|
|
71
|
+
bound_args.apply_defaults()
|
|
72
|
+
|
|
73
|
+
# 3. 构造 SQL 参数字典
|
|
74
|
+
# 排除掉第一个参数 (self 或 service),剩下的传给 SQL execute
|
|
75
|
+
all_params = bound_args.arguments
|
|
76
|
+
param_names = list(all_params.keys())
|
|
77
|
+
if param_names:
|
|
78
|
+
sql_params = {k: all_params[k] for k in param_names[1:]}
|
|
79
|
+
else:
|
|
80
|
+
sql_params = {}
|
|
81
|
+
|
|
82
|
+
await self.session.execute(_templates[func_name], params=sql_params)
|
|
83
|
+
|
|
84
|
+
execute_only.__signature__ = sig
|
|
85
|
+
return execute_only # type: ignore
|
|
86
|
+
|
|
87
|
+
# 2. 解析返回逻辑
|
|
88
|
+
# 处理列表/序列返回: List[User], Sequence[User]
|
|
89
|
+
if origin is not None and issubclass(origin, Sequence):
|
|
90
|
+
inner_type = args[0] if args else Any
|
|
91
|
+
is_model = _is_base_model(inner_type)
|
|
92
|
+
|
|
93
|
+
async def fetch_all(self: SessionService, *args: Any, **kwargs: Any) -> Sequence:
|
|
94
|
+
bound_args = sig.bind(self, *args, **kwargs)
|
|
95
|
+
bound_args.apply_defaults()
|
|
96
|
+
|
|
97
|
+
# 3. 构造 SQL 参数字典
|
|
98
|
+
# 排除掉第一个参数 (self 或 service),剩下的传给 SQL execute
|
|
99
|
+
all_params = bound_args.arguments
|
|
100
|
+
param_names = list(all_params.keys())
|
|
101
|
+
if param_names:
|
|
102
|
+
sql_params = {k: all_params[k] for k in param_names[1:]}
|
|
103
|
+
else:
|
|
104
|
+
sql_params = {}
|
|
105
|
+
|
|
106
|
+
sql = _templates[func_name]
|
|
107
|
+
result = await self.session.execute(sql, params=sql_params)
|
|
108
|
+
# 使用 scalars() 获取单列结果,如果是多列会自动映射到 Row
|
|
109
|
+
items = result.all() if is_model else result.scalars().all()
|
|
110
|
+
if is_model:
|
|
111
|
+
return [inner_type.model_validate(row, from_attributes=True) for row in items] # type: ignore
|
|
112
|
+
return items
|
|
113
|
+
|
|
114
|
+
fetch_all.__signature__ = sig
|
|
115
|
+
return fetch_all # type: ignore
|
|
116
|
+
|
|
117
|
+
# 处理单体返回 (Optional[User] 或 User)
|
|
118
|
+
is_nullable = NoneType in args
|
|
119
|
+
# 提取实际类型 (处理 Optional[User] 拿到 User)
|
|
120
|
+
actual_type = next((a for a in args if a is not NoneType), return_hint)
|
|
121
|
+
is_model = _is_base_model(actual_type)
|
|
122
|
+
|
|
123
|
+
async def fetch_one(self: SessionService, *args: Any, **kwargs: Any) -> Any:
|
|
124
|
+
bound_args = sig.bind(self, *args, **kwargs)
|
|
125
|
+
bound_args.apply_defaults()
|
|
126
|
+
|
|
127
|
+
# 3. 构造 SQL 参数字典
|
|
128
|
+
# 排除掉第一个参数 (self 或 service),剩下的传给 SQL execute
|
|
129
|
+
all_params = bound_args.arguments
|
|
130
|
+
param_names = list(all_params.keys())
|
|
131
|
+
if param_names:
|
|
132
|
+
sql_params = {k: all_params[k] for k in param_names[1:]}
|
|
133
|
+
else:
|
|
134
|
+
sql_params = {}
|
|
135
|
+
|
|
136
|
+
sql = _templates[func_name]
|
|
137
|
+
result = await self.session.execute(sql, params=sql_params)
|
|
138
|
+
|
|
139
|
+
# 根据是否是模型决定取 row 还是 scalar
|
|
140
|
+
data = result.first() if is_model else result.scalar_one_or_none()
|
|
141
|
+
|
|
142
|
+
if data is None:
|
|
143
|
+
if is_nullable:
|
|
144
|
+
return None
|
|
145
|
+
raise ValueError(f"Query {func_name} expected a result but got None")
|
|
146
|
+
|
|
147
|
+
if is_model:
|
|
148
|
+
# 将 Row 转为 dict 再校验
|
|
149
|
+
return actual_type.model_validate(data, from_attributes=True)
|
|
150
|
+
return data
|
|
151
|
+
|
|
152
|
+
fetch_one.__signature__ = sig
|
|
153
|
+
return fetch_one # type: ignore
|
|
154
|
+
|
|
@@ -1,7 +1,12 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
# cython: language_level=3
|
|
2
|
+
# cython: boundscheck=False
|
|
3
|
+
# cython: wraparound=False
|
|
4
|
+
# cython: initializedcheck=False
|
|
5
|
+
# cython: nonecheck=False
|
|
6
|
+
|
|
3
7
|
|
|
4
|
-
|
|
8
|
+
cimport cython
|
|
9
|
+
from cpython cimport datetime, time
|
|
5
10
|
|
|
6
11
|
|
|
7
12
|
cdef int _sequence_length = 10
|
|
@@ -17,16 +22,19 @@ cdef class Cursor:
|
|
|
17
22
|
long long last_point
|
|
18
23
|
|
|
19
24
|
def __init__(self, seed: int):
|
|
20
|
-
self.seed = seed
|
|
21
|
-
if self.seed > 16:
|
|
22
|
-
self.seed %= 16
|
|
25
|
+
self.seed = seed % 16
|
|
23
26
|
self.cursor = 0
|
|
24
27
|
self.last_point = 0
|
|
25
28
|
|
|
26
29
|
cdef inline long long fetch(self) nogil:
|
|
27
30
|
cdef:
|
|
28
31
|
long long count = 0
|
|
29
|
-
|
|
32
|
+
# 使用 10ms 作为刻度
|
|
33
|
+
long long point = <long long>((time.time() - _start_point) * 100)
|
|
34
|
+
|
|
35
|
+
# 简单的时钟回拨处理
|
|
36
|
+
if point < self.last_point:
|
|
37
|
+
point = self.last_point
|
|
30
38
|
|
|
31
39
|
if self.last_point == point:
|
|
32
40
|
count = self.cursor + 1
|
|
@@ -34,10 +42,14 @@ cdef class Cursor:
|
|
|
34
42
|
return 0
|
|
35
43
|
else:
|
|
36
44
|
self.last_point = point
|
|
37
|
-
|
|
45
|
+
|
|
46
|
+
self.cursor = <int>count
|
|
47
|
+
# ID 组成: 时间戳(高位) + 种子(4位) + 序列号(10位)
|
|
38
48
|
return (point << (_sequence_length + 4)) + (self.seed << _sequence_length) + count
|
|
39
49
|
|
|
40
50
|
def next_val(self) -> int:
|
|
51
|
+
cdef long long index
|
|
52
|
+
|
|
41
53
|
index = self.fetch()
|
|
42
54
|
while index == 0:
|
|
43
55
|
index = self.fetch()
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# cython: language_level=3
|
|
2
|
+
# cython: boundscheck=False
|
|
3
|
+
# cython: wraparound=False
|
|
4
|
+
# cython: initializedcheck=False
|
|
5
|
+
# cython: nonecheck=False
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__author__ = "ziyan.yin"
|
|
9
|
+
__describe__ = ""
|
|
10
|
+
|
|
11
|
+
cimport cython
|
|
12
|
+
|
|
13
|
+
from starlette import _utils as starlette_utils
|
|
14
|
+
from starlette.datastructures import URL
|
|
15
|
+
from starlette.responses import RedirectResponse
|
|
16
|
+
|
|
17
|
+
# 使用 Fast Enum 思想定义常量,减少属性查找
|
|
18
|
+
DEF MATCH_NONE = 0
|
|
19
|
+
DEF MATCH_PARTIAL = 1
|
|
20
|
+
DEF MATCH_FULL = 2
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
cdef int find_params(unicode path):
|
|
24
|
+
cdef Py_UCS4 ch
|
|
25
|
+
cdef int i
|
|
26
|
+
for i, ch in enumerate(path):
|
|
27
|
+
if ch == "{":
|
|
28
|
+
return i
|
|
29
|
+
return -1
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
cdef int get_longest_common_prefix(unicode path, unicode node_path):
|
|
33
|
+
cdef int i
|
|
34
|
+
cdef int path_len = len(path)
|
|
35
|
+
cdef int node_len = len(node_path)
|
|
36
|
+
cdef int max_len = path_len if path_len < node_len else node_len
|
|
37
|
+
|
|
38
|
+
for i in range(max_len):
|
|
39
|
+
if path[i] != node_path[i]:
|
|
40
|
+
return i
|
|
41
|
+
return max_len
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@cython.no_gc
|
|
45
|
+
cdef class RouteNode:
|
|
46
|
+
|
|
47
|
+
cdef readonly:
|
|
48
|
+
unicode prefix
|
|
49
|
+
list params_routes
|
|
50
|
+
list static_routes
|
|
51
|
+
dict children
|
|
52
|
+
|
|
53
|
+
cdef public object parent
|
|
54
|
+
|
|
55
|
+
def __init__(self, prefix: str):
|
|
56
|
+
self.prefix = prefix
|
|
57
|
+
self.params_routes = []
|
|
58
|
+
self.static_routes = []
|
|
59
|
+
self.children = {}
|
|
60
|
+
self.parent = None
|
|
61
|
+
|
|
62
|
+
def add_route(self, fullpath: str, handler: object):
|
|
63
|
+
cdef int index = find_params(fullpath)
|
|
64
|
+
cdef bint wild_child = index >= 0
|
|
65
|
+
cdef unicode path = fullpath[:index] if wild_child else fullpath
|
|
66
|
+
insert_route(self, path, wild_child, handler)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
cdef void insert_route(RouteNode node, unicode path, bint wild_child, object handler):
|
|
70
|
+
if node.prefix == path:
|
|
71
|
+
if wild_child:
|
|
72
|
+
node.params_routes.append(handler)
|
|
73
|
+
else:
|
|
74
|
+
node.static_routes.append(handler)
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
# 性能优化:避免 removeprefix 产生新字符串
|
|
78
|
+
cdef int prefix_len = len(node.prefix)
|
|
79
|
+
cdef Py_UCS4 key = path[prefix_len]
|
|
80
|
+
|
|
81
|
+
if key not in node.children:
|
|
82
|
+
child = RouteNode(path)
|
|
83
|
+
child.parent = node
|
|
84
|
+
if wild_child:
|
|
85
|
+
child.params_routes.append(handler)
|
|
86
|
+
else:
|
|
87
|
+
child.static_routes.append(handler)
|
|
88
|
+
node.children[key] = child
|
|
89
|
+
return
|
|
90
|
+
|
|
91
|
+
cdef RouteNode child_node = node.children[key]
|
|
92
|
+
cdef int i = get_longest_common_prefix(child_node.prefix, path)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
if i == len(child_node.prefix):
|
|
96
|
+
insert_route(node.children[key], path, wild_child, handler)
|
|
97
|
+
return
|
|
98
|
+
longest_prefix = child_node.prefix[:i]
|
|
99
|
+
new_mid_node = RouteNode(longest_prefix)
|
|
100
|
+
new_mid_node.parent = node
|
|
101
|
+
|
|
102
|
+
node.children[key] = new_mid_node
|
|
103
|
+
|
|
104
|
+
# 将旧子节点挂在新节点下
|
|
105
|
+
cdef Py_UCS4 old_key = child_node.prefix[i]
|
|
106
|
+
new_mid_node.children[old_key] = child_node
|
|
107
|
+
child_node.parent = new_mid_node
|
|
108
|
+
|
|
109
|
+
# 递归插入新路由
|
|
110
|
+
insert_route(new_mid_node, path, wild_child, handler)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
cdef RouteNode search_node(unicode url):
|
|
114
|
+
cdef RouteNode current_node = ROOT
|
|
115
|
+
cdef int n = len(url)
|
|
116
|
+
cdef int i = get_longest_common_prefix(url, current_node.prefix)
|
|
117
|
+
cdef Py_UCS4 key
|
|
118
|
+
|
|
119
|
+
while i < n:
|
|
120
|
+
key = url[i]
|
|
121
|
+
if key not in current_node.children:
|
|
122
|
+
break
|
|
123
|
+
current_node = <RouteNode>current_node.children[key]
|
|
124
|
+
i = get_longest_common_prefix(url, current_node.prefix)
|
|
125
|
+
return current_node
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def handle(scope, receive, send):
|
|
129
|
+
cdef unicode route_path
|
|
130
|
+
cdef RouteNode leaf_node, current_node
|
|
131
|
+
cdef object route, match, child_scope
|
|
132
|
+
cdef object partial = None
|
|
133
|
+
cdef object partial_scope = None
|
|
134
|
+
|
|
135
|
+
cdef object router = scope["app"].router
|
|
136
|
+
|
|
137
|
+
if "router" not in scope:
|
|
138
|
+
scope["router"] = router
|
|
139
|
+
|
|
140
|
+
if scope["type"] == "lifespan":
|
|
141
|
+
await router.lifespan(scope, receive, send)
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
scope["path"] = route_path = starlette_utils.get_route_path(scope)
|
|
145
|
+
leaf_node = search_node(route_path)
|
|
146
|
+
|
|
147
|
+
# 1. 静态路由精确匹配
|
|
148
|
+
if leaf_node.prefix == route_path:
|
|
149
|
+
for route in leaf_node.static_routes:
|
|
150
|
+
match, child_scope = route.matches(scope)
|
|
151
|
+
if match.value == MATCH_FULL:
|
|
152
|
+
scope.update(child_scope)
|
|
153
|
+
await route.handle(scope, receive, send)
|
|
154
|
+
return
|
|
155
|
+
elif match.value == MATCH_PARTIAL and partial is None:
|
|
156
|
+
partial, partial_scope = route, child_scope
|
|
157
|
+
|
|
158
|
+
# 2. 参数路由/通配符路由(向上回溯)
|
|
159
|
+
current_node = leaf_node
|
|
160
|
+
while current_node is not None:
|
|
161
|
+
if current_node.params_routes:
|
|
162
|
+
for route in current_node.params_routes:
|
|
163
|
+
match, child_scope = route.matches(scope)
|
|
164
|
+
if match.value == MATCH_FULL:
|
|
165
|
+
scope.update(child_scope)
|
|
166
|
+
await route.handle(scope, receive, send)
|
|
167
|
+
return
|
|
168
|
+
elif match.value == MATCH_PARTIAL and partial is None:
|
|
169
|
+
partial, partial_scope = route, child_scope
|
|
170
|
+
current_node = current_node.parent
|
|
171
|
+
|
|
172
|
+
if partial is not None:
|
|
173
|
+
scope.update(partial_scope)
|
|
174
|
+
await partial.handle(scope, receive, send)
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
if scope["type"] == "http" and router.redirect_slashes and route_path != "/":
|
|
178
|
+
new_path = route_path.rstrip("/") if route_path.endswith("/") else route_path + "/"
|
|
179
|
+
redirect_scope = {**scope, "path": new_path}
|
|
180
|
+
|
|
181
|
+
if leaf_node.prefix == new_path and leaf_node.static_routes:
|
|
182
|
+
redirect_url = URL(scope=redirect_scope)
|
|
183
|
+
response = RedirectResponse(url=str(redirect_url))
|
|
184
|
+
await response(scope, receive, send)
|
|
185
|
+
return
|
|
186
|
+
|
|
187
|
+
await router.default(scope, receive, send)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
ROOT = RouteNode("")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def install(app):
|
|
194
|
+
for route in app.routes:
|
|
195
|
+
ROOT.add_route(route.path, route)
|
|
196
|
+
app.router.middleware_stack = handle
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# cython: language_level=3
|
|
2
|
+
# cython: boundscheck=False
|
|
3
|
+
# cython: wraparound=False
|
|
4
|
+
# cython: initializedcheck=False
|
|
5
|
+
# cython: nonecheck=False
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__author__ = "ziyan.yin"
|
|
9
|
+
__describe__ = ""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
from libc.stdlib cimport strtol
|
|
13
|
+
from libc.string cimport memmove, strlen
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
cdef inline size_t _unquote_optimized(char* c_str, bint change_plus) nogil:
|
|
17
|
+
cdef:
|
|
18
|
+
size_t n = strlen(c_str)
|
|
19
|
+
size_t read_pos = 0
|
|
20
|
+
size_t write_pos = 0
|
|
21
|
+
char hex_buf[3]
|
|
22
|
+
|
|
23
|
+
hex_buf[2] = 0 # 确保 strtol 有终止符
|
|
24
|
+
|
|
25
|
+
while read_pos < n:
|
|
26
|
+
if c_str[read_pos] == b'+' and change_plus:
|
|
27
|
+
c_str[write_pos] = b' '
|
|
28
|
+
read_pos += 1
|
|
29
|
+
write_pos += 1
|
|
30
|
+
elif c_str[read_pos] == b'%' and read_pos + 2 < n:
|
|
31
|
+
# 提取两位十六进制数
|
|
32
|
+
hex_buf[0] = c_str[read_pos + 1]
|
|
33
|
+
hex_buf[1] = c_str[read_pos + 2]
|
|
34
|
+
# 转换为字符
|
|
35
|
+
c_str[write_pos] = <char>strtol(hex_buf, NULL, 16)
|
|
36
|
+
read_pos += 3
|
|
37
|
+
write_pos += 1
|
|
38
|
+
else:
|
|
39
|
+
# 普通字符移动
|
|
40
|
+
c_str[write_pos] = c_str[read_pos]
|
|
41
|
+
read_pos += 1
|
|
42
|
+
write_pos += 1
|
|
43
|
+
|
|
44
|
+
return write_pos
|
|
45
|
+
|
|
46
|
+
def unquote(bytes val, str encoding = "utf-8"):
|
|
47
|
+
if not val: return ""
|
|
48
|
+
# 注意:这里会修改原始 bytes 的内存(如果是从 Python 传进来的,通常是只读的)
|
|
49
|
+
# 建议先 copy 一份或者使用 bytearray
|
|
50
|
+
cdef bytearray tmp = bytearray(val)
|
|
51
|
+
cdef char* c_raw = tmp
|
|
52
|
+
cdef size_t new_len = _unquote_optimized(c_raw, 0)
|
|
53
|
+
return tmp[:new_len].decode(encoding)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def unquote_plus(val: bytes, encoding: str = "utf-8") -> str:
|
|
57
|
+
if not val: return ""
|
|
58
|
+
# 注意:这里会修改原始 bytes 的内存(如果是从 Python 传进来的,通常是只读的)
|
|
59
|
+
# 建议先 copy 一份或者使用 bytearray
|
|
60
|
+
cdef bytearray tmp = bytearray(val)
|
|
61
|
+
cdef char* c_raw = tmp
|
|
62
|
+
cdef size_t new_len = _unquote_optimized(c_raw, 1)
|
|
63
|
+
return tmp[:new_len].decode(encoding)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def parse_qsl(bytes qs, bint keep_blank_values = False):
|
|
67
|
+
if not qs:
|
|
68
|
+
return []
|
|
69
|
+
|
|
70
|
+
cdef list r = [] # 修复未定义错误
|
|
71
|
+
query_args = qs.split(b'&')
|
|
72
|
+
|
|
73
|
+
for name_value in query_args:
|
|
74
|
+
if not name_value:
|
|
75
|
+
continue
|
|
76
|
+
nv = name_value.split(b'=')
|
|
77
|
+
|
|
78
|
+
if len(nv) < 2:
|
|
79
|
+
if not keep_blank_values:
|
|
80
|
+
continue
|
|
81
|
+
name = unquote_plus(nv[0])
|
|
82
|
+
value = ""
|
|
83
|
+
else:
|
|
84
|
+
name = unquote_plus(nv[0])
|
|
85
|
+
value = unquote_plus(nv[1])
|
|
86
|
+
|
|
87
|
+
r.append((name, value))
|
|
88
|
+
return r
|
|
@@ -223,7 +223,7 @@ class APIResponse(JSONResponse):
|
|
|
223
223
|
|
|
224
224
|
def init_headers(self, headers: Mapping[str, str] | None = None) -> None:
|
|
225
225
|
self.raw_headers = [
|
|
226
|
-
(b"content-length",
|
|
226
|
+
(b"content-length", f"{len(self.body)}".encode("latin-1")),
|
|
227
227
|
(b"content-type", b"application/json; charset=utf-8"),
|
|
228
228
|
]
|
|
229
229
|
if headers:
|
|
@@ -234,8 +234,6 @@ class APIResponse(JSONResponse):
|
|
|
234
234
|
self.raw_headers.extend(raw_headers)
|
|
235
235
|
|
|
236
236
|
|
|
237
|
-
|
|
238
|
-
|
|
239
237
|
class APIError(Exception):
|
|
240
238
|
__slots__ = ("code", "message")
|
|
241
239
|
|
|
@@ -4,7 +4,7 @@ __date__ = "2024-12-25"
|
|
|
4
4
|
|
|
5
5
|
import datetime
|
|
6
6
|
import decimal
|
|
7
|
-
from typing import Annotated, Any, TypeVar, Union
|
|
7
|
+
from typing import Annotated, Any, ParamSpec, TypeVar, Union
|
|
8
8
|
|
|
9
9
|
from pydantic import BaseModel, PlainSerializer
|
|
10
10
|
from sqlmodel import SQLModel
|
|
@@ -19,8 +19,10 @@ T = TypeVar("T", bound=Any)
|
|
|
19
19
|
E = TypeVar("E", bound=Exception)
|
|
20
20
|
C = TypeVar("C", bound=Comparable)
|
|
21
21
|
S = TypeVar("S", bound=Serializable)
|
|
22
|
+
P = ParamSpec("P")
|
|
22
23
|
Schema = TypeVar("Schema", bound=BaseModel)
|
|
23
24
|
Model = TypeVar("Model", bound=SQLModel)
|
|
25
|
+
NoneType = type(None)
|
|
24
26
|
|
|
25
27
|
Cursor = Annotated[int, PlainSerializer(lambda x: str(x), return_type=str)]
|
|
26
28
|
LocalDateTime = Annotated[
|
|
@@ -27,6 +27,7 @@ fastapi_extra/database/__init__.py
|
|
|
27
27
|
fastapi_extra/database/model.py
|
|
28
28
|
fastapi_extra/database/service.py
|
|
29
29
|
fastapi_extra/database/session.py
|
|
30
|
+
fastapi_extra/database/sqlmap.py
|
|
30
31
|
fastapi_extra/native/cursor.pyx
|
|
31
32
|
fastapi_extra/native/routing.pyx
|
|
32
33
|
fastapi_extra/native/urlparse.pyx
|
|
@@ -1,191 +0,0 @@
|
|
|
1
|
-
__author__ = "ziyan.yin"
|
|
2
|
-
__describe__ = ""
|
|
3
|
-
|
|
4
|
-
cimport cython
|
|
5
|
-
|
|
6
|
-
from starlette import _utils as starlette_utils
|
|
7
|
-
from starlette.datastructures import URL
|
|
8
|
-
from starlette.responses import RedirectResponse
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
cdef int find_params(unicode path):
|
|
12
|
-
for i, ch in enumerate(path):
|
|
13
|
-
if ch == "{":
|
|
14
|
-
return i
|
|
15
|
-
return -1
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
cdef int get_longest_common_prefix(unicode path, unicode node_path):
|
|
19
|
-
cdef int i
|
|
20
|
-
cdef int max_len = min(len(path), len(node_path))
|
|
21
|
-
for i in range(max_len):
|
|
22
|
-
if path[i] != node_path[i]:
|
|
23
|
-
return i
|
|
24
|
-
return max_len
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
@cython.no_gc
|
|
28
|
-
cdef class RouteNode:
|
|
29
|
-
|
|
30
|
-
cdef readonly:
|
|
31
|
-
unicode prefix
|
|
32
|
-
list params_routes
|
|
33
|
-
list static_routes
|
|
34
|
-
dict children
|
|
35
|
-
|
|
36
|
-
cdef public object parent
|
|
37
|
-
|
|
38
|
-
def __cinit__(self, prefix: str):
|
|
39
|
-
self.prefix = prefix
|
|
40
|
-
self.params_routes = []
|
|
41
|
-
self.static_routes = []
|
|
42
|
-
self.children = {}
|
|
43
|
-
self.parent = None
|
|
44
|
-
|
|
45
|
-
def add_route(self, fullpath: str, handler: object):
|
|
46
|
-
wild_child = False
|
|
47
|
-
if (index := find_params(fullpath)) >= 0:
|
|
48
|
-
wild_child = True
|
|
49
|
-
path = fullpath[:index]
|
|
50
|
-
else:
|
|
51
|
-
path = fullpath
|
|
52
|
-
insert_route(self, path, wild_child, handler)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
cdef void insert_route(RouteNode node, unicode path, bint wild_child, object handler):
|
|
56
|
-
if node.prefix == path:
|
|
57
|
-
add_node(node, wild_child, handler)
|
|
58
|
-
return
|
|
59
|
-
|
|
60
|
-
cdef Py_UCS4 key = path.removeprefix(node.prefix)[0]
|
|
61
|
-
if key not in node.children:
|
|
62
|
-
add_child_node(node, key, path, wild_child, handler)
|
|
63
|
-
return
|
|
64
|
-
|
|
65
|
-
child_node = node.children[key]
|
|
66
|
-
i = get_longest_common_prefix(child_node.prefix, path)
|
|
67
|
-
longest_prefix = child_node.prefix[0: i]
|
|
68
|
-
if i == len(child_node.prefix):
|
|
69
|
-
insert_route(node.children[key], path, wild_child, handler)
|
|
70
|
-
return
|
|
71
|
-
next_node = RouteNode.__new__(RouteNode, longest_prefix)
|
|
72
|
-
next_node.parent = node
|
|
73
|
-
node.children[key] = next_node
|
|
74
|
-
next_node.children[child_node.prefix[i]] = child_node
|
|
75
|
-
child_node.parent = next_node
|
|
76
|
-
insert_route(next_node, path, wild_child, handler)
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
cdef inline void add_child_node(RouteNode node, Py_UCS4 key, unicode path, bint wild_child, object handler):
|
|
80
|
-
child = RouteNode.__new__(RouteNode, path)
|
|
81
|
-
child.parent = node
|
|
82
|
-
add_node(child, wild_child, handler)
|
|
83
|
-
node.children[key] = child
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
cdef inline void add_node(RouteNode node, bint wild_child, object handler):
|
|
87
|
-
if wild_child:
|
|
88
|
-
node.params_routes.append(handler)
|
|
89
|
-
else:
|
|
90
|
-
node.static_routes.append(handler)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
root_node = RouteNode.__new__(RouteNode, "")
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
cdef RouteNode search_node(unicode url):
|
|
97
|
-
cdef RouteNode current_node = root_node
|
|
98
|
-
cdef int n = len(url)
|
|
99
|
-
cdef int i = get_longest_common_prefix(url, current_node.prefix)
|
|
100
|
-
|
|
101
|
-
while i < n:
|
|
102
|
-
key = url[i]
|
|
103
|
-
if key not in current_node.children:
|
|
104
|
-
break
|
|
105
|
-
current_node = current_node.children[key]
|
|
106
|
-
i = get_longest_common_prefix(url, current_node.prefix)
|
|
107
|
-
|
|
108
|
-
return current_node
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
async def handle(scope, receive, send):
|
|
112
|
-
router = scope["app"].router
|
|
113
|
-
assert scope["type"] in ("http", "websocket", "lifespan")
|
|
114
|
-
|
|
115
|
-
if "router" not in scope:
|
|
116
|
-
scope["router"] = router
|
|
117
|
-
|
|
118
|
-
if scope["type"] == "lifespan":
|
|
119
|
-
await router.lifespan(scope, receive, send)
|
|
120
|
-
return
|
|
121
|
-
|
|
122
|
-
partial = None
|
|
123
|
-
|
|
124
|
-
scope["path"] = route_path = starlette_utils.get_route_path(scope)
|
|
125
|
-
leaf_node = search_node(route_path)
|
|
126
|
-
|
|
127
|
-
if leaf_node.prefix == route_path:
|
|
128
|
-
for route in leaf_node.static_routes:
|
|
129
|
-
match, child_scope = route.matches(scope)
|
|
130
|
-
if match.value == 2:
|
|
131
|
-
scope.update(child_scope)
|
|
132
|
-
await route.handle(scope, receive, send)
|
|
133
|
-
return
|
|
134
|
-
elif match.value == 1 and partial is None:
|
|
135
|
-
partial = route
|
|
136
|
-
partial_scope = child_scope
|
|
137
|
-
else:
|
|
138
|
-
current_node = leaf_node
|
|
139
|
-
routes = current_node.params_routes
|
|
140
|
-
while current_node.parent:
|
|
141
|
-
for route in routes:
|
|
142
|
-
match, child_scope = route.matches(scope)
|
|
143
|
-
if match.value == 2:
|
|
144
|
-
scope.update(child_scope)
|
|
145
|
-
await route.handle(scope, receive, send)
|
|
146
|
-
return
|
|
147
|
-
elif match.value == 1 and partial is None:
|
|
148
|
-
partial = route
|
|
149
|
-
partial_scope = child_scope
|
|
150
|
-
current_node = current_node.parent
|
|
151
|
-
|
|
152
|
-
if partial is not None:
|
|
153
|
-
scope.update(partial_scope)
|
|
154
|
-
await partial.handle(scope, receive, send)
|
|
155
|
-
return
|
|
156
|
-
|
|
157
|
-
if scope["type"] == "http" and router.redirect_slashes and route_path != "/":
|
|
158
|
-
redirect_scope = dict(scope)
|
|
159
|
-
if route_path.endswith("/"):
|
|
160
|
-
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
|
|
161
|
-
else:
|
|
162
|
-
redirect_scope["path"] = redirect_scope["path"] + "/"
|
|
163
|
-
|
|
164
|
-
if leaf_node.prefix == redirect_scope["path"]:
|
|
165
|
-
for route in leaf_node.static_routes:
|
|
166
|
-
match, child_scope = route.matches(redirect_scope)
|
|
167
|
-
if match.value != 0:
|
|
168
|
-
redirect_url = URL(scope=redirect_scope)
|
|
169
|
-
response = RedirectResponse(url=str(redirect_url))
|
|
170
|
-
await response(scope, receive, send)
|
|
171
|
-
return
|
|
172
|
-
else:
|
|
173
|
-
current_node = leaf_node
|
|
174
|
-
routes = current_node.params_routes
|
|
175
|
-
while current_node.parent:
|
|
176
|
-
for route in routes:
|
|
177
|
-
if match.value != 0:
|
|
178
|
-
redirect_url = URL(scope=redirect_scope)
|
|
179
|
-
response = RedirectResponse(url=str(redirect_url))
|
|
180
|
-
await response(scope, receive, send)
|
|
181
|
-
return
|
|
182
|
-
current_node = current_node.parent
|
|
183
|
-
|
|
184
|
-
await router.default(scope, receive, send)
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
def install(app):
|
|
188
|
-
for route in app.routes:
|
|
189
|
-
root_node.add_route(route.path, route)
|
|
190
|
-
|
|
191
|
-
app.router.middleware_stack = handle
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
__author__ = "ziyan.yin"
|
|
2
|
-
__describe__ = ""
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from libc.stdlib cimport strtol
|
|
6
|
-
from libc.string cimport memmove, strlen
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
cdef inline size_t _unquote(char* c_string, bint change_plus):
|
|
10
|
-
cdef:
|
|
11
|
-
int i = 0
|
|
12
|
-
char[2] quote
|
|
13
|
-
size_t n = strlen(c_string)
|
|
14
|
-
|
|
15
|
-
while i < n:
|
|
16
|
-
if c_string[i] == '+' and change_plus:
|
|
17
|
-
c_string[i] = ' '
|
|
18
|
-
elif c_string[i] == '%':
|
|
19
|
-
quote[0] = c_string[i + 1]
|
|
20
|
-
quote[1] = c_string[i + 2]
|
|
21
|
-
c_string[i] = strtol(quote, NULL, 16)
|
|
22
|
-
memmove(c_string + i + 1, c_string + i + 3, n - i - 2)
|
|
23
|
-
n -= 2
|
|
24
|
-
i += 1
|
|
25
|
-
return n
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def unquote(val: bytes, encoding: str = "utf-8") -> str:
|
|
29
|
-
return val[:_unquote(val, 0)].decode(encoding)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def unquote_plus(val: bytes, encoding: str = "utf-8") -> str:
|
|
33
|
-
return val[:_unquote(val, 1)].decode(encoding)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def parse_qsl(qs: bytes, keep_blank_values: bool = False) -> list[tuple[str, str]]:
|
|
37
|
-
query_args = qs.split(b'&') if qs else []
|
|
38
|
-
r = []
|
|
39
|
-
for name_value in query_args:
|
|
40
|
-
if not name_value:
|
|
41
|
-
continue
|
|
42
|
-
nv = name_value.split(b'=')
|
|
43
|
-
if len(nv) < 2:
|
|
44
|
-
if not keep_blank_values:
|
|
45
|
-
continue
|
|
46
|
-
nv.append(b'')
|
|
47
|
-
name = unquote_plus(nv[0])
|
|
48
|
-
value = unquote_plus(nv[1])
|
|
49
|
-
r.append((name, value))
|
|
50
|
-
return r
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|