sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -644
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -462
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +217 -451
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +418 -498
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +592 -634
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +393 -436
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +549 -942
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -550
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +741 -0
  31. sqlspec/adapters/psycopg/driver.py +732 -733
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +243 -426
  35. sqlspec/base.py +220 -825
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
  137. sqlspec-0.12.0.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -330
  150. sqlspec/mixins.py +0 -306
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.0.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/filters.py DELETED
@@ -1,330 +0,0 @@
1
- """Collection filter datastructures."""
2
-
3
- from abc import ABC, abstractmethod
4
- from collections import abc
5
- from dataclasses import dataclass
6
- from datetime import datetime
7
- from typing import Any, Generic, Literal, Optional, Protocol, Union, cast
8
-
9
- from sqlglot import exp
10
- from typing_extensions import TypeAlias, TypeVar
11
-
12
- from sqlspec.statement import SQLStatement
13
-
14
- __all__ = (
15
- "BeforeAfter",
16
- "CollectionFilter",
17
- "FilterTypes",
18
- "InAnyFilter",
19
- "LimitOffset",
20
- "NotInCollectionFilter",
21
- "NotInSearchFilter",
22
- "OnBeforeAfter",
23
- "OrderBy",
24
- "PaginationFilter",
25
- "SearchFilter",
26
- "StatementFilter",
27
- "apply_filter",
28
- )
29
-
30
- T = TypeVar("T")
31
-
32
-
33
- class StatementFilter(Protocol):
34
- """Protocol for filters that can be appended to a statement."""
35
-
36
- @abstractmethod
37
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
38
- """Append the filter to the statement.
39
-
40
- Args:
41
- statement: The SQL statement to modify.
42
-
43
- Returns:
44
- The modified statement.
45
- """
46
- raise NotImplementedError
47
-
48
-
49
- @dataclass
50
- class BeforeAfter(StatementFilter):
51
- """Data required to filter a query on a ``datetime`` column."""
52
-
53
- field_name: str
54
- """Name of the model attribute to filter on."""
55
- before: Optional[datetime] = None
56
- """Filter results where field earlier than this."""
57
- after: Optional[datetime] = None
58
- """Filter results where field later than this."""
59
-
60
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
61
- conditions = []
62
- params: dict[str, Any] = {}
63
- col_expr = exp.column(self.field_name)
64
-
65
- if self.before:
66
- param_name = statement.generate_param_name(f"{self.field_name}_before")
67
- conditions.append(exp.LT(this=col_expr, expression=exp.Placeholder(this=param_name)))
68
- params[param_name] = self.before
69
- if self.after:
70
- param_name = statement.generate_param_name(f"{self.field_name}_after")
71
- conditions.append(exp.GT(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
72
- params[param_name] = self.after
73
-
74
- if conditions:
75
- final_condition = conditions[0]
76
- for cond in conditions[1:]:
77
- final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
78
- statement.add_condition(final_condition, params)
79
- return statement
80
-
81
-
82
- @dataclass
83
- class OnBeforeAfter(StatementFilter):
84
- """Data required to filter a query on a ``datetime`` column."""
85
-
86
- field_name: str
87
- """Name of the model attribute to filter on."""
88
- on_or_before: Optional[datetime] = None
89
- """Filter results where field is on or earlier than this."""
90
- on_or_after: Optional[datetime] = None
91
- """Filter results where field on or later than this."""
92
-
93
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
94
- conditions = []
95
- params: dict[str, Any] = {}
96
- col_expr = exp.column(self.field_name)
97
-
98
- if self.on_or_before:
99
- param_name = statement.generate_param_name(f"{self.field_name}_on_or_before")
100
- conditions.append(exp.LTE(this=col_expr, expression=exp.Placeholder(this=param_name)))
101
- params[param_name] = self.on_or_before
102
- if self.on_or_after:
103
- param_name = statement.generate_param_name(f"{self.field_name}_on_or_after")
104
- conditions.append(exp.GTE(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
105
- params[param_name] = self.on_or_after
106
-
107
- if conditions:
108
- final_condition = conditions[0]
109
- for cond in conditions[1:]:
110
- final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
111
- statement.add_condition(final_condition, params)
112
- return statement
113
-
114
-
115
- class InAnyFilter(StatementFilter, ABC, Generic[T]):
116
- """Subclass for methods that have a `prefer_any` attribute."""
117
-
118
- @abstractmethod
119
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
120
- raise NotImplementedError
121
-
122
-
123
- @dataclass
124
- class CollectionFilter(InAnyFilter[T]):
125
- """Data required to construct a ``WHERE ... IN (...)`` clause."""
126
-
127
- field_name: str
128
- """Name of the model attribute to filter on."""
129
- values: Optional[abc.Collection[T]]
130
- """Values for ``IN`` clause.
131
-
132
- An empty list will return an empty result set, however, if ``None``, the filter is not applied to the query, and all rows are returned. """
133
-
134
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
135
- if self.values is None:
136
- return statement
137
-
138
- if not self.values: # Empty collection
139
- # Add a condition that is always false
140
- statement.add_condition(exp.false())
141
- return statement
142
-
143
- placeholder_expressions: list[exp.Placeholder] = []
144
- current_params: dict[str, Any] = {}
145
-
146
- for i, value_item in enumerate(self.values):
147
- param_key = statement.generate_param_name(f"{self.field_name}_in_{i}")
148
- placeholder_expressions.append(exp.Placeholder(this=param_key))
149
- current_params[param_key] = value_item
150
-
151
- in_condition = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
152
- statement.add_condition(in_condition, current_params)
153
- return statement
154
-
155
-
156
- @dataclass
157
- class NotInCollectionFilter(InAnyFilter[T]):
158
- """Data required to construct a ``WHERE ... NOT IN (...)`` clause."""
159
-
160
- field_name: str
161
- """Name of the model attribute to filter on."""
162
- values: Optional[abc.Collection[T]]
163
- """Values for ``NOT IN`` clause.
164
-
165
- An empty list or ``None`` will return all rows."""
166
-
167
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
168
- if self.values is None or not self.values: # Empty list or None, no filter applied
169
- return statement
170
-
171
- placeholder_expressions: list[exp.Placeholder] = []
172
- current_params: dict[str, Any] = {}
173
-
174
- for i, value_item in enumerate(self.values):
175
- param_key = statement.generate_param_name(f"{self.field_name}_notin_{i}")
176
- placeholder_expressions.append(exp.Placeholder(this=param_key))
177
- current_params[param_key] = value_item
178
-
179
- in_expr = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
180
- not_in_condition = exp.Not(this=in_expr)
181
- statement.add_condition(not_in_condition, current_params)
182
- return statement
183
-
184
-
185
- class PaginationFilter(StatementFilter, ABC):
186
- """Subclass for methods that function as a pagination type."""
187
-
188
- @abstractmethod
189
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
190
- raise NotImplementedError
191
-
192
-
193
- @dataclass
194
- class LimitOffset(PaginationFilter):
195
- """Data required to add limit/offset filtering to a query."""
196
-
197
- limit: int
198
- """Value for ``LIMIT`` clause of query."""
199
- offset: int
200
- """Value for ``OFFSET`` clause of query."""
201
-
202
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
203
- # Generate parameter names for limit and offset
204
- limit_param_name = statement.generate_param_name("limit_val")
205
- offset_param_name = statement.generate_param_name("offset_val")
206
-
207
- statement.add_limit(self.limit, param_name=limit_param_name)
208
- statement.add_offset(self.offset, param_name=offset_param_name)
209
-
210
- return statement
211
-
212
-
213
- @dataclass
214
- class OrderBy(StatementFilter):
215
- """Data required to construct a ``ORDER BY ...`` clause."""
216
-
217
- field_name: str
218
- """Name of the model attribute to sort on."""
219
- sort_order: Literal["asc", "desc"] = "asc"
220
- """Sort ascending or descending"""
221
-
222
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
223
- # Basic validation for sort_order, though Literal helps at type checking time
224
- normalized_sort_order = self.sort_order.lower()
225
- if normalized_sort_order not in {"asc", "desc"}:
226
- normalized_sort_order = "asc"
227
-
228
- statement.add_order_by(self.field_name, direction=cast("Literal['asc', 'desc']", normalized_sort_order))
229
-
230
- return statement
231
-
232
-
233
- @dataclass
234
- class SearchFilter(StatementFilter):
235
- """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause."""
236
-
237
- field_name: Union[str, set[str]]
238
- """Name of the model attribute to search on."""
239
- value: str
240
- """Search value."""
241
- ignore_case: Optional[bool] = False
242
- """Should the search be case insensitive."""
243
-
244
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
245
- if not self.value:
246
- return statement
247
-
248
- search_val_param_name = statement.generate_param_name("search_val")
249
-
250
- # The pattern %value% needs to be handled carefully.
251
- params = {search_val_param_name: f"%{self.value}%"}
252
- pattern_expr = exp.Placeholder(this=search_val_param_name)
253
-
254
- like_op = exp.ILike if self.ignore_case else exp.Like
255
-
256
- if isinstance(self.field_name, str):
257
- condition = like_op(this=exp.column(self.field_name), expression=pattern_expr)
258
- statement.add_condition(condition, params)
259
- elif isinstance(self.field_name, set) and self.field_name:
260
- field_conditions = [like_op(this=exp.column(field), expression=pattern_expr) for field in self.field_name]
261
- if not field_conditions:
262
- return statement
263
-
264
- final_condition = field_conditions[0]
265
- for cond in field_conditions[1:]:
266
- final_condition = exp.Or(this=final_condition, expression=cond) # type: ignore[assignment]
267
- statement.add_condition(final_condition, params)
268
-
269
- return statement
270
-
271
-
272
- @dataclass
273
- class NotInSearchFilter(SearchFilter): # Inherits field_name, value, ignore_case
274
- """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
275
-
276
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
277
- if not self.value:
278
- return statement
279
-
280
- search_val_param_name = statement.generate_param_name("not_search_val")
281
-
282
- params = {search_val_param_name: f"%{self.value}%"}
283
- pattern_expr = exp.Placeholder(this=search_val_param_name)
284
-
285
- like_op = exp.ILike if self.ignore_case else exp.Like
286
-
287
- if isinstance(self.field_name, str):
288
- condition = exp.Not(this=like_op(this=exp.column(self.field_name), expression=pattern_expr))
289
- statement.add_condition(condition, params)
290
- elif isinstance(self.field_name, set) and self.field_name:
291
- field_conditions = [
292
- exp.Not(this=like_op(this=exp.column(field), expression=pattern_expr)) for field in self.field_name
293
- ]
294
- if not field_conditions:
295
- return statement
296
-
297
- # Combine with AND: (field1 NOT LIKE pattern) AND (field2 NOT LIKE pattern) ...
298
- final_condition = field_conditions[0]
299
- for cond in field_conditions[1:]:
300
- final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
301
- statement.add_condition(final_condition, params)
302
-
303
- return statement
304
-
305
-
306
- # Function to be imported in SQLStatement module
307
- def apply_filter(statement: SQLStatement, filter_obj: StatementFilter) -> SQLStatement:
308
- """Apply a statement filter to a SQL statement.
309
-
310
- Args:
311
- statement: The SQL statement to modify.
312
- filter_obj: The filter to apply.
313
-
314
- Returns:
315
- The modified statement.
316
- """
317
- return filter_obj.append_to_statement(statement)
318
-
319
-
320
- FilterTypes: TypeAlias = Union[
321
- BeforeAfter,
322
- OnBeforeAfter,
323
- CollectionFilter[Any],
324
- LimitOffset,
325
- OrderBy,
326
- SearchFilter,
327
- NotInCollectionFilter[Any],
328
- NotInSearchFilter,
329
- ]
330
- """Aggregate type alias of the types supported for collection filtering."""
sqlspec/mixins.py DELETED
@@ -1,306 +0,0 @@
1
- import datetime
2
- from abc import abstractmethod
3
- from collections.abc import Sequence
4
- from enum import Enum
5
- from functools import partial
6
- from pathlib import Path, PurePath
7
- from typing import (
8
- TYPE_CHECKING,
9
- Any,
10
- Callable,
11
- ClassVar,
12
- Generic,
13
- Optional,
14
- Union,
15
- cast,
16
- overload,
17
- )
18
- from uuid import UUID
19
-
20
- from sqlglot import parse_one
21
- from sqlglot.dialects.dialect import DialectType
22
-
23
- from sqlspec.exceptions import SQLConversionError, SQLParsingError, SQLSpecError
24
- from sqlspec.typing import (
25
- ConnectionT,
26
- ModelDTOT,
27
- ModelT,
28
- StatementParameterType,
29
- convert,
30
- get_type_adapter,
31
- is_dataclass,
32
- is_msgspec_struct,
33
- is_pydantic_model,
34
- )
35
-
36
- if TYPE_CHECKING:
37
- from sqlspec.typing import ArrowTable
38
-
39
- __all__ = (
40
- "AsyncArrowBulkOperationsMixin",
41
- "AsyncParquetExportMixin",
42
- "SQLTranslatorMixin",
43
- "SyncArrowBulkOperationsMixin",
44
- "SyncParquetExportMixin",
45
- )
46
-
47
-
48
- class SyncArrowBulkOperationsMixin(Generic[ConnectionT]):
49
- """Mixin for sync drivers supporting bulk Apache Arrow operations."""
50
-
51
- __supports_arrow__: "ClassVar[bool]" = True
52
-
53
- @abstractmethod
54
- def select_arrow( # pyright: ignore[reportUnknownParameterType]
55
- self,
56
- sql: str,
57
- parameters: "Optional[StatementParameterType]" = None,
58
- /,
59
- *,
60
- connection: "Optional[ConnectionT]" = None,
61
- **kwargs: Any,
62
- ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType]
63
- """Execute a SQL query and return results as an Apache Arrow Table.
64
-
65
- Args:
66
- sql: The SQL query string.
67
- parameters: Parameters for the query.
68
- connection: Optional connection override.
69
- **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
70
-
71
- Returns:
72
- An Apache Arrow Table containing the query results.
73
- """
74
- raise NotImplementedError
75
-
76
-
77
- class AsyncArrowBulkOperationsMixin(Generic[ConnectionT]):
78
- """Mixin for async drivers supporting bulk Apache Arrow operations."""
79
-
80
- __supports_arrow__: "ClassVar[bool]" = True
81
-
82
- @abstractmethod
83
- async def select_arrow( # pyright: ignore[reportUnknownParameterType]
84
- self,
85
- sql: str,
86
- parameters: "Optional[StatementParameterType]" = None,
87
- /,
88
- *,
89
- connection: "Optional[ConnectionT]" = None,
90
- **kwargs: Any,
91
- ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType]
92
- """Execute a SQL query and return results as an Apache Arrow Table.
93
-
94
- Args:
95
- sql: The SQL query string.
96
- parameters: Parameters for the query.
97
- connection: Optional connection override.
98
- **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
99
-
100
- Returns:
101
- An Apache Arrow Table containing the query results.
102
- """
103
- raise NotImplementedError
104
-
105
-
106
- class SyncParquetExportMixin(Generic[ConnectionT]):
107
- """Mixin for sync drivers supporting Parquet export."""
108
-
109
- @abstractmethod
110
- def select_to_parquet(
111
- self,
112
- sql: str,
113
- parameters: "Optional[StatementParameterType]" = None,
114
- /,
115
- *,
116
- connection: "Optional[ConnectionT]" = None,
117
- **kwargs: Any,
118
- ) -> None:
119
- """Export a SQL query to a Parquet file."""
120
- raise NotImplementedError
121
-
122
-
123
- class AsyncParquetExportMixin(Generic[ConnectionT]):
124
- """Mixin for async drivers supporting Parquet export."""
125
-
126
- @abstractmethod
127
- async def select_to_parquet(
128
- self,
129
- sql: str,
130
- parameters: "Optional[StatementParameterType]" = None,
131
- /,
132
- *,
133
- connection: "Optional[ConnectionT]" = None,
134
- **kwargs: Any,
135
- ) -> None:
136
- """Export a SQL query to a Parquet file."""
137
- raise NotImplementedError
138
-
139
-
140
- class SQLTranslatorMixin(Generic[ConnectionT]):
141
- """Mixin for drivers supporting SQL translation."""
142
-
143
- dialect: str
144
-
145
- def convert_to_dialect(
146
- self,
147
- sql: str,
148
- to_dialect: DialectType = None,
149
- pretty: bool = True,
150
- ) -> str:
151
- """Convert a SQL query to a different dialect.
152
-
153
- Args:
154
- sql: The SQL query string to convert.
155
- to_dialect: The target dialect to convert to.
156
- pretty: Whether to pretty-print the SQL query.
157
-
158
- Returns:
159
- The converted SQL query string.
160
-
161
- Raises:
162
- SQLParsingError: If the SQL query cannot be parsed.
163
- SQLConversionError: If the SQL query cannot be converted to the target dialect.
164
- """
165
- try:
166
- parsed = parse_one(sql, dialect=self.dialect)
167
- except Exception as e:
168
- error_msg = f"Failed to parse SQL: {e!s}"
169
- raise SQLParsingError(error_msg) from e
170
- if to_dialect is None:
171
- to_dialect = self.dialect
172
- try:
173
- return parsed.sql(dialect=to_dialect, pretty=pretty)
174
- except Exception as e:
175
- error_msg = f"Failed to convert SQL to {to_dialect}: {e!s}"
176
- raise SQLConversionError(error_msg) from e
177
-
178
-
179
- _DEFAULT_TYPE_DECODERS = [ # pyright: ignore[reportUnknownVariableType]
180
- (lambda x: x is UUID, lambda t, v: t(v.hex)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
181
- (lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
182
- (lambda x: x is datetime.date, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
183
- (lambda x: x is datetime.time, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
184
- (lambda x: x is Enum, lambda t, v: t(v.value)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
185
- ]
186
-
187
-
188
- def _default_msgspec_deserializer(
189
- target_type: Any,
190
- value: Any,
191
- type_decoders: "Union[Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]], None]" = None,
192
- ) -> Any: # pragma: no cover
193
- """Transform values non-natively supported by ``msgspec``
194
-
195
- Args:
196
- target_type: Encountered type
197
- value: Value to coerce
198
- type_decoders: Optional sequence of type decoders
199
-
200
- Raises:
201
- TypeError: If the value cannot be coerced to the target type
202
-
203
- Returns:
204
- A ``msgspec``-supported type
205
- """
206
-
207
- if isinstance(value, target_type):
208
- return value
209
-
210
- if type_decoders:
211
- for predicate, decoder in type_decoders:
212
- if predicate(target_type):
213
- return decoder(target_type, value)
214
-
215
- if issubclass(target_type, (Path, PurePath, UUID)):
216
- return target_type(value)
217
-
218
- try:
219
- return target_type(value)
220
- except Exception as e:
221
- msg = f"Unsupported type: {type(value)!r}"
222
- raise TypeError(msg) from e
223
-
224
-
225
- class ResultConverter:
226
- """Simple mixin to help convert to dictionary or list of dictionaries to specified schema type.
227
-
228
- Single objects are transformed to the supplied schema type, and lists of objects are transformed into a list of the supplied schema type.
229
-
230
- Args:
231
- data: A database model instance or row mapping.
232
- Type: :class:`~sqlspec.typing.ModelDictT`
233
-
234
- Returns:
235
- The converted schema object.
236
- """
237
-
238
- @overload
239
- @staticmethod
240
- def to_schema(data: "ModelT", *, schema_type: None = None) -> "ModelT": ...
241
- @overload
242
- @staticmethod
243
- def to_schema(data: "dict[str, Any]", *, schema_type: "type[ModelDTOT]") -> "ModelDTOT": ...
244
- @overload
245
- @staticmethod
246
- def to_schema(data: "Sequence[ModelT]", *, schema_type: None = None) -> "Sequence[ModelT]": ...
247
- @overload
248
- @staticmethod
249
- def to_schema(data: "Sequence[dict[str, Any]]", *, schema_type: "type[ModelDTOT]") -> "Sequence[ModelDTOT]": ...
250
-
251
- @staticmethod
252
- def to_schema(
253
- data: "Union[ModelT, Sequence[ModelT], dict[str, Any], Sequence[dict[str, Any]], ModelDTOT, Sequence[ModelDTOT]]",
254
- *,
255
- schema_type: "Optional[type[ModelDTOT]]" = None,
256
- ) -> "Union[ModelT, Sequence[ModelT], ModelDTOT, Sequence[ModelDTOT]]":
257
- if schema_type is None:
258
- if not isinstance(data, Sequence):
259
- return cast("ModelT", data)
260
- return cast("Sequence[ModelT]", data)
261
- if is_dataclass(schema_type):
262
- if not isinstance(data, Sequence):
263
- # data is assumed to be dict[str, Any] as per the method's overloads
264
- return cast("ModelDTOT", schema_type(**data)) # type: ignore[operator]
265
- # data is assumed to be Sequence[dict[str, Any]]
266
- return cast("Sequence[ModelDTOT]", [schema_type(**item) for item in data]) # type: ignore[operator]
267
- if is_msgspec_struct(schema_type):
268
- if not isinstance(data, Sequence):
269
- return cast(
270
- "ModelDTOT",
271
- convert(
272
- obj=data,
273
- type=schema_type,
274
- from_attributes=True,
275
- dec_hook=partial(
276
- _default_msgspec_deserializer,
277
- type_decoders=_DEFAULT_TYPE_DECODERS,
278
- ),
279
- ),
280
- )
281
- return cast(
282
- "Sequence[ModelDTOT]",
283
- convert(
284
- obj=data,
285
- type=list[schema_type], # type: ignore[valid-type]
286
- from_attributes=True,
287
- dec_hook=partial(
288
- _default_msgspec_deserializer,
289
- type_decoders=_DEFAULT_TYPE_DECODERS,
290
- ),
291
- ),
292
- )
293
-
294
- if schema_type is not None and is_pydantic_model(schema_type):
295
- if not isinstance(data, Sequence):
296
- return cast(
297
- "ModelDTOT",
298
- get_type_adapter(schema_type).validate_python(data, from_attributes=True), # pyright: ignore
299
- )
300
- return cast(
301
- "Sequence[ModelDTOT]",
302
- get_type_adapter(list[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
303
- )
304
-
305
- msg = "`schema_type` should be a valid Dataclass, Pydantic model or Msgspec struct"
306
- raise SQLSpecError(msg)