sqlspec 0.17.0__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sqlspec might be problematic. Click here for more details.

sqlspec/_sql.py CHANGED
@@ -33,13 +33,22 @@ from sqlspec.builder import (
33
33
  Truncate,
34
34
  Update,
35
35
  )
36
+ from sqlspec.builder._expression_wrappers import (
37
+ AggregateExpression,
38
+ ConversionExpression,
39
+ FunctionExpression,
40
+ MathExpression,
41
+ StringExpression,
42
+ )
36
43
  from sqlspec.builder.mixins._join_operations import JoinBuilder
37
44
  from sqlspec.builder.mixins._select_operations import Case, SubqueryBuilder, WindowFunctionBuilder
38
45
  from sqlspec.exceptions import SQLBuilderError
39
46
 
40
47
  if TYPE_CHECKING:
48
+ from sqlspec.builder._expression_wrappers import ExpressionWrapper
41
49
  from sqlspec.core.statement import SQL
42
50
 
51
+
43
52
  __all__ = (
44
53
  "AlterTable",
45
54
  "Case",
@@ -746,7 +755,9 @@ class SQLFactory:
746
755
  # ===================
747
756
 
748
757
  @staticmethod
749
- def count(column: Union[str, exp.Expression] = "*", distinct: bool = False) -> exp.Expression:
758
+ def count(
759
+ column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
760
+ ) -> AggregateExpression:
750
761
  """Create a COUNT expression.
751
762
 
752
763
  Args:
@@ -756,12 +767,14 @@ class SQLFactory:
756
767
  Returns:
757
768
  COUNT expression.
758
769
  """
759
- if column == "*":
760
- return exp.Count(this=exp.Star(), distinct=distinct)
761
- col_expr = exp.column(column) if isinstance(column, str) else column
762
- return exp.Count(this=col_expr, distinct=distinct)
763
-
764
- def count_distinct(self, column: Union[str, exp.Expression]) -> exp.Expression:
770
+ if isinstance(column, str) and column == "*":
771
+ expr = exp.Count(this=exp.Star(), distinct=distinct)
772
+ else:
773
+ col_expr = SQLFactory._extract_expression(column)
774
+ expr = exp.Count(this=col_expr, distinct=distinct)
775
+ return AggregateExpression(expr)
776
+
777
+ def count_distinct(self, column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
765
778
  """Create a COUNT(DISTINCT column) expression.
766
779
 
767
780
  Args:
@@ -773,7 +786,9 @@ class SQLFactory:
773
786
  return self.count(column, distinct=True)
774
787
 
775
788
  @staticmethod
776
- def sum(column: Union[str, exp.Expression], distinct: bool = False) -> exp.Expression:
789
+ def sum(
790
+ column: Union[str, exp.Expression, "ExpressionWrapper", "Case"], distinct: bool = False
791
+ ) -> AggregateExpression:
777
792
  """Create a SUM expression.
778
793
 
779
794
  Args:
@@ -783,11 +798,11 @@ class SQLFactory:
783
798
  Returns:
784
799
  SUM expression.
785
800
  """
786
- col_expr = exp.column(column) if isinstance(column, str) else column
787
- return exp.Sum(this=col_expr, distinct=distinct)
801
+ col_expr = SQLFactory._extract_expression(column)
802
+ return AggregateExpression(exp.Sum(this=col_expr, distinct=distinct))
788
803
 
789
804
  @staticmethod
790
- def avg(column: Union[str, exp.Expression]) -> exp.Expression:
805
+ def avg(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
791
806
  """Create an AVG expression.
792
807
 
793
808
  Args:
@@ -796,11 +811,11 @@ class SQLFactory:
796
811
  Returns:
797
812
  AVG expression.
798
813
  """
799
- col_expr = exp.column(column) if isinstance(column, str) else column
800
- return exp.Avg(this=col_expr)
814
+ col_expr = SQLFactory._extract_expression(column)
815
+ return AggregateExpression(exp.Avg(this=col_expr))
801
816
 
802
817
  @staticmethod
803
- def max(column: Union[str, exp.Expression]) -> exp.Expression:
818
+ def max(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
804
819
  """Create a MAX expression.
805
820
 
806
821
  Args:
@@ -809,11 +824,11 @@ class SQLFactory:
809
824
  Returns:
810
825
  MAX expression.
811
826
  """
812
- col_expr = exp.column(column) if isinstance(column, str) else column
813
- return exp.Max(this=col_expr)
827
+ col_expr = SQLFactory._extract_expression(column)
828
+ return AggregateExpression(exp.Max(this=col_expr))
814
829
 
815
830
  @staticmethod
816
- def min(column: Union[str, exp.Expression]) -> exp.Expression:
831
+ def min(column: Union[str, exp.Expression, "ExpressionWrapper", "Case"]) -> AggregateExpression:
817
832
  """Create a MIN expression.
818
833
 
819
834
  Args:
@@ -822,15 +837,15 @@ class SQLFactory:
822
837
  Returns:
823
838
  MIN expression.
824
839
  """
825
- col_expr = exp.column(column) if isinstance(column, str) else column
826
- return exp.Min(this=col_expr)
840
+ col_expr = SQLFactory._extract_expression(column)
841
+ return AggregateExpression(exp.Min(this=col_expr))
827
842
 
828
843
  # ===================
829
844
  # Advanced SQL Operations
830
845
  # ===================
831
846
 
832
847
  @staticmethod
833
- def rollup(*columns: Union[str, exp.Expression]) -> exp.Expression:
848
+ def rollup(*columns: Union[str, exp.Expression]) -> FunctionExpression:
834
849
  """Create a ROLLUP expression for GROUP BY clauses.
835
850
 
836
851
  Args:
@@ -850,10 +865,10 @@ class SQLFactory:
850
865
  ```
851
866
  """
852
867
  column_exprs = [exp.column(col) if isinstance(col, str) else col for col in columns]
853
- return exp.Rollup(expressions=column_exprs)
868
+ return FunctionExpression(exp.Rollup(expressions=column_exprs))
854
869
 
855
870
  @staticmethod
856
- def cube(*columns: Union[str, exp.Expression]) -> exp.Expression:
871
+ def cube(*columns: Union[str, exp.Expression]) -> FunctionExpression:
857
872
  """Create a CUBE expression for GROUP BY clauses.
858
873
 
859
874
  Args:
@@ -873,10 +888,10 @@ class SQLFactory:
873
888
  ```
874
889
  """
875
890
  column_exprs = [exp.column(col) if isinstance(col, str) else col for col in columns]
876
- return exp.Cube(expressions=column_exprs)
891
+ return FunctionExpression(exp.Cube(expressions=column_exprs))
877
892
 
878
893
  @staticmethod
879
- def grouping_sets(*column_sets: Union[tuple[str, ...], list[str]]) -> exp.Expression:
894
+ def grouping_sets(*column_sets: Union[tuple[str, ...], list[str]]) -> FunctionExpression:
880
895
  """Create a GROUPING SETS expression for GROUP BY clauses.
881
896
 
882
897
  Args:
@@ -908,10 +923,10 @@ class SQLFactory:
908
923
  else:
909
924
  set_expressions.append(exp.column(column_set))
910
925
 
911
- return exp.GroupingSets(expressions=set_expressions)
926
+ return FunctionExpression(exp.GroupingSets(expressions=set_expressions))
912
927
 
913
928
  @staticmethod
914
- def any(values: Union[list[Any], exp.Expression, str]) -> exp.Expression:
929
+ def any(values: Union[list[Any], exp.Expression, str]) -> FunctionExpression:
915
930
  """Create an ANY expression for use with comparison operators.
916
931
 
917
932
  Args:
@@ -932,18 +947,18 @@ class SQLFactory:
932
947
  ```
933
948
  """
934
949
  if isinstance(values, list):
935
- literals = [SQLFactory._to_literal(v) for v in values]
936
- return exp.Any(this=exp.Array(expressions=literals))
950
+ literals = [SQLFactory.to_literal(v) for v in values]
951
+ return FunctionExpression(exp.Any(this=exp.Array(expressions=literals)))
937
952
  if isinstance(values, str):
938
953
  # Parse as SQL
939
954
  parsed = exp.maybe_parse(values) # type: ignore[var-annotated]
940
955
  if parsed:
941
- return exp.Any(this=parsed)
942
- return exp.Any(this=exp.Literal.string(values))
943
- return exp.Any(this=values)
956
+ return FunctionExpression(exp.Any(this=parsed))
957
+ return FunctionExpression(exp.Any(this=exp.Literal.string(values)))
958
+ return FunctionExpression(exp.Any(this=values))
944
959
 
945
960
  @staticmethod
946
- def not_any_(values: Union[list[Any], exp.Expression, str]) -> exp.Expression:
961
+ def not_any_(values: Union[list[Any], exp.Expression, str]) -> FunctionExpression:
947
962
  """Create a NOT ANY expression for use with comparison operators.
948
963
 
949
964
  Args:
@@ -963,14 +978,14 @@ class SQLFactory:
963
978
  )
964
979
  ```
965
980
  """
966
- return SQLFactory.any(values) # NOT ANY is handled by the comparison operator
981
+ return SQLFactory.any(values)
967
982
 
968
983
  # ===================
969
984
  # String Functions
970
985
  # ===================
971
986
 
972
987
  @staticmethod
973
- def concat(*expressions: Union[str, exp.Expression]) -> exp.Expression:
988
+ def concat(*expressions: Union[str, exp.Expression]) -> StringExpression:
974
989
  """Create a CONCAT expression.
975
990
 
976
991
  Args:
@@ -980,10 +995,10 @@ class SQLFactory:
980
995
  CONCAT expression.
981
996
  """
982
997
  exprs = [exp.column(expr) if isinstance(expr, str) else expr for expr in expressions]
983
- return exp.Concat(expressions=exprs)
998
+ return StringExpression(exp.Concat(expressions=exprs))
984
999
 
985
1000
  @staticmethod
986
- def upper(column: Union[str, exp.Expression]) -> exp.Expression:
1001
+ def upper(column: Union[str, exp.Expression]) -> StringExpression:
987
1002
  """Create an UPPER expression.
988
1003
 
989
1004
  Args:
@@ -993,10 +1008,10 @@ class SQLFactory:
993
1008
  UPPER expression.
994
1009
  """
995
1010
  col_expr = exp.column(column) if isinstance(column, str) else column
996
- return exp.Upper(this=col_expr)
1011
+ return StringExpression(exp.Upper(this=col_expr))
997
1012
 
998
1013
  @staticmethod
999
- def lower(column: Union[str, exp.Expression]) -> exp.Expression:
1014
+ def lower(column: Union[str, exp.Expression]) -> StringExpression:
1000
1015
  """Create a LOWER expression.
1001
1016
 
1002
1017
  Args:
@@ -1006,10 +1021,10 @@ class SQLFactory:
1006
1021
  LOWER expression.
1007
1022
  """
1008
1023
  col_expr = exp.column(column) if isinstance(column, str) else column
1009
- return exp.Lower(this=col_expr)
1024
+ return StringExpression(exp.Lower(this=col_expr))
1010
1025
 
1011
1026
  @staticmethod
1012
- def length(column: Union[str, exp.Expression]) -> exp.Expression:
1027
+ def length(column: Union[str, exp.Expression]) -> StringExpression:
1013
1028
  """Create a LENGTH expression.
1014
1029
 
1015
1030
  Args:
@@ -1019,14 +1034,14 @@ class SQLFactory:
1019
1034
  LENGTH expression.
1020
1035
  """
1021
1036
  col_expr = exp.column(column) if isinstance(column, str) else column
1022
- return exp.Length(this=col_expr)
1037
+ return StringExpression(exp.Length(this=col_expr))
1023
1038
 
1024
1039
  # ===================
1025
1040
  # Math Functions
1026
1041
  # ===================
1027
1042
 
1028
1043
  @staticmethod
1029
- def round(column: Union[str, exp.Expression], decimals: int = 0) -> exp.Expression:
1044
+ def round(column: Union[str, exp.Expression], decimals: int = 0) -> MathExpression:
1030
1045
  """Create a ROUND expression.
1031
1046
 
1032
1047
  Args:
@@ -1038,15 +1053,15 @@ class SQLFactory:
1038
1053
  """
1039
1054
  col_expr = exp.column(column) if isinstance(column, str) else column
1040
1055
  if decimals == 0:
1041
- return exp.Round(this=col_expr)
1042
- return exp.Round(this=col_expr, expression=exp.Literal.number(decimals))
1056
+ return MathExpression(exp.Round(this=col_expr))
1057
+ return MathExpression(exp.Round(this=col_expr, expression=exp.Literal.number(decimals)))
1043
1058
 
1044
1059
  # ===================
1045
1060
  # Conversion Functions
1046
1061
  # ===================
1047
1062
 
1048
1063
  @staticmethod
1049
- def _to_literal(value: Any) -> exp.Expression:
1064
+ def to_literal(value: Any) -> FunctionExpression:
1050
1065
  """Convert a Python value to a SQLGlot literal expression.
1051
1066
 
1052
1067
  Uses SQLGlot's built-in exp.convert() function for optimal dialect-agnostic
@@ -1063,12 +1078,52 @@ class SQLFactory:
1063
1078
  Returns:
1064
1079
  SQLGlot expression representing the literal value.
1065
1080
  """
1081
+ if isinstance(value, exp.Expression):
1082
+ return FunctionExpression(value)
1083
+ return FunctionExpression(exp.convert(value))
1084
+
1085
+ @staticmethod
1086
+ def _to_expression(value: Any) -> exp.Expression:
1087
+ """Convert a Python value to a raw SQLGlot expression.
1088
+
1089
+ Args:
1090
+ value: Python value or SQLGlot expression to convert.
1091
+
1092
+ Returns:
1093
+ Raw SQLGlot expression.
1094
+ """
1095
+ if isinstance(value, exp.Expression):
1096
+ return value
1097
+ return exp.convert(value)
1098
+
1099
+ @staticmethod
1100
+ def _extract_expression(value: Any) -> exp.Expression:
1101
+ """Extract SQLGlot expression from value, handling our wrapper types.
1102
+
1103
+ Args:
1104
+ value: String, SQLGlot expression, or our wrapper type.
1105
+
1106
+ Returns:
1107
+ Raw SQLGlot expression.
1108
+ """
1109
+ from sqlspec.builder._expression_wrappers import ExpressionWrapper
1110
+ from sqlspec.builder.mixins._select_operations import Case
1111
+
1112
+ if isinstance(value, str):
1113
+ return exp.column(value)
1114
+ if isinstance(value, Column):
1115
+ return value._expression
1116
+ if isinstance(value, ExpressionWrapper):
1117
+ return value.expression
1118
+ if isinstance(value, Case):
1119
+ # Case has _expression property via trait
1120
+ return exp.Case(ifs=value._conditions, default=value._default)
1066
1121
  if isinstance(value, exp.Expression):
1067
1122
  return value
1068
1123
  return exp.convert(value)
1069
1124
 
1070
1125
  @staticmethod
1071
- def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression, Any]) -> exp.Expression:
1126
+ def decode(column: Union[str, exp.Expression], *args: Union[str, exp.Expression, Any]) -> FunctionExpression:
1072
1127
  """Create a DECODE expression (Oracle-style conditional logic).
1073
1128
 
1074
1129
  DECODE compares column to each search value and returns the corresponding result.
@@ -1105,22 +1160,22 @@ class SQLFactory:
1105
1160
  for i in range(0, len(args) - 1, 2):
1106
1161
  if i + 1 >= len(args):
1107
1162
  # Odd number of args means last one is default
1108
- default = SQLFactory._to_literal(args[i])
1163
+ default = SQLFactory._to_expression(args[i])
1109
1164
  break
1110
1165
 
1111
1166
  search_val = args[i]
1112
1167
  result_val = args[i + 1]
1113
1168
 
1114
- search_expr = SQLFactory._to_literal(search_val)
1115
- result_expr = SQLFactory._to_literal(result_val)
1169
+ search_expr = SQLFactory._to_expression(search_val)
1170
+ result_expr = SQLFactory._to_expression(result_val)
1116
1171
 
1117
1172
  condition = exp.EQ(this=col_expr, expression=search_expr)
1118
- conditions.append(exp.When(this=condition, then=result_expr))
1173
+ conditions.append(exp.If(this=condition, true=result_expr))
1119
1174
 
1120
- return exp.Case(ifs=conditions, default=default)
1175
+ return FunctionExpression(exp.Case(ifs=conditions, default=default))
1121
1176
 
1122
1177
  @staticmethod
1123
- def cast(column: Union[str, exp.Expression], data_type: str) -> exp.Expression:
1178
+ def cast(column: Union[str, exp.Expression], data_type: str) -> ConversionExpression:
1124
1179
  """Create a CAST expression for type conversion.
1125
1180
 
1126
1181
  Args:
@@ -1131,10 +1186,10 @@ class SQLFactory:
1131
1186
  CAST expression.
1132
1187
  """
1133
1188
  col_expr = exp.column(column) if isinstance(column, str) else column
1134
- return exp.Cast(this=col_expr, to=exp.DataType.build(data_type))
1189
+ return ConversionExpression(exp.Cast(this=col_expr, to=exp.DataType.build(data_type)))
1135
1190
 
1136
1191
  @staticmethod
1137
- def coalesce(*expressions: Union[str, exp.Expression]) -> exp.Expression:
1192
+ def coalesce(*expressions: Union[str, exp.Expression]) -> ConversionExpression:
1138
1193
  """Create a COALESCE expression.
1139
1194
 
1140
1195
  Args:
@@ -1144,10 +1199,12 @@ class SQLFactory:
1144
1199
  COALESCE expression.
1145
1200
  """
1146
1201
  exprs = [exp.column(expr) if isinstance(expr, str) else expr for expr in expressions]
1147
- return exp.Coalesce(expressions=exprs)
1202
+ return ConversionExpression(exp.Coalesce(expressions=exprs))
1148
1203
 
1149
1204
  @staticmethod
1150
- def nvl(column: Union[str, exp.Expression], substitute_value: Union[str, exp.Expression, Any]) -> exp.Expression:
1205
+ def nvl(
1206
+ column: Union[str, exp.Expression], substitute_value: Union[str, exp.Expression, Any]
1207
+ ) -> ConversionExpression:
1151
1208
  """Create an NVL (Oracle-style) expression using COALESCE.
1152
1209
 
1153
1210
  Args:
@@ -1158,15 +1215,15 @@ class SQLFactory:
1158
1215
  COALESCE expression equivalent to NVL.
1159
1216
  """
1160
1217
  col_expr = exp.column(column) if isinstance(column, str) else column
1161
- sub_expr = SQLFactory._to_literal(substitute_value)
1162
- return exp.Coalesce(expressions=[col_expr, sub_expr])
1218
+ sub_expr = SQLFactory._to_expression(substitute_value)
1219
+ return ConversionExpression(exp.Coalesce(expressions=[col_expr, sub_expr]))
1163
1220
 
1164
1221
  @staticmethod
1165
1222
  def nvl2(
1166
1223
  column: Union[str, exp.Expression],
1167
1224
  value_if_not_null: Union[str, exp.Expression, Any],
1168
1225
  value_if_null: Union[str, exp.Expression, Any],
1169
- ) -> exp.Expression:
1226
+ ) -> ConversionExpression:
1170
1227
  """Create an NVL2 (Oracle-style) expression using CASE.
1171
1228
 
1172
1229
  NVL2 returns value_if_not_null if column is not NULL,
@@ -1187,22 +1244,22 @@ class SQLFactory:
1187
1244
  ```
1188
1245
  """
1189
1246
  col_expr = exp.column(column) if isinstance(column, str) else column
1190
- not_null_expr = SQLFactory._to_literal(value_if_not_null)
1191
- null_expr = SQLFactory._to_literal(value_if_null)
1247
+ not_null_expr = SQLFactory._to_expression(value_if_not_null)
1248
+ null_expr = SQLFactory._to_expression(value_if_null)
1192
1249
 
1193
1250
  # Create CASE WHEN column IS NOT NULL THEN value_if_not_null ELSE value_if_null END
1194
1251
  is_null = exp.Is(this=col_expr, expression=exp.Null())
1195
1252
  condition = exp.Not(this=is_null)
1196
1253
  when_clause = exp.If(this=condition, true=not_null_expr)
1197
1254
 
1198
- return exp.Case(ifs=[when_clause], default=null_expr)
1255
+ return ConversionExpression(exp.Case(ifs=[when_clause], default=null_expr))
1199
1256
 
1200
1257
  # ===================
1201
1258
  # Bulk Operations
1202
1259
  # ===================
1203
1260
 
1204
1261
  @staticmethod
1205
- def bulk_insert(table_name: str, column_count: int, placeholder_style: str = "?") -> exp.Expression:
1262
+ def bulk_insert(table_name: str, column_count: int, placeholder_style: str = "?") -> FunctionExpression:
1206
1263
  """Create bulk INSERT expression for executemany operations.
1207
1264
 
1208
1265
  This is specifically for bulk loading operations like CSV ingestion where
@@ -1237,13 +1294,15 @@ class SQLFactory:
1237
1294
  # Creates: INSERT INTO "my_table" VALUES (:1, :2, :3)
1238
1295
  ```
1239
1296
  """
1240
- return exp.Insert(
1241
- this=exp.Table(this=exp.to_identifier(table_name)),
1242
- expression=exp.Values(
1243
- expressions=[
1244
- exp.Tuple(expressions=[exp.Placeholder(this=placeholder_style) for _ in range(column_count)])
1245
- ]
1246
- ),
1297
+ return FunctionExpression(
1298
+ exp.Insert(
1299
+ this=exp.Table(this=exp.to_identifier(table_name)),
1300
+ expression=exp.Values(
1301
+ expressions=[
1302
+ exp.Tuple(expressions=[exp.Placeholder(this=placeholder_style) for _ in range(column_count)])
1303
+ ]
1304
+ ),
1305
+ )
1247
1306
  )
1248
1307
 
1249
1308
  def truncate(self, table_name: str) -> "Truncate":
@@ -1297,7 +1356,7 @@ class SQLFactory:
1297
1356
  self,
1298
1357
  partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
1299
1358
  order_by: Optional[Union[str, list[str], exp.Expression]] = None,
1300
- ) -> exp.Expression:
1359
+ ) -> FunctionExpression:
1301
1360
  """Create a ROW_NUMBER() window function.
1302
1361
 
1303
1362
  Args:
@@ -1313,7 +1372,7 @@ class SQLFactory:
1313
1372
  self,
1314
1373
  partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
1315
1374
  order_by: Optional[Union[str, list[str], exp.Expression]] = None,
1316
- ) -> exp.Expression:
1375
+ ) -> FunctionExpression:
1317
1376
  """Create a RANK() window function.
1318
1377
 
1319
1378
  Args:
@@ -1329,7 +1388,7 @@ class SQLFactory:
1329
1388
  self,
1330
1389
  partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
1331
1390
  order_by: Optional[Union[str, list[str], exp.Expression]] = None,
1332
- ) -> exp.Expression:
1391
+ ) -> FunctionExpression:
1333
1392
  """Create a DENSE_RANK() window function.
1334
1393
 
1335
1394
  Args:
@@ -1347,7 +1406,7 @@ class SQLFactory:
1347
1406
  func_args: list[exp.Expression],
1348
1407
  partition_by: Optional[Union[str, list[str], exp.Expression]] = None,
1349
1408
  order_by: Optional[Union[str, list[str], exp.Expression]] = None,
1350
- ) -> exp.Expression:
1409
+ ) -> FunctionExpression:
1351
1410
  """Helper to create window function expressions.
1352
1411
 
1353
1412
  Args:
@@ -1373,13 +1432,13 @@ class SQLFactory:
1373
1432
 
1374
1433
  if order_by:
1375
1434
  if isinstance(order_by, str):
1376
- over_args["order"] = [exp.column(order_by).asc()]
1435
+ over_args["order"] = exp.Order(expressions=[exp.column(order_by).asc()])
1377
1436
  elif isinstance(order_by, list):
1378
- over_args["order"] = [exp.column(col).asc() for col in order_by]
1437
+ over_args["order"] = exp.Order(expressions=[exp.column(col).asc() for col in order_by])
1379
1438
  elif isinstance(order_by, exp.Expression):
1380
- over_args["order"] = [order_by]
1439
+ over_args["order"] = exp.Order(expressions=[order_by])
1381
1440
 
1382
- return exp.Window(this=func_expr, **over_args)
1441
+ return FunctionExpression(exp.Window(this=func_expr, **over_args))
1383
1442
 
1384
1443
 
1385
1444
  # Create a default SQL factory instance
@@ -5,7 +5,7 @@ for building SQL conditions with type safety and parameter binding.
5
5
  """
6
6
 
7
7
  from collections.abc import Iterable
8
- from typing import Any, Optional
8
+ from typing import Any, Optional, cast
9
9
 
10
10
  from sqlglot import exp
11
11
 
@@ -241,6 +241,10 @@ class Column:
241
241
  """Create a DESC ordering expression."""
242
242
  return exp.Ordered(this=self._expression, desc=True)
243
243
 
244
+ def as_(self, alias: str) -> exp.Alias:
245
+ """Create an aliased expression."""
246
+ return cast("exp.Alias", exp.alias_(self._expression, alias))
247
+
244
248
  def __repr__(self) -> str:
245
249
  if self.table:
246
250
  return f"Column<{self.table}.{self.name}>"
@@ -0,0 +1,46 @@
1
+ """Expression wrapper classes for proper type annotations."""
2
+
3
+ from typing import cast
4
+
5
+ from sqlglot import exp
6
+
7
+ __all__ = ("AggregateExpression", "ConversionExpression", "FunctionExpression", "MathExpression", "StringExpression")
8
+
9
+
10
+ class ExpressionWrapper:
11
+ """Base wrapper for SQLGlot expressions."""
12
+
13
+ def __init__(self, expression: exp.Expression) -> None:
14
+ self._expression = expression
15
+
16
+ def as_(self, alias: str) -> exp.Alias:
17
+ """Create an aliased expression."""
18
+ return cast("exp.Alias", exp.alias_(self._expression, alias))
19
+
20
+ @property
21
+ def expression(self) -> exp.Expression:
22
+ """Get the underlying SQLGlot expression."""
23
+ return self._expression
24
+
25
+ def __str__(self) -> str:
26
+ return str(self._expression)
27
+
28
+
29
+ class AggregateExpression(ExpressionWrapper):
30
+ """Aggregate functions like COUNT, SUM, AVG."""
31
+
32
+
33
+ class FunctionExpression(ExpressionWrapper):
34
+ """General SQL functions."""
35
+
36
+
37
+ class MathExpression(ExpressionWrapper):
38
+ """Mathematical functions like ROUND."""
39
+
40
+
41
+ class StringExpression(ExpressionWrapper):
42
+ """String functions like UPPER, LOWER, LENGTH."""
43
+
44
+
45
+ class ConversionExpression(ExpressionWrapper):
46
+ """Conversion functions like CAST, COALESCE."""
@@ -412,9 +412,7 @@ class ConflictBuilder:
412
412
  # Create ON CONFLICT with proper structure
413
413
  conflict_keys = [exp.to_identifier(col) for col in self._columns] if self._columns else None
414
414
  on_conflict = exp.OnConflict(
415
- conflict_keys=conflict_keys,
416
- action=exp.var("DO UPDATE"),
417
- expressions=set_expressions if set_expressions else None,
415
+ conflict_keys=conflict_keys, action=exp.var("DO UPDATE"), expressions=set_expressions or None
418
416
  )
419
417
 
420
418
  insert_expr.set("conflict", on_conflict)
@@ -44,26 +44,26 @@ class Update(
44
44
  update_query = (
45
45
  Update()
46
46
  .table("users")
47
- .set(name="John Doe")
48
- .set(email="john@example.com")
47
+ .set_(name="John Doe")
48
+ .set_(email="john@example.com")
49
49
  .where("id = 1")
50
50
  )
51
51
 
52
52
  update_query = (
53
- Update("users").set(name="John Doe").where("id = 1")
53
+ Update("users").set_(name="John Doe").where("id = 1")
54
54
  )
55
55
 
56
56
  update_query = (
57
57
  Update()
58
58
  .table("users")
59
- .set(status="active")
59
+ .set_(status="active")
60
60
  .where_eq("id", 123)
61
61
  )
62
62
 
63
63
  update_query = (
64
64
  Update()
65
65
  .table("users", "u")
66
- .set(name="Updated Name")
66
+ .set_(name="Updated Name")
67
67
  .from_("profiles", "p")
68
68
  .where("u.id = p.user_id AND p.is_verified = true")
69
69
  )
@@ -10,6 +10,9 @@ from sqlspec.builder._parsing_utils import parse_order_expression
10
10
  from sqlspec.exceptions import SQLBuilderError
11
11
 
12
12
  if TYPE_CHECKING:
13
+ from sqlspec.builder._column import Column
14
+ from sqlspec.builder._expression_wrappers import ExpressionWrapper
15
+ from sqlspec.builder.mixins._select_operations import Case
13
16
  from sqlspec.protocols import SQLBuilderProtocol
14
17
 
15
18
  __all__ = ("LimitOffsetClauseMixin", "OrderByClauseMixin", "ReturningClauseMixin")
@@ -24,7 +27,7 @@ class OrderByClauseMixin:
24
27
  # Type annotation for PyRight - this will be provided by the base class
25
28
  _expression: Optional[exp.Expression]
26
29
 
27
- def order_by(self, *items: Union[str, exp.Ordered], desc: bool = False) -> Self:
30
+ def order_by(self, *items: Union[str, exp.Ordered, "Column"], desc: bool = False) -> Self:
28
31
  """Add ORDER BY clause.
29
32
 
30
33
  Args:
@@ -49,7 +52,13 @@ class OrderByClauseMixin:
49
52
  if desc:
50
53
  order_item = order_item.desc()
51
54
  else:
52
- order_item = item
55
+ # Extract expression from Column objects or use as-is for sqlglot expressions
56
+ from sqlspec._sql import SQLFactory
57
+
58
+ extracted_item = SQLFactory._extract_expression(item)
59
+ order_item = extracted_item
60
+ if desc and not isinstance(item, exp.Ordered):
61
+ order_item = order_item.desc()
53
62
  current_expr = current_expr.order_by(order_item, copy=False)
54
63
  builder._expression = current_expr
55
64
  return cast("Self", builder)
@@ -111,7 +120,7 @@ class ReturningClauseMixin:
111
120
  # Type annotation for PyRight - this will be provided by the base class
112
121
  _expression: Optional[exp.Expression]
113
122
 
114
- def returning(self, *columns: Union[str, exp.Expression]) -> Self:
123
+ def returning(self, *columns: Union[str, exp.Expression, "Column", "ExpressionWrapper", "Case"]) -> Self:
115
124
  """Add RETURNING clause to the statement.
116
125
 
117
126
  Args:
@@ -130,6 +139,9 @@ class ReturningClauseMixin:
130
139
  if not isinstance(self._expression, valid_types):
131
140
  msg = "RETURNING is only supported for INSERT, UPDATE, and DELETE statements."
132
141
  raise SQLBuilderError(msg)
133
- returning_exprs = [exp.column(c) if isinstance(c, str) else c for c in columns]
142
+ # Extract expressions from various wrapper types
143
+ from sqlspec._sql import SQLFactory
144
+
145
+ returning_exprs = [SQLFactory._extract_expression(c) for c in columns]
134
146
  self._expression.set("returning", exp.Returning(expressions=returning_exprs))
135
147
  return self
@@ -858,7 +858,7 @@ class Case:
858
858
  from sqlspec._sql import SQLFactory
859
859
 
860
860
  cond_expr = exp.maybe_parse(condition) or exp.column(condition) if isinstance(condition, str) else condition
861
- val_expr = SQLFactory._to_literal(value)
861
+ val_expr = SQLFactory._to_expression(value)
862
862
 
863
863
  # SQLGlot uses exp.If for CASE WHEN clauses, not exp.When
864
864
  when_clause = exp.If(this=cond_expr, true=val_expr)
@@ -876,7 +876,7 @@ class Case:
876
876
  """
877
877
  from sqlspec._sql import SQLFactory
878
878
 
879
- self._default = SQLFactory._to_literal(value)
879
+ self._default = SQLFactory._to_expression(value)
880
880
  return self
881
881
 
882
882
  def end(self) -> Self:
@@ -111,10 +111,10 @@ class UpdateSetClauseMixin:
111
111
  """Set columns and values for the UPDATE statement.
112
112
 
113
113
  Supports:
114
- - set(column, value)
115
- - set(mapping)
116
- - set(**kwargs)
117
- - set(mapping, **kwargs)
114
+ - set_(column, value)
115
+ - set_(mapping)
116
+ - set_(**kwargs)
117
+ - set_(mapping, **kwargs)
118
118
 
119
119
  Args:
120
120
  *args: Either (column, value) or a mapping.
sqlspec/protocols.py CHANGED
@@ -20,6 +20,7 @@ if TYPE_CHECKING:
20
20
  __all__ = (
21
21
  "BytesConvertibleProtocol",
22
22
  "DictProtocol",
23
+ "ExpressionWithAliasProtocol",
23
24
  "FilterAppenderProtocol",
24
25
  "FilterParameterProtocol",
25
26
  "HasExpressionProtocol",
@@ -172,6 +173,15 @@ class BytesConvertibleProtocol(Protocol):
172
173
  ...
173
174
 
174
175
 
176
+ @runtime_checkable
177
+ class ExpressionWithAliasProtocol(Protocol):
178
+ """Protocol for SQL expressions that support aliasing with as_() method."""
179
+
180
+ def as_(self, alias: str, **kwargs: Any) -> "exp.Alias":
181
+ """Create an aliased expression."""
182
+ ...
183
+
184
+
175
185
  @runtime_checkable
176
186
  class ObjectStoreItemProtocol(Protocol):
177
187
  """Protocol for object store items with path/key attributes."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sqlspec
3
- Version: 0.17.0
3
+ Version: 0.17.1
4
4
  Summary: SQL Experiments in Python
5
5
  Project-URL: Discord, https://discord.gg/litestar
6
6
  Project-URL: Issue, https://github.com/litestar-org/sqlspec/issues/
@@ -2,14 +2,14 @@ sqlspec/__init__.py,sha256=wy9ZqukBvr2WZkUmrljRHPGFjpUKWkEXxFWlTLuciJc,2123
2
2
  sqlspec/__main__.py,sha256=lXBKZMOXA1uY735Rnsb-GS7aXy0nt22tYmd2X9FcxrY,253
3
3
  sqlspec/__metadata__.py,sha256=IUw6MCTy1oeUJ1jAVYbuJLkOWbiAWorZ5W-E-SAD9N4,395
4
4
  sqlspec/_serialization.py,sha256=6U5-smk2h2yl0i6am2prtOLJTdu4NJQdcLlSfSUMaUQ,2590
5
- sqlspec/_sql.py,sha256=UrJCc_iriWw9Dm1eeGswWoFvkeRtfueuTG_f01O7CAk,46623
5
+ sqlspec/_sql.py,sha256=M1sirpRds7Bs3r72gRUTxO8Yp7qup-GPD3WLiEEAwVg,49135
6
6
  sqlspec/_typing.py,sha256=jv-7QHGLrJLfnP86bR-Xcmj3PDoddNZEKDz_vYRBiAU,22684
7
7
  sqlspec/base.py,sha256=OhFSpDaweCjZEabpTtl95pg91WzMko76d7sFOjyZSoo,25730
8
8
  sqlspec/cli.py,sha256=3ZxPwl4neNWyrAkM9J9ccC_gaFigDJbhuZfx15JVE7E,9903
9
9
  sqlspec/config.py,sha256=s7csxGK0SlTvB9jOvHlKKm4Y272RInQrUd6hGXwy31Q,14974
10
10
  sqlspec/exceptions.py,sha256=mCqNJ0JSPA-TUPpAfdctwwqJWbiNsWap5ATNNRdczwU,6159
11
11
  sqlspec/loader.py,sha256=KSL5OsjPsuZZJrgohdhdmimwDqVPj_BvHqHIpP2Fq_0,25818
12
- sqlspec/protocols.py,sha256=iwwy7zdIBV7TcoxIYpKuTvN5fGiULQac2f4a-saxyKU,12937
12
+ sqlspec/protocols.py,sha256=GxKDn-Uw4fBVaKprYE3rCuhSz1CF2yV4awIrXbV2_j4,13236
13
13
  sqlspec/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  sqlspec/typing.py,sha256=yj8D8O-pkfUVZDfVHEgQaB95-5alwgQbp_sqNJOVhvQ,6301
15
15
  sqlspec/adapters/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -58,25 +58,26 @@ sqlspec/adapters/sqlite/driver.py,sha256=uAhasoCNOV30gTvl1EUpofRcm8YiEW5RnVy07Xy
58
58
  sqlspec/adapters/sqlite/pool.py,sha256=FxlnBksmiPDLyMB0C7j3wjfc45PU_RrukTwNAzN3DDQ,4977
59
59
  sqlspec/builder/__init__.py,sha256=AY_D6TlSpYFJ7sU39S9xubq-eXZTpuNDlisoeK0LojU,1675
60
60
  sqlspec/builder/_base.py,sha256=yz-6e-x66vrNne6z5zq4Ae0C3p0mHEEIe1Y8er-A0pg,17812
61
- sqlspec/builder/_column.py,sha256=46baZj403BKfGjZcMc9LtQfMLeMQ7ROPyFL64V7dDM0,13124
61
+ sqlspec/builder/_column.py,sha256=GfTMFPzJIhwDVKefMbyom-HoEj-KSKB92-XoZfchnEo,13289
62
62
  sqlspec/builder/_ddl.py,sha256=A_fV4d92o2ZOhX150YMSsQDm3veQTQrwlxgLdFMBBfg,48184
63
63
  sqlspec/builder/_ddl_utils.py,sha256=1mFSNe9w5rZXA1Ud4CTuca7eibi0XayHrIPcnEgRB7s,4034
64
64
  sqlspec/builder/_delete.py,sha256=xWA5nQB3UB8kpEGXN2k5ynt4cGZ7blkNoURpI0bKoeg,2264
65
- sqlspec/builder/_insert.py,sha256=rARWh5olbur6oPP_3FoAuJp8irj5cRLW0mKdWEx2cqU,16896
65
+ sqlspec/builder/_expression_wrappers.py,sha256=HTl8qAFD4sZNXLD2akkJWGCPse2HWS237mTGE4Cx_7I,1244
66
+ sqlspec/builder/_insert.py,sha256=p0UegiB9WNkHWafTPgfyWsQlt76oRp7VvuMK3_zIjDw,16850
66
67
  sqlspec/builder/_merge.py,sha256=95PLQSKA3zjk0wTZG3m817fTZpsS95PrS2qF34iLAP8,2004
67
68
  sqlspec/builder/_parsing_utils.py,sha256=RH8OFBFAetalEgHW5JLcEyyCdW_awVdy07MjboOkqL4,8383
68
69
  sqlspec/builder/_select.py,sha256=m5sfyuAssjlNimLLNBAeFooVIfM2FgKN1boPfdsOkaA,5785
69
- sqlspec/builder/_update.py,sha256=UFHM_uWVY5RnZQ6winiyjKNtBryKRAXJlXtCVQdifyw,6015
70
+ sqlspec/builder/_update.py,sha256=QieiguEq9T_UECv10f1xwQJp58gc3w246cvtCDpPwuw,6020
70
71
  sqlspec/builder/mixins/__init__.py,sha256=YXhAzKmQbQtne5j26SKWY8PUxwosl0RhlhLoahAdkj0,1885
71
72
  sqlspec/builder/mixins/_cte_and_set_ops.py,sha256=p5O9m_jvpaWxv1XP9Ys2DRI-qOTq30rr2EwYjAbIT8o,9088
72
73
  sqlspec/builder/mixins/_delete_operations.py,sha256=l0liajnoAfRgtWtyStuAIfxreEFRkNO4DtBwyGqAfic,1198
73
74
  sqlspec/builder/mixins/_insert_operations.py,sha256=3ZuVNAPgJG0fzOPaprwUPa0Un3NP7erHwtCg8AGZWD8,9500
74
75
  sqlspec/builder/mixins/_join_operations.py,sha256=8o_aApK5cmJbNCNfWa4bs5fR2zgQUjon5p-oyiW41Qw,11440
75
76
  sqlspec/builder/mixins/_merge_operations.py,sha256=e9QDv1s84-2F2ZAZrr7UJtKXhy3X0NDN7AZ--8mOTKw,24193
76
- sqlspec/builder/mixins/_order_limit_operations.py,sha256=ABPuFSqHRv7XaS9-3HNZO3Jn0QovhJrkYT158xxduns,4835
77
+ sqlspec/builder/mixins/_order_limit_operations.py,sha256=KrTXE0HmKCMn4AkNFyMhEZrmJ2Dl-o5kSiCjk2jNIKM,5499
77
78
  sqlspec/builder/mixins/_pivot_operations.py,sha256=j5vdzXuEqB1Jn3Ie_QjVwSH2_OEi65oZ64bQJHd3jXo,6108
78
- sqlspec/builder/mixins/_select_operations.py,sha256=m7iejdCw04mfxohNiHWeQSKQyI94vrhJ_JcYRhUPYw8,35314
79
- sqlspec/builder/mixins/_update_operations.py,sha256=lk9VRM0KGmYhofbWChemJxSZF6I0LhrRgqMXVmXZMeU,8650
79
+ sqlspec/builder/mixins/_select_operations.py,sha256=JznnA0bS6yXlylSLpLQ9xNdnWOo0ckqirOSDkceRmJw,35320
80
+ sqlspec/builder/mixins/_update_operations.py,sha256=rcmnmSSwHSujqe0gd-dKsWHLX9nPMS1we1Y3_h4S7G4,8654
80
81
  sqlspec/builder/mixins/_where_clause.py,sha256=1iz7Y2x_ooG2bOCu2zX0v5_bkGFpAckVQKvnyrR1JNQ,36373
81
82
  sqlspec/core/__init__.py,sha256=rU_xGsXhqIOnBbyB2InhJknYePm5NQ2DSWdBigror4g,1775
82
83
  sqlspec/core/cache.py,sha256=cLL9bd5wn1oeMzn5E5Ym0sAemA8U4QP6B55x4L9-26I,27044
@@ -129,9 +130,9 @@ sqlspec/utils/singleton.py,sha256=SKnszJi1NPeERgX7IjVIGYAYx4XqR1E_rph3bU6olAU,10
129
130
  sqlspec/utils/sync_tools.py,sha256=WRuk1ZEhb_0CRrumAdnmi-i-dV6qVd3cgJyZw8RY9QQ,7390
130
131
  sqlspec/utils/text.py,sha256=n5K0gvXvyCc8jNteNKsBOymwf_JnQ65f3lu0YaYq4Ys,2898
131
132
  sqlspec/utils/type_guards.py,sha256=9C4SRebO4JiQrMzcJZFUA0KjSU48G26RmX6lbijyjBg,30476
132
- sqlspec-0.17.0.dist-info/METADATA,sha256=3VTksHJXA8Azr2boYZpQtPO27y_QXZq_jYk-zuftYzE,16822
133
- sqlspec-0.17.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
134
- sqlspec-0.17.0.dist-info/entry_points.txt,sha256=G-ZqY1Nuuw3Iys7nXw23f6ILenk_Lt47VdK2mhJCWHg,53
135
- sqlspec-0.17.0.dist-info/licenses/LICENSE,sha256=MdujfZ6l5HuLz4mElxlu049itenOR3gnhN1_Nd3nVcM,1078
136
- sqlspec-0.17.0.dist-info/licenses/NOTICE,sha256=Lyir8ozXWov7CyYS4huVaOCNrtgL17P-bNV-5daLntQ,1634
137
- sqlspec-0.17.0.dist-info/RECORD,,
133
+ sqlspec-0.17.1.dist-info/METADATA,sha256=mjoS62X8qL5LIdUs0VRuK1kCZPw2Rsk8GVjyGByNQzg,16822
134
+ sqlspec-0.17.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
135
+ sqlspec-0.17.1.dist-info/entry_points.txt,sha256=G-ZqY1Nuuw3Iys7nXw23f6ILenk_Lt47VdK2mhJCWHg,53
136
+ sqlspec-0.17.1.dist-info/licenses/LICENSE,sha256=MdujfZ6l5HuLz4mElxlu049itenOR3gnhN1_Nd3nVcM,1078
137
+ sqlspec-0.17.1.dist-info/licenses/NOTICE,sha256=Lyir8ozXWov7CyYS4huVaOCNrtgL17P-bNV-5daLntQ,1634
138
+ sqlspec-0.17.1.dist-info/RECORD,,