sqlspec 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. sqlspec/__init__.py +104 -0
  2. sqlspec/__main__.py +12 -0
  3. sqlspec/__metadata__.py +14 -0
  4. sqlspec/_serialization.py +312 -0
  5. sqlspec/_typing.py +784 -0
  6. sqlspec/adapters/__init__.py +0 -0
  7. sqlspec/adapters/adbc/__init__.py +5 -0
  8. sqlspec/adapters/adbc/_types.py +12 -0
  9. sqlspec/adapters/adbc/adk/__init__.py +5 -0
  10. sqlspec/adapters/adbc/adk/store.py +880 -0
  11. sqlspec/adapters/adbc/config.py +436 -0
  12. sqlspec/adapters/adbc/data_dictionary.py +537 -0
  13. sqlspec/adapters/adbc/driver.py +841 -0
  14. sqlspec/adapters/adbc/litestar/__init__.py +5 -0
  15. sqlspec/adapters/adbc/litestar/store.py +504 -0
  16. sqlspec/adapters/adbc/type_converter.py +153 -0
  17. sqlspec/adapters/aiosqlite/__init__.py +29 -0
  18. sqlspec/adapters/aiosqlite/_types.py +13 -0
  19. sqlspec/adapters/aiosqlite/adk/__init__.py +5 -0
  20. sqlspec/adapters/aiosqlite/adk/store.py +536 -0
  21. sqlspec/adapters/aiosqlite/config.py +310 -0
  22. sqlspec/adapters/aiosqlite/data_dictionary.py +260 -0
  23. sqlspec/adapters/aiosqlite/driver.py +463 -0
  24. sqlspec/adapters/aiosqlite/litestar/__init__.py +5 -0
  25. sqlspec/adapters/aiosqlite/litestar/store.py +281 -0
  26. sqlspec/adapters/aiosqlite/pool.py +500 -0
  27. sqlspec/adapters/asyncmy/__init__.py +25 -0
  28. sqlspec/adapters/asyncmy/_types.py +12 -0
  29. sqlspec/adapters/asyncmy/adk/__init__.py +5 -0
  30. sqlspec/adapters/asyncmy/adk/store.py +503 -0
  31. sqlspec/adapters/asyncmy/config.py +246 -0
  32. sqlspec/adapters/asyncmy/data_dictionary.py +241 -0
  33. sqlspec/adapters/asyncmy/driver.py +632 -0
  34. sqlspec/adapters/asyncmy/litestar/__init__.py +5 -0
  35. sqlspec/adapters/asyncmy/litestar/store.py +296 -0
  36. sqlspec/adapters/asyncpg/__init__.py +23 -0
  37. sqlspec/adapters/asyncpg/_type_handlers.py +76 -0
  38. sqlspec/adapters/asyncpg/_types.py +23 -0
  39. sqlspec/adapters/asyncpg/adk/__init__.py +5 -0
  40. sqlspec/adapters/asyncpg/adk/store.py +460 -0
  41. sqlspec/adapters/asyncpg/config.py +464 -0
  42. sqlspec/adapters/asyncpg/data_dictionary.py +321 -0
  43. sqlspec/adapters/asyncpg/driver.py +720 -0
  44. sqlspec/adapters/asyncpg/litestar/__init__.py +5 -0
  45. sqlspec/adapters/asyncpg/litestar/store.py +253 -0
  46. sqlspec/adapters/bigquery/__init__.py +18 -0
  47. sqlspec/adapters/bigquery/_types.py +12 -0
  48. sqlspec/adapters/bigquery/adk/__init__.py +5 -0
  49. sqlspec/adapters/bigquery/adk/store.py +585 -0
  50. sqlspec/adapters/bigquery/config.py +298 -0
  51. sqlspec/adapters/bigquery/data_dictionary.py +256 -0
  52. sqlspec/adapters/bigquery/driver.py +1073 -0
  53. sqlspec/adapters/bigquery/litestar/__init__.py +5 -0
  54. sqlspec/adapters/bigquery/litestar/store.py +327 -0
  55. sqlspec/adapters/bigquery/type_converter.py +125 -0
  56. sqlspec/adapters/duckdb/__init__.py +24 -0
  57. sqlspec/adapters/duckdb/_types.py +12 -0
  58. sqlspec/adapters/duckdb/adk/__init__.py +14 -0
  59. sqlspec/adapters/duckdb/adk/store.py +563 -0
  60. sqlspec/adapters/duckdb/config.py +396 -0
  61. sqlspec/adapters/duckdb/data_dictionary.py +264 -0
  62. sqlspec/adapters/duckdb/driver.py +604 -0
  63. sqlspec/adapters/duckdb/litestar/__init__.py +5 -0
  64. sqlspec/adapters/duckdb/litestar/store.py +332 -0
  65. sqlspec/adapters/duckdb/pool.py +273 -0
  66. sqlspec/adapters/duckdb/type_converter.py +133 -0
  67. sqlspec/adapters/oracledb/__init__.py +32 -0
  68. sqlspec/adapters/oracledb/_numpy_handlers.py +133 -0
  69. sqlspec/adapters/oracledb/_types.py +39 -0
  70. sqlspec/adapters/oracledb/_uuid_handlers.py +130 -0
  71. sqlspec/adapters/oracledb/adk/__init__.py +5 -0
  72. sqlspec/adapters/oracledb/adk/store.py +1632 -0
  73. sqlspec/adapters/oracledb/config.py +469 -0
  74. sqlspec/adapters/oracledb/data_dictionary.py +717 -0
  75. sqlspec/adapters/oracledb/driver.py +1493 -0
  76. sqlspec/adapters/oracledb/litestar/__init__.py +5 -0
  77. sqlspec/adapters/oracledb/litestar/store.py +765 -0
  78. sqlspec/adapters/oracledb/migrations.py +532 -0
  79. sqlspec/adapters/oracledb/type_converter.py +207 -0
  80. sqlspec/adapters/psqlpy/__init__.py +16 -0
  81. sqlspec/adapters/psqlpy/_type_handlers.py +44 -0
  82. sqlspec/adapters/psqlpy/_types.py +12 -0
  83. sqlspec/adapters/psqlpy/adk/__init__.py +5 -0
  84. sqlspec/adapters/psqlpy/adk/store.py +483 -0
  85. sqlspec/adapters/psqlpy/config.py +271 -0
  86. sqlspec/adapters/psqlpy/data_dictionary.py +179 -0
  87. sqlspec/adapters/psqlpy/driver.py +892 -0
  88. sqlspec/adapters/psqlpy/litestar/__init__.py +5 -0
  89. sqlspec/adapters/psqlpy/litestar/store.py +272 -0
  90. sqlspec/adapters/psqlpy/type_converter.py +102 -0
  91. sqlspec/adapters/psycopg/__init__.py +32 -0
  92. sqlspec/adapters/psycopg/_type_handlers.py +90 -0
  93. sqlspec/adapters/psycopg/_types.py +18 -0
  94. sqlspec/adapters/psycopg/adk/__init__.py +5 -0
  95. sqlspec/adapters/psycopg/adk/store.py +962 -0
  96. sqlspec/adapters/psycopg/config.py +487 -0
  97. sqlspec/adapters/psycopg/data_dictionary.py +630 -0
  98. sqlspec/adapters/psycopg/driver.py +1336 -0
  99. sqlspec/adapters/psycopg/litestar/__init__.py +5 -0
  100. sqlspec/adapters/psycopg/litestar/store.py +554 -0
  101. sqlspec/adapters/spanner/__init__.py +38 -0
  102. sqlspec/adapters/spanner/_type_handlers.py +186 -0
  103. sqlspec/adapters/spanner/_types.py +12 -0
  104. sqlspec/adapters/spanner/adk/__init__.py +5 -0
  105. sqlspec/adapters/spanner/adk/store.py +435 -0
  106. sqlspec/adapters/spanner/config.py +241 -0
  107. sqlspec/adapters/spanner/data_dictionary.py +95 -0
  108. sqlspec/adapters/spanner/dialect/__init__.py +6 -0
  109. sqlspec/adapters/spanner/dialect/_spangres.py +52 -0
  110. sqlspec/adapters/spanner/dialect/_spanner.py +123 -0
  111. sqlspec/adapters/spanner/driver.py +366 -0
  112. sqlspec/adapters/spanner/litestar/__init__.py +5 -0
  113. sqlspec/adapters/spanner/litestar/store.py +266 -0
  114. sqlspec/adapters/spanner/type_converter.py +46 -0
  115. sqlspec/adapters/sqlite/__init__.py +18 -0
  116. sqlspec/adapters/sqlite/_type_handlers.py +86 -0
  117. sqlspec/adapters/sqlite/_types.py +11 -0
  118. sqlspec/adapters/sqlite/adk/__init__.py +5 -0
  119. sqlspec/adapters/sqlite/adk/store.py +582 -0
  120. sqlspec/adapters/sqlite/config.py +221 -0
  121. sqlspec/adapters/sqlite/data_dictionary.py +256 -0
  122. sqlspec/adapters/sqlite/driver.py +527 -0
  123. sqlspec/adapters/sqlite/litestar/__init__.py +5 -0
  124. sqlspec/adapters/sqlite/litestar/store.py +318 -0
  125. sqlspec/adapters/sqlite/pool.py +140 -0
  126. sqlspec/base.py +811 -0
  127. sqlspec/builder/__init__.py +146 -0
  128. sqlspec/builder/_base.py +900 -0
  129. sqlspec/builder/_column.py +517 -0
  130. sqlspec/builder/_ddl.py +1642 -0
  131. sqlspec/builder/_delete.py +84 -0
  132. sqlspec/builder/_dml.py +381 -0
  133. sqlspec/builder/_expression_wrappers.py +46 -0
  134. sqlspec/builder/_factory.py +1537 -0
  135. sqlspec/builder/_insert.py +315 -0
  136. sqlspec/builder/_join.py +375 -0
  137. sqlspec/builder/_merge.py +848 -0
  138. sqlspec/builder/_parsing_utils.py +297 -0
  139. sqlspec/builder/_select.py +1615 -0
  140. sqlspec/builder/_update.py +161 -0
  141. sqlspec/builder/_vector_expressions.py +259 -0
  142. sqlspec/cli.py +764 -0
  143. sqlspec/config.py +1540 -0
  144. sqlspec/core/__init__.py +305 -0
  145. sqlspec/core/cache.py +785 -0
  146. sqlspec/core/compiler.py +603 -0
  147. sqlspec/core/filters.py +872 -0
  148. sqlspec/core/hashing.py +274 -0
  149. sqlspec/core/metrics.py +83 -0
  150. sqlspec/core/parameters/__init__.py +64 -0
  151. sqlspec/core/parameters/_alignment.py +266 -0
  152. sqlspec/core/parameters/_converter.py +413 -0
  153. sqlspec/core/parameters/_processor.py +341 -0
  154. sqlspec/core/parameters/_registry.py +201 -0
  155. sqlspec/core/parameters/_transformers.py +226 -0
  156. sqlspec/core/parameters/_types.py +430 -0
  157. sqlspec/core/parameters/_validator.py +123 -0
  158. sqlspec/core/pipeline.py +187 -0
  159. sqlspec/core/result.py +1124 -0
  160. sqlspec/core/splitter.py +940 -0
  161. sqlspec/core/stack.py +163 -0
  162. sqlspec/core/statement.py +835 -0
  163. sqlspec/core/type_conversion.py +235 -0
  164. sqlspec/driver/__init__.py +36 -0
  165. sqlspec/driver/_async.py +1027 -0
  166. sqlspec/driver/_common.py +1236 -0
  167. sqlspec/driver/_sync.py +1025 -0
  168. sqlspec/driver/mixins/__init__.py +7 -0
  169. sqlspec/driver/mixins/_result_tools.py +61 -0
  170. sqlspec/driver/mixins/_sql_translator.py +122 -0
  171. sqlspec/driver/mixins/_storage.py +311 -0
  172. sqlspec/exceptions.py +321 -0
  173. sqlspec/extensions/__init__.py +0 -0
  174. sqlspec/extensions/adk/__init__.py +53 -0
  175. sqlspec/extensions/adk/_types.py +51 -0
  176. sqlspec/extensions/adk/converters.py +172 -0
  177. sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +144 -0
  178. sqlspec/extensions/adk/migrations/__init__.py +0 -0
  179. sqlspec/extensions/adk/service.py +181 -0
  180. sqlspec/extensions/adk/store.py +536 -0
  181. sqlspec/extensions/aiosql/__init__.py +10 -0
  182. sqlspec/extensions/aiosql/adapter.py +471 -0
  183. sqlspec/extensions/fastapi/__init__.py +19 -0
  184. sqlspec/extensions/fastapi/extension.py +341 -0
  185. sqlspec/extensions/fastapi/providers.py +543 -0
  186. sqlspec/extensions/flask/__init__.py +36 -0
  187. sqlspec/extensions/flask/_state.py +72 -0
  188. sqlspec/extensions/flask/_utils.py +40 -0
  189. sqlspec/extensions/flask/extension.py +402 -0
  190. sqlspec/extensions/litestar/__init__.py +23 -0
  191. sqlspec/extensions/litestar/_utils.py +52 -0
  192. sqlspec/extensions/litestar/cli.py +92 -0
  193. sqlspec/extensions/litestar/config.py +90 -0
  194. sqlspec/extensions/litestar/handlers.py +316 -0
  195. sqlspec/extensions/litestar/migrations/0001_create_session_table.py +137 -0
  196. sqlspec/extensions/litestar/migrations/__init__.py +3 -0
  197. sqlspec/extensions/litestar/plugin.py +638 -0
  198. sqlspec/extensions/litestar/providers.py +454 -0
  199. sqlspec/extensions/litestar/store.py +265 -0
  200. sqlspec/extensions/otel/__init__.py +58 -0
  201. sqlspec/extensions/prometheus/__init__.py +107 -0
  202. sqlspec/extensions/starlette/__init__.py +10 -0
  203. sqlspec/extensions/starlette/_state.py +26 -0
  204. sqlspec/extensions/starlette/_utils.py +52 -0
  205. sqlspec/extensions/starlette/extension.py +257 -0
  206. sqlspec/extensions/starlette/middleware.py +154 -0
  207. sqlspec/loader.py +716 -0
  208. sqlspec/migrations/__init__.py +36 -0
  209. sqlspec/migrations/base.py +728 -0
  210. sqlspec/migrations/commands.py +1140 -0
  211. sqlspec/migrations/context.py +142 -0
  212. sqlspec/migrations/fix.py +203 -0
  213. sqlspec/migrations/loaders.py +450 -0
  214. sqlspec/migrations/runner.py +1024 -0
  215. sqlspec/migrations/templates.py +234 -0
  216. sqlspec/migrations/tracker.py +403 -0
  217. sqlspec/migrations/utils.py +256 -0
  218. sqlspec/migrations/validation.py +203 -0
  219. sqlspec/observability/__init__.py +22 -0
  220. sqlspec/observability/_config.py +228 -0
  221. sqlspec/observability/_diagnostics.py +67 -0
  222. sqlspec/observability/_dispatcher.py +151 -0
  223. sqlspec/observability/_observer.py +180 -0
  224. sqlspec/observability/_runtime.py +381 -0
  225. sqlspec/observability/_spans.py +158 -0
  226. sqlspec/protocols.py +530 -0
  227. sqlspec/py.typed +0 -0
  228. sqlspec/storage/__init__.py +46 -0
  229. sqlspec/storage/_utils.py +104 -0
  230. sqlspec/storage/backends/__init__.py +1 -0
  231. sqlspec/storage/backends/base.py +163 -0
  232. sqlspec/storage/backends/fsspec.py +398 -0
  233. sqlspec/storage/backends/local.py +377 -0
  234. sqlspec/storage/backends/obstore.py +580 -0
  235. sqlspec/storage/errors.py +104 -0
  236. sqlspec/storage/pipeline.py +604 -0
  237. sqlspec/storage/registry.py +289 -0
  238. sqlspec/typing.py +219 -0
  239. sqlspec/utils/__init__.py +31 -0
  240. sqlspec/utils/arrow_helpers.py +95 -0
  241. sqlspec/utils/config_resolver.py +153 -0
  242. sqlspec/utils/correlation.py +132 -0
  243. sqlspec/utils/data_transformation.py +114 -0
  244. sqlspec/utils/dependencies.py +79 -0
  245. sqlspec/utils/deprecation.py +113 -0
  246. sqlspec/utils/fixtures.py +250 -0
  247. sqlspec/utils/logging.py +172 -0
  248. sqlspec/utils/module_loader.py +273 -0
  249. sqlspec/utils/portal.py +325 -0
  250. sqlspec/utils/schema.py +288 -0
  251. sqlspec/utils/serializers.py +396 -0
  252. sqlspec/utils/singleton.py +41 -0
  253. sqlspec/utils/sync_tools.py +277 -0
  254. sqlspec/utils/text.py +108 -0
  255. sqlspec/utils/type_converters.py +99 -0
  256. sqlspec/utils/type_guards.py +1324 -0
  257. sqlspec/utils/version.py +444 -0
  258. sqlspec-0.32.0.dist-info/METADATA +202 -0
  259. sqlspec-0.32.0.dist-info/RECORD +262 -0
  260. sqlspec-0.32.0.dist-info/WHEEL +4 -0
  261. sqlspec-0.32.0.dist-info/entry_points.txt +2 -0
  262. sqlspec-0.32.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,7 @@
1
+ """Driver mixins for instrumentation, storage, and utilities."""
2
+
3
+ from sqlspec.driver.mixins._result_tools import ToSchemaMixin
4
+ from sqlspec.driver.mixins._sql_translator import SQLTranslatorMixin
5
+ from sqlspec.driver.mixins._storage import StorageDriverMixin
6
+
7
+ __all__ = ("SQLTranslatorMixin", "StorageDriverMixin", "ToSchemaMixin")
@@ -0,0 +1,61 @@
1
+ """Result handling and schema conversion mixins for database drivers."""
2
+
3
+ from typing import TYPE_CHECKING, Any, overload
4
+
5
+ from mypy_extensions import trait
6
+
7
+ from sqlspec.utils.schema import to_schema
8
+
9
+ if TYPE_CHECKING:
10
+ from sqlspec.typing import SchemaT
11
+
12
+ __all__ = ("ToSchemaMixin",)
13
+
14
+
15
+ @trait
16
+ class ToSchemaMixin:
17
+ """Mixin providing data transformation methods for various schema types."""
18
+
19
+ __slots__ = ()
20
+
21
+ @overload
22
+ @staticmethod
23
+ def to_schema(data: "list[dict[str, Any]]", *, schema_type: "type[SchemaT]") -> "list[SchemaT]": ...
24
+ @overload
25
+ @staticmethod
26
+ def to_schema(data: "list[dict[str, Any]]", *, schema_type: None = None) -> "list[dict[str, Any]]": ...
27
+ @overload
28
+ @staticmethod
29
+ def to_schema(data: "dict[str, Any]", *, schema_type: "type[SchemaT]") -> "SchemaT": ...
30
+ @overload
31
+ @staticmethod
32
+ def to_schema(data: "dict[str, Any]", *, schema_type: None = None) -> "dict[str, Any]": ...
33
+ @overload
34
+ @staticmethod
35
+ def to_schema(data: Any, *, schema_type: "type[SchemaT]") -> Any: ...
36
+ @overload
37
+ @staticmethod
38
+ def to_schema(data: Any, *, schema_type: None = None) -> Any: ...
39
+
40
+ @staticmethod
41
+ def to_schema(data: Any, *, schema_type: "type[Any] | None" = None) -> Any:
42
+ """Convert data to a specified schema type.
43
+
44
+ Supports transformation to various schema types including:
45
+ - TypedDict
46
+ - dataclasses
47
+ - msgspec Structs
48
+ - Pydantic models
49
+ - attrs classes
50
+
51
+ Args:
52
+ data: Input data to convert (dict, list of dicts, or other)
53
+ schema_type: Target schema type for conversion. If None, returns data unchanged.
54
+
55
+ Returns:
56
+ Converted data in the specified schema type, or original data if schema_type is None
57
+
58
+ Raises:
59
+ SQLSpecError: If schema_type is not a supported type
60
+ """
61
+ return to_schema(data, schema_type=schema_type)
@@ -0,0 +1,122 @@
1
+ """SQL translation mixin for cross-database compatibility."""
2
+
3
+ from typing import Final, NoReturn
4
+
5
+ from mypy_extensions import trait
6
+ from sqlglot import exp, parse_one
7
+ from sqlglot.dialects.dialect import DialectType
8
+
9
+ from sqlspec.core import SQL, Statement
10
+ from sqlspec.exceptions import SQLConversionError
11
+
12
+ __all__ = ("SQLTranslatorMixin",)
13
+
14
+
15
+ _DEFAULT_PRETTY: Final[bool] = True
16
+
17
+
18
+ @trait
19
+ class SQLTranslatorMixin:
20
+ """Mixin for drivers supporting SQL translation."""
21
+
22
+ __slots__ = ()
23
+ dialect: "DialectType | None"
24
+
25
+ def convert_to_dialect(
26
+ self, statement: "Statement", to_dialect: "DialectType | None" = None, pretty: bool = _DEFAULT_PRETTY
27
+ ) -> str:
28
+ """Convert a statement to a target SQL dialect.
29
+
30
+ Args:
31
+ statement: SQL statement to convert
32
+ to_dialect: Target dialect (defaults to current dialect)
33
+ pretty: Whether to format the output SQL
34
+
35
+ Returns:
36
+ SQL string in target dialect
37
+
38
+
39
+ """
40
+
41
+ parsed_expression: exp.Expression | None = None
42
+
43
+ if statement is not None and isinstance(statement, SQL):
44
+ if statement.expression is None:
45
+ self._raise_statement_parse_error()
46
+ parsed_expression = statement.expression
47
+ elif isinstance(statement, exp.Expression):
48
+ parsed_expression = statement
49
+ else:
50
+ parsed_expression = self._parse_statement_safely(statement)
51
+
52
+ target_dialect = to_dialect or self.dialect
53
+
54
+ return self._generate_sql_safely(parsed_expression, target_dialect, pretty)
55
+
56
+ def _parse_statement_safely(self, statement: "Statement") -> "exp.Expression":
57
+ """Parse statement with error handling.
58
+
59
+ Args:
60
+ statement: SQL statement to parse
61
+
62
+ Returns:
63
+ Parsed expression
64
+
65
+ """
66
+ try:
67
+ sql_string = str(statement)
68
+
69
+ return parse_one(sql_string, dialect=self.dialect, copy=False)
70
+ except Exception as e:
71
+ self._raise_parse_error(e)
72
+
73
+ def _generate_sql_safely(self, expression: "exp.Expression", dialect: DialectType, pretty: bool) -> str:
74
+ """Generate SQL with error handling.
75
+
76
+ Args:
77
+ expression: Parsed expression to convert
78
+ dialect: Target SQL dialect
79
+ pretty: Whether to format the output SQL
80
+
81
+ Returns:
82
+ Generated SQL string
83
+
84
+ """
85
+ try:
86
+ return expression.sql(dialect=dialect, pretty=pretty)
87
+ except Exception as e:
88
+ self._raise_conversion_error(dialect, e)
89
+
90
+ def _raise_statement_parse_error(self) -> NoReturn:
91
+ """Raise error for unparsable statements.
92
+
93
+ Raises:
94
+ SQLConversionError: Always raised
95
+ """
96
+ msg = "Statement could not be parsed"
97
+ raise SQLConversionError(msg)
98
+
99
+ def _raise_parse_error(self, e: Exception) -> NoReturn:
100
+ """Raise error for parsing failures.
101
+
102
+ Args:
103
+ e: Original exception that caused the failure
104
+
105
+ Raises:
106
+ SQLConversionError: Always raised
107
+ """
108
+ error_msg = f"Failed to parse SQL statement: {e!s}"
109
+ raise SQLConversionError(error_msg) from e
110
+
111
+ def _raise_conversion_error(self, dialect: DialectType, e: Exception) -> NoReturn:
112
+ """Raise error for conversion failures.
113
+
114
+ Args:
115
+ dialect: Target dialect that caused the failure
116
+ e: Original exception that caused the failure
117
+
118
+ Raises:
119
+ SQLConversionError: Always raised
120
+ """
121
+ error_msg = f"Failed to convert SQL expression to {dialect}: {e!s}"
122
+ raise SQLConversionError(error_msg) from e
@@ -0,0 +1,311 @@
1
+ """Storage bridge mixin shared by sync and async drivers."""
2
+
3
+ from collections.abc import Iterable
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any, cast
6
+
7
+ from mypy_extensions import trait
8
+
9
+ from sqlspec.exceptions import StorageCapabilityError
10
+ from sqlspec.storage import (
11
+ AsyncStoragePipeline,
12
+ StorageBridgeJob,
13
+ StorageCapabilities,
14
+ StorageDestination,
15
+ StorageFormat,
16
+ StorageTelemetry,
17
+ SyncStoragePipeline,
18
+ create_storage_bridge_job,
19
+ )
20
+ from sqlspec.utils.module_loader import ensure_pyarrow
21
+
22
+ if TYPE_CHECKING:
23
+ from collections.abc import Awaitable
24
+
25
+ from sqlspec.core import StatementConfig, StatementFilter
26
+ from sqlspec.core.result import ArrowResult
27
+ from sqlspec.core.statement import SQL
28
+ from sqlspec.observability import ObservabilityRuntime
29
+ from sqlspec.typing import ArrowTable, StatementParameters
30
+
31
+ __all__ = ("StorageDriverMixin",)
32
+
33
+
34
+ CAPABILITY_HINTS: dict[str, str] = {
35
+ "arrow_export_enabled": "native Arrow export",
36
+ "arrow_import_enabled": "native Arrow import",
37
+ "parquet_export_enabled": "native Parquet export",
38
+ "parquet_import_enabled": "native Parquet import",
39
+ }
40
+
41
+
42
+ @trait
43
+ class StorageDriverMixin:
44
+ """Mixin providing capability-aware storage bridge helpers."""
45
+
46
+ __slots__ = ()
47
+ storage_pipeline_factory: "type[SyncStoragePipeline | AsyncStoragePipeline] | None" = None
48
+ driver_features: dict[str, Any]
49
+
50
+ if TYPE_CHECKING:
51
+
52
+ @property
53
+ def observability(self) -> "ObservabilityRuntime": ...
54
+
55
+ def storage_capabilities(self) -> StorageCapabilities:
56
+ """Return cached storage capabilities for the active driver."""
57
+
58
+ capabilities = self.driver_features.get("storage_capabilities")
59
+ if capabilities is None:
60
+ msg = "Storage capabilities are not configured for this driver."
61
+ raise StorageCapabilityError(msg, capability="storage_capabilities")
62
+ return cast("StorageCapabilities", dict(capabilities))
63
+
64
+ def select_to_storage(
65
+ self,
66
+ statement: "SQL | str",
67
+ destination: StorageDestination,
68
+ /,
69
+ *parameters: "StatementParameters | StatementFilter",
70
+ statement_config: "StatementConfig | None" = None,
71
+ partitioner: "dict[str, Any] | None" = None,
72
+ format_hint: StorageFormat | None = None,
73
+ telemetry: StorageTelemetry | None = None,
74
+ ) -> "StorageBridgeJob | Awaitable[StorageBridgeJob]":
75
+ """Stream a SELECT statement directly into storage."""
76
+
77
+ self._raise_not_implemented("select_to_storage")
78
+ raise NotImplementedError
79
+
80
+ def select_to_arrow(
81
+ self,
82
+ statement: "SQL | str",
83
+ /,
84
+ *parameters: "StatementParameters | StatementFilter",
85
+ partitioner: "dict[str, Any] | None" = None,
86
+ memory_pool: Any | None = None,
87
+ statement_config: "StatementConfig | None" = None,
88
+ ) -> "ArrowResult | Awaitable[ArrowResult]":
89
+ """Execute a SELECT that returns an ArrowResult."""
90
+
91
+ self._raise_not_implemented("select_to_arrow")
92
+ raise NotImplementedError
93
+
94
+ def load_from_arrow(
95
+ self,
96
+ table: str,
97
+ source: "ArrowResult | Any",
98
+ *,
99
+ partitioner: "dict[str, Any] | None" = None,
100
+ overwrite: bool = False,
101
+ ) -> "StorageBridgeJob | Awaitable[StorageBridgeJob]":
102
+ """Load Arrow data into the target table."""
103
+
104
+ self._raise_not_implemented("load_from_arrow")
105
+ raise NotImplementedError
106
+
107
+ def load_from_storage(
108
+ self,
109
+ table: str,
110
+ source: StorageDestination,
111
+ *,
112
+ file_format: StorageFormat,
113
+ partitioner: "dict[str, Any] | None" = None,
114
+ overwrite: bool = False,
115
+ ) -> "StorageBridgeJob | Awaitable[StorageBridgeJob]":
116
+ """Load artifacts from storage into the target table."""
117
+
118
+ self._raise_not_implemented("load_from_storage")
119
+ raise NotImplementedError
120
+
121
+ def stage_artifact(self, request: "dict[str, Any]") -> "dict[str, Any]":
122
+ """Provision staging metadata for adapters that require remote URIs."""
123
+
124
+ self._raise_not_implemented("stage_artifact")
125
+ raise NotImplementedError
126
+
127
+ def flush_staging_artifacts(self, artifacts: "list[dict[str, Any]]", *, error: Exception | None = None) -> None:
128
+ """Clean up staged artifacts after a job completes."""
129
+
130
+ if artifacts:
131
+ self._raise_not_implemented("flush_staging_artifacts")
132
+
133
+ def get_storage_job(self, job_id: str) -> StorageBridgeJob | None:
134
+ """Fetch a previously created job handle."""
135
+
136
+ return None
137
+
138
+ def _storage_pipeline(self) -> "SyncStoragePipeline | AsyncStoragePipeline":
139
+ factory = self.storage_pipeline_factory
140
+ if factory is None:
141
+ if getattr(self, "is_async", False):
142
+ return AsyncStoragePipeline()
143
+ return SyncStoragePipeline()
144
+ return factory()
145
+
146
+ def _raise_not_implemented(self, capability: str) -> None:
147
+ msg = f"{capability} is not implemented for this driver"
148
+ remediation = "Override StorageDriverMixin methods on the adapter to enable this capability."
149
+ raise StorageCapabilityError(msg, capability=capability, remediation=remediation)
150
+
151
+ def _require_capability(self, capability_flag: str) -> None:
152
+ capabilities = self.storage_capabilities()
153
+ if capabilities.get(capability_flag, False):
154
+ return
155
+ human_label = CAPABILITY_HINTS.get(capability_flag, capability_flag)
156
+ remediation = "Check adapter supports this capability or stage artifacts via storage pipeline."
157
+ msg = f"{human_label} is not available for this adapter"
158
+ raise StorageCapabilityError(msg, capability=capability_flag, remediation=remediation)
159
+
160
+ def _attach_partition_telemetry(self, telemetry: StorageTelemetry, partitioner: "dict[str, Any] | None") -> None:
161
+ if not partitioner:
162
+ return
163
+ extra = dict(telemetry.get("extra", {}))
164
+ extra["partitioner"] = partitioner
165
+ telemetry["extra"] = extra
166
+
167
+ def _create_storage_job(
168
+ self, produced: StorageTelemetry, provided: StorageTelemetry | None = None, *, status: str = "completed"
169
+ ) -> StorageBridgeJob:
170
+ merged = cast("StorageTelemetry", dict(produced))
171
+ if provided:
172
+ source_bytes = provided.get("bytes_processed")
173
+ if source_bytes is not None:
174
+ merged["bytes_processed"] = int(merged.get("bytes_processed", 0)) + int(source_bytes)
175
+ extra = dict(merged.get("extra", {}))
176
+ extra["source"] = provided
177
+ merged["extra"] = extra
178
+ return create_storage_bridge_job(status, merged)
179
+
180
+ def _write_result_to_storage_sync(
181
+ self,
182
+ result: "ArrowResult",
183
+ destination: StorageDestination,
184
+ *,
185
+ format_hint: StorageFormat | None = None,
186
+ storage_options: "dict[str, Any] | None" = None,
187
+ pipeline: "SyncStoragePipeline | None" = None,
188
+ ) -> StorageTelemetry:
189
+ runtime = self.observability
190
+ span = runtime.start_storage_span(
191
+ "write", destination=self._stringify_storage_target(destination), format_label=format_hint
192
+ )
193
+ try:
194
+ telemetry = result.write_to_storage_sync(
195
+ destination, format_hint=format_hint, storage_options=storage_options, pipeline=pipeline
196
+ )
197
+ except Exception as exc: # pragma: no cover - passthrough
198
+ runtime.end_storage_span(span, error=exc)
199
+ raise
200
+ telemetry = runtime.annotate_storage_telemetry(telemetry)
201
+ runtime.end_storage_span(span, telemetry=telemetry)
202
+ return telemetry
203
+
204
+ async def _write_result_to_storage_async(
205
+ self,
206
+ result: "ArrowResult",
207
+ destination: StorageDestination,
208
+ *,
209
+ format_hint: StorageFormat | None = None,
210
+ storage_options: "dict[str, Any] | None" = None,
211
+ pipeline: "AsyncStoragePipeline | None" = None,
212
+ ) -> StorageTelemetry:
213
+ runtime = self.observability
214
+ span = runtime.start_storage_span(
215
+ "write", destination=self._stringify_storage_target(destination), format_label=format_hint
216
+ )
217
+ try:
218
+ telemetry = await result.write_to_storage_async(
219
+ destination, format_hint=format_hint, storage_options=storage_options, pipeline=pipeline
220
+ )
221
+ except Exception as exc: # pragma: no cover - passthrough
222
+ runtime.end_storage_span(span, error=exc)
223
+ raise
224
+ telemetry = runtime.annotate_storage_telemetry(telemetry)
225
+ runtime.end_storage_span(span, telemetry=telemetry)
226
+ return telemetry
227
+
228
+ def _read_arrow_from_storage_sync(
229
+ self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None
230
+ ) -> "tuple[ArrowTable, StorageTelemetry]":
231
+ runtime = self.observability
232
+ span = runtime.start_storage_span(
233
+ "read", destination=self._stringify_storage_target(source), format_label=file_format
234
+ )
235
+ pipeline = cast("SyncStoragePipeline", self._storage_pipeline())
236
+ try:
237
+ table, telemetry = pipeline.read_arrow(source, file_format=file_format, storage_options=storage_options)
238
+ except Exception as exc: # pragma: no cover - passthrough
239
+ runtime.end_storage_span(span, error=exc)
240
+ raise
241
+ telemetry = runtime.annotate_storage_telemetry(telemetry)
242
+ runtime.end_storage_span(span, telemetry=telemetry)
243
+ return table, telemetry
244
+
245
+ async def _read_arrow_from_storage_async(
246
+ self, source: StorageDestination, *, file_format: StorageFormat, storage_options: "dict[str, Any] | None" = None
247
+ ) -> "tuple[ArrowTable, StorageTelemetry]":
248
+ runtime = self.observability
249
+ span = runtime.start_storage_span(
250
+ "read", destination=self._stringify_storage_target(source), format_label=file_format
251
+ )
252
+ pipeline = cast("AsyncStoragePipeline", self._storage_pipeline())
253
+ try:
254
+ table, telemetry = await pipeline.read_arrow_async(
255
+ source, file_format=file_format, storage_options=storage_options
256
+ )
257
+ except Exception as exc: # pragma: no cover - passthrough
258
+ runtime.end_storage_span(span, error=exc)
259
+ raise
260
+ telemetry = runtime.annotate_storage_telemetry(telemetry)
261
+ runtime.end_storage_span(span, telemetry=telemetry)
262
+ return table, telemetry
263
+
264
+ @staticmethod
265
+ def _build_ingest_telemetry(table: "ArrowTable", *, format_label: str = "arrow") -> StorageTelemetry:
266
+ rows = int(getattr(table, "num_rows", 0))
267
+ bytes_processed = int(getattr(table, "nbytes", 0))
268
+ return {"rows_processed": rows, "bytes_processed": bytes_processed, "format": format_label}
269
+
270
+ def _coerce_arrow_table(self, source: "ArrowResult | Any") -> "ArrowTable":
271
+ ensure_pyarrow()
272
+ import pyarrow as pa
273
+
274
+ if hasattr(source, "get_data"):
275
+ table = source.get_data()
276
+ if isinstance(table, pa.Table):
277
+ return table
278
+ msg = "ArrowResult did not return a pyarrow.Table instance"
279
+ raise TypeError(msg)
280
+ if isinstance(source, pa.Table):
281
+ return source
282
+ if isinstance(source, pa.RecordBatch):
283
+ return pa.Table.from_batches([source])
284
+ if isinstance(source, Iterable):
285
+ return pa.Table.from_pylist(list(source))
286
+ msg = f"Unsupported Arrow source type: {type(source).__name__}"
287
+ raise TypeError(msg)
288
+
289
+ @staticmethod
290
+ def _stringify_storage_target(target: StorageDestination | None) -> str | None:
291
+ if target is None:
292
+ return None
293
+ if isinstance(target, Path):
294
+ return target.as_posix()
295
+ return str(target)
296
+
297
+ @staticmethod
298
+ def _arrow_table_to_rows(
299
+ table: "ArrowTable", columns: "list[str] | None" = None
300
+ ) -> "tuple[list[str], list[tuple[Any, ...]]]":
301
+ ensure_pyarrow()
302
+ resolved_columns = columns or list(table.column_names)
303
+ if not resolved_columns:
304
+ msg = "Arrow table has no columns to import"
305
+ raise ValueError(msg)
306
+ batches = table.to_pylist()
307
+ records: list[tuple[Any, ...]] = []
308
+ for row in batches:
309
+ record = tuple(row.get(col) for col in resolved_columns)
310
+ records.append(record)
311
+ return resolved_columns, records