sqlspec 0.12.1__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +21 -180
- sqlspec/adapters/adbc/config.py +10 -12
- sqlspec/adapters/adbc/driver.py +120 -118
- sqlspec/adapters/aiosqlite/config.py +3 -3
- sqlspec/adapters/aiosqlite/driver.py +116 -141
- sqlspec/adapters/asyncmy/config.py +3 -4
- sqlspec/adapters/asyncmy/driver.py +123 -135
- sqlspec/adapters/asyncpg/config.py +3 -7
- sqlspec/adapters/asyncpg/driver.py +98 -140
- sqlspec/adapters/bigquery/config.py +4 -5
- sqlspec/adapters/bigquery/driver.py +231 -181
- sqlspec/adapters/duckdb/config.py +3 -6
- sqlspec/adapters/duckdb/driver.py +132 -124
- sqlspec/adapters/oracledb/config.py +6 -5
- sqlspec/adapters/oracledb/driver.py +242 -259
- sqlspec/adapters/psqlpy/config.py +3 -7
- sqlspec/adapters/psqlpy/driver.py +118 -93
- sqlspec/adapters/psycopg/config.py +34 -30
- sqlspec/adapters/psycopg/driver.py +342 -214
- sqlspec/adapters/sqlite/config.py +3 -3
- sqlspec/adapters/sqlite/driver.py +150 -104
- sqlspec/config.py +0 -4
- sqlspec/driver/_async.py +89 -98
- sqlspec/driver/_common.py +52 -17
- sqlspec/driver/_sync.py +81 -105
- sqlspec/driver/connection.py +207 -0
- sqlspec/driver/mixins/_csv_writer.py +91 -0
- sqlspec/driver/mixins/_pipeline.py +38 -49
- sqlspec/driver/mixins/_result_utils.py +27 -9
- sqlspec/driver/mixins/_storage.py +149 -216
- sqlspec/driver/mixins/_type_coercion.py +3 -4
- sqlspec/driver/parameters.py +138 -0
- sqlspec/exceptions.py +10 -2
- sqlspec/extensions/aiosql/adapter.py +0 -10
- sqlspec/extensions/litestar/handlers.py +0 -1
- sqlspec/extensions/litestar/plugin.py +0 -3
- sqlspec/extensions/litestar/providers.py +0 -14
- sqlspec/loader.py +31 -118
- sqlspec/protocols.py +542 -0
- sqlspec/service/__init__.py +3 -2
- sqlspec/service/_util.py +147 -0
- sqlspec/service/base.py +1116 -9
- sqlspec/statement/builder/__init__.py +42 -32
- sqlspec/statement/builder/_ddl_utils.py +0 -10
- sqlspec/statement/builder/_parsing_utils.py +10 -4
- sqlspec/statement/builder/base.py +70 -23
- sqlspec/statement/builder/column.py +283 -0
- sqlspec/statement/builder/ddl.py +102 -65
- sqlspec/statement/builder/delete.py +23 -7
- sqlspec/statement/builder/insert.py +29 -15
- sqlspec/statement/builder/merge.py +4 -4
- sqlspec/statement/builder/mixins/_aggregate_functions.py +113 -14
- sqlspec/statement/builder/mixins/_common_table_expr.py +0 -1
- sqlspec/statement/builder/mixins/_delete_from.py +1 -1
- sqlspec/statement/builder/mixins/_from.py +10 -8
- sqlspec/statement/builder/mixins/_group_by.py +0 -1
- sqlspec/statement/builder/mixins/_insert_from_select.py +0 -1
- sqlspec/statement/builder/mixins/_insert_values.py +0 -2
- sqlspec/statement/builder/mixins/_join.py +20 -13
- sqlspec/statement/builder/mixins/_limit_offset.py +3 -3
- sqlspec/statement/builder/mixins/_merge_clauses.py +3 -4
- sqlspec/statement/builder/mixins/_order_by.py +2 -2
- sqlspec/statement/builder/mixins/_pivot.py +4 -7
- sqlspec/statement/builder/mixins/_select_columns.py +6 -5
- sqlspec/statement/builder/mixins/_unpivot.py +6 -9
- sqlspec/statement/builder/mixins/_update_from.py +2 -1
- sqlspec/statement/builder/mixins/_update_set.py +11 -8
- sqlspec/statement/builder/mixins/_where.py +61 -34
- sqlspec/statement/builder/select.py +32 -17
- sqlspec/statement/builder/update.py +25 -11
- sqlspec/statement/filters.py +39 -14
- sqlspec/statement/parameter_manager.py +220 -0
- sqlspec/statement/parameters.py +210 -79
- sqlspec/statement/pipelines/__init__.py +166 -23
- sqlspec/statement/pipelines/analyzers/_analyzer.py +22 -25
- sqlspec/statement/pipelines/context.py +35 -39
- sqlspec/statement/pipelines/transformers/__init__.py +2 -3
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +19 -187
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +667 -43
- sqlspec/statement/pipelines/transformers/_remove_comments_and_hints.py +76 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +33 -18
- sqlspec/statement/pipelines/validators/_parameter_style.py +87 -14
- sqlspec/statement/pipelines/validators/_performance.py +38 -23
- sqlspec/statement/pipelines/validators/_security.py +39 -62
- sqlspec/statement/result.py +37 -129
- sqlspec/statement/splitter.py +0 -12
- sqlspec/statement/sql.py +885 -379
- sqlspec/statement/sql_compiler.py +140 -0
- sqlspec/storage/__init__.py +10 -2
- sqlspec/storage/backends/fsspec.py +82 -35
- sqlspec/storage/backends/obstore.py +66 -49
- sqlspec/storage/capabilities.py +101 -0
- sqlspec/storage/registry.py +56 -83
- sqlspec/typing.py +6 -434
- sqlspec/utils/cached_property.py +25 -0
- sqlspec/utils/correlation.py +0 -2
- sqlspec/utils/logging.py +0 -6
- sqlspec/utils/sync_tools.py +0 -4
- sqlspec/utils/text.py +0 -5
- sqlspec/utils/type_guards.py +892 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/METADATA +1 -1
- sqlspec-0.13.0.dist-info/RECORD +150 -0
- sqlspec/statement/builder/protocols.py +0 -20
- sqlspec/statement/pipelines/base.py +0 -315
- sqlspec/statement/pipelines/result_types.py +0 -41
- sqlspec/statement/pipelines/transformers/_remove_comments.py +0 -66
- sqlspec/statement/pipelines/transformers/_remove_hints.py +0 -81
- sqlspec/statement/pipelines/validators/base.py +0 -67
- sqlspec/storage/protocol.py +0 -170
- sqlspec-0.12.1.dist-info/RECORD +0 -145
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.12.1.dist-info → sqlspec-0.13.0.dist-info}/licenses/NOTICE +0 -0
sqlspec/service/base.py
CHANGED
|
@@ -1,24 +1,1131 @@
|
|
|
1
|
-
|
|
1
|
+
# mypy: disable-error-code="arg-type,misc,type-var"
|
|
2
|
+
# pyright: reportCallIssue=false, reportArgumentType=false
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, overload
|
|
2
4
|
|
|
3
|
-
from
|
|
5
|
+
from sqlglot import exp, parse_one
|
|
4
6
|
|
|
5
|
-
|
|
7
|
+
from sqlspec.typing import ConnectionT
|
|
8
|
+
from sqlspec.utils.type_guards import (
|
|
9
|
+
is_dict_row,
|
|
10
|
+
is_indexable_row,
|
|
11
|
+
is_limit_offset_filter,
|
|
12
|
+
is_select_builder,
|
|
13
|
+
is_statement_filter,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from sqlspec.driver import AsyncDriverAdapterProtocol, SyncDriverAdapterProtocol
|
|
18
|
+
from sqlspec.service.pagination import OffsetPagination
|
|
19
|
+
from sqlspec.statement import SQLConfig, Statement, StatementFilter
|
|
20
|
+
from sqlspec.statement.builder import Delete, Insert, QueryBuilder, Select, Update
|
|
21
|
+
from sqlspec.statement.sql import SQL
|
|
22
|
+
from sqlspec.typing import ModelDTOT, RowT, StatementParameters
|
|
23
|
+
|
|
24
|
+
__all__ = ("SQLSpecAsyncService", "SQLSpecSyncService")
|
|
6
25
|
|
|
7
26
|
|
|
8
27
|
T = TypeVar("T")
|
|
28
|
+
SyncDriverT = TypeVar("SyncDriverT", bound="SyncDriverAdapterProtocol[Any]")
|
|
29
|
+
AsyncDriverT = TypeVar("AsyncDriverT", bound="AsyncDriverAdapterProtocol[Any]")
|
|
9
30
|
|
|
10
31
|
|
|
11
|
-
class
|
|
12
|
-
"""
|
|
32
|
+
class SQLSpecSyncService(Generic[SyncDriverT, ConnectionT]):
|
|
33
|
+
"""Sync Service for database operations."""
|
|
13
34
|
|
|
14
|
-
def __init__(self, driver: "
|
|
35
|
+
def __init__(self, driver: "SyncDriverT", connection: "ConnectionT") -> None:
|
|
15
36
|
self._driver = driver
|
|
37
|
+
self._connection = connection
|
|
16
38
|
|
|
17
39
|
@classmethod
|
|
18
|
-
def new(cls, driver: "
|
|
19
|
-
return cls(driver=driver)
|
|
40
|
+
def new(cls, driver: "SyncDriverT", connection: "ConnectionT") -> "SQLSpecSyncService[SyncDriverT, ConnectionT]":
|
|
41
|
+
return cls(driver=driver, connection=connection)
|
|
20
42
|
|
|
21
43
|
@property
|
|
22
|
-
def driver(self) -> "
|
|
44
|
+
def driver(self) -> "SyncDriverT":
|
|
23
45
|
"""Get the driver instance."""
|
|
24
46
|
return self._driver
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def connection(self) -> "ConnectionT":
|
|
50
|
+
"""Get the connection instance."""
|
|
51
|
+
return self._connection
|
|
52
|
+
|
|
53
|
+
def _normalize_statement(
|
|
54
|
+
self,
|
|
55
|
+
statement: "Union[Statement, Select]",
|
|
56
|
+
params: "Optional[dict[str, Any]]" = None,
|
|
57
|
+
config: "Optional[SQLConfig]" = None,
|
|
58
|
+
) -> "SQL":
|
|
59
|
+
"""Normalize a statement of any supported type into a SQL object.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
statement: The statement to normalize (str, Expression, SQL, or Select)
|
|
63
|
+
params: Optional parameters (ignored for Select and SQL objects)
|
|
64
|
+
config: Optional SQL configuration
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
A normalized SQL object
|
|
68
|
+
"""
|
|
69
|
+
from sqlspec.statement.sql import SQL
|
|
70
|
+
|
|
71
|
+
if is_select_builder(statement):
|
|
72
|
+
# Select has its own parameters via build(), ignore external params
|
|
73
|
+
safe_query = statement.build()
|
|
74
|
+
return SQL(safe_query.sql, parameters=safe_query.parameters, config=config)
|
|
75
|
+
|
|
76
|
+
if isinstance(statement, SQL):
|
|
77
|
+
# SQL object is already complete, ignore external params
|
|
78
|
+
return statement
|
|
79
|
+
|
|
80
|
+
if isinstance(statement, (str, exp.Expression)):
|
|
81
|
+
return SQL(statement, parameters=params, config=config)
|
|
82
|
+
|
|
83
|
+
# Fallback for type safety
|
|
84
|
+
msg = f"Unsupported statement type: {type(statement).__name__}"
|
|
85
|
+
raise TypeError(msg)
|
|
86
|
+
|
|
87
|
+
@overload
|
|
88
|
+
def execute(
|
|
89
|
+
self,
|
|
90
|
+
statement: "Select",
|
|
91
|
+
/,
|
|
92
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
93
|
+
schema_type: "type[ModelDTOT]",
|
|
94
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
95
|
+
_config: "Optional[SQLConfig]" = None,
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
) -> "list[ModelDTOT]": ...
|
|
98
|
+
|
|
99
|
+
@overload
|
|
100
|
+
def execute(
|
|
101
|
+
self,
|
|
102
|
+
statement: "Select",
|
|
103
|
+
/,
|
|
104
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
105
|
+
schema_type: None = None,
|
|
106
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
107
|
+
_config: "Optional[SQLConfig]" = None,
|
|
108
|
+
**kwargs: Any,
|
|
109
|
+
) -> "list[RowT]": ...
|
|
110
|
+
|
|
111
|
+
@overload
|
|
112
|
+
def execute(
|
|
113
|
+
self,
|
|
114
|
+
statement: "Union[Insert, Update, Delete]",
|
|
115
|
+
/,
|
|
116
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
117
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
118
|
+
_config: "Optional[SQLConfig]" = None,
|
|
119
|
+
**kwargs: Any,
|
|
120
|
+
) -> "list[RowT]": ...
|
|
121
|
+
|
|
122
|
+
@overload
|
|
123
|
+
def execute(
|
|
124
|
+
self,
|
|
125
|
+
statement: "Union[str, SQL]", # exp.Expression
|
|
126
|
+
/,
|
|
127
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
128
|
+
schema_type: "type[ModelDTOT]",
|
|
129
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
130
|
+
_config: "Optional[SQLConfig]" = None,
|
|
131
|
+
**kwargs: Any,
|
|
132
|
+
) -> "list[ModelDTOT]": ...
|
|
133
|
+
|
|
134
|
+
@overload
|
|
135
|
+
def execute(
|
|
136
|
+
self,
|
|
137
|
+
statement: "Union[str, SQL]",
|
|
138
|
+
/,
|
|
139
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
140
|
+
schema_type: None = None,
|
|
141
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
142
|
+
_config: "Optional[SQLConfig]" = None,
|
|
143
|
+
**kwargs: Any,
|
|
144
|
+
) -> "list[RowT]": ...
|
|
145
|
+
|
|
146
|
+
def execute(
|
|
147
|
+
self,
|
|
148
|
+
statement: "Union[Statement, QueryBuilder[Any]]",
|
|
149
|
+
/,
|
|
150
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
151
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
152
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
153
|
+
_config: "Optional[SQLConfig]" = None,
|
|
154
|
+
**kwargs: Any,
|
|
155
|
+
) -> Any:
|
|
156
|
+
"""Execute a statement and return the result."""
|
|
157
|
+
result = self.driver.execute(
|
|
158
|
+
statement, *parameters, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
159
|
+
)
|
|
160
|
+
return result.get_data()
|
|
161
|
+
|
|
162
|
+
def execute_many(
|
|
163
|
+
self,
|
|
164
|
+
statement: "Union[Statement, QueryBuilder[Any]]",
|
|
165
|
+
/,
|
|
166
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
167
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
168
|
+
_config: "Optional[SQLConfig]" = None,
|
|
169
|
+
**kwargs: Any,
|
|
170
|
+
) -> Any:
|
|
171
|
+
"""Execute a statement multiple times and return the result."""
|
|
172
|
+
result = self.driver.execute_many(statement, *parameters, _connection=_connection, _config=_config, **kwargs)
|
|
173
|
+
return result.get_data()
|
|
174
|
+
|
|
175
|
+
def execute_script(
|
|
176
|
+
self,
|
|
177
|
+
statement: "Statement",
|
|
178
|
+
/,
|
|
179
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
180
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
181
|
+
_config: "Optional[SQLConfig]" = None,
|
|
182
|
+
**kwargs: Any,
|
|
183
|
+
) -> Any:
|
|
184
|
+
"""Execute a script statement."""
|
|
185
|
+
result = self.driver.execute_script(statement, *parameters, _connection=_connection, _config=_config, **kwargs)
|
|
186
|
+
return result.get_data()
|
|
187
|
+
|
|
188
|
+
@overload
|
|
189
|
+
def select_one(
|
|
190
|
+
self,
|
|
191
|
+
statement: "Union[Statement, Select]",
|
|
192
|
+
/,
|
|
193
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
194
|
+
schema_type: "type[ModelDTOT]",
|
|
195
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
196
|
+
_config: "Optional[SQLConfig]" = None,
|
|
197
|
+
**kwargs: Any,
|
|
198
|
+
) -> "ModelDTOT": ...
|
|
199
|
+
|
|
200
|
+
@overload
|
|
201
|
+
def select_one(
|
|
202
|
+
self,
|
|
203
|
+
statement: "Union[Statement, Select]",
|
|
204
|
+
/,
|
|
205
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
206
|
+
schema_type: None = None,
|
|
207
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
208
|
+
_config: "Optional[SQLConfig]" = None,
|
|
209
|
+
**kwargs: Any,
|
|
210
|
+
) -> "RowT": ...
|
|
211
|
+
|
|
212
|
+
def select_one(
|
|
213
|
+
self,
|
|
214
|
+
statement: "Union[Statement, Select]",
|
|
215
|
+
/,
|
|
216
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
217
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
218
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
219
|
+
_config: "Optional[SQLConfig]" = None,
|
|
220
|
+
**kwargs: Any,
|
|
221
|
+
) -> Any:
|
|
222
|
+
"""Execute a select statement and return exactly one row.
|
|
223
|
+
|
|
224
|
+
Raises an exception if no rows or more than one row is returned.
|
|
225
|
+
"""
|
|
226
|
+
result = self.driver.execute(
|
|
227
|
+
statement, *parameters, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
228
|
+
)
|
|
229
|
+
data = result.get_data()
|
|
230
|
+
# For select operations, data should be a list
|
|
231
|
+
if not isinstance(data, list):
|
|
232
|
+
msg = "Expected list result from select operation"
|
|
233
|
+
raise TypeError(msg)
|
|
234
|
+
if not data:
|
|
235
|
+
msg = "No rows found"
|
|
236
|
+
raise ValueError(msg)
|
|
237
|
+
if len(data) > 1:
|
|
238
|
+
msg = f"Expected exactly one row, found {len(data)}"
|
|
239
|
+
raise ValueError(msg)
|
|
240
|
+
return data[0]
|
|
241
|
+
|
|
242
|
+
@overload
|
|
243
|
+
def select_one_or_none(
|
|
244
|
+
self,
|
|
245
|
+
statement: "Union[Statement, Select]",
|
|
246
|
+
/,
|
|
247
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
248
|
+
schema_type: "type[ModelDTOT]",
|
|
249
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
250
|
+
_config: "Optional[SQLConfig]" = None,
|
|
251
|
+
**kwargs: Any,
|
|
252
|
+
) -> "Optional[ModelDTOT]": ...
|
|
253
|
+
|
|
254
|
+
@overload
|
|
255
|
+
def select_one_or_none(
|
|
256
|
+
self,
|
|
257
|
+
statement: "Union[Statement, Select]",
|
|
258
|
+
/,
|
|
259
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
260
|
+
schema_type: None = None,
|
|
261
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
262
|
+
_config: "Optional[SQLConfig]" = None,
|
|
263
|
+
**kwargs: Any,
|
|
264
|
+
) -> "Optional[RowT]": ...
|
|
265
|
+
|
|
266
|
+
def select_one_or_none(
|
|
267
|
+
self,
|
|
268
|
+
statement: "Union[Statement, Select]",
|
|
269
|
+
/,
|
|
270
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
271
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
272
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
273
|
+
_config: "Optional[SQLConfig]" = None,
|
|
274
|
+
**kwargs: Any,
|
|
275
|
+
) -> Any:
|
|
276
|
+
"""Execute a select statement and return at most one row.
|
|
277
|
+
|
|
278
|
+
Returns None if no rows are found.
|
|
279
|
+
Raises an exception if more than one row is returned.
|
|
280
|
+
"""
|
|
281
|
+
result = self.driver.execute(
|
|
282
|
+
statement, *parameters, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
283
|
+
)
|
|
284
|
+
data = result.get_data()
|
|
285
|
+
# For select operations, data should be a list
|
|
286
|
+
if not isinstance(data, list):
|
|
287
|
+
msg = "Expected list result from select operation"
|
|
288
|
+
raise TypeError(msg)
|
|
289
|
+
if not data:
|
|
290
|
+
return None
|
|
291
|
+
if len(data) > 1:
|
|
292
|
+
msg = f"Expected at most one row, found {len(data)}"
|
|
293
|
+
raise ValueError(msg)
|
|
294
|
+
return data[0]
|
|
295
|
+
|
|
296
|
+
@overload
|
|
297
|
+
def select(
|
|
298
|
+
self,
|
|
299
|
+
statement: "Union[Statement, Select]",
|
|
300
|
+
/,
|
|
301
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
302
|
+
schema_type: "type[ModelDTOT]",
|
|
303
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
304
|
+
_config: "Optional[SQLConfig]" = None,
|
|
305
|
+
**kwargs: Any,
|
|
306
|
+
) -> "list[ModelDTOT]": ...
|
|
307
|
+
|
|
308
|
+
@overload
|
|
309
|
+
def select(
|
|
310
|
+
self,
|
|
311
|
+
statement: "Union[Statement, Select]",
|
|
312
|
+
/,
|
|
313
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
314
|
+
schema_type: None = None,
|
|
315
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
316
|
+
_config: "Optional[SQLConfig]" = None,
|
|
317
|
+
**kwargs: Any,
|
|
318
|
+
) -> "list[RowT]": ...
|
|
319
|
+
|
|
320
|
+
def select(
|
|
321
|
+
self,
|
|
322
|
+
statement: "Union[Statement, Select]",
|
|
323
|
+
/,
|
|
324
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
325
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
326
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
327
|
+
_config: "Optional[SQLConfig]" = None,
|
|
328
|
+
**kwargs: Any,
|
|
329
|
+
) -> Any:
|
|
330
|
+
"""Execute a select statement and return all rows."""
|
|
331
|
+
result = self.driver.execute(
|
|
332
|
+
statement, *parameters, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
333
|
+
)
|
|
334
|
+
data = result.get_data()
|
|
335
|
+
# For select operations, data should be a list
|
|
336
|
+
if not isinstance(data, list):
|
|
337
|
+
msg = "Expected list result from select operation"
|
|
338
|
+
raise TypeError(msg)
|
|
339
|
+
return data
|
|
340
|
+
|
|
341
|
+
def select_value(
|
|
342
|
+
self,
|
|
343
|
+
statement: "Union[Statement, Select]",
|
|
344
|
+
/,
|
|
345
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
346
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
347
|
+
_config: "Optional[SQLConfig]" = None,
|
|
348
|
+
**kwargs: Any,
|
|
349
|
+
) -> Any:
|
|
350
|
+
"""Execute a select statement and return a single scalar value.
|
|
351
|
+
|
|
352
|
+
Expects exactly one row with one column.
|
|
353
|
+
Raises an exception if no rows or more than one row/column is returned.
|
|
354
|
+
"""
|
|
355
|
+
result = self.driver.execute(statement, *parameters, _connection=_connection, _config=_config, **kwargs)
|
|
356
|
+
data = result.get_data()
|
|
357
|
+
# For select operations, data should be a list
|
|
358
|
+
if not isinstance(data, list):
|
|
359
|
+
msg = "Expected list result from select operation"
|
|
360
|
+
raise TypeError(msg)
|
|
361
|
+
if not data:
|
|
362
|
+
msg = "No rows found"
|
|
363
|
+
raise ValueError(msg)
|
|
364
|
+
if len(data) > 1:
|
|
365
|
+
msg = f"Expected exactly one row, found {len(data)}"
|
|
366
|
+
raise ValueError(msg)
|
|
367
|
+
row = data[0]
|
|
368
|
+
if is_dict_row(row):
|
|
369
|
+
if not row:
|
|
370
|
+
msg = "Row has no columns"
|
|
371
|
+
raise ValueError(msg)
|
|
372
|
+
return next(iter(row.values()))
|
|
373
|
+
if is_indexable_row(row):
|
|
374
|
+
# Tuple or list-like row
|
|
375
|
+
if not row:
|
|
376
|
+
msg = "Row has no columns"
|
|
377
|
+
raise ValueError(msg)
|
|
378
|
+
return row[0]
|
|
379
|
+
msg = f"Unexpected row type: {type(row)}"
|
|
380
|
+
raise ValueError(msg)
|
|
381
|
+
|
|
382
|
+
def select_value_or_none(
|
|
383
|
+
self,
|
|
384
|
+
statement: "Union[Statement, Select]",
|
|
385
|
+
/,
|
|
386
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
387
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
388
|
+
_config: "Optional[SQLConfig]" = None,
|
|
389
|
+
**kwargs: Any,
|
|
390
|
+
) -> Any:
|
|
391
|
+
"""Execute a select statement and return a single scalar value or None.
|
|
392
|
+
|
|
393
|
+
Returns None if no rows are found.
|
|
394
|
+
Expects at most one row with one column.
|
|
395
|
+
Raises an exception if more than one row is returned.
|
|
396
|
+
"""
|
|
397
|
+
result = self.driver.execute(statement, *parameters, _connection=_connection, _config=_config, **kwargs)
|
|
398
|
+
data = result.get_data()
|
|
399
|
+
# For select operations, data should be a list
|
|
400
|
+
if not isinstance(data, list):
|
|
401
|
+
msg = "Expected list result from select operation"
|
|
402
|
+
raise TypeError(msg)
|
|
403
|
+
if not data:
|
|
404
|
+
return None
|
|
405
|
+
if len(data) > 1:
|
|
406
|
+
msg = f"Expected at most one row, found {len(data)}"
|
|
407
|
+
raise ValueError(msg)
|
|
408
|
+
row = data[0]
|
|
409
|
+
if isinstance(row, dict):
|
|
410
|
+
if not row:
|
|
411
|
+
return None
|
|
412
|
+
return next(iter(row.values()))
|
|
413
|
+
if isinstance(row, (tuple, list)):
|
|
414
|
+
# Tuple or list-like row
|
|
415
|
+
return row[0]
|
|
416
|
+
try:
|
|
417
|
+
return row[0]
|
|
418
|
+
except (TypeError, IndexError) as e:
|
|
419
|
+
msg = f"Cannot extract value from row type {type(row).__name__}: {e}"
|
|
420
|
+
raise TypeError(msg) from e
|
|
421
|
+
|
|
422
|
+
@overload
|
|
423
|
+
def paginate(
|
|
424
|
+
self,
|
|
425
|
+
statement: "Union[Statement, Select]",
|
|
426
|
+
/,
|
|
427
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
428
|
+
schema_type: "type[ModelDTOT]",
|
|
429
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
430
|
+
_config: "Optional[SQLConfig]" = None,
|
|
431
|
+
**kwargs: Any,
|
|
432
|
+
) -> "OffsetPagination[ModelDTOT]": ...
|
|
433
|
+
|
|
434
|
+
@overload
|
|
435
|
+
def paginate(
|
|
436
|
+
self,
|
|
437
|
+
statement: "Union[Statement, Select]",
|
|
438
|
+
/,
|
|
439
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
440
|
+
schema_type: None = None,
|
|
441
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
442
|
+
_config: "Optional[SQLConfig]" = None,
|
|
443
|
+
**kwargs: Any,
|
|
444
|
+
) -> "OffsetPagination[RowT]": ...
|
|
445
|
+
|
|
446
|
+
def paginate(
|
|
447
|
+
self,
|
|
448
|
+
statement: "Union[Statement, Select]",
|
|
449
|
+
/,
|
|
450
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
451
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
452
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
453
|
+
_config: "Optional[SQLConfig]" = None,
|
|
454
|
+
**kwargs: Any,
|
|
455
|
+
) -> Any:
|
|
456
|
+
"""Execute a paginated query with automatic counting.
|
|
457
|
+
|
|
458
|
+
This method performs two queries:
|
|
459
|
+
1. A count query to get the total number of results
|
|
460
|
+
2. A data query with limit/offset applied
|
|
461
|
+
|
|
462
|
+
Pagination can be specified either via LimitOffsetFilter in parameters
|
|
463
|
+
or via 'limit' and 'offset' in kwargs.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
statement: The SELECT statement to paginate
|
|
467
|
+
*parameters: Statement parameters and filters (can include LimitOffsetFilter)
|
|
468
|
+
schema_type: Optional model type for automatic schema conversion
|
|
469
|
+
_connection: Optional connection to use
|
|
470
|
+
_config: Optional SQL configuration
|
|
471
|
+
**kwargs: Additional driver-specific arguments. Can include 'limit' and 'offset'
|
|
472
|
+
if LimitOffsetFilter is not provided
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
OffsetPagination object containing items, limit, offset, and total count
|
|
476
|
+
|
|
477
|
+
Raises:
|
|
478
|
+
ValueError: If neither LimitOffsetFilter nor limit/offset kwargs are provided
|
|
479
|
+
|
|
480
|
+
Example:
|
|
481
|
+
>>> # Using LimitOffsetFilter (recommended)
|
|
482
|
+
>>> from sqlspec.statement.filters import LimitOffsetFilter
|
|
483
|
+
>>> result = service.paginate(
|
|
484
|
+
... sql.select("*").from_("users"),
|
|
485
|
+
... LimitOffsetFilter(limit=10, offset=20),
|
|
486
|
+
... )
|
|
487
|
+
>>> print(
|
|
488
|
+
... f"Showing {len(result.items)} of {result.total} users"
|
|
489
|
+
... )
|
|
490
|
+
|
|
491
|
+
>>> # Using kwargs (convenience)
|
|
492
|
+
>>> result = service.paginate(
|
|
493
|
+
... sql.select("*").from_("users"), limit=10, offset=20
|
|
494
|
+
... )
|
|
495
|
+
|
|
496
|
+
>>> # With schema conversion
|
|
497
|
+
>>> result = service.paginate(
|
|
498
|
+
... sql.select("*").from_("users"),
|
|
499
|
+
... LimitOffsetFilter(limit=10, offset=0),
|
|
500
|
+
... schema_type=User,
|
|
501
|
+
... )
|
|
502
|
+
>>> # result.items is list[User] with proper type inference
|
|
503
|
+
|
|
504
|
+
>>> # With multiple filters
|
|
505
|
+
>>> from sqlspec.statement.filters import (
|
|
506
|
+
... LimitOffsetFilter,
|
|
507
|
+
... OrderByFilter,
|
|
508
|
+
... )
|
|
509
|
+
>>> result = service.paginate(
|
|
510
|
+
... sql.select("*").from_("users"),
|
|
511
|
+
... OrderByFilter("created_at", "desc"),
|
|
512
|
+
... LimitOffsetFilter(limit=20, offset=40),
|
|
513
|
+
... schema_type=User,
|
|
514
|
+
... )
|
|
515
|
+
"""
|
|
516
|
+
from sqlspec.service.pagination import OffsetPagination
|
|
517
|
+
from sqlspec.statement.sql import SQL
|
|
518
|
+
|
|
519
|
+
# Separate filters from parameters
|
|
520
|
+
filters: list[StatementFilter] = []
|
|
521
|
+
params: list[Any] = []
|
|
522
|
+
|
|
523
|
+
for p in parameters:
|
|
524
|
+
# Use type guard to check if it implements the StatementFilter protocol
|
|
525
|
+
if is_statement_filter(p):
|
|
526
|
+
filters.append(p)
|
|
527
|
+
else:
|
|
528
|
+
params.append(p)
|
|
529
|
+
|
|
530
|
+
# Check for LimitOffsetFilter in filters
|
|
531
|
+
limit_offset_filter = None
|
|
532
|
+
other_filters = []
|
|
533
|
+
for f in filters:
|
|
534
|
+
if is_limit_offset_filter(f):
|
|
535
|
+
limit_offset_filter = f
|
|
536
|
+
else:
|
|
537
|
+
other_filters.append(f)
|
|
538
|
+
|
|
539
|
+
if limit_offset_filter is not None:
|
|
540
|
+
limit = limit_offset_filter.limit
|
|
541
|
+
offset = limit_offset_filter.offset
|
|
542
|
+
elif "limit" in kwargs and "offset" in kwargs:
|
|
543
|
+
limit = kwargs.pop("limit")
|
|
544
|
+
offset = kwargs.pop("offset")
|
|
545
|
+
else:
|
|
546
|
+
msg = "Pagination requires either a LimitOffsetFilter in parameters or 'limit' and 'offset' in kwargs."
|
|
547
|
+
raise ValueError(msg)
|
|
548
|
+
|
|
549
|
+
base_stmt = self._normalize_statement(statement, params, _config)
|
|
550
|
+
|
|
551
|
+
filtered_stmt = base_stmt
|
|
552
|
+
for filter_obj in other_filters:
|
|
553
|
+
filtered_stmt = filter_obj.append_to_statement(filtered_stmt)
|
|
554
|
+
|
|
555
|
+
sql_str = filtered_stmt.to_sql()
|
|
556
|
+
|
|
557
|
+
# Parse and transform the AST to create a count query
|
|
558
|
+
parsed = parse_one(sql_str)
|
|
559
|
+
|
|
560
|
+
# Using exp.Subquery to properly wrap the parsed expression
|
|
561
|
+
subquery = exp.Subquery(this=parsed, alias="_count_subquery")
|
|
562
|
+
count_ast = exp.Select().select(exp.func("COUNT", exp.Star()).as_("total")).from_(subquery)
|
|
563
|
+
|
|
564
|
+
count_stmt = SQL(count_ast.sql(), _config=_config)
|
|
565
|
+
|
|
566
|
+
# Execute count query
|
|
567
|
+
total = self.select_value(count_stmt, _connection=_connection, _config=_config, **kwargs)
|
|
568
|
+
|
|
569
|
+
data_stmt = self._normalize_statement(statement, params, _config)
|
|
570
|
+
|
|
571
|
+
for filter_obj in other_filters:
|
|
572
|
+
data_stmt = filter_obj.append_to_statement(data_stmt)
|
|
573
|
+
|
|
574
|
+
data_stmt = data_stmt.limit(limit).offset(offset)
|
|
575
|
+
|
|
576
|
+
# Execute data query
|
|
577
|
+
items = self.select(data_stmt, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs)
|
|
578
|
+
|
|
579
|
+
return OffsetPagination(items=items, limit=limit, offset=offset, total=total)
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
class SQLSpecAsyncService(Generic[AsyncDriverT, ConnectionT]):
|
|
583
|
+
"""Async Service for database operations."""
|
|
584
|
+
|
|
585
|
+
def __init__(self, driver: "AsyncDriverT", connection: "ConnectionT") -> None:
|
|
586
|
+
self._driver = driver
|
|
587
|
+
self._connection = connection
|
|
588
|
+
|
|
589
|
+
@classmethod
|
|
590
|
+
def new(cls, driver: "AsyncDriverT", connection: "ConnectionT") -> "SQLSpecAsyncService[AsyncDriverT, ConnectionT]":
|
|
591
|
+
return cls(driver=driver, connection=connection)
|
|
592
|
+
|
|
593
|
+
@property
|
|
594
|
+
def driver(self) -> "AsyncDriverT":
|
|
595
|
+
"""Get the driver instance."""
|
|
596
|
+
return self._driver
|
|
597
|
+
|
|
598
|
+
@property
|
|
599
|
+
def connection(self) -> "ConnectionT":
|
|
600
|
+
"""Get the connection instance."""
|
|
601
|
+
return self._connection
|
|
602
|
+
|
|
603
|
+
def _normalize_statement(
|
|
604
|
+
self,
|
|
605
|
+
statement: "Union[Statement, Select]",
|
|
606
|
+
params: "Optional[dict[str, Any]]" = None,
|
|
607
|
+
config: "Optional[SQLConfig]" = None,
|
|
608
|
+
) -> "SQL":
|
|
609
|
+
"""Normalize a statement of any supported type into a SQL object.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
statement: The statement to normalize (str, Expression, SQL, or Select)
|
|
613
|
+
params: Optional parameters (ignored for Select and SQL objects)
|
|
614
|
+
config: Optional SQL configuration
|
|
615
|
+
|
|
616
|
+
Returns:
|
|
617
|
+
A normalized SQL object
|
|
618
|
+
"""
|
|
619
|
+
from sqlspec.statement.sql import SQL
|
|
620
|
+
|
|
621
|
+
if is_select_builder(statement):
|
|
622
|
+
# Select has its own parameters via build(), ignore external params
|
|
623
|
+
safe_query = statement.build()
|
|
624
|
+
return SQL(safe_query.sql, parameters=safe_query.parameters, config=config)
|
|
625
|
+
|
|
626
|
+
if isinstance(statement, SQL):
|
|
627
|
+
# SQL object is already complete, ignore external params
|
|
628
|
+
return statement
|
|
629
|
+
|
|
630
|
+
if isinstance(statement, (str, exp.Expression)):
|
|
631
|
+
return SQL(statement, parameters=params, config=config)
|
|
632
|
+
|
|
633
|
+
# Fallback for type safety
|
|
634
|
+
msg = f"Unsupported statement type: {type(statement).__name__}"
|
|
635
|
+
raise TypeError(msg)
|
|
636
|
+
|
|
637
|
+
@overload
|
|
638
|
+
async def execute(
|
|
639
|
+
self,
|
|
640
|
+
statement: "Select",
|
|
641
|
+
/,
|
|
642
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
643
|
+
schema_type: "type[ModelDTOT]",
|
|
644
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
645
|
+
_config: "Optional[SQLConfig]" = None,
|
|
646
|
+
**kwargs: Any,
|
|
647
|
+
) -> "list[ModelDTOT]": ...
|
|
648
|
+
|
|
649
|
+
@overload
|
|
650
|
+
async def execute(
|
|
651
|
+
self,
|
|
652
|
+
statement: "Select",
|
|
653
|
+
/,
|
|
654
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
655
|
+
schema_type: None = None,
|
|
656
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
657
|
+
_config: "Optional[SQLConfig]" = None,
|
|
658
|
+
**kwargs: Any,
|
|
659
|
+
) -> "list[RowT]": ...
|
|
660
|
+
|
|
661
|
+
@overload
|
|
662
|
+
async def execute(
|
|
663
|
+
self,
|
|
664
|
+
statement: "Union[Insert, Update, Delete]",
|
|
665
|
+
/,
|
|
666
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
667
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
668
|
+
_config: "Optional[SQLConfig]" = None,
|
|
669
|
+
**kwargs: Any,
|
|
670
|
+
) -> "list[RowT]": ...
|
|
671
|
+
|
|
672
|
+
@overload
|
|
673
|
+
async def execute(
|
|
674
|
+
self,
|
|
675
|
+
statement: "Union[str, SQL]", # exp.Expression
|
|
676
|
+
/,
|
|
677
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
678
|
+
schema_type: "type[ModelDTOT]",
|
|
679
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
680
|
+
_config: "Optional[SQLConfig]" = None,
|
|
681
|
+
**kwargs: Any,
|
|
682
|
+
) -> "list[ModelDTOT]": ...
|
|
683
|
+
|
|
684
|
+
@overload
|
|
685
|
+
async def execute(
|
|
686
|
+
self,
|
|
687
|
+
statement: "Union[str, SQL]",
|
|
688
|
+
/,
|
|
689
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
690
|
+
schema_type: None = None,
|
|
691
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
692
|
+
_config: "Optional[SQLConfig]" = None,
|
|
693
|
+
**kwargs: Any,
|
|
694
|
+
) -> "list[RowT]": ...
|
|
695
|
+
|
|
696
|
+
async def execute(
|
|
697
|
+
self,
|
|
698
|
+
statement: "Union[Statement, QueryBuilder[Any]]",
|
|
699
|
+
/,
|
|
700
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
701
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
702
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
703
|
+
_config: "Optional[SQLConfig]" = None,
|
|
704
|
+
**kwargs: Any,
|
|
705
|
+
) -> Any:
|
|
706
|
+
"""Execute a statement and return the result."""
|
|
707
|
+
result = await self.driver.execute(
|
|
708
|
+
statement, *parameters, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
709
|
+
)
|
|
710
|
+
return result.get_data()
|
|
711
|
+
|
|
712
|
+
async def execute_many(
|
|
713
|
+
self,
|
|
714
|
+
statement: "Union[Statement, QueryBuilder[Any]]",
|
|
715
|
+
/,
|
|
716
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
717
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
718
|
+
_config: "Optional[SQLConfig]" = None,
|
|
719
|
+
**kwargs: Any,
|
|
720
|
+
) -> Any:
|
|
721
|
+
"""Execute a statement multiple times and return the result."""
|
|
722
|
+
result = await self.driver.execute_many(
|
|
723
|
+
statement, *parameters, _connection=_connection, _config=_config, **kwargs
|
|
724
|
+
)
|
|
725
|
+
return result.get_data()
|
|
726
|
+
|
|
727
|
+
async def execute_script(
|
|
728
|
+
self,
|
|
729
|
+
statement: "Statement",
|
|
730
|
+
/,
|
|
731
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
732
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
733
|
+
_config: "Optional[SQLConfig]" = None,
|
|
734
|
+
**kwargs: Any,
|
|
735
|
+
) -> Any:
|
|
736
|
+
"""Execute a script statement."""
|
|
737
|
+
result = await self.driver.execute_script(
|
|
738
|
+
statement, *parameters, _connection=_connection, _config=_config, **kwargs
|
|
739
|
+
)
|
|
740
|
+
return result.get_data()
|
|
741
|
+
|
|
742
|
+
@overload
|
|
743
|
+
async def select_one(
|
|
744
|
+
self,
|
|
745
|
+
statement: "Union[Statement, Select]",
|
|
746
|
+
/,
|
|
747
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
748
|
+
schema_type: "type[ModelDTOT]",
|
|
749
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
750
|
+
_config: "Optional[SQLConfig]" = None,
|
|
751
|
+
**kwargs: Any,
|
|
752
|
+
) -> "ModelDTOT": ...
|
|
753
|
+
|
|
754
|
+
@overload
|
|
755
|
+
async def select_one(
|
|
756
|
+
self,
|
|
757
|
+
statement: "Union[Statement, Select]",
|
|
758
|
+
/,
|
|
759
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
760
|
+
schema_type: None = None,
|
|
761
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
762
|
+
_config: "Optional[SQLConfig]" = None,
|
|
763
|
+
**kwargs: Any,
|
|
764
|
+
) -> "RowT": ...
|
|
765
|
+
|
|
766
|
+
async def select_one(
|
|
767
|
+
self,
|
|
768
|
+
statement: "Union[Statement, Select]",
|
|
769
|
+
/,
|
|
770
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
771
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
772
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
773
|
+
_config: "Optional[SQLConfig]" = None,
|
|
774
|
+
**kwargs: Any,
|
|
775
|
+
) -> Any:
|
|
776
|
+
"""Execute a select statement and return exactly one row.
|
|
777
|
+
|
|
778
|
+
Raises an exception if no rows or more than one row is returned.
|
|
779
|
+
"""
|
|
780
|
+
result = await self.driver.execute(
|
|
781
|
+
statement, *parameters, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
782
|
+
)
|
|
783
|
+
data = result.get_data()
|
|
784
|
+
# For select operations, data should be a list
|
|
785
|
+
if not isinstance(data, list):
|
|
786
|
+
msg = "Expected list result from select operation"
|
|
787
|
+
raise TypeError(msg)
|
|
788
|
+
if not data:
|
|
789
|
+
msg = "No rows found"
|
|
790
|
+
raise ValueError(msg)
|
|
791
|
+
if len(data) > 1:
|
|
792
|
+
msg = f"Expected exactly one row, found {len(data)}"
|
|
793
|
+
raise ValueError(msg)
|
|
794
|
+
return data[0]
|
|
795
|
+
|
|
796
|
+
@overload
|
|
797
|
+
async def select_one_or_none(
|
|
798
|
+
self,
|
|
799
|
+
statement: "Union[Statement, Select]",
|
|
800
|
+
/,
|
|
801
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
802
|
+
schema_type: "type[ModelDTOT]",
|
|
803
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
804
|
+
_config: "Optional[SQLConfig]" = None,
|
|
805
|
+
**kwargs: Any,
|
|
806
|
+
) -> "Optional[ModelDTOT]": ...
|
|
807
|
+
|
|
808
|
+
@overload
|
|
809
|
+
async def select_one_or_none(
|
|
810
|
+
self,
|
|
811
|
+
statement: "Union[Statement, Select]",
|
|
812
|
+
/,
|
|
813
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
814
|
+
schema_type: None = None,
|
|
815
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
816
|
+
_config: "Optional[SQLConfig]" = None,
|
|
817
|
+
**kwargs: Any,
|
|
818
|
+
) -> "Optional[RowT]": ...
|
|
819
|
+
|
|
820
|
+
async def select_one_or_none(
|
|
821
|
+
self,
|
|
822
|
+
statement: "Union[Statement, Select]",
|
|
823
|
+
/,
|
|
824
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
825
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
826
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
827
|
+
_config: "Optional[SQLConfig]" = None,
|
|
828
|
+
**kwargs: Any,
|
|
829
|
+
) -> Any:
|
|
830
|
+
"""Execute a select statement and return at most one row.
|
|
831
|
+
|
|
832
|
+
Returns None if no rows are found.
|
|
833
|
+
Raises an exception if more than one row is returned.
|
|
834
|
+
"""
|
|
835
|
+
result = await self.driver.execute(
|
|
836
|
+
statement, *parameters, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
837
|
+
)
|
|
838
|
+
data = result.get_data()
|
|
839
|
+
# For select operations, data should be a list
|
|
840
|
+
if not isinstance(data, list):
|
|
841
|
+
msg = "Expected list result from select operation"
|
|
842
|
+
raise TypeError(msg)
|
|
843
|
+
if not data:
|
|
844
|
+
return None
|
|
845
|
+
if len(data) > 1:
|
|
846
|
+
msg = f"Expected at most one row, found {len(data)}"
|
|
847
|
+
raise ValueError(msg)
|
|
848
|
+
return data[0]
|
|
849
|
+
|
|
850
|
+
@overload
|
|
851
|
+
async def select(
|
|
852
|
+
self,
|
|
853
|
+
statement: "Union[Statement, Select]",
|
|
854
|
+
/,
|
|
855
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
856
|
+
schema_type: "type[ModelDTOT]",
|
|
857
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
858
|
+
_config: "Optional[SQLConfig]" = None,
|
|
859
|
+
**kwargs: Any,
|
|
860
|
+
) -> "list[ModelDTOT]": ...
|
|
861
|
+
|
|
862
|
+
@overload
|
|
863
|
+
async def select(
|
|
864
|
+
self,
|
|
865
|
+
statement: "Union[Statement, Select]",
|
|
866
|
+
/,
|
|
867
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
868
|
+
schema_type: None = None,
|
|
869
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
870
|
+
_config: "Optional[SQLConfig]" = None,
|
|
871
|
+
**kwargs: Any,
|
|
872
|
+
) -> "list[RowT]": ...
|
|
873
|
+
|
|
874
|
+
async def select(
|
|
875
|
+
self,
|
|
876
|
+
statement: "Union[Statement, Select]",
|
|
877
|
+
/,
|
|
878
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
879
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
880
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
881
|
+
_config: "Optional[SQLConfig]" = None,
|
|
882
|
+
**kwargs: Any,
|
|
883
|
+
) -> Any:
|
|
884
|
+
"""Execute a select statement and return all rows."""
|
|
885
|
+
result = await self.driver.execute(
|
|
886
|
+
statement, *parameters, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
887
|
+
)
|
|
888
|
+
data = result.get_data()
|
|
889
|
+
# For select operations, data should be a list
|
|
890
|
+
if not isinstance(data, list):
|
|
891
|
+
msg = "Expected list result from select operation"
|
|
892
|
+
raise TypeError(msg)
|
|
893
|
+
return data
|
|
894
|
+
|
|
895
|
+
async def select_value(
|
|
896
|
+
self,
|
|
897
|
+
statement: "Union[Statement, Select]",
|
|
898
|
+
/,
|
|
899
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
900
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
901
|
+
_config: "Optional[SQLConfig]" = None,
|
|
902
|
+
**kwargs: Any,
|
|
903
|
+
) -> Any:
|
|
904
|
+
"""Execute a select statement and return a single scalar value.
|
|
905
|
+
|
|
906
|
+
Expects exactly one row with one column.
|
|
907
|
+
Raises an exception if no rows or more than one row/column is returned.
|
|
908
|
+
"""
|
|
909
|
+
result = await self.driver.execute(statement, *parameters, _connection=_connection, _config=_config, **kwargs)
|
|
910
|
+
data = result.get_data()
|
|
911
|
+
# For select operations, data should be a list
|
|
912
|
+
if not isinstance(data, list):
|
|
913
|
+
msg = "Expected list result from select operation"
|
|
914
|
+
raise TypeError(msg)
|
|
915
|
+
if not data:
|
|
916
|
+
msg = "No rows found"
|
|
917
|
+
raise ValueError(msg)
|
|
918
|
+
if len(data) > 1:
|
|
919
|
+
msg = f"Expected exactly one row, found {len(data)}"
|
|
920
|
+
raise ValueError(msg)
|
|
921
|
+
row = data[0]
|
|
922
|
+
if is_dict_row(row):
|
|
923
|
+
if not row:
|
|
924
|
+
msg = "Row has no columns"
|
|
925
|
+
raise ValueError(msg)
|
|
926
|
+
return next(iter(row.values()))
|
|
927
|
+
if is_indexable_row(row):
|
|
928
|
+
# Tuple or list-like row
|
|
929
|
+
if not row:
|
|
930
|
+
msg = "Row has no columns"
|
|
931
|
+
raise ValueError(msg)
|
|
932
|
+
return row[0]
|
|
933
|
+
msg = f"Unexpected row type: {type(row)}"
|
|
934
|
+
raise ValueError(msg)
|
|
935
|
+
|
|
936
|
+
async def select_value_or_none(
|
|
937
|
+
self,
|
|
938
|
+
statement: "Union[Statement, Select]",
|
|
939
|
+
/,
|
|
940
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
941
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
942
|
+
_config: "Optional[SQLConfig]" = None,
|
|
943
|
+
**kwargs: Any,
|
|
944
|
+
) -> Any:
|
|
945
|
+
"""Execute a select statement and return a single scalar value or None.
|
|
946
|
+
|
|
947
|
+
Returns None if no rows are found.
|
|
948
|
+
Expects at most one row with one column.
|
|
949
|
+
Raises an exception if more than one row is returned.
|
|
950
|
+
"""
|
|
951
|
+
result = await self.driver.execute(statement, *parameters, _connection=_connection, _config=_config, **kwargs)
|
|
952
|
+
data = result.get_data()
|
|
953
|
+
# For select operations, data should be a list
|
|
954
|
+
if not isinstance(data, list):
|
|
955
|
+
msg = "Expected list result from select operation"
|
|
956
|
+
raise TypeError(msg)
|
|
957
|
+
if not data:
|
|
958
|
+
return None
|
|
959
|
+
if len(data) > 1:
|
|
960
|
+
msg = f"Expected at most one row, found {len(data)}"
|
|
961
|
+
raise ValueError(msg)
|
|
962
|
+
row = data[0]
|
|
963
|
+
if isinstance(row, dict):
|
|
964
|
+
if not row:
|
|
965
|
+
return None
|
|
966
|
+
return next(iter(row.values()))
|
|
967
|
+
if isinstance(row, (tuple, list)):
|
|
968
|
+
# Tuple or list-like row
|
|
969
|
+
return row[0]
|
|
970
|
+
# Try indexing - if it fails, we'll get a proper error
|
|
971
|
+
try:
|
|
972
|
+
return row[0]
|
|
973
|
+
except (TypeError, IndexError) as e:
|
|
974
|
+
msg = f"Cannot extract value from row type {type(row).__name__}: {e}"
|
|
975
|
+
raise TypeError(msg) from e
|
|
976
|
+
|
|
977
|
+
@overload
|
|
978
|
+
async def paginate(
|
|
979
|
+
self,
|
|
980
|
+
statement: "Union[Statement, Select]",
|
|
981
|
+
/,
|
|
982
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
983
|
+
schema_type: "type[ModelDTOT]",
|
|
984
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
985
|
+
_config: "Optional[SQLConfig]" = None,
|
|
986
|
+
**kwargs: Any,
|
|
987
|
+
) -> "OffsetPagination[ModelDTOT]": ...
|
|
988
|
+
|
|
989
|
+
@overload
|
|
990
|
+
async def paginate(
|
|
991
|
+
self,
|
|
992
|
+
statement: "Union[Statement, Select]",
|
|
993
|
+
/,
|
|
994
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
995
|
+
schema_type: None = None,
|
|
996
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
997
|
+
_config: "Optional[SQLConfig]" = None,
|
|
998
|
+
**kwargs: Any,
|
|
999
|
+
) -> "OffsetPagination[RowT]": ...
|
|
1000
|
+
|
|
1001
|
+
async def paginate(
|
|
1002
|
+
self,
|
|
1003
|
+
statement: "Union[Statement, Select]",
|
|
1004
|
+
/,
|
|
1005
|
+
*parameters: "Union[StatementParameters, StatementFilter]",
|
|
1006
|
+
schema_type: "Optional[type[ModelDTOT]]" = None,
|
|
1007
|
+
_connection: "Optional[ConnectionT]" = None,
|
|
1008
|
+
_config: "Optional[SQLConfig]" = None,
|
|
1009
|
+
**kwargs: Any,
|
|
1010
|
+
) -> Any:
|
|
1011
|
+
"""Execute a paginated query with automatic counting.
|
|
1012
|
+
|
|
1013
|
+
This method performs two queries:
|
|
1014
|
+
1. A count query to get the total number of results
|
|
1015
|
+
2. A data query with limit/offset applied
|
|
1016
|
+
|
|
1017
|
+
Pagination can be specified either via LimitOffsetFilter in parameters
|
|
1018
|
+
or via 'limit' and 'offset' in kwargs.
|
|
1019
|
+
|
|
1020
|
+
Args:
|
|
1021
|
+
statement: The SELECT statement to paginate
|
|
1022
|
+
*parameters: Statement parameters and filters (can include LimitOffsetFilter)
|
|
1023
|
+
schema_type: Optional model type for automatic schema conversion
|
|
1024
|
+
_connection: Optional connection to use
|
|
1025
|
+
_config: Optional SQL configuration
|
|
1026
|
+
**kwargs: Additional driver-specific arguments. Can include 'limit' and 'offset'
|
|
1027
|
+
if LimitOffsetFilter is not provided
|
|
1028
|
+
|
|
1029
|
+
Returns:
|
|
1030
|
+
OffsetPagination object containing items, limit, offset, and total count
|
|
1031
|
+
|
|
1032
|
+
Raises:
|
|
1033
|
+
ValueError: If neither LimitOffsetFilter nor limit/offset kwargs are provided
|
|
1034
|
+
|
|
1035
|
+
Example:
|
|
1036
|
+
>>> # Basic pagination
|
|
1037
|
+
>>> from sqlspec.statement.filters import LimitOffsetFilter
|
|
1038
|
+
>>> result = await service.paginate(
|
|
1039
|
+
... sql.select("*").from_("users"),
|
|
1040
|
+
... LimitOffsetFilter(limit=10, offset=20),
|
|
1041
|
+
... )
|
|
1042
|
+
>>> print(
|
|
1043
|
+
... f"Showing {len(result.items)} of {result.total} users"
|
|
1044
|
+
... )
|
|
1045
|
+
|
|
1046
|
+
>>> # With schema conversion
|
|
1047
|
+
>>> result = await service.paginate(
|
|
1048
|
+
... sql.select("*").from_("users"),
|
|
1049
|
+
... LimitOffsetFilter(limit=10, offset=0),
|
|
1050
|
+
... schema_type=User,
|
|
1051
|
+
... )
|
|
1052
|
+
>>> # result.items is list[User] with proper type inference
|
|
1053
|
+
|
|
1054
|
+
>>> # With multiple filters
|
|
1055
|
+
>>> from sqlspec.statement.filters import (
|
|
1056
|
+
... LimitOffsetFilter,
|
|
1057
|
+
... OrderByFilter,
|
|
1058
|
+
... )
|
|
1059
|
+
>>> result = await service.paginate(
|
|
1060
|
+
... sql.select("*").from_("users"),
|
|
1061
|
+
... OrderByFilter("created_at", "desc"),
|
|
1062
|
+
... LimitOffsetFilter(limit=20, offset=40),
|
|
1063
|
+
... schema_type=User,
|
|
1064
|
+
... )
|
|
1065
|
+
"""
|
|
1066
|
+
from sqlspec.service.pagination import OffsetPagination
|
|
1067
|
+
from sqlspec.statement.sql import SQL
|
|
1068
|
+
|
|
1069
|
+
# Separate filters from parameters
|
|
1070
|
+
filters: list[StatementFilter] = []
|
|
1071
|
+
params: list[Any] = []
|
|
1072
|
+
|
|
1073
|
+
for p in parameters:
|
|
1074
|
+
# Use type guard to check if it implements the StatementFilter protocol
|
|
1075
|
+
if is_statement_filter(p):
|
|
1076
|
+
filters.append(p)
|
|
1077
|
+
else:
|
|
1078
|
+
params.append(p)
|
|
1079
|
+
|
|
1080
|
+
# Check for LimitOffsetFilter in filters
|
|
1081
|
+
limit_offset_filter = None
|
|
1082
|
+
other_filters = []
|
|
1083
|
+
for f in filters:
|
|
1084
|
+
if is_limit_offset_filter(f):
|
|
1085
|
+
limit_offset_filter = f
|
|
1086
|
+
else:
|
|
1087
|
+
other_filters.append(f)
|
|
1088
|
+
|
|
1089
|
+
if limit_offset_filter is not None:
|
|
1090
|
+
limit = limit_offset_filter.limit
|
|
1091
|
+
offset = limit_offset_filter.offset
|
|
1092
|
+
elif "limit" in kwargs and "offset" in kwargs:
|
|
1093
|
+
limit = kwargs.pop("limit")
|
|
1094
|
+
offset = kwargs.pop("offset")
|
|
1095
|
+
else:
|
|
1096
|
+
msg = "Pagination requires either a LimitOffsetFilter in parameters or 'limit' and 'offset' in kwargs."
|
|
1097
|
+
raise ValueError(msg)
|
|
1098
|
+
|
|
1099
|
+
base_stmt = self._normalize_statement(statement, params, _config)
|
|
1100
|
+
|
|
1101
|
+
filtered_stmt = base_stmt
|
|
1102
|
+
for filter_obj in other_filters:
|
|
1103
|
+
filtered_stmt = filter_obj.append_to_statement(filtered_stmt)
|
|
1104
|
+
|
|
1105
|
+
sql_str = filtered_stmt.to_sql()
|
|
1106
|
+
|
|
1107
|
+
# Parse and transform the AST to create a count query
|
|
1108
|
+
parsed = parse_one(sql_str)
|
|
1109
|
+
|
|
1110
|
+
# Using exp.Subquery to properly wrap the parsed expression
|
|
1111
|
+
subquery = exp.Subquery(this=parsed, alias="_count_subquery")
|
|
1112
|
+
count_ast = exp.Select().select(exp.func("COUNT", exp.Star()).as_("total")).from_(subquery)
|
|
1113
|
+
|
|
1114
|
+
count_stmt = SQL(count_ast.sql(), _config=_config)
|
|
1115
|
+
|
|
1116
|
+
# Execute count query
|
|
1117
|
+
total = await self.select_value(count_stmt, _connection=_connection, _config=_config, **kwargs)
|
|
1118
|
+
|
|
1119
|
+
data_stmt = self._normalize_statement(statement, params, _config)
|
|
1120
|
+
|
|
1121
|
+
for filter_obj in other_filters:
|
|
1122
|
+
data_stmt = filter_obj.append_to_statement(data_stmt)
|
|
1123
|
+
|
|
1124
|
+
data_stmt = data_stmt.limit(limit).offset(offset)
|
|
1125
|
+
|
|
1126
|
+
# Execute data query
|
|
1127
|
+
items = await self.select(
|
|
1128
|
+
data_stmt, schema_type=schema_type, _connection=_connection, _config=_config, **kwargs
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
return OffsetPagination(items=items, limit=limit, offset=offset, total=total)
|