python-basekit 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,53 @@
1
+ from collections.abc import Awaitable, Callable
2
+ from contextlib import AbstractAsyncContextManager, suppress
3
+ from functools import wraps
4
+
5
+ from pydantic import BaseModel
6
+
7
+ from basekit.cache.utils import (
8
+ CacheValueError,
9
+ deserialize_cache_value,
10
+ get_instance,
11
+ serialize_cache_key,
12
+ serialize_cache_value,
13
+ )
14
+
15
+
16
+ class Cache(BaseModel):
17
+ def lifespan(self) -> AbstractAsyncContextManager[None]:
18
+ raise NotImplementedError
19
+
20
+ async def set(self, key: bytes, value: bytes, /) -> None:
21
+ raise NotImplementedError
22
+
23
+ async def get(self, key: bytes, /) -> bytes | None:
24
+ raise NotImplementedError
25
+
26
+ async def delete(self, key: bytes, /) -> None:
27
+ raise NotImplementedError
28
+
29
+ def __call__[**P, T](
30
+ self, func: Callable[P, Awaitable[T]], /
31
+ ) -> Callable[P, Awaitable[T]]:
32
+ @wraps(wrapped=func)
33
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
34
+ data: dict = {
35
+ "module": func.__module__,
36
+ "function": func.__qualname__,
37
+ "instance": get_instance(func),
38
+ "args": args,
39
+ "kwargs": kwargs,
40
+ }
41
+ key: bytes = serialize_cache_key(data)
42
+
43
+ cached_value: bytes | None = await self.get(key)
44
+ if cached_value is not None:
45
+ with suppress(CacheValueError):
46
+ return deserialize_cache_value(cached_value)
47
+ await self.delete(key)
48
+
49
+ value: T = await func(*args, **kwargs)
50
+ await self.set(key, serialize_cache_value(value))
51
+ return value
52
+
53
+ return wrapper
basekit/cache/utils.py ADDED
@@ -0,0 +1,113 @@
1
+ import pickle
2
+ import pickletools
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel
7
+
8
+
9
+ class CacheKeyError(Exception):
10
+ pass
11
+
12
+
13
+ class CacheValueError(Exception):
14
+ pass
15
+
16
+
17
+ def get_instance(func: Callable, /) -> Any:
18
+ current: Callable | None = func
19
+ while current is not None:
20
+ instance: Any = getattr(current, "__self__", None)
21
+ if instance is not None:
22
+ return instance
23
+ current = getattr(current, "__wrapped__", None)
24
+ return None
25
+
26
+
27
+ def cache_key_fields(cls: type[BaseModel], /) -> list[str]:
28
+ model_fields: set[str] = set()
29
+ cache_key_fields: set[str] = set()
30
+ for base in reversed(cls.__mro__):
31
+ if not issubclass(base, BaseModel):
32
+ continue
33
+ current_model_fields: set[str] = set(base.model_fields.keys()) - model_fields
34
+ current_cache_key_fields: set[str] = set(
35
+ base.__dict__.get("cache_key_fields", current_model_fields)
36
+ )
37
+ model_fields.update(current_model_fields)
38
+ cache_key_fields.update(current_cache_key_fields)
39
+ return list(cache_key_fields)
40
+
41
+
42
+ def build_cache_key(value: Any, /) -> dict:
43
+ if isinstance(value, type) and issubclass(value, BaseModel):
44
+ return {
45
+ "module": value.__module__,
46
+ "class": value.__qualname__,
47
+ "schema": value.model_json_schema(),
48
+ }
49
+ if isinstance(value, BaseModel):
50
+ cls: type[BaseModel] = value.__class__
51
+ instance: dict = {f: getattr(value, f) for f in cache_key_fields(cls)}
52
+ return {
53
+ "module": cls.__module__,
54
+ "class": cls.__qualname__,
55
+ "instance": build_cache_key(instance),
56
+ }
57
+ if isinstance(value, dict):
58
+ if not all(isinstance(k, str) for k in value.keys()):
59
+ raise ValueError(f"字典键必须是字符串:{value.keys()}")
60
+ return {"dict": {k: build_cache_key(v) for k, v in value.items()}}
61
+ if isinstance(value, list):
62
+ return {"list": [build_cache_key(i) for i in value]}
63
+ if isinstance(value, tuple):
64
+ return {"tuple": [build_cache_key(i) for i in value]}
65
+ if isinstance(value, bytes):
66
+ return {"bytes": value}
67
+ if isinstance(value, str):
68
+ return {"str": value}
69
+ if isinstance(value, bool):
70
+ return {"bool": value}
71
+ if isinstance(value, int):
72
+ return {"int": value}
73
+ if isinstance(value, float):
74
+ return {"float": value}
75
+ if value is None:
76
+ return {"none": None}
77
+ raise ValueError(f"类型不支持序列化:{value.__class__}")
78
+
79
+
80
+ def normalize_cache_key(value: Any, /) -> Any:
81
+ if isinstance(value, dict):
82
+ return {k: normalize_cache_key(value[k]) for k in sorted(value.keys())}
83
+ if isinstance(value, list):
84
+ return [normalize_cache_key(item) for item in value]
85
+ if isinstance(value, tuple):
86
+ return tuple(normalize_cache_key(item) for item in value)
87
+ return value
88
+
89
+
90
+ def serialize_cache_key(value: Any, /) -> bytes:
91
+ try:
92
+ data: dict = normalize_cache_key(build_cache_key(value))
93
+ return pickletools.optimize(
94
+ p=pickle.dumps(obj=data, protocol=pickle.HIGHEST_PROTOCOL)
95
+ )
96
+ except Exception as exc:
97
+ raise CacheKeyError("缓存键序列化失败") from exc
98
+
99
+
100
+ def serialize_cache_value(value: Any, /) -> bytes:
101
+ try:
102
+ return pickletools.optimize(
103
+ p=pickle.dumps(obj=value, protocol=pickle.HIGHEST_PROTOCOL)
104
+ )
105
+ except Exception as exc:
106
+ raise CacheValueError("缓存值序列化失败") from exc
107
+
108
+
109
+ def deserialize_cache_value(value: bytes, /) -> Any:
110
+ try:
111
+ return pickle.loads(value)
112
+ except Exception as exc:
113
+ raise CacheValueError("缓存值反序列化失败") from exc
basekit/database.py ADDED
@@ -0,0 +1,33 @@
1
+ from collections.abc import AsyncIterator
2
+ from contextlib import asynccontextmanager
3
+ from typing import override
4
+
5
+ from pydantic import BaseModel, PrivateAttr
6
+ from sqlalchemy import MetaData
7
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
8
+
9
+
10
+ class DatabaseClient(BaseModel):
11
+ url: str
12
+
13
+ _engine: AsyncEngine = PrivateAttr()
14
+
15
+ @override
16
+ def model_post_init(self, context) -> None:
17
+ self._engine = create_async_engine(url=self.url)
18
+
19
+ @asynccontextmanager
20
+ async def lifespan(self) -> AsyncIterator[None]:
21
+ try:
22
+ yield
23
+ finally:
24
+ await self._engine.dispose()
25
+
26
+ async def create_schema(self, metadata: MetaData, /) -> None:
27
+ async with self._engine.begin() as conn:
28
+ await conn.run_sync(fn=metadata.create_all)
29
+
30
+ @asynccontextmanager
31
+ async def connection(self) -> AsyncIterator[AsyncConnection]:
32
+ async with self._engine.connect() as connection:
33
+ yield connection
@@ -0,0 +1,99 @@
1
+ from collections.abc import AsyncIterator
2
+ from contextlib import asynccontextmanager, suppress
3
+ from http.cookiejar import CookieJar
4
+ from typing import Any, override
5
+
6
+ import logfire
7
+ from curl_cffi import AsyncSession, BrowserTypeLiteral, Response
8
+ from curl_cffi.requests.exceptions import HTTPError
9
+ from pydantic import PrivateAttr
10
+
11
+ from basekit.http.schema import HttpClient, HttpMethod
12
+ from basekit.http.utils import clone_cookiejar, truncate_data
13
+
14
+
15
+ class CurlCffiClient(HttpClient[Response]):
16
+ impersonate: BrowserTypeLiteral = "chrome"
17
+
18
+ _client: AsyncSession = PrivateAttr()
19
+
20
+ @override
21
+ def model_post_init(self, context: Any) -> None:
22
+ self._client = AsyncSession(
23
+ headers=self.headers,
24
+ verify=self.verify,
25
+ http_version="v2",
26
+ timeout=self.timeout,
27
+ allow_redirects=True,
28
+ max_clients=self.limit,
29
+ impersonate=self.impersonate,
30
+ )
31
+
32
+ @override
33
+ @asynccontextmanager
34
+ async def lifespan(self) -> AsyncIterator[None]:
35
+ async with self._client:
36
+ yield
37
+
38
+ def save_cookies(self) -> CookieJar:
39
+ return clone_cookiejar(self._client.cookies.jar)
40
+
41
+ def load_cookies(self, cookies: CookieJar) -> None:
42
+ self._client.cookies = clone_cookiejar(cookies)
43
+
44
+ def _extract_content(self, response: Response, /) -> dict | str | bytes:
45
+ content_type: str = response.headers.get(key="Content-Type", default="") or ""
46
+ content_type = content_type.split(sep=";")[0].strip().lower()
47
+ if content_type == "application/json":
48
+ with suppress(Exception):
49
+ return response.json()
50
+ if content_type.startswith("text/"):
51
+ with suppress(Exception):
52
+ return response.text
53
+ with suppress(Exception):
54
+ return response.content
55
+ return b""
56
+
57
+ async def _request_once(
58
+ self,
59
+ method: HttpMethod,
60
+ url: str,
61
+ /,
62
+ *,
63
+ params: dict | None = None,
64
+ data: dict | None = None,
65
+ headers: dict[str, str] | None = None,
66
+ ) -> Response:
67
+ with logfire.span(f"curl cffi client | request | {method} | {url}") as span:
68
+ span.set_attribute(key="request_params", value=truncate_data(params))
69
+ span.set_attribute(key="request_data", value=truncate_data(data))
70
+ span.set_attribute(key="request_headers", value=truncate_data(headers))
71
+
72
+ response: Response = await self._client.request(
73
+ method=method, url=url, params=params, json=data, headers=headers
74
+ )
75
+
76
+ status_code: int = response.status_code
77
+ span.set_attribute(key="status_code", value=status_code)
78
+ span.message = f"{span.message} -> {status_code}"
79
+
80
+ response_headers: dict = dict(response.headers.items())
81
+ span.set_attribute(
82
+ key="response_headers", value=truncate_data(response_headers)
83
+ )
84
+
85
+ response_content: dict | str | bytes = self._extract_content(response)
86
+ span.set_attribute(
87
+ key="response_content", value=truncate_data(response_content)
88
+ )
89
+
90
+ response.raise_for_status()
91
+ return response
92
+
93
+ def _should_retry(self, exc: BaseException, /) -> bool:
94
+ if not isinstance(exc, HTTPError):
95
+ return True
96
+ if exc.response is None:
97
+ return True
98
+ status_code: int = exc.response.status_code
99
+ return status_code == 429 or status_code >= 500
@@ -0,0 +1,100 @@
1
+ from collections.abc import AsyncIterator
2
+ from contextlib import asynccontextmanager, suppress
3
+ from http.cookiejar import CookieJar
4
+ from typing import Any, override
5
+
6
+ import logfire
7
+ from httpx import AsyncClient, HTTPStatusError, Limits, Response
8
+ from pydantic import PrivateAttr
9
+
10
+ from basekit.http.schema import HttpClient, HttpMethod
11
+ from basekit.http.utils import clone_cookiejar, truncate_data
12
+
13
+
14
+ class HttpxClient(HttpClient[Response]):
15
+ _client: AsyncClient = PrivateAttr()
16
+
17
+ @override
18
+ def model_post_init(self, context: Any) -> None:
19
+ self._client = AsyncClient(
20
+ headers=self.headers,
21
+ verify=self.verify,
22
+ http2=True,
23
+ timeout=self.timeout,
24
+ follow_redirects=True,
25
+ limits=Limits(
26
+ max_connections=self.limit,
27
+ max_keepalive_connections=self.limit,
28
+ ),
29
+ )
30
+
31
+ @override
32
+ @asynccontextmanager
33
+ async def lifespan(self) -> AsyncIterator[None]:
34
+ async with self._client:
35
+ yield
36
+
37
+ @override
38
+ def save_cookies(self) -> CookieJar:
39
+ return clone_cookiejar(self._client.cookies.jar)
40
+
41
+ @override
42
+ def load_cookies(self, cookies: CookieJar) -> None:
43
+ self._client.cookies = clone_cookiejar(cookies)
44
+
45
+ def _extract_content(self, response: Response, /) -> dict | str | bytes:
46
+ content_type: str = response.headers.get(key="Content-Type", default="") or ""
47
+ content_type = content_type.split(sep=";")[0].strip().lower()
48
+ if content_type == "application/json":
49
+ with suppress(Exception):
50
+ return response.json()
51
+ if content_type.startswith("text/"):
52
+ with suppress(Exception):
53
+ return response.text
54
+ with suppress(Exception):
55
+ return response.content
56
+ return b""
57
+
58
+ @override
59
+ async def _request_once(
60
+ self,
61
+ method: HttpMethod,
62
+ url: str,
63
+ /,
64
+ *,
65
+ params: dict | None = None,
66
+ data: dict | None = None,
67
+ headers: dict[str, str] | None = None,
68
+ ) -> Response:
69
+ with logfire.span(f"httpx client | request | {method} | {url}") as span:
70
+ span.set_attribute(key="request_params", value=truncate_data(params))
71
+ span.set_attribute(key="request_data", value=truncate_data(data))
72
+ span.set_attribute(key="request_headers", value=truncate_data(headers))
73
+
74
+ response: Response = await self._client.request(
75
+ method=method, url=url, params=params, json=data, headers=headers
76
+ )
77
+
78
+ status_code: int = response.status_code
79
+ span.set_attribute(key="status_code", value=status_code)
80
+ span.message = f"{span.message} -> {status_code}"
81
+
82
+ response_headers: dict = dict(response.headers.items())
83
+ span.set_attribute(
84
+ key="response_headers", value=truncate_data(response_headers)
85
+ )
86
+
87
+ response_content: dict | str | bytes = self._extract_content(response)
88
+ span.set_attribute(
89
+ key="response_content", value=truncate_data(response_content)
90
+ )
91
+
92
+ response.raise_for_status()
93
+ return response
94
+
95
+ @override
96
+ def _should_retry(self, exc: BaseException, /) -> bool:
97
+ if not isinstance(exc, HTTPStatusError):
98
+ return True
99
+ status_code: int = exc.response.status_code
100
+ return status_code == 429 or status_code >= 500
basekit/http/schema.py ADDED
@@ -0,0 +1,111 @@
1
+ from contextlib import AbstractAsyncContextManager
2
+ from http.cookiejar import CookieJar
3
+ from typing import ClassVar, Literal
4
+
5
+ from pydantic import BaseModel, ConfigDict
6
+ from tenacity import (
7
+ retry,
8
+ retry_if_exception,
9
+ stop_after_attempt,
10
+ wait_exponential_jitter,
11
+ )
12
+ from tenacity.stop import stop_base
13
+ from tenacity.wait import wait_base
14
+
15
+ DEFAULT_TIMEOUT = 5.0
16
+ DEFAULT_LIMIT = 100
17
+ DEFAULT_RETRY_STOP = stop_after_attempt(max_attempt_number=3)
18
+ DEFAULT_RETRY_WAIT = wait_exponential_jitter()
19
+
20
+ type HttpMethod = Literal["GET", "POST"]
21
+
22
+
23
+ class HttpClient[ResponseT](BaseModel):
24
+ model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
25
+
26
+ headers: dict[str, str] | None = None
27
+ verify: bool = True
28
+ timeout: float = DEFAULT_TIMEOUT
29
+ limit: int = DEFAULT_LIMIT
30
+ retry_stop: stop_base = DEFAULT_RETRY_STOP
31
+ retry_wait: wait_base = DEFAULT_RETRY_WAIT
32
+
33
+ def lifespan(self) -> AbstractAsyncContextManager[None]:
34
+ raise NotImplementedError
35
+
36
+ def save_cookies(self) -> CookieJar:
37
+ raise NotImplementedError
38
+
39
+ def load_cookies(self, cookies: CookieJar) -> None:
40
+ raise NotImplementedError
41
+
42
+ async def _request_once(
43
+ self,
44
+ method: HttpMethod,
45
+ url: str,
46
+ /,
47
+ *,
48
+ params: dict | None = None,
49
+ data: dict | None = None,
50
+ headers: dict[str, str] | None = None,
51
+ ) -> ResponseT:
52
+ raise NotImplementedError
53
+
54
+ def _should_retry(self, exc: BaseException, /) -> bool:
55
+ raise NotImplementedError
56
+
57
+ async def request(
58
+ self,
59
+ method: HttpMethod,
60
+ url: str,
61
+ /,
62
+ *,
63
+ params: dict | None = None,
64
+ data: dict | None = None,
65
+ headers: dict[str, str] | None = None,
66
+ retry_on_failure: bool = False,
67
+ ) -> ResponseT:
68
+ if not retry_on_failure:
69
+ return await self._request_once(
70
+ method, url, params=params, data=data, headers=headers
71
+ )
72
+ return await retry(
73
+ retry=retry_if_exception(predicate=self._should_retry),
74
+ stop=self.retry_stop,
75
+ wait=self.retry_wait,
76
+ reraise=True,
77
+ )(self._request_once)(method, url, params=params, data=data, headers=headers)
78
+
79
+ async def get(
80
+ self,
81
+ url: str,
82
+ /,
83
+ *,
84
+ params: dict | None = None,
85
+ headers: dict[str, str] | None = None,
86
+ retry_on_failure: bool = True,
87
+ ) -> ResponseT:
88
+ return await self.request(
89
+ "GET",
90
+ url,
91
+ params=params,
92
+ headers=headers,
93
+ retry_on_failure=retry_on_failure,
94
+ )
95
+
96
+ async def post(
97
+ self,
98
+ url: str,
99
+ /,
100
+ *,
101
+ data: dict | None = None,
102
+ headers: dict[str, str] | None = None,
103
+ retry_on_failure: bool = False,
104
+ ) -> ResponseT:
105
+ return await self.request(
106
+ "POST",
107
+ url,
108
+ data=data,
109
+ headers=headers,
110
+ retry_on_failure=retry_on_failure,
111
+ )
basekit/http/utils.py ADDED
@@ -0,0 +1,31 @@
1
+ from http.cookiejar import CookieJar
2
+ from typing import Any
3
+
4
+ SIZE_LIMIT = 1048576 # 1MB
5
+
6
+
7
+ def truncate_data(value: Any, /) -> Any:
8
+ match value:
9
+ case dict():
10
+ return {k: truncate_data(v) for k, v in value.items()}
11
+ case list():
12
+ return [truncate_data(v) for v in value]
13
+ case str():
14
+ size: int = len(value)
15
+ if size <= SIZE_LIMIT:
16
+ return value
17
+ return f"[TOO LARGE: {size} chars]{value[:SIZE_LIMIT]}..."
18
+ case bytes():
19
+ size: int = len(value)
20
+ if size <= SIZE_LIMIT:
21
+ return value
22
+ return f"[TOO LARGE: {size} bytes]{value[:SIZE_LIMIT]}..."
23
+ case _:
24
+ return value
25
+
26
+
27
+ def clone_cookiejar(cookiejar: CookieJar, /) -> CookieJar:
28
+ cloned_cookie_jar = CookieJar()
29
+ for cookie in cookiejar:
30
+ cloned_cookie_jar.set_cookie(cookie=cookie)
31
+ return cloned_cookie_jar