sqlspec 0.11.1__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -621
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -431
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +218 -436
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +417 -487
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +600 -553
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +392 -406
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +548 -921
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -533
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +725 -0
  31. sqlspec/adapters/psycopg/driver.py +734 -694
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +242 -405
  35. sqlspec/base.py +220 -784
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/METADATA +97 -26
  137. sqlspec-0.12.1.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -331
  150. sqlspec/mixins.py +0 -305
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.1.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.1.dist-info → sqlspec-0.12.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/filters.py DELETED
@@ -1,331 +0,0 @@
1
- """Collection filter datastructures."""
2
-
3
- from abc import ABC, abstractmethod
4
- from collections import abc
5
- from dataclasses import dataclass
6
- from datetime import datetime
7
- from typing import Any, Generic, Literal, Optional, Protocol, Union, cast, runtime_checkable
8
-
9
- from sqlglot import exp
10
- from typing_extensions import TypeAlias, TypeVar
11
-
12
- from sqlspec.statement import SQLStatement
13
-
14
- __all__ = (
15
- "BeforeAfter",
16
- "CollectionFilter",
17
- "FilterTypes",
18
- "InAnyFilter",
19
- "LimitOffset",
20
- "NotInCollectionFilter",
21
- "NotInSearchFilter",
22
- "OnBeforeAfter",
23
- "OrderBy",
24
- "PaginationFilter",
25
- "SearchFilter",
26
- "StatementFilter",
27
- "apply_filter",
28
- )
29
-
30
- T = TypeVar("T")
31
-
32
-
33
- @runtime_checkable
34
- class StatementFilter(Protocol):
35
- """Protocol for filters that can be appended to a statement."""
36
-
37
- @abstractmethod
38
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
39
- """Append the filter to the statement.
40
-
41
- Args:
42
- statement: The SQL statement to modify.
43
-
44
- Returns:
45
- The modified statement.
46
- """
47
- raise NotImplementedError
48
-
49
-
50
- @dataclass
51
- class BeforeAfter(StatementFilter):
52
- """Data required to filter a query on a ``datetime`` column."""
53
-
54
- field_name: str
55
- """Name of the model attribute to filter on."""
56
- before: Optional[datetime] = None
57
- """Filter results where field earlier than this."""
58
- after: Optional[datetime] = None
59
- """Filter results where field later than this."""
60
-
61
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
62
- conditions = []
63
- params: dict[str, Any] = {}
64
- col_expr = exp.column(self.field_name)
65
-
66
- if self.before:
67
- param_name = statement.generate_param_name(f"{self.field_name}_before")
68
- conditions.append(exp.LT(this=col_expr, expression=exp.Placeholder(this=param_name)))
69
- params[param_name] = self.before
70
- if self.after:
71
- param_name = statement.generate_param_name(f"{self.field_name}_after")
72
- conditions.append(exp.GT(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
73
- params[param_name] = self.after
74
-
75
- if conditions:
76
- final_condition = conditions[0]
77
- for cond in conditions[1:]:
78
- final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
79
- statement.add_condition(final_condition, params)
80
- return statement
81
-
82
-
83
- @dataclass
84
- class OnBeforeAfter(StatementFilter):
85
- """Data required to filter a query on a ``datetime`` column."""
86
-
87
- field_name: str
88
- """Name of the model attribute to filter on."""
89
- on_or_before: Optional[datetime] = None
90
- """Filter results where field is on or earlier than this."""
91
- on_or_after: Optional[datetime] = None
92
- """Filter results where field on or later than this."""
93
-
94
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
95
- conditions = []
96
- params: dict[str, Any] = {}
97
- col_expr = exp.column(self.field_name)
98
-
99
- if self.on_or_before:
100
- param_name = statement.generate_param_name(f"{self.field_name}_on_or_before")
101
- conditions.append(exp.LTE(this=col_expr, expression=exp.Placeholder(this=param_name)))
102
- params[param_name] = self.on_or_before
103
- if self.on_or_after:
104
- param_name = statement.generate_param_name(f"{self.field_name}_on_or_after")
105
- conditions.append(exp.GTE(this=col_expr, expression=exp.Placeholder(this=param_name))) # type: ignore[arg-type]
106
- params[param_name] = self.on_or_after
107
-
108
- if conditions:
109
- final_condition = conditions[0]
110
- for cond in conditions[1:]:
111
- final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
112
- statement.add_condition(final_condition, params)
113
- return statement
114
-
115
-
116
- class InAnyFilter(StatementFilter, ABC, Generic[T]):
117
- """Subclass for methods that have a `prefer_any` attribute."""
118
-
119
- @abstractmethod
120
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
121
- raise NotImplementedError
122
-
123
-
124
- @dataclass
125
- class CollectionFilter(InAnyFilter[T]):
126
- """Data required to construct a ``WHERE ... IN (...)`` clause."""
127
-
128
- field_name: str
129
- """Name of the model attribute to filter on."""
130
- values: Optional[abc.Collection[T]]
131
- """Values for ``IN`` clause.
132
-
133
- An empty list will return an empty result set, however, if ``None``, the filter is not applied to the query, and all rows are returned. """
134
-
135
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
136
- if self.values is None:
137
- return statement
138
-
139
- if not self.values: # Empty collection
140
- # Add a condition that is always false
141
- statement.add_condition(exp.false())
142
- return statement
143
-
144
- placeholder_expressions: list[exp.Placeholder] = []
145
- current_params: dict[str, Any] = {}
146
-
147
- for i, value_item in enumerate(self.values):
148
- param_key = statement.generate_param_name(f"{self.field_name}_in_{i}")
149
- placeholder_expressions.append(exp.Placeholder(this=param_key))
150
- current_params[param_key] = value_item
151
-
152
- in_condition = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
153
- statement.add_condition(in_condition, current_params)
154
- return statement
155
-
156
-
157
- @dataclass
158
- class NotInCollectionFilter(InAnyFilter[T]):
159
- """Data required to construct a ``WHERE ... NOT IN (...)`` clause."""
160
-
161
- field_name: str
162
- """Name of the model attribute to filter on."""
163
- values: Optional[abc.Collection[T]]
164
- """Values for ``NOT IN`` clause.
165
-
166
- An empty list or ``None`` will return all rows."""
167
-
168
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
169
- if self.values is None or not self.values: # Empty list or None, no filter applied
170
- return statement
171
-
172
- placeholder_expressions: list[exp.Placeholder] = []
173
- current_params: dict[str, Any] = {}
174
-
175
- for i, value_item in enumerate(self.values):
176
- param_key = statement.generate_param_name(f"{self.field_name}_notin_{i}")
177
- placeholder_expressions.append(exp.Placeholder(this=param_key))
178
- current_params[param_key] = value_item
179
-
180
- in_expr = exp.In(this=exp.column(self.field_name), expressions=placeholder_expressions)
181
- not_in_condition = exp.Not(this=in_expr)
182
- statement.add_condition(not_in_condition, current_params)
183
- return statement
184
-
185
-
186
- class PaginationFilter(StatementFilter, ABC):
187
- """Subclass for methods that function as a pagination type."""
188
-
189
- @abstractmethod
190
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
191
- raise NotImplementedError
192
-
193
-
194
- @dataclass
195
- class LimitOffset(PaginationFilter):
196
- """Data required to add limit/offset filtering to a query."""
197
-
198
- limit: int
199
- """Value for ``LIMIT`` clause of query."""
200
- offset: int
201
- """Value for ``OFFSET`` clause of query."""
202
-
203
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
204
- # Generate parameter names for limit and offset
205
- limit_param_name = statement.generate_param_name("limit_val")
206
- offset_param_name = statement.generate_param_name("offset_val")
207
-
208
- statement.add_limit(self.limit, param_name=limit_param_name)
209
- statement.add_offset(self.offset, param_name=offset_param_name)
210
-
211
- return statement
212
-
213
-
214
- @dataclass
215
- class OrderBy(StatementFilter):
216
- """Data required to construct a ``ORDER BY ...`` clause."""
217
-
218
- field_name: str
219
- """Name of the model attribute to sort on."""
220
- sort_order: Literal["asc", "desc"] = "asc"
221
- """Sort ascending or descending"""
222
-
223
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
224
- # Basic validation for sort_order, though Literal helps at type checking time
225
- normalized_sort_order = self.sort_order.lower()
226
- if normalized_sort_order not in {"asc", "desc"}:
227
- normalized_sort_order = "asc"
228
-
229
- statement.add_order_by(self.field_name, direction=cast("Literal['asc', 'desc']", normalized_sort_order))
230
-
231
- return statement
232
-
233
-
234
- @dataclass
235
- class SearchFilter(StatementFilter):
236
- """Data required to construct a ``WHERE field_name LIKE '%' || :value || '%'`` clause."""
237
-
238
- field_name: Union[str, set[str]]
239
- """Name of the model attribute to search on."""
240
- value: str
241
- """Search value."""
242
- ignore_case: Optional[bool] = False
243
- """Should the search be case insensitive."""
244
-
245
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
246
- if not self.value:
247
- return statement
248
-
249
- search_val_param_name = statement.generate_param_name("search_val")
250
-
251
- # The pattern %value% needs to be handled carefully.
252
- params = {search_val_param_name: f"%{self.value}%"}
253
- pattern_expr = exp.Placeholder(this=search_val_param_name)
254
-
255
- like_op = exp.ILike if self.ignore_case else exp.Like
256
-
257
- if isinstance(self.field_name, str):
258
- condition = like_op(this=exp.column(self.field_name), expression=pattern_expr)
259
- statement.add_condition(condition, params)
260
- elif isinstance(self.field_name, set) and self.field_name:
261
- field_conditions = [like_op(this=exp.column(field), expression=pattern_expr) for field in self.field_name]
262
- if not field_conditions:
263
- return statement
264
-
265
- final_condition = field_conditions[0]
266
- for cond in field_conditions[1:]:
267
- final_condition = exp.Or(this=final_condition, expression=cond) # type: ignore[assignment]
268
- statement.add_condition(final_condition, params)
269
-
270
- return statement
271
-
272
-
273
- @dataclass
274
- class NotInSearchFilter(SearchFilter): # Inherits field_name, value, ignore_case
275
- """Data required to construct a ``WHERE field_name NOT LIKE '%' || :value || '%'`` clause."""
276
-
277
- def append_to_statement(self, statement: SQLStatement) -> SQLStatement:
278
- if not self.value:
279
- return statement
280
-
281
- search_val_param_name = statement.generate_param_name("not_search_val")
282
-
283
- params = {search_val_param_name: f"%{self.value}%"}
284
- pattern_expr = exp.Placeholder(this=search_val_param_name)
285
-
286
- like_op = exp.ILike if self.ignore_case else exp.Like
287
-
288
- if isinstance(self.field_name, str):
289
- condition = exp.Not(this=like_op(this=exp.column(self.field_name), expression=pattern_expr))
290
- statement.add_condition(condition, params)
291
- elif isinstance(self.field_name, set) and self.field_name:
292
- field_conditions = [
293
- exp.Not(this=like_op(this=exp.column(field), expression=pattern_expr)) for field in self.field_name
294
- ]
295
- if not field_conditions:
296
- return statement
297
-
298
- # Combine with AND: (field1 NOT LIKE pattern) AND (field2 NOT LIKE pattern) ...
299
- final_condition = field_conditions[0]
300
- for cond in field_conditions[1:]:
301
- final_condition = exp.And(this=final_condition, expression=cond) # type: ignore[assignment]
302
- statement.add_condition(final_condition, params)
303
-
304
- return statement
305
-
306
-
307
- # Function to be imported in SQLStatement module
308
- def apply_filter(statement: SQLStatement, filter_obj: StatementFilter) -> SQLStatement:
309
- """Apply a statement filter to a SQL statement.
310
-
311
- Args:
312
- statement: The SQL statement to modify.
313
- filter_obj: The filter to apply.
314
-
315
- Returns:
316
- The modified statement.
317
- """
318
- return filter_obj.append_to_statement(statement)
319
-
320
-
321
- FilterTypes: TypeAlias = Union[
322
- BeforeAfter,
323
- OnBeforeAfter,
324
- CollectionFilter[Any],
325
- LimitOffset,
326
- OrderBy,
327
- SearchFilter,
328
- NotInCollectionFilter[Any],
329
- NotInSearchFilter,
330
- ]
331
- """Aggregate type alias of the types supported for collection filtering."""
sqlspec/mixins.py DELETED
@@ -1,305 +0,0 @@
1
- import datetime
2
- from abc import abstractmethod
3
- from collections.abc import Sequence
4
- from enum import Enum
5
- from functools import partial
6
- from pathlib import Path, PurePath
7
- from typing import (
8
- TYPE_CHECKING,
9
- Any,
10
- Callable,
11
- ClassVar,
12
- Generic,
13
- Optional,
14
- Union,
15
- cast,
16
- overload,
17
- )
18
- from uuid import UUID
19
-
20
- from sqlglot import parse_one
21
- from sqlglot.dialects.dialect import DialectType
22
-
23
- from sqlspec.exceptions import SQLConversionError, SQLParsingError, SQLSpecError
24
- from sqlspec.typing import (
25
- ConnectionT,
26
- ModelDTOT,
27
- ModelT,
28
- StatementParameterType,
29
- convert,
30
- get_type_adapter,
31
- is_dataclass,
32
- is_msgspec_struct,
33
- is_pydantic_model,
34
- )
35
-
36
- if TYPE_CHECKING:
37
- from sqlspec.filters import StatementFilter
38
- from sqlspec.typing import ArrowTable
39
-
40
- __all__ = (
41
- "AsyncArrowBulkOperationsMixin",
42
- "AsyncParquetExportMixin",
43
- "SQLTranslatorMixin",
44
- "SyncArrowBulkOperationsMixin",
45
- "SyncParquetExportMixin",
46
- )
47
-
48
-
49
- class SyncArrowBulkOperationsMixin(Generic[ConnectionT]):
50
- """Mixin for sync drivers supporting bulk Apache Arrow operations."""
51
-
52
- __supports_arrow__: "ClassVar[bool]" = True
53
-
54
- @abstractmethod
55
- def select_arrow(
56
- self,
57
- sql: str,
58
- parameters: "Optional[StatementParameterType]" = None,
59
- *filters: "StatementFilter",
60
- connection: "Optional[ConnectionT]" = None,
61
- **kwargs: Any,
62
- ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType]
63
- """Execute a SQL query and return results as an Apache Arrow Table.
64
-
65
- Args:
66
- sql: The SQL query string.
67
- parameters: Parameters for the query.
68
- filters: Optional filters to apply to the query.
69
- connection: Optional connection override.
70
- **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
71
-
72
- Returns:
73
- An Apache Arrow Table containing the query results.
74
- """
75
- raise NotImplementedError
76
-
77
-
78
- class AsyncArrowBulkOperationsMixin(Generic[ConnectionT]):
79
- """Mixin for async drivers supporting bulk Apache Arrow operations."""
80
-
81
- __supports_arrow__: "ClassVar[bool]" = True
82
-
83
- @abstractmethod
84
- async def select_arrow(
85
- self,
86
- sql: str,
87
- parameters: "Optional[StatementParameterType]" = None,
88
- *filters: "StatementFilter",
89
- connection: "Optional[ConnectionT]" = None,
90
- **kwargs: Any,
91
- ) -> "ArrowTable": # pyright: ignore[reportUnknownReturnType]
92
- """Execute a SQL query and return results as an Apache Arrow Table.
93
-
94
- Args:
95
- sql: The SQL query string.
96
- parameters: Parameters for the query.
97
- filters: Optional filters to apply to the query.
98
- connection: Optional connection override.
99
- **kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
100
-
101
- Returns:
102
- An Apache Arrow Table containing the query results.
103
- """
104
- raise NotImplementedError
105
-
106
-
107
- class SyncParquetExportMixin(Generic[ConnectionT]):
108
- """Mixin for sync drivers supporting Parquet export."""
109
-
110
- @abstractmethod
111
- def select_to_parquet(
112
- self,
113
- sql: str,
114
- parameters: "Optional[StatementParameterType]" = None,
115
- *filters: "StatementFilter",
116
- connection: "Optional[ConnectionT]" = None,
117
- **kwargs: Any,
118
- ) -> None:
119
- """Export a SQL query to a Parquet file."""
120
- raise NotImplementedError
121
-
122
-
123
- class AsyncParquetExportMixin(Generic[ConnectionT]):
124
- """Mixin for async drivers supporting Parquet export."""
125
-
126
- @abstractmethod
127
- async def select_to_parquet(
128
- self,
129
- sql: str,
130
- parameters: "Optional[StatementParameterType]" = None,
131
- *filters: "StatementFilter",
132
- connection: "Optional[ConnectionT]" = None,
133
- **kwargs: Any,
134
- ) -> None:
135
- """Export a SQL query to a Parquet file."""
136
- raise NotImplementedError
137
-
138
-
139
- class SQLTranslatorMixin(Generic[ConnectionT]):
140
- """Mixin for drivers supporting SQL translation."""
141
-
142
- dialect: str
143
-
144
- def convert_to_dialect(
145
- self,
146
- sql: str,
147
- to_dialect: DialectType = None,
148
- pretty: bool = True,
149
- ) -> str:
150
- """Convert a SQL query to a different dialect.
151
-
152
- Args:
153
- sql: The SQL query string to convert.
154
- to_dialect: The target dialect to convert to.
155
- pretty: Whether to pretty-print the SQL query.
156
-
157
- Returns:
158
- The converted SQL query string.
159
-
160
- Raises:
161
- SQLParsingError: If the SQL query cannot be parsed.
162
- SQLConversionError: If the SQL query cannot be converted to the target dialect.
163
- """
164
- try:
165
- parsed = parse_one(sql, dialect=self.dialect)
166
- except Exception as e:
167
- error_msg = f"Failed to parse SQL: {e!s}"
168
- raise SQLParsingError(error_msg) from e
169
- if to_dialect is None:
170
- to_dialect = self.dialect
171
- try:
172
- return parsed.sql(dialect=to_dialect, pretty=pretty)
173
- except Exception as e:
174
- error_msg = f"Failed to convert SQL to {to_dialect}: {e!s}"
175
- raise SQLConversionError(error_msg) from e
176
-
177
-
178
- _DEFAULT_TYPE_DECODERS = [ # pyright: ignore[reportUnknownVariableType]
179
- (lambda x: x is UUID, lambda t, v: t(v.hex)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
180
- (lambda x: x is datetime.datetime, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
181
- (lambda x: x is datetime.date, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
182
- (lambda x: x is datetime.time, lambda t, v: t(v.isoformat())), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
183
- (lambda x: x is Enum, lambda t, v: t(v.value)), # pyright: ignore[reportUnknownLambdaType,reportUnknownMemberType]
184
- ]
185
-
186
-
187
- def _default_msgspec_deserializer(
188
- target_type: Any,
189
- value: Any,
190
- type_decoders: "Union[Sequence[tuple[Callable[[Any], bool], Callable[[Any, Any], Any]]], None]" = None,
191
- ) -> Any: # pragma: no cover
192
- """Transform values non-natively supported by ``msgspec``
193
-
194
- Args:
195
- target_type: Encountered type
196
- value: Value to coerce
197
- type_decoders: Optional sequence of type decoders
198
-
199
- Raises:
200
- TypeError: If the value cannot be coerced to the target type
201
-
202
- Returns:
203
- A ``msgspec``-supported type
204
- """
205
-
206
- if isinstance(value, target_type):
207
- return value
208
-
209
- if type_decoders:
210
- for predicate, decoder in type_decoders:
211
- if predicate(target_type):
212
- return decoder(target_type, value)
213
-
214
- if issubclass(target_type, (Path, PurePath, UUID)):
215
- return target_type(value)
216
-
217
- try:
218
- return target_type(value)
219
- except Exception as e:
220
- msg = f"Unsupported type: {type(value)!r}"
221
- raise TypeError(msg) from e
222
-
223
-
224
- class ResultConverter:
225
- """Simple mixin to help convert to dictionary or list of dictionaries to specified schema type.
226
-
227
- Single objects are transformed to the supplied schema type, and lists of objects are transformed into a list of the supplied schema type.
228
-
229
- Args:
230
- data: A database model instance or row mapping.
231
- Type: :class:`~sqlspec.typing.ModelDictT`
232
-
233
- Returns:
234
- The converted schema object.
235
- """
236
-
237
- @overload
238
- @staticmethod
239
- def to_schema(data: "ModelT", *, schema_type: None = None) -> "ModelT": ...
240
- @overload
241
- @staticmethod
242
- def to_schema(data: "dict[str, Any]", *, schema_type: "type[ModelDTOT]") -> "ModelDTOT": ...
243
- @overload
244
- @staticmethod
245
- def to_schema(data: "Sequence[ModelT]", *, schema_type: None = None) -> "Sequence[ModelT]": ...
246
- @overload
247
- @staticmethod
248
- def to_schema(data: "Sequence[dict[str, Any]]", *, schema_type: "type[ModelDTOT]") -> "Sequence[ModelDTOT]": ...
249
-
250
- @staticmethod
251
- def to_schema(
252
- data: "Union[ModelT, Sequence[ModelT], dict[str, Any], Sequence[dict[str, Any]], ModelDTOT, Sequence[ModelDTOT]]",
253
- *,
254
- schema_type: "Optional[type[ModelDTOT]]" = None,
255
- ) -> "Union[ModelT, Sequence[ModelT], ModelDTOT, Sequence[ModelDTOT]]":
256
- if schema_type is None:
257
- if not isinstance(data, Sequence):
258
- return cast("ModelT", data)
259
- return cast("Sequence[ModelT]", data)
260
- if is_dataclass(schema_type):
261
- if not isinstance(data, Sequence):
262
- # data is assumed to be dict[str, Any] as per the method's overloads
263
- return cast("ModelDTOT", schema_type(**data)) # type: ignore[operator]
264
- # data is assumed to be Sequence[dict[str, Any]]
265
- return cast("Sequence[ModelDTOT]", [schema_type(**item) for item in data]) # type: ignore[operator]
266
- if is_msgspec_struct(schema_type):
267
- if not isinstance(data, Sequence):
268
- return cast(
269
- "ModelDTOT",
270
- convert(
271
- obj=data,
272
- type=schema_type,
273
- from_attributes=True,
274
- dec_hook=partial(
275
- _default_msgspec_deserializer,
276
- type_decoders=_DEFAULT_TYPE_DECODERS,
277
- ),
278
- ),
279
- )
280
- return cast(
281
- "Sequence[ModelDTOT]",
282
- convert(
283
- obj=data,
284
- type=list[schema_type], # type: ignore[valid-type]
285
- from_attributes=True,
286
- dec_hook=partial(
287
- _default_msgspec_deserializer,
288
- type_decoders=_DEFAULT_TYPE_DECODERS,
289
- ),
290
- ),
291
- )
292
-
293
- if schema_type is not None and is_pydantic_model(schema_type):
294
- if not isinstance(data, Sequence):
295
- return cast(
296
- "ModelDTOT",
297
- get_type_adapter(schema_type).validate_python(data, from_attributes=True), # pyright: ignore
298
- )
299
- return cast(
300
- "Sequence[ModelDTOT]",
301
- get_type_adapter(list[schema_type]).validate_python(data, from_attributes=True), # type: ignore[valid-type] # pyright: ignore[reportUnknownArgumentType]
302
- )
303
-
304
- msg = "`schema_type` should be a valid Dataclass, Pydantic model or Msgspec struct"
305
- raise SQLSpecError(msg)