fastapi-extra 0.3.5__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/PKG-INFO +1 -1
  2. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/__init__.py +1 -1
  3. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/database/service.py +17 -2
  4. fastapi_extra-0.4.0/fastapi_extra/database/sqlmap.py +154 -0
  5. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/native/cursor.pyx +20 -8
  6. fastapi_extra-0.4.0/fastapi_extra/native/routing.pyx +196 -0
  7. fastapi_extra-0.4.0/fastapi_extra/native/urlparse.pyx +88 -0
  8. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/response.py +1 -3
  9. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/types.py +3 -1
  10. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/PKG-INFO +1 -1
  11. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/SOURCES.txt +1 -0
  12. fastapi_extra-0.3.5/fastapi_extra/native/routing.pyx +0 -191
  13. fastapi_extra-0.3.5/fastapi_extra/native/urlparse.pyx +0 -50
  14. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/LICENSE +0 -0
  15. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/README.rst +0 -0
  16. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/_patch.py +0 -0
  17. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/cache/__init__.py +0 -0
  18. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/cache/redis.py +0 -0
  19. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/cursor.pyi +0 -0
  20. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/database/__init__.py +0 -0
  21. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/database/model.py +0 -0
  22. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/database/session.py +0 -0
  23. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/dependency.py +0 -0
  24. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/form.py +0 -0
  25. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/py.typed +0 -0
  26. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/routing.pyi +0 -0
  27. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/settings.py +0 -0
  28. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/urlparse.pyi +0 -0
  29. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra/utils.py +0 -0
  30. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/dependency_links.txt +0 -0
  31. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/requires.txt +0 -0
  32. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/fastapi_extra.egg-info/top_level.txt +0 -0
  33. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/pyproject.toml +0 -0
  34. {fastapi_extra-0.3.5 → fastapi_extra-0.4.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastapi-extra
3
- Version: 0.3.5
3
+ Version: 0.4.0
4
4
  Summary: extra package for fastapi.
5
5
  Author-email: Ziyan Yin <408856732@qq.com>
6
6
  License: BSD-3-Clause
@@ -1,4 +1,4 @@
1
- __version__ = "0.3.5"
1
+ __version__ = "0.4.0"
2
2
 
3
3
 
4
4
  from fastapi import FastAPI
@@ -2,13 +2,18 @@ __author__ = "ziyan.yin"
2
2
  __date__ = "2025-01-12"
3
3
 
4
4
  from contextvars import ContextVar
5
- from typing import Any, Generic, TypeVar
5
+ from typing import Any, Generic, Sequence, TypeVar
6
+
7
+ from sqlalchemy import ColumnExpressionArgument
8
+ from sqlmodel import select
6
9
 
7
10
  from fastapi_extra.database.model import SQLModel
8
11
  from fastapi_extra.database.session import AsyncSession, DefaultSession
9
12
  from fastapi_extra.dependency import AbstractService
10
13
 
11
14
  Model = TypeVar("Model", bound=SQLModel)
15
+ ID = int | str
16
+ PK = ID | tuple[ID] | dict[str, ID]
12
17
 
13
18
 
14
19
  class ModelService(AbstractService, Generic[Model], abstract=True):
@@ -40,14 +45,24 @@ class ModelService(AbstractService, Generic[Model], abstract=True):
40
45
  assert _session is not None, "Session is not initialized"
41
46
  return _session
42
47
 
43
- async def get(self, ident: int | str, **kwargs: Any) -> Model | None:
48
+ async def get(self, ident: PK, **kwargs: Any) -> Model | None:
44
49
  return await self.session.get(self.__model__, ident, **kwargs)
50
+
51
+ async def get_list(self, *clause: ColumnExpressionArgument[bool] | bool) -> Sequence[Model]:
52
+ return (await self.session.exec(select(self.__model__).where(*clause))).all()
45
53
 
46
54
  async def create_model(self, **kwargs: Any) -> Model:
47
55
  model = self.__model__.model_validate(kwargs)
48
56
  self.session.add(model)
49
57
  await self.session.flush()
50
58
  return model
59
+
60
+ async def update_model(self, model: Model, **kwargs: Any) -> Model:
61
+ for key, value in kwargs.items():
62
+ if key in model.__fields__:
63
+ setattr(model, key, value)
64
+ await self.session.flush()
65
+ return model
51
66
 
52
67
  async def delete(self, model: Model) -> None:
53
68
  return await self.session.delete(model)
@@ -0,0 +1,154 @@
1
+ __author__ = "ziyan.yin"
2
+ __date__ = "2026-05-14"
3
+
4
+
5
+ import inspect
6
+ from collections.abc import Sequence
7
+ from pathlib import Path
8
+ from typing import (Any, Callable, Protocol, get_args, get_origin,
9
+ get_type_hints)
10
+
11
+ from pydantic import BaseModel, Field
12
+ from sqlmodel import text
13
+ from sqlmodel.ext.asyncio.session import AsyncSession
14
+
15
+ from fastapi_extra.settings import Settings
16
+ from fastapi_extra.types import NoneType, P, T
17
+
18
+
19
+ class SQLTemplateConfig(BaseModel):
20
+ path: str = Field(default="./template/sql")
21
+ suffix: str = Field(default=".sql")
22
+
23
+
24
+ class SQLTemplateSettings(Settings):
25
+ sqlmap: SQLTemplateConfig = Field(alias="sqlmap")
26
+
27
+
28
+ _settings = SQLTemplateSettings() # type: ignore
29
+ _templates = {}
30
+
31
+
32
+ def install() -> None:
33
+ path = Path(_settings.sqlmap.path)
34
+ if not path.exists():
35
+ return
36
+ for file in path.glob(f"*{_settings.sqlmap.suffix}"):
37
+ _templates[file.stem] = text(file.read_text(encoding="utf-8"))
38
+
39
+
40
+ def _is_base_model(cls: Any) -> bool:
41
+ """安全判断是否为 Pydantic BaseModel"""
42
+ try:
43
+ return isinstance(cls, type) and issubclass(cls, BaseModel)
44
+ except TypeError:
45
+ return False
46
+
47
+
48
+ class SessionService(Protocol):
49
+ session: AsyncSession
50
+
51
+
52
+ def Mapped(func: Callable[P, T]) -> Callable[P, T]:
53
+ func_name = func.__name__
54
+
55
+ sig = inspect.signature(func)
56
+ hints = get_type_hints(func)
57
+ return_hint = hints.get("return", NoneType)
58
+
59
+ origin = get_origin(return_hint)
60
+ args = get_args(return_hint)
61
+
62
+ # 1. 预校验 SQL 模板是否存在
63
+ if func_name not in _templates:
64
+ # 这里可以选择抛异常或者在运行时加载
65
+ pass
66
+
67
+ # 处理 NoneType / 返回为空
68
+ if return_hint is NoneType:
69
+ async def execute_only(self: SessionService, *args: Any, **kwargs: Any) -> None:
70
+ bound_args = sig.bind(self, *args, **kwargs)
71
+ bound_args.apply_defaults()
72
+
73
+ # 3. 构造 SQL 参数字典
74
+ # 排除掉第一个参数 (self 或 service),剩下的传给 SQL execute
75
+ all_params = bound_args.arguments
76
+ param_names = list(all_params.keys())
77
+ if param_names:
78
+ sql_params = {k: all_params[k] for k in param_names[1:]}
79
+ else:
80
+ sql_params = {}
81
+
82
+ await self.session.execute(_templates[func_name], params=sql_params)
83
+
84
+ execute_only.__signature__ = sig
85
+ return execute_only # type: ignore
86
+
87
+ # 2. 解析返回逻辑
88
+ # 处理列表/序列返回: List[User], Sequence[User]
89
+ if origin is not None and issubclass(origin, Sequence):
90
+ inner_type = args[0] if args else Any
91
+ is_model = _is_base_model(inner_type)
92
+
93
+ async def fetch_all(self: SessionService, *args: Any, **kwargs: Any) -> Sequence:
94
+ bound_args = sig.bind(self, *args, **kwargs)
95
+ bound_args.apply_defaults()
96
+
97
+ # 3. 构造 SQL 参数字典
98
+ # 排除掉第一个参数 (self 或 service),剩下的传给 SQL execute
99
+ all_params = bound_args.arguments
100
+ param_names = list(all_params.keys())
101
+ if param_names:
102
+ sql_params = {k: all_params[k] for k in param_names[1:]}
103
+ else:
104
+ sql_params = {}
105
+
106
+ sql = _templates[func_name]
107
+ result = await self.session.execute(sql, params=sql_params)
108
+ # 使用 scalars() 获取单列结果,如果是多列会自动映射到 Row
109
+ items = result.all() if is_model else result.scalars().all()
110
+ if is_model:
111
+ return [inner_type.model_validate(row, from_attributes=True) for row in items] # type: ignore
112
+ return items
113
+
114
+ fetch_all.__signature__ = sig
115
+ return fetch_all # type: ignore
116
+
117
+ # 处理单体返回 (Optional[User] 或 User)
118
+ is_nullable = NoneType in args
119
+ # 提取实际类型 (处理 Optional[User] 拿到 User)
120
+ actual_type = next((a for a in args if a is not NoneType), return_hint)
121
+ is_model = _is_base_model(actual_type)
122
+
123
+ async def fetch_one(self: SessionService, *args: Any, **kwargs: Any) -> Any:
124
+ bound_args = sig.bind(self, *args, **kwargs)
125
+ bound_args.apply_defaults()
126
+
127
+ # 3. 构造 SQL 参数字典
128
+ # 排除掉第一个参数 (self 或 service),剩下的传给 SQL execute
129
+ all_params = bound_args.arguments
130
+ param_names = list(all_params.keys())
131
+ if param_names:
132
+ sql_params = {k: all_params[k] for k in param_names[1:]}
133
+ else:
134
+ sql_params = {}
135
+
136
+ sql = _templates[func_name]
137
+ result = await self.session.execute(sql, params=sql_params)
138
+
139
+ # 根据是否是模型决定取 row 还是 scalar
140
+ data = result.first() if is_model else result.scalar_one_or_none()
141
+
142
+ if data is None:
143
+ if is_nullable:
144
+ return None
145
+ raise ValueError(f"Query {func_name} expected a result but got None")
146
+
147
+ if is_model:
148
+ # 将 Row 转为 dict 再校验
149
+ return actual_type.model_validate(data, from_attributes=True)
150
+ return data
151
+
152
+ fetch_one.__signature__ = sig
153
+ return fetch_one # type: ignore
154
+
@@ -1,7 +1,12 @@
1
- cimport cython
2
- from cpython cimport time
1
+ # cython: language_level=3
2
+ # cython: boundscheck=False
3
+ # cython: wraparound=False
4
+ # cython: initializedcheck=False
5
+ # cython: nonecheck=False
6
+
3
7
 
4
- import datetime
8
+ cimport cython
9
+ from cpython cimport datetime, time
5
10
 
6
11
 
7
12
  cdef int _sequence_length = 10
@@ -17,16 +22,19 @@ cdef class Cursor:
17
22
  long long last_point
18
23
 
19
24
  def __init__(self, seed: int):
20
- self.seed = seed
21
- if self.seed > 16:
22
- self.seed %= 16
25
+ self.seed = seed % 16
23
26
  self.cursor = 0
24
27
  self.last_point = 0
25
28
 
26
29
  cdef inline long long fetch(self) nogil:
27
30
  cdef:
28
31
  long long count = 0
29
- long long point = int((time.time() - _start_point) * 100)
32
+ # 使用 10ms 作为刻度
33
+ long long point = <long long>((time.time() - _start_point) * 100)
34
+
35
+ # 简单的时钟回拨处理
36
+ if point < self.last_point:
37
+ point = self.last_point
30
38
 
31
39
  if self.last_point == point:
32
40
  count = self.cursor + 1
@@ -34,10 +42,14 @@ cdef class Cursor:
34
42
  return 0
35
43
  else:
36
44
  self.last_point = point
37
- self.cursor = count
45
+
46
+ self.cursor = <int>count
47
+ # ID 组成: 时间戳(高位) + 种子(4位) + 序列号(10位)
38
48
  return (point << (_sequence_length + 4)) + (self.seed << _sequence_length) + count
39
49
 
40
50
  def next_val(self) -> int:
51
+ cdef long long index
52
+
41
53
  index = self.fetch()
42
54
  while index == 0:
43
55
  index = self.fetch()
@@ -0,0 +1,196 @@
1
+ # cython: language_level=3
2
+ # cython: boundscheck=False
3
+ # cython: wraparound=False
4
+ # cython: initializedcheck=False
5
+ # cython: nonecheck=False
6
+
7
+
8
+ __author__ = "ziyan.yin"
9
+ __describe__ = ""
10
+
11
+ cimport cython
12
+
13
+ from starlette import _utils as starlette_utils
14
+ from starlette.datastructures import URL
15
+ from starlette.responses import RedirectResponse
16
+
17
+ # 使用 Fast Enum 思想定义常量,减少属性查找
18
+ DEF MATCH_NONE = 0
19
+ DEF MATCH_PARTIAL = 1
20
+ DEF MATCH_FULL = 2
21
+
22
+
23
+ cdef int find_params(unicode path):
24
+ cdef Py_UCS4 ch
25
+ cdef int i
26
+ for i, ch in enumerate(path):
27
+ if ch == "{":
28
+ return i
29
+ return -1
30
+
31
+
32
+ cdef int get_longest_common_prefix(unicode path, unicode node_path):
33
+ cdef int i
34
+ cdef int path_len = len(path)
35
+ cdef int node_len = len(node_path)
36
+ cdef int max_len = path_len if path_len < node_len else node_len
37
+
38
+ for i in range(max_len):
39
+ if path[i] != node_path[i]:
40
+ return i
41
+ return max_len
42
+
43
+
44
+ @cython.no_gc
45
+ cdef class RouteNode:
46
+
47
+ cdef readonly:
48
+ unicode prefix
49
+ list params_routes
50
+ list static_routes
51
+ dict children
52
+
53
+ cdef public object parent
54
+
55
+ def __init__(self, prefix: str):
56
+ self.prefix = prefix
57
+ self.params_routes = []
58
+ self.static_routes = []
59
+ self.children = {}
60
+ self.parent = None
61
+
62
+ def add_route(self, fullpath: str, handler: object):
63
+ cdef int index = find_params(fullpath)
64
+ cdef bint wild_child = index >= 0
65
+ cdef unicode path = fullpath[:index] if wild_child else fullpath
66
+ insert_route(self, path, wild_child, handler)
67
+
68
+
69
+ cdef void insert_route(RouteNode node, unicode path, bint wild_child, object handler):
70
+ if node.prefix == path:
71
+ if wild_child:
72
+ node.params_routes.append(handler)
73
+ else:
74
+ node.static_routes.append(handler)
75
+ return
76
+
77
+ # 性能优化:避免 removeprefix 产生新字符串
78
+ cdef int prefix_len = len(node.prefix)
79
+ cdef Py_UCS4 key = path[prefix_len]
80
+
81
+ if key not in node.children:
82
+ child = RouteNode(path)
83
+ child.parent = node
84
+ if wild_child:
85
+ child.params_routes.append(handler)
86
+ else:
87
+ child.static_routes.append(handler)
88
+ node.children[key] = child
89
+ return
90
+
91
+ cdef RouteNode child_node = node.children[key]
92
+ cdef int i = get_longest_common_prefix(child_node.prefix, path)
93
+
94
+
95
+ if i == len(child_node.prefix):
96
+ insert_route(node.children[key], path, wild_child, handler)
97
+ return
98
+ longest_prefix = child_node.prefix[:i]
99
+ new_mid_node = RouteNode(longest_prefix)
100
+ new_mid_node.parent = node
101
+
102
+ node.children[key] = new_mid_node
103
+
104
+ # 将旧子节点挂在新节点下
105
+ cdef Py_UCS4 old_key = child_node.prefix[i]
106
+ new_mid_node.children[old_key] = child_node
107
+ child_node.parent = new_mid_node
108
+
109
+ # 递归插入新路由
110
+ insert_route(new_mid_node, path, wild_child, handler)
111
+
112
+
113
+ cdef RouteNode search_node(unicode url):
114
+ cdef RouteNode current_node = ROOT
115
+ cdef int n = len(url)
116
+ cdef int i = get_longest_common_prefix(url, current_node.prefix)
117
+ cdef Py_UCS4 key
118
+
119
+ while i < n:
120
+ key = url[i]
121
+ if key not in current_node.children:
122
+ break
123
+ current_node = <RouteNode>current_node.children[key]
124
+ i = get_longest_common_prefix(url, current_node.prefix)
125
+ return current_node
126
+
127
+
128
+ async def handle(scope, receive, send):
129
+ cdef unicode route_path
130
+ cdef RouteNode leaf_node, current_node
131
+ cdef object route, match, child_scope
132
+ cdef object partial = None
133
+ cdef object partial_scope = None
134
+
135
+ cdef object router = scope["app"].router
136
+
137
+ if "router" not in scope:
138
+ scope["router"] = router
139
+
140
+ if scope["type"] == "lifespan":
141
+ await router.lifespan(scope, receive, send)
142
+ return
143
+
144
+ scope["path"] = route_path = starlette_utils.get_route_path(scope)
145
+ leaf_node = search_node(route_path)
146
+
147
+ # 1. 静态路由精确匹配
148
+ if leaf_node.prefix == route_path:
149
+ for route in leaf_node.static_routes:
150
+ match, child_scope = route.matches(scope)
151
+ if match.value == MATCH_FULL:
152
+ scope.update(child_scope)
153
+ await route.handle(scope, receive, send)
154
+ return
155
+ elif match.value == MATCH_PARTIAL and partial is None:
156
+ partial, partial_scope = route, child_scope
157
+
158
+ # 2. 参数路由/通配符路由(向上回溯)
159
+ current_node = leaf_node
160
+ while current_node is not None:
161
+ if current_node.params_routes:
162
+ for route in current_node.params_routes:
163
+ match, child_scope = route.matches(scope)
164
+ if match.value == MATCH_FULL:
165
+ scope.update(child_scope)
166
+ await route.handle(scope, receive, send)
167
+ return
168
+ elif match.value == MATCH_PARTIAL and partial is None:
169
+ partial, partial_scope = route, child_scope
170
+ current_node = current_node.parent
171
+
172
+ if partial is not None:
173
+ scope.update(partial_scope)
174
+ await partial.handle(scope, receive, send)
175
+ return
176
+
177
+ if scope["type"] == "http" and router.redirect_slashes and route_path != "/":
178
+ new_path = route_path.rstrip("/") if route_path.endswith("/") else route_path + "/"
179
+ redirect_scope = {**scope, "path": new_path}
180
+
181
+ if leaf_node.prefix == new_path and leaf_node.static_routes:
182
+ redirect_url = URL(scope=redirect_scope)
183
+ response = RedirectResponse(url=str(redirect_url))
184
+ await response(scope, receive, send)
185
+ return
186
+
187
+ await router.default(scope, receive, send)
188
+
189
+
190
+ ROOT = RouteNode("")
191
+
192
+
193
+ def install(app):
194
+ for route in app.routes:
195
+ ROOT.add_route(route.path, route)
196
+ app.router.middleware_stack = handle
@@ -0,0 +1,88 @@
1
+ # cython: language_level=3
2
+ # cython: boundscheck=False
3
+ # cython: wraparound=False
4
+ # cython: initializedcheck=False
5
+ # cython: nonecheck=False
6
+
7
+
8
+ __author__ = "ziyan.yin"
9
+ __describe__ = ""
10
+
11
+
12
+ from libc.stdlib cimport strtol
13
+ from libc.string cimport memmove, strlen
14
+
15
+
16
+ cdef inline size_t _unquote_optimized(char* c_str, bint change_plus) nogil:
17
+ cdef:
18
+ size_t n = strlen(c_str)
19
+ size_t read_pos = 0
20
+ size_t write_pos = 0
21
+ char hex_buf[3]
22
+
23
+ hex_buf[2] = 0 # 确保 strtol 有终止符
24
+
25
+ while read_pos < n:
26
+ if c_str[read_pos] == b'+' and change_plus:
27
+ c_str[write_pos] = b' '
28
+ read_pos += 1
29
+ write_pos += 1
30
+ elif c_str[read_pos] == b'%' and read_pos + 2 < n:
31
+ # 提取两位十六进制数
32
+ hex_buf[0] = c_str[read_pos + 1]
33
+ hex_buf[1] = c_str[read_pos + 2]
34
+ # 转换为字符
35
+ c_str[write_pos] = <char>strtol(hex_buf, NULL, 16)
36
+ read_pos += 3
37
+ write_pos += 1
38
+ else:
39
+ # 普通字符移动
40
+ c_str[write_pos] = c_str[read_pos]
41
+ read_pos += 1
42
+ write_pos += 1
43
+
44
+ return write_pos
45
+
46
+ def unquote(bytes val, str encoding = "utf-8"):
47
+ if not val: return ""
48
+ # 注意:这里会修改原始 bytes 的内存(如果是从 Python 传进来的,通常是只读的)
49
+ # 建议先 copy 一份或者使用 bytearray
50
+ cdef bytearray tmp = bytearray(val)
51
+ cdef char* c_raw = tmp
52
+ cdef size_t new_len = _unquote_optimized(c_raw, 0)
53
+ return tmp[:new_len].decode(encoding)
54
+
55
+
56
+ def unquote_plus(val: bytes, encoding: str = "utf-8") -> str:
57
+ if not val: return ""
58
+ # 注意:这里会修改原始 bytes 的内存(如果是从 Python 传进来的,通常是只读的)
59
+ # 建议先 copy 一份或者使用 bytearray
60
+ cdef bytearray tmp = bytearray(val)
61
+ cdef char* c_raw = tmp
62
+ cdef size_t new_len = _unquote_optimized(c_raw, 1)
63
+ return tmp[:new_len].decode(encoding)
64
+
65
+
66
+ def parse_qsl(bytes qs, bint keep_blank_values = False):
67
+ if not qs:
68
+ return []
69
+
70
+ cdef list r = [] # 修复未定义错误
71
+ query_args = qs.split(b'&')
72
+
73
+ for name_value in query_args:
74
+ if not name_value:
75
+ continue
76
+ nv = name_value.split(b'=')
77
+
78
+ if len(nv) < 2:
79
+ if not keep_blank_values:
80
+ continue
81
+ name = unquote_plus(nv[0])
82
+ value = ""
83
+ else:
84
+ name = unquote_plus(nv[0])
85
+ value = unquote_plus(nv[1])
86
+
87
+ r.append((name, value))
88
+ return r
@@ -223,7 +223,7 @@ class APIResponse(JSONResponse):
223
223
 
224
224
  def init_headers(self, headers: Mapping[str, str] | None = None) -> None:
225
225
  self.raw_headers = [
226
- (b"content-length", str(len(self.body)).encode("latin-1")),
226
+ (b"content-length", f"{len(self.body)}".encode("latin-1")),
227
227
  (b"content-type", b"application/json; charset=utf-8"),
228
228
  ]
229
229
  if headers:
@@ -234,8 +234,6 @@ class APIResponse(JSONResponse):
234
234
  self.raw_headers.extend(raw_headers)
235
235
 
236
236
 
237
-
238
-
239
237
  class APIError(Exception):
240
238
  __slots__ = ("code", "message")
241
239
 
@@ -4,7 +4,7 @@ __date__ = "2024-12-25"
4
4
 
5
5
  import datetime
6
6
  import decimal
7
- from typing import Annotated, Any, TypeVar, Union
7
+ from typing import Annotated, Any, ParamSpec, TypeVar, Union
8
8
 
9
9
  from pydantic import BaseModel, PlainSerializer
10
10
  from sqlmodel import SQLModel
@@ -19,8 +19,10 @@ T = TypeVar("T", bound=Any)
19
19
  E = TypeVar("E", bound=Exception)
20
20
  C = TypeVar("C", bound=Comparable)
21
21
  S = TypeVar("S", bound=Serializable)
22
+ P = ParamSpec("P")
22
23
  Schema = TypeVar("Schema", bound=BaseModel)
23
24
  Model = TypeVar("Model", bound=SQLModel)
25
+ NoneType = type(None)
24
26
 
25
27
  Cursor = Annotated[int, PlainSerializer(lambda x: str(x), return_type=str)]
26
28
  LocalDateTime = Annotated[
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastapi-extra
3
- Version: 0.3.5
3
+ Version: 0.4.0
4
4
  Summary: extra package for fastapi.
5
5
  Author-email: Ziyan Yin <408856732@qq.com>
6
6
  License: BSD-3-Clause
@@ -27,6 +27,7 @@ fastapi_extra/database/__init__.py
27
27
  fastapi_extra/database/model.py
28
28
  fastapi_extra/database/service.py
29
29
  fastapi_extra/database/session.py
30
+ fastapi_extra/database/sqlmap.py
30
31
  fastapi_extra/native/cursor.pyx
31
32
  fastapi_extra/native/routing.pyx
32
33
  fastapi_extra/native/urlparse.pyx
@@ -1,191 +0,0 @@
1
- __author__ = "ziyan.yin"
2
- __describe__ = ""
3
-
4
- cimport cython
5
-
6
- from starlette import _utils as starlette_utils
7
- from starlette.datastructures import URL
8
- from starlette.responses import RedirectResponse
9
-
10
-
11
- cdef int find_params(unicode path):
12
- for i, ch in enumerate(path):
13
- if ch == "{":
14
- return i
15
- return -1
16
-
17
-
18
- cdef int get_longest_common_prefix(unicode path, unicode node_path):
19
- cdef int i
20
- cdef int max_len = min(len(path), len(node_path))
21
- for i in range(max_len):
22
- if path[i] != node_path[i]:
23
- return i
24
- return max_len
25
-
26
-
27
- @cython.no_gc
28
- cdef class RouteNode:
29
-
30
- cdef readonly:
31
- unicode prefix
32
- list params_routes
33
- list static_routes
34
- dict children
35
-
36
- cdef public object parent
37
-
38
- def __cinit__(self, prefix: str):
39
- self.prefix = prefix
40
- self.params_routes = []
41
- self.static_routes = []
42
- self.children = {}
43
- self.parent = None
44
-
45
- def add_route(self, fullpath: str, handler: object):
46
- wild_child = False
47
- if (index := find_params(fullpath)) >= 0:
48
- wild_child = True
49
- path = fullpath[:index]
50
- else:
51
- path = fullpath
52
- insert_route(self, path, wild_child, handler)
53
-
54
-
55
- cdef void insert_route(RouteNode node, unicode path, bint wild_child, object handler):
56
- if node.prefix == path:
57
- add_node(node, wild_child, handler)
58
- return
59
-
60
- cdef Py_UCS4 key = path.removeprefix(node.prefix)[0]
61
- if key not in node.children:
62
- add_child_node(node, key, path, wild_child, handler)
63
- return
64
-
65
- child_node = node.children[key]
66
- i = get_longest_common_prefix(child_node.prefix, path)
67
- longest_prefix = child_node.prefix[0: i]
68
- if i == len(child_node.prefix):
69
- insert_route(node.children[key], path, wild_child, handler)
70
- return
71
- next_node = RouteNode.__new__(RouteNode, longest_prefix)
72
- next_node.parent = node
73
- node.children[key] = next_node
74
- next_node.children[child_node.prefix[i]] = child_node
75
- child_node.parent = next_node
76
- insert_route(next_node, path, wild_child, handler)
77
-
78
-
79
- cdef inline void add_child_node(RouteNode node, Py_UCS4 key, unicode path, bint wild_child, object handler):
80
- child = RouteNode.__new__(RouteNode, path)
81
- child.parent = node
82
- add_node(child, wild_child, handler)
83
- node.children[key] = child
84
-
85
-
86
- cdef inline void add_node(RouteNode node, bint wild_child, object handler):
87
- if wild_child:
88
- node.params_routes.append(handler)
89
- else:
90
- node.static_routes.append(handler)
91
-
92
-
93
- root_node = RouteNode.__new__(RouteNode, "")
94
-
95
-
96
- cdef RouteNode search_node(unicode url):
97
- cdef RouteNode current_node = root_node
98
- cdef int n = len(url)
99
- cdef int i = get_longest_common_prefix(url, current_node.prefix)
100
-
101
- while i < n:
102
- key = url[i]
103
- if key not in current_node.children:
104
- break
105
- current_node = current_node.children[key]
106
- i = get_longest_common_prefix(url, current_node.prefix)
107
-
108
- return current_node
109
-
110
-
111
- async def handle(scope, receive, send):
112
- router = scope["app"].router
113
- assert scope["type"] in ("http", "websocket", "lifespan")
114
-
115
- if "router" not in scope:
116
- scope["router"] = router
117
-
118
- if scope["type"] == "lifespan":
119
- await router.lifespan(scope, receive, send)
120
- return
121
-
122
- partial = None
123
-
124
- scope["path"] = route_path = starlette_utils.get_route_path(scope)
125
- leaf_node = search_node(route_path)
126
-
127
- if leaf_node.prefix == route_path:
128
- for route in leaf_node.static_routes:
129
- match, child_scope = route.matches(scope)
130
- if match.value == 2:
131
- scope.update(child_scope)
132
- await route.handle(scope, receive, send)
133
- return
134
- elif match.value == 1 and partial is None:
135
- partial = route
136
- partial_scope = child_scope
137
- else:
138
- current_node = leaf_node
139
- routes = current_node.params_routes
140
- while current_node.parent:
141
- for route in routes:
142
- match, child_scope = route.matches(scope)
143
- if match.value == 2:
144
- scope.update(child_scope)
145
- await route.handle(scope, receive, send)
146
- return
147
- elif match.value == 1 and partial is None:
148
- partial = route
149
- partial_scope = child_scope
150
- current_node = current_node.parent
151
-
152
- if partial is not None:
153
- scope.update(partial_scope)
154
- await partial.handle(scope, receive, send)
155
- return
156
-
157
- if scope["type"] == "http" and router.redirect_slashes and route_path != "/":
158
- redirect_scope = dict(scope)
159
- if route_path.endswith("/"):
160
- redirect_scope["path"] = redirect_scope["path"].rstrip("/")
161
- else:
162
- redirect_scope["path"] = redirect_scope["path"] + "/"
163
-
164
- if leaf_node.prefix == redirect_scope["path"]:
165
- for route in leaf_node.static_routes:
166
- match, child_scope = route.matches(redirect_scope)
167
- if match.value != 0:
168
- redirect_url = URL(scope=redirect_scope)
169
- response = RedirectResponse(url=str(redirect_url))
170
- await response(scope, receive, send)
171
- return
172
- else:
173
- current_node = leaf_node
174
- routes = current_node.params_routes
175
- while current_node.parent:
176
- for route in routes:
177
- if match.value != 0:
178
- redirect_url = URL(scope=redirect_scope)
179
- response = RedirectResponse(url=str(redirect_url))
180
- await response(scope, receive, send)
181
- return
182
- current_node = current_node.parent
183
-
184
- await router.default(scope, receive, send)
185
-
186
-
187
- def install(app):
188
- for route in app.routes:
189
- root_node.add_route(route.path, route)
190
-
191
- app.router.middleware_stack = handle
@@ -1,50 +0,0 @@
1
- __author__ = "ziyan.yin"
2
- __describe__ = ""
3
-
4
-
5
- from libc.stdlib cimport strtol
6
- from libc.string cimport memmove, strlen
7
-
8
-
9
- cdef inline size_t _unquote(char* c_string, bint change_plus):
10
- cdef:
11
- int i = 0
12
- char[2] quote
13
- size_t n = strlen(c_string)
14
-
15
- while i < n:
16
- if c_string[i] == '+' and change_plus:
17
- c_string[i] = ' '
18
- elif c_string[i] == '%':
19
- quote[0] = c_string[i + 1]
20
- quote[1] = c_string[i + 2]
21
- c_string[i] = strtol(quote, NULL, 16)
22
- memmove(c_string + i + 1, c_string + i + 3, n - i - 2)
23
- n -= 2
24
- i += 1
25
- return n
26
-
27
-
28
- def unquote(val: bytes, encoding: str = "utf-8") -> str:
29
- return val[:_unquote(val, 0)].decode(encoding)
30
-
31
-
32
- def unquote_plus(val: bytes, encoding: str = "utf-8") -> str:
33
- return val[:_unquote(val, 1)].decode(encoding)
34
-
35
-
36
- def parse_qsl(qs: bytes, keep_blank_values: bool = False) -> list[tuple[str, str]]:
37
- query_args = qs.split(b'&') if qs else []
38
- r = []
39
- for name_value in query_args:
40
- if not name_value:
41
- continue
42
- nv = name_value.split(b'=')
43
- if len(nv) < 2:
44
- if not keep_blank_values:
45
- continue
46
- nv.append(b'')
47
- name = unquote_plus(nv[0])
48
- value = unquote_plus(nv[1])
49
- r.append((name, value))
50
- return r
File without changes
File without changes
File without changes