sqlspec 0.16.2__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sqlspec might be problematic. Click here for more details.
- sqlspec/__init__.py +11 -1
- sqlspec/_sql.py +152 -489
- sqlspec/adapters/aiosqlite/__init__.py +11 -1
- sqlspec/adapters/aiosqlite/config.py +137 -165
- sqlspec/adapters/aiosqlite/driver.py +21 -10
- sqlspec/adapters/aiosqlite/pool.py +492 -0
- sqlspec/adapters/duckdb/__init__.py +2 -0
- sqlspec/adapters/duckdb/config.py +11 -235
- sqlspec/adapters/duckdb/pool.py +243 -0
- sqlspec/adapters/sqlite/__init__.py +2 -0
- sqlspec/adapters/sqlite/config.py +4 -115
- sqlspec/adapters/sqlite/pool.py +140 -0
- sqlspec/base.py +147 -26
- sqlspec/builder/__init__.py +6 -0
- sqlspec/builder/_column.py +5 -1
- sqlspec/builder/_expression_wrappers.py +46 -0
- sqlspec/builder/_insert.py +1 -3
- sqlspec/builder/_parsing_utils.py +27 -0
- sqlspec/builder/_update.py +5 -5
- sqlspec/builder/mixins/_join_operations.py +115 -1
- sqlspec/builder/mixins/_order_limit_operations.py +16 -4
- sqlspec/builder/mixins/_select_operations.py +307 -3
- sqlspec/builder/mixins/_update_operations.py +4 -4
- sqlspec/builder/mixins/_where_clause.py +60 -11
- sqlspec/core/compiler.py +7 -5
- sqlspec/driver/_common.py +9 -1
- sqlspec/loader.py +27 -54
- sqlspec/protocols.py +10 -0
- sqlspec/storage/registry.py +2 -2
- sqlspec/typing.py +53 -99
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/METADATA +1 -1
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/RECORD +36 -32
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/WHEEL +0 -0
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/entry_points.txt +0 -0
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/licenses/LICENSE +0 -0
- {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/_sql.py
CHANGED
|
@@ -4,10 +4,9 @@ Provides both statement builders (select, insert, update, etc.) and column expre
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
|
-
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
8
8
|
|
|
9
9
|
import sqlglot
|
|
10
|
-
from mypy_extensions import trait
|
|
11
10
|
from sqlglot import exp
|
|
12
11
|
from sqlglot.dialects.dialect import DialectType
|
|
13
12
|
from sqlglot.errors import ParseError as SQLGlotParseError
|
|
@@ -34,12 +33,22 @@ from sqlspec.builder import (
|
|
|
34
33
|
Truncate,
|
|
35
34
|
Update,
|
|
36
35
|
)
|
|
36
|
+
from sqlspec.builder._expression_wrappers import (
|
|
37
|
+
AggregateExpression,
|
|
38
|
+
ConversionExpression,
|
|
39
|
+
FunctionExpression,
|
|
40
|
+
MathExpression,
|
|
41
|
+
StringExpression,
|
|
42
|
+
)
|
|
43
|
+
from sqlspec.builder.mixins._join_operations import JoinBuilder
|
|
44
|
+
from sqlspec.builder.mixins._select_operations import Case, SubqueryBuilder, WindowFunctionBuilder
|
|
37
45
|
from sqlspec.exceptions import SQLBuilderError
|
|
38
46
|
|
|
39
47
|
if TYPE_CHECKING:
|
|
40
|
-
from sqlspec.builder.
|
|
48
|
+
from sqlspec.builder._expression_wrappers import ExpressionWrapper
|
|
41
49
|
from sqlspec.core.statement import SQL
|
|
42
50
|
|
|
51
|
+
|
|
43
52
|
__all__ = (
|
|
44
53
|
"AlterTable",
|
|
45
54
|
"Case",
|
|
@@ -179,7 +188,7 @@ class SQLFactory:
|
|
|
179
188
|
# Statement Builders
|
|
180
189
|
# ===================
|
|
181
190
|
def select(
|
|
182
|
-
self, *columns_or_sql: Union[str, exp.Expression, Column, "SQL"], dialect: DialectType = None
|
|
191
|
+
self, *columns_or_sql: Union[str, exp.Expression, Column, "SQL", "Case"], dialect: DialectType = None
|
|
183
192
|
) -> "Select":
|
|
184
193
|
builder_dialect = dialect or self.dialect
|
|
185
194
|
if len(columns_or_sql) == 1 and isinstance(columns_or_sql[0], str):
|
|
@@ -223,7 +232,10 @@ class SQLFactory:
|
|
|
223
232
|
if self._looks_like_sql(table_or_sql):
|
|
224
233
|
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
|
|
225
234
|
if detected != "UPDATE":
|
|
226
|
-
msg =
|
|
235
|
+
msg = (
|
|
236
|
+
f"sql.update() expects UPDATE statement, got {detected}. "
|
|
237
|
+
f"Use sql.{detected.lower()}() if a dedicated builder exists."
|
|
238
|
+
)
|
|
227
239
|
raise SQLBuilderError(msg)
|
|
228
240
|
return self._populate_update_from_sql(builder, table_or_sql)
|
|
229
241
|
return builder.table(table_or_sql)
|
|
@@ -235,7 +247,10 @@ class SQLFactory:
|
|
|
235
247
|
if table_or_sql and self._looks_like_sql(table_or_sql):
|
|
236
248
|
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
|
|
237
249
|
if detected != "DELETE":
|
|
238
|
-
msg =
|
|
250
|
+
msg = (
|
|
251
|
+
f"sql.delete() expects DELETE statement, got {detected}. "
|
|
252
|
+
f"Use sql.{detected.lower()}() if a dedicated builder exists."
|
|
253
|
+
)
|
|
239
254
|
raise SQLBuilderError(msg)
|
|
240
255
|
return self._populate_delete_from_sql(builder, table_or_sql)
|
|
241
256
|
return builder
|
|
@@ -247,7 +262,10 @@ class SQLFactory:
|
|
|
247
262
|
if self._looks_like_sql(table_or_sql):
|
|
248
263
|
detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
|
|
249
264
|
if detected != "MERGE":
|
|
250
|
-
msg =
|
|
265
|
+
msg = (
|
|
266
|
+
f"sql.merge() expects MERGE statement, got {detected}. "
|
|
267
|
+
f"Use sql.{detected.lower()}() if a dedicated builder exists."
|
|
268
|
+
)
|
|
251
269
|
raise SQLBuilderError(msg)
|
|
252
270
|
return self._populate_merge_from_sql(builder, table_or_sql)
|
|
253
271
|
return builder.into(table_or_sql)
|
|
@@ -737,7 +755,9 @@ class SQLFactory:
|
|
|
737
755
|
# ===================
|
|
738
756
|
|
|
739
757
|
@staticmethod
|
|
740
|
-
def count(
|
|
758
|
+
def count(
|
|
759
|
+
column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
|
|
760
|
+
) -> AggregateExpression:
|
|
741
761
|
"""Create a COUNT expression.
|
|
742
762
|
|
|
743
763
|
Args:
|
|
@@ -747,12 +767,14 @@ class SQLFactory:
|
|
|
747
767
|
Returns:
|
|
748
768
|
COUNT expression.
|
|
749
769
|
"""
|
|
750
|
-
if column == "*":
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
770
|
+
if isinstance(column, str) and column == "*":
|
|
771
|
+
expr = exp.Count(this=exp.Star(), distinct=distinct)
|
|
772
|
+
else:
|
|
773
|
+
col_expr = SQLFactory._extract_expression(column)
|
|
774
|
+
expr = exp.Count(this=col_expr, distinct=distinct)
|
|
775
|
+
return AggregateExpression(expr)
|
|
754
776
|
|
|
755
|
-
def count_distinct(self, column: Union[str, exp.Expression]) ->
|
|
777
|
+
def count_distinct(self, column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
|
|
756
778
|
"""Create a COUNT(DISTINCT column) expression.
|
|
757
779
|
|
|
758
780
|
Args:
|
|
@@ -764,7 +786,9 @@ class SQLFactory:
|
|
|
764
786
|
return self.count(column, distinct=True)
|
|
765
787
|
|
|
766
788
|
@staticmethod
|
|
767
|
-
def sum(
|
|
789
|
+
def sum(
|
|
790
|
+
column: Union[str, exp.Expression, "ExpressionWrapper", "Case"], distinct: bool = False
|
|
791
|
+
) -> AggregateExpression:
|
|
768
792
|
"""Create a SUM expression.
|
|
769
793
|
|
|
770
794
|
Args:
|
|
@@ -774,11 +798,11 @@ class SQLFactory:
|
|
|
774
798
|
Returns:
|
|
775
799
|
SUM expression.
|
|
776
800
|
"""
|
|
777
|
-
col_expr =
|
|
778
|
-
return exp.Sum(this=col_expr, distinct=distinct)
|
|
801
|
+
col_expr = SQLFactory._extract_expression(column)
|
|
802
|
+
return AggregateExpression(exp.Sum(this=col_expr, distinct=distinct))
|
|
779
803
|
|
|
780
804
|
@staticmethod
|
|
781
|
-
def avg(column: Union[str, exp.Expression]) ->
|
|
805
|
+
def avg(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
|
|
782
806
|
"""Create an AVG expression.
|
|
783
807
|
|
|
784
808
|
Args:
|
|
@@ -787,11 +811,11 @@ class SQLFactory:
|
|
|
787
811
|
Returns:
|
|
788
812
|
AVG expression.
|
|
789
813
|
"""
|
|
790
|
-
col_expr =
|
|
791
|
-
return exp.Avg(this=col_expr)
|
|
814
|
+
col_expr = SQLFactory._extract_expression(column)
|
|
815
|
+
return AggregateExpression(exp.Avg(this=col_expr))
|
|
792
816
|
|
|
793
817
|
@staticmethod
|
|
794
|
-
def max(column: Union[str, exp.Expression]) ->
|
|
818
|
+
def max(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
|
|
795
819
|
"""Create a MAX expression.
|
|
796
820
|
|
|
797
821
|
Args:
|
|
@@ -800,11 +824,11 @@ class SQLFactory:
|
|
|
800
824
|
Returns:
|
|
801
825
|
MAX expression.
|
|
802
826
|
"""
|
|
803
|
-
col_expr =
|
|
804
|
-
return exp.Max(this=col_expr)
|
|
827
|
+
col_expr = SQLFactory._extract_expression(column)
|
|
828
|
+
return AggregateExpression(exp.Max(this=col_expr))
|
|
805
829
|
|
|
806
830
|
@staticmethod
|
|
807
|
-
def min(column: Union[str, exp.Expression]) ->
|
|
831
|
+
def min(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
|
|
808
832
|
"""Create a MIN expression.
|
|
809
833
|
|
|
810
834
|
Args:
|
|
@@ -813,15 +837,15 @@ class SQLFactory:
|
|
|
813
837
|
Returns:
|
|
814
838
|
MIN expression.
|
|
815
839
|
"""
|
|
816
|
-
col_expr =
|
|
817
|
-
return exp.Min(this=col_expr)
|
|
840
|
+
col_expr = SQLFactory._extract_expression(column)
|
|
841
|
+
return AggregateExpression(exp.Min(this=col_expr))
|
|
818
842
|
|
|
819
843
|
# ===================
|
|
820
844
|
# Advanced SQL Operations
|
|
821
845
|
# ===================
|
|
822
846
|
|
|
823
847
|
@staticmethod
|
|
824
|
-
def rollup(*columns: Union[str, exp.Expression]) ->
|
|
848
|
+
def rollup(*columns: Union[str, exp.Expression]) -> FunctionExpression:
|
|
825
849
|
"""Create a ROLLUP expression for GROUP BY clauses.
|
|
826
850
|
|
|
827
851
|
Args:
|
|
@@ -841,10 +865,10 @@ class SQLFactory:
|
|
|
841
865
|
```
|
|
842
866
|
"""
|
|
843
867
|
column_exprs = [exp.column(col) if isinstance(col, str) else col for col in columns]
|
|
844
|
-
return exp.Rollup(expressions=column_exprs)
|
|
868
|
+
return FunctionExpression(exp.Rollup(expressions=column_exprs))
|
|
845
869
|
|
|
846
870
|
@staticmethod
|
|
847
|
-
def cube(*columns: Union[str, exp.Expression]) ->
|
|
871
|
+
def cube(*columns: Union[str, exp.Expression]) -> FunctionExpression:
|
|
848
872
|
"""Create a CUBE expression for GROUP BY clauses.
|
|
849
873
|
|
|
850
874
|
Args:
|
|
@@ -864,10 +888,10 @@ class SQLFactory:
|
|
|
864
888
|
```
|
|
865
889
|
"""
|
|
866
890
|
column_exprs = [exp.column(col) if isinstance(col, str) else col for col in columns]
|
|
867
|
-
return exp.Cube(expressions=column_exprs)
|
|
891
|
+
return FunctionExpression(exp.Cube(expressions=column_exprs))
|
|
868
892
|
|
|
869
893
|
@staticmethod
|
|
870
|
-
def grouping_sets(*column_sets: Union[tuple[str, ...], list[str]]) ->
|
|
894
|
+
def grouping_sets(*column_sets: Union[tuple[str, ...], list[str]]) -> FunctionExpression:
|
|
871
895
|
"""Create a GROUPING SETS expression for GROUP BY clauses.
|
|
872
896
|
|
|
873
897
|
Args:
|
|
@@ -899,10 +923,10 @@ class SQLFactory:
|
|
|
899
923
|
else:
|
|
900
924
|
set_expressions.append(exp.column(column_set))
|
|
901
925
|
|
|
902
|
-
return exp.GroupingSets(expressions=set_expressions)
|
|
926
|
+
return FunctionExpression(exp.GroupingSets(expressions=set_expressions))
|
|
903
927
|
|
|
904
928
|
@staticmethod
|
|
905
|
-
def any(values: Union[list[Any], exp.Expression, str]) ->
|
|
929
|
+
def any(values: Union[list[Any], exp.Expression, str]) -> FunctionExpression:
|
|
906
930
|
"""Create an ANY expression for use with comparison operators.
|
|
907
931
|
|
|
908
932
|
Args:
|
|
@@ -923,18 +947,18 @@ class SQLFactory:
|
|
|
923
947
|
```
|
|
924
948
|
"""
|
|
925
949
|
if isinstance(values, list):
|
|
926
|
-
literals = [SQLFactory.
|
|
927
|
-
return exp.Any(this=exp.Array(expressions=literals))
|
|
950
|
+
literals = [SQLFactory.to_literal(v) for v in values]
|
|
951
|
+
return FunctionExpression(exp.Any(this=exp.Array(expressions=literals)))
|
|
928
952
|
if isinstance(values, str):
|
|
929
953
|
# Parse as SQL
|
|
930
954
|
parsed = exp.maybe_parse(values) # type: ignore[var-annotated]
|
|
931
955
|
if parsed:
|
|
932
|
-
return exp.Any(this=parsed)
|
|
933
|
-
return exp.Any(this=exp.Literal.string(values))
|
|
934
|
-
return exp.Any(this=values)
|
|
956
|
+
return FunctionExpression(exp.Any(this=parsed))
|
|
957
|
+
return FunctionExpression(exp.Any(this=exp.Literal.string(values)))
|
|
958
|
+
return FunctionExpression(exp.Any(this=values))
|
|
935
959
|
|
|
936
960
|
@staticmethod
|
|
937
|
-
def not_any_(values: Union[list[Any], exp.Expression, str]) ->
|
|
961
|
+
def not_any_(values: Union[list[Any], exp.Expression, str]) -> FunctionExpression:
|
|
938
962
|
"""Create a NOT ANY expression for use with comparison operators.
|
|
939
963
|
|
|
940
964
|
Args:
|
|
@@ -954,14 +978,14 @@ class SQLFactory:
|
|
|
954
978
|
)
|
|
955
979
|
```
|
|
956
980
|
"""
|
|
957
|
-
return SQLFactory.any(values)
|
|
981
|
+
return SQLFactory.any(values)
|
|
958
982
|
|
|
959
983
|
# ===================
|
|
960
984
|
# String Functions
|
|
961
985
|
# ===================
|
|
962
986
|
|
|
963
987
|
@staticmethod
|
|
964
|
-
def concat(*expressions: Union[str, exp.Expression]) ->
|
|
988
|
+
def concat(*expressions: Union[str, exp.Expression]) -> StringExpression:
|
|
965
989
|
"""Create a CONCAT expression.
|
|
966
990
|
|
|
967
991
|
Args:
|
|
@@ -971,10 +995,10 @@ class SQLFactory:
|
|
|
971
995
|
CONCAT expression.
|
|
972
996
|
"""
|
|
973
997
|
exprs = [exp.column(expr) if isinstance(expr, str) else expr for expr in expressions]
|
|
974
|
-
return exp.Concat(expressions=exprs)
|
|
998
|
+
return StringExpression(exp.Concat(expressions=exprs))
|
|
975
999
|
|
|
976
1000
|
@staticmethod
|
|
977
|
-
def upper(column: Union[str, exp.Expression]) ->
|
|
1001
|
+
def upper(column: Union[str, exp.Expression]) -> StringExpression:
|
|
978
1002
|
"""Create an UPPER expression.
|
|
979
1003
|
|
|
980
1004
|
Args:
|
|
@@ -984,10 +1008,10 @@ class SQLFactory:
|
|
|
984
1008
|
UPPER expression.
|
|
985
1009
|
"""
|
|
986
1010
|
col_expr = exp.column(column) if isinstance(column, str) else column
|
|
987
|
-
return exp.Upper(this=col_expr)
|
|
1011
|
+
return StringExpression(exp.Upper(this=col_expr))
|
|
988
1012
|
|
|
989
1013
|
@staticmethod
|
|
990
|
-
def lower(column: Union[str, exp.Expression]) ->
|
|
1014
|
+
def lower(column: Union[str, exp.Expression]) -> StringExpression:
|
|
991
1015
|
"""Create a LOWER expression.
|
|
992
1016
|
|
|
993
1017
|
Args:
|
|
@@ -997,10 +1021,10 @@ class SQLFactory:
|
|
|
997
1021
|
LOWER expression.
|
|
998
1022
|
"""
|
|
999
1023
|
col_expr = exp.column(column) if isinstance(column, str) else column
|
|
1000
|
-
return exp.Lower(this=col_expr)
|
|
1024
|
+
return StringExpression(exp.Lower(this=col_expr))
|
|
1001
1025
|
|
|
1002
1026
|
@staticmethod
|
|
1003
|
-
def length(column: Union[str, exp.Expression]) ->
|
|
1027
|
+
def length(column: Union[str, exp.Expression]) -> StringExpression:
|
|
1004
1028
|
"""Create a LENGTH expression.
|
|
1005
1029
|
|
|
1006
1030
|
Args:
|
|
@@ -1010,14 +1034,14 @@ class SQLFactory:
|
|
|
1010
1034
|
LENGTH expression.
|
|
1011
1035
|
"""
|
|
1012
1036
|
col_expr = exp.column(column) if isinstance(column, str) else column
|
|
1013
|
-
return exp.Length(this=col_expr)
|
|
1037
|
+
return StringExpression(exp.Length(this=col_expr))
|
|
1014
1038
|
|
|
1015
1039
|
# ===================
|
|
1016
1040
|
# Math Functions
|
|
1017
1041
|
# ===================
|
|
1018
1042
|
|
|
1019
1043
|
@staticmethod
|
|
1020
|
-
def round(column: Union[str, exp.Expression], decimals: int = 0) ->
|
|
1044
|
+
def round(column: Union[str, exp.Expression], decimals: int = 0) -> MathExpression:
|
|
1021
1045
|
"""Create a ROUND expression.
|
|
1022
1046
|
|
|
1023
1047
|
Args:
|
|
@@ -1029,15 +1053,15 @@ class SQLFactory:
|
|
|
1029
1053
|
"""
|
|
1030
1054
|
col_expr = exp.column(column) if isinstance(column, str) else column
|
|
1031
1055
|
if decimals == 0:
|
|
1032
|
-
return exp.Round(this=col_expr)
|
|
1033
|
-
return exp.Round(this=col_expr, expression=exp.Literal.number(decimals))
|
|
1056
|
+
return MathExpression(exp.Round(this=col_expr))
|
|
1057
|
+
return MathExpression(exp.Round(this=col_expr, expression=exp.Literal.number(decimals)))
|
|
1034
1058
|
|
|
1035
1059
|
# ===================
|
|
1036
1060
|
# Conversion Functions
|
|
1037
1061
|
# ===================
|
|
1038
1062
|
|
|
1039
1063
|
@staticmethod
|
|
1040
|
-
def
|
|
1064
|
+
def to_literal(value: Any) -> FunctionExpression:
|
|
1041
1065
|
"""Convert a Python value to a SQLGlot literal expression.
|
|
1042
1066
|
|
|
1043
1067
|
Uses SQLGlot's built-in exp.convert() function for optimal dialect-agnostic
|
|
@@ -1054,12 +1078,52 @@ class SQLFactory:
|
|
|
1054
1078
|
Returns:
|
|
1055
1079
|
SQLGlot expression representing the literal value.
|
|
1056
1080
|
"""
|
|
1081
|
+
if isinstance(value, exp.Expression):
|
|
1082
|
+
return FunctionExpression(value)
|
|
1083
|
+
return FunctionExpression(exp.convert(value))
|
|
1084
|
+
|
|
1085
|
+
@staticmethod
|
|
1086
|
+
def _to_expression(value: Any) -> exp.Expression:
|
|
1087
|
+
"""Convert a Python value to a raw SQLGlot expression.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
value: Python value or SQLGlot expression to convert.
|
|
1091
|
+
|
|
1092
|
+
Returns:
|
|
1093
|
+
Raw SQLGlot expression.
|
|
1094
|
+
"""
|
|
1095
|
+
if isinstance(value, exp.Expression):
|
|
1096
|
+
return value
|
|
1097
|
+
return exp.convert(value)
|
|
1098
|
+
|
|
1099
|
+
@staticmethod
|
|
1100
|
+
def _extract_expression(value: Any) -> exp.Expression:
|
|
1101
|
+
"""Extract SQLGlot expression from value, handling our wrapper types.
|
|
1102
|
+
|
|
1103
|
+
Args:
|
|
1104
|
+
value: String, SQLGlot expression, or our wrapper type.
|
|
1105
|
+
|
|
1106
|
+
Returns:
|
|
1107
|
+
Raw SQLGlot expression.
|
|
1108
|
+
"""
|
|
1109
|
+
from sqlspec.builder._expression_wrappers import ExpressionWrapper
|
|
1110
|
+
from sqlspec.builder.mixins._select_operations import Case
|
|
1111
|
+
|
|
1112
|
+
if isinstance(value, str):
|
|
1113
|
+
return exp.column(value)
|
|
1114
|
+
if isinstance(value, Column):
|
|
1115
|
+
return value._expression
|
|
1116
|
+
if isinstance(value, ExpressionWrapper):
|
|
1117
|
+
return value.expression
|
|
1118
|
+
if isinstance(value, Case):
|
|
1119
|
+
# Case has _expression property via trait
|
|
1120
|
+
return exp.Case(ifs=value._conditions, default=value._default)
|
|
1057
1121
|
if isinstance(value, exp.Expression):
|
|
1058
1122
|
return value
|
|
1059
1123
|
return exp.convert(value)
|
|
1060
1124
|
|
|
1061
1125
|
@staticmethod
|
|
1062
|
-
def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression, Any]) ->
|
|
1126
|
+
def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression, Any]) -> FunctionExpression:
|
|
1063
1127
|
"""Create a DECODE expression (Oracle-style conditional logic).
|
|
1064
1128
|
|
|
1065
1129
|
DECODE compares column to each search value and returns the corresponding result.
|
|
@@ -1096,22 +1160,22 @@ class SQLFactory:
|
|
|
1096
1160
|
for i in range(0, len(args) - 1, 2):
|
|
1097
1161
|
if i + 1 >= len(args):
|
|
1098
1162
|
# Odd number of args means last one is default
|
|
1099
|
-
default = SQLFactory.
|
|
1163
|
+
default = SQLFactory._to_expression(args[i])
|
|
1100
1164
|
break
|
|
1101
1165
|
|
|
1102
1166
|
search_val = args[i]
|
|
1103
1167
|
result_val = args[i + 1]
|
|
1104
1168
|
|
|
1105
|
-
search_expr = SQLFactory.
|
|
1106
|
-
result_expr = SQLFactory.
|
|
1169
|
+
search_expr = SQLFactory._to_expression(search_val)
|
|
1170
|
+
result_expr = SQLFactory._to_expression(result_val)
|
|
1107
1171
|
|
|
1108
1172
|
condition = exp.EQ(this=col_expr, expression=search_expr)
|
|
1109
|
-
conditions.append(exp.
|
|
1173
|
+
conditions.append(exp.If(this=condition, true=result_expr))
|
|
1110
1174
|
|
|
1111
|
-
return exp.Case(ifs=conditions, default=default)
|
|
1175
|
+
return FunctionExpression(exp.Case(ifs=conditions, default=default))
|
|
1112
1176
|
|
|
1113
1177
|
@staticmethod
|
|
1114
|
-
def cast(column: Union[str, exp.Expression], data_type: str) ->
|
|
1178
|
+
def cast(column: Union[str, exp.Expression], data_type: str) -> ConversionExpression:
|
|
1115
1179
|
"""Create a CAST expression for type conversion.
|
|
1116
1180
|
|
|
1117
1181
|
Args:
|
|
@@ -1122,10 +1186,10 @@ class SQLFactory:
|
|
|
1122
1186
|
CAST expression.
|
|
1123
1187
|
"""
|
|
1124
1188
|
col_expr = exp.column(column) if isinstance(column, str) else column
|
|
1125
|
-
return exp.Cast(this=col_expr, to=exp.DataType.build(data_type))
|
|
1189
|
+
return ConversionExpression(exp.Cast(this=col_expr, to=exp.DataType.build(data_type)))
|
|
1126
1190
|
|
|
1127
1191
|
@staticmethod
|
|
1128
|
-
def coalesce(*expressions: Union[str, exp.Expression]) ->
|
|
1192
|
+
def coalesce(*expressions: Union[str, exp.Expression]) -> ConversionExpression:
|
|
1129
1193
|
"""Create a COALESCE expression.
|
|
1130
1194
|
|
|
1131
1195
|
Args:
|
|
@@ -1135,10 +1199,12 @@ class SQLFactory:
|
|
|
1135
1199
|
COALESCE expression.
|
|
1136
1200
|
"""
|
|
1137
1201
|
exprs = [exp.column(expr) if isinstance(expr, str) else expr for expr in expressions]
|
|
1138
|
-
return exp.Coalesce(expressions=exprs)
|
|
1202
|
+
return ConversionExpression(exp.Coalesce(expressions=exprs))
|
|
1139
1203
|
|
|
1140
1204
|
@staticmethod
|
|
1141
|
-
def nvl(
|
|
1205
|
+
def nvl(
|
|
1206
|
+
column: Union[str, exp.Expression], substitute_value: Union[str, exp.Expression, Any]
|
|
1207
|
+
) -> ConversionExpression:
|
|
1142
1208
|
"""Create an NVL (Oracle-style) expression using COALESCE.
|
|
1143
1209
|
|
|
1144
1210
|
Args:
|
|
@@ -1149,15 +1215,15 @@ class SQLFactory:
|
|
|
1149
1215
|
COALESCE expression equivalent to NVL.
|
|
1150
1216
|
"""
|
|
1151
1217
|
col_expr = exp.column(column) if isinstance(column, str) else column
|
|
1152
|
-
sub_expr = SQLFactory.
|
|
1153
|
-
return exp.Coalesce(expressions=[col_expr, sub_expr])
|
|
1218
|
+
sub_expr = SQLFactory._to_expression(substitute_value)
|
|
1219
|
+
return ConversionExpression(exp.Coalesce(expressions=[col_expr, sub_expr]))
|
|
1154
1220
|
|
|
1155
1221
|
@staticmethod
|
|
1156
1222
|
def nvl2(
|
|
1157
1223
|
column: Union[str, exp.Expression],
|
|
1158
1224
|
value_if_not_null: Union[str, exp.Expression, Any],
|
|
1159
1225
|
value_if_null: Union[str, exp.Expression, Any],
|
|
1160
|
-
) ->
|
|
1226
|
+
) -> ConversionExpression:
|
|
1161
1227
|
"""Create an NVL2 (Oracle-style) expression using CASE.
|
|
1162
1228
|
|
|
1163
1229
|
NVL2 returns value_if_not_null if column is not NULL,
|
|
@@ -1178,22 +1244,22 @@ class SQLFactory:
|
|
|
1178
1244
|
```
|
|
1179
1245
|
"""
|
|
1180
1246
|
col_expr = exp.column(column) if isinstance(column, str) else column
|
|
1181
|
-
not_null_expr = SQLFactory.
|
|
1182
|
-
null_expr = SQLFactory.
|
|
1247
|
+
not_null_expr = SQLFactory._to_expression(value_if_not_null)
|
|
1248
|
+
null_expr = SQLFactory._to_expression(value_if_null)
|
|
1183
1249
|
|
|
1184
1250
|
# Create CASE WHEN column IS NOT NULL THEN value_if_not_null ELSE value_if_null END
|
|
1185
1251
|
is_null = exp.Is(this=col_expr, expression=exp.Null())
|
|
1186
1252
|
condition = exp.Not(this=is_null)
|
|
1187
1253
|
when_clause = exp.If(this=condition, true=not_null_expr)
|
|
1188
1254
|
|
|
1189
|
-
return exp.Case(ifs=[when_clause], default=null_expr)
|
|
1255
|
+
return ConversionExpression(exp.Case(ifs=[when_clause], default=null_expr))
|
|
1190
1256
|
|
|
1191
1257
|
# ===================
|
|
1192
1258
|
# Bulk Operations
|
|
1193
1259
|
# ===================
|
|
1194
1260
|
|
|
1195
1261
|
@staticmethod
|
|
1196
|
-
def bulk_insert(table_name: str, column_count: int, placeholder_style: str = "?") ->
|
|
1262
|
+
def bulk_insert(table_name: str, column_count: int, placeholder_style: str = "?") -> FunctionExpression:
|
|
1197
1263
|
"""Create bulk INSERT expression for executemany operations.
|
|
1198
1264
|
|
|
1199
1265
|
This is specifically for bulk loading operations like CSV ingestion where
|
|
@@ -1228,13 +1294,15 @@ class SQLFactory:
|
|
|
1228
1294
|
# Creates: INSERT INTO "my_table" VALUES (:1, :2, :3)
|
|
1229
1295
|
```
|
|
1230
1296
|
"""
|
|
1231
|
-
return
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1297
|
+
return FunctionExpression(
|
|
1298
|
+
exp.Insert(
|
|
1299
|
+
this=exp.Table(this=exp.to_identifier(table_name)),
|
|
1300
|
+
expression=exp.Values(
|
|
1301
|
+
expressions=[
|
|
1302
|
+
exp.Tuple(expressions=[exp.Placeholder(this=placeholder_style) for _ in range(column_count)])
|
|
1303
|
+
]
|
|
1304
|
+
),
|
|
1305
|
+
)
|
|
1238
1306
|
)
|
|
1239
1307
|
|
|
1240
1308
|
def truncate(self, table_name: str) -> "Truncate":
|
|
@@ -1288,7 +1356,7 @@ class SQLFactory:
|
|
|
1288
1356
|
self,
|
|
1289
1357
|
partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
|
|
1290
1358
|
order_by: Optional[Union[str, list[str], exp.Expression]] = None,
|
|
1291
|
-
) ->
|
|
1359
|
+
) -> FunctionExpression:
|
|
1292
1360
|
"""Create a ROW_NUMBER() window function.
|
|
1293
1361
|
|
|
1294
1362
|
Args:
|
|
@@ -1304,7 +1372,7 @@ class SQLFactory:
|
|
|
1304
1372
|
self,
|
|
1305
1373
|
partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
|
|
1306
1374
|
order_by: Optional[Union[str, list[str], exp.Expression]] = None,
|
|
1307
|
-
) ->
|
|
1375
|
+
) -> FunctionExpression:
|
|
1308
1376
|
"""Create a RANK() window function.
|
|
1309
1377
|
|
|
1310
1378
|
Args:
|
|
@@ -1320,7 +1388,7 @@ class SQLFactory:
|
|
|
1320
1388
|
self,
|
|
1321
1389
|
partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
|
|
1322
1390
|
order_by: Optional[Union[str, list[str], exp.Expression]] = None,
|
|
1323
|
-
) ->
|
|
1391
|
+
) -> FunctionExpression:
|
|
1324
1392
|
"""Create a DENSE_RANK() window function.
|
|
1325
1393
|
|
|
1326
1394
|
Args:
|
|
@@ -1338,7 +1406,7 @@ class SQLFactory:
|
|
|
1338
1406
|
func_args: list[exp.Expression],
|
|
1339
1407
|
partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
|
|
1340
1408
|
order_by: Optional[Union[str, list[str], exp.Expression]] = None,
|
|
1341
|
-
) ->
|
|
1409
|
+
) -> FunctionExpression:
|
|
1342
1410
|
"""Helper to create window function expressions.
|
|
1343
1411
|
|
|
1344
1412
|
Args:
|
|
@@ -1364,418 +1432,13 @@ class SQLFactory:
|
|
|
1364
1432
|
|
|
1365
1433
|
if order_by:
|
|
1366
1434
|
if isinstance(order_by, str):
|
|
1367
|
-
over_args["order"] = [exp.column(order_by).asc()]
|
|
1435
|
+
over_args["order"] = exp.Order(expressions=[exp.column(order_by).asc()])
|
|
1368
1436
|
elif isinstance(order_by, list):
|
|
1369
|
-
over_args["order"] = [exp.column(col).asc() for col in order_by]
|
|
1437
|
+
over_args["order"] = exp.Order(expressions=[exp.column(col).asc() for col in order_by])
|
|
1370
1438
|
elif isinstance(order_by, exp.Expression):
|
|
1371
|
-
over_args["order"] = [order_by]
|
|
1372
|
-
|
|
1373
|
-
return exp.Window(this=func_expr, **over_args)
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
@trait
|
|
1377
|
-
class Case:
|
|
1378
|
-
"""Builder for CASE expressions using the SQL factory.
|
|
1379
|
-
|
|
1380
|
-
Example:
|
|
1381
|
-
```python
|
|
1382
|
-
from sqlspec import sql
|
|
1383
|
-
|
|
1384
|
-
case_expr = (
|
|
1385
|
-
sql.case()
|
|
1386
|
-
.when(sql.age < 18, "Minor")
|
|
1387
|
-
.when(sql.age < 65, "Adult")
|
|
1388
|
-
.else_("Senior")
|
|
1389
|
-
.end()
|
|
1390
|
-
)
|
|
1391
|
-
```
|
|
1392
|
-
"""
|
|
1393
|
-
|
|
1394
|
-
def __init__(self) -> None:
|
|
1395
|
-
"""Initialize the CASE expression builder."""
|
|
1396
|
-
self._conditions: list[exp.If] = []
|
|
1397
|
-
self._default: Optional[exp.Expression] = None
|
|
1398
|
-
|
|
1399
|
-
def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
|
|
1400
|
-
"""Equal to (==) - convert to expression then compare."""
|
|
1401
|
-
from sqlspec.builder._column import ColumnExpression
|
|
1402
|
-
|
|
1403
|
-
case_expr = exp.Case(ifs=self._conditions, default=self._default)
|
|
1404
|
-
if other is None:
|
|
1405
|
-
return ColumnExpression(exp.Is(this=case_expr, expression=exp.Null()))
|
|
1406
|
-
return ColumnExpression(exp.EQ(this=case_expr, expression=exp.convert(other)))
|
|
1407
|
-
|
|
1408
|
-
def __hash__(self) -> int:
|
|
1409
|
-
"""Make Case hashable."""
|
|
1410
|
-
return hash(id(self))
|
|
1411
|
-
|
|
1412
|
-
def when(self, condition: Union[str, exp.Expression], value: Union[str, exp.Expression, Any]) -> "Case":
|
|
1413
|
-
"""Add a WHEN clause.
|
|
1414
|
-
|
|
1415
|
-
Args:
|
|
1416
|
-
condition: Condition to test.
|
|
1417
|
-
value: Value to return if condition is true.
|
|
1418
|
-
|
|
1419
|
-
Returns:
|
|
1420
|
-
Self for method chaining.
|
|
1421
|
-
"""
|
|
1422
|
-
cond_expr = exp.maybe_parse(condition) or exp.column(condition) if isinstance(condition, str) else condition
|
|
1423
|
-
val_expr = SQLFactory._to_literal(value)
|
|
1424
|
-
|
|
1425
|
-
# SQLGlot uses exp.If for CASE WHEN clauses, not exp.When
|
|
1426
|
-
when_clause = exp.If(this=cond_expr, true=val_expr)
|
|
1427
|
-
self._conditions.append(when_clause)
|
|
1428
|
-
return self
|
|
1429
|
-
|
|
1430
|
-
def else_(self, value: Union[str, exp.Expression, Any]) -> "Case":
|
|
1431
|
-
"""Add an ELSE clause.
|
|
1432
|
-
|
|
1433
|
-
Args:
|
|
1434
|
-
value: Default value to return.
|
|
1435
|
-
|
|
1436
|
-
Returns:
|
|
1437
|
-
Self for method chaining.
|
|
1438
|
-
"""
|
|
1439
|
-
self._default = SQLFactory._to_literal(value)
|
|
1440
|
-
return self
|
|
1441
|
-
|
|
1442
|
-
def end(self) -> exp.Expression:
|
|
1443
|
-
"""Complete the CASE expression.
|
|
1444
|
-
|
|
1445
|
-
Returns:
|
|
1446
|
-
Complete CASE expression.
|
|
1447
|
-
"""
|
|
1448
|
-
return exp.Case(ifs=self._conditions, default=self._default)
|
|
1449
|
-
|
|
1450
|
-
def as_(self, alias: str) -> exp.Alias:
|
|
1451
|
-
"""Complete the CASE expression with an alias.
|
|
1452
|
-
|
|
1453
|
-
Args:
|
|
1454
|
-
alias: Alias name for the CASE expression.
|
|
1455
|
-
|
|
1456
|
-
Returns:
|
|
1457
|
-
Aliased CASE expression.
|
|
1458
|
-
"""
|
|
1459
|
-
case_expr = exp.Case(ifs=self._conditions, default=self._default)
|
|
1460
|
-
return cast("exp.Alias", exp.alias_(case_expr, alias))
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
@trait
|
|
1464
|
-
class WindowFunctionBuilder:
|
|
1465
|
-
"""Builder for window functions with fluent syntax.
|
|
1466
|
-
|
|
1467
|
-
Example:
|
|
1468
|
-
```python
|
|
1469
|
-
from sqlspec import sql
|
|
1470
|
-
|
|
1471
|
-
# sql.row_number_.partition_by("department").order_by("salary")
|
|
1472
|
-
window_func = (
|
|
1473
|
-
sql.row_number_.partition_by("department")
|
|
1474
|
-
.order_by("salary")
|
|
1475
|
-
.as_("row_num")
|
|
1476
|
-
)
|
|
1477
|
-
```
|
|
1478
|
-
"""
|
|
1479
|
-
|
|
1480
|
-
def __init__(self, function_name: str) -> None:
|
|
1481
|
-
"""Initialize the window function builder.
|
|
1482
|
-
|
|
1483
|
-
Args:
|
|
1484
|
-
function_name: Name of the window function (row_number, rank, etc.)
|
|
1485
|
-
"""
|
|
1486
|
-
self._function_name = function_name
|
|
1487
|
-
self._partition_by_cols: list[exp.Expression] = []
|
|
1488
|
-
self._order_by_cols: list[exp.Expression] = []
|
|
1489
|
-
self._alias: Optional[str] = None
|
|
1490
|
-
|
|
1491
|
-
def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
|
|
1492
|
-
"""Equal to (==) - convert to expression then compare."""
|
|
1493
|
-
from sqlspec.builder._column import ColumnExpression
|
|
1494
|
-
|
|
1495
|
-
window_expr = self._build_expression()
|
|
1496
|
-
if other is None:
|
|
1497
|
-
return ColumnExpression(exp.Is(this=window_expr, expression=exp.Null()))
|
|
1498
|
-
return ColumnExpression(exp.EQ(this=window_expr, expression=exp.convert(other)))
|
|
1499
|
-
|
|
1500
|
-
def __hash__(self) -> int:
|
|
1501
|
-
"""Make WindowFunctionBuilder hashable."""
|
|
1502
|
-
return hash(id(self))
|
|
1503
|
-
|
|
1504
|
-
def partition_by(self, *columns: Union[str, exp.Expression]) -> "WindowFunctionBuilder":
|
|
1505
|
-
"""Add PARTITION BY clause.
|
|
1506
|
-
|
|
1507
|
-
Args:
|
|
1508
|
-
*columns: Columns to partition by.
|
|
1439
|
+
over_args["order"] = exp.Order(expressions=[order_by])
|
|
1509
1440
|
|
|
1510
|
-
|
|
1511
|
-
Self for method chaining.
|
|
1512
|
-
"""
|
|
1513
|
-
for col in columns:
|
|
1514
|
-
col_expr = exp.column(col) if isinstance(col, str) else col
|
|
1515
|
-
self._partition_by_cols.append(col_expr)
|
|
1516
|
-
return self
|
|
1517
|
-
|
|
1518
|
-
def order_by(self, *columns: Union[str, exp.Expression]) -> "WindowFunctionBuilder":
|
|
1519
|
-
"""Add ORDER BY clause.
|
|
1520
|
-
|
|
1521
|
-
Args:
|
|
1522
|
-
*columns: Columns to order by.
|
|
1523
|
-
|
|
1524
|
-
Returns:
|
|
1525
|
-
Self for method chaining.
|
|
1526
|
-
"""
|
|
1527
|
-
for col in columns:
|
|
1528
|
-
if isinstance(col, str):
|
|
1529
|
-
col_expr = exp.column(col).asc()
|
|
1530
|
-
self._order_by_cols.append(col_expr)
|
|
1531
|
-
else:
|
|
1532
|
-
# Convert to ordered expression
|
|
1533
|
-
self._order_by_cols.append(exp.Ordered(this=col, desc=False))
|
|
1534
|
-
return self
|
|
1535
|
-
|
|
1536
|
-
def as_(self, alias: str) -> exp.Alias:
|
|
1537
|
-
"""Complete the window function with an alias.
|
|
1538
|
-
|
|
1539
|
-
Args:
|
|
1540
|
-
alias: Alias name for the window function.
|
|
1541
|
-
|
|
1542
|
-
Returns:
|
|
1543
|
-
Aliased window function expression.
|
|
1544
|
-
"""
|
|
1545
|
-
window_expr = self._build_expression()
|
|
1546
|
-
return cast("exp.Alias", exp.alias_(window_expr, alias))
|
|
1547
|
-
|
|
1548
|
-
def build(self) -> exp.Expression:
|
|
1549
|
-
"""Complete the window function without an alias.
|
|
1550
|
-
|
|
1551
|
-
Returns:
|
|
1552
|
-
Window function expression.
|
|
1553
|
-
"""
|
|
1554
|
-
return self._build_expression()
|
|
1555
|
-
|
|
1556
|
-
def _build_expression(self) -> exp.Expression:
|
|
1557
|
-
"""Build the complete window function expression."""
|
|
1558
|
-
# Create the function expression
|
|
1559
|
-
func_expr = exp.Anonymous(this=self._function_name.upper(), expressions=[])
|
|
1560
|
-
|
|
1561
|
-
# Build the OVER clause arguments
|
|
1562
|
-
over_args: dict[str, Any] = {}
|
|
1563
|
-
|
|
1564
|
-
if self._partition_by_cols:
|
|
1565
|
-
over_args["partition_by"] = self._partition_by_cols
|
|
1566
|
-
|
|
1567
|
-
if self._order_by_cols:
|
|
1568
|
-
over_args["order"] = exp.Order(expressions=self._order_by_cols)
|
|
1569
|
-
|
|
1570
|
-
return exp.Window(this=func_expr, **over_args)
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
@trait
|
|
1574
|
-
class SubqueryBuilder:
|
|
1575
|
-
"""Builder for subquery operations with fluent syntax.
|
|
1576
|
-
|
|
1577
|
-
Example:
|
|
1578
|
-
```python
|
|
1579
|
-
from sqlspec import sql
|
|
1580
|
-
|
|
1581
|
-
# sql.exists_(subquery)
|
|
1582
|
-
exists_check = sql.exists_(
|
|
1583
|
-
sql.select("1")
|
|
1584
|
-
.from_("orders")
|
|
1585
|
-
.where_eq("user_id", sql.users.id)
|
|
1586
|
-
)
|
|
1587
|
-
|
|
1588
|
-
# sql.in_(subquery)
|
|
1589
|
-
in_check = sql.in_(
|
|
1590
|
-
sql.select("category_id")
|
|
1591
|
-
.from_("categories")
|
|
1592
|
-
.where_eq("active", True)
|
|
1593
|
-
)
|
|
1594
|
-
```
|
|
1595
|
-
"""
|
|
1596
|
-
|
|
1597
|
-
def __init__(self, operation: str) -> None:
|
|
1598
|
-
"""Initialize the subquery builder.
|
|
1599
|
-
|
|
1600
|
-
Args:
|
|
1601
|
-
operation: Type of subquery operation (exists, in, any, all)
|
|
1602
|
-
"""
|
|
1603
|
-
self._operation = operation
|
|
1604
|
-
|
|
1605
|
-
def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
|
|
1606
|
-
"""Equal to (==) - not typically used but needed for type consistency."""
|
|
1607
|
-
from sqlspec.builder._column import ColumnExpression
|
|
1608
|
-
|
|
1609
|
-
# SubqueryBuilder doesn't have a direct expression, so this is a placeholder
|
|
1610
|
-
# In practice, this shouldn't be called as subqueries are used differently
|
|
1611
|
-
placeholder_expr = exp.Literal.string(f"subquery_{self._operation}")
|
|
1612
|
-
if other is None:
|
|
1613
|
-
return ColumnExpression(exp.Is(this=placeholder_expr, expression=exp.Null()))
|
|
1614
|
-
return ColumnExpression(exp.EQ(this=placeholder_expr, expression=exp.convert(other)))
|
|
1615
|
-
|
|
1616
|
-
def __hash__(self) -> int:
|
|
1617
|
-
"""Make SubqueryBuilder hashable."""
|
|
1618
|
-
return hash(id(self))
|
|
1619
|
-
|
|
1620
|
-
def __call__(self, subquery: Union[str, exp.Expression, Any]) -> exp.Expression:
|
|
1621
|
-
"""Build the subquery expression.
|
|
1622
|
-
|
|
1623
|
-
Args:
|
|
1624
|
-
subquery: The subquery - can be a SQL string, SelectBuilder, or expression
|
|
1625
|
-
|
|
1626
|
-
Returns:
|
|
1627
|
-
The subquery expression (EXISTS, IN, ANY, ALL, etc.)
|
|
1628
|
-
"""
|
|
1629
|
-
subquery_expr: exp.Expression
|
|
1630
|
-
if isinstance(subquery, str):
|
|
1631
|
-
# Parse as SQL
|
|
1632
|
-
parsed: Optional[exp.Expression] = exp.maybe_parse(subquery)
|
|
1633
|
-
if not parsed:
|
|
1634
|
-
msg = f"Could not parse subquery SQL: {subquery}"
|
|
1635
|
-
raise SQLBuilderError(msg)
|
|
1636
|
-
subquery_expr = parsed
|
|
1637
|
-
elif hasattr(subquery, "build") and callable(getattr(subquery, "build", None)):
|
|
1638
|
-
# It's a query builder - build it to get the SQL and parse
|
|
1639
|
-
built_query = subquery.build() # pyright: ignore[reportAttributeAccessIssue]
|
|
1640
|
-
subquery_expr = exp.maybe_parse(built_query.sql)
|
|
1641
|
-
if not subquery_expr:
|
|
1642
|
-
msg = f"Could not parse built query: {built_query.sql}"
|
|
1643
|
-
raise SQLBuilderError(msg)
|
|
1644
|
-
elif isinstance(subquery, exp.Expression):
|
|
1645
|
-
subquery_expr = subquery
|
|
1646
|
-
else:
|
|
1647
|
-
# Try to convert to expression
|
|
1648
|
-
parsed = exp.maybe_parse(str(subquery))
|
|
1649
|
-
if not parsed:
|
|
1650
|
-
msg = f"Could not convert subquery to expression: {subquery}"
|
|
1651
|
-
raise SQLBuilderError(msg)
|
|
1652
|
-
subquery_expr = parsed
|
|
1653
|
-
|
|
1654
|
-
# Build the appropriate expression based on operation
|
|
1655
|
-
if self._operation == "exists":
|
|
1656
|
-
return exp.Exists(this=subquery_expr)
|
|
1657
|
-
if self._operation == "in":
|
|
1658
|
-
# For IN, we create a subquery that can be used with WHERE column IN (subquery)
|
|
1659
|
-
return exp.In(expressions=[subquery_expr])
|
|
1660
|
-
if self._operation == "any":
|
|
1661
|
-
return exp.Any(this=subquery_expr)
|
|
1662
|
-
if self._operation == "all":
|
|
1663
|
-
return exp.All(this=subquery_expr)
|
|
1664
|
-
msg = f"Unknown subquery operation: {self._operation}"
|
|
1665
|
-
raise SQLBuilderError(msg)
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
@trait
|
|
1669
|
-
class JoinBuilder:
|
|
1670
|
-
"""Builder for JOIN operations with fluent syntax.
|
|
1671
|
-
|
|
1672
|
-
Example:
|
|
1673
|
-
```python
|
|
1674
|
-
from sqlspec import sql
|
|
1675
|
-
|
|
1676
|
-
# sql.left_join_("posts").on("users.id = posts.user_id")
|
|
1677
|
-
join_clause = sql.left_join_("posts").on(
|
|
1678
|
-
"users.id = posts.user_id"
|
|
1679
|
-
)
|
|
1680
|
-
|
|
1681
|
-
# Or with query builder
|
|
1682
|
-
query = (
|
|
1683
|
-
sql.select("users.name", "posts.title")
|
|
1684
|
-
.from_("users")
|
|
1685
|
-
.join(
|
|
1686
|
-
sql.left_join_("posts").on(
|
|
1687
|
-
"users.id = posts.user_id"
|
|
1688
|
-
)
|
|
1689
|
-
)
|
|
1690
|
-
)
|
|
1691
|
-
```
|
|
1692
|
-
"""
|
|
1693
|
-
|
|
1694
|
-
def __init__(self, join_type: str) -> None:
|
|
1695
|
-
"""Initialize the join builder.
|
|
1696
|
-
|
|
1697
|
-
Args:
|
|
1698
|
-
join_type: Type of join (inner, left, right, full, cross)
|
|
1699
|
-
"""
|
|
1700
|
-
self._join_type = join_type.upper()
|
|
1701
|
-
self._table: Optional[Union[str, exp.Expression]] = None
|
|
1702
|
-
self._condition: Optional[exp.Expression] = None
|
|
1703
|
-
self._alias: Optional[str] = None
|
|
1704
|
-
|
|
1705
|
-
def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
|
|
1706
|
-
"""Equal to (==) - not typically used but needed for type consistency."""
|
|
1707
|
-
from sqlspec.builder._column import ColumnExpression
|
|
1708
|
-
|
|
1709
|
-
# JoinBuilder doesn't have a direct expression, so this is a placeholder
|
|
1710
|
-
# In practice, this shouldn't be called as joins are used differently
|
|
1711
|
-
placeholder_expr = exp.Literal.string(f"join_{self._join_type.lower()}")
|
|
1712
|
-
if other is None:
|
|
1713
|
-
return ColumnExpression(exp.Is(this=placeholder_expr, expression=exp.Null()))
|
|
1714
|
-
return ColumnExpression(exp.EQ(this=placeholder_expr, expression=exp.convert(other)))
|
|
1715
|
-
|
|
1716
|
-
def __hash__(self) -> int:
|
|
1717
|
-
"""Make JoinBuilder hashable."""
|
|
1718
|
-
return hash(id(self))
|
|
1719
|
-
|
|
1720
|
-
def __call__(self, table: Union[str, exp.Expression], alias: Optional[str] = None) -> "JoinBuilder":
|
|
1721
|
-
"""Set the table to join.
|
|
1722
|
-
|
|
1723
|
-
Args:
|
|
1724
|
-
table: Table name or expression to join
|
|
1725
|
-
alias: Optional alias for the table
|
|
1726
|
-
|
|
1727
|
-
Returns:
|
|
1728
|
-
Self for method chaining
|
|
1729
|
-
"""
|
|
1730
|
-
self._table = table
|
|
1731
|
-
self._alias = alias
|
|
1732
|
-
return self
|
|
1733
|
-
|
|
1734
|
-
def on(self, condition: Union[str, exp.Expression]) -> exp.Expression:
|
|
1735
|
-
"""Set the join condition and build the JOIN expression.
|
|
1736
|
-
|
|
1737
|
-
Args:
|
|
1738
|
-
condition: JOIN condition (e.g., "users.id = posts.user_id")
|
|
1739
|
-
|
|
1740
|
-
Returns:
|
|
1741
|
-
Complete JOIN expression
|
|
1742
|
-
"""
|
|
1743
|
-
if not self._table:
|
|
1744
|
-
msg = "Table must be set before calling .on()"
|
|
1745
|
-
raise SQLBuilderError(msg)
|
|
1746
|
-
|
|
1747
|
-
# Parse the condition
|
|
1748
|
-
condition_expr: exp.Expression
|
|
1749
|
-
if isinstance(condition, str):
|
|
1750
|
-
parsed: Optional[exp.Expression] = exp.maybe_parse(condition)
|
|
1751
|
-
condition_expr = parsed or exp.condition(condition)
|
|
1752
|
-
else:
|
|
1753
|
-
condition_expr = condition
|
|
1754
|
-
|
|
1755
|
-
# Build table expression
|
|
1756
|
-
table_expr: exp.Expression
|
|
1757
|
-
if isinstance(self._table, str):
|
|
1758
|
-
table_expr = exp.to_table(self._table)
|
|
1759
|
-
if self._alias:
|
|
1760
|
-
table_expr = exp.alias_(table_expr, self._alias)
|
|
1761
|
-
else:
|
|
1762
|
-
table_expr = self._table
|
|
1763
|
-
if self._alias:
|
|
1764
|
-
table_expr = exp.alias_(table_expr, self._alias)
|
|
1765
|
-
|
|
1766
|
-
# Create the appropriate join type using same pattern as existing JoinClauseMixin
|
|
1767
|
-
if self._join_type == "INNER JOIN":
|
|
1768
|
-
return exp.Join(this=table_expr, on=condition_expr)
|
|
1769
|
-
if self._join_type == "LEFT JOIN":
|
|
1770
|
-
return exp.Join(this=table_expr, on=condition_expr, side="LEFT")
|
|
1771
|
-
if self._join_type == "RIGHT JOIN":
|
|
1772
|
-
return exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
|
|
1773
|
-
if self._join_type == "FULL JOIN":
|
|
1774
|
-
return exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
|
|
1775
|
-
if self._join_type == "CROSS JOIN":
|
|
1776
|
-
# CROSS JOIN doesn't use ON condition
|
|
1777
|
-
return exp.Join(this=table_expr, kind="CROSS")
|
|
1778
|
-
return exp.Join(this=table_expr, on=condition_expr)
|
|
1441
|
+
return FunctionExpression(exp.Window(this=func_expr, **over_args))
|
|
1779
1442
|
|
|
1780
1443
|
|
|
1781
1444
|
# Create a default SQL factory instance
|