sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -644
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -462
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +217 -451
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +418 -498
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +592 -634
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +393 -436
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +549 -942
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -550
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +732 -733
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +243 -426
- sqlspec/base.py +220 -825
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -330
- sqlspec/mixins.py +0 -306
- sqlspec/statement.py +0 -378
- sqlspec-0.11.0.dist-info/RECORD +0 -69
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Base class for validators
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
|
|
5
|
+
from sqlspec.exceptions import RiskLevel
|
|
6
|
+
from sqlspec.statement.pipelines.base import ProcessorProtocol
|
|
7
|
+
from sqlspec.statement.pipelines.result_types import ValidationError
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from sqlglot import exp
|
|
11
|
+
|
|
12
|
+
from sqlspec.statement.pipelines.context import SQLProcessingContext
|
|
13
|
+
|
|
14
|
+
__all__ = ("BaseValidator",)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BaseValidator(ProcessorProtocol, ABC):
|
|
18
|
+
"""Base class for all validators."""
|
|
19
|
+
|
|
20
|
+
def process(
|
|
21
|
+
self, expression: "Optional[exp.Expression]", context: "SQLProcessingContext"
|
|
22
|
+
) -> "Optional[exp.Expression]":
|
|
23
|
+
"""Process the SQL expression through this validator.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
expression: The SQL expression to validate.
|
|
27
|
+
context: The SQL processing context.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
The expression unchanged (validators don't transform).
|
|
31
|
+
"""
|
|
32
|
+
if expression is None:
|
|
33
|
+
return None
|
|
34
|
+
self.validate(expression, context)
|
|
35
|
+
return expression
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def validate(self, expression: "exp.Expression", context: "SQLProcessingContext") -> None:
|
|
39
|
+
"""Validate the expression and add any errors to the context.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
expression: The SQL expression to validate.
|
|
43
|
+
context: The SQL processing context.
|
|
44
|
+
"""
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
def add_error(
|
|
48
|
+
self,
|
|
49
|
+
context: "SQLProcessingContext",
|
|
50
|
+
message: str,
|
|
51
|
+
code: str,
|
|
52
|
+
risk_level: RiskLevel,
|
|
53
|
+
expression: "exp.Expression | None" = None,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""Helper to add a validation error to the context.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
context: The SQL processing context.
|
|
59
|
+
message: The error message.
|
|
60
|
+
code: The error code.
|
|
61
|
+
risk_level: The risk level.
|
|
62
|
+
expression: The specific expression with the error (optional).
|
|
63
|
+
"""
|
|
64
|
+
error = ValidationError(
|
|
65
|
+
message=message, code=code, risk_level=risk_level, processor=self.__class__.__name__, expression=expression
|
|
66
|
+
)
|
|
67
|
+
context.validation_errors.append(error)
|
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
"""SQL statement result classes for handling different types of SQL operations."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
# Import Mapping for type checking in __post_init__
|
|
6
|
+
from collections.abc import Mapping, Sequence
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Generic, Optional, Union, cast
|
|
9
|
+
|
|
10
|
+
from typing_extensions import TypedDict, TypeVar
|
|
11
|
+
|
|
12
|
+
from sqlspec.typing import ArrowTable, RowT
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from sqlspec.statement.sql import SQL
|
|
16
|
+
|
|
17
|
+
__all__ = ("ArrowResult", "DMLResultDict", "SQLResult", "ScriptResultDict", "SelectResultDict", "StatementResult")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SelectResultDict(TypedDict):
|
|
24
|
+
"""TypedDict for SELECT/RETURNING query results.
|
|
25
|
+
|
|
26
|
+
This structure is returned by drivers when executing SELECT queries
|
|
27
|
+
or DML queries with RETURNING clauses.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
data: "list[Any]"
|
|
31
|
+
"""List of rows returned by the query."""
|
|
32
|
+
column_names: "list[str]"
|
|
33
|
+
"""List of column names in the result set."""
|
|
34
|
+
rows_affected: int
|
|
35
|
+
"""Number of rows affected (-1 when unsupported)."""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class DMLResultDict(TypedDict, total=False):
|
|
39
|
+
"""TypedDict for DML (INSERT/UPDATE/DELETE) results without RETURNING.
|
|
40
|
+
|
|
41
|
+
This structure is returned by drivers when executing DML operations
|
|
42
|
+
that don't return data (no RETURNING clause).
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
rows_affected: int
|
|
46
|
+
"""Number of rows affected by the operation."""
|
|
47
|
+
status_message: str
|
|
48
|
+
"""Status message from the database (-1 when unsupported)."""
|
|
49
|
+
description: str
|
|
50
|
+
"""Optional description of the operation."""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ScriptResultDict(TypedDict, total=False):
|
|
54
|
+
"""TypedDict for script execution results.
|
|
55
|
+
|
|
56
|
+
This structure is returned by drivers when executing multi-statement
|
|
57
|
+
SQL scripts.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
statements_executed: int
|
|
61
|
+
"""Number of statements that were executed."""
|
|
62
|
+
status_message: str
|
|
63
|
+
"""Overall status message from the script execution."""
|
|
64
|
+
description: str
|
|
65
|
+
"""Optional description of the script execution."""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class StatementResult(ABC, Generic[RowT]):
|
|
70
|
+
"""Base class for SQL statement execution results.
|
|
71
|
+
|
|
72
|
+
This class provides a common interface for handling different types of
|
|
73
|
+
SQL operation results. Subclasses implement specific behavior for
|
|
74
|
+
SELECT, INSERT/UPDATE/DELETE, and script operations.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
statement: The original SQL statement that was executed.
|
|
78
|
+
data: The result data from the operation.
|
|
79
|
+
rows_affected: Number of rows affected by the operation (if applicable).
|
|
80
|
+
last_inserted_id: Last inserted ID (if applicable).
|
|
81
|
+
execution_time: Time taken to execute the statement in seconds.
|
|
82
|
+
metadata: Additional metadata about the operation.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
statement: "SQL"
|
|
86
|
+
"""The original SQL statement that was executed."""
|
|
87
|
+
data: "Any"
|
|
88
|
+
"""The result data from the operation."""
|
|
89
|
+
rows_affected: int = 0
|
|
90
|
+
"""Number of rows affected by the operation."""
|
|
91
|
+
last_inserted_id: Optional[Union[int, str]] = None
|
|
92
|
+
"""Last inserted ID from the operation."""
|
|
93
|
+
execution_time: Optional[float] = None
|
|
94
|
+
"""Time taken to execute the statement in seconds."""
|
|
95
|
+
metadata: "dict[str, Any]" = field(default_factory=dict)
|
|
96
|
+
"""Additional metadata about the operation."""
|
|
97
|
+
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def is_success(self) -> bool:
|
|
100
|
+
"""Check if the operation was successful.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
True if the operation completed successfully, False otherwise.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
@abstractmethod
|
|
107
|
+
def get_data(self) -> "Any":
|
|
108
|
+
"""Get the processed data from the result.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
The processed result data in an appropriate format.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def get_metadata(self, key: str, default: Any = None) -> Any:
|
|
115
|
+
"""Get metadata value by key.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
key: The metadata key to retrieve.
|
|
119
|
+
default: Default value if key is not found.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
The metadata value or default.
|
|
123
|
+
"""
|
|
124
|
+
return self.metadata.get(key, default)
|
|
125
|
+
|
|
126
|
+
def set_metadata(self, key: str, value: Any) -> None:
|
|
127
|
+
"""Set metadata value by key.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
key: The metadata key to set.
|
|
131
|
+
value: The value to set.
|
|
132
|
+
"""
|
|
133
|
+
self.metadata[key] = value
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# RowT is introduced for clarity within SQLResult, representing the type of a single row.
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass
|
|
140
|
+
class SQLResult(StatementResult[RowT], Generic[RowT]):
|
|
141
|
+
"""Unified result class for SQL operations that return a list of rows
|
|
142
|
+
or affect rows (e.g., SELECT, INSERT, UPDATE, DELETE).
|
|
143
|
+
|
|
144
|
+
For DML operations with RETURNING clauses, the returned data will be in `self.data`.
|
|
145
|
+
The `operation_type` attribute helps distinguish the nature of the operation.
|
|
146
|
+
|
|
147
|
+
For script execution, this class also tracks multiple statement results and errors.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
error: Optional[Exception] = None
|
|
151
|
+
operation_index: Optional[int] = None
|
|
152
|
+
pipeline_sql: Optional["SQL"] = None
|
|
153
|
+
parameters: Optional[Any] = None
|
|
154
|
+
|
|
155
|
+
# Attributes primarily for SELECT-like results or results with column structure
|
|
156
|
+
column_names: "list[str]" = field(default_factory=list)
|
|
157
|
+
total_count: Optional[int] = None # Total rows if pagination/limit was involved
|
|
158
|
+
has_more: bool = False # For pagination
|
|
159
|
+
|
|
160
|
+
# Attributes primarily for DML-like results
|
|
161
|
+
operation_type: str = "SELECT" # Default, override for DML
|
|
162
|
+
inserted_ids: "list[Union[int, str]]" = field(default_factory=list)
|
|
163
|
+
# rows_affected and last_inserted_id are inherited from StatementResult
|
|
164
|
+
|
|
165
|
+
# Attributes for script execution
|
|
166
|
+
statement_results: "list[SQLResult[Any]]" = field(default_factory=list)
|
|
167
|
+
"""Individual statement results when executing scripts."""
|
|
168
|
+
errors: "list[str]" = field(default_factory=list)
|
|
169
|
+
"""Errors encountered during script execution."""
|
|
170
|
+
total_statements: int = 0
|
|
171
|
+
"""Total number of statements in the script."""
|
|
172
|
+
successful_statements: int = 0
|
|
173
|
+
"""Number of statements that executed successfully."""
|
|
174
|
+
|
|
175
|
+
def __post_init__(self) -> None:
|
|
176
|
+
"""Post-initialization to infer column names and total count if not provided."""
|
|
177
|
+
if not self.column_names and self.data and isinstance(self.data[0], Mapping):
|
|
178
|
+
self.column_names = list(self.data[0].keys())
|
|
179
|
+
|
|
180
|
+
if self.total_count is None:
|
|
181
|
+
self.total_count = len(self.data) if self.data is not None else 0
|
|
182
|
+
|
|
183
|
+
# If data is populated for a DML, it implies returning data.
|
|
184
|
+
# No separate returning_data field needed; self.data serves this purpose.
|
|
185
|
+
|
|
186
|
+
def is_success(self) -> bool:
|
|
187
|
+
"""Check if the operation was successful.
|
|
188
|
+
- For SELECT: True if data is not None and rows_affected is not negative.
|
|
189
|
+
- For DML (INSERT, UPDATE, DELETE, EXECUTE): True if rows_affected is >= 0.
|
|
190
|
+
- For SCRIPT: True if no errors and all statements succeeded.
|
|
191
|
+
"""
|
|
192
|
+
op_type_upper = self.operation_type.upper()
|
|
193
|
+
|
|
194
|
+
# For script execution, check if there are no errors and all statements succeeded
|
|
195
|
+
if op_type_upper == "SCRIPT" or self.statement_results:
|
|
196
|
+
return len(self.errors) == 0 and self.total_statements == self.successful_statements
|
|
197
|
+
|
|
198
|
+
if op_type_upper == "SELECT":
|
|
199
|
+
# For SELECT, success means we got some data container and rows_affected is not negative
|
|
200
|
+
data_success = self.data is not None
|
|
201
|
+
rows_success = self.rows_affected is None or self.rows_affected >= 0
|
|
202
|
+
return data_success and rows_success
|
|
203
|
+
if op_type_upper in {"INSERT", "UPDATE", "DELETE", "EXECUTE"}:
|
|
204
|
+
return self.rows_affected is not None and self.rows_affected >= 0
|
|
205
|
+
return False # Should not happen if operation_type is one of the above
|
|
206
|
+
|
|
207
|
+
def get_data(self) -> "Union[list[RowT], dict[str, Any]]":
|
|
208
|
+
"""Get the data from the result.
|
|
209
|
+
For regular operations, returns the list of rows.
|
|
210
|
+
For script operations, returns a summary dictionary.
|
|
211
|
+
"""
|
|
212
|
+
# For script execution, return summary data
|
|
213
|
+
if self.operation_type.upper() == "SCRIPT" or self.statement_results:
|
|
214
|
+
return {
|
|
215
|
+
"total_statements": self.total_statements,
|
|
216
|
+
"successful_statements": self.successful_statements,
|
|
217
|
+
"failed_statements": self.total_statements - self.successful_statements,
|
|
218
|
+
"errors": self.errors,
|
|
219
|
+
"statement_results": self.statement_results,
|
|
220
|
+
"total_rows_affected": self.get_total_rows_affected(),
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
# For regular operations, return the data as usual
|
|
224
|
+
return cast("list[RowT]", self.data)
|
|
225
|
+
|
|
226
|
+
# --- Script execution methods ---
|
|
227
|
+
|
|
228
|
+
def add_statement_result(self, result: "SQLResult[Any]") -> None:
|
|
229
|
+
"""Add a statement result to the script execution results."""
|
|
230
|
+
self.statement_results.append(result)
|
|
231
|
+
self.total_statements += 1
|
|
232
|
+
if result.is_success():
|
|
233
|
+
self.successful_statements += 1
|
|
234
|
+
|
|
235
|
+
def add_error(self, error: str) -> None:
|
|
236
|
+
"""Add an error message to the script execution errors."""
|
|
237
|
+
self.errors.append(error)
|
|
238
|
+
|
|
239
|
+
def get_statement_result(self, index: int) -> "Optional[SQLResult[Any]]":
|
|
240
|
+
"""Get a statement result by index."""
|
|
241
|
+
if 0 <= index < len(self.statement_results):
|
|
242
|
+
return self.statement_results[index]
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
def get_total_rows_affected(self) -> int:
|
|
246
|
+
"""Get the total number of rows affected across all statements."""
|
|
247
|
+
if self.statement_results:
|
|
248
|
+
# For script execution, sum up rows affected from all statements
|
|
249
|
+
total = 0
|
|
250
|
+
for stmt_result in self.statement_results:
|
|
251
|
+
if stmt_result.rows_affected is not None and stmt_result.rows_affected >= 0:
|
|
252
|
+
# Only count non-negative values, -1 indicates failure
|
|
253
|
+
total += stmt_result.rows_affected
|
|
254
|
+
return total
|
|
255
|
+
# For single statement execution
|
|
256
|
+
return max(self.rows_affected or 0, 0) # Treat negative values as 0
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def num_rows(self) -> int:
|
|
260
|
+
return self.get_total_rows_affected()
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def num_columns(self) -> int:
|
|
264
|
+
"""Get the number of columns in the result data."""
|
|
265
|
+
return len(self.column_names) if self.column_names else 0
|
|
266
|
+
|
|
267
|
+
def get_errors(self) -> "list[str]":
|
|
268
|
+
"""Get all errors from script execution."""
|
|
269
|
+
return self.errors.copy()
|
|
270
|
+
|
|
271
|
+
def has_errors(self) -> bool:
|
|
272
|
+
"""Check if there are any errors from script execution."""
|
|
273
|
+
return len(self.errors) > 0
|
|
274
|
+
|
|
275
|
+
# --- Existing methods for regular operations ---
|
|
276
|
+
|
|
277
|
+
def get_first(self) -> "Optional[RowT]":
|
|
278
|
+
"""Get the first row from the result, if any."""
|
|
279
|
+
return self.data[0] if self.data else None
|
|
280
|
+
|
|
281
|
+
def get_count(self) -> int:
|
|
282
|
+
"""Get the number of rows in the current result set (e.g., a page of data)."""
|
|
283
|
+
return len(self.data) if self.data is not None else 0
|
|
284
|
+
|
|
285
|
+
def is_empty(self) -> bool:
|
|
286
|
+
"""Check if the result set (self.data) is empty."""
|
|
287
|
+
return not self.data
|
|
288
|
+
|
|
289
|
+
# --- Methods related to DML operations ---
|
|
290
|
+
def get_affected_count(self) -> int:
|
|
291
|
+
"""Get the number of rows affected by a DML operation."""
|
|
292
|
+
return self.rows_affected or 0
|
|
293
|
+
|
|
294
|
+
def get_inserted_id(self) -> "Optional[Union[int, str]]":
|
|
295
|
+
"""Get the last inserted ID (typically for single row inserts)."""
|
|
296
|
+
return self.last_inserted_id
|
|
297
|
+
|
|
298
|
+
def get_inserted_ids(self) -> "list[Union[int, str]]":
|
|
299
|
+
"""Get all inserted IDs (useful for batch inserts)."""
|
|
300
|
+
return self.inserted_ids
|
|
301
|
+
|
|
302
|
+
def get_returning_data(self) -> "list[RowT]":
|
|
303
|
+
"""Get data returned by RETURNING clauses.
|
|
304
|
+
This is effectively self.data for this unified class.
|
|
305
|
+
"""
|
|
306
|
+
return cast("list[RowT]", self.data)
|
|
307
|
+
|
|
308
|
+
def was_inserted(self) -> bool:
|
|
309
|
+
"""Check if this was an INSERT operation."""
|
|
310
|
+
return self.operation_type.upper() == "INSERT"
|
|
311
|
+
|
|
312
|
+
def was_updated(self) -> bool:
|
|
313
|
+
"""Check if this was an UPDATE operation."""
|
|
314
|
+
return self.operation_type.upper() == "UPDATE"
|
|
315
|
+
|
|
316
|
+
def was_deleted(self) -> bool:
|
|
317
|
+
"""Check if this was a DELETE operation."""
|
|
318
|
+
return self.operation_type.upper() == "DELETE"
|
|
319
|
+
|
|
320
|
+
def __len__(self) -> int:
|
|
321
|
+
"""Get the number of rows in the result set.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Number of rows in the data.
|
|
325
|
+
"""
|
|
326
|
+
return len(self.data) if self.data is not None else 0
|
|
327
|
+
|
|
328
|
+
def __getitem__(self, index: int) -> "RowT":
|
|
329
|
+
"""Get a row by index.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
index: Row index
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
The row at the specified index
|
|
336
|
+
|
|
337
|
+
Raises:
|
|
338
|
+
TypeError: If data is None
|
|
339
|
+
"""
|
|
340
|
+
if self.data is None:
|
|
341
|
+
msg = "No data available"
|
|
342
|
+
raise TypeError(msg)
|
|
343
|
+
return cast("RowT", self.data[index])
|
|
344
|
+
|
|
345
|
+
# --- SQLAlchemy-style convenience methods ---
|
|
346
|
+
|
|
347
|
+
def all(self) -> "list[RowT]":
|
|
348
|
+
"""Return all rows as a list.
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
List of all rows in the result
|
|
352
|
+
"""
|
|
353
|
+
if self.data is None:
|
|
354
|
+
return []
|
|
355
|
+
return cast("list[RowT]", self.data)
|
|
356
|
+
|
|
357
|
+
def one(self) -> "RowT":
|
|
358
|
+
"""Return exactly one row.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
The single row
|
|
362
|
+
|
|
363
|
+
Raises:
|
|
364
|
+
ValueError: If no results or more than one result
|
|
365
|
+
"""
|
|
366
|
+
if self.data is None or len(self.data) == 0:
|
|
367
|
+
msg = "No result found, exactly one row expected"
|
|
368
|
+
raise ValueError(msg)
|
|
369
|
+
if len(self.data) > 1:
|
|
370
|
+
msg = f"Multiple results found ({len(self.data)}), exactly one row expected"
|
|
371
|
+
raise ValueError(msg)
|
|
372
|
+
return cast("RowT", self.data[0])
|
|
373
|
+
|
|
374
|
+
def one_or_none(self) -> "Optional[RowT]":
|
|
375
|
+
"""Return at most one row.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
The single row or None if no results
|
|
379
|
+
|
|
380
|
+
Raises:
|
|
381
|
+
ValueError: If more than one result
|
|
382
|
+
"""
|
|
383
|
+
if self.data is None or len(self.data) == 0:
|
|
384
|
+
return None
|
|
385
|
+
if len(self.data) > 1:
|
|
386
|
+
msg = f"Multiple results found ({len(self.data)}), at most one row expected"
|
|
387
|
+
raise ValueError(msg)
|
|
388
|
+
return cast("RowT", self.data[0])
|
|
389
|
+
|
|
390
|
+
def scalar(self) -> Any:
|
|
391
|
+
"""Return the first column of the first row.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
The scalar value from first column of first row
|
|
395
|
+
|
|
396
|
+
Raises:
|
|
397
|
+
ValueError: If no results
|
|
398
|
+
"""
|
|
399
|
+
row = self.one()
|
|
400
|
+
if isinstance(row, Mapping):
|
|
401
|
+
# For dict-like rows, get the first column value
|
|
402
|
+
if not row:
|
|
403
|
+
msg = "Row has no columns"
|
|
404
|
+
raise ValueError(msg)
|
|
405
|
+
first_key = cast("str", next(iter(row.keys())))
|
|
406
|
+
return cast("Any", row[first_key])
|
|
407
|
+
if isinstance(row, Sequence) and not isinstance(row, (str, bytes)):
|
|
408
|
+
# For tuple/list-like rows
|
|
409
|
+
if len(row) == 0:
|
|
410
|
+
msg = "Row has no columns"
|
|
411
|
+
raise ValueError(msg)
|
|
412
|
+
return cast("Any", row[0])
|
|
413
|
+
# For scalar values returned directly
|
|
414
|
+
return row
|
|
415
|
+
|
|
416
|
+
def scalar_or_none(self) -> Any:
|
|
417
|
+
"""Return the first column of the first row, or None if no results.
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
The scalar value from first column of first row, or None
|
|
421
|
+
"""
|
|
422
|
+
row = self.one_or_none()
|
|
423
|
+
if row is None:
|
|
424
|
+
return None
|
|
425
|
+
|
|
426
|
+
if isinstance(row, Mapping):
|
|
427
|
+
if not row:
|
|
428
|
+
return None
|
|
429
|
+
first_key = next(iter(row.keys()))
|
|
430
|
+
return row[first_key]
|
|
431
|
+
if isinstance(row, Sequence) and not isinstance(row, (str, bytes)):
|
|
432
|
+
# For tuple/list-like rows
|
|
433
|
+
if len(row) == 0:
|
|
434
|
+
return None
|
|
435
|
+
return cast("Any", row[0])
|
|
436
|
+
# For scalar values returned directly
|
|
437
|
+
return row
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
@dataclass
|
|
441
|
+
class ArrowResult(StatementResult[ArrowTable]):
|
|
442
|
+
"""Result class for SQL operations that return Apache Arrow data.
|
|
443
|
+
|
|
444
|
+
This class is used when database drivers support returning results as
|
|
445
|
+
Apache Arrow format for high-performance data interchange, especially
|
|
446
|
+
useful for analytics workloads and data science applications.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
statement: The original SQL statement that was executed.
|
|
450
|
+
data: The Apache Arrow Table containing the result data.
|
|
451
|
+
schema: Optional Arrow schema information.
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
schema: Optional["dict[str, Any]"] = None
|
|
455
|
+
"""Optional Arrow schema information."""
|
|
456
|
+
data: "ArrowTable"
|
|
457
|
+
"""The result data from the operation."""
|
|
458
|
+
|
|
459
|
+
def is_success(self) -> bool:
|
|
460
|
+
"""Check if the Arrow operation was successful.
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
True if the operation completed successfully and has valid Arrow data.
|
|
464
|
+
"""
|
|
465
|
+
return bool(self.data)
|
|
466
|
+
|
|
467
|
+
def get_data(self) -> "ArrowTable":
|
|
468
|
+
"""Get the Apache Arrow Table from the result.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
The Arrow table containing the result data.
|
|
472
|
+
|
|
473
|
+
Raises:
|
|
474
|
+
ValueError: If no Arrow table is available.
|
|
475
|
+
"""
|
|
476
|
+
if self.data is None:
|
|
477
|
+
msg = "No Arrow table available for this result"
|
|
478
|
+
raise ValueError(msg)
|
|
479
|
+
return self.data
|
|
480
|
+
|
|
481
|
+
@property
|
|
482
|
+
def column_names(self) -> "list[str]":
|
|
483
|
+
"""Get the column names from the Arrow table.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
List of column names.
|
|
487
|
+
|
|
488
|
+
Raises:
|
|
489
|
+
ValueError: If no Arrow table is available.
|
|
490
|
+
"""
|
|
491
|
+
if self.data is None:
|
|
492
|
+
msg = "No Arrow table available"
|
|
493
|
+
raise ValueError(msg)
|
|
494
|
+
|
|
495
|
+
return self.data.column_names
|
|
496
|
+
|
|
497
|
+
@property
|
|
498
|
+
def num_rows(self) -> int:
|
|
499
|
+
"""Get the number of rows in the Arrow table.
|
|
500
|
+
|
|
501
|
+
Returns:
|
|
502
|
+
Number of rows.
|
|
503
|
+
|
|
504
|
+
Raises:
|
|
505
|
+
ValueError: If no Arrow table is available.
|
|
506
|
+
"""
|
|
507
|
+
if self.data is None:
|
|
508
|
+
msg = "No Arrow table available"
|
|
509
|
+
raise ValueError(msg)
|
|
510
|
+
|
|
511
|
+
return self.data.num_rows
|
|
512
|
+
|
|
513
|
+
@property
|
|
514
|
+
def num_columns(self) -> int:
|
|
515
|
+
"""Get the number of columns in the Arrow table.
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
Number of columns.
|
|
519
|
+
|
|
520
|
+
Raises:
|
|
521
|
+
ValueError: If no Arrow table is available.
|
|
522
|
+
"""
|
|
523
|
+
if self.data is None:
|
|
524
|
+
msg = "No Arrow table available"
|
|
525
|
+
raise ValueError(msg)
|
|
526
|
+
|
|
527
|
+
return self.data.num_columns
|