sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (155) hide show
  1. sqlspec/__init__.py +16 -3
  2. sqlspec/_serialization.py +3 -10
  3. sqlspec/_sql.py +1147 -0
  4. sqlspec/_typing.py +343 -41
  5. sqlspec/adapters/adbc/__init__.py +2 -6
  6. sqlspec/adapters/adbc/config.py +474 -149
  7. sqlspec/adapters/adbc/driver.py +330 -644
  8. sqlspec/adapters/aiosqlite/__init__.py +2 -6
  9. sqlspec/adapters/aiosqlite/config.py +143 -57
  10. sqlspec/adapters/aiosqlite/driver.py +269 -462
  11. sqlspec/adapters/asyncmy/__init__.py +3 -8
  12. sqlspec/adapters/asyncmy/config.py +247 -202
  13. sqlspec/adapters/asyncmy/driver.py +217 -451
  14. sqlspec/adapters/asyncpg/__init__.py +4 -7
  15. sqlspec/adapters/asyncpg/config.py +329 -176
  16. sqlspec/adapters/asyncpg/driver.py +418 -498
  17. sqlspec/adapters/bigquery/__init__.py +2 -2
  18. sqlspec/adapters/bigquery/config.py +407 -0
  19. sqlspec/adapters/bigquery/driver.py +592 -634
  20. sqlspec/adapters/duckdb/__init__.py +4 -1
  21. sqlspec/adapters/duckdb/config.py +432 -321
  22. sqlspec/adapters/duckdb/driver.py +393 -436
  23. sqlspec/adapters/oracledb/__init__.py +3 -8
  24. sqlspec/adapters/oracledb/config.py +625 -0
  25. sqlspec/adapters/oracledb/driver.py +549 -942
  26. sqlspec/adapters/psqlpy/__init__.py +4 -7
  27. sqlspec/adapters/psqlpy/config.py +372 -203
  28. sqlspec/adapters/psqlpy/driver.py +197 -550
  29. sqlspec/adapters/psycopg/__init__.py +3 -8
  30. sqlspec/adapters/psycopg/config.py +741 -0
  31. sqlspec/adapters/psycopg/driver.py +732 -733
  32. sqlspec/adapters/sqlite/__init__.py +2 -6
  33. sqlspec/adapters/sqlite/config.py +146 -81
  34. sqlspec/adapters/sqlite/driver.py +243 -426
  35. sqlspec/base.py +220 -825
  36. sqlspec/config.py +354 -0
  37. sqlspec/driver/__init__.py +22 -0
  38. sqlspec/driver/_async.py +252 -0
  39. sqlspec/driver/_common.py +338 -0
  40. sqlspec/driver/_sync.py +261 -0
  41. sqlspec/driver/mixins/__init__.py +17 -0
  42. sqlspec/driver/mixins/_pipeline.py +523 -0
  43. sqlspec/driver/mixins/_result_utils.py +122 -0
  44. sqlspec/driver/mixins/_sql_translator.py +35 -0
  45. sqlspec/driver/mixins/_storage.py +993 -0
  46. sqlspec/driver/mixins/_type_coercion.py +131 -0
  47. sqlspec/exceptions.py +299 -7
  48. sqlspec/extensions/aiosql/__init__.py +10 -0
  49. sqlspec/extensions/aiosql/adapter.py +474 -0
  50. sqlspec/extensions/litestar/__init__.py +1 -6
  51. sqlspec/extensions/litestar/_utils.py +1 -5
  52. sqlspec/extensions/litestar/config.py +5 -6
  53. sqlspec/extensions/litestar/handlers.py +13 -12
  54. sqlspec/extensions/litestar/plugin.py +22 -24
  55. sqlspec/extensions/litestar/providers.py +37 -55
  56. sqlspec/loader.py +528 -0
  57. sqlspec/service/__init__.py +3 -0
  58. sqlspec/service/base.py +24 -0
  59. sqlspec/service/pagination.py +26 -0
  60. sqlspec/statement/__init__.py +21 -0
  61. sqlspec/statement/builder/__init__.py +54 -0
  62. sqlspec/statement/builder/_ddl_utils.py +119 -0
  63. sqlspec/statement/builder/_parsing_utils.py +135 -0
  64. sqlspec/statement/builder/base.py +328 -0
  65. sqlspec/statement/builder/ddl.py +1379 -0
  66. sqlspec/statement/builder/delete.py +80 -0
  67. sqlspec/statement/builder/insert.py +274 -0
  68. sqlspec/statement/builder/merge.py +95 -0
  69. sqlspec/statement/builder/mixins/__init__.py +65 -0
  70. sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
  71. sqlspec/statement/builder/mixins/_case_builder.py +91 -0
  72. sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
  73. sqlspec/statement/builder/mixins/_delete_from.py +34 -0
  74. sqlspec/statement/builder/mixins/_from.py +61 -0
  75. sqlspec/statement/builder/mixins/_group_by.py +119 -0
  76. sqlspec/statement/builder/mixins/_having.py +35 -0
  77. sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
  78. sqlspec/statement/builder/mixins/_insert_into.py +36 -0
  79. sqlspec/statement/builder/mixins/_insert_values.py +69 -0
  80. sqlspec/statement/builder/mixins/_join.py +110 -0
  81. sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
  82. sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
  83. sqlspec/statement/builder/mixins/_order_by.py +46 -0
  84. sqlspec/statement/builder/mixins/_pivot.py +82 -0
  85. sqlspec/statement/builder/mixins/_returning.py +37 -0
  86. sqlspec/statement/builder/mixins/_select_columns.py +60 -0
  87. sqlspec/statement/builder/mixins/_set_ops.py +122 -0
  88. sqlspec/statement/builder/mixins/_unpivot.py +80 -0
  89. sqlspec/statement/builder/mixins/_update_from.py +54 -0
  90. sqlspec/statement/builder/mixins/_update_set.py +91 -0
  91. sqlspec/statement/builder/mixins/_update_table.py +29 -0
  92. sqlspec/statement/builder/mixins/_where.py +374 -0
  93. sqlspec/statement/builder/mixins/_window_functions.py +86 -0
  94. sqlspec/statement/builder/protocols.py +20 -0
  95. sqlspec/statement/builder/select.py +206 -0
  96. sqlspec/statement/builder/update.py +178 -0
  97. sqlspec/statement/filters.py +571 -0
  98. sqlspec/statement/parameters.py +736 -0
  99. sqlspec/statement/pipelines/__init__.py +67 -0
  100. sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
  101. sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
  102. sqlspec/statement/pipelines/base.py +315 -0
  103. sqlspec/statement/pipelines/context.py +119 -0
  104. sqlspec/statement/pipelines/result_types.py +41 -0
  105. sqlspec/statement/pipelines/transformers/__init__.py +8 -0
  106. sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
  107. sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
  108. sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
  109. sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
  110. sqlspec/statement/pipelines/validators/__init__.py +23 -0
  111. sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
  112. sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
  113. sqlspec/statement/pipelines/validators/_performance.py +703 -0
  114. sqlspec/statement/pipelines/validators/_security.py +990 -0
  115. sqlspec/statement/pipelines/validators/base.py +67 -0
  116. sqlspec/statement/result.py +527 -0
  117. sqlspec/statement/splitter.py +701 -0
  118. sqlspec/statement/sql.py +1198 -0
  119. sqlspec/storage/__init__.py +15 -0
  120. sqlspec/storage/backends/__init__.py +0 -0
  121. sqlspec/storage/backends/base.py +166 -0
  122. sqlspec/storage/backends/fsspec.py +315 -0
  123. sqlspec/storage/backends/obstore.py +464 -0
  124. sqlspec/storage/protocol.py +170 -0
  125. sqlspec/storage/registry.py +315 -0
  126. sqlspec/typing.py +157 -36
  127. sqlspec/utils/correlation.py +155 -0
  128. sqlspec/utils/deprecation.py +3 -6
  129. sqlspec/utils/fixtures.py +6 -11
  130. sqlspec/utils/logging.py +135 -0
  131. sqlspec/utils/module_loader.py +45 -43
  132. sqlspec/utils/serializers.py +4 -0
  133. sqlspec/utils/singleton.py +6 -8
  134. sqlspec/utils/sync_tools.py +15 -27
  135. sqlspec/utils/text.py +58 -26
  136. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
  137. sqlspec-0.12.0.dist-info/RECORD +145 -0
  138. sqlspec/adapters/bigquery/config/__init__.py +0 -3
  139. sqlspec/adapters/bigquery/config/_common.py +0 -40
  140. sqlspec/adapters/bigquery/config/_sync.py +0 -87
  141. sqlspec/adapters/oracledb/config/__init__.py +0 -9
  142. sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
  143. sqlspec/adapters/oracledb/config/_common.py +0 -131
  144. sqlspec/adapters/oracledb/config/_sync.py +0 -186
  145. sqlspec/adapters/psycopg/config/__init__.py +0 -19
  146. sqlspec/adapters/psycopg/config/_async.py +0 -169
  147. sqlspec/adapters/psycopg/config/_common.py +0 -56
  148. sqlspec/adapters/psycopg/config/_sync.py +0 -168
  149. sqlspec/filters.py +0 -330
  150. sqlspec/mixins.py +0 -306
  151. sqlspec/statement.py +0 -378
  152. sqlspec-0.11.0.dist-info/RECORD +0 -69
  153. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
  154. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
  155. {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,67 @@
1
+ # Base class for validators
2
+ from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from sqlspec.exceptions import RiskLevel
6
+ from sqlspec.statement.pipelines.base import ProcessorProtocol
7
+ from sqlspec.statement.pipelines.result_types import ValidationError
8
+
9
+ if TYPE_CHECKING:
10
+ from sqlglot import exp
11
+
12
+ from sqlspec.statement.pipelines.context import SQLProcessingContext
13
+
14
+ __all__ = ("BaseValidator",)
15
+
16
+
17
+ class BaseValidator(ProcessorProtocol, ABC):
18
+ """Base class for all validators."""
19
+
20
+ def process(
21
+ self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
22
+ ) -> "Optional[exp.Expression]":
23
+ """Process the SQL expression through this validator.
24
+
25
+ Args:
26
+ expression: The SQL expression to validate.
27
+ context: The SQL processing context.
28
+
29
+ Returns:
30
+ The expression unchanged (validators don't transform).
31
+ """
32
+ if expression is None:
33
+ return None
34
+ self.validate(expression, context)
35
+ return expression
36
+
37
+ @abstractmethod
38
+ def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
39
+ """Validate the expression and add any errors to the context.
40
+
41
+ Args:
42
+ expression: The SQL expression to validate.
43
+ context: The SQL processing context.
44
+ """
45
+ raise NotImplementedError
46
+
47
+ def add_error(
48
+ self,
49
+ context: "SQLProcessingContext",
50
+ message: str,
51
+ code: str,
52
+ risk_level: RiskLevel,
53
+ expression: "exp.Expression | None" = None,
54
+ ) -> None:
55
+ """Helper to add a validation error to the context.
56
+
57
+ Args:
58
+ context: The SQL processing context.
59
+ message: The error message.
60
+ code: The error code.
61
+ risk_level: The risk level.
62
+ expression: The specific expression with the error (optional).
63
+ """
64
+ error = ValidationError(
65
+ message=message, code=code, risk_level=risk_level, processor=self.__class__.__name__, expression=expression
66
+ )
67
+ context.validation_errors.append(error)
@@ -0,0 +1,527 @@
1
+ """SQL statement result classes for handling different types of SQL operations."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ # Import Mapping for type checking in __post_init__
6
+ from collections.abc import Mapping, Sequence
7
+ from dataclasses import dataclass, field
8
+ from typing import TYPE_CHECKING, Any, Generic, Optional, Union, cast
9
+
10
+ from typing_extensions import TypedDict, TypeVar
11
+
12
+ from sqlspec.typing import ArrowTable, RowT
13
+
14
+ if TYPE_CHECKING:
15
+ from sqlspec.statement.sql import SQL
16
+
17
+ __all__ = ("ArrowResult", "DMLResultDict", "SQLResult", "ScriptResultDict", "SelectResultDict", "StatementResult")
18
+
19
+
20
+ T = TypeVar("T")
21
+
22
+
23
+ class SelectResultDict(TypedDict):
24
+ """TypedDict for SELECT/RETURNING query results.
25
+
26
+ This structure is returned by drivers when executing SELECT queries
27
+ or DML queries with RETURNING clauses.
28
+ """
29
+
30
+ data: "list[Any]"
31
+ """List of rows returned by the query."""
32
+ column_names: "list[str]"
33
+ """List of column names in the result set."""
34
+ rows_affected: int
35
+ """Number of rows affected (-1 when unsupported)."""
36
+
37
+
38
+ class DMLResultDict(TypedDict, total=False):
39
+ """TypedDict for DML (INSERT/UPDATE/DELETE) results without RETURNING.
40
+
41
+ This structure is returned by drivers when executing DML operations
42
+ that don't return data (no RETURNING clause).
43
+ """
44
+
45
+ rows_affected: int
46
+ """Number of rows affected by the operation."""
47
+ status_message: str
48
+ """Status message from the database (-1 when unsupported)."""
49
+ description: str
50
+ """Optional description of the operation."""
51
+
52
+
53
+ class ScriptResultDict(TypedDict, total=False):
54
+ """TypedDict for script execution results.
55
+
56
+ This structure is returned by drivers when executing multi-statement
57
+ SQL scripts.
58
+ """
59
+
60
+ statements_executed: int
61
+ """Number of statements that were executed."""
62
+ status_message: str
63
+ """Overall status message from the script execution."""
64
+ description: str
65
+ """Optional description of the script execution."""
66
+
67
+
68
+ @dataclass
69
+ class StatementResult(ABC, Generic[RowT]):
70
+ """Base class for SQL statement execution results.
71
+
72
+ This class provides a common interface for handling different types of
73
+ SQL operation results. Subclasses implement specific behavior for
74
+ SELECT, INSERT/UPDATE/DELETE, and script operations.
75
+
76
+ Args:
77
+ statement: The original SQL statement that was executed.
78
+ data: The result data from the operation.
79
+ rows_affected: Number of rows affected by the operation (if applicable).
80
+ last_inserted_id: Last inserted ID (if applicable).
81
+ execution_time: Time taken to execute the statement in seconds.
82
+ metadata: Additional metadata about the operation.
83
+ """
84
+
85
+ statement: "SQL"
86
+ """The original SQL statement that was executed."""
87
+ data: "Any"
88
+ """The result data from the operation."""
89
+ rows_affected: int = 0
90
+ """Number of rows affected by the operation."""
91
+ last_inserted_id: Optional[Union[int, str]] = None
92
+ """Last inserted ID from the operation."""
93
+ execution_time: Optional[float] = None
94
+ """Time taken to execute the statement in seconds."""
95
+ metadata: "dict[str, Any]" = field(default_factory=dict)
96
+ """Additional metadata about the operation."""
97
+
98
+ @abstractmethod
99
+ def is_success(self) -> bool:
100
+ """Check if the operation was successful.
101
+
102
+ Returns:
103
+ True if the operation completed successfully, False otherwise.
104
+ """
105
+
106
+ @abstractmethod
107
+ def get_data(self) -> "Any":
108
+ """Get the processed data from the result.
109
+
110
+ Returns:
111
+ The processed result data in an appropriate format.
112
+ """
113
+
114
+ def get_metadata(self, key: str, default: Any = None) -> Any:
115
+ """Get metadata value by key.
116
+
117
+ Args:
118
+ key: The metadata key to retrieve.
119
+ default: Default value if key is not found.
120
+
121
+ Returns:
122
+ The metadata value or default.
123
+ """
124
+ return self.metadata.get(key, default)
125
+
126
+ def set_metadata(self, key: str, value: Any) -> None:
127
+ """Set metadata value by key.
128
+
129
+ Args:
130
+ key: The metadata key to set.
131
+ value: The value to set.
132
+ """
133
+ self.metadata[key] = value
134
+
135
+
136
+ # RowT is introduced for clarity within SQLResult, representing the type of a single row.
137
+
138
+
139
+ @dataclass
140
+ class SQLResult(StatementResult[RowT], Generic[RowT]):
141
+ """Unified result class for SQL operations that return a list of rows
142
+ or affect rows (e.g., SELECT, INSERT, UPDATE, DELETE).
143
+
144
+ For DML operations with RETURNING clauses, the returned data will be in `self.data`.
145
+ The `operation_type` attribute helps distinguish the nature of the operation.
146
+
147
+ For script execution, this class also tracks multiple statement results and errors.
148
+ """
149
+
150
+ error: Optional[Exception] = None
151
+ operation_index: Optional[int] = None
152
+ pipeline_sql: Optional["SQL"] = None
153
+ parameters: Optional[Any] = None
154
+
155
+ # Attributes primarily for SELECT-like results or results with column structure
156
+ column_names: "list[str]" = field(default_factory=list)
157
+ total_count: Optional[int] = None # Total rows if pagination/limit was involved
158
+ has_more: bool = False # For pagination
159
+
160
+ # Attributes primarily for DML-like results
161
+ operation_type: str = "SELECT" # Default, override for DML
162
+ inserted_ids: "list[Union[int, str]]" = field(default_factory=list)
163
+ # rows_affected and last_inserted_id are inherited from StatementResult
164
+
165
+ # Attributes for script execution
166
+ statement_results: "list[SQLResult[Any]]" = field(default_factory=list)
167
+ """Individual statement results when executing scripts."""
168
+ errors: "list[str]" = field(default_factory=list)
169
+ """Errors encountered during script execution."""
170
+ total_statements: int = 0
171
+ """Total number of statements in the script."""
172
+ successful_statements: int = 0
173
+ """Number of statements that executed successfully."""
174
+
175
+ def __post_init__(self) -> None:
176
+ """Post-initialization to infer column names and total count if not provided."""
177
+ if not self.column_names and self.data and isinstance(self.data[0], Mapping):
178
+ self.column_names = list(self.data[0].keys())
179
+
180
+ if self.total_count is None:
181
+ self.total_count = len(self.data) if self.data is not None else 0
182
+
183
+ # If data is populated for a DML, it implies returning data.
184
+ # No separate returning_data field needed; self.data serves this purpose.
185
+
186
+ def is_success(self) -> bool:
187
+ """Check if the operation was successful.
188
+ - For SELECT: True if data is not None and rows_affected is not negative.
189
+ - For DML (INSERT, UPDATE, DELETE, EXECUTE): True if rows_affected is >= 0.
190
+ - For SCRIPT: True if no errors and all statements succeeded.
191
+ """
192
+ op_type_upper = self.operation_type.upper()
193
+
194
+ # For script execution, check if there are no errors and all statements succeeded
195
+ if op_type_upper == "SCRIPT" or self.statement_results:
196
+ return len(self.errors) == 0 and self.total_statements == self.successful_statements
197
+
198
+ if op_type_upper == "SELECT":
199
+ # For SELECT, success means we got some data container and rows_affected is not negative
200
+ data_success = self.data is not None
201
+ rows_success = self.rows_affected is None or self.rows_affected >= 0
202
+ return data_success and rows_success
203
+ if op_type_upper in {"INSERT", "UPDATE", "DELETE", "EXECUTE"}:
204
+ return self.rows_affected is not None and self.rows_affected >= 0
205
+ return False # Should not happen if operation_type is one of the above
206
+
207
+ def get_data(self) -> "Union[list[RowT], dict[str, Any]]":
208
+ """Get the data from the result.
209
+ For regular operations, returns the list of rows.
210
+ For script operations, returns a summary dictionary.
211
+ """
212
+ # For script execution, return summary data
213
+ if self.operation_type.upper() == "SCRIPT" or self.statement_results:
214
+ return {
215
+ "total_statements": self.total_statements,
216
+ "successful_statements": self.successful_statements,
217
+ "failed_statements": self.total_statements - self.successful_statements,
218
+ "errors": self.errors,
219
+ "statement_results": self.statement_results,
220
+ "total_rows_affected": self.get_total_rows_affected(),
221
+ }
222
+
223
+ # For regular operations, return the data as usual
224
+ return cast("list[RowT]", self.data)
225
+
226
+ # --- Script execution methods ---
227
+
228
+ def add_statement_result(self, result: "SQLResult[Any]") -> None:
229
+ """Add a statement result to the script execution results."""
230
+ self.statement_results.append(result)
231
+ self.total_statements += 1
232
+ if result.is_success():
233
+ self.successful_statements += 1
234
+
235
+ def add_error(self, error: str) -> None:
236
+ """Add an error message to the script execution errors."""
237
+ self.errors.append(error)
238
+
239
+ def get_statement_result(self, index: int) -> "Optional[SQLResult[Any]]":
240
+ """Get a statement result by index."""
241
+ if 0 <= index < len(self.statement_results):
242
+ return self.statement_results[index]
243
+ return None
244
+
245
+ def get_total_rows_affected(self) -> int:
246
+ """Get the total number of rows affected across all statements."""
247
+ if self.statement_results:
248
+ # For script execution, sum up rows affected from all statements
249
+ total = 0
250
+ for stmt_result in self.statement_results:
251
+ if stmt_result.rows_affected is not None and stmt_result.rows_affected >= 0:
252
+ # Only count non-negative values, -1 indicates failure
253
+ total += stmt_result.rows_affected
254
+ return total
255
+ # For single statement execution
256
+ return max(self.rows_affected or 0, 0) # Treat negative values as 0
257
+
258
+ @property
259
+ def num_rows(self) -> int:
260
+ return self.get_total_rows_affected()
261
+
262
+ @property
263
+ def num_columns(self) -> int:
264
+ """Get the number of columns in the result data."""
265
+ return len(self.column_names) if self.column_names else 0
266
+
267
+ def get_errors(self) -> "list[str]":
268
+ """Get all errors from script execution."""
269
+ return self.errors.copy()
270
+
271
+ def has_errors(self) -> bool:
272
+ """Check if there are any errors from script execution."""
273
+ return len(self.errors) > 0
274
+
275
+ # --- Existing methods for regular operations ---
276
+
277
+ def get_first(self) -> "Optional[RowT]":
278
+ """Get the first row from the result, if any."""
279
+ return self.data[0] if self.data else None
280
+
281
+ def get_count(self) -> int:
282
+ """Get the number of rows in the current result set (e.g., a page of data)."""
283
+ return len(self.data) if self.data is not None else 0
284
+
285
+ def is_empty(self) -> bool:
286
+ """Check if the result set (self.data) is empty."""
287
+ return not self.data
288
+
289
+ # --- Methods related to DML operations ---
290
+ def get_affected_count(self) -> int:
291
+ """Get the number of rows affected by a DML operation."""
292
+ return self.rows_affected or 0
293
+
294
+ def get_inserted_id(self) -> "Optional[Union[int, str]]":
295
+ """Get the last inserted ID (typically for single row inserts)."""
296
+ return self.last_inserted_id
297
+
298
+ def get_inserted_ids(self) -> "list[Union[int, str]]":
299
+ """Get all inserted IDs (useful for batch inserts)."""
300
+ return self.inserted_ids
301
+
302
+ def get_returning_data(self) -> "list[RowT]":
303
+ """Get data returned by RETURNING clauses.
304
+ This is effectively self.data for this unified class.
305
+ """
306
+ return cast("list[RowT]", self.data)
307
+
308
+ def was_inserted(self) -> bool:
309
+ """Check if this was an INSERT operation."""
310
+ return self.operation_type.upper() == "INSERT"
311
+
312
+ def was_updated(self) -> bool:
313
+ """Check if this was an UPDATE operation."""
314
+ return self.operation_type.upper() == "UPDATE"
315
+
316
+ def was_deleted(self) -> bool:
317
+ """Check if this was a DELETE operation."""
318
+ return self.operation_type.upper() == "DELETE"
319
+
320
+ def __len__(self) -> int:
321
+ """Get the number of rows in the result set.
322
+
323
+ Returns:
324
+ Number of rows in the data.
325
+ """
326
+ return len(self.data) if self.data is not None else 0
327
+
328
+ def __getitem__(self, index: int) -> "RowT":
329
+ """Get a row by index.
330
+
331
+ Args:
332
+ index: Row index
333
+
334
+ Returns:
335
+ The row at the specified index
336
+
337
+ Raises:
338
+ TypeError: If data is None
339
+ """
340
+ if self.data is None:
341
+ msg = "No data available"
342
+ raise TypeError(msg)
343
+ return cast("RowT", self.data[index])
344
+
345
+ # --- SQLAlchemy-style convenience methods ---
346
+
347
+ def all(self) -> "list[RowT]":
348
+ """Return all rows as a list.
349
+
350
+ Returns:
351
+ List of all rows in the result
352
+ """
353
+ if self.data is None:
354
+ return []
355
+ return cast("list[RowT]", self.data)
356
+
357
+ def one(self) -> "RowT":
358
+ """Return exactly one row.
359
+
360
+ Returns:
361
+ The single row
362
+
363
+ Raises:
364
+ ValueError: If no results or more than one result
365
+ """
366
+ if self.data is None or len(self.data) == 0:
367
+ msg = "No result found, exactly one row expected"
368
+ raise ValueError(msg)
369
+ if len(self.data) > 1:
370
+ msg = f"Multiple results found ({len(self.data)}), exactly one row expected"
371
+ raise ValueError(msg)
372
+ return cast("RowT", self.data[0])
373
+
374
+ def one_or_none(self) -> "Optional[RowT]":
375
+ """Return at most one row.
376
+
377
+ Returns:
378
+ The single row or None if no results
379
+
380
+ Raises:
381
+ ValueError: If more than one result
382
+ """
383
+ if self.data is None or len(self.data) == 0:
384
+ return None
385
+ if len(self.data) > 1:
386
+ msg = f"Multiple results found ({len(self.data)}), at most one row expected"
387
+ raise ValueError(msg)
388
+ return cast("RowT", self.data[0])
389
+
390
+ def scalar(self) -> Any:
391
+ """Return the first column of the first row.
392
+
393
+ Returns:
394
+ The scalar value from first column of first row
395
+
396
+ Raises:
397
+ ValueError: If no results
398
+ """
399
+ row = self.one()
400
+ if isinstance(row, Mapping):
401
+ # For dict-like rows, get the first column value
402
+ if not row:
403
+ msg = "Row has no columns"
404
+ raise ValueError(msg)
405
+ first_key = cast("str", next(iter(row.keys())))
406
+ return cast("Any", row[first_key])
407
+ if isinstance(row, Sequence) and not isinstance(row, (str, bytes)):
408
+ # For tuple/list-like rows
409
+ if len(row) == 0:
410
+ msg = "Row has no columns"
411
+ raise ValueError(msg)
412
+ return cast("Any", row[0])
413
+ # For scalar values returned directly
414
+ return row
415
+
416
+ def scalar_or_none(self) -> Any:
417
+ """Return the first column of the first row, or None if no results.
418
+
419
+ Returns:
420
+ The scalar value from first column of first row, or None
421
+ """
422
+ row = self.one_or_none()
423
+ if row is None:
424
+ return None
425
+
426
+ if isinstance(row, Mapping):
427
+ if not row:
428
+ return None
429
+ first_key = next(iter(row.keys()))
430
+ return row[first_key]
431
+ if isinstance(row, Sequence) and not isinstance(row, (str, bytes)):
432
+ # For tuple/list-like rows
433
+ if len(row) == 0:
434
+ return None
435
+ return cast("Any", row[0])
436
+ # For scalar values returned directly
437
+ return row
438
+
439
+
440
+ @dataclass
441
+ class ArrowResult(StatementResult[ArrowTable]):
442
+ """Result class for SQL operations that return Apache Arrow data.
443
+
444
+ This class is used when database drivers support returning results as
445
+ Apache Arrow format for high-performance data interchange, especially
446
+ useful for analytics workloads and data science applications.
447
+
448
+ Args:
449
+ statement: The original SQL statement that was executed.
450
+ data: The Apache Arrow Table containing the result data.
451
+ schema: Optional Arrow schema information.
452
+ """
453
+
454
+ schema: Optional["dict[str, Any]"] = None
455
+ """Optional Arrow schema information."""
456
+ data: "ArrowTable"
457
+ """The result data from the operation."""
458
+
459
+ def is_success(self) -> bool:
460
+ """Check if the Arrow operation was successful.
461
+
462
+ Returns:
463
+ True if the operation completed successfully and has valid Arrow data.
464
+ """
465
+ return bool(self.data)
466
+
467
+ def get_data(self) -> "ArrowTable":
468
+ """Get the Apache Arrow Table from the result.
469
+
470
+ Returns:
471
+ The Arrow table containing the result data.
472
+
473
+ Raises:
474
+ ValueError: If no Arrow table is available.
475
+ """
476
+ if self.data is None:
477
+ msg = "No Arrow table available for this result"
478
+ raise ValueError(msg)
479
+ return self.data
480
+
481
+ @property
482
+ def column_names(self) -> "list[str]":
483
+ """Get the column names from the Arrow table.
484
+
485
+ Returns:
486
+ List of column names.
487
+
488
+ Raises:
489
+ ValueError: If no Arrow table is available.
490
+ """
491
+ if self.data is None:
492
+ msg = "No Arrow table available"
493
+ raise ValueError(msg)
494
+
495
+ return self.data.column_names
496
+
497
+ @property
498
+ def num_rows(self) -> int:
499
+ """Get the number of rows in the Arrow table.
500
+
501
+ Returns:
502
+ Number of rows.
503
+
504
+ Raises:
505
+ ValueError: If no Arrow table is available.
506
+ """
507
+ if self.data is None:
508
+ msg = "No Arrow table available"
509
+ raise ValueError(msg)
510
+
511
+ return self.data.num_rows
512
+
513
+ @property
514
+ def num_columns(self) -> int:
515
+ """Get the number of columns in the Arrow table.
516
+
517
+ Returns:
518
+ Number of columns.
519
+
520
+ Raises:
521
+ ValueError: If no Arrow table is available.
522
+ """
523
+ if self.data is None:
524
+ msg = "No Arrow table available"
525
+ raise ValueError(msg)
526
+
527
+ return self.data.num_columns