sqlspec 0.16.2__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

Files changed (36) hide show
  1. sqlspec/__init__.py +11 -1
  2. sqlspec/_sql.py +152 -489
  3. sqlspec/adapters/aiosqlite/__init__.py +11 -1
  4. sqlspec/adapters/aiosqlite/config.py +137 -165
  5. sqlspec/adapters/aiosqlite/driver.py +21 -10
  6. sqlspec/adapters/aiosqlite/pool.py +492 -0
  7. sqlspec/adapters/duckdb/__init__.py +2 -0
  8. sqlspec/adapters/duckdb/config.py +11 -235
  9. sqlspec/adapters/duckdb/pool.py +243 -0
  10. sqlspec/adapters/sqlite/__init__.py +2 -0
  11. sqlspec/adapters/sqlite/config.py +4 -115
  12. sqlspec/adapters/sqlite/pool.py +140 -0
  13. sqlspec/base.py +147 -26
  14. sqlspec/builder/__init__.py +6 -0
  15. sqlspec/builder/_column.py +5 -1
  16. sqlspec/builder/_expression_wrappers.py +46 -0
  17. sqlspec/builder/_insert.py +1 -3
  18. sqlspec/builder/_parsing_utils.py +27 -0
  19. sqlspec/builder/_update.py +5 -5
  20. sqlspec/builder/mixins/_join_operations.py +115 -1
  21. sqlspec/builder/mixins/_order_limit_operations.py +16 -4
  22. sqlspec/builder/mixins/_select_operations.py +307 -3
  23. sqlspec/builder/mixins/_update_operations.py +4 -4
  24. sqlspec/builder/mixins/_where_clause.py +60 -11
  25. sqlspec/core/compiler.py +7 -5
  26. sqlspec/driver/_common.py +9 -1
  27. sqlspec/loader.py +27 -54
  28. sqlspec/protocols.py +10 -0
  29. sqlspec/storage/registry.py +2 -2
  30. sqlspec/typing.py +53 -99
  31. {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/METADATA +1 -1
  32. {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/RECORD +36 -32
  33. {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/WHEEL +0 -0
  34. {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/entry_points.txt +0 -0
  35. {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/licenses/LICENSE +0 -0
  36. {sqlspec-0.16.2.dist-info → sqlspec-0.17.1.dist-info}/licenses/NOTICE +0 -0
sqlspec/_sql.py CHANGED
@@ -4,10 +4,9 @@ Provides both statement builders (select, insert, update, etc.) and column expre
4
4
  """
5
5
 
6
6
  import logging
7
- from typing import TYPE_CHECKING, Any, Optional, Union, cast
7
+ from typing import TYPE_CHECKING, Any, Optional, Union
8
8
 
9
9
  import sqlglot
10
- from mypy_extensions import trait
11
10
  from sqlglot import exp
12
11
  from sqlglot.dialects.dialect import DialectType
13
12
  from sqlglot.errors import ParseError as SQLGlotParseError
@@ -34,12 +33,22 @@ from sqlspec.builder import (
34
33
  Truncate,
35
34
  Update,
36
35
  )
36
+ from sqlspec.builder._expression_wrappers import (
37
+ AggregateExpression,
38
+ ConversionExpression,
39
+ FunctionExpression,
40
+ MathExpression,
41
+ StringExpression,
42
+ )
43
+ from sqlspec.builder.mixins._join_operations import JoinBuilder
44
+ from sqlspec.builder.mixins._select_operations import Case, SubqueryBuilder, WindowFunctionBuilder
37
45
  from sqlspec.exceptions import SQLBuilderError
38
46
 
39
47
  if TYPE_CHECKING:
40
- from sqlspec.builder._column import ColumnExpression
48
+ from sqlspec.builder._expression_wrappers import ExpressionWrapper
41
49
  from sqlspec.core.statement import SQL
42
50
 
51
+
43
52
  __all__ = (
44
53
  "AlterTable",
45
54
  "Case",
@@ -179,7 +188,7 @@ class SQLFactory:
179
188
  # Statement Builders
180
189
  # ===================
181
190
  def select(
182
- self, *columns_or_sql: Union[str, exp.Expression, Column, "SQL"], dialect: DialectType = None
191
+ self, *columns_or_sql: Union[str, exp.Expression, Column, "SQL", "Case"], dialect: DialectType = None
183
192
  ) -> "Select":
184
193
  builder_dialect = dialect or self.dialect
185
194
  if len(columns_or_sql) == 1 and isinstance(columns_or_sql[0], str):
@@ -223,7 +232,10 @@ class SQLFactory:
223
232
  if self._looks_like_sql(table_or_sql):
224
233
  detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
225
234
  if detected != "UPDATE":
226
- msg = f"sql.update() expects UPDATE statement, got {detected}. Use sql.{detected.lower()}() if a dedicated builder exists."
235
+ msg = (
236
+ f"sql.update() expects UPDATE statement, got {detected}. "
237
+ f"Use sql.{detected.lower()}() if a dedicated builder exists."
238
+ )
227
239
  raise SQLBuilderError(msg)
228
240
  return self._populate_update_from_sql(builder, table_or_sql)
229
241
  return builder.table(table_or_sql)
@@ -235,7 +247,10 @@ class SQLFactory:
235
247
  if table_or_sql and self._looks_like_sql(table_or_sql):
236
248
  detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
237
249
  if detected != "DELETE":
238
- msg = f"sql.delete() expects DELETE statement, got {detected}. Use sql.{detected.lower()}() if a dedicated builder exists."
250
+ msg = (
251
+ f"sql.delete() expects DELETE statement, got {detected}. "
252
+ f"Use sql.{detected.lower()}() if a dedicated builder exists."
253
+ )
239
254
  raise SQLBuilderError(msg)
240
255
  return self._populate_delete_from_sql(builder, table_or_sql)
241
256
  return builder
@@ -247,7 +262,10 @@ class SQLFactory:
247
262
  if self._looks_like_sql(table_or_sql):
248
263
  detected = self.detect_sql_type(table_or_sql, dialect=builder_dialect)
249
264
  if detected != "MERGE":
250
- msg = f"sql.merge() expects MERGE statement, got {detected}. Use sql.{detected.lower()}() if a dedicated builder exists."
265
+ msg = (
266
+ f"sql.merge() expects MERGE statement, got {detected}. "
267
+ f"Use sql.{detected.lower()}() if a dedicated builder exists."
268
+ )
251
269
  raise SQLBuilderError(msg)
252
270
  return self._populate_merge_from_sql(builder, table_or_sql)
253
271
  return builder.into(table_or_sql)
@@ -737,7 +755,9 @@ class SQLFactory:
737
755
  # ===================
738
756
 
739
757
  @staticmethod
740
- def count(column: Union[str, exp.Expression] = "*", distinct: bool = False) -> exp.Expression:
758
+ def count(
759
+ column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
760
+ ) -> AggregateExpression:
741
761
  """Create a COUNT expression.
742
762
 
743
763
  Args:
@@ -747,12 +767,14 @@ class SQLFactory:
747
767
  Returns:
748
768
  COUNT expression.
749
769
  """
750
- if column == "*":
751
- return exp.Count(this=exp.Star(), distinct=distinct)
752
- col_expr = exp.column(column) if isinstance(column, str) else column
753
- return exp.Count(this=col_expr, distinct=distinct)
770
+ if isinstance(column, str) and column == "*":
771
+ expr = exp.Count(this=exp.Star(), distinct=distinct)
772
+ else:
773
+ col_expr = SQLFactory._extract_expression(column)
774
+ expr = exp.Count(this=col_expr, distinct=distinct)
775
+ return AggregateExpression(expr)
754
776
 
755
- def count_distinct(self, column: Union[str, exp.Expression]) -> exp.Expression:
777
+ def count_distinct(self, column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
756
778
  """Create a COUNT(DISTINCT column) expression.
757
779
 
758
780
  Args:
@@ -764,7 +786,9 @@ class SQLFactory:
764
786
  return self.count(column, distinct=True)
765
787
 
766
788
  @staticmethod
767
- def sum(column: Union[str, exp.Expression], distinct: bool = False) -> exp.Expression:
789
+ def sum(
790
+ column: Union[str, exp.Expression, "ExpressionWrapper", "Case"], distinct: bool = False
791
+ ) -> AggregateExpression:
768
792
  """Create a SUM expression.
769
793
 
770
794
  Args:
@@ -774,11 +798,11 @@ class SQLFactory:
774
798
  Returns:
775
799
  SUM expression.
776
800
  """
777
- col_expr = exp.column(column) if isinstance(column, str) else column
778
- return exp.Sum(this=col_expr, distinct=distinct)
801
+ col_expr = SQLFactory._extract_expression(column)
802
+ return AggregateExpression(exp.Sum(this=col_expr, distinct=distinct))
779
803
 
780
804
  @staticmethod
781
- def avg(column: Union[str, exp.Expression]) -> exp.Expression:
805
+ def avg(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
782
806
  """Create an AVG expression.
783
807
 
784
808
  Args:
@@ -787,11 +811,11 @@ class SQLFactory:
787
811
  Returns:
788
812
  AVG expression.
789
813
  """
790
- col_expr = exp.column(column) if isinstance(column, str) else column
791
- return exp.Avg(this=col_expr)
814
+ col_expr = SQLFactory._extract_expression(column)
815
+ return AggregateExpression(exp.Avg(this=col_expr))
792
816
 
793
817
  @staticmethod
794
- def max(column: Union[str, exp.Expression]) -> exp.Expression:
818
+ def max(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
795
819
  """Create a MAX expression.
796
820
 
797
821
  Args:
@@ -800,11 +824,11 @@ class SQLFactory:
800
824
  Returns:
801
825
  MAX expression.
802
826
  """
803
- col_expr = exp.column(column) if isinstance(column, str) else column
804
- return exp.Max(this=col_expr)
827
+ col_expr = SQLFactory._extract_expression(column)
828
+ return AggregateExpression(exp.Max(this=col_expr))
805
829
 
806
830
  @staticmethod
807
- def min(column: Union[str, exp.Expression]) -> exp.Expression:
831
+ def min(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
808
832
  """Create a MIN expression.
809
833
 
810
834
  Args:
@@ -813,15 +837,15 @@ class SQLFactory:
813
837
  Returns:
814
838
  MIN expression.
815
839
  """
816
- col_expr = exp.column(column) if isinstance(column, str) else column
817
- return exp.Min(this=col_expr)
840
+ col_expr = SQLFactory._extract_expression(column)
841
+ return AggregateExpression(exp.Min(this=col_expr))
818
842
 
819
843
  # ===================
820
844
  # Advanced SQL Operations
821
845
  # ===================
822
846
 
823
847
  @staticmethod
824
- def rollup(*columns: Union[str, exp.Expression]) -> exp.Expression:
848
+ def rollup(*columns: Union[str, exp.Expression]) -> FunctionExpression:
825
849
  """Create a ROLLUP expression for GROUP BY clauses.
826
850
 
827
851
  Args:
@@ -841,10 +865,10 @@ class SQLFactory:
841
865
  ```
842
866
  """
843
867
  column_exprs = [exp.column(col) if isinstance(col, str) else col for col in columns]
844
- return exp.Rollup(expressions=column_exprs)
868
+ return FunctionExpression(exp.Rollup(expressions=column_exprs))
845
869
 
846
870
  @staticmethod
847
- def cube(*columns: Union[str, exp.Expression]) -> exp.Expression:
871
+ def cube(*columns: Union[str, exp.Expression]) -> FunctionExpression:
848
872
  """Create a CUBE expression for GROUP BY clauses.
849
873
 
850
874
  Args:
@@ -864,10 +888,10 @@ class SQLFactory:
864
888
  ```
865
889
  """
866
890
  column_exprs = [exp.column(col) if isinstance(col, str) else col for col in columns]
867
- return exp.Cube(expressions=column_exprs)
891
+ return FunctionExpression(exp.Cube(expressions=column_exprs))
868
892
 
869
893
  @staticmethod
870
- def grouping_sets(*column_sets: Union[tuple[str, ...], list[str]]) -> exp.Expression:
894
+ def grouping_sets(*column_sets: Union[tuple[str, ...], list[str]]) -> FunctionExpression:
871
895
  """Create a GROUPING SETS expression for GROUP BY clauses.
872
896
 
873
897
  Args:
@@ -899,10 +923,10 @@ class SQLFactory:
899
923
  else:
900
924
  set_expressions.append(exp.column(column_set))
901
925
 
902
- return exp.GroupingSets(expressions=set_expressions)
926
+ return FunctionExpression(exp.GroupingSets(expressions=set_expressions))
903
927
 
904
928
  @staticmethod
905
- def any(values: Union[list[Any], exp.Expression, str]) -> exp.Expression:
929
+ def any(values: Union[list[Any], exp.Expression, str]) -> FunctionExpression:
906
930
  """Create an ANY expression for use with comparison operators.
907
931
 
908
932
  Args:
@@ -923,18 +947,18 @@ class SQLFactory:
923
947
  ```
924
948
  """
925
949
  if isinstance(values, list):
926
- literals = [SQLFactory._to_literal(v) for v in values]
927
- return exp.Any(this=exp.Array(expressions=literals))
950
+ literals = [SQLFactory.to_literal(v) for v in values]
951
+ return FunctionExpression(exp.Any(this=exp.Array(expressions=literals)))
928
952
  if isinstance(values, str):
929
953
  # Parse as SQL
930
954
  parsed = exp.maybe_parse(values) # type: ignore[var-annotated]
931
955
  if parsed:
932
- return exp.Any(this=parsed)
933
- return exp.Any(this=exp.Literal.string(values))
934
- return exp.Any(this=values)
956
+ return FunctionExpression(exp.Any(this=parsed))
957
+ return FunctionExpression(exp.Any(this=exp.Literal.string(values)))
958
+ return FunctionExpression(exp.Any(this=values))
935
959
 
936
960
  @staticmethod
937
- def not_any_(values: Union[list[Any], exp.Expression, str]) -> exp.Expression:
961
+ def not_any_(values: Union[list[Any], exp.Expression, str]) -> FunctionExpression:
938
962
  """Create a NOT ANY expression for use with comparison operators.
939
963
 
940
964
  Args:
@@ -954,14 +978,14 @@ class SQLFactory:
954
978
  )
955
979
  ```
956
980
  """
957
- return SQLFactory.any(values) # NOT ANY is handled by the comparison operator
981
+ return SQLFactory.any(values)
958
982
 
959
983
  # ===================
960
984
  # String Functions
961
985
  # ===================
962
986
 
963
987
  @staticmethod
964
- def concat(*expressions: Union[str, exp.Expression]) -> exp.Expression:
988
+ def concat(*expressions: Union[str, exp.Expression]) -> StringExpression:
965
989
  """Create a CONCAT expression.
966
990
 
967
991
  Args:
@@ -971,10 +995,10 @@ class SQLFactory:
971
995
  CONCAT expression.
972
996
  """
973
997
  exprs = [exp.column(expr) if isinstance(expr, str) else expr for expr in expressions]
974
- return exp.Concat(expressions=exprs)
998
+ return StringExpression(exp.Concat(expressions=exprs))
975
999
 
976
1000
  @staticmethod
977
- def upper(column: Union[str, exp.Expression]) -> exp.Expression:
1001
+ def upper(column: Union[str, exp.Expression]) -> StringExpression:
978
1002
  """Create an UPPER expression.
979
1003
 
980
1004
  Args:
@@ -984,10 +1008,10 @@ class SQLFactory:
984
1008
  UPPER expression.
985
1009
  """
986
1010
  col_expr = exp.column(column) if isinstance(column, str) else column
987
- return exp.Upper(this=col_expr)
1011
+ return StringExpression(exp.Upper(this=col_expr))
988
1012
 
989
1013
  @staticmethod
990
- def lower(column: Union[str, exp.Expression]) -> exp.Expression:
1014
+ def lower(column: Union[str, exp.Expression]) -> StringExpression:
991
1015
  """Create a LOWER expression.
992
1016
 
993
1017
  Args:
@@ -997,10 +1021,10 @@ class SQLFactory:
997
1021
  LOWER expression.
998
1022
  """
999
1023
  col_expr = exp.column(column) if isinstance(column, str) else column
1000
- return exp.Lower(this=col_expr)
1024
+ return StringExpression(exp.Lower(this=col_expr))
1001
1025
 
1002
1026
  @staticmethod
1003
- def length(column: Union[str, exp.Expression]) -> exp.Expression:
1027
+ def length(column: Union[str, exp.Expression]) -> StringExpression:
1004
1028
  """Create a LENGTH expression.
1005
1029
 
1006
1030
  Args:
@@ -1010,14 +1034,14 @@ class SQLFactory:
1010
1034
  LENGTH expression.
1011
1035
  """
1012
1036
  col_expr = exp.column(column) if isinstance(column, str) else column
1013
- return exp.Length(this=col_expr)
1037
+ return StringExpression(exp.Length(this=col_expr))
1014
1038
 
1015
1039
  # ===================
1016
1040
  # Math Functions
1017
1041
  # ===================
1018
1042
 
1019
1043
  @staticmethod
1020
- def round(column: Union[str, exp.Expression], decimals: int = 0) -> exp.Expression:
1044
+ def round(column: Union[str, exp.Expression], decimals: int = 0) -> MathExpression:
1021
1045
  """Create a ROUND expression.
1022
1046
 
1023
1047
  Args:
@@ -1029,15 +1053,15 @@ class SQLFactory:
1029
1053
  """
1030
1054
  col_expr = exp.column(column) if isinstance(column, str) else column
1031
1055
  if decimals == 0:
1032
- return exp.Round(this=col_expr)
1033
- return exp.Round(this=col_expr, expression=exp.Literal.number(decimals))
1056
+ return MathExpression(exp.Round(this=col_expr))
1057
+ return MathExpression(exp.Round(this=col_expr, expression=exp.Literal.number(decimals)))
1034
1058
 
1035
1059
  # ===================
1036
1060
  # Conversion Functions
1037
1061
  # ===================
1038
1062
 
1039
1063
  @staticmethod
1040
- def _to_literal(value: Any) -> exp.Expression:
1064
+ def to_literal(value: Any) -> FunctionExpression:
1041
1065
  """Convert a Python value to a SQLGlot literal expression.
1042
1066
 
1043
1067
  Uses SQLGlot's built-in exp.convert() function for optimal dialect-agnostic
@@ -1054,12 +1078,52 @@ class SQLFactory:
1054
1078
  Returns:
1055
1079
  SQLGlot expression representing the literal value.
1056
1080
  """
1081
+ if isinstance(value, exp.Expression):
1082
+ return FunctionExpression(value)
1083
+ return FunctionExpression(exp.convert(value))
1084
+
1085
+ @staticmethod
1086
+ def _to_expression(value: Any) -> exp.Expression:
1087
+ """Convert a Python value to a raw SQLGlot expression.
1088
+
1089
+ Args:
1090
+ value: Python value or SQLGlot expression to convert.
1091
+
1092
+ Returns:
1093
+ Raw SQLGlot expression.
1094
+ """
1095
+ if isinstance(value, exp.Expression):
1096
+ return value
1097
+ return exp.convert(value)
1098
+
1099
+ @staticmethod
1100
+ def _extract_expression(value: Any) -> exp.Expression:
1101
+ """Extract SQLGlot expression from value, handling our wrapper types.
1102
+
1103
+ Args:
1104
+ value: String, SQLGlot expression, or our wrapper type.
1105
+
1106
+ Returns:
1107
+ Raw SQLGlot expression.
1108
+ """
1109
+ from sqlspec.builder._expression_wrappers import ExpressionWrapper
1110
+ from sqlspec.builder.mixins._select_operations import Case
1111
+
1112
+ if isinstance(value, str):
1113
+ return exp.column(value)
1114
+ if isinstance(value, Column):
1115
+ return value._expression
1116
+ if isinstance(value, ExpressionWrapper):
1117
+ return value.expression
1118
+ if isinstance(value, Case):
1119
+ # Case has _expression property via trait
1120
+ return exp.Case(ifs=value._conditions, default=value._default)
1057
1121
  if isinstance(value, exp.Expression):
1058
1122
  return value
1059
1123
  return exp.convert(value)
1060
1124
 
1061
1125
  @staticmethod
1062
- def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression, Any]) -> exp.Expression:
1126
+ def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression, Any]) -> FunctionExpression:
1063
1127
  """Create a DECODE expression (Oracle-style conditional logic).
1064
1128
 
1065
1129
  DECODE compares column to each search value and returns the corresponding result.
@@ -1096,22 +1160,22 @@ class SQLFactory:
1096
1160
  for i in range(0, len(args) - 1, 2):
1097
1161
  if i + 1 >= len(args):
1098
1162
  # Odd number of args means last one is default
1099
- default = SQLFactory._to_literal(args[i])
1163
+ default = SQLFactory._to_expression(args[i])
1100
1164
  break
1101
1165
 
1102
1166
  search_val = args[i]
1103
1167
  result_val = args[i + 1]
1104
1168
 
1105
- search_expr = SQLFactory._to_literal(search_val)
1106
- result_expr = SQLFactory._to_literal(result_val)
1169
+ search_expr = SQLFactory._to_expression(search_val)
1170
+ result_expr = SQLFactory._to_expression(result_val)
1107
1171
 
1108
1172
  condition = exp.EQ(this=col_expr, expression=search_expr)
1109
- conditions.append(exp.When(this=condition, then=result_expr))
1173
+ conditions.append(exp.If(this=condition, true=result_expr))
1110
1174
 
1111
- return exp.Case(ifs=conditions, default=default)
1175
+ return FunctionExpression(exp.Case(ifs=conditions, default=default))
1112
1176
 
1113
1177
  @staticmethod
1114
- def cast(column: Union[str, exp.Expression], data_type: str) -> exp.Expression:
1178
+ def cast(column: Union[str, exp.Expression], data_type: str) -> ConversionExpression:
1115
1179
  """Create a CAST expression for type conversion.
1116
1180
 
1117
1181
  Args:
@@ -1122,10 +1186,10 @@ class SQLFactory:
1122
1186
  CAST expression.
1123
1187
  """
1124
1188
  col_expr = exp.column(column) if isinstance(column, str) else column
1125
- return exp.Cast(this=col_expr, to=exp.DataType.build(data_type))
1189
+ return ConversionExpression(exp.Cast(this=col_expr, to=exp.DataType.build(data_type)))
1126
1190
 
1127
1191
  @staticmethod
1128
- def coalesce(*expressions: Union[str, exp.Expression]) -> exp.Expression:
1192
+ def coalesce(*expressions: Union[str, exp.Expression]) -> ConversionExpression:
1129
1193
  """Create a COALESCE expression.
1130
1194
 
1131
1195
  Args:
@@ -1135,10 +1199,12 @@ class SQLFactory:
1135
1199
  COALESCE expression.
1136
1200
  """
1137
1201
  exprs = [exp.column(expr) if isinstance(expr, str) else expr for expr in expressions]
1138
- return exp.Coalesce(expressions=exprs)
1202
+ return ConversionExpression(exp.Coalesce(expressions=exprs))
1139
1203
 
1140
1204
  @staticmethod
1141
- def nvl(column: Union[str, exp.Expression], substitute_value: Union[str, exp.Expression, Any]) -> exp.Expression:
1205
+ def nvl(
1206
+ column: Union[str, exp.Expression], substitute_value: Union[str, exp.Expression, Any]
1207
+ ) -> ConversionExpression:
1142
1208
  """Create an NVL (Oracle-style) expression using COALESCE.
1143
1209
 
1144
1210
  Args:
@@ -1149,15 +1215,15 @@ class SQLFactory:
1149
1215
  COALESCE expression equivalent to NVL.
1150
1216
  """
1151
1217
  col_expr = exp.column(column) if isinstance(column, str) else column
1152
- sub_expr = SQLFactory._to_literal(substitute_value)
1153
- return exp.Coalesce(expressions=[col_expr, sub_expr])
1218
+ sub_expr = SQLFactory._to_expression(substitute_value)
1219
+ return ConversionExpression(exp.Coalesce(expressions=[col_expr, sub_expr]))
1154
1220
 
1155
1221
  @staticmethod
1156
1222
  def nvl2(
1157
1223
  column: Union[str, exp.Expression],
1158
1224
  value_if_not_null: Union[str, exp.Expression, Any],
1159
1225
  value_if_null: Union[str, exp.Expression, Any],
1160
- ) -> exp.Expression:
1226
+ ) -> ConversionExpression:
1161
1227
  """Create an NVL2 (Oracle-style) expression using CASE.
1162
1228
 
1163
1229
  NVL2 returns value_if_not_null if column is not NULL,
@@ -1178,22 +1244,22 @@ class SQLFactory:
1178
1244
  ```
1179
1245
  """
1180
1246
  col_expr = exp.column(column) if isinstance(column, str) else column
1181
- not_null_expr = SQLFactory._to_literal(value_if_not_null)
1182
- null_expr = SQLFactory._to_literal(value_if_null)
1247
+ not_null_expr = SQLFactory._to_expression(value_if_not_null)
1248
+ null_expr = SQLFactory._to_expression(value_if_null)
1183
1249
 
1184
1250
  # Create CASE WHEN column IS NOT NULL THEN value_if_not_null ELSE value_if_null END
1185
1251
  is_null = exp.Is(this=col_expr, expression=exp.Null())
1186
1252
  condition = exp.Not(this=is_null)
1187
1253
  when_clause = exp.If(this=condition, true=not_null_expr)
1188
1254
 
1189
- return exp.Case(ifs=[when_clause], default=null_expr)
1255
+ return ConversionExpression(exp.Case(ifs=[when_clause], default=null_expr))
1190
1256
 
1191
1257
  # ===================
1192
1258
  # Bulk Operations
1193
1259
  # ===================
1194
1260
 
1195
1261
  @staticmethod
1196
- def bulk_insert(table_name: str, column_count: int, placeholder_style: str = "?") -> exp.Expression:
1262
+ def bulk_insert(table_name: str, column_count: int, placeholder_style: str = "?") -> FunctionExpression:
1197
1263
  """Create bulk INSERT expression for executemany operations.
1198
1264
 
1199
1265
  This is specifically for bulk loading operations like CSV ingestion where
@@ -1228,13 +1294,15 @@ class SQLFactory:
1228
1294
  # Creates: INSERT INTO "my_table" VALUES (:1, :2, :3)
1229
1295
  ```
1230
1296
  """
1231
- return exp.Insert(
1232
- this=exp.Table(this=exp.to_identifier(table_name)),
1233
- expression=exp.Values(
1234
- expressions=[
1235
- exp.Tuple(expressions=[exp.Placeholder(this=placeholder_style) for _ in range(column_count)])
1236
- ]
1237
- ),
1297
+ return FunctionExpression(
1298
+ exp.Insert(
1299
+ this=exp.Table(this=exp.to_identifier(table_name)),
1300
+ expression=exp.Values(
1301
+ expressions=[
1302
+ exp.Tuple(expressions=[exp.Placeholder(this=placeholder_style) for _ in range(column_count)])
1303
+ ]
1304
+ ),
1305
+ )
1238
1306
  )
1239
1307
 
1240
1308
  def truncate(self, table_name: str) -> "Truncate":
@@ -1288,7 +1356,7 @@ class SQLFactory:
1288
1356
  self,
1289
1357
  partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
1290
1358
  order_by: Optional[Union[str, list[str], exp.Expression]] = None,
1291
- ) -> exp.Expression:
1359
+ ) -> FunctionExpression:
1292
1360
  """Create a ROW_NUMBER() window function.
1293
1361
 
1294
1362
  Args:
@@ -1304,7 +1372,7 @@ class SQLFactory:
1304
1372
  self,
1305
1373
  partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
1306
1374
  order_by: Optional[Union[str, list[str], exp.Expression]] = None,
1307
- ) -> exp.Expression:
1375
+ ) -> FunctionExpression:
1308
1376
  """Create a RANK() window function.
1309
1377
 
1310
1378
  Args:
@@ -1320,7 +1388,7 @@ class SQLFactory:
1320
1388
  self,
1321
1389
  partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
1322
1390
  order_by: Optional[Union[str, list[str], exp.Expression]] = None,
1323
- ) -> exp.Expression:
1391
+ ) -> FunctionExpression:
1324
1392
  """Create a DENSE_RANK() window function.
1325
1393
 
1326
1394
  Args:
@@ -1338,7 +1406,7 @@ class SQLFactory:
1338
1406
  func_args: list[exp.Expression],
1339
1407
  partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
1340
1408
  order_by: Optional[Union[str, list[str], exp.Expression]] = None,
1341
- ) -> exp.Expression:
1409
+ ) -> FunctionExpression:
1342
1410
  """Helper to create window function expressions.
1343
1411
 
1344
1412
  Args:
@@ -1364,418 +1432,13 @@ class SQLFactory:
1364
1432
 
1365
1433
  if order_by:
1366
1434
  if isinstance(order_by, str):
1367
- over_args["order"] = [exp.column(order_by).asc()]
1435
+ over_args["order"] = exp.Order(expressions=[exp.column(order_by).asc()])
1368
1436
  elif isinstance(order_by, list):
1369
- over_args["order"] = [exp.column(col).asc() for col in order_by]
1437
+ over_args["order"] = exp.Order(expressions=[exp.column(col).asc() for col in order_by])
1370
1438
  elif isinstance(order_by, exp.Expression):
1371
- over_args["order"] = [order_by]
1372
-
1373
- return exp.Window(this=func_expr, **over_args)
1374
-
1375
-
1376
- @trait
1377
- class Case:
1378
- """Builder for CASE expressions using the SQL factory.
1379
-
1380
- Example:
1381
- ```python
1382
- from sqlspec import sql
1383
-
1384
- case_expr = (
1385
- sql.case()
1386
- .when(sql.age < 18, "Minor")
1387
- .when(sql.age < 65, "Adult")
1388
- .else_("Senior")
1389
- .end()
1390
- )
1391
- ```
1392
- """
1393
-
1394
- def __init__(self) -> None:
1395
- """Initialize the CASE expression builder."""
1396
- self._conditions: list[exp.If] = []
1397
- self._default: Optional[exp.Expression] = None
1398
-
1399
- def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
1400
- """Equal to (==) - convert to expression then compare."""
1401
- from sqlspec.builder._column import ColumnExpression
1402
-
1403
- case_expr = exp.Case(ifs=self._conditions, default=self._default)
1404
- if other is None:
1405
- return ColumnExpression(exp.Is(this=case_expr, expression=exp.Null()))
1406
- return ColumnExpression(exp.EQ(this=case_expr, expression=exp.convert(other)))
1407
-
1408
- def __hash__(self) -> int:
1409
- """Make Case hashable."""
1410
- return hash(id(self))
1411
-
1412
- def when(self, condition: Union[str, exp.Expression], value: Union[str, exp.Expression, Any]) -> "Case":
1413
- """Add a WHEN clause.
1414
-
1415
- Args:
1416
- condition: Condition to test.
1417
- value: Value to return if condition is true.
1418
-
1419
- Returns:
1420
- Self for method chaining.
1421
- """
1422
- cond_expr = exp.maybe_parse(condition) or exp.column(condition) if isinstance(condition, str) else condition
1423
- val_expr = SQLFactory._to_literal(value)
1424
-
1425
- # SQLGlot uses exp.If for CASE WHEN clauses, not exp.When
1426
- when_clause = exp.If(this=cond_expr, true=val_expr)
1427
- self._conditions.append(when_clause)
1428
- return self
1429
-
1430
- def else_(self, value: Union[str, exp.Expression, Any]) -> "Case":
1431
- """Add an ELSE clause.
1432
-
1433
- Args:
1434
- value: Default value to return.
1435
-
1436
- Returns:
1437
- Self for method chaining.
1438
- """
1439
- self._default = SQLFactory._to_literal(value)
1440
- return self
1441
-
1442
- def end(self) -> exp.Expression:
1443
- """Complete the CASE expression.
1444
-
1445
- Returns:
1446
- Complete CASE expression.
1447
- """
1448
- return exp.Case(ifs=self._conditions, default=self._default)
1449
-
1450
- def as_(self, alias: str) -> exp.Alias:
1451
- """Complete the CASE expression with an alias.
1452
-
1453
- Args:
1454
- alias: Alias name for the CASE expression.
1455
-
1456
- Returns:
1457
- Aliased CASE expression.
1458
- """
1459
- case_expr = exp.Case(ifs=self._conditions, default=self._default)
1460
- return cast("exp.Alias", exp.alias_(case_expr, alias))
1461
-
1462
-
1463
- @trait
1464
- class WindowFunctionBuilder:
1465
- """Builder for window functions with fluent syntax.
1466
-
1467
- Example:
1468
- ```python
1469
- from sqlspec import sql
1470
-
1471
- # sql.row_number_.partition_by("department").order_by("salary")
1472
- window_func = (
1473
- sql.row_number_.partition_by("department")
1474
- .order_by("salary")
1475
- .as_("row_num")
1476
- )
1477
- ```
1478
- """
1479
-
1480
- def __init__(self, function_name: str) -> None:
1481
- """Initialize the window function builder.
1482
-
1483
- Args:
1484
- function_name: Name of the window function (row_number, rank, etc.)
1485
- """
1486
- self._function_name = function_name
1487
- self._partition_by_cols: list[exp.Expression] = []
1488
- self._order_by_cols: list[exp.Expression] = []
1489
- self._alias: Optional[str] = None
1490
-
1491
- def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
1492
- """Equal to (==) - convert to expression then compare."""
1493
- from sqlspec.builder._column import ColumnExpression
1494
-
1495
- window_expr = self._build_expression()
1496
- if other is None:
1497
- return ColumnExpression(exp.Is(this=window_expr, expression=exp.Null()))
1498
- return ColumnExpression(exp.EQ(this=window_expr, expression=exp.convert(other)))
1499
-
1500
- def __hash__(self) -> int:
1501
- """Make WindowFunctionBuilder hashable."""
1502
- return hash(id(self))
1503
-
1504
- def partition_by(self, *columns: Union[str, exp.Expression]) -> "WindowFunctionBuilder":
1505
- """Add PARTITION BY clause.
1506
-
1507
- Args:
1508
- *columns: Columns to partition by.
1439
+ over_args["order"] = exp.Order(expressions=[order_by])
1509
1440
 
1510
- Returns:
1511
- Self for method chaining.
1512
- """
1513
- for col in columns:
1514
- col_expr = exp.column(col) if isinstance(col, str) else col
1515
- self._partition_by_cols.append(col_expr)
1516
- return self
1517
-
1518
- def order_by(self, *columns: Union[str, exp.Expression]) -> "WindowFunctionBuilder":
1519
- """Add ORDER BY clause.
1520
-
1521
- Args:
1522
- *columns: Columns to order by.
1523
-
1524
- Returns:
1525
- Self for method chaining.
1526
- """
1527
- for col in columns:
1528
- if isinstance(col, str):
1529
- col_expr = exp.column(col).asc()
1530
- self._order_by_cols.append(col_expr)
1531
- else:
1532
- # Convert to ordered expression
1533
- self._order_by_cols.append(exp.Ordered(this=col, desc=False))
1534
- return self
1535
-
1536
- def as_(self, alias: str) -> exp.Alias:
1537
- """Complete the window function with an alias.
1538
-
1539
- Args:
1540
- alias: Alias name for the window function.
1541
-
1542
- Returns:
1543
- Aliased window function expression.
1544
- """
1545
- window_expr = self._build_expression()
1546
- return cast("exp.Alias", exp.alias_(window_expr, alias))
1547
-
1548
- def build(self) -> exp.Expression:
1549
- """Complete the window function without an alias.
1550
-
1551
- Returns:
1552
- Window function expression.
1553
- """
1554
- return self._build_expression()
1555
-
1556
- def _build_expression(self) -> exp.Expression:
1557
- """Build the complete window function expression."""
1558
- # Create the function expression
1559
- func_expr = exp.Anonymous(this=self._function_name.upper(), expressions=[])
1560
-
1561
- # Build the OVER clause arguments
1562
- over_args: dict[str, Any] = {}
1563
-
1564
- if self._partition_by_cols:
1565
- over_args["partition_by"] = self._partition_by_cols
1566
-
1567
- if self._order_by_cols:
1568
- over_args["order"] = exp.Order(expressions=self._order_by_cols)
1569
-
1570
- return exp.Window(this=func_expr, **over_args)
1571
-
1572
-
1573
- @trait
1574
- class SubqueryBuilder:
1575
- """Builder for subquery operations with fluent syntax.
1576
-
1577
- Example:
1578
- ```python
1579
- from sqlspec import sql
1580
-
1581
- # sql.exists_(subquery)
1582
- exists_check = sql.exists_(
1583
- sql.select("1")
1584
- .from_("orders")
1585
- .where_eq("user_id", sql.users.id)
1586
- )
1587
-
1588
- # sql.in_(subquery)
1589
- in_check = sql.in_(
1590
- sql.select("category_id")
1591
- .from_("categories")
1592
- .where_eq("active", True)
1593
- )
1594
- ```
1595
- """
1596
-
1597
- def __init__(self, operation: str) -> None:
1598
- """Initialize the subquery builder.
1599
-
1600
- Args:
1601
- operation: Type of subquery operation (exists, in, any, all)
1602
- """
1603
- self._operation = operation
1604
-
1605
- def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
1606
- """Equal to (==) - not typically used but needed for type consistency."""
1607
- from sqlspec.builder._column import ColumnExpression
1608
-
1609
- # SubqueryBuilder doesn't have a direct expression, so this is a placeholder
1610
- # In practice, this shouldn't be called as subqueries are used differently
1611
- placeholder_expr = exp.Literal.string(f"subquery_{self._operation}")
1612
- if other is None:
1613
- return ColumnExpression(exp.Is(this=placeholder_expr, expression=exp.Null()))
1614
- return ColumnExpression(exp.EQ(this=placeholder_expr, expression=exp.convert(other)))
1615
-
1616
- def __hash__(self) -> int:
1617
- """Make SubqueryBuilder hashable."""
1618
- return hash(id(self))
1619
-
1620
- def __call__(self, subquery: Union[str, exp.Expression, Any]) -> exp.Expression:
1621
- """Build the subquery expression.
1622
-
1623
- Args:
1624
- subquery: The subquery - can be a SQL string, SelectBuilder, or expression
1625
-
1626
- Returns:
1627
- The subquery expression (EXISTS, IN, ANY, ALL, etc.)
1628
- """
1629
- subquery_expr: exp.Expression
1630
- if isinstance(subquery, str):
1631
- # Parse as SQL
1632
- parsed: Optional[exp.Expression] = exp.maybe_parse(subquery)
1633
- if not parsed:
1634
- msg = f"Could not parse subquery SQL: {subquery}"
1635
- raise SQLBuilderError(msg)
1636
- subquery_expr = parsed
1637
- elif hasattr(subquery, "build") and callable(getattr(subquery, "build", None)):
1638
- # It's a query builder - build it to get the SQL and parse
1639
- built_query = subquery.build() # pyright: ignore[reportAttributeAccessIssue]
1640
- subquery_expr = exp.maybe_parse(built_query.sql)
1641
- if not subquery_expr:
1642
- msg = f"Could not parse built query: {built_query.sql}"
1643
- raise SQLBuilderError(msg)
1644
- elif isinstance(subquery, exp.Expression):
1645
- subquery_expr = subquery
1646
- else:
1647
- # Try to convert to expression
1648
- parsed = exp.maybe_parse(str(subquery))
1649
- if not parsed:
1650
- msg = f"Could not convert subquery to expression: {subquery}"
1651
- raise SQLBuilderError(msg)
1652
- subquery_expr = parsed
1653
-
1654
- # Build the appropriate expression based on operation
1655
- if self._operation == "exists":
1656
- return exp.Exists(this=subquery_expr)
1657
- if self._operation == "in":
1658
- # For IN, we create a subquery that can be used with WHERE column IN (subquery)
1659
- return exp.In(expressions=[subquery_expr])
1660
- if self._operation == "any":
1661
- return exp.Any(this=subquery_expr)
1662
- if self._operation == "all":
1663
- return exp.All(this=subquery_expr)
1664
- msg = f"Unknown subquery operation: {self._operation}"
1665
- raise SQLBuilderError(msg)
1666
-
1667
-
1668
- @trait
1669
- class JoinBuilder:
1670
- """Builder for JOIN operations with fluent syntax.
1671
-
1672
- Example:
1673
- ```python
1674
- from sqlspec import sql
1675
-
1676
- # sql.left_join_("posts").on("users.id = posts.user_id")
1677
- join_clause = sql.left_join_("posts").on(
1678
- "users.id = posts.user_id"
1679
- )
1680
-
1681
- # Or with query builder
1682
- query = (
1683
- sql.select("users.name", "posts.title")
1684
- .from_("users")
1685
- .join(
1686
- sql.left_join_("posts").on(
1687
- "users.id = posts.user_id"
1688
- )
1689
- )
1690
- )
1691
- ```
1692
- """
1693
-
1694
- def __init__(self, join_type: str) -> None:
1695
- """Initialize the join builder.
1696
-
1697
- Args:
1698
- join_type: Type of join (inner, left, right, full, cross)
1699
- """
1700
- self._join_type = join_type.upper()
1701
- self._table: Optional[Union[str, exp.Expression]] = None
1702
- self._condition: Optional[exp.Expression] = None
1703
- self._alias: Optional[str] = None
1704
-
1705
- def __eq__(self, other: object) -> "ColumnExpression": # type: ignore[override]
1706
- """Equal to (==) - not typically used but needed for type consistency."""
1707
- from sqlspec.builder._column import ColumnExpression
1708
-
1709
- # JoinBuilder doesn't have a direct expression, so this is a placeholder
1710
- # In practice, this shouldn't be called as joins are used differently
1711
- placeholder_expr = exp.Literal.string(f"join_{self._join_type.lower()}")
1712
- if other is None:
1713
- return ColumnExpression(exp.Is(this=placeholder_expr, expression=exp.Null()))
1714
- return ColumnExpression(exp.EQ(this=placeholder_expr, expression=exp.convert(other)))
1715
-
1716
- def __hash__(self) -> int:
1717
- """Make JoinBuilder hashable."""
1718
- return hash(id(self))
1719
-
1720
- def __call__(self, table: Union[str, exp.Expression], alias: Optional[str] = None) -> "JoinBuilder":
1721
- """Set the table to join.
1722
-
1723
- Args:
1724
- table: Table name or expression to join
1725
- alias: Optional alias for the table
1726
-
1727
- Returns:
1728
- Self for method chaining
1729
- """
1730
- self._table = table
1731
- self._alias = alias
1732
- return self
1733
-
1734
- def on(self, condition: Union[str, exp.Expression]) -> exp.Expression:
1735
- """Set the join condition and build the JOIN expression.
1736
-
1737
- Args:
1738
- condition: JOIN condition (e.g., "users.id = posts.user_id")
1739
-
1740
- Returns:
1741
- Complete JOIN expression
1742
- """
1743
- if not self._table:
1744
- msg = "Table must be set before calling .on()"
1745
- raise SQLBuilderError(msg)
1746
-
1747
- # Parse the condition
1748
- condition_expr: exp.Expression
1749
- if isinstance(condition, str):
1750
- parsed: Optional[exp.Expression] = exp.maybe_parse(condition)
1751
- condition_expr = parsed or exp.condition(condition)
1752
- else:
1753
- condition_expr = condition
1754
-
1755
- # Build table expression
1756
- table_expr: exp.Expression
1757
- if isinstance(self._table, str):
1758
- table_expr = exp.to_table(self._table)
1759
- if self._alias:
1760
- table_expr = exp.alias_(table_expr, self._alias)
1761
- else:
1762
- table_expr = self._table
1763
- if self._alias:
1764
- table_expr = exp.alias_(table_expr, self._alias)
1765
-
1766
- # Create the appropriate join type using same pattern as existing JoinClauseMixin
1767
- if self._join_type == "INNER JOIN":
1768
- return exp.Join(this=table_expr, on=condition_expr)
1769
- if self._join_type == "LEFT JOIN":
1770
- return exp.Join(this=table_expr, on=condition_expr, side="LEFT")
1771
- if self._join_type == "RIGHT JOIN":
1772
- return exp.Join(this=table_expr, on=condition_expr, side="RIGHT")
1773
- if self._join_type == "FULL JOIN":
1774
- return exp.Join(this=table_expr, on=condition_expr, side="FULL", kind="OUTER")
1775
- if self._join_type == "CROSS JOIN":
1776
- # CROSS JOIN doesn't use ON condition
1777
- return exp.Join(this=table_expr, kind="CROSS")
1778
- return exp.Join(this=table_expr, on=condition_expr)
1441
+ return FunctionExpression(exp.Window(this=func_expr, **over_args))
1779
1442
 
1780
1443
 
1781
1444
  # Create a default SQL factory instance