sqlspec 0.15.0__py3-none-any.whl → 0.16.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/_sql.py +702 -44
- sqlspec/builder/_base.py +77 -44
- sqlspec/builder/_column.py +0 -4
- sqlspec/builder/_ddl.py +15 -52
- sqlspec/builder/_ddl_utils.py +0 -1
- sqlspec/builder/_delete.py +4 -5
- sqlspec/builder/_insert.py +235 -44
- sqlspec/builder/_merge.py +17 -2
- sqlspec/builder/_parsing_utils.py +42 -14
- sqlspec/builder/_select.py +29 -33
- sqlspec/builder/_update.py +4 -2
- sqlspec/builder/mixins/_cte_and_set_ops.py +47 -20
- sqlspec/builder/mixins/_delete_operations.py +6 -1
- sqlspec/builder/mixins/_insert_operations.py +126 -24
- sqlspec/builder/mixins/_join_operations.py +44 -10
- sqlspec/builder/mixins/_merge_operations.py +183 -25
- sqlspec/builder/mixins/_order_limit_operations.py +15 -3
- sqlspec/builder/mixins/_pivot_operations.py +11 -2
- sqlspec/builder/mixins/_select_operations.py +21 -14
- sqlspec/builder/mixins/_update_operations.py +80 -32
- sqlspec/builder/mixins/_where_clause.py +201 -66
- sqlspec/core/cache.py +26 -28
- sqlspec/core/compiler.py +58 -37
- sqlspec/core/filters.py +12 -10
- sqlspec/core/parameters.py +80 -52
- sqlspec/core/result.py +30 -17
- sqlspec/core/statement.py +47 -22
- sqlspec/driver/_async.py +76 -46
- sqlspec/driver/_common.py +25 -6
- sqlspec/driver/_sync.py +73 -43
- sqlspec/driver/mixins/_result_tools.py +62 -37
- sqlspec/driver/mixins/_sql_translator.py +61 -11
- sqlspec/extensions/litestar/cli.py +1 -1
- sqlspec/extensions/litestar/plugin.py +2 -2
- sqlspec/protocols.py +7 -0
- sqlspec/utils/sync_tools.py +1 -1
- sqlspec/utils/type_guards.py +7 -3
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/METADATA +1 -1
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/RECORD +43 -43
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/WHEEL +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.15.0.dist-info → sqlspec-0.16.2.dist-info}/licenses/NOTICE +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Optional, Union
|
|
4
4
|
|
|
5
|
+
from mypy_extensions import trait
|
|
5
6
|
from sqlglot import exp
|
|
6
7
|
from typing_extensions import Self
|
|
7
8
|
|
|
@@ -10,10 +11,21 @@ from sqlspec.exceptions import SQLBuilderError
|
|
|
10
11
|
__all__ = ("CommonTableExpressionMixin", "SetOperationMixin")
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
@trait
|
|
13
15
|
class CommonTableExpressionMixin:
|
|
14
16
|
"""Mixin providing WITH clause (Common Table Expressions) support for SQL builders."""
|
|
15
17
|
|
|
16
|
-
|
|
18
|
+
__slots__ = ()
|
|
19
|
+
# Type annotation for PyRight - this will be provided by the base class
|
|
20
|
+
_expression: Optional[exp.Expression]
|
|
21
|
+
|
|
22
|
+
_with_ctes: Any # Provided by QueryBuilder
|
|
23
|
+
dialect: Any # Provided by QueryBuilder
|
|
24
|
+
|
|
25
|
+
def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, str]:
|
|
26
|
+
"""Add parameter - provided by QueryBuilder."""
|
|
27
|
+
msg = "Method must be provided by QueryBuilder subclass"
|
|
28
|
+
raise NotImplementedError(msg)
|
|
17
29
|
|
|
18
30
|
def with_(
|
|
19
31
|
self, name: str, query: Union[Any, str], recursive: bool = False, columns: Optional[list[str]] = None
|
|
@@ -42,22 +54,22 @@ class CommonTableExpressionMixin:
|
|
|
42
54
|
|
|
43
55
|
cte_expr: Optional[exp.Expression] = None
|
|
44
56
|
if isinstance(query, str):
|
|
45
|
-
cte_expr = exp.maybe_parse(query, dialect=self.dialect)
|
|
57
|
+
cte_expr = exp.maybe_parse(query, dialect=self.dialect)
|
|
46
58
|
elif isinstance(query, exp.Expression):
|
|
47
59
|
cte_expr = query
|
|
48
60
|
else:
|
|
49
|
-
built_query = query.to_statement()
|
|
61
|
+
built_query = query.to_statement()
|
|
50
62
|
cte_sql = built_query.sql
|
|
51
|
-
cte_expr = exp.maybe_parse(cte_sql, dialect=self.dialect)
|
|
63
|
+
cte_expr = exp.maybe_parse(cte_sql, dialect=self.dialect)
|
|
52
64
|
|
|
53
65
|
parameters = built_query.parameters
|
|
54
66
|
if parameters:
|
|
55
67
|
if isinstance(parameters, dict):
|
|
56
68
|
for param_name, param_value in parameters.items():
|
|
57
|
-
self.add_parameter(param_value, name=param_name)
|
|
69
|
+
self.add_parameter(param_value, name=param_name)
|
|
58
70
|
elif isinstance(parameters, (list, tuple)):
|
|
59
71
|
for param_value in parameters:
|
|
60
|
-
self.add_parameter(param_value)
|
|
72
|
+
self.add_parameter(param_value)
|
|
61
73
|
|
|
62
74
|
if not cte_expr:
|
|
63
75
|
msg = f"Could not parse CTE query: {query}"
|
|
@@ -68,29 +80,42 @@ class CommonTableExpressionMixin:
|
|
|
68
80
|
else:
|
|
69
81
|
cte_alias_expr = exp.alias_(cte_expr, name)
|
|
70
82
|
|
|
71
|
-
existing_with = self._expression.args.get("with")
|
|
83
|
+
existing_with = self._expression.args.get("with")
|
|
72
84
|
if existing_with:
|
|
73
85
|
existing_with.expressions.append(cte_alias_expr)
|
|
74
86
|
if recursive:
|
|
75
87
|
existing_with.set("recursive", recursive)
|
|
76
88
|
else:
|
|
77
|
-
|
|
78
|
-
if
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
89
|
+
# Only SELECT, INSERT, UPDATE support WITH clauses
|
|
90
|
+
if hasattr(self._expression, "with_") and isinstance(
|
|
91
|
+
self._expression, (exp.Select, exp.Insert, exp.Update)
|
|
92
|
+
):
|
|
93
|
+
self._expression = self._expression.with_(cte_alias_expr, as_=name, copy=False)
|
|
94
|
+
if recursive:
|
|
95
|
+
with_clause = self._expression.find(exp.With)
|
|
96
|
+
if with_clause:
|
|
97
|
+
with_clause.set("recursive", recursive)
|
|
98
|
+
self._with_ctes[name] = exp.CTE(this=cte_expr, alias=exp.to_table(name))
|
|
83
99
|
|
|
84
100
|
return self
|
|
85
101
|
|
|
86
102
|
|
|
103
|
+
@trait
|
|
87
104
|
class SetOperationMixin:
|
|
88
105
|
"""Mixin providing set operations (UNION, INTERSECT, EXCEPT) for SELECT builders."""
|
|
89
106
|
|
|
90
|
-
|
|
107
|
+
__slots__ = ()
|
|
108
|
+
# Type annotation for PyRight - this will be provided by the base class
|
|
109
|
+
_expression: Optional[exp.Expression]
|
|
110
|
+
|
|
91
111
|
_parameters: dict[str, Any]
|
|
92
112
|
dialect: Any = None
|
|
93
113
|
|
|
114
|
+
def build(self) -> Any:
|
|
115
|
+
"""Build the query - provided by QueryBuilder."""
|
|
116
|
+
msg = "Method must be provided by QueryBuilder subclass"
|
|
117
|
+
raise NotImplementedError(msg)
|
|
118
|
+
|
|
94
119
|
def union(self, other: Any, all_: bool = False) -> Self:
|
|
95
120
|
"""Combine this query with another using UNION.
|
|
96
121
|
|
|
@@ -104,7 +129,7 @@ class SetOperationMixin:
|
|
|
104
129
|
Returns:
|
|
105
130
|
The new builder instance for the union query.
|
|
106
131
|
"""
|
|
107
|
-
left_query = self.build()
|
|
132
|
+
left_query = self.build()
|
|
108
133
|
right_query = other.build()
|
|
109
134
|
left_expr: Optional[exp.Expression] = exp.maybe_parse(left_query.sql, dialect=self.dialect)
|
|
110
135
|
right_expr: Optional[exp.Expression] = exp.maybe_parse(right_query.sql, dialect=self.dialect)
|
|
@@ -124,9 +149,11 @@ class SetOperationMixin:
|
|
|
124
149
|
counter += 1
|
|
125
150
|
new_param_name = f"{param_name}_right_{counter}"
|
|
126
151
|
|
|
127
|
-
def rename_parameter(
|
|
128
|
-
|
|
129
|
-
|
|
152
|
+
def rename_parameter(
|
|
153
|
+
node: exp.Expression, old_name: str = param_name, new_name: str = new_param_name
|
|
154
|
+
) -> exp.Expression:
|
|
155
|
+
if isinstance(node, exp.Placeholder) and node.name == old_name:
|
|
156
|
+
return exp.Placeholder(this=new_name)
|
|
130
157
|
return node
|
|
131
158
|
|
|
132
159
|
right_expr = right_expr.transform(rename_parameter)
|
|
@@ -150,7 +177,7 @@ class SetOperationMixin:
|
|
|
150
177
|
Returns:
|
|
151
178
|
The new builder instance for the intersect query.
|
|
152
179
|
"""
|
|
153
|
-
left_query = self.build()
|
|
180
|
+
left_query = self.build()
|
|
154
181
|
right_query = other.build()
|
|
155
182
|
left_expr: Optional[exp.Expression] = exp.maybe_parse(left_query.sql, dialect=self.dialect)
|
|
156
183
|
right_expr: Optional[exp.Expression] = exp.maybe_parse(right_query.sql, dialect=self.dialect)
|
|
@@ -178,7 +205,7 @@ class SetOperationMixin:
|
|
|
178
205
|
Returns:
|
|
179
206
|
The new builder instance for the except query.
|
|
180
207
|
"""
|
|
181
|
-
left_query = self.build()
|
|
208
|
+
left_query = self.build()
|
|
182
209
|
right_query = other.build()
|
|
183
210
|
left_expr: Optional[exp.Expression] = exp.maybe_parse(left_query.sql, dialect=self.dialect)
|
|
184
211
|
right_expr: Optional[exp.Expression] = exp.maybe_parse(right_query.sql, dialect=self.dialect)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
|
+
from mypy_extensions import trait
|
|
5
6
|
from sqlglot import exp
|
|
6
7
|
from typing_extensions import Self
|
|
7
8
|
|
|
@@ -10,10 +11,14 @@ from sqlspec.exceptions import SQLBuilderError
|
|
|
10
11
|
__all__ = ("DeleteFromClauseMixin",)
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
@trait
|
|
13
15
|
class DeleteFromClauseMixin:
|
|
14
16
|
"""Mixin providing FROM clause for DELETE builders."""
|
|
15
17
|
|
|
16
|
-
|
|
18
|
+
__slots__ = ()
|
|
19
|
+
|
|
20
|
+
# Type annotation for PyRight - this will be provided by the base class
|
|
21
|
+
_expression: Optional[exp.Expression]
|
|
17
22
|
|
|
18
23
|
def from_(self, table: str) -> Self:
|
|
19
24
|
"""Set the target table for the DELETE statement.
|
|
@@ -1,20 +1,28 @@
|
|
|
1
1
|
"""Insert operation mixins for SQL builders."""
|
|
2
2
|
|
|
3
3
|
from collections.abc import Sequence
|
|
4
|
-
from typing import Any, Optional, Union
|
|
4
|
+
from typing import Any, Optional, TypeVar, Union
|
|
5
5
|
|
|
6
|
+
from mypy_extensions import trait
|
|
6
7
|
from sqlglot import exp
|
|
7
8
|
from typing_extensions import Self
|
|
8
9
|
|
|
9
10
|
from sqlspec.exceptions import SQLBuilderError
|
|
11
|
+
from sqlspec.protocols import SQLBuilderProtocol
|
|
12
|
+
|
|
13
|
+
BuilderT = TypeVar("BuilderT", bound=SQLBuilderProtocol)
|
|
10
14
|
|
|
11
15
|
__all__ = ("InsertFromSelectMixin", "InsertIntoClauseMixin", "InsertValuesMixin")
|
|
12
16
|
|
|
13
17
|
|
|
18
|
+
@trait
|
|
14
19
|
class InsertIntoClauseMixin:
|
|
15
20
|
"""Mixin providing INTO clause for INSERT builders."""
|
|
16
21
|
|
|
17
|
-
|
|
22
|
+
__slots__ = ()
|
|
23
|
+
|
|
24
|
+
# Type annotation for PyRight - this will be provided by the base class
|
|
25
|
+
_expression: Optional[exp.Expression]
|
|
18
26
|
|
|
19
27
|
def into(self, table: str) -> Self:
|
|
20
28
|
"""Set the target table for the INSERT statement.
|
|
@@ -39,10 +47,26 @@ class InsertIntoClauseMixin:
|
|
|
39
47
|
return self
|
|
40
48
|
|
|
41
49
|
|
|
50
|
+
@trait
|
|
42
51
|
class InsertValuesMixin:
|
|
43
52
|
"""Mixin providing VALUES and columns methods for INSERT builders."""
|
|
44
53
|
|
|
45
|
-
|
|
54
|
+
__slots__ = ()
|
|
55
|
+
|
|
56
|
+
# Type annotation for PyRight - this will be provided by the base class
|
|
57
|
+
_expression: Optional[exp.Expression]
|
|
58
|
+
|
|
59
|
+
_columns: Any # Provided by QueryBuilder
|
|
60
|
+
|
|
61
|
+
def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, str]:
|
|
62
|
+
"""Add parameter - provided by QueryBuilder."""
|
|
63
|
+
msg = "Method must be provided by QueryBuilder subclass"
|
|
64
|
+
raise NotImplementedError(msg)
|
|
65
|
+
|
|
66
|
+
def _generate_unique_parameter_name(self, base_name: str) -> str:
|
|
67
|
+
"""Generate unique parameter name - provided by QueryBuilder."""
|
|
68
|
+
msg = "Method must be provided by QueryBuilder subclass"
|
|
69
|
+
raise NotImplementedError(msg)
|
|
46
70
|
|
|
47
71
|
def columns(self, *columns: Union[str, exp.Expression]) -> Self:
|
|
48
72
|
"""Set the columns for the INSERT statement and synchronize the _columns attribute on the builder."""
|
|
@@ -54,7 +78,7 @@ class InsertValuesMixin:
|
|
|
54
78
|
column_exprs = [exp.column(col) if isinstance(col, str) else col for col in columns]
|
|
55
79
|
self._expression.set("columns", column_exprs)
|
|
56
80
|
try:
|
|
57
|
-
cols = self._columns
|
|
81
|
+
cols = self._columns
|
|
58
82
|
if not columns:
|
|
59
83
|
cols.clear()
|
|
60
84
|
else:
|
|
@@ -63,27 +87,94 @@ class InsertValuesMixin:
|
|
|
63
87
|
pass
|
|
64
88
|
return self
|
|
65
89
|
|
|
66
|
-
def values(self, *values: Any) -> Self:
|
|
67
|
-
"""Add a row of values to the INSERT statement
|
|
90
|
+
def values(self, *values: Any, **kwargs: Any) -> Self:
|
|
91
|
+
"""Add a row of values to the INSERT statement.
|
|
92
|
+
|
|
93
|
+
Supports:
|
|
94
|
+
- values(val1, val2, val3)
|
|
95
|
+
- values(col1=val1, col2=val2)
|
|
96
|
+
- values(mapping)
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
*values: Either positional values or a single mapping.
|
|
100
|
+
**kwargs: Column-value pairs.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
The current builder instance for method chaining.
|
|
104
|
+
"""
|
|
68
105
|
if self._expression is None:
|
|
69
106
|
self._expression = exp.Insert()
|
|
70
107
|
if not isinstance(self._expression, exp.Insert):
|
|
71
108
|
msg = "Cannot add values to a non-INSERT expression."
|
|
72
109
|
raise SQLBuilderError(msg)
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
if
|
|
76
|
-
msg =
|
|
110
|
+
|
|
111
|
+
if kwargs:
|
|
112
|
+
if values:
|
|
113
|
+
msg = "Cannot mix positional values with keyword values."
|
|
77
114
|
raise SQLBuilderError(msg)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
115
|
+
try:
|
|
116
|
+
_columns = self._columns
|
|
117
|
+
if not _columns:
|
|
118
|
+
self.columns(*kwargs.keys())
|
|
119
|
+
except AttributeError:
|
|
120
|
+
pass
|
|
121
|
+
row_exprs = []
|
|
122
|
+
for col, val in kwargs.items():
|
|
123
|
+
if isinstance(val, exp.Expression):
|
|
124
|
+
row_exprs.append(val)
|
|
125
|
+
else:
|
|
126
|
+
column_name = col if isinstance(col, str) else str(col)
|
|
127
|
+
if "." in column_name:
|
|
128
|
+
column_name = column_name.split(".")[-1]
|
|
129
|
+
param_name = self._generate_unique_parameter_name(column_name)
|
|
130
|
+
_, param_name = self.add_parameter(val, name=param_name)
|
|
131
|
+
row_exprs.append(exp.var(param_name))
|
|
132
|
+
elif len(values) == 1 and hasattr(values[0], "items"):
|
|
133
|
+
mapping = values[0]
|
|
134
|
+
try:
|
|
135
|
+
_columns = self._columns
|
|
136
|
+
if not _columns:
|
|
137
|
+
self.columns(*mapping.keys())
|
|
138
|
+
except AttributeError:
|
|
139
|
+
pass
|
|
140
|
+
row_exprs = []
|
|
141
|
+
for col, val in mapping.items():
|
|
142
|
+
if isinstance(val, exp.Expression):
|
|
143
|
+
row_exprs.append(val)
|
|
144
|
+
else:
|
|
145
|
+
column_name = col if isinstance(col, str) else str(col)
|
|
146
|
+
if "." in column_name:
|
|
147
|
+
column_name = column_name.split(".")[-1]
|
|
148
|
+
param_name = self._generate_unique_parameter_name(column_name)
|
|
149
|
+
_, param_name = self.add_parameter(val, name=param_name)
|
|
150
|
+
row_exprs.append(exp.var(param_name))
|
|
151
|
+
else:
|
|
152
|
+
try:
|
|
153
|
+
_columns = self._columns
|
|
154
|
+
if _columns and len(values) != len(_columns):
|
|
155
|
+
msg = f"Number of values ({len(values)}) does not match the number of specified columns ({len(_columns)})."
|
|
156
|
+
raise SQLBuilderError(msg)
|
|
157
|
+
except AttributeError:
|
|
158
|
+
pass
|
|
159
|
+
row_exprs = []
|
|
160
|
+
for i, v in enumerate(values):
|
|
161
|
+
if isinstance(v, exp.Expression):
|
|
162
|
+
row_exprs.append(v)
|
|
163
|
+
else:
|
|
164
|
+
try:
|
|
165
|
+
_columns = self._columns
|
|
166
|
+
if _columns and i < len(_columns):
|
|
167
|
+
column_name = (
|
|
168
|
+
str(_columns[i]).split(".")[-1] if "." in str(_columns[i]) else str(_columns[i])
|
|
169
|
+
)
|
|
170
|
+
param_name = self._generate_unique_parameter_name(column_name)
|
|
171
|
+
else:
|
|
172
|
+
param_name = self._generate_unique_parameter_name(f"value_{i + 1}")
|
|
173
|
+
except AttributeError:
|
|
174
|
+
param_name = self._generate_unique_parameter_name(f"value_{i + 1}")
|
|
175
|
+
_, param_name = self.add_parameter(v, name=param_name)
|
|
176
|
+
row_exprs.append(exp.var(param_name))
|
|
177
|
+
|
|
87
178
|
values_expr = exp.Values(expressions=[row_exprs])
|
|
88
179
|
self._expression.set("expression", values_expr)
|
|
89
180
|
return self
|
|
@@ -100,10 +191,21 @@ class InsertValuesMixin:
|
|
|
100
191
|
return self.values(*values)
|
|
101
192
|
|
|
102
193
|
|
|
194
|
+
@trait
|
|
103
195
|
class InsertFromSelectMixin:
|
|
104
196
|
"""Mixin providing INSERT ... SELECT support for INSERT builders."""
|
|
105
197
|
|
|
106
|
-
|
|
198
|
+
__slots__ = ()
|
|
199
|
+
|
|
200
|
+
# Type annotation for PyRight - this will be provided by the base class
|
|
201
|
+
_expression: Optional[exp.Expression]
|
|
202
|
+
|
|
203
|
+
_table: Any # Provided by QueryBuilder
|
|
204
|
+
|
|
205
|
+
def add_parameter(self, value: Any, name: Optional[str] = None) -> tuple[Any, str]:
|
|
206
|
+
"""Add parameter - provided by QueryBuilder."""
|
|
207
|
+
msg = "Method must be provided by QueryBuilder subclass"
|
|
208
|
+
raise NotImplementedError(msg)
|
|
107
209
|
|
|
108
210
|
def from_select(self, select_builder: Any) -> Self:
|
|
109
211
|
"""Sets the INSERT source to a SELECT statement.
|
|
@@ -118,7 +220,7 @@ class InsertFromSelectMixin:
|
|
|
118
220
|
SQLBuilderError: If the table is not set or the select_builder is invalid.
|
|
119
221
|
"""
|
|
120
222
|
try:
|
|
121
|
-
if not self._table:
|
|
223
|
+
if not self._table:
|
|
122
224
|
msg = "The target table must be set using .into() before adding values."
|
|
123
225
|
raise SQLBuilderError(msg)
|
|
124
226
|
except AttributeError:
|
|
@@ -129,11 +231,11 @@ class InsertFromSelectMixin:
|
|
|
129
231
|
if not isinstance(self._expression, exp.Insert):
|
|
130
232
|
msg = "Cannot set INSERT source on a non-INSERT expression."
|
|
131
233
|
raise SQLBuilderError(msg)
|
|
132
|
-
subquery_parameters = select_builder._parameters
|
|
234
|
+
subquery_parameters = select_builder._parameters
|
|
133
235
|
if subquery_parameters:
|
|
134
236
|
for p_name, p_value in subquery_parameters.items():
|
|
135
|
-
self.add_parameter(p_value, name=p_name)
|
|
136
|
-
select_expr = select_builder._expression
|
|
237
|
+
self.add_parameter(p_value, name=p_name)
|
|
238
|
+
select_expr = select_builder._expression
|
|
137
239
|
if select_expr and isinstance(select_expr, exp.Select):
|
|
138
240
|
self._expression.set("expression", select_expr.copy())
|
|
139
241
|
else:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
|
2
2
|
|
|
3
|
+
from mypy_extensions import trait
|
|
3
4
|
from sqlglot import exp
|
|
4
5
|
from typing_extensions import Self
|
|
5
6
|
|
|
@@ -8,18 +9,25 @@ from sqlspec.exceptions import SQLBuilderError
|
|
|
8
9
|
from sqlspec.utils.type_guards import has_query_builder_parameters
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
12
|
+
from sqlspec.core.statement import SQL
|
|
11
13
|
from sqlspec.protocols import SQLBuilderProtocol
|
|
12
14
|
|
|
13
15
|
__all__ = ("JoinClauseMixin",)
|
|
14
16
|
|
|
15
17
|
|
|
18
|
+
@trait
|
|
16
19
|
class JoinClauseMixin:
|
|
17
20
|
"""Mixin providing JOIN clause methods for SELECT builders."""
|
|
18
21
|
|
|
22
|
+
__slots__ = ()
|
|
23
|
+
|
|
24
|
+
# Type annotation for PyRight - this will be provided by the base class
|
|
25
|
+
_expression: Optional[exp.Expression]
|
|
26
|
+
|
|
19
27
|
def join(
|
|
20
28
|
self,
|
|
21
29
|
table: Union[str, exp.Expression, Any],
|
|
22
|
-
on: Optional[Union[str, exp.Expression]] = None,
|
|
30
|
+
on: Optional[Union[str, exp.Expression, "SQL"]] = None,
|
|
23
31
|
alias: Optional[str] = None,
|
|
24
32
|
join_type: str = "INNER",
|
|
25
33
|
) -> Self:
|
|
@@ -36,12 +44,12 @@ class JoinClauseMixin:
|
|
|
36
44
|
if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
|
|
37
45
|
table_expr_value = getattr(table, "_expression", None)
|
|
38
46
|
if table_expr_value is not None:
|
|
39
|
-
subquery_exp = exp.paren(table_expr_value
|
|
47
|
+
subquery_exp = exp.paren(table_expr_value)
|
|
40
48
|
else:
|
|
41
49
|
subquery_exp = exp.paren(exp.Anonymous(this=""))
|
|
42
50
|
table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
|
|
43
51
|
else:
|
|
44
|
-
subquery = table.build()
|
|
52
|
+
subquery = table.build()
|
|
45
53
|
sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
|
|
46
54
|
subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
|
|
47
55
|
table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
|
|
@@ -49,7 +57,33 @@ class JoinClauseMixin:
|
|
|
49
57
|
table_expr = table
|
|
50
58
|
on_expr: Optional[exp.Expression] = None
|
|
51
59
|
if on is not None:
|
|
52
|
-
|
|
60
|
+
if isinstance(on, str):
|
|
61
|
+
on_expr = exp.condition(on)
|
|
62
|
+
elif hasattr(on, "expression") and hasattr(on, "sql"):
|
|
63
|
+
# Handle SQL objects (from sql.raw with parameters)
|
|
64
|
+
expression = getattr(on, "expression", None)
|
|
65
|
+
if expression is not None and isinstance(expression, exp.Expression):
|
|
66
|
+
# Merge parameters from SQL object into builder
|
|
67
|
+
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
|
|
68
|
+
sql_parameters = getattr(on, "parameters", {})
|
|
69
|
+
for param_name, param_value in sql_parameters.items():
|
|
70
|
+
builder.add_parameter(param_value, name=param_name)
|
|
71
|
+
on_expr = expression
|
|
72
|
+
else:
|
|
73
|
+
# If expression is None, fall back to parsing the raw SQL
|
|
74
|
+
sql_text = getattr(on, "sql", "")
|
|
75
|
+
# Merge parameters even when parsing raw SQL
|
|
76
|
+
if hasattr(on, "parameters") and hasattr(builder, "add_parameter"):
|
|
77
|
+
sql_parameters = getattr(on, "parameters", {})
|
|
78
|
+
for param_name, param_value in sql_parameters.items():
|
|
79
|
+
builder.add_parameter(param_value, name=param_name)
|
|
80
|
+
on_expr = exp.maybe_parse(sql_text) or exp.condition(str(sql_text))
|
|
81
|
+
# For other types (should be exp.Expression)
|
|
82
|
+
elif isinstance(on, exp.Expression):
|
|
83
|
+
on_expr = on
|
|
84
|
+
else:
|
|
85
|
+
# Last resort - convert to string and parse
|
|
86
|
+
on_expr = exp.condition(str(on))
|
|
53
87
|
join_type_upper = join_type.upper()
|
|
54
88
|
if join_type_upper == "INNER":
|
|
55
89
|
join_expr = exp.Join(this=table_expr, on=on_expr)
|
|
@@ -66,22 +100,22 @@ class JoinClauseMixin:
|
|
|
66
100
|
return cast("Self", builder)
|
|
67
101
|
|
|
68
102
|
def inner_join(
|
|
69
|
-
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
|
|
103
|
+
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
|
|
70
104
|
) -> Self:
|
|
71
105
|
return self.join(table, on, alias, "INNER")
|
|
72
106
|
|
|
73
107
|
def left_join(
|
|
74
|
-
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
|
|
108
|
+
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
|
|
75
109
|
) -> Self:
|
|
76
110
|
return self.join(table, on, alias, "LEFT")
|
|
77
111
|
|
|
78
112
|
def right_join(
|
|
79
|
-
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
|
|
113
|
+
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
|
|
80
114
|
) -> Self:
|
|
81
115
|
return self.join(table, on, alias, "RIGHT")
|
|
82
116
|
|
|
83
117
|
def full_join(
|
|
84
|
-
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression], alias: Optional[str] = None
|
|
118
|
+
self, table: Union[str, exp.Expression, Any], on: Union[str, exp.Expression, "SQL"], alias: Optional[str] = None
|
|
85
119
|
) -> Self:
|
|
86
120
|
return self.join(table, on, alias, "FULL")
|
|
87
121
|
|
|
@@ -99,12 +133,12 @@ class JoinClauseMixin:
|
|
|
99
133
|
if hasattr(table, "_expression") and getattr(table, "_expression", None) is not None:
|
|
100
134
|
table_expr_value = getattr(table, "_expression", None)
|
|
101
135
|
if table_expr_value is not None:
|
|
102
|
-
subquery_exp = exp.paren(table_expr_value
|
|
136
|
+
subquery_exp = exp.paren(table_expr_value)
|
|
103
137
|
else:
|
|
104
138
|
subquery_exp = exp.paren(exp.Anonymous(this=""))
|
|
105
139
|
table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
|
|
106
140
|
else:
|
|
107
|
-
subquery = table.build()
|
|
141
|
+
subquery = table.build()
|
|
108
142
|
sql_str = subquery.sql if hasattr(subquery, "sql") and not callable(subquery.sql) else str(subquery)
|
|
109
143
|
subquery_exp = exp.paren(exp.maybe_parse(sql_str, dialect=getattr(builder, "dialect", None)))
|
|
110
144
|
table_expr = exp.alias_(subquery_exp, alias) if alias else subquery_exp
|