sqlspec 0.16.1__cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- 51ff5a9eadfdefd49f98__mypyc.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/__init__.py +92 -0
- sqlspec/__main__.py +12 -0
- sqlspec/__metadata__.py +14 -0
- sqlspec/_serialization.py +77 -0
- sqlspec/_sql.py +1780 -0
- sqlspec/_typing.py +680 -0
- sqlspec/adapters/__init__.py +0 -0
- sqlspec/adapters/adbc/__init__.py +5 -0
- sqlspec/adapters/adbc/_types.py +12 -0
- sqlspec/adapters/adbc/config.py +361 -0
- sqlspec/adapters/adbc/driver.py +512 -0
- sqlspec/adapters/aiosqlite/__init__.py +19 -0
- sqlspec/adapters/aiosqlite/_types.py +13 -0
- sqlspec/adapters/aiosqlite/config.py +253 -0
- sqlspec/adapters/aiosqlite/driver.py +248 -0
- sqlspec/adapters/asyncmy/__init__.py +19 -0
- sqlspec/adapters/asyncmy/_types.py +12 -0
- sqlspec/adapters/asyncmy/config.py +180 -0
- sqlspec/adapters/asyncmy/driver.py +274 -0
- sqlspec/adapters/asyncpg/__init__.py +21 -0
- sqlspec/adapters/asyncpg/_types.py +17 -0
- sqlspec/adapters/asyncpg/config.py +229 -0
- sqlspec/adapters/asyncpg/driver.py +344 -0
- sqlspec/adapters/bigquery/__init__.py +18 -0
- sqlspec/adapters/bigquery/_types.py +12 -0
- sqlspec/adapters/bigquery/config.py +298 -0
- sqlspec/adapters/bigquery/driver.py +558 -0
- sqlspec/adapters/duckdb/__init__.py +22 -0
- sqlspec/adapters/duckdb/_types.py +12 -0
- sqlspec/adapters/duckdb/config.py +504 -0
- sqlspec/adapters/duckdb/driver.py +368 -0
- sqlspec/adapters/oracledb/__init__.py +32 -0
- sqlspec/adapters/oracledb/_types.py +14 -0
- sqlspec/adapters/oracledb/config.py +317 -0
- sqlspec/adapters/oracledb/driver.py +538 -0
- sqlspec/adapters/psqlpy/__init__.py +16 -0
- sqlspec/adapters/psqlpy/_types.py +11 -0
- sqlspec/adapters/psqlpy/config.py +214 -0
- sqlspec/adapters/psqlpy/driver.py +530 -0
- sqlspec/adapters/psycopg/__init__.py +32 -0
- sqlspec/adapters/psycopg/_types.py +17 -0
- sqlspec/adapters/psycopg/config.py +426 -0
- sqlspec/adapters/psycopg/driver.py +796 -0
- sqlspec/adapters/sqlite/__init__.py +15 -0
- sqlspec/adapters/sqlite/_types.py +11 -0
- sqlspec/adapters/sqlite/config.py +240 -0
- sqlspec/adapters/sqlite/driver.py +294 -0
- sqlspec/base.py +571 -0
- sqlspec/builder/__init__.py +62 -0
- sqlspec/builder/_base.py +473 -0
- sqlspec/builder/_column.py +320 -0
- sqlspec/builder/_ddl.py +1346 -0
- sqlspec/builder/_ddl_utils.py +103 -0
- sqlspec/builder/_delete.py +76 -0
- sqlspec/builder/_insert.py +256 -0
- sqlspec/builder/_merge.py +71 -0
- sqlspec/builder/_parsing_utils.py +140 -0
- sqlspec/builder/_select.py +170 -0
- sqlspec/builder/_update.py +188 -0
- sqlspec/builder/mixins/__init__.py +55 -0
- sqlspec/builder/mixins/_cte_and_set_ops.py +222 -0
- sqlspec/builder/mixins/_delete_operations.py +41 -0
- sqlspec/builder/mixins/_insert_operations.py +244 -0
- sqlspec/builder/mixins/_join_operations.py +122 -0
- sqlspec/builder/mixins/_merge_operations.py +476 -0
- sqlspec/builder/mixins/_order_limit_operations.py +135 -0
- sqlspec/builder/mixins/_pivot_operations.py +153 -0
- sqlspec/builder/mixins/_select_operations.py +603 -0
- sqlspec/builder/mixins/_update_operations.py +187 -0
- sqlspec/builder/mixins/_where_clause.py +621 -0
- sqlspec/cli.py +247 -0
- sqlspec/config.py +395 -0
- sqlspec/core/__init__.py +63 -0
- sqlspec/core/cache.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/core/cache.py +871 -0
- sqlspec/core/compiler.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/core/compiler.py +417 -0
- sqlspec/core/filters.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/core/filters.py +830 -0
- sqlspec/core/hashing.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/core/hashing.py +310 -0
- sqlspec/core/parameters.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/core/parameters.py +1237 -0
- sqlspec/core/result.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/core/result.py +677 -0
- sqlspec/core/splitter.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/core/splitter.py +819 -0
- sqlspec/core/statement.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/core/statement.py +676 -0
- sqlspec/driver/__init__.py +19 -0
- sqlspec/driver/_async.py +502 -0
- sqlspec/driver/_common.py +631 -0
- sqlspec/driver/_sync.py +503 -0
- sqlspec/driver/mixins/__init__.py +6 -0
- sqlspec/driver/mixins/_result_tools.py +193 -0
- sqlspec/driver/mixins/_sql_translator.py +86 -0
- sqlspec/exceptions.py +193 -0
- sqlspec/extensions/__init__.py +0 -0
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +461 -0
- sqlspec/extensions/litestar/__init__.py +6 -0
- sqlspec/extensions/litestar/_utils.py +52 -0
- sqlspec/extensions/litestar/cli.py +48 -0
- sqlspec/extensions/litestar/config.py +92 -0
- sqlspec/extensions/litestar/handlers.py +260 -0
- sqlspec/extensions/litestar/plugin.py +145 -0
- sqlspec/extensions/litestar/providers.py +454 -0
- sqlspec/loader.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/loader.py +760 -0
- sqlspec/migrations/__init__.py +35 -0
- sqlspec/migrations/base.py +414 -0
- sqlspec/migrations/commands.py +443 -0
- sqlspec/migrations/loaders.py +402 -0
- sqlspec/migrations/runner.py +213 -0
- sqlspec/migrations/tracker.py +140 -0
- sqlspec/migrations/utils.py +129 -0
- sqlspec/protocols.py +407 -0
- sqlspec/py.typed +0 -0
- sqlspec/storage/__init__.py +23 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +163 -0
- sqlspec/storage/backends/fsspec.py +386 -0
- sqlspec/storage/backends/obstore.py +459 -0
- sqlspec/storage/capabilities.py +102 -0
- sqlspec/storage/registry.py +239 -0
- sqlspec/typing.py +299 -0
- sqlspec/utils/__init__.py +3 -0
- sqlspec/utils/correlation.py +150 -0
- sqlspec/utils/deprecation.py +106 -0
- sqlspec/utils/fixtures.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/utils/fixtures.py +58 -0
- sqlspec/utils/logging.py +127 -0
- sqlspec/utils/module_loader.py +89 -0
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +32 -0
- sqlspec/utils/sync_tools.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/utils/sync_tools.py +237 -0
- sqlspec/utils/text.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/utils/text.py +96 -0
- sqlspec/utils/type_guards.cpython-39-aarch64-linux-gnu.so +0 -0
- sqlspec/utils/type_guards.py +1139 -0
- sqlspec-0.16.1.dist-info/METADATA +365 -0
- sqlspec-0.16.1.dist-info/RECORD +148 -0
- sqlspec-0.16.1.dist-info/WHEEL +7 -0
- sqlspec-0.16.1.dist-info/entry_points.txt +2 -0
- sqlspec-0.16.1.dist-info/licenses/LICENSE +21 -0
- sqlspec-0.16.1.dist-info/licenses/NOTICE +29 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from functools import wraps
|
|
3
|
+
from typing import Callable, Literal, Optional
|
|
4
|
+
from warnings import warn
|
|
5
|
+
|
|
6
|
+
from typing_extensions import ParamSpec, TypeVar
|
|
7
|
+
|
|
8
|
+
__all__ = ("deprecated", "warn_deprecation")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
P = ParamSpec("P")
|
|
13
|
+
DeprecatedKind = Literal["function", "method", "classmethod", "attribute", "property", "class", "parameter", "import"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def warn_deprecation(
|
|
17
|
+
version: str,
|
|
18
|
+
deprecated_name: str,
|
|
19
|
+
kind: DeprecatedKind,
|
|
20
|
+
*,
|
|
21
|
+
removal_in: Optional[str] = None,
|
|
22
|
+
alternative: Optional[str] = None,
|
|
23
|
+
info: Optional[str] = None,
|
|
24
|
+
pending: bool = False,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Warn about a call to a deprecated function.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
version: SQLSpec version where the deprecation will occur
|
|
30
|
+
deprecated_name: Name of the deprecated function
|
|
31
|
+
removal_in: SQLSpec version where the deprecated function will be removed
|
|
32
|
+
alternative: Name of a function that should be used instead
|
|
33
|
+
info: Additional information
|
|
34
|
+
pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning`
|
|
35
|
+
kind: Type of the deprecated thing
|
|
36
|
+
"""
|
|
37
|
+
parts = []
|
|
38
|
+
|
|
39
|
+
if kind == "import":
|
|
40
|
+
access_type = "Import of"
|
|
41
|
+
elif kind in {"function", "method"}:
|
|
42
|
+
access_type = "Call to"
|
|
43
|
+
else:
|
|
44
|
+
access_type = "Use of"
|
|
45
|
+
|
|
46
|
+
if pending:
|
|
47
|
+
parts.append(f"{access_type} {kind} awaiting deprecation '{deprecated_name}'") # pyright: ignore[reportUnknownMemberType]
|
|
48
|
+
else:
|
|
49
|
+
parts.append(f"{access_type} deprecated {kind} '{deprecated_name}'") # pyright: ignore[reportUnknownMemberType]
|
|
50
|
+
|
|
51
|
+
parts.extend( # pyright: ignore[reportUnknownMemberType]
|
|
52
|
+
(f"Deprecated in SQLSpec {version}", f"This {kind} will be removed in {removal_in or 'the next major version'}")
|
|
53
|
+
)
|
|
54
|
+
if alternative:
|
|
55
|
+
parts.append(f"Use {alternative!r} instead") # pyright: ignore[reportUnknownMemberType]
|
|
56
|
+
|
|
57
|
+
if info:
|
|
58
|
+
parts.append(info) # pyright: ignore[reportUnknownMemberType]
|
|
59
|
+
|
|
60
|
+
text = ". ".join(parts) # pyright: ignore[reportUnknownArgumentType]
|
|
61
|
+
warning_class = PendingDeprecationWarning if pending else DeprecationWarning
|
|
62
|
+
|
|
63
|
+
warn(text, warning_class, stacklevel=2)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def deprecated(
|
|
67
|
+
version: str,
|
|
68
|
+
*,
|
|
69
|
+
removal_in: Optional[str] = None,
|
|
70
|
+
alternative: Optional[str] = None,
|
|
71
|
+
info: Optional[str] = None,
|
|
72
|
+
pending: bool = False,
|
|
73
|
+
kind: Optional[Literal["function", "method", "classmethod", "property"]] = None,
|
|
74
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
75
|
+
"""Create a decorator wrapping a function, method or property with a deprecation warning.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
version: SQLSpec version where the deprecation will occur
|
|
79
|
+
removal_in: SQLSpec version where the deprecated function will be removed
|
|
80
|
+
alternative: Name of a function that should be used instead
|
|
81
|
+
info: Additional information
|
|
82
|
+
pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning`
|
|
83
|
+
kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure
|
|
84
|
+
out if it's a function or method
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
A decorator wrapping the function call with a warning
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
91
|
+
@wraps(func)
|
|
92
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
93
|
+
warn_deprecation(
|
|
94
|
+
version=version,
|
|
95
|
+
deprecated_name=func.__name__,
|
|
96
|
+
info=info,
|
|
97
|
+
alternative=alternative,
|
|
98
|
+
pending=pending,
|
|
99
|
+
removal_in=removal_in,
|
|
100
|
+
kind=kind or ("method" if inspect.ismethod(func) else "function"),
|
|
101
|
+
)
|
|
102
|
+
return func(*args, **kwargs)
|
|
103
|
+
|
|
104
|
+
return wrapped
|
|
105
|
+
|
|
106
|
+
return decorator
|
|
Binary file
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from sqlspec._serialization import decode_json
|
|
5
|
+
from sqlspec.exceptions import MissingDependencyError
|
|
6
|
+
|
|
7
|
+
__all__ = ("open_fixture", "open_fixture_async")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def open_fixture(fixtures_path: Any, fixture_name: str) -> Any:
|
|
11
|
+
"""Load and parse a JSON fixture file.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
fixtures_path: The path to look for fixtures (pathlib.Path or anyio.Path)
|
|
15
|
+
fixture_name: The fixture name to load.
|
|
16
|
+
|
|
17
|
+
Raises:
|
|
18
|
+
FileNotFoundError: Fixtures not found.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The parsed JSON data
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
fixture = Path(fixtures_path / f"{fixture_name}.json")
|
|
25
|
+
if fixture.exists():
|
|
26
|
+
with fixture.open(mode="r", encoding="utf-8") as f:
|
|
27
|
+
f_data = f.read()
|
|
28
|
+
return decode_json(f_data)
|
|
29
|
+
msg = f"Could not find the {fixture_name} fixture"
|
|
30
|
+
raise FileNotFoundError(msg)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def open_fixture_async(fixtures_path: Any, fixture_name: str) -> Any:
|
|
34
|
+
"""Load and parse a JSON fixture file asynchronously.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
fixtures_path: The path to look for fixtures (pathlib.Path or anyio.Path)
|
|
38
|
+
fixture_name: The fixture name to load.
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
FileNotFoundError: Fixtures not found.
|
|
42
|
+
MissingDependencyError: The `anyio` library is required to use this function.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The parsed JSON data
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
from anyio import Path as AsyncPath
|
|
49
|
+
except ImportError as exc:
|
|
50
|
+
raise MissingDependencyError(package="anyio") from exc
|
|
51
|
+
|
|
52
|
+
fixture = AsyncPath(fixtures_path / f"{fixture_name}.json")
|
|
53
|
+
if await fixture.exists():
|
|
54
|
+
async with await fixture.open(mode="r", encoding="utf-8") as f:
|
|
55
|
+
f_data = await f.read()
|
|
56
|
+
return decode_json(f_data)
|
|
57
|
+
msg = f"Could not find the {fixture_name} fixture"
|
|
58
|
+
raise FileNotFoundError(msg)
|
sqlspec/utils/logging.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""Logging utilities for SQLSpec.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for structured logging with correlation IDs.
|
|
4
|
+
Users should configure their own logging handlers and levels as needed.
|
|
5
|
+
SQLSpec provides StructuredFormatter for JSON-formatted logs if desired.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from contextvars import ContextVar
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
|
|
14
|
+
from sqlspec._serialization import encode_json
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from logging import LogRecord
|
|
18
|
+
|
|
19
|
+
__all__ = ("StructuredFormatter", "correlation_id_var", "get_correlation_id", "get_logger", "set_correlation_id")
|
|
20
|
+
|
|
21
|
+
correlation_id_var: ContextVar[str | None] = ContextVar("correlation_id", default=None)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def set_correlation_id(correlation_id: str | None) -> None:
|
|
25
|
+
"""Set the correlation ID for the current context.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
correlation_id: The correlation ID to set, or None to clear
|
|
29
|
+
"""
|
|
30
|
+
correlation_id_var.set(correlation_id)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_correlation_id() -> str | None:
|
|
34
|
+
"""Get the current correlation ID.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
The current correlation ID or None if not set
|
|
38
|
+
"""
|
|
39
|
+
return correlation_id_var.get()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class StructuredFormatter(logging.Formatter):
|
|
43
|
+
"""Structured JSON formatter with correlation ID support."""
|
|
44
|
+
|
|
45
|
+
def format(self, record: LogRecord) -> str:
|
|
46
|
+
"""Format log record as structured JSON.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
record: The log record to format
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
JSON formatted log entry
|
|
53
|
+
"""
|
|
54
|
+
log_entry = {
|
|
55
|
+
"timestamp": self.formatTime(record, self.datefmt),
|
|
56
|
+
"level": record.levelname,
|
|
57
|
+
"logger": record.name,
|
|
58
|
+
"message": record.getMessage(),
|
|
59
|
+
"module": record.module,
|
|
60
|
+
"function": record.funcName,
|
|
61
|
+
"line": record.lineno,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
if correlation_id := get_correlation_id():
|
|
65
|
+
log_entry["correlation_id"] = correlation_id
|
|
66
|
+
|
|
67
|
+
if hasattr(record, "extra_fields"):
|
|
68
|
+
log_entry.update(record.extra_fields) # pyright: ignore
|
|
69
|
+
|
|
70
|
+
if record.exc_info:
|
|
71
|
+
log_entry["exception"] = self.formatException(record.exc_info)
|
|
72
|
+
|
|
73
|
+
return encode_json(log_entry)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class CorrelationIDFilter(logging.Filter):
|
|
77
|
+
"""Filter that adds correlation ID to log records."""
|
|
78
|
+
|
|
79
|
+
def filter(self, record: LogRecord) -> bool:
|
|
80
|
+
"""Add correlation ID to record if available.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
record: The log record to filter
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Always True to pass the record through
|
|
87
|
+
"""
|
|
88
|
+
if correlation_id := get_correlation_id():
|
|
89
|
+
record.correlation_id = correlation_id
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_logger(name: str | None = None) -> logging.Logger:
|
|
94
|
+
"""Get a logger instance with standardized configuration.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
name: Logger name. If not provided, returns the root sqlspec logger.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Configured logger instance
|
|
101
|
+
"""
|
|
102
|
+
if name is None:
|
|
103
|
+
return logging.getLogger("sqlspec")
|
|
104
|
+
|
|
105
|
+
if not name.startswith("sqlspec"):
|
|
106
|
+
name = f"sqlspec.{name}"
|
|
107
|
+
|
|
108
|
+
logger = logging.getLogger(name)
|
|
109
|
+
|
|
110
|
+
if not any(isinstance(f, CorrelationIDFilter) for f in logger.filters):
|
|
111
|
+
logger.addFilter(CorrelationIDFilter())
|
|
112
|
+
|
|
113
|
+
return logger
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def log_with_context(logger: logging.Logger, level: int, message: str, **extra_fields: Any) -> None:
|
|
117
|
+
"""Log a message with structured extra fields.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
logger: The logger to use
|
|
121
|
+
level: Log level
|
|
122
|
+
message: Log message
|
|
123
|
+
**extra_fields: Additional fields to include in structured logs
|
|
124
|
+
"""
|
|
125
|
+
record = logger.makeRecord(logger.name, level, "(unknown file)", 0, message, (), None)
|
|
126
|
+
record.extra_fields = extra_fields
|
|
127
|
+
logger.handle(record)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""General utility functions."""
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
from importlib.util import find_spec
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
__all__ = ("import_string", "module_to_os_path")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def module_to_os_path(dotted_path: str = "app") -> "Path":
|
|
12
|
+
"""Convert a module dotted path to filesystem path.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
dotted_path: The path to the module.
|
|
16
|
+
|
|
17
|
+
Raises:
|
|
18
|
+
TypeError: The module could not be found.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The path to the module.
|
|
22
|
+
"""
|
|
23
|
+
try:
|
|
24
|
+
if (src := find_spec(dotted_path)) is None: # pragma: no cover
|
|
25
|
+
msg = f"Couldn't find the path for {dotted_path}"
|
|
26
|
+
raise TypeError(msg)
|
|
27
|
+
except ModuleNotFoundError as e:
|
|
28
|
+
msg = f"Couldn't find the path for {dotted_path}"
|
|
29
|
+
raise TypeError(msg) from e
|
|
30
|
+
|
|
31
|
+
path = Path(str(src.origin))
|
|
32
|
+
return path.parent if path.is_file() else path
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def import_string(dotted_path: str) -> "Any":
|
|
36
|
+
"""Import a module or attribute from a dotted path string.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
dotted_path: The path of the module to import.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
The imported object.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def _raise_import_error(msg: str, exc: "Optional[Exception]" = None) -> None:
|
|
46
|
+
if exc is not None:
|
|
47
|
+
raise ImportError(msg) from exc
|
|
48
|
+
raise ImportError(msg)
|
|
49
|
+
|
|
50
|
+
obj: Any = None
|
|
51
|
+
try:
|
|
52
|
+
parts = dotted_path.split(".")
|
|
53
|
+
module = None
|
|
54
|
+
i = len(parts)
|
|
55
|
+
|
|
56
|
+
for i in range(len(parts), 0, -1):
|
|
57
|
+
module_path = ".".join(parts[:i])
|
|
58
|
+
try:
|
|
59
|
+
module = importlib.import_module(module_path)
|
|
60
|
+
break
|
|
61
|
+
except ModuleNotFoundError:
|
|
62
|
+
continue
|
|
63
|
+
else:
|
|
64
|
+
_raise_import_error(f"{dotted_path} doesn't look like a module path")
|
|
65
|
+
|
|
66
|
+
if module is None:
|
|
67
|
+
_raise_import_error(f"Failed to import any module from {dotted_path}")
|
|
68
|
+
|
|
69
|
+
obj = module
|
|
70
|
+
attrs = parts[i:]
|
|
71
|
+
if not attrs and i == len(parts) and len(parts) > 1:
|
|
72
|
+
parent_module_path = ".".join(parts[:-1])
|
|
73
|
+
attr = parts[-1]
|
|
74
|
+
try:
|
|
75
|
+
parent_module = importlib.import_module(parent_module_path)
|
|
76
|
+
except Exception:
|
|
77
|
+
return obj
|
|
78
|
+
if not hasattr(parent_module, attr):
|
|
79
|
+
_raise_import_error(f"Module '{parent_module_path}' has no attribute '{attr}' in '{dotted_path}'")
|
|
80
|
+
|
|
81
|
+
for attr in attrs:
|
|
82
|
+
if not hasattr(obj, attr):
|
|
83
|
+
_raise_import_error(
|
|
84
|
+
f"Module '{module.__name__ if module is not None else 'unknown'}' has no attribute '{attr}' in '{dotted_path}'"
|
|
85
|
+
)
|
|
86
|
+
obj = getattr(obj, attr)
|
|
87
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
88
|
+
_raise_import_error(f"Could not import '{dotted_path}': {e}", e)
|
|
89
|
+
return obj
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from typing import Any, TypeVar
|
|
3
|
+
|
|
4
|
+
__all__ = ("SingletonMeta",)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
_T = TypeVar("_T")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SingletonMeta(type):
|
|
11
|
+
"""Metaclass for singleton pattern."""
|
|
12
|
+
|
|
13
|
+
_instances: dict[type, object] = {}
|
|
14
|
+
_lock = threading.Lock()
|
|
15
|
+
|
|
16
|
+
def __call__(cls: type[_T], *args: Any, **kwargs: Any) -> _T:
|
|
17
|
+
"""Call method for the singleton metaclass.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
cls: The class being instantiated.
|
|
21
|
+
*args: Positional arguments for the class constructor.
|
|
22
|
+
**kwargs: Keyword arguments for the class constructor.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
The singleton instance of the class.
|
|
26
|
+
"""
|
|
27
|
+
if cls not in SingletonMeta._instances: # pyright: ignore[reportUnnecessaryContains]
|
|
28
|
+
with SingletonMeta._lock:
|
|
29
|
+
if cls not in SingletonMeta._instances:
|
|
30
|
+
instance = super().__call__(*args, **kwargs) # type: ignore[misc]
|
|
31
|
+
SingletonMeta._instances[cls] = instance
|
|
32
|
+
return SingletonMeta._instances[cls] # type: ignore[return-value]
|
|
Binary file
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import inspect
|
|
4
|
+
import sys
|
|
5
|
+
from contextlib import AbstractAsyncContextManager, AbstractContextManager
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
|
7
|
+
|
|
8
|
+
from typing_extensions import ParamSpec
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from collections.abc import Awaitable, Callable, Coroutine
|
|
12
|
+
from types import TracebackType
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import uvloop # pyright: ignore[reportMissingImports]
|
|
16
|
+
except ImportError:
|
|
17
|
+
uvloop = None # type: ignore[assignment,unused-ignore]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
ReturnT = TypeVar("ReturnT")
|
|
21
|
+
ParamSpecT = ParamSpec("ParamSpecT")
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CapacityLimiter:
|
|
26
|
+
"""Limits the number of concurrent operations using a semaphore."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, total_tokens: int) -> None:
|
|
29
|
+
self._semaphore = asyncio.Semaphore(total_tokens)
|
|
30
|
+
|
|
31
|
+
async def acquire(self) -> None:
|
|
32
|
+
await self._semaphore.acquire()
|
|
33
|
+
|
|
34
|
+
def release(self) -> None:
|
|
35
|
+
self._semaphore.release()
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def total_tokens(self) -> int:
|
|
39
|
+
return self._semaphore._value
|
|
40
|
+
|
|
41
|
+
@total_tokens.setter
|
|
42
|
+
def total_tokens(self, value: int) -> None:
|
|
43
|
+
self._semaphore = asyncio.Semaphore(value)
|
|
44
|
+
|
|
45
|
+
async def __aenter__(self) -> None:
|
|
46
|
+
await self.acquire()
|
|
47
|
+
|
|
48
|
+
async def __aexit__(
|
|
49
|
+
self,
|
|
50
|
+
exc_type: "Optional[type[BaseException]]",
|
|
51
|
+
exc_val: "Optional[BaseException]",
|
|
52
|
+
exc_tb: "Optional[TracebackType]",
|
|
53
|
+
) -> None:
|
|
54
|
+
self.release()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
_default_limiter = CapacityLimiter(15)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def run_(async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]") -> "Callable[ParamSpecT, ReturnT]":
|
|
61
|
+
"""Convert an async function to a blocking function using asyncio.run().
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
async_function: The async function to convert.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
A blocking function that runs the async function.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
@functools.wraps(async_function)
|
|
71
|
+
def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
|
|
72
|
+
partial_f = functools.partial(async_function, *args, **kwargs)
|
|
73
|
+
try:
|
|
74
|
+
loop = asyncio.get_running_loop()
|
|
75
|
+
except RuntimeError:
|
|
76
|
+
loop = None
|
|
77
|
+
|
|
78
|
+
if loop is not None:
|
|
79
|
+
return asyncio.run(partial_f())
|
|
80
|
+
if uvloop and sys.platform != "win32":
|
|
81
|
+
uvloop.install() # pyright: ignore[reportUnknownMemberType]
|
|
82
|
+
return asyncio.run(partial_f())
|
|
83
|
+
|
|
84
|
+
return wrapper
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def await_(
|
|
88
|
+
async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]", raise_sync_error: bool = True
|
|
89
|
+
) -> "Callable[ParamSpecT, ReturnT]":
|
|
90
|
+
"""Convert an async function to a blocking one, running in the main async loop.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
async_function: The async function to convert.
|
|
94
|
+
raise_sync_error: If False, runs in a new event loop if no loop is present.
|
|
95
|
+
If True (default), raises RuntimeError if no loop is running.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
A blocking function that runs the async function.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
@functools.wraps(async_function)
|
|
102
|
+
def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
|
|
103
|
+
partial_f = functools.partial(async_function, *args, **kwargs)
|
|
104
|
+
try:
|
|
105
|
+
loop = asyncio.get_running_loop()
|
|
106
|
+
except RuntimeError:
|
|
107
|
+
if raise_sync_error:
|
|
108
|
+
msg = "Cannot run async function"
|
|
109
|
+
raise RuntimeError(msg) from None
|
|
110
|
+
return asyncio.run(partial_f())
|
|
111
|
+
else:
|
|
112
|
+
if loop.is_running():
|
|
113
|
+
try:
|
|
114
|
+
current_task = asyncio.current_task(loop=loop)
|
|
115
|
+
except RuntimeError:
|
|
116
|
+
current_task = None
|
|
117
|
+
|
|
118
|
+
if current_task is not None:
|
|
119
|
+
msg = "await_ cannot be called from within an async task running on the same event loop. Use 'await' instead."
|
|
120
|
+
raise RuntimeError(msg)
|
|
121
|
+
future = asyncio.run_coroutine_threadsafe(partial_f(), loop)
|
|
122
|
+
return future.result()
|
|
123
|
+
if raise_sync_error:
|
|
124
|
+
msg = "Cannot run async function"
|
|
125
|
+
raise RuntimeError(msg)
|
|
126
|
+
return asyncio.run(partial_f())
|
|
127
|
+
|
|
128
|
+
return wrapper
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def async_(
|
|
132
|
+
function: "Callable[ParamSpecT, ReturnT]", *, limiter: "Optional[CapacityLimiter]" = None
|
|
133
|
+
) -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
|
|
134
|
+
"""Convert a blocking function to an async one using asyncio.to_thread().
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
function: The blocking function to convert.
|
|
138
|
+
limiter: Limit the total number of threads.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
An async function that runs the original function in a thread.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
@functools.wraps(function)
|
|
145
|
+
async def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
|
|
146
|
+
partial_f = functools.partial(function, *args, **kwargs)
|
|
147
|
+
used_limiter = limiter or _default_limiter
|
|
148
|
+
async with used_limiter:
|
|
149
|
+
return await asyncio.to_thread(partial_f)
|
|
150
|
+
|
|
151
|
+
return wrapper
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def ensure_async_(
|
|
155
|
+
function: "Callable[ParamSpecT, Union[Awaitable[ReturnT], ReturnT]]",
|
|
156
|
+
) -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
|
|
157
|
+
"""Convert a function to an async one if it is not already.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
function: The function to convert.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
An async function that runs the original function.
|
|
164
|
+
"""
|
|
165
|
+
if inspect.iscoroutinefunction(function):
|
|
166
|
+
return function
|
|
167
|
+
|
|
168
|
+
@functools.wraps(function)
|
|
169
|
+
async def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
|
|
170
|
+
result = function(*args, **kwargs)
|
|
171
|
+
if inspect.isawaitable(result):
|
|
172
|
+
return await result
|
|
173
|
+
return await async_(lambda: result)()
|
|
174
|
+
|
|
175
|
+
return wrapper
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class _ContextManagerWrapper(Generic[T]):
|
|
179
|
+
def __init__(self, cm: AbstractContextManager[T]) -> None:
|
|
180
|
+
self._cm = cm
|
|
181
|
+
|
|
182
|
+
async def __aenter__(self) -> T:
|
|
183
|
+
return self._cm.__enter__()
|
|
184
|
+
|
|
185
|
+
async def __aexit__(
|
|
186
|
+
self,
|
|
187
|
+
exc_type: "Optional[type[BaseException]]",
|
|
188
|
+
exc_val: "Optional[BaseException]",
|
|
189
|
+
exc_tb: "Optional[TracebackType]",
|
|
190
|
+
) -> "Optional[bool]":
|
|
191
|
+
return self._cm.__exit__(exc_type, exc_val, exc_tb)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def with_ensure_async_(
|
|
195
|
+
obj: "Union[AbstractContextManager[T], AbstractAsyncContextManager[T]]",
|
|
196
|
+
) -> "AbstractAsyncContextManager[T]":
|
|
197
|
+
"""Convert a context manager to an async one if it is not already.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
obj: The context manager to convert.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
An async context manager that runs the original context manager.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
if isinstance(obj, AbstractContextManager):
|
|
207
|
+
return cast("AbstractAsyncContextManager[T]", _ContextManagerWrapper(obj))
|
|
208
|
+
return obj
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class NoValue:
|
|
212
|
+
"""Sentinel class for missing values."""
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
async def get_next(iterable: Any, default: Any = NoValue, *args: Any) -> Any: # pragma: no cover
|
|
216
|
+
"""Return the next item from an async iterator.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
iterable: An async iterable.
|
|
220
|
+
default: An optional default value to return if the iterable is empty.
|
|
221
|
+
*args: The remaining args
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
The next value of the iterable.
|
|
225
|
+
|
|
226
|
+
Raises:
|
|
227
|
+
StopAsyncIteration: The iterable given is not async.
|
|
228
|
+
"""
|
|
229
|
+
has_default = bool(not isinstance(default, NoValue))
|
|
230
|
+
try:
|
|
231
|
+
return await iterable.__anext__()
|
|
232
|
+
|
|
233
|
+
except StopAsyncIteration as exc:
|
|
234
|
+
if has_default:
|
|
235
|
+
return default
|
|
236
|
+
|
|
237
|
+
raise StopAsyncIteration from exc
|
|
Binary file
|