u-toolkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
u_toolkit/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ def get(obj: dict, key: str, fallback=None):
2
+ items = key.split(".", 1)
3
+ if len(items) > 1:
4
+ for i in items:
5
+ return get(obj, i, fallback)
6
+ elif items:
7
+ return obj.get(key, fallback)
8
+
9
+ return fallback
@@ -0,0 +1,49 @@
1
+ from typing import TypeVarTuple, overload
2
+
3
+ from pydantic import alias_generators
4
+ from pydantic.fields import ComputedFieldInfo, FieldInfo
5
+
6
+
7
+ Ts = TypeVarTuple("Ts")
8
+
9
+
10
+ @overload
11
+ def to_camel(string: str, _: ComputedFieldInfo | FieldInfo) -> str: ...
12
+
13
+
14
+ @overload
15
+ def to_camel(string: str) -> str: ...
16
+
17
+
18
+ def to_camel(string: str, *_, **__):
19
+ if string.isupper():
20
+ return string
21
+ return alias_generators.to_camel(string)
22
+
23
+
24
+ @overload
25
+ def to_snake(string: str, _: ComputedFieldInfo | FieldInfo) -> str: ...
26
+
27
+
28
+ @overload
29
+ def to_snake(string: str) -> str: ...
30
+
31
+
32
+ def to_snake(string: str, *_, **__):
33
+ if string.isupper():
34
+ return string
35
+ return alias_generators.to_snake(string)
36
+
37
+
38
+ @overload
39
+ def to_pascal(string: str, _: ComputedFieldInfo | FieldInfo) -> str: ...
40
+
41
+
42
+ @overload
43
+ def to_pascal(string: str) -> str: ...
44
+
45
+
46
+ def to_pascal(string: str, *_, **__):
47
+ if string.isupper():
48
+ return string
49
+ return alias_generators.to_pascal(string)
u_toolkit/datetime.py ADDED
@@ -0,0 +1,30 @@
1
+ from datetime import UTC, datetime
2
+
3
+
4
+ def to_utc(dt: datetime, /) -> datetime:
5
+ """转成 UTC 时间
6
+
7
+ :param dt: 时间
8
+ :return: UTC 时间
9
+ """
10
+ if dt.tzinfo is None:
11
+ return dt.replace(tzinfo=UTC)
12
+ return dt.astimezone(UTC)
13
+
14
+
15
+ def to_naive(dt: datetime, /) -> datetime:
16
+ """去除时区标识
17
+
18
+ :param v: 时间
19
+ :return: 不带时区标识的时间
20
+ """
21
+ return dt.replace(tzinfo=None)
22
+
23
+
24
+ def to_utc_naive(dt: datetime, /) -> datetime:
25
+ """将时间转换成不带时区标识的 UTC 时间
26
+
27
+ :param v: 时间
28
+ :return: 不带时区标识的 UTC 时间
29
+ """
30
+ return to_naive(to_utc(dt))
@@ -0,0 +1,46 @@
1
+ from collections.abc import Callable
2
+ from functools import wraps
3
+ from typing import Generic, NamedTuple, TypeVar
4
+
5
+ from u_toolkit.signature import list_parameters, update_parameters
6
+
7
+
8
+ _FnT = TypeVar("_FnT", bound=Callable)
9
+
10
+ _T = TypeVar("_T")
11
+
12
+
13
+ class DefineMethodParams(NamedTuple, Generic[_T, _FnT]):
14
+ method_class: type[_T]
15
+ method_name: str
16
+ method: _FnT
17
+
18
+
19
+ class DefineMethodDecorator(Generic[_T, _FnT]):
20
+ def __init__(self, fn: _FnT):
21
+ self.fn = fn
22
+ self.name = fn.__name__
23
+
24
+ def register_method(self, params: DefineMethodParams[_T, _FnT]): ...
25
+
26
+ def __set_name__(self, owner_class: type, name: str):
27
+ self.register_method(DefineMethodParams(owner_class, name, self.fn))
28
+
29
+ def __get__(self, instance: _T, owner_class: type[_T]):
30
+ update_parameters(self.fn, *list_parameters(self.fn)[1:])
31
+
32
+ @wraps(self.fn)
33
+ def wrapper(*args, **kwargs):
34
+ return self.fn(instance, *args, **kwargs)
35
+
36
+ return wrapper
37
+
38
+
39
+ def define_method_handler(
40
+ handle: Callable[[DefineMethodParams[_T, _FnT]], None],
41
+ ):
42
+ class Decorator(DefineMethodDecorator):
43
+ def register_method(self, params: DefineMethodParams):
44
+ handle(params)
45
+
46
+ return Decorator
u_toolkit/enum.py ADDED
@@ -0,0 +1,43 @@
1
+ from enum import StrEnum, auto
2
+
3
+ from .alias_generators import to_camel, to_pascal, to_snake
4
+
5
+
6
+ __all__ = [
7
+ "CamelEnum",
8
+ "NameEnum",
9
+ "PascalEnum",
10
+ "SnakeEnum",
11
+ "TitleEnum",
12
+ "auto",
13
+ ]
14
+
15
+
16
+ class NameEnum(StrEnum):
17
+ @staticmethod
18
+ def _generate_next_value_(name, *_, **__) -> str:
19
+ return name
20
+
21
+
22
+ class PascalEnum(StrEnum):
23
+ @staticmethod
24
+ def _generate_next_value_(name, *_, **__) -> str:
25
+ return to_pascal(name)
26
+
27
+
28
+ class CamelEnum(StrEnum):
29
+ @staticmethod
30
+ def _generate_next_value_(name, *_, **__) -> str:
31
+ return to_camel(name)
32
+
33
+
34
+ class SnakeEnum(StrEnum):
35
+ @staticmethod
36
+ def _generate_next_value_(name, *_, **__) -> str:
37
+ return to_snake(name)
38
+
39
+
40
+ class TitleEnum(StrEnum):
41
+ @staticmethod
42
+ def _generate_next_value_(name, *_, **__) -> str:
43
+ return name.replace("_", " ").title()
File without changes
@@ -0,0 +1,342 @@
1
+ import inspect
2
+ import re
3
+ from collections.abc import Callable
4
+ from enum import Enum, StrEnum, auto
5
+ from functools import partial, update_wrapper, wraps
6
+ from typing import Any, Literal, NamedTuple, Protocol, Self, TypeVar, cast
7
+
8
+ from fastapi import APIRouter, Depends
9
+ from pydantic.alias_generators import to_snake
10
+
11
+ from u_toolkit.decorators import DefineMethodParams, define_method_handler
12
+ from u_toolkit.fastapi.helpers import get_depend_from_annotation, is_depend
13
+ from u_toolkit.fastapi.responses import Response, build_responses
14
+ from u_toolkit.helpers import is_annotated
15
+ from u_toolkit.merge import deep_merge_dict
16
+ from u_toolkit.signature import update_parameters, with_parameter
17
+
18
+
19
+ class EndpointsClassInterface(Protocol):
20
+ dependencies: tuple | None = None
21
+ responses: tuple[Response, ...] | None = None
22
+ prefix: str | None = None
23
+ tags: tuple[str | Enum, ...] | None = None
24
+ deprecated: bool | None = None
25
+
26
+ @classmethod
27
+ def build_self(cls) -> Self: ...
28
+
29
+
30
+ _T = TypeVar("_T")
31
+ EndpointsClassInterfaceT = TypeVar(
32
+ "EndpointsClassInterfaceT",
33
+ bound=EndpointsClassInterface,
34
+ )
35
+
36
+
37
+ LiteralUpperMethods = Literal[
38
+ "GET",
39
+ "POST",
40
+ "PATCH",
41
+ "PUT",
42
+ "DELETE",
43
+ "OPTIONS",
44
+ "HEAD",
45
+ "TRACE",
46
+ ]
47
+ LiteralLowerMethods = Literal[
48
+ "get",
49
+ "post",
50
+ "patch",
51
+ "put",
52
+ "delete",
53
+ "options",
54
+ "head",
55
+ "trace",
56
+ ]
57
+
58
+
59
+ class Methods(StrEnum):
60
+ GET = auto()
61
+ POST = auto()
62
+ PATCH = auto()
63
+ PUT = auto()
64
+ DELETE = auto()
65
+ OPTIONS = auto()
66
+ HEAD = auto()
67
+ TRACE = auto()
68
+
69
+
70
+ METHOD_PATTERNS = {
71
+ method: re.compile(f"^{method}", re.IGNORECASE) for method in Methods
72
+ }
73
+
74
+ _FnName = str
75
+
76
+
77
+ class EndpointInfo(NamedTuple):
78
+ fn: Callable
79
+ original_name: str
80
+ method: Methods
81
+ method_pattern: re.Pattern
82
+ path: str
83
+
84
+
85
+ def get_method(name: str):
86
+ for method, method_pattern in METHOD_PATTERNS.items():
87
+ if method_pattern.search(name):
88
+ return method, method_pattern
89
+ return None
90
+
91
+
92
+ def valid_endpoint(name: str):
93
+ if get_method(name) is None:
94
+ raise ValueError("Invalid endpoint function.")
95
+
96
+
97
+ def iter_endpoints(cls: type[_T]):
98
+ prefix = "/"
99
+
100
+ if not cls.__name__.startswith("_"):
101
+ prefix += f"{to_snake(cls.__name__)}"
102
+
103
+ for name, fn in inspect.getmembers(
104
+ cls,
105
+ lambda arg: inspect.ismethoddescriptor(arg) or inspect.isfunction(arg),
106
+ ):
107
+ paths = [prefix]
108
+
109
+ if method := get_method(name):
110
+ path = method[1].sub(name, "").replace("__", "/")
111
+ if path:
112
+ paths.append(path)
113
+
114
+ yield EndpointInfo(
115
+ fn=fn,
116
+ original_name=name,
117
+ path="/".join(paths),
118
+ method=method[0],
119
+ method_pattern=method[1],
120
+ )
121
+
122
+
123
+ def iter_dependencies(cls: type[_T]):
124
+ _split = re.compile(r"\s+|:|=")
125
+ dependencies: dict = dict(inspect.getmembers(cls, is_depend))
126
+ for name, type_ in inspect.get_annotations(cls).items():
127
+ if is_annotated(type_):
128
+ dependency = get_depend_from_annotation(type_)
129
+ dependencies[name] = dependency
130
+
131
+ for line in inspect.getsource(cls).split("\n"):
132
+ token: str = _split.split(line.strip(), 1)[0]
133
+ for name, dep in dependencies.items():
134
+ if name == token:
135
+ yield token, dep
136
+
137
+
138
+ _CBVEndpointParamName = Literal[
139
+ "tags",
140
+ "dependencies",
141
+ "responses",
142
+ "response_model",
143
+ "status",
144
+ "deprecated",
145
+ "methods",
146
+ ]
147
+
148
+
149
+ class CBV:
150
+ def __init__(self, router: APIRouter | None = None) -> None:
151
+ self.router = router or APIRouter()
152
+
153
+ self._state: dict[
154
+ type[EndpointsClassInterface],
155
+ dict[_FnName, dict[_CBVEndpointParamName, Any]],
156
+ ] = {}
157
+
158
+ self._initialed_state: dict[
159
+ type[EndpointsClassInterface], EndpointsClassInterface
160
+ ] = {}
161
+
162
+ def create_route(
163
+ self,
164
+ *,
165
+ cls: type[EndpointsClassInterfaceT],
166
+ path: str,
167
+ method: Methods | LiteralUpperMethods | LiteralLowerMethods,
168
+ method_name: str,
169
+ ):
170
+ class_tags = list(cls.tags) if cls.tags else []
171
+ endpoint_tags: list[str | Enum] = (
172
+ self._state[cls][method_name].get("tags") or []
173
+ )
174
+ tags = class_tags + endpoint_tags
175
+
176
+ class_dependencies = list(cls.dependencies) if cls.dependencies else []
177
+ endpoint_dependencies = (
178
+ self._state[cls][method_name].get("dependencies") or []
179
+ )
180
+ dependencies = class_dependencies + endpoint_dependencies
181
+
182
+ class_responses = cls.responses or []
183
+ endpoint_responses = (
184
+ self._state[cls][method_name].get("responses") or []
185
+ )
186
+ responses = build_responses(*class_responses, *endpoint_responses)
187
+
188
+ status_code = self._state[cls][method_name].get("status")
189
+
190
+ deprecated = self._state[cls][method_name].get(
191
+ "deprecated", cls.deprecated
192
+ )
193
+
194
+ response_model = self._state[cls][method_name].get("response_model")
195
+
196
+ endpoint_methods = self._state[cls][method_name].get("methods") or [
197
+ method
198
+ ]
199
+
200
+ return self.router.api_route(
201
+ path,
202
+ methods=endpoint_methods,
203
+ tags=tags,
204
+ dependencies=dependencies,
205
+ response_model=response_model,
206
+ responses=responses,
207
+ status_code=status_code,
208
+ deprecated=deprecated,
209
+ )
210
+
211
+ def info(
212
+ self,
213
+ *,
214
+ methods: list[Methods | LiteralUpperMethods | LiteralLowerMethods]
215
+ | None = None,
216
+ tags: list[str | Enum] | None = None,
217
+ dependencies: list | None = None,
218
+ responses: list[Response] | None = None,
219
+ response_model: Any | None = None,
220
+ status: int | None = None,
221
+ deprecated: bool | None = None,
222
+ ):
223
+ state = self._state
224
+ initial_state = self._initial_state
225
+ data: dict[_CBVEndpointParamName, Any] = {
226
+ "methods": methods,
227
+ "tags": tags,
228
+ "dependencies": dependencies,
229
+ "responses": responses,
230
+ "response_model": response_model,
231
+ "status": status,
232
+ "deprecated": deprecated,
233
+ }
234
+
235
+ def handle(params: DefineMethodParams):
236
+ initial_state(params.method_class)
237
+ deep_merge_dict(
238
+ state,
239
+ {params.method_class: {params.method_name: data}},
240
+ )
241
+
242
+ return define_method_handler(handle)
243
+
244
+ def _initial_state(self, cls: type[_T]) -> EndpointsClassInterface:
245
+ if result := self._initialed_state.get(cls): # type: ignore
246
+ return result
247
+
248
+ self._update_cls(cls)
249
+ n_cls = cast(type[EndpointsClassInterface], cls)
250
+
251
+ default_data = {}
252
+ for endpoint in iter_endpoints(n_cls):
253
+ default_data[endpoint.original_name] = {}
254
+
255
+ self._state.setdefault(n_cls, default_data)
256
+ result = self._build_cls(n_cls)
257
+ self._initialed_state[n_cls] = result
258
+ return result
259
+
260
+ def _update_cls(self, cls: type[_T]):
261
+ for extra_name in EndpointsClassInterface.__annotations__:
262
+ if not hasattr(cls, extra_name):
263
+ setattr(cls, extra_name, None)
264
+
265
+ # TODO: 加个如果存在属性, 校验属性类型是否是预期的
266
+
267
+ def _build_cls(self, cls: type[_T]) -> _T:
268
+ if inspect.isfunction(cls.__init__) and hasattr(cls, "build_self"):
269
+ return cast(type[EndpointsClassInterface], cls).build_self() # type: ignore
270
+ return cls()
271
+
272
+ def __create_class_dependencies_injector(
273
+ self, cls: type[EndpointsClassInterfaceT]
274
+ ):
275
+ """将类的依赖添加到函数实例上
276
+
277
+ ```python
278
+ @cbv
279
+ class A:
280
+ a = Depends(lambda: id(object()))
281
+
282
+ def get(self):
283
+ # 使得每次 self.a 可以访问到当前请求的依赖
284
+ print(self.a)
285
+ ```
286
+ """
287
+
288
+ def collect_cls_dependencies(**kwargs):
289
+ return kwargs
290
+
291
+ parameters = [
292
+ inspect.Parameter(
293
+ name=name,
294
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
295
+ default=dep,
296
+ )
297
+ for name, dep in iter_dependencies(cls)
298
+ ]
299
+ update_parameters(collect_cls_dependencies, *parameters)
300
+
301
+ def decorator(method: Callable):
302
+ sign_fn = partial(method)
303
+ update_wrapper(sign_fn, method)
304
+
305
+ parameters, *_ = with_parameter(
306
+ method,
307
+ name=collect_cls_dependencies.__name__,
308
+ default=Depends(collect_cls_dependencies),
309
+ )
310
+ update_parameters(sign_fn, *parameters)
311
+
312
+ @wraps(sign_fn)
313
+ def wrapper(*args, **kwargs):
314
+ instance = self._build_cls(cls)
315
+ dependencies = kwargs.pop(collect_cls_dependencies.__name__)
316
+ for dep_name, dep_value in dependencies.items():
317
+ setattr(instance, dep_name, dep_value)
318
+ fn = getattr(instance, method.__name__)
319
+ return fn(*args, **kwargs)
320
+
321
+ return wrapper
322
+
323
+ return decorator
324
+
325
+ def __call__(self, cls: type[_T]) -> type[_T]:
326
+ instance = self._initial_state(cls)
327
+ cls_ = cast(type[EndpointsClassInterface], cls)
328
+
329
+ decorator = self.__create_class_dependencies_injector(cls_)
330
+
331
+ for endpoint_info in iter_endpoints(cls):
332
+ route = self.create_route(
333
+ cls=cast(type[EndpointsClassInterface], cls),
334
+ path=endpoint_info.path,
335
+ method=endpoint_info.method,
336
+ method_name=endpoint_info.original_name,
337
+ )
338
+ method = getattr(instance, endpoint_info.original_name)
339
+ endpoint = decorator(method)
340
+ route(endpoint)
341
+
342
+ return cls
@@ -0,0 +1,9 @@
1
+ from pydantic_settings import BaseSettings, SettingsConfigDict
2
+
3
+
4
+ class FastAPISettings(BaseSettings):
5
+ model_config = SettingsConfigDict(
6
+ env_prefix="FASTAPI_",
7
+ env_file=(".env", ".env.prod", ".env.dev", ".env.test"),
8
+ extra="ignore",
9
+ )
@@ -0,0 +1,115 @@
1
+ from collections.abc import Callable, Sequence
2
+ from typing import Generic, Literal, Protocol, cast
3
+
4
+ from fastapi import status
5
+ from pydantic import BaseModel, create_model
6
+
7
+ from u_toolkit.pydantic.type_vars import BaseModelT
8
+
9
+
10
+ class WrapperError(BaseModel, Generic[BaseModelT]): # type: ignore
11
+ @classmethod
12
+ def create(
13
+ cls: type["WrapperError[BaseModelT]"],
14
+ model: BaseModelT,
15
+ ) -> "WrapperError[BaseModelT]":
16
+ raise NotImplementedError
17
+
18
+
19
+ class EndpointError(WrapperError, Generic[BaseModelT]):
20
+ error: BaseModelT
21
+
22
+ @classmethod
23
+ def create(cls, model: BaseModelT):
24
+ return cls(error=model)
25
+
26
+
27
+ class HTTPErrorInterface(Protocol):
28
+ status: int
29
+
30
+ @classmethod
31
+ def response_class(cls) -> type[BaseModel]: ...
32
+
33
+
34
+ class NamedHTTPError(Exception, Generic[BaseModelT]):
35
+ status: int = status.HTTP_400_BAD_REQUEST
36
+ code: str | None = None
37
+ targets: Sequence[str] | None = None
38
+ target_transform: Callable[[str], str] | None = None
39
+ message: str | None = None
40
+ wrapper_class: type[WrapperError[BaseModelT]] | None = EndpointError
41
+
42
+ @classmethod
43
+ def error_name(cls):
44
+ return cls.__name__.removesuffix("Error")
45
+
46
+ @classmethod
47
+ def model_class(cls) -> type[BaseModelT]:
48
+ type_ = cls.error_name()
49
+ error_code = cls.code or type_
50
+ kwargs = {
51
+ "code": (Literal[error_code], ...),
52
+ "message": (Literal[cls.message] if cls.message else str, ...),
53
+ }
54
+ if cls.targets:
55
+ kwargs["target"] = (Literal[*cls.transformed_targets()], ...)
56
+
57
+ return cast(type[BaseModelT], create_model(f"{type_}Model", **kwargs))
58
+
59
+ @classmethod
60
+ def error_code(cls):
61
+ return cls.code or cls.error_name()
62
+
63
+ @classmethod
64
+ def transformed_targets(cls) -> list[str]:
65
+ if cls.targets:
66
+ result = []
67
+ for i in cls.targets:
68
+ if cls.target_transform:
69
+ result.append(cls.target_transform(i))
70
+ else:
71
+ result.append(i)
72
+ return result
73
+ return []
74
+
75
+ def __init__(
76
+ self,
77
+ *,
78
+ message: str,
79
+ target: str | None = None,
80
+ headers: dict[str, str] | None = None,
81
+ ) -> None:
82
+ kwargs = {
83
+ "code": self.error_code(),
84
+ "message": message,
85
+ }
86
+
87
+ if target:
88
+ if self.target_transform:
89
+ target = self.target_transform(target)
90
+ kwargs["target"] = target
91
+ kwargs["message"] = kwargs["message"].format(target=target)
92
+
93
+ self.model = self.model_class()(**kwargs)
94
+ self.data: BaseModel = (
95
+ self.wrapper_class.create(self.model)
96
+ if self.wrapper_class is not None
97
+ else self.model
98
+ )
99
+
100
+ self.headers = headers
101
+
102
+ def __str__(self) -> str:
103
+ return f"{self.status}: {self.data.error.code}" # type: ignore
104
+
105
+ def __repr__(self) -> str:
106
+ return f"{self.model_class: str(self.error)}"
107
+
108
+ @classmethod
109
+ def response_class(cls):
110
+ model = cls.model_class()
111
+ return cls.wrapper_class if cls.wrapper_class is not None else model
112
+
113
+ @classmethod
114
+ def response_schema(cls):
115
+ return {cls.status: {"model": cls.response_class()}}
@@ -0,0 +1,18 @@
1
+ from typing import Annotated, Any, get_args
2
+
3
+ from fastapi.params import Depends
4
+
5
+
6
+ def is_depend(value: Any):
7
+ return isinstance(value, Depends)
8
+
9
+
10
+ def get_depend_from_annotation(annotation: Annotated):
11
+ args = list(get_args(annotation))
12
+ # 因为 FastAPI 好像也是取最后的依赖运行的, 所以这里也将参数反转
13
+ args.reverse()
14
+ for arg in args:
15
+ if is_depend(arg):
16
+ return arg
17
+
18
+ raise ValueError
@@ -0,0 +1,67 @@
1
+ import asyncio
2
+ from collections.abc import Awaitable, Callable, Coroutine
3
+ from contextlib import (
4
+ AbstractAsyncContextManager,
5
+ AbstractContextManager,
6
+ AsyncExitStack,
7
+ asynccontextmanager,
8
+ )
9
+ from typing import TypeVar
10
+
11
+ from fastapi import FastAPI
12
+
13
+
14
+ Hook = Callable[
15
+ [FastAPI],
16
+ Awaitable[None] | Coroutine[None, None, None] | None,
17
+ ]
18
+
19
+ HookT = TypeVar("HookT", bound=Hook)
20
+
21
+ ContextManager = Callable[
22
+ [FastAPI],
23
+ AbstractContextManager | AbstractAsyncContextManager,
24
+ ]
25
+
26
+ ContextManagerT = TypeVar("ContextManagerT", bound=ContextManager)
27
+
28
+
29
+ class Lifespan:
30
+ def __init__(self) -> None:
31
+ self._startup_hooks: list[Hook] = []
32
+ self._shutdown_hooks: list[Hook] = []
33
+ self._context_managers: list[ContextManager] = []
34
+
35
+ def on_startup(self, fn: HookT) -> HookT:
36
+ self._startup_hooks.append(fn)
37
+ return fn
38
+
39
+ def on_shutdown(self, fn: HookT) -> HookT:
40
+ self._shutdown_hooks.append(fn)
41
+ return fn
42
+
43
+ def on_context(self, fn: ContextManagerT) -> ContextManagerT:
44
+ self._context_managers.append(fn)
45
+ return fn
46
+
47
+ @asynccontextmanager
48
+ async def __call__(self, _app: FastAPI):
49
+ for hook in self._startup_hooks:
50
+ ret = hook(_app)
51
+ if asyncio.iscoroutine(ret):
52
+ await ret
53
+
54
+ async with AsyncExitStack() as stack:
55
+ for ctx in self._context_managers:
56
+ i = ctx(_app)
57
+ if isinstance(i, AbstractContextManager):
58
+ stack.enter_context(i)
59
+ elif isinstance(i, AbstractAsyncContextManager):
60
+ await stack.enter_async_context(i)
61
+
62
+ yield
63
+
64
+ for hook in self._shutdown_hooks:
65
+ ret = hook(_app)
66
+ if asyncio.iscoroutine(ret):
67
+ await ret
@@ -0,0 +1,79 @@
1
+ from collections.abc import Sequence
2
+ from math import ceil
3
+ from typing import Annotated, Generic, overload
4
+
5
+ from fastapi import Depends, Query
6
+ from pydantic import BaseModel, Field, NonNegativeInt, PositiveInt
7
+ from sqlalchemy import MappingResult, ScalarResult
8
+
9
+ from u_toolkit.pydantic.type_vars import BaseModelT
10
+ from u_toolkit.sqlalchemy.type_vars import DeclarativeBaseT
11
+
12
+
13
+ class Page(BaseModel, Generic[BaseModelT]):
14
+ page_size: PositiveInt = Field(description="page size")
15
+ page_no: PositiveInt = Field(description="page number")
16
+
17
+ page_count: NonNegativeInt = Field(description="page count")
18
+ count: NonNegativeInt = Field(description="result count")
19
+
20
+ results: list[BaseModelT] = Field(description="results")
21
+
22
+
23
+ class PageParamsModel(BaseModel):
24
+ page_size: int = Query(
25
+ 50,
26
+ ge=1,
27
+ le=100,
28
+ description="page size",
29
+ )
30
+ page_no: PositiveInt = Query(
31
+ 1,
32
+ description="page number",
33
+ )
34
+
35
+
36
+ PageParams = Annotated[PageParamsModel, Depends()]
37
+
38
+
39
+ @overload
40
+ def page(
41
+ model_class: type[BaseModelT],
42
+ pagination: PageParamsModel,
43
+ count: int,
44
+ results: ScalarResult[DeclarativeBaseT],
45
+ ) -> Page[BaseModelT]: ...
46
+
47
+
48
+ @overload
49
+ def page(
50
+ model_class: type[BaseModelT],
51
+ pagination: PageParamsModel,
52
+ count: int,
53
+ results: MappingResult,
54
+ ) -> Page[BaseModelT]: ...
55
+
56
+
57
+ @overload
58
+ def page(
59
+ model_class: type[BaseModelT],
60
+ pagination: PageParamsModel,
61
+ count: int,
62
+ results: Sequence[DeclarativeBaseT],
63
+ ) -> Page[BaseModelT]: ...
64
+
65
+
66
+ def page(
67
+ model_class: type[BaseModelT],
68
+ pagination: PageParamsModel,
69
+ count: int,
70
+ results,
71
+ ) -> Page[BaseModelT]:
72
+ results_ = [model_class.model_validate(i) for i in results]
73
+ return Page(
74
+ page_size=pagination.page_size,
75
+ page_no=pagination.page_no,
76
+ page_count=ceil(count / pagination.page_size),
77
+ count=count,
78
+ results=results_, # type: ignore
79
+ )
@@ -0,0 +1,66 @@
1
+ from pydantic import BaseModel
2
+
3
+ from u_toolkit.merge import deep_merge_dict
4
+
5
+ from .exception import HTTPErrorInterface
6
+
7
+
8
+ def _merge_responses(
9
+ target: dict,
10
+ source: dict,
11
+ ):
12
+ for status, response in target.items():
13
+ model_class = response.get("model")
14
+ if status in source:
15
+ source_model_class = source[status].get("model")
16
+ if source_model_class is not None:
17
+ target[status]["model"] = model_class | source_model_class
18
+
19
+ for status, response in source.items():
20
+ if status not in target:
21
+ target[status] = response
22
+
23
+
24
+ def error_responses(*errors: type[HTTPErrorInterface]):
25
+ source = {}
26
+
27
+ for e in errors:
28
+ model_class = e.response_class()
29
+ if e.status in source:
30
+ current: type[BaseModel] = source[e.status]["model"]
31
+
32
+ source[e.status] = {"model": current | model_class}
33
+ else:
34
+ deep_merge_dict(source, {e.status: {"model": model_class}})
35
+ return source
36
+
37
+
38
+ Response = (
39
+ tuple[int, type[BaseModel]]
40
+ | dict[int, type[BaseModel]]
41
+ | int
42
+ | type[HTTPErrorInterface]
43
+ )
44
+
45
+
46
+ def build_responses(*responses: Response):
47
+ result = {}
48
+ errors: list[type[HTTPErrorInterface]] = []
49
+
50
+ for arg in responses:
51
+ status = None
52
+ response = {}
53
+ if isinstance(arg, tuple):
54
+ status, response = arg
55
+ elif isinstance(arg, dict):
56
+ for status_, response_ in arg.items():
57
+ result[status_] = {"model": response_}
58
+ elif isinstance(arg, int):
59
+ status = arg
60
+ else:
61
+ errors.append(arg)
62
+
63
+ result[status] = {"model": response}
64
+
65
+ _merge_responses(result, error_responses(*errors))
66
+ return result
u_toolkit/function.py ADDED
@@ -0,0 +1,12 @@
1
+ from collections.abc import Callable
2
+
3
+
4
+ def get_name(fn: Callable, /):
5
+ return fn.__name__
6
+
7
+
8
+ def add_document(fn: Callable, document: str, /):
9
+ if fn.__doc__ is None:
10
+ fn.__doc__ = document
11
+ else:
12
+ fn.__doc__ += f"\n\n{document}"
u_toolkit/helpers.py ADDED
@@ -0,0 +1,5 @@
1
+ from typing import Annotated, get_origin
2
+
3
+
4
+ def is_annotated(target):
5
+ return get_origin(target) is Annotated
u_toolkit/logger.py ADDED
@@ -0,0 +1,11 @@
1
+ from u_toolkit.enum import NameEnum, auto
2
+
3
+
4
+ class LogLevel(NameEnum):
5
+ TRACE = auto()
6
+ DEBUG = auto()
7
+ INFO = auto()
8
+ SUCCESS = auto()
9
+ WARNING = auto()
10
+ ERROR = auto()
11
+ CRITICAL = auto()
u_toolkit/merge.py ADDED
@@ -0,0 +1,31 @@
1
+ from collections.abc import Mapping
2
+
3
+
4
+ def deep_merge_dict(
5
+ target: dict,
6
+ source: Mapping,
7
+ ):
8
+ """深层合并两个字典
9
+
10
+ :param target: 存放合并内容的字典
11
+ :param source: 来源, 因为不会修改, 所以只读映射就可以
12
+ :param exclude_keys: 需要排除的 keys
13
+ """
14
+
15
+ for ok, ov in source.items():
16
+ v = target.get(ok)
17
+ # 如果两边都是映射类型, 就可以合并
18
+ if isinstance(v, dict) and isinstance(ov, Mapping):
19
+ deep_merge_dict(v, ov)
20
+ # 如果当前值允许进行相加的操作
21
+ # 并且不是字符串和数字
22
+ # 并且旧字典与当前值类型相同
23
+ elif (
24
+ hasattr(v, "__add__")
25
+ and not isinstance(v, str | int)
26
+ and type(v) is type(ov)
27
+ ):
28
+ target[ok] = v + ov
29
+ # 否则使用有效的值
30
+ else:
31
+ target[ok] = v or ov
u_toolkit/object.py ADDED
File without changes
u_toolkit/path.py ADDED
@@ -0,0 +1,4 @@
1
+ def path(*subpath) -> str:
2
+ return "/" + "/".join(
3
+ f"{i}".removeprefix("/").removesuffix("/") for i in subpath
4
+ )
File without changes
@@ -0,0 +1,24 @@
1
+ from datetime import datetime
2
+ from typing import Annotated
3
+
4
+ from pydantic import Field, PlainSerializer
5
+
6
+ from u_toolkit.datetime import to_naive, to_utc, to_utc_naive
7
+
8
+
9
+ DBSmallInt = Annotated[int, Field(ge=-32768, le=32767)]
10
+ DBInt = Annotated[int, Field(ge=-2147483648, le=2147483647)]
11
+ DBBigInt = Annotated[
12
+ int, Field(ge=-9223372036854775808, le=9223372036854775807)
13
+ ]
14
+ DBSmallSerial = Annotated[int, Field(ge=1, le=32767)]
15
+ DBIntSerial = Annotated[int, Field(ge=1, le=2147483647)]
16
+ DBBigintSerial = Annotated[int, Field(ge=1, le=9223372036854775807)]
17
+
18
+
19
+ # 去除时区信息
20
+ NaiveDatetime = Annotated[datetime, PlainSerializer(to_naive)]
21
+ # 将时区转成 UTC
22
+ UTCDateTime = Annotated[datetime, PlainSerializer(to_utc)]
23
+ # 将时间转成 UTC, 并且去除时区信息
24
+ UTCNaiveDateTime = Annotated[datetime, PlainSerializer(to_utc_naive)]
@@ -0,0 +1,41 @@
1
+ from datetime import datetime
2
+ from typing import Annotated, Any
3
+
4
+ from pydantic import BaseModel, ConfigDict, WrapSerializer
5
+
6
+ from u_toolkit.alias_generators import to_camel
7
+ from u_toolkit.datetime import to_utc
8
+
9
+ from .type_vars import BaseModelT
10
+
11
+
12
+ class Model(BaseModel):
13
+ model_config = ConfigDict(
14
+ populate_by_name=True,
15
+ from_attributes=True,
16
+ )
17
+
18
+
19
+ class CamelModel(Model):
20
+ model_config = ConfigDict(
21
+ alias_generator=to_camel,
22
+ populate_by_name=True,
23
+ from_attributes=True,
24
+ field_title_generator=to_camel,
25
+ )
26
+
27
+
28
+ def convert_to_utc(value: Any, handler, info) -> dict[str, datetime]:
29
+ # Note that `helper` can actually help serialize the `value` for
30
+ # further custom serialization in case it's a subclass.
31
+ partial_result = handler(value, info)
32
+ if info.mode == "json":
33
+ return {
34
+ k: to_utc(datetime.fromisoformat(v))
35
+ for k, v in partial_result.items()
36
+ }
37
+
38
+ return {k: to_utc(v) for k, v in partial_result.items()}
39
+
40
+
41
+ ConvertToUTCModel = Annotated[BaseModelT, WrapSerializer(convert_to_utc)]
@@ -0,0 +1,6 @@
1
+ from typing import TypeVar
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
u_toolkit/signature.py ADDED
@@ -0,0 +1,74 @@
1
+ import inspect
2
+ from collections.abc import Callable, Sequence
3
+ from functools import partial
4
+ from typing import Annotated, Any, overload
5
+
6
+
7
+ def list_parameters(fn: Callable, /) -> list[inspect.Parameter]:
8
+ signature = inspect.signature(fn)
9
+ return list(signature.parameters.values())
10
+
11
+
12
+ @overload
13
+ def with_parameter(
14
+ fn: Callable, *, name: str, annotation: type | Annotated
15
+ ) -> tuple[list[inspect.Parameter], inspect.Parameter, int]: ...
16
+ @overload
17
+ def with_parameter(
18
+ fn: Callable, *, name: str, default: Any
19
+ ) -> tuple[list[inspect.Parameter], inspect.Parameter, int]: ...
20
+ @overload
21
+ def with_parameter(
22
+ fn: Callable, *, name: str, annotation: type | Annotated, default: Any
23
+ ) -> tuple[list[inspect.Parameter], inspect.Parameter, int]: ...
24
+
25
+
26
+ def with_parameter(
27
+ fn: Callable,
28
+ *,
29
+ name: str,
30
+ annotation: type | Annotated | None = None,
31
+ default: Any = None,
32
+ ) -> tuple[list[inspect.Parameter], inspect.Parameter, int]:
33
+ kwargs = {}
34
+ if annotation is not None:
35
+ kwargs["annotation"] = annotation
36
+ if default is not None:
37
+ kwargs["default"] = default
38
+
39
+ parameters = list_parameters(fn)
40
+ parameter = inspect.Parameter(
41
+ name=name, kind=inspect.Parameter.KEYWORD_ONLY, **kwargs
42
+ )
43
+ index = -1
44
+ if parameters and parameters[index].kind == inspect.Parameter.VAR_KEYWORD:
45
+ parameters.insert(index, parameter)
46
+ index = -2
47
+ else:
48
+ parameters.append(parameter)
49
+
50
+ return parameters, parameter, index
51
+
52
+
53
+ def update_signature(
54
+ fn: Callable,
55
+ *,
56
+ parameters: Sequence[inspect.Parameter] | None = None,
57
+ return_annotation: type | None = None,
58
+ ):
59
+ signature = inspect.signature(fn)
60
+ kwargs = {}
61
+ if parameters:
62
+ kwargs["parameters"] = parameters
63
+ if return_annotation:
64
+ kwargs["return_annotation"] = return_annotation
65
+
66
+ fn.__signature__ = signature.replace(**kwargs) # type: ignore
67
+
68
+
69
+ def update_parameters(fn: Callable, *parameters: inspect.Parameter):
70
+ update_signature(fn, parameters=parameters)
71
+
72
+
73
+ def update_return_annotation(fn: Callable, return_annotation: type, /):
74
+ update_signature(fn, return_annotation=return_annotation)
File without changes
File without changes
@@ -0,0 +1,12 @@
1
+ import sqlalchemy as sa
2
+ from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
3
+
4
+
5
+ def json_object_build(
6
+ label: str,
7
+ table: type[DeclarativeBase],
8
+ *attrs: QueryableAttribute | sa.Column,
9
+ ):
10
+ return sa.func.json_build_object(
11
+ *[sa.text(f"'{i.key}', {table.__tablename__}.{i.key}") for i in attrs]
12
+ ).label(label)
File without changes
@@ -0,0 +1,20 @@
1
+ from functools import partial
2
+ from typing import Annotated
3
+ from uuid import UUID, uuid4
4
+
5
+ from sqlalchemy import BIGINT
6
+ from sqlalchemy.orm import mapped_column
7
+
8
+
9
+ primary_key_column = partial(mapped_column, primary_key=True, sort_order=-9999)
10
+
11
+ IntPrimaryKey = Annotated[int, primary_key_column()]
12
+ BigIntPrimaryKey = Annotated[int, primary_key_column(BIGINT)]
13
+
14
+ _auto_pk = partial(primary_key_column, autoincrement=True)
15
+
16
+ AutoIntPrimaryKey = Annotated[int, _auto_pk()]
17
+ AutoBigIntPrimaryKey = Annotated[int, _auto_pk(BIGINT)]
18
+
19
+
20
+ UUID4PrimaryKey = Annotated[UUID, primary_key_column(default=uuid4)]
@@ -0,0 +1,23 @@
1
+ from datetime import datetime
2
+
3
+ from pydantic.alias_generators import to_snake
4
+ from sqlalchemy import func
5
+ from sqlalchemy.orm import Mapped, declared_attr, mapped_column
6
+
7
+
8
+ class TablenameMixin:
9
+ # 当继承该类时, 会自动将表名转为 snake 名称形式
10
+ @declared_attr.directive
11
+ @classmethod
12
+ def __tablename__(cls) -> str:
13
+ return f"{to_snake(cls.__name__)}_tb"
14
+
15
+
16
+ class TimeStampMixin:
17
+ # 当继承该类时, 会给表添加创建时间和更新时间字段
18
+ created_at: Mapped[datetime] = mapped_column(
19
+ server_default=func.now(), sort_order=9998
20
+ )
21
+ updated_at: Mapped[datetime | None] = mapped_column(
22
+ onupdate=func.now(), sort_order=9999
23
+ )
@@ -0,0 +1,17 @@
1
+ import sqlalchemy as sa
2
+ from pydantic import TypeAdapter
3
+
4
+ from .type_vars import DeclarativeBaseT
5
+
6
+
7
+ class TableInfo:
8
+ def __init__(self, table: type[DeclarativeBaseT]) -> None:
9
+ self.table = table
10
+ self.columns = sa.inspect(table).c
11
+ self.adapters = {
12
+ i.name: (
13
+ TypeAdapter(i.type.python_type),
14
+ i.type.python_type,
15
+ )
16
+ for i in self.columns
17
+ }
@@ -0,0 +1,6 @@
1
+ from typing import TypeVar
2
+
3
+ from sqlalchemy.orm import DeclarativeBase
4
+
5
+
6
+ DeclarativeBaseT = TypeVar("DeclarativeBaseT", bound=DeclarativeBase)
@@ -0,0 +1,10 @@
1
+ Metadata-Version: 2.4
2
+ Name: u-toolkit
3
+ Version: 0.1.0
4
+ Summary: Add your description here
5
+ Requires-Python: >=3.11
6
+ Requires-Dist: pydantic>=2.11.3
7
+ Provides-Extra: fastapi
8
+ Requires-Dist: pydantic-settings>=2.9.1; extra == 'fastapi'
9
+ Provides-Extra: sqlalchemy
10
+ Requires-Dist: sqlalchemy>=2.0.40; extra == 'sqlalchemy'
@@ -0,0 +1,36 @@
1
+ u_toolkit/__init__.py,sha256=7f_bFg1UERA4gyh3HjCAjOZo2cBUgElpwgX5x-Xk2WA,238
2
+ u_toolkit/alias_generators.py,sha256=KmPL1ViGJjptONffZUAX4ENvkldrwDylm_Ly5RYE8b4,963
3
+ u_toolkit/datetime.py,sha256=GOG0xa6yKeqvFXJkImK5Pt7wsPXnHMlNKju8deniGJ4,644
4
+ u_toolkit/decorators.py,sha256=i6dCahV4AHQPwpW48Q55vxMYvcq8GC1noq1SpVE-xjE,1229
5
+ u_toolkit/enum.py,sha256=2yWmK8V1q0P5S3ltBED-TF2H09BmRbsBuqu3Th3G1NE,862
6
+ u_toolkit/function.py,sha256=GAE6TIm5jpemUiNUPjzk7pIHrxwSSX7sHl1V5E735dE,252
7
+ u_toolkit/helpers.py,sha256=K3FCz93K1nT4o7gWKVpMKy3k3BnirTPCSb8F8EFUrWk,112
8
+ u_toolkit/logger.py,sha256=NOmdR24QfSuo1EMnetyFioa8pA8OYceYTlQ4qiQcBdE,209
9
+ u_toolkit/merge.py,sha256=kHrloud-nPti5j48zkdvaiy4mIJYqOVguixAjWC46kE,924
10
+ u_toolkit/object.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ u_toolkit/path.py,sha256=IkyIHcU9hKBCOZfF30FrKf4CfL-MH91fjeYF9EY7eos,128
12
+ u_toolkit/signature.py,sha256=jjIhyqTV3XL_tjPeA4l8dFnF5TybHd4AHF_g_2G2uAg,2168
13
+ u_toolkit/fastapi/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ u_toolkit/fastapi/cbv.py,sha256=fuZCe6An75D3ernxFk6BLta_27VpoXvxBgwhToiCMIA,9940
15
+ u_toolkit/fastapi/config.py,sha256=kGpokR9XXr1KxMA1GVKYkCdKwqIQAIwOJ-v6sGbqzAQ,267
16
+ u_toolkit/fastapi/exception.py,sha256=5E4wAJYwp0RJ4SEBVkchOgrgfwCgniQ8Mtg1O5sWUXE,3288
17
+ u_toolkit/fastapi/helpers.py,sha256=BCMMLxa1c6BMA_rKq-hCi0iyEjrR3Z5rPMeTvgaVJB0,447
18
+ u_toolkit/fastapi/lifespan.py,sha256=W1TwWymW7xtmntx59QBC4LQ6xQr3L7yuMMGj4U8hhTQ,1813
19
+ u_toolkit/fastapi/pagination.py,sha256=yOgEDUT04m_mZ0cPliuDbUHLFnmxGAmr5PyZlwfjT_s,1940
20
+ u_toolkit/fastapi/responses.py,sha256=ZxxFCQpahS1I0ilB-XYW6_cqq1Qscu950iMauq7Oi5E,1740
21
+ u_toolkit/pydantic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ u_toolkit/pydantic/fields.py,sha256=9I8Jwek_oNn_niI98ZWaSPEUYIl5mw24YgbnpcEtgZk,839
23
+ u_toolkit/pydantic/models.py,sha256=Dqp3HnPlUU7ZpfBbYbERfcrLUb2CBJKHySRvM1Rv3HE,1101
24
+ u_toolkit/pydantic/type_vars.py,sha256=XBB3r2SipHOIeyCIuhaBwUDuIfm-M8aI3Gg8rx2uT_A,113
25
+ u_toolkit/sqlalchemy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ u_toolkit/sqlalchemy/fields.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ u_toolkit/sqlalchemy/function.py,sha256=DRkszvqDfl-BKBW9IcYeN665pgDYlwWiC3DOCjBnxGY,345
28
+ u_toolkit/sqlalchemy/table_info.py,sha256=EjY2Vxbw2c6TUGBf3NeJl1BB2Tnn9NfaFikUx9I3BBk,441
29
+ u_toolkit/sqlalchemy/type_vars.py,sha256=m2VeV41CBIK_1QX3w2kgz-n556sILAGZ-Kaz3TDDDIY,143
30
+ u_toolkit/sqlalchemy/orm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ u_toolkit/sqlalchemy/orm/fields.py,sha256=3zoYil23I6YLtc_59aHDt9w5l1NBTkePT9AfXI3DMiY,593
32
+ u_toolkit/sqlalchemy/orm/models.py,sha256=V8vf4ps3phAmwxyaFYK7pw8Igz7h097o4QBjKB0gwC8,705
33
+ u_toolkit-0.1.0.dist-info/METADATA,sha256=bVnbba44m98h7Pemhx8jpKivqb0w3458KX84204IkF8,312
34
+ u_toolkit-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
35
+ u_toolkit-0.1.0.dist-info/entry_points.txt,sha256=hTfAYCd5vvRiqgnJk2eBsoRIiIVB9pK8WZm3Q3jjKFU,45
36
+ u_toolkit-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ u-toolkit = u_toolkit:main