sqlspec 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +16 -3
- sqlspec/_serialization.py +3 -10
- sqlspec/_sql.py +1147 -0
- sqlspec/_typing.py +343 -41
- sqlspec/adapters/adbc/__init__.py +2 -6
- sqlspec/adapters/adbc/config.py +474 -149
- sqlspec/adapters/adbc/driver.py +330 -644
- sqlspec/adapters/aiosqlite/__init__.py +2 -6
- sqlspec/adapters/aiosqlite/config.py +143 -57
- sqlspec/adapters/aiosqlite/driver.py +269 -462
- sqlspec/adapters/asyncmy/__init__.py +3 -8
- sqlspec/adapters/asyncmy/config.py +247 -202
- sqlspec/adapters/asyncmy/driver.py +217 -451
- sqlspec/adapters/asyncpg/__init__.py +4 -7
- sqlspec/adapters/asyncpg/config.py +329 -176
- sqlspec/adapters/asyncpg/driver.py +418 -498
- sqlspec/adapters/bigquery/__init__.py +2 -2
- sqlspec/adapters/bigquery/config.py +407 -0
- sqlspec/adapters/bigquery/driver.py +592 -634
- sqlspec/adapters/duckdb/__init__.py +4 -1
- sqlspec/adapters/duckdb/config.py +432 -321
- sqlspec/adapters/duckdb/driver.py +393 -436
- sqlspec/adapters/oracledb/__init__.py +3 -8
- sqlspec/adapters/oracledb/config.py +625 -0
- sqlspec/adapters/oracledb/driver.py +549 -942
- sqlspec/adapters/psqlpy/__init__.py +4 -7
- sqlspec/adapters/psqlpy/config.py +372 -203
- sqlspec/adapters/psqlpy/driver.py +197 -550
- sqlspec/adapters/psycopg/__init__.py +3 -8
- sqlspec/adapters/psycopg/config.py +741 -0
- sqlspec/adapters/psycopg/driver.py +732 -733
- sqlspec/adapters/sqlite/__init__.py +2 -6
- sqlspec/adapters/sqlite/config.py +146 -81
- sqlspec/adapters/sqlite/driver.py +243 -426
- sqlspec/base.py +220 -825
- sqlspec/config.py +354 -0
- sqlspec/driver/__init__.py +22 -0
- sqlspec/driver/_async.py +252 -0
- sqlspec/driver/_common.py +338 -0
- sqlspec/driver/_sync.py +261 -0
- sqlspec/driver/mixins/__init__.py +17 -0
- sqlspec/driver/mixins/_pipeline.py +523 -0
- sqlspec/driver/mixins/_result_utils.py +122 -0
- sqlspec/driver/mixins/_sql_translator.py +35 -0
- sqlspec/driver/mixins/_storage.py +993 -0
- sqlspec/driver/mixins/_type_coercion.py +131 -0
- sqlspec/exceptions.py +299 -7
- sqlspec/extensions/aiosql/__init__.py +10 -0
- sqlspec/extensions/aiosql/adapter.py +474 -0
- sqlspec/extensions/litestar/__init__.py +1 -6
- sqlspec/extensions/litestar/_utils.py +1 -5
- sqlspec/extensions/litestar/config.py +5 -6
- sqlspec/extensions/litestar/handlers.py +13 -12
- sqlspec/extensions/litestar/plugin.py +22 -24
- sqlspec/extensions/litestar/providers.py +37 -55
- sqlspec/loader.py +528 -0
- sqlspec/service/__init__.py +3 -0
- sqlspec/service/base.py +24 -0
- sqlspec/service/pagination.py +26 -0
- sqlspec/statement/__init__.py +21 -0
- sqlspec/statement/builder/__init__.py +54 -0
- sqlspec/statement/builder/_ddl_utils.py +119 -0
- sqlspec/statement/builder/_parsing_utils.py +135 -0
- sqlspec/statement/builder/base.py +328 -0
- sqlspec/statement/builder/ddl.py +1379 -0
- sqlspec/statement/builder/delete.py +80 -0
- sqlspec/statement/builder/insert.py +274 -0
- sqlspec/statement/builder/merge.py +95 -0
- sqlspec/statement/builder/mixins/__init__.py +65 -0
- sqlspec/statement/builder/mixins/_aggregate_functions.py +151 -0
- sqlspec/statement/builder/mixins/_case_builder.py +91 -0
- sqlspec/statement/builder/mixins/_common_table_expr.py +91 -0
- sqlspec/statement/builder/mixins/_delete_from.py +34 -0
- sqlspec/statement/builder/mixins/_from.py +61 -0
- sqlspec/statement/builder/mixins/_group_by.py +119 -0
- sqlspec/statement/builder/mixins/_having.py +35 -0
- sqlspec/statement/builder/mixins/_insert_from_select.py +48 -0
- sqlspec/statement/builder/mixins/_insert_into.py +36 -0
- sqlspec/statement/builder/mixins/_insert_values.py +69 -0
- sqlspec/statement/builder/mixins/_join.py +110 -0
- sqlspec/statement/builder/mixins/_limit_offset.py +53 -0
- sqlspec/statement/builder/mixins/_merge_clauses.py +405 -0
- sqlspec/statement/builder/mixins/_order_by.py +46 -0
- sqlspec/statement/builder/mixins/_pivot.py +82 -0
- sqlspec/statement/builder/mixins/_returning.py +37 -0
- sqlspec/statement/builder/mixins/_select_columns.py +60 -0
- sqlspec/statement/builder/mixins/_set_ops.py +122 -0
- sqlspec/statement/builder/mixins/_unpivot.py +80 -0
- sqlspec/statement/builder/mixins/_update_from.py +54 -0
- sqlspec/statement/builder/mixins/_update_set.py +91 -0
- sqlspec/statement/builder/mixins/_update_table.py +29 -0
- sqlspec/statement/builder/mixins/_where.py +374 -0
- sqlspec/statement/builder/mixins/_window_functions.py +86 -0
- sqlspec/statement/builder/protocols.py +20 -0
- sqlspec/statement/builder/select.py +206 -0
- sqlspec/statement/builder/update.py +178 -0
- sqlspec/statement/filters.py +571 -0
- sqlspec/statement/parameters.py +736 -0
- sqlspec/statement/pipelines/__init__.py +67 -0
- sqlspec/statement/pipelines/analyzers/__init__.py +9 -0
- sqlspec/statement/pipelines/analyzers/_analyzer.py +649 -0
- sqlspec/statement/pipelines/base.py +315 -0
- sqlspec/statement/pipelines/context.py +119 -0
- sqlspec/statement/pipelines/result_types.py +41 -0
- sqlspec/statement/pipelines/transformers/__init__.py +8 -0
- sqlspec/statement/pipelines/transformers/_expression_simplifier.py +256 -0
- sqlspec/statement/pipelines/transformers/_literal_parameterizer.py +623 -0
- sqlspec/statement/pipelines/transformers/_remove_comments.py +66 -0
- sqlspec/statement/pipelines/transformers/_remove_hints.py +81 -0
- sqlspec/statement/pipelines/validators/__init__.py +23 -0
- sqlspec/statement/pipelines/validators/_dml_safety.py +275 -0
- sqlspec/statement/pipelines/validators/_parameter_style.py +297 -0
- sqlspec/statement/pipelines/validators/_performance.py +703 -0
- sqlspec/statement/pipelines/validators/_security.py +990 -0
- sqlspec/statement/pipelines/validators/base.py +67 -0
- sqlspec/statement/result.py +527 -0
- sqlspec/statement/splitter.py +701 -0
- sqlspec/statement/sql.py +1198 -0
- sqlspec/storage/__init__.py +15 -0
- sqlspec/storage/backends/__init__.py +0 -0
- sqlspec/storage/backends/base.py +166 -0
- sqlspec/storage/backends/fsspec.py +315 -0
- sqlspec/storage/backends/obstore.py +464 -0
- sqlspec/storage/protocol.py +170 -0
- sqlspec/storage/registry.py +315 -0
- sqlspec/typing.py +157 -36
- sqlspec/utils/correlation.py +155 -0
- sqlspec/utils/deprecation.py +3 -6
- sqlspec/utils/fixtures.py +6 -11
- sqlspec/utils/logging.py +135 -0
- sqlspec/utils/module_loader.py +45 -43
- sqlspec/utils/serializers.py +4 -0
- sqlspec/utils/singleton.py +6 -8
- sqlspec/utils/sync_tools.py +15 -27
- sqlspec/utils/text.py +58 -26
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/METADATA +100 -26
- sqlspec-0.12.0.dist-info/RECORD +145 -0
- sqlspec/adapters/bigquery/config/__init__.py +0 -3
- sqlspec/adapters/bigquery/config/_common.py +0 -40
- sqlspec/adapters/bigquery/config/_sync.py +0 -87
- sqlspec/adapters/oracledb/config/__init__.py +0 -9
- sqlspec/adapters/oracledb/config/_asyncio.py +0 -186
- sqlspec/adapters/oracledb/config/_common.py +0 -131
- sqlspec/adapters/oracledb/config/_sync.py +0 -186
- sqlspec/adapters/psycopg/config/__init__.py +0 -19
- sqlspec/adapters/psycopg/config/_async.py +0 -169
- sqlspec/adapters/psycopg/config/_common.py +0 -56
- sqlspec/adapters/psycopg/config/_sync.py +0 -168
- sqlspec/filters.py +0 -330
- sqlspec/mixins.py +0 -306
- sqlspec/statement.py +0 -378
- sqlspec-0.11.0.dist-info/RECORD +0 -69
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/WHEEL +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.11.0.dist-info → sqlspec-0.12.0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""DDL builder utilities."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
|
4
|
+
|
|
5
|
+
from sqlglot import exp
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from sqlspec.statement.builder.ddl import ColumnDefinition, ConstraintDefinition
|
|
9
|
+
|
|
10
|
+
__all__ = ("build_column_expression", "build_constraint_expression")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def build_column_expression(col: "ColumnDefinition") -> "exp.Expression":
|
|
14
|
+
"""Build SQLGlot expression for a column definition."""
|
|
15
|
+
# Start with column name and type
|
|
16
|
+
col_def = exp.ColumnDef(this=exp.to_identifier(col.name), kind=exp.DataType.build(col.dtype))
|
|
17
|
+
|
|
18
|
+
# Add constraints
|
|
19
|
+
constraints: list[exp.ColumnConstraint] = []
|
|
20
|
+
|
|
21
|
+
if col.not_null:
|
|
22
|
+
constraints.append(exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()))
|
|
23
|
+
|
|
24
|
+
if col.primary_key:
|
|
25
|
+
constraints.append(exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint()))
|
|
26
|
+
|
|
27
|
+
if col.unique:
|
|
28
|
+
constraints.append(exp.ColumnConstraint(kind=exp.UniqueColumnConstraint()))
|
|
29
|
+
|
|
30
|
+
if col.default is not None:
|
|
31
|
+
# Handle different default value types
|
|
32
|
+
default_expr: Optional[exp.Expression] = None
|
|
33
|
+
if isinstance(col.default, str):
|
|
34
|
+
# Check if it's a function/expression or a literal string
|
|
35
|
+
if col.default.upper() in {"CURRENT_TIMESTAMP", "CURRENT_DATE", "CURRENT_TIME"} or "(" in col.default:
|
|
36
|
+
default_expr = exp.maybe_parse(col.default)
|
|
37
|
+
else:
|
|
38
|
+
default_expr = exp.Literal.string(col.default)
|
|
39
|
+
elif isinstance(col.default, (int, float)):
|
|
40
|
+
default_expr = exp.Literal.number(col.default)
|
|
41
|
+
elif col.default is True:
|
|
42
|
+
default_expr = exp.true()
|
|
43
|
+
elif col.default is False:
|
|
44
|
+
default_expr = exp.false()
|
|
45
|
+
else:
|
|
46
|
+
default_expr = exp.Literal.string(str(col.default))
|
|
47
|
+
|
|
48
|
+
constraints.append(exp.ColumnConstraint(kind=default_expr))
|
|
49
|
+
|
|
50
|
+
if col.check:
|
|
51
|
+
check_expr = exp.Check(this=exp.maybe_parse(col.check))
|
|
52
|
+
constraints.append(exp.ColumnConstraint(kind=check_expr))
|
|
53
|
+
|
|
54
|
+
if col.comment:
|
|
55
|
+
constraints.append(exp.ColumnConstraint(kind=exp.CommentColumnConstraint(this=exp.Literal.string(col.comment))))
|
|
56
|
+
|
|
57
|
+
if col.generated:
|
|
58
|
+
# Handle generated columns (computed columns)
|
|
59
|
+
generated_expr = exp.GeneratedAsIdentityColumnConstraint(this=exp.maybe_parse(col.generated))
|
|
60
|
+
constraints.append(exp.ColumnConstraint(kind=generated_expr))
|
|
61
|
+
|
|
62
|
+
if col.collate:
|
|
63
|
+
constraints.append(exp.ColumnConstraint(kind=exp.CollateColumnConstraint(this=exp.to_identifier(col.collate))))
|
|
64
|
+
|
|
65
|
+
# Set constraints on column definition
|
|
66
|
+
if constraints:
|
|
67
|
+
col_def.set("constraints", constraints)
|
|
68
|
+
|
|
69
|
+
return col_def
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def build_constraint_expression(constraint: "ConstraintDefinition") -> "Optional[exp.Expression]":
|
|
73
|
+
"""Build SQLGlot expression for a table constraint."""
|
|
74
|
+
if constraint.constraint_type == "PRIMARY KEY":
|
|
75
|
+
# Build primary key constraint
|
|
76
|
+
pk_cols = [exp.to_identifier(col) for col in constraint.columns]
|
|
77
|
+
pk_constraint = exp.PrimaryKey(expressions=pk_cols)
|
|
78
|
+
|
|
79
|
+
if constraint.name:
|
|
80
|
+
return exp.Constraint(this=exp.to_identifier(constraint.name), expression=pk_constraint)
|
|
81
|
+
return pk_constraint
|
|
82
|
+
|
|
83
|
+
if constraint.constraint_type == "FOREIGN KEY":
|
|
84
|
+
# Build foreign key constraint
|
|
85
|
+
fk_cols = [exp.to_identifier(col) for col in constraint.columns]
|
|
86
|
+
ref_cols = [exp.to_identifier(col) for col in constraint.references_columns]
|
|
87
|
+
|
|
88
|
+
fk_constraint = exp.ForeignKey(
|
|
89
|
+
expressions=fk_cols,
|
|
90
|
+
reference=exp.Reference(
|
|
91
|
+
this=exp.to_table(constraint.references_table) if constraint.references_table else None,
|
|
92
|
+
expressions=ref_cols,
|
|
93
|
+
on_delete=constraint.on_delete,
|
|
94
|
+
on_update=constraint.on_update,
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
if constraint.name:
|
|
99
|
+
return exp.Constraint(this=exp.to_identifier(constraint.name), expression=fk_constraint)
|
|
100
|
+
return fk_constraint
|
|
101
|
+
|
|
102
|
+
if constraint.constraint_type == "UNIQUE":
|
|
103
|
+
# Build unique constraint
|
|
104
|
+
unique_cols = [exp.to_identifier(col) for col in constraint.columns]
|
|
105
|
+
unique_constraint = exp.UniqueKeyProperty(expressions=unique_cols)
|
|
106
|
+
|
|
107
|
+
if constraint.name:
|
|
108
|
+
return exp.Constraint(this=exp.to_identifier(constraint.name), expression=unique_constraint)
|
|
109
|
+
return unique_constraint
|
|
110
|
+
|
|
111
|
+
if constraint.constraint_type == "CHECK":
|
|
112
|
+
# Build check constraint
|
|
113
|
+
check_expr = exp.Check(this=exp.maybe_parse(constraint.condition) if constraint.condition else None)
|
|
114
|
+
|
|
115
|
+
if constraint.name:
|
|
116
|
+
return exp.Constraint(this=exp.to_identifier(constraint.name), expression=check_expr)
|
|
117
|
+
return check_expr
|
|
118
|
+
|
|
119
|
+
return None
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Centralized parsing utilities for SQLSpec builders.
|
|
2
|
+
|
|
3
|
+
This module provides common parsing functions to handle complex SQL expressions
|
|
4
|
+
that users might pass as strings to various builder methods.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import contextlib
|
|
8
|
+
from typing import Any, Optional, Union, cast
|
|
9
|
+
|
|
10
|
+
from sqlglot import exp, maybe_parse, parse_one
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def parse_column_expression(column_input: Union[str, exp.Expression]) -> exp.Expression:
|
|
14
|
+
"""Parse a column input that might be a complex expression.
|
|
15
|
+
|
|
16
|
+
Handles cases like:
|
|
17
|
+
- Simple column names: "name" -> Column(this=name)
|
|
18
|
+
- Qualified names: "users.name" -> Column(table=users, this=name)
|
|
19
|
+
- Aliased columns: "name AS user_name" -> Alias(this=Column(name), alias=user_name)
|
|
20
|
+
- Function calls: "MAX(price)" -> Max(this=Column(price))
|
|
21
|
+
- Complex expressions: "CASE WHEN ... END" -> Case(...)
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
column_input: String or SQLGlot expression representing a column/expression
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
exp.Expression: Parsed SQLGlot expression
|
|
28
|
+
"""
|
|
29
|
+
if isinstance(column_input, exp.Expression):
|
|
30
|
+
return column_input
|
|
31
|
+
return exp.maybe_parse(column_input) or exp.column(str(column_input))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def parse_table_expression(table_input: str, explicit_alias: Optional[str] = None) -> exp.Expression:
|
|
35
|
+
"""Parses a table string that can be a name, a name with an alias, or a subquery string."""
|
|
36
|
+
with contextlib.suppress(Exception):
|
|
37
|
+
# Wrapping in a SELECT statement is a robust way to parse various table-like syntaxes
|
|
38
|
+
parsed = parse_one(f"SELECT * FROM {table_input}")
|
|
39
|
+
if isinstance(parsed, exp.Select) and parsed.args.get("from"):
|
|
40
|
+
from_clause = cast("exp.From", parsed.args.get("from"))
|
|
41
|
+
table_expr = from_clause.this
|
|
42
|
+
|
|
43
|
+
if explicit_alias:
|
|
44
|
+
return exp.alias_(table_expr, explicit_alias) # type:ignore[no-any-return]
|
|
45
|
+
return table_expr # type:ignore[no-any-return]
|
|
46
|
+
|
|
47
|
+
return exp.to_table(table_input, alias=explicit_alias)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def parse_order_expression(order_input: Union[str, exp.Expression]) -> exp.Expression:
|
|
51
|
+
"""Parse an ORDER BY expression that might include direction.
|
|
52
|
+
|
|
53
|
+
Handles cases like:
|
|
54
|
+
- Simple column: "name" -> Column(this=name)
|
|
55
|
+
- With direction: "name DESC" -> Ordered(this=Column(name), desc=True)
|
|
56
|
+
- Qualified: "users.name ASC" -> Ordered(this=Column(table=users, this=name), desc=False)
|
|
57
|
+
- Function: "COUNT(*) DESC" -> Ordered(this=Count(this=Star), desc=True)
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
order_input: String or SQLGlot expression for ORDER BY
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
exp.Expression: Parsed SQLGlot expression (usually Ordered or Column)
|
|
64
|
+
"""
|
|
65
|
+
if isinstance(order_input, exp.Expression):
|
|
66
|
+
return order_input
|
|
67
|
+
|
|
68
|
+
with contextlib.suppress(Exception):
|
|
69
|
+
parsed = maybe_parse(str(order_input), into=exp.Ordered)
|
|
70
|
+
if parsed:
|
|
71
|
+
return parsed
|
|
72
|
+
|
|
73
|
+
return parse_column_expression(order_input)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def parse_condition_expression(
|
|
77
|
+
condition_input: Union[str, exp.Expression, tuple[str, Any]], builder: "Any" = None
|
|
78
|
+
) -> exp.Expression:
|
|
79
|
+
"""Parse a condition that might be complex SQL.
|
|
80
|
+
|
|
81
|
+
Handles cases like:
|
|
82
|
+
- Simple conditions: "name = 'John'" -> EQ(Column(name), Literal('John'))
|
|
83
|
+
- Tuple format: ("name", "John") -> EQ(Column(name), Literal('John'))
|
|
84
|
+
- Complex conditions: "age > 18 AND status = 'active'" -> And(GT(...), EQ(...))
|
|
85
|
+
- Function conditions: "LENGTH(name) > 5" -> GT(Length(Column(name)), Literal(5))
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
condition_input: String, tuple, or SQLGlot expression for condition
|
|
89
|
+
builder: Optional builder instance for parameter binding
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
exp.Expression: Parsed SQLGlot expression (usually a comparison or logical op)
|
|
93
|
+
"""
|
|
94
|
+
if isinstance(condition_input, exp.Expression):
|
|
95
|
+
return condition_input
|
|
96
|
+
|
|
97
|
+
tuple_condition_parts = 2
|
|
98
|
+
if isinstance(condition_input, tuple) and len(condition_input) == tuple_condition_parts:
|
|
99
|
+
# Handle (column, value) tuple format with proper parameter binding
|
|
100
|
+
column, value = condition_input
|
|
101
|
+
column_expr = parse_column_expression(column)
|
|
102
|
+
if value is None:
|
|
103
|
+
return exp.Is(this=column_expr, expression=exp.null())
|
|
104
|
+
# Use builder's parameter system if available
|
|
105
|
+
if builder and hasattr(builder, "add_parameter"):
|
|
106
|
+
_, param_name = builder.add_parameter(value)
|
|
107
|
+
return exp.EQ(this=column_expr, expression=exp.Placeholder(this=param_name))
|
|
108
|
+
# Fallback to literal value
|
|
109
|
+
if isinstance(value, str):
|
|
110
|
+
return exp.EQ(this=column_expr, expression=exp.Literal.string(value))
|
|
111
|
+
if isinstance(value, (int, float)):
|
|
112
|
+
return exp.EQ(this=column_expr, expression=exp.Literal.number(str(value)))
|
|
113
|
+
return exp.EQ(this=column_expr, expression=exp.Literal.string(str(value)))
|
|
114
|
+
|
|
115
|
+
if not isinstance(condition_input, str):
|
|
116
|
+
condition_input = str(condition_input)
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
# Parse as condition using SQLGlot's condition parser
|
|
120
|
+
return exp.condition(condition_input)
|
|
121
|
+
except Exception:
|
|
122
|
+
# If that fails, try parsing as a general expression
|
|
123
|
+
try:
|
|
124
|
+
parsed = exp.maybe_parse(condition_input) # type: ignore[var-annotated]
|
|
125
|
+
if parsed:
|
|
126
|
+
return parsed # type:ignore[no-any-return]
|
|
127
|
+
except Exception: # noqa: S110
|
|
128
|
+
# SQLGlot condition parsing failed, will use raw condition
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
# Ultimate fallback: treat as raw condition string
|
|
132
|
+
return exp.condition(condition_input)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
__all__ = ("parse_column_expression", "parse_condition_expression", "parse_order_expression", "parse_table_expression")
|
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
"""Safe SQL query builder with validation and parameter binding.
|
|
2
|
+
|
|
3
|
+
This module provides a fluent interface for building SQL queries safely,
|
|
4
|
+
with automatic parameter binding and validation. Enhanced with SQLGlot's
|
|
5
|
+
advanced builder patterns and optimization capabilities.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Generic, NoReturn, Optional, Union
|
|
11
|
+
|
|
12
|
+
import sqlglot
|
|
13
|
+
from sqlglot import Dialect, exp
|
|
14
|
+
from sqlglot.dialects.dialect import DialectType
|
|
15
|
+
from sqlglot.errors import ParseError as SQLGlotParseError
|
|
16
|
+
from sqlglot.optimizer import optimize
|
|
17
|
+
from typing_extensions import Self
|
|
18
|
+
|
|
19
|
+
from sqlspec.exceptions import SQLBuilderError
|
|
20
|
+
from sqlspec.statement.sql import SQL, SQLConfig
|
|
21
|
+
from sqlspec.typing import RowT
|
|
22
|
+
from sqlspec.utils.logging import get_logger
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from sqlspec.statement.result import SQLResult
|
|
26
|
+
|
|
27
|
+
__all__ = ("QueryBuilder", "SafeQuery")
|
|
28
|
+
|
|
29
|
+
logger = get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(frozen=True)
|
|
33
|
+
class SafeQuery:
|
|
34
|
+
"""A safely constructed SQL query with bound parameters."""
|
|
35
|
+
|
|
36
|
+
sql: str
|
|
37
|
+
parameters: dict[str, Any] = field(default_factory=dict)
|
|
38
|
+
dialect: DialectType = field(default=None)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class QueryBuilder(ABC, Generic[RowT]):
|
|
43
|
+
"""Abstract base class for SQL query builders with SQLGlot optimization.
|
|
44
|
+
|
|
45
|
+
Provides common functionality for dialect handling, parameter management,
|
|
46
|
+
query construction, and automatic query optimization using SQLGlot's
|
|
47
|
+
advanced capabilities.
|
|
48
|
+
|
|
49
|
+
New features:
|
|
50
|
+
- Automatic query optimization (join reordering, predicate pushdown)
|
|
51
|
+
- Query complexity analysis
|
|
52
|
+
- Smart parameter naming based on context
|
|
53
|
+
- Expression caching for performance
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
dialect: DialectType = field(default=None)
|
|
57
|
+
schema: Optional[dict[str, dict[str, str]]] = field(default=None)
|
|
58
|
+
_expression: Optional[exp.Expression] = field(default=None, init=False, repr=False, compare=False, hash=False)
|
|
59
|
+
_parameters: dict[str, Any] = field(default_factory=dict, init=False, repr=False, compare=False, hash=False)
|
|
60
|
+
_parameter_counter: int = field(default=0, init=False, repr=False, compare=False, hash=False)
|
|
61
|
+
_with_ctes: dict[str, exp.CTE] = field(default_factory=dict, init=False, repr=False, compare=False, hash=False)
|
|
62
|
+
enable_optimization: bool = field(default=True, init=True)
|
|
63
|
+
optimize_joins: bool = field(default=True, init=True)
|
|
64
|
+
optimize_predicates: bool = field(default=True, init=True)
|
|
65
|
+
simplify_expressions: bool = field(default=True, init=True)
|
|
66
|
+
|
|
67
|
+
def __post_init__(self) -> None:
|
|
68
|
+
self._expression = self._create_base_expression()
|
|
69
|
+
if not self._expression:
|
|
70
|
+
# This path should be unreachable if _raise_sql_builder_error has NoReturn
|
|
71
|
+
self._raise_sql_builder_error(
|
|
72
|
+
"QueryBuilder._create_base_expression must return a valid sqlglot expression."
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
def _create_base_expression(self) -> exp.Expression:
|
|
77
|
+
"""Create the base sqlglot expression for the specific query type.
|
|
78
|
+
|
|
79
|
+
Examples:
|
|
80
|
+
For a SELECT query, this would return `exp.Select()`.
|
|
81
|
+
For an INSERT query, this would return `exp.Insert()`.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
exp.Expression: A new sqlglot expression.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def _expected_result_type(self) -> "type[SQLResult[RowT]]":
|
|
90
|
+
"""The expected result type for the query being built.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
type[ResultT]: The type of the result.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def _raise_sql_builder_error(message: str, cause: Optional[BaseException] = None) -> NoReturn:
|
|
98
|
+
"""Helper to raise SQLBuilderError, potentially with a cause.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
message: The error message.
|
|
102
|
+
cause: The optional original exception to chain.
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
SQLBuilderError: Always raises this exception.
|
|
106
|
+
"""
|
|
107
|
+
raise SQLBuilderError(message) from cause
|
|
108
|
+
|
|
109
|
+
def _add_parameter(self, value: Any, context: Optional[str] = None) -> str:
|
|
110
|
+
"""Adds a parameter to the query and returns its placeholder name.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
value: The value of the parameter.
|
|
114
|
+
context: Optional context hint for parameter naming (e.g., "where", "join")
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
str: The placeholder name for the parameter (e.g., :param_1 or :where_param_1).
|
|
118
|
+
"""
|
|
119
|
+
self._parameter_counter += 1
|
|
120
|
+
|
|
121
|
+
# Use context-aware naming if provided
|
|
122
|
+
param_name = f"{context}_param_{self._parameter_counter}" if context else f"param_{self._parameter_counter}"
|
|
123
|
+
|
|
124
|
+
self._parameters[param_name] = value
|
|
125
|
+
return param_name
|
|
126
|
+
|
|
127
|
+
def add_parameter(self: Self, value: Any, name: Optional[str] = None) -> tuple[Self, str]:
|
|
128
|
+
"""Explicitly adds a parameter to the query.
|
|
129
|
+
|
|
130
|
+
This is useful for parameters that are not directly tied to a
|
|
131
|
+
builder method like `where` or `values`.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
value: The value of the parameter.
|
|
135
|
+
name: Optional explicit name for the parameter. If None, a name
|
|
136
|
+
will be generated.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
tuple[Self, str]: The builder instance and the parameter name.
|
|
140
|
+
"""
|
|
141
|
+
if name:
|
|
142
|
+
if name in self._parameters:
|
|
143
|
+
self._raise_sql_builder_error(f"Parameter name '{name}' already exists.")
|
|
144
|
+
param_name_to_use = name
|
|
145
|
+
else:
|
|
146
|
+
self._parameter_counter += 1
|
|
147
|
+
param_name_to_use = f"param_{self._parameter_counter}"
|
|
148
|
+
|
|
149
|
+
self._parameters[param_name_to_use] = value
|
|
150
|
+
return self, param_name_to_use
|
|
151
|
+
|
|
152
|
+
def _generate_unique_parameter_name(self, base_name: str) -> str:
|
|
153
|
+
"""Generate unique parameter name when collision occurs.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
base_name: The desired base name for the parameter
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
A unique parameter name that doesn't exist in current parameters
|
|
160
|
+
"""
|
|
161
|
+
if base_name not in self._parameters:
|
|
162
|
+
return base_name
|
|
163
|
+
|
|
164
|
+
i = 1
|
|
165
|
+
while True:
|
|
166
|
+
name = f"{base_name}_{i}"
|
|
167
|
+
if name not in self._parameters:
|
|
168
|
+
return name
|
|
169
|
+
i += 1
|
|
170
|
+
|
|
171
|
+
def with_cte(self: Self, alias: str, query: "Union[QueryBuilder[Any], exp.Select, str]") -> Self:
|
|
172
|
+
"""Adds a Common Table Expression (CTE) to the query.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
alias: The alias for the CTE.
|
|
176
|
+
query: The CTE query, which can be another QueryBuilder instance,
|
|
177
|
+
a raw SQL string, or a sqlglot Select expression.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Self: The current builder instance for method chaining.
|
|
181
|
+
"""
|
|
182
|
+
if alias in self._with_ctes:
|
|
183
|
+
self._raise_sql_builder_error(f"CTE with alias '{alias}' already exists.")
|
|
184
|
+
|
|
185
|
+
cte_select_expression: exp.Select
|
|
186
|
+
|
|
187
|
+
if isinstance(query, QueryBuilder):
|
|
188
|
+
if query._expression is None:
|
|
189
|
+
self._raise_sql_builder_error("CTE query builder has no expression.")
|
|
190
|
+
if not isinstance(query._expression, exp.Select):
|
|
191
|
+
msg = f"CTE query builder expression must be a Select, got {type(query._expression).__name__}."
|
|
192
|
+
self._raise_sql_builder_error(msg)
|
|
193
|
+
cte_select_expression = query._expression.copy()
|
|
194
|
+
for p_name, p_value in query._parameters.items():
|
|
195
|
+
self.add_parameter(p_value, f"cte_{alias}_{p_name}")
|
|
196
|
+
|
|
197
|
+
elif isinstance(query, str):
|
|
198
|
+
try:
|
|
199
|
+
parsed_expression = sqlglot.parse_one(query, read=self.dialect_name)
|
|
200
|
+
if not isinstance(parsed_expression, exp.Select):
|
|
201
|
+
msg = f"CTE query string must parse to a SELECT statement, got {type(parsed_expression).__name__}."
|
|
202
|
+
self._raise_sql_builder_error(msg)
|
|
203
|
+
# parsed_expression is now known to be exp.Select
|
|
204
|
+
cte_select_expression = parsed_expression
|
|
205
|
+
except SQLGlotParseError as e:
|
|
206
|
+
self._raise_sql_builder_error(f"Failed to parse CTE query string: {e!s}", e)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
msg = f"An unexpected error occurred while parsing CTE query string: {e!s}"
|
|
209
|
+
self._raise_sql_builder_error(msg, e)
|
|
210
|
+
elif isinstance(query, exp.Select):
|
|
211
|
+
cte_select_expression = query.copy()
|
|
212
|
+
else:
|
|
213
|
+
msg = f"Invalid query type for CTE: {type(query).__name__}"
|
|
214
|
+
self._raise_sql_builder_error(msg)
|
|
215
|
+
return self # This line won't be reached but satisfies type checkers
|
|
216
|
+
|
|
217
|
+
self._with_ctes[alias] = exp.CTE(this=cte_select_expression, alias=exp.to_table(alias))
|
|
218
|
+
return self
|
|
219
|
+
|
|
220
|
+
def build(self) -> "SafeQuery":
|
|
221
|
+
"""Builds the SQL query string and parameters.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
SafeQuery: A dataclass containing the SQL string and parameters.
|
|
225
|
+
"""
|
|
226
|
+
if self._expression is None:
|
|
227
|
+
self._raise_sql_builder_error("QueryBuilder expression not initialized.")
|
|
228
|
+
|
|
229
|
+
final_expression = self._expression.copy()
|
|
230
|
+
|
|
231
|
+
if self._with_ctes:
|
|
232
|
+
if hasattr(final_expression, "with_") and callable(getattr(final_expression, "with_", None)):
|
|
233
|
+
for alias, cte_node in self._with_ctes.items():
|
|
234
|
+
final_expression = final_expression.with_( # pyright: ignore
|
|
235
|
+
cte_node.args["this"], as_=alias, copy=False
|
|
236
|
+
)
|
|
237
|
+
elif (
|
|
238
|
+
isinstance(final_expression, (exp.Select, exp.Insert, exp.Update, exp.Delete, exp.Union))
|
|
239
|
+
and self._with_ctes
|
|
240
|
+
):
|
|
241
|
+
final_expression = exp.With(expressions=list(self._with_ctes.values()), this=final_expression)
|
|
242
|
+
|
|
243
|
+
# Apply SQLGlot optimizations if enabled
|
|
244
|
+
if self.enable_optimization:
|
|
245
|
+
final_expression = self._optimize_expression(final_expression)
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
sql_string = final_expression.sql(dialect=self.dialect_name, pretty=True)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
err_msg = f"Error generating SQL from expression: {e!s}"
|
|
251
|
+
logger.exception("SQL generation failed")
|
|
252
|
+
self._raise_sql_builder_error(err_msg, e)
|
|
253
|
+
|
|
254
|
+
return SafeQuery(sql=sql_string, parameters=self._parameters.copy(), dialect=self.dialect)
|
|
255
|
+
|
|
256
|
+
def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
|
|
257
|
+
"""Apply SQLGlot optimizations to the expression.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
expression: The expression to optimize
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
The optimized expression
|
|
264
|
+
"""
|
|
265
|
+
if not self.enable_optimization:
|
|
266
|
+
return expression
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
# Use SQLGlot's comprehensive optimizer
|
|
270
|
+
return optimize(
|
|
271
|
+
expression.copy(),
|
|
272
|
+
schema=self.schema,
|
|
273
|
+
dialect=self.dialect_name,
|
|
274
|
+
optimizer_settings={
|
|
275
|
+
"optimize_joins": self.optimize_joins,
|
|
276
|
+
"pushdown_predicates": self.optimize_predicates,
|
|
277
|
+
"simplify_expressions": self.simplify_expressions,
|
|
278
|
+
},
|
|
279
|
+
)
|
|
280
|
+
except Exception:
|
|
281
|
+
# Continue with unoptimized query on failure
|
|
282
|
+
return expression
|
|
283
|
+
|
|
284
|
+
def to_statement(self, config: "Optional[SQLConfig]" = None) -> "SQL":
|
|
285
|
+
"""Converts the built query into a SQL statement object.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
config: Optional SQL configuration.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
SQL: A SQL statement object.
|
|
292
|
+
"""
|
|
293
|
+
safe_query = self.build()
|
|
294
|
+
|
|
295
|
+
return SQL(
|
|
296
|
+
statement=safe_query.sql,
|
|
297
|
+
parameters=safe_query.parameters,
|
|
298
|
+
_dialect=safe_query.dialect,
|
|
299
|
+
_config=config,
|
|
300
|
+
_builder_result_type=self._expected_result_type,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def __str__(self) -> str:
|
|
304
|
+
"""Return the SQL string representation of the query.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
str: The SQL string for this query.
|
|
308
|
+
"""
|
|
309
|
+
try:
|
|
310
|
+
return self.build().sql
|
|
311
|
+
except Exception:
|
|
312
|
+
# Fallback to default representation if build fails
|
|
313
|
+
return super().__str__()
|
|
314
|
+
|
|
315
|
+
@property
|
|
316
|
+
def dialect_name(self) -> "Optional[str]":
|
|
317
|
+
"""Returns the name of the dialect, if set."""
|
|
318
|
+
if isinstance(self.dialect, str):
|
|
319
|
+
return self.dialect
|
|
320
|
+
if self.dialect is not None:
|
|
321
|
+
if isinstance(self.dialect, type) and issubclass(self.dialect, Dialect):
|
|
322
|
+
return self.dialect.__name__.lower()
|
|
323
|
+
if isinstance(self.dialect, Dialect):
|
|
324
|
+
return type(self.dialect).__name__.lower()
|
|
325
|
+
# Handle case where dialect might have a __name__ attribute
|
|
326
|
+
if hasattr(self.dialect, "__name__"):
|
|
327
|
+
return self.dialect.__name__.lower()
|
|
328
|
+
return None
|