sqlspec 0.16.1__cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (148) hide show
  1. 51ff5a9eadfdefd49f98__mypyc.cpython-310-aarch64-linux-gnu.so +0 -0
  2. sqlspec/__init__.py +92 -0
  3. sqlspec/__main__.py +12 -0
  4. sqlspec/__metadata__.py +14 -0
  5. sqlspec/_serialization.py +77 -0
  6. sqlspec/_sql.py +1780 -0
  7. sqlspec/_typing.py +680 -0
  8. sqlspec/adapters/__init__.py +0 -0
  9. sqlspec/adapters/adbc/__init__.py +5 -0
  10. sqlspec/adapters/adbc/_types.py +12 -0
  11. sqlspec/adapters/adbc/config.py +361 -0
  12. sqlspec/adapters/adbc/driver.py +512 -0
  13. sqlspec/adapters/aiosqlite/__init__.py +19 -0
  14. sqlspec/adapters/aiosqlite/_types.py +13 -0
  15. sqlspec/adapters/aiosqlite/config.py +253 -0
  16. sqlspec/adapters/aiosqlite/driver.py +248 -0
  17. sqlspec/adapters/asyncmy/__init__.py +19 -0
  18. sqlspec/adapters/asyncmy/_types.py +12 -0
  19. sqlspec/adapters/asyncmy/config.py +180 -0
  20. sqlspec/adapters/asyncmy/driver.py +274 -0
  21. sqlspec/adapters/asyncpg/__init__.py +21 -0
  22. sqlspec/adapters/asyncpg/_types.py +17 -0
  23. sqlspec/adapters/asyncpg/config.py +229 -0
  24. sqlspec/adapters/asyncpg/driver.py +344 -0
  25. sqlspec/adapters/bigquery/__init__.py +18 -0
  26. sqlspec/adapters/bigquery/_types.py +12 -0
  27. sqlspec/adapters/bigquery/config.py +298 -0
  28. sqlspec/adapters/bigquery/driver.py +558 -0
  29. sqlspec/adapters/duckdb/__init__.py +22 -0
  30. sqlspec/adapters/duckdb/_types.py +12 -0
  31. sqlspec/adapters/duckdb/config.py +504 -0
  32. sqlspec/adapters/duckdb/driver.py +368 -0
  33. sqlspec/adapters/oracledb/__init__.py +32 -0
  34. sqlspec/adapters/oracledb/_types.py +14 -0
  35. sqlspec/adapters/oracledb/config.py +317 -0
  36. sqlspec/adapters/oracledb/driver.py +538 -0
  37. sqlspec/adapters/psqlpy/__init__.py +16 -0
  38. sqlspec/adapters/psqlpy/_types.py +11 -0
  39. sqlspec/adapters/psqlpy/config.py +214 -0
  40. sqlspec/adapters/psqlpy/driver.py +530 -0
  41. sqlspec/adapters/psycopg/__init__.py +32 -0
  42. sqlspec/adapters/psycopg/_types.py +17 -0
  43. sqlspec/adapters/psycopg/config.py +426 -0
  44. sqlspec/adapters/psycopg/driver.py +796 -0
  45. sqlspec/adapters/sqlite/__init__.py +15 -0
  46. sqlspec/adapters/sqlite/_types.py +11 -0
  47. sqlspec/adapters/sqlite/config.py +240 -0
  48. sqlspec/adapters/sqlite/driver.py +294 -0
  49. sqlspec/base.py +571 -0
  50. sqlspec/builder/__init__.py +62 -0
  51. sqlspec/builder/_base.py +473 -0
  52. sqlspec/builder/_column.py +320 -0
  53. sqlspec/builder/_ddl.py +1346 -0
  54. sqlspec/builder/_ddl_utils.py +103 -0
  55. sqlspec/builder/_delete.py +76 -0
  56. sqlspec/builder/_insert.py +256 -0
  57. sqlspec/builder/_merge.py +71 -0
  58. sqlspec/builder/_parsing_utils.py +140 -0
  59. sqlspec/builder/_select.py +170 -0
  60. sqlspec/builder/_update.py +188 -0
  61. sqlspec/builder/mixins/__init__.py +55 -0
  62. sqlspec/builder/mixins/_cte_and_set_ops.py +222 -0
  63. sqlspec/builder/mixins/_delete_operations.py +41 -0
  64. sqlspec/builder/mixins/_insert_operations.py +244 -0
  65. sqlspec/builder/mixins/_join_operations.py +122 -0
  66. sqlspec/builder/mixins/_merge_operations.py +476 -0
  67. sqlspec/builder/mixins/_order_limit_operations.py +135 -0
  68. sqlspec/builder/mixins/_pivot_operations.py +153 -0
  69. sqlspec/builder/mixins/_select_operations.py +603 -0
  70. sqlspec/builder/mixins/_update_operations.py +187 -0
  71. sqlspec/builder/mixins/_where_clause.py +621 -0
  72. sqlspec/cli.py +247 -0
  73. sqlspec/config.py +395 -0
  74. sqlspec/core/__init__.py +63 -0
  75. sqlspec/core/cache.cpython-310-aarch64-linux-gnu.so +0 -0
  76. sqlspec/core/cache.py +871 -0
  77. sqlspec/core/compiler.cpython-310-aarch64-linux-gnu.so +0 -0
  78. sqlspec/core/compiler.py +417 -0
  79. sqlspec/core/filters.cpython-310-aarch64-linux-gnu.so +0 -0
  80. sqlspec/core/filters.py +830 -0
  81. sqlspec/core/hashing.cpython-310-aarch64-linux-gnu.so +0 -0
  82. sqlspec/core/hashing.py +310 -0
  83. sqlspec/core/parameters.cpython-310-aarch64-linux-gnu.so +0 -0
  84. sqlspec/core/parameters.py +1237 -0
  85. sqlspec/core/result.cpython-310-aarch64-linux-gnu.so +0 -0
  86. sqlspec/core/result.py +677 -0
  87. sqlspec/core/splitter.cpython-310-aarch64-linux-gnu.so +0 -0
  88. sqlspec/core/splitter.py +819 -0
  89. sqlspec/core/statement.cpython-310-aarch64-linux-gnu.so +0 -0
  90. sqlspec/core/statement.py +676 -0
  91. sqlspec/driver/__init__.py +19 -0
  92. sqlspec/driver/_async.py +502 -0
  93. sqlspec/driver/_common.py +631 -0
  94. sqlspec/driver/_sync.py +503 -0
  95. sqlspec/driver/mixins/__init__.py +6 -0
  96. sqlspec/driver/mixins/_result_tools.py +193 -0
  97. sqlspec/driver/mixins/_sql_translator.py +86 -0
  98. sqlspec/exceptions.py +193 -0
  99. sqlspec/extensions/__init__.py +0 -0
  100. sqlspec/extensions/aiosql/__init__.py +10 -0
  101. sqlspec/extensions/aiosql/adapter.py +461 -0
  102. sqlspec/extensions/litestar/__init__.py +6 -0
  103. sqlspec/extensions/litestar/_utils.py +52 -0
  104. sqlspec/extensions/litestar/cli.py +48 -0
  105. sqlspec/extensions/litestar/config.py +92 -0
  106. sqlspec/extensions/litestar/handlers.py +260 -0
  107. sqlspec/extensions/litestar/plugin.py +145 -0
  108. sqlspec/extensions/litestar/providers.py +454 -0
  109. sqlspec/loader.cpython-310-aarch64-linux-gnu.so +0 -0
  110. sqlspec/loader.py +760 -0
  111. sqlspec/migrations/__init__.py +35 -0
  112. sqlspec/migrations/base.py +414 -0
  113. sqlspec/migrations/commands.py +443 -0
  114. sqlspec/migrations/loaders.py +402 -0
  115. sqlspec/migrations/runner.py +213 -0
  116. sqlspec/migrations/tracker.py +140 -0
  117. sqlspec/migrations/utils.py +129 -0
  118. sqlspec/protocols.py +407 -0
  119. sqlspec/py.typed +0 -0
  120. sqlspec/storage/__init__.py +23 -0
  121. sqlspec/storage/backends/__init__.py +0 -0
  122. sqlspec/storage/backends/base.py +163 -0
  123. sqlspec/storage/backends/fsspec.py +386 -0
  124. sqlspec/storage/backends/obstore.py +459 -0
  125. sqlspec/storage/capabilities.py +102 -0
  126. sqlspec/storage/registry.py +239 -0
  127. sqlspec/typing.py +299 -0
  128. sqlspec/utils/__init__.py +3 -0
  129. sqlspec/utils/correlation.py +150 -0
  130. sqlspec/utils/deprecation.py +106 -0
  131. sqlspec/utils/fixtures.cpython-310-aarch64-linux-gnu.so +0 -0
  132. sqlspec/utils/fixtures.py +58 -0
  133. sqlspec/utils/logging.py +127 -0
  134. sqlspec/utils/module_loader.py +89 -0
  135. sqlspec/utils/serializers.py +4 -0
  136. sqlspec/utils/singleton.py +32 -0
  137. sqlspec/utils/sync_tools.cpython-310-aarch64-linux-gnu.so +0 -0
  138. sqlspec/utils/sync_tools.py +237 -0
  139. sqlspec/utils/text.cpython-310-aarch64-linux-gnu.so +0 -0
  140. sqlspec/utils/text.py +96 -0
  141. sqlspec/utils/type_guards.cpython-310-aarch64-linux-gnu.so +0 -0
  142. sqlspec/utils/type_guards.py +1139 -0
  143. sqlspec-0.16.1.dist-info/METADATA +365 -0
  144. sqlspec-0.16.1.dist-info/RECORD +148 -0
  145. sqlspec-0.16.1.dist-info/WHEEL +7 -0
  146. sqlspec-0.16.1.dist-info/entry_points.txt +2 -0
  147. sqlspec-0.16.1.dist-info/licenses/LICENSE +21 -0
  148. sqlspec-0.16.1.dist-info/licenses/NOTICE +29 -0
@@ -0,0 +1,106 @@
1
+ import inspect
2
+ from functools import wraps
3
+ from typing import Callable, Literal, Optional
4
+ from warnings import warn
5
+
6
+ from typing_extensions import ParamSpec, TypeVar
7
+
8
+ __all__ = ("deprecated", "warn_deprecation")
9
+
10
+
11
+ T = TypeVar("T")
12
+ P = ParamSpec("P")
13
+ DeprecatedKind = Literal["function", "method", "classmethod", "attribute", "property", "class", "parameter", "import"]
14
+
15
+
16
+ def warn_deprecation(
17
+ version: str,
18
+ deprecated_name: str,
19
+ kind: DeprecatedKind,
20
+ *,
21
+ removal_in: Optional[str] = None,
22
+ alternative: Optional[str] = None,
23
+ info: Optional[str] = None,
24
+ pending: bool = False,
25
+ ) -> None:
26
+ """Warn about a call to a deprecated function.
27
+
28
+ Args:
29
+ version: SQLSpec version where the deprecation will occur
30
+ deprecated_name: Name of the deprecated function
31
+ removal_in: SQLSpec version where the deprecated function will be removed
32
+ alternative: Name of a function that should be used instead
33
+ info: Additional information
34
+ pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning`
35
+ kind: Type of the deprecated thing
36
+ """
37
+ parts = []
38
+
39
+ if kind == "import":
40
+ access_type = "Import of"
41
+ elif kind in {"function", "method"}:
42
+ access_type = "Call to"
43
+ else:
44
+ access_type = "Use of"
45
+
46
+ if pending:
47
+ parts.append(f"{access_type} {kind} awaiting deprecation '{deprecated_name}'") # pyright: ignore[reportUnknownMemberType]
48
+ else:
49
+ parts.append(f"{access_type} deprecated {kind} '{deprecated_name}'") # pyright: ignore[reportUnknownMemberType]
50
+
51
+ parts.extend( # pyright: ignore[reportUnknownMemberType]
52
+ (f"Deprecated in SQLSpec {version}", f"This {kind} will be removed in {removal_in or 'the next major version'}")
53
+ )
54
+ if alternative:
55
+ parts.append(f"Use {alternative!r} instead") # pyright: ignore[reportUnknownMemberType]
56
+
57
+ if info:
58
+ parts.append(info) # pyright: ignore[reportUnknownMemberType]
59
+
60
+ text = ". ".join(parts) # pyright: ignore[reportUnknownArgumentType]
61
+ warning_class = PendingDeprecationWarning if pending else DeprecationWarning
62
+
63
+ warn(text, warning_class, stacklevel=2)
64
+
65
+
66
+ def deprecated(
67
+ version: str,
68
+ *,
69
+ removal_in: Optional[str] = None,
70
+ alternative: Optional[str] = None,
71
+ info: Optional[str] = None,
72
+ pending: bool = False,
73
+ kind: Optional[Literal["function", "method", "classmethod", "property"]] = None,
74
+ ) -> Callable[[Callable[P, T]], Callable[P, T]]:
75
+ """Create a decorator wrapping a function, method or property with a deprecation warning.
76
+
77
+ Args:
78
+ version: SQLSpec version where the deprecation will occur
79
+ removal_in: SQLSpec version where the deprecated function will be removed
80
+ alternative: Name of a function that should be used instead
81
+ info: Additional information
82
+ pending: Use :class:`warnings.PendingDeprecationWarning` instead of :class:`warnings.DeprecationWarning`
83
+ kind: Type of the deprecated callable. If ``None``, will use ``inspect`` to figure
84
+ out if it's a function or method
85
+
86
+ Returns:
87
+ A decorator wrapping the function call with a warning
88
+ """
89
+
90
+ def decorator(func: Callable[P, T]) -> Callable[P, T]:
91
+ @wraps(func)
92
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> T:
93
+ warn_deprecation(
94
+ version=version,
95
+ deprecated_name=func.__name__,
96
+ info=info,
97
+ alternative=alternative,
98
+ pending=pending,
99
+ removal_in=removal_in,
100
+ kind=kind or ("method" if inspect.ismethod(func) else "function"),
101
+ )
102
+ return func(*args, **kwargs)
103
+
104
+ return wrapped
105
+
106
+ return decorator
@@ -0,0 +1,58 @@
1
+ from pathlib import Path
2
+ from typing import Any
3
+
4
+ from sqlspec._serialization import decode_json
5
+ from sqlspec.exceptions import MissingDependencyError
6
+
7
+ __all__ = ("open_fixture", "open_fixture_async")
8
+
9
+
10
+ def open_fixture(fixtures_path: Any, fixture_name: str) -> Any:
11
+ """Load and parse a JSON fixture file.
12
+
13
+ Args:
14
+ fixtures_path: The path to look for fixtures (pathlib.Path or anyio.Path)
15
+ fixture_name: The fixture name to load.
16
+
17
+ Raises:
18
+ FileNotFoundError: Fixtures not found.
19
+
20
+ Returns:
21
+ The parsed JSON data
22
+ """
23
+
24
+ fixture = Path(fixtures_path / f"{fixture_name}.json")
25
+ if fixture.exists():
26
+ with fixture.open(mode="r", encoding="utf-8") as f:
27
+ f_data = f.read()
28
+ return decode_json(f_data)
29
+ msg = f"Could not find the {fixture_name} fixture"
30
+ raise FileNotFoundError(msg)
31
+
32
+
33
+ async def open_fixture_async(fixtures_path: Any, fixture_name: str) -> Any:
34
+ """Load and parse a JSON fixture file asynchronously.
35
+
36
+ Args:
37
+ fixtures_path: The path to look for fixtures (pathlib.Path or anyio.Path)
38
+ fixture_name: The fixture name to load.
39
+
40
+ Raises:
41
+ FileNotFoundError: Fixtures not found.
42
+ MissingDependencyError: The `anyio` library is required to use this function.
43
+
44
+ Returns:
45
+ The parsed JSON data
46
+ """
47
+ try:
48
+ from anyio import Path as AsyncPath
49
+ except ImportError as exc:
50
+ raise MissingDependencyError(package="anyio") from exc
51
+
52
+ fixture = AsyncPath(fixtures_path / f"{fixture_name}.json")
53
+ if await fixture.exists():
54
+ async with await fixture.open(mode="r", encoding="utf-8") as f:
55
+ f_data = await f.read()
56
+ return decode_json(f_data)
57
+ msg = f"Could not find the {fixture_name} fixture"
58
+ raise FileNotFoundError(msg)
@@ -0,0 +1,127 @@
1
+ """Logging utilities for SQLSpec.
2
+
3
+ This module provides utilities for structured logging with correlation IDs.
4
+ Users should configure their own logging handlers and levels as needed.
5
+ SQLSpec provides StructuredFormatter for JSON-formatted logs if desired.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from contextvars import ContextVar
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ from sqlspec._serialization import encode_json
15
+
16
+ if TYPE_CHECKING:
17
+ from logging import LogRecord
18
+
19
+ __all__ = ("StructuredFormatter", "correlation_id_var", "get_correlation_id", "get_logger", "set_correlation_id")
20
+
21
+ correlation_id_var: ContextVar[str | None] = ContextVar("correlation_id", default=None)
22
+
23
+
24
+ def set_correlation_id(correlation_id: str | None) -> None:
25
+ """Set the correlation ID for the current context.
26
+
27
+ Args:
28
+ correlation_id: The correlation ID to set, or None to clear
29
+ """
30
+ correlation_id_var.set(correlation_id)
31
+
32
+
33
+ def get_correlation_id() -> str | None:
34
+ """Get the current correlation ID.
35
+
36
+ Returns:
37
+ The current correlation ID or None if not set
38
+ """
39
+ return correlation_id_var.get()
40
+
41
+
42
+ class StructuredFormatter(logging.Formatter):
43
+ """Structured JSON formatter with correlation ID support."""
44
+
45
+ def format(self, record: LogRecord) -> str:
46
+ """Format log record as structured JSON.
47
+
48
+ Args:
49
+ record: The log record to format
50
+
51
+ Returns:
52
+ JSON formatted log entry
53
+ """
54
+ log_entry = {
55
+ "timestamp": self.formatTime(record, self.datefmt),
56
+ "level": record.levelname,
57
+ "logger": record.name,
58
+ "message": record.getMessage(),
59
+ "module": record.module,
60
+ "function": record.funcName,
61
+ "line": record.lineno,
62
+ }
63
+
64
+ if correlation_id := get_correlation_id():
65
+ log_entry["correlation_id"] = correlation_id
66
+
67
+ if hasattr(record, "extra_fields"):
68
+ log_entry.update(record.extra_fields) # pyright: ignore
69
+
70
+ if record.exc_info:
71
+ log_entry["exception"] = self.formatException(record.exc_info)
72
+
73
+ return encode_json(log_entry)
74
+
75
+
76
+ class CorrelationIDFilter(logging.Filter):
77
+ """Filter that adds correlation ID to log records."""
78
+
79
+ def filter(self, record: LogRecord) -> bool:
80
+ """Add correlation ID to record if available.
81
+
82
+ Args:
83
+ record: The log record to filter
84
+
85
+ Returns:
86
+ Always True to pass the record through
87
+ """
88
+ if correlation_id := get_correlation_id():
89
+ record.correlation_id = correlation_id
90
+ return True
91
+
92
+
93
+ def get_logger(name: str | None = None) -> logging.Logger:
94
+ """Get a logger instance with standardized configuration.
95
+
96
+ Args:
97
+ name: Logger name. If not provided, returns the root sqlspec logger.
98
+
99
+ Returns:
100
+ Configured logger instance
101
+ """
102
+ if name is None:
103
+ return logging.getLogger("sqlspec")
104
+
105
+ if not name.startswith("sqlspec"):
106
+ name = f"sqlspec.{name}"
107
+
108
+ logger = logging.getLogger(name)
109
+
110
+ if not any(isinstance(f, CorrelationIDFilter) for f in logger.filters):
111
+ logger.addFilter(CorrelationIDFilter())
112
+
113
+ return logger
114
+
115
+
116
+ def log_with_context(logger: logging.Logger, level: int, message: str, **extra_fields: Any) -> None:
117
+ """Log a message with structured extra fields.
118
+
119
+ Args:
120
+ logger: The logger to use
121
+ level: Log level
122
+ message: Log message
123
+ **extra_fields: Additional fields to include in structured logs
124
+ """
125
+ record = logger.makeRecord(logger.name, level, "(unknown file)", 0, message, (), None)
126
+ record.extra_fields = extra_fields
127
+ logger.handle(record)
@@ -0,0 +1,89 @@
1
+ """General utility functions."""
2
+
3
+ import importlib
4
+ from importlib.util import find_spec
5
+ from pathlib import Path
6
+ from typing import Any, Optional
7
+
8
+ __all__ = ("import_string", "module_to_os_path")
9
+
10
+
11
+ def module_to_os_path(dotted_path: str = "app") -> "Path":
12
+ """Convert a module dotted path to filesystem path.
13
+
14
+ Args:
15
+ dotted_path: The path to the module.
16
+
17
+ Raises:
18
+ TypeError: The module could not be found.
19
+
20
+ Returns:
21
+ The path to the module.
22
+ """
23
+ try:
24
+ if (src := find_spec(dotted_path)) is None: # pragma: no cover
25
+ msg = f"Couldn't find the path for {dotted_path}"
26
+ raise TypeError(msg)
27
+ except ModuleNotFoundError as e:
28
+ msg = f"Couldn't find the path for {dotted_path}"
29
+ raise TypeError(msg) from e
30
+
31
+ path = Path(str(src.origin))
32
+ return path.parent if path.is_file() else path
33
+
34
+
35
+ def import_string(dotted_path: str) -> "Any":
36
+ """Import a module or attribute from a dotted path string.
37
+
38
+ Args:
39
+ dotted_path: The path of the module to import.
40
+
41
+ Returns:
42
+ The imported object.
43
+ """
44
+
45
+ def _raise_import_error(msg: str, exc: "Optional[Exception]" = None) -> None:
46
+ if exc is not None:
47
+ raise ImportError(msg) from exc
48
+ raise ImportError(msg)
49
+
50
+ obj: Any = None
51
+ try:
52
+ parts = dotted_path.split(".")
53
+ module = None
54
+ i = len(parts)
55
+
56
+ for i in range(len(parts), 0, -1):
57
+ module_path = ".".join(parts[:i])
58
+ try:
59
+ module = importlib.import_module(module_path)
60
+ break
61
+ except ModuleNotFoundError:
62
+ continue
63
+ else:
64
+ _raise_import_error(f"{dotted_path} doesn't look like a module path")
65
+
66
+ if module is None:
67
+ _raise_import_error(f"Failed to import any module from {dotted_path}")
68
+
69
+ obj = module
70
+ attrs = parts[i:]
71
+ if not attrs and i == len(parts) and len(parts) > 1:
72
+ parent_module_path = ".".join(parts[:-1])
73
+ attr = parts[-1]
74
+ try:
75
+ parent_module = importlib.import_module(parent_module_path)
76
+ except Exception:
77
+ return obj
78
+ if not hasattr(parent_module, attr):
79
+ _raise_import_error(f"Module '{parent_module_path}' has no attribute '{attr}' in '{dotted_path}'")
80
+
81
+ for attr in attrs:
82
+ if not hasattr(obj, attr):
83
+ _raise_import_error(
84
+ f"Module '{module.__name__ if module is not None else 'unknown'}' has no attribute '{attr}' in '{dotted_path}'"
85
+ )
86
+ obj = getattr(obj, attr)
87
+ except Exception as e: # pylint: disable=broad-exception-caught
88
+ _raise_import_error(f"Could not import '{dotted_path}': {e}", e)
89
+ return obj
@@ -0,0 +1,4 @@
1
+ from sqlspec._serialization import decode_json as from_json
2
+ from sqlspec._serialization import encode_json as to_json
3
+
4
+ __all__ = ("from_json", "to_json")
@@ -0,0 +1,32 @@
1
+ import threading
2
+ from typing import Any, TypeVar
3
+
4
+ __all__ = ("SingletonMeta",)
5
+
6
+
7
+ _T = TypeVar("_T")
8
+
9
+
10
+ class SingletonMeta(type):
11
+ """Metaclass for singleton pattern."""
12
+
13
+ _instances: dict[type, object] = {}
14
+ _lock = threading.Lock()
15
+
16
+ def __call__(cls: type[_T], *args: Any, **kwargs: Any) -> _T:
17
+ """Call method for the singleton metaclass.
18
+
19
+ Args:
20
+ cls: The class being instantiated.
21
+ *args: Positional arguments for the class constructor.
22
+ **kwargs: Keyword arguments for the class constructor.
23
+
24
+ Returns:
25
+ The singleton instance of the class.
26
+ """
27
+ if cls not in SingletonMeta._instances: # pyright: ignore[reportUnnecessaryContains]
28
+ with SingletonMeta._lock:
29
+ if cls not in SingletonMeta._instances:
30
+ instance = super().__call__(*args, **kwargs) # type: ignore[misc]
31
+ SingletonMeta._instances[cls] = instance
32
+ return SingletonMeta._instances[cls] # type: ignore[return-value]
@@ -0,0 +1,237 @@
1
+ import asyncio
2
+ import functools
3
+ import inspect
4
+ import sys
5
+ from contextlib import AbstractAsyncContextManager, AbstractContextManager
6
+ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
7
+
8
+ from typing_extensions import ParamSpec
9
+
10
+ if TYPE_CHECKING:
11
+ from collections.abc import Awaitable, Callable, Coroutine
12
+ from types import TracebackType
13
+
14
+ try:
15
+ import uvloop # pyright: ignore[reportMissingImports]
16
+ except ImportError:
17
+ uvloop = None # type: ignore[assignment,unused-ignore]
18
+
19
+
20
+ ReturnT = TypeVar("ReturnT")
21
+ ParamSpecT = ParamSpec("ParamSpecT")
22
+ T = TypeVar("T")
23
+
24
+
25
+ class CapacityLimiter:
26
+ """Limits the number of concurrent operations using a semaphore."""
27
+
28
+ def __init__(self, total_tokens: int) -> None:
29
+ self._semaphore = asyncio.Semaphore(total_tokens)
30
+
31
+ async def acquire(self) -> None:
32
+ await self._semaphore.acquire()
33
+
34
+ def release(self) -> None:
35
+ self._semaphore.release()
36
+
37
+ @property
38
+ def total_tokens(self) -> int:
39
+ return self._semaphore._value
40
+
41
+ @total_tokens.setter
42
+ def total_tokens(self, value: int) -> None:
43
+ self._semaphore = asyncio.Semaphore(value)
44
+
45
+ async def __aenter__(self) -> None:
46
+ await self.acquire()
47
+
48
+ async def __aexit__(
49
+ self,
50
+ exc_type: "Optional[type[BaseException]]",
51
+ exc_val: "Optional[BaseException]",
52
+ exc_tb: "Optional[TracebackType]",
53
+ ) -> None:
54
+ self.release()
55
+
56
+
57
+ _default_limiter = CapacityLimiter(15)
58
+
59
+
60
+ def run_(async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]") -> "Callable[ParamSpecT, ReturnT]":
61
+ """Convert an async function to a blocking function using asyncio.run().
62
+
63
+ Args:
64
+ async_function: The async function to convert.
65
+
66
+ Returns:
67
+ A blocking function that runs the async function.
68
+ """
69
+
70
+ @functools.wraps(async_function)
71
+ def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
72
+ partial_f = functools.partial(async_function, *args, **kwargs)
73
+ try:
74
+ loop = asyncio.get_running_loop()
75
+ except RuntimeError:
76
+ loop = None
77
+
78
+ if loop is not None:
79
+ return asyncio.run(partial_f())
80
+ if uvloop and sys.platform != "win32":
81
+ uvloop.install() # pyright: ignore[reportUnknownMemberType]
82
+ return asyncio.run(partial_f())
83
+
84
+ return wrapper
85
+
86
+
87
+ def await_(
88
+ async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]", raise_sync_error: bool = True
89
+ ) -> "Callable[ParamSpecT, ReturnT]":
90
+ """Convert an async function to a blocking one, running in the main async loop.
91
+
92
+ Args:
93
+ async_function: The async function to convert.
94
+ raise_sync_error: If False, runs in a new event loop if no loop is present.
95
+ If True (default), raises RuntimeError if no loop is running.
96
+
97
+ Returns:
98
+ A blocking function that runs the async function.
99
+ """
100
+
101
+ @functools.wraps(async_function)
102
+ def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
103
+ partial_f = functools.partial(async_function, *args, **kwargs)
104
+ try:
105
+ loop = asyncio.get_running_loop()
106
+ except RuntimeError:
107
+ if raise_sync_error:
108
+ msg = "Cannot run async function"
109
+ raise RuntimeError(msg) from None
110
+ return asyncio.run(partial_f())
111
+ else:
112
+ if loop.is_running():
113
+ try:
114
+ current_task = asyncio.current_task(loop=loop)
115
+ except RuntimeError:
116
+ current_task = None
117
+
118
+ if current_task is not None:
119
+ msg = "await_ cannot be called from within an async task running on the same event loop. Use 'await' instead."
120
+ raise RuntimeError(msg)
121
+ future = asyncio.run_coroutine_threadsafe(partial_f(), loop)
122
+ return future.result()
123
+ if raise_sync_error:
124
+ msg = "Cannot run async function"
125
+ raise RuntimeError(msg)
126
+ return asyncio.run(partial_f())
127
+
128
+ return wrapper
129
+
130
+
131
+ def async_(
132
+ function: "Callable[ParamSpecT, ReturnT]", *, limiter: "Optional[CapacityLimiter]" = None
133
+ ) -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
134
+ """Convert a blocking function to an async one using asyncio.to_thread().
135
+
136
+ Args:
137
+ function: The blocking function to convert.
138
+ limiter: Limit the total number of threads.
139
+
140
+ Returns:
141
+ An async function that runs the original function in a thread.
142
+ """
143
+
144
+ @functools.wraps(function)
145
+ async def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
146
+ partial_f = functools.partial(function, *args, **kwargs)
147
+ used_limiter = limiter or _default_limiter
148
+ async with used_limiter:
149
+ return await asyncio.to_thread(partial_f)
150
+
151
+ return wrapper
152
+
153
+
154
+ def ensure_async_(
155
+ function: "Callable[ParamSpecT, Union[Awaitable[ReturnT], ReturnT]]",
156
+ ) -> "Callable[ParamSpecT, Awaitable[ReturnT]]":
157
+ """Convert a function to an async one if it is not already.
158
+
159
+ Args:
160
+ function: The function to convert.
161
+
162
+ Returns:
163
+ An async function that runs the original function.
164
+ """
165
+ if inspect.iscoroutinefunction(function):
166
+ return function
167
+
168
+ @functools.wraps(function)
169
+ async def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT":
170
+ result = function(*args, **kwargs)
171
+ if inspect.isawaitable(result):
172
+ return await result
173
+ return await async_(lambda: result)()
174
+
175
+ return wrapper
176
+
177
+
178
+ class _ContextManagerWrapper(Generic[T]):
179
+ def __init__(self, cm: AbstractContextManager[T]) -> None:
180
+ self._cm = cm
181
+
182
+ async def __aenter__(self) -> T:
183
+ return self._cm.__enter__()
184
+
185
+ async def __aexit__(
186
+ self,
187
+ exc_type: "Optional[type[BaseException]]",
188
+ exc_val: "Optional[BaseException]",
189
+ exc_tb: "Optional[TracebackType]",
190
+ ) -> "Optional[bool]":
191
+ return self._cm.__exit__(exc_type, exc_val, exc_tb)
192
+
193
+
194
+ def with_ensure_async_(
195
+ obj: "Union[AbstractContextManager[T], AbstractAsyncContextManager[T]]",
196
+ ) -> "AbstractAsyncContextManager[T]":
197
+ """Convert a context manager to an async one if it is not already.
198
+
199
+ Args:
200
+ obj: The context manager to convert.
201
+
202
+ Returns:
203
+ An async context manager that runs the original context manager.
204
+ """
205
+
206
+ if isinstance(obj, AbstractContextManager):
207
+ return cast("AbstractAsyncContextManager[T]", _ContextManagerWrapper(obj))
208
+ return obj
209
+
210
+
211
+ class NoValue:
212
+ """Sentinel class for missing values."""
213
+
214
+
215
+ async def get_next(iterable: Any, default: Any = NoValue, *args: Any) -> Any: # pragma: no cover
216
+ """Return the next item from an async iterator.
217
+
218
+ Args:
219
+ iterable: An async iterable.
220
+ default: An optional default value to return if the iterable is empty.
221
+ *args: The remaining args
222
+
223
+ Returns:
224
+ The next value of the iterable.
225
+
226
+ Raises:
227
+ StopAsyncIteration: The iterable given is not async.
228
+ """
229
+ has_default = bool(not isinstance(default, NoValue))
230
+ try:
231
+ return await iterable.__anext__()
232
+
233
+ except StopAsyncIteration as exc:
234
+ if has_default:
235
+ return default
236
+
237
+ raise StopAsyncIteration from exc