sqlframe 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '1.4.0'
16
- __version_tuple__ = version_tuple = (1, 4, 0)
15
+ __version__ = version = '1.5.1'
16
+ __version_tuple__ = version_tuple = (1, 5, 1)
sqlframe/base/column.py CHANGED
@@ -251,6 +251,13 @@ class Column:
251
251
  return lit(value)
252
252
  return Column(value)
253
253
 
254
+ @property
255
+ def dtype(self) -> t.Optional[DataType]:
256
+ expression = self.expression.unalias()
257
+ if isinstance(expression, exp.Cast):
258
+ return expression.args.get("to")
259
+ return None
260
+
254
261
  def copy(self) -> Column:
255
262
  return Column(self.expression.copy())
256
263
 
@@ -7,6 +7,7 @@ import typing as t
7
7
 
8
8
  from sqlglot import exp as expression
9
9
  from sqlglot.helper import ensure_list
10
+ from sqlglot.helper import flatten as _flatten
10
11
 
11
12
  from sqlframe.base.column import Column
12
13
  from sqlframe.base.util import get_func_from_session
@@ -60,6 +61,10 @@ def first_always_ignore_nulls(col: ColumnOrName, ignorenulls: t.Optional[bool] =
60
61
  return first(col)
61
62
 
62
63
 
64
+ def bitwise_not_from_bitnot(col: ColumnOrName) -> Column:
65
+ return Column.invoke_anonymous_function(col, "BITNOT")
66
+
67
+
63
68
  def factorial_from_case_statement(col: ColumnOrName) -> Column:
64
69
  from sqlframe.base.session import _BaseSession
65
70
 
@@ -160,6 +165,10 @@ def factorial_ensure_int(col: ColumnOrName) -> Column:
160
165
  return Column.invoke_anonymous_function(col_func(col).cast("integer"), "FACTORIAL")
161
166
 
162
167
 
168
+ def skewness_from_skew(col: ColumnOrName) -> Column:
169
+ return Column.invoke_anonymous_function(col, "SKEW")
170
+
171
+
163
172
  def isnan_using_equal(col: ColumnOrName) -> Column:
164
173
  lit = get_func_from_session("lit")
165
174
  return Column(
@@ -219,6 +228,30 @@ def percentile_approx_without_accuracy_and_plural(
219
228
  return Column(make_bracket_approx_percentile(percentage)) # type: ignore
220
229
 
221
230
 
231
+ def percentile_approx_without_accuracy_and_max_array(
232
+ col: ColumnOrName,
233
+ percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
234
+ accuracy: t.Optional[float] = None,
235
+ ) -> Column:
236
+ from sqlframe.base.functions import percentile_approx
237
+
238
+ lit = get_func_from_session("lit")
239
+ array = get_func_from_session("array")
240
+ col_func = get_func_from_session("col")
241
+
242
+ def make_approx_percentile(percentage: float) -> expression.Anonymous:
243
+ return expression.Anonymous(
244
+ this="APPROX_PERCENTILE",
245
+ expressions=[col_func(col).expression, lit(percentage).expression],
246
+ )
247
+
248
+ if accuracy:
249
+ logger.warning("Accuracy is ignored since it is not supported in this dialect")
250
+ if isinstance(percentage, (list, tuple)):
251
+ return array(*[make_approx_percentile(p) for p in percentage])
252
+ return percentile_approx(col, percentage)
253
+
254
+
222
255
  def percentile_without_disc(
223
256
  col: ColumnOrName,
224
257
  percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
@@ -258,6 +291,57 @@ def round_cast_as_numeric(col: ColumnOrName, scale: t.Optional[int] = None) -> C
258
291
  return round(col_func(col).cast("numeric"), scale)
259
292
 
260
293
 
294
+ def bround_using_half_even(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
295
+ lit_func = get_func_from_session("lit")
296
+
297
+ return Column.invoke_anonymous_function(col, "ROUND", scale, lit_func("HALF_TO_EVEN")) # type: ignore
298
+
299
+
300
+ def shiftleft_from_bitshiftleft(col: ColumnOrName, numBits: int) -> Column:
301
+ col_func = get_func_from_session("col")
302
+ lit = get_func_from_session("lit")
303
+
304
+ return Column(
305
+ expression.Anonymous(
306
+ this="BITSHIFTLEFT",
307
+ expressions=[col_func(col).expression, lit(numBits).expression],
308
+ )
309
+ )
310
+
311
+
312
+ def shiftright_from_bitshiftright(col: ColumnOrName, numBits: int) -> Column:
313
+ col_func = get_func_from_session("col")
314
+ lit = get_func_from_session("lit")
315
+
316
+ return Column(
317
+ expression.Anonymous(
318
+ this="BITSHIFTRIGHT",
319
+ expressions=[col_func(col).expression, lit(numBits).expression],
320
+ )
321
+ )
322
+
323
+
324
+ def struct_with_eq(
325
+ col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName
326
+ ) -> Column:
327
+ from sqlframe.base.session import _BaseSession
328
+
329
+ col_func = get_func_from_session("col")
330
+
331
+ columns = [col_func(x) for x in ensure_list(col) + list(cols)]
332
+ expressions = []
333
+ for column in columns:
334
+ expressions.append(
335
+ expression.PropertyEQ(
336
+ this=expression.parse_identifier(
337
+ column.alias_or_name, dialect=_BaseSession().input_dialect
338
+ ),
339
+ expression=column.expression,
340
+ )
341
+ )
342
+ return Column(expression.Struct(expressions=expressions))
343
+
344
+
261
345
  def year_from_extract(col: ColumnOrName) -> Column:
262
346
  col_func = get_func_from_session("col")
263
347
 
@@ -421,6 +505,49 @@ def make_date_from_date_func(year: ColumnOrName, month: ColumnOrName, day: Colum
421
505
  )
422
506
 
423
507
 
508
+ def make_date_date_from_parts(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
509
+ col_func = get_func_from_session("col")
510
+
511
+ return Column(
512
+ expression.Anonymous(
513
+ this="DATE_FROM_PARTS",
514
+ expressions=[
515
+ col_func(year).cast("integer").expression,
516
+ col_func(month).cast("integer").expression,
517
+ col_func(day).cast("integer").expression,
518
+ ],
519
+ )
520
+ )
521
+
522
+
523
+ def date_add_no_date_sub(
524
+ col: ColumnOrName, days: t.Union[ColumnOrName, int], cast_as_date: bool = True
525
+ ) -> Column:
526
+ lit_func = get_func_from_session("col")
527
+
528
+ if isinstance(days, int):
529
+ days = lit_func(days)
530
+
531
+ result = Column.invoke_expression_over_column(
532
+ Column.ensure_col(col).cast("date"),
533
+ expression.DateAdd,
534
+ expression=days,
535
+ unit=expression.Var(this="DAY"),
536
+ )
537
+ if cast_as_date:
538
+ return result.cast("date")
539
+ return result
540
+
541
+
542
+ def date_sub_by_date_add(
543
+ col: ColumnOrName, days: t.Union[ColumnOrName, int], cast_as_date: bool = True
544
+ ) -> Column:
545
+ lit_func = get_func_from_session("col")
546
+ date_add_func = get_func_from_session("date_add")
547
+
548
+ return date_add_func(col, days * lit_func(-1), cast_as_date)
549
+
550
+
424
551
  def to_date_from_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
425
552
  from sqlframe.base.functions import to_date
426
553
 
@@ -520,6 +647,31 @@ def add_months_by_multiplication(
520
647
  return value
521
648
 
522
649
 
650
+ def add_months_using_func(
651
+ start: ColumnOrName, months: t.Union[ColumnOrName, int], cast_as_date: bool = True
652
+ ) -> Column:
653
+ from sqlframe.base.functions import add_months
654
+
655
+ if isinstance(months, int):
656
+ months = get_func_from_session("lit")(months)
657
+ else:
658
+ months = Column.ensure_col(months)
659
+
660
+ value = Column(
661
+ expression.Anonymous(
662
+ this="ADD_MONTHS",
663
+ expressions=[
664
+ Column.ensure_col(start).expression,
665
+ months.expression, # type: ignore
666
+ ],
667
+ )
668
+ )
669
+
670
+ if cast_as_date:
671
+ return value.cast("date")
672
+ return value
673
+
674
+
523
675
  def months_between_from_age_and_extract(
524
676
  date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
525
677
  ) -> Column:
@@ -545,6 +697,23 @@ def months_between_from_age_and_extract(
545
697
  ).cast("bigint")
546
698
 
547
699
 
700
+ def months_between_cast_as_date_cast_roundoff(
701
+ date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
702
+ ) -> Column:
703
+ from sqlframe.base.functions import months_between
704
+
705
+ col_func = get_func_from_session("col")
706
+
707
+ date1 = col_func(date1).cast("date")
708
+ date2 = col_func(date2).cast("date")
709
+
710
+ value = months_between(date1, date2)
711
+
712
+ if roundOff:
713
+ return value.cast("bigint")
714
+ return value
715
+
716
+
548
717
  def from_unixtime_from_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
549
718
  from sqlframe.base.session import _BaseSession
550
719
 
@@ -590,12 +759,30 @@ def bas64_from_encode(col: ColumnOrLiteral) -> Column:
590
759
  )
591
760
 
592
761
 
762
+ def base64_from_base64_encode(col: ColumnOrLiteral) -> Column:
763
+ return Column(
764
+ expression.Anonymous(
765
+ this="BASE64_ENCODE",
766
+ expressions=[Column(col).expression],
767
+ )
768
+ )
769
+
770
+
593
771
  def unbase64_from_decode(col: ColumnOrLiteral) -> Column:
594
772
  return Column(
595
773
  expression.Decode(this=Column(col).expression, charset=expression.Literal.string("base64"))
596
774
  )
597
775
 
598
776
 
777
+ def unbase64_from_base64_decode_string(col: ColumnOrLiteral) -> Column:
778
+ return Column(
779
+ expression.Anonymous(
780
+ this="BASE64_DECODE_STRING",
781
+ expressions=[Column(col).expression],
782
+ )
783
+ )
784
+
785
+
599
786
  def decode_from_blob(col: ColumnOrLiteral, charset: str) -> Column:
600
787
  return Column(
601
788
  expression.Decode(
@@ -716,6 +903,19 @@ def overlay_from_substr(
716
903
  )
717
904
 
718
905
 
906
+ def levenshtein_edit_distance(
907
+ left: ColumnOrName, right: ColumnOrName, threshold: t.Optional[int] = None
908
+ ) -> Column:
909
+ if threshold is not None:
910
+ logger.warning("Threshold is ignored since it is not supported in this dialect")
911
+ return Column(
912
+ expression.Anonymous(
913
+ this="EDITDISTANCE",
914
+ expressions=[Column.ensure_col(left).expression, Column.ensure_col(right).expression],
915
+ )
916
+ )
917
+
918
+
719
919
  def split_no_limit(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
720
920
  from sqlframe.base.functions import split
721
921
 
@@ -758,6 +958,17 @@ def split_with_split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = N
758
958
  )
759
959
 
760
960
 
961
+ def regexp_extract_coalesce_empty_str(
962
+ str: ColumnOrName, pattern: str, idx: t.Optional[int] = None
963
+ ) -> Column:
964
+ from sqlframe.base.functions import regexp_extract
965
+
966
+ coalesce = get_func_from_session("coalesce")
967
+ lit_func = get_func_from_session("lit")
968
+
969
+ return coalesce(regexp_extract(str, pattern, idx), lit_func(""))
970
+
971
+
761
972
  def array_contains_any(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
762
973
  lit = get_func_from_session("lit")
763
974
  value_col = value if isinstance(value, Column) else lit(value)
@@ -830,6 +1041,16 @@ def slice_with_brackets(
830
1041
  )
831
1042
 
832
1043
 
1044
+ def array_join_no_null_replacement(
1045
+ col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
1046
+ ) -> Column:
1047
+ from sqlframe.base.functions import array_join
1048
+
1049
+ if null_replacement is None:
1050
+ logger.warning("Null replacement is ignored since it is not supported in this dialect")
1051
+ return array_join(col, delimiter)
1052
+
1053
+
833
1054
  def array_join_null_replacement_with_transform(
834
1055
  col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
835
1056
  ) -> Column:
@@ -860,6 +1081,57 @@ def array_join_null_replacement_with_transform(
860
1081
  return array_join(col, delimiter)
861
1082
 
862
1083
 
1084
+ def array_contains_cast_variant(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
1085
+ from sqlframe.base.functions import array_contains
1086
+
1087
+ lit = get_func_from_session("lit")
1088
+ value_col = value if isinstance(value, Column) else lit(value)
1089
+ return array_contains(col, value_col.cast("variant"))
1090
+
1091
+
1092
+ def arrays_overlap_as_plural(col1: ColumnOrName, col2: ColumnOrName) -> Column:
1093
+ col_func = get_func_from_session("col")
1094
+
1095
+ return Column(
1096
+ expression.Anonymous(
1097
+ this="ARRAYS_OVERLAP",
1098
+ expressions=[col_func(col1).expression, col_func(col2).expression],
1099
+ )
1100
+ )
1101
+
1102
+
1103
+ def slice_as_array_slice(
1104
+ x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
1105
+ ) -> Column:
1106
+ lit = get_func_from_session("lit")
1107
+
1108
+ start_col = start if isinstance(start, Column) else lit(start)
1109
+ length_col = length if isinstance(length, Column) else lit(length)
1110
+ return Column.invoke_anonymous_function(
1111
+ x, "ARRAY_SLICE", start_col - lit(1), start_col + length_col
1112
+ )
1113
+
1114
+
1115
+ def array_position_cast_variant_and_flip(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
1116
+ when = get_func_from_session("when")
1117
+ lit = get_func_from_session("lit")
1118
+ value_col = value if isinstance(value, Column) else lit(value)
1119
+ # Some engines return NULL if item is not found but Spark expects 0 so we coalesce to 0
1120
+ resp = Column.invoke_anonymous_function(value_col.cast("variant"), "ARRAY_POSITION", col)
1121
+ return when(resp.isNotNull(), resp + lit(1)).otherwise(lit(0))
1122
+
1123
+
1124
+ def array_intersect_using_intersection(col1: ColumnOrName, col2: ColumnOrName) -> Column:
1125
+ col_func = get_func_from_session("col")
1126
+
1127
+ return Column(
1128
+ expression.Anonymous(
1129
+ this="ARRAY_INTERSECTION",
1130
+ expressions=[col_func(col1).expression, col_func(col2).expression],
1131
+ )
1132
+ )
1133
+
1134
+
863
1135
  def element_at_using_brackets(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
864
1136
  col_func = get_func_from_session("col")
865
1137
  lit = get_func_from_session("lit")
@@ -936,6 +1208,25 @@ def get_json_object_using_arrow_op(col: ColumnOrName, path: str) -> Column:
936
1208
  )
937
1209
 
938
1210
 
1211
+ def get_json_object_cast_object(col: ColumnOrName, path: str) -> Column:
1212
+ from sqlframe.base.functions import get_json_object
1213
+
1214
+ col_func = get_func_from_session("col")
1215
+
1216
+ return get_json_object(col_func(col).cast("variant"), path)
1217
+
1218
+
1219
+ def create_map_with_cast(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
1220
+ from sqlframe.base.functions import create_map
1221
+
1222
+ col = get_func_from_session("col")
1223
+
1224
+ columns = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols
1225
+ col1_dtype = col(columns[0]).dtype or "VARCHAR"
1226
+ col2_dtype = col(columns[1]).dtype or "VARCHAR"
1227
+ return create_map(*cols).cast(f"MAP({col1_dtype}, {col2_dtype})")
1228
+
1229
+
939
1230
  def array_min_from_sort(col: ColumnOrName) -> Column:
940
1231
  element_at = get_func_from_session("element_at")
941
1232
  array_sort = get_func_from_session("array_sort")
@@ -992,6 +1283,43 @@ def array_max_from_subquery(col: ColumnOrName) -> Column:
992
1283
  return Column(expression.Subquery(this=select)).alias(col_func(col).alias_or_name)
993
1284
 
994
1285
 
1286
+ def sort_array_using_array_sort(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
1287
+ col_func = get_func_from_session("col")
1288
+ lit_func = get_func_from_session("lit")
1289
+ expressions = [col_func(col).expression]
1290
+ asc = asc if asc is not None else True
1291
+ expressions.append(lit_func(asc).expression)
1292
+ if asc:
1293
+ expressions.append(lit_func(True).expression)
1294
+ else:
1295
+ expressions.append(lit_func(False).expression)
1296
+
1297
+ return Column(
1298
+ expression.Anonymous(
1299
+ this="ARRAY_SORT",
1300
+ expressions=expressions,
1301
+ )
1302
+ )
1303
+
1304
+
1305
+ def flatten_using_array_flatten(col: ColumnOrName) -> Column:
1306
+ col_func = get_func_from_session("col")
1307
+
1308
+ return Column(
1309
+ expression.Anonymous(
1310
+ this="ARRAY_FLATTEN",
1311
+ expressions=[col_func(col).expression],
1312
+ )
1313
+ )
1314
+
1315
+
1316
+ def map_concat_using_map_cat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
1317
+ columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
1318
+ if len(columns) == 1:
1319
+ return Column.invoke_anonymous_function(columns[0], "MAP_CAT")
1320
+ return Column.invoke_anonymous_function(columns[0], "MAP_CAT", *columns[1:])
1321
+
1322
+
995
1323
  def sequence_from_generate_series(
996
1324
  start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
997
1325
  ) -> Column:
@@ -1026,6 +1354,27 @@ def sequence_from_generate_array(
1026
1354
  )
1027
1355
 
1028
1356
 
1357
+ def sequence_from_array_generate_range(
1358
+ start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
1359
+ ) -> Column:
1360
+ col_func = get_func_from_session("col")
1361
+ when = get_func_from_session("when")
1362
+ lit = get_func_from_session("lit")
1363
+
1364
+ return Column(
1365
+ expression.Anonymous(
1366
+ this="ARRAY_GENERATE_RANGE",
1367
+ expressions=[
1368
+ col_func(start).expression,
1369
+ (
1370
+ col_func(stop) + when(col_func(stop) > lit(0), lit(1)).otherwise(lit(-1))
1371
+ ).expression,
1372
+ col_func(step).expression if step else lit(1).expression,
1373
+ ],
1374
+ )
1375
+ )
1376
+
1377
+
1029
1378
  def regexp_extract_only_one_group(
1030
1379
  str: ColumnOrName, pattern: str, idx: t.Optional[int] = None
1031
1380
  ) -> Column:
@@ -1048,6 +1397,28 @@ def hex_casted_as_bytes(col: ColumnOrName) -> Column:
1048
1397
  )
1049
1398
 
1050
1399
 
1400
+ def hex_using_encode(col: ColumnOrName) -> Column:
1401
+ col_func = get_func_from_session("col")
1402
+
1403
+ return Column(
1404
+ expression.Anonymous(
1405
+ this="HEX_ENCODE",
1406
+ expressions=[col_func(col).expression],
1407
+ )
1408
+ )
1409
+
1410
+
1411
+ def unhex_hex_decode_str(col: ColumnOrName) -> Column:
1412
+ col_func = get_func_from_session("col")
1413
+
1414
+ return Column(
1415
+ expression.Anonymous(
1416
+ this="HEX_DECODE_STRING",
1417
+ expressions=[col_func(col).expression],
1418
+ )
1419
+ )
1420
+
1421
+
1051
1422
  def bit_length_from_length(col: ColumnOrName) -> Column:
1052
1423
  lit = get_func_from_session("lit")
1053
1424
  col_func = get_func_from_session("col")
@@ -222,7 +222,7 @@ def cot(col: ColumnOrName) -> Column:
222
222
  return Column.invoke_anonymous_function(col, "COT")
223
223
 
224
224
 
225
- @meta(unsupported_engines=["duckdb", "postgres"])
225
+ @meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
226
226
  def csc(col: ColumnOrName) -> Column:
227
227
  return Column.invoke_anonymous_function(col, "CSC")
228
228
 
@@ -281,7 +281,7 @@ def rint(col: ColumnOrName) -> Column:
281
281
  return Column.invoke_anonymous_function(col, "RINT")
282
282
 
283
283
 
284
- @meta(unsupported_engines=["duckdb", "postgres"])
284
+ @meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
285
285
  def sec(col: ColumnOrName) -> Column:
286
286
  return Column.invoke_anonymous_function(col, "SEC")
287
287
 
@@ -407,7 +407,7 @@ def collect_set(col: ColumnOrName) -> Column:
407
407
  return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg)
408
408
 
409
409
 
410
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
410
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
411
411
  def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column:
412
412
  col1_value = lit(col1) if isinstance(col1, (int, float)) else col1
413
413
  col2_value = lit(col2) if isinstance(col2, (int, float)) else col2
@@ -482,7 +482,7 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
482
482
  return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2)
483
483
 
484
484
 
485
- @meta(unsupported_engines=["bigquery", "postgres"])
485
+ @meta(unsupported_engines=["bigquery", "postgres", "snowflake"])
486
486
  def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
487
487
  this = Column.invoke_expression_over_column(col, expression.First)
488
488
  if ignorenulls:
@@ -516,7 +516,7 @@ def isnull(col: ColumnOrName) -> Column:
516
516
  return Column.invoke_anonymous_function(col, "ISNULL")
517
517
 
518
518
 
519
- @meta(unsupported_engines=["bigquery", "postgres"])
519
+ @meta(unsupported_engines=["bigquery", "postgres", "snowflake"])
520
520
  def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
521
521
  this = Column.invoke_expression_over_column(col, expression.Last)
522
522
  if ignorenulls:
@@ -569,7 +569,7 @@ def rand(seed: t.Optional[int] = None) -> Column:
569
569
  return Column.invoke_expression_over_column(None, expression.Rand)
570
570
 
571
571
 
572
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
572
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
573
573
  def randn(seed: t.Optional[int] = None) -> Column:
574
574
  if seed is not None:
575
575
  return Column.invoke_expression_over_column(None, expression.Randn, this=lit(seed))
@@ -610,7 +610,7 @@ def shiftright(col: ColumnOrName, numBits: int) -> Column:
610
610
  shiftRight = shiftright
611
611
 
612
612
 
613
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
613
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
614
614
  def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column:
615
615
  return Column.invoke_anonymous_function(
616
616
  Column.ensure_col(col).cast("bigint"), "SHIFTRIGHTUNSIGNED", lit(numBits)
@@ -631,7 +631,7 @@ def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOr
631
631
  return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns)
632
632
 
633
633
 
634
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
634
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
635
635
  def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column:
636
636
  return Column.invoke_anonymous_function(col, "CONV", lit(fromBase), lit(toBase))
637
637
 
@@ -976,7 +976,7 @@ def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Colum
976
976
  return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column)
977
977
 
978
978
 
979
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
979
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
980
980
  def crc32(col: ColumnOrName) -> Column:
981
981
  return Column.invoke_anonymous_function(col, "CRC32")
982
982
 
@@ -1002,7 +1002,7 @@ def hash(*cols: ColumnOrName) -> Column:
1002
1002
  return Column.invoke_anonymous_function(cols[0], "HASH", *args)
1003
1003
 
1004
1004
 
1005
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1005
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1006
1006
  def xxhash64(*cols: ColumnOrName) -> Column:
1007
1007
  args = cols[1:] if len(cols) > 1 else []
1008
1008
  return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args)
@@ -1069,14 +1069,14 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
1069
1069
  )
1070
1070
 
1071
1071
 
1072
- @meta(unsupported_engines="bigquery")
1072
+ @meta(unsupported_engines=["bigquery", "snowflake"])
1073
1073
  def decode(col: ColumnOrName, charset: str) -> Column:
1074
1074
  return Column.invoke_expression_over_column(
1075
1075
  col, expression.Decode, charset=expression.Literal.string(charset)
1076
1076
  )
1077
1077
 
1078
1078
 
1079
- @meta(unsupported_engines="bigquery")
1079
+ @meta(unsupported_engines=["bigquery", "snowflake"])
1080
1080
  def encode(col: ColumnOrName, charset: str) -> Column:
1081
1081
  return Column.invoke_expression_over_column(
1082
1082
  col, expression.Encode, charset=expression.Literal.string(charset)
@@ -1088,7 +1088,7 @@ def format_number(col: ColumnOrName, d: int) -> Column:
1088
1088
  return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d))
1089
1089
 
1090
1090
 
1091
- @meta()
1091
+ @meta(unsupported_engines="snowflake")
1092
1092
  def format_string(format: str, *cols: ColumnOrName) -> Column:
1093
1093
  format_col = lit(format)
1094
1094
  columns = [Column.ensure_col(x) for x in cols]
@@ -1114,7 +1114,7 @@ def overlay(
1114
1114
  return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos_value)
1115
1115
 
1116
1116
 
1117
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1117
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1118
1118
  def sentences(
1119
1119
  string: ColumnOrName,
1120
1120
  language: t.Optional[ColumnOrName] = None,
@@ -1134,7 +1134,7 @@ def substring(str: ColumnOrName, pos: int, len: int) -> Column:
1134
1134
  return Column.ensure_col(str).substr(pos, len)
1135
1135
 
1136
1136
 
1137
- @meta(unsupported_engines=["duckdb", "postgres"])
1137
+ @meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
1138
1138
  def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
1139
1139
  return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count))
1140
1140
 
@@ -1236,7 +1236,7 @@ def soundex(col: ColumnOrName) -> Column:
1236
1236
  return Column.invoke_anonymous_function(col, "SOUNDEX")
1237
1237
 
1238
1238
 
1239
- @meta(unsupported_engines="postgres")
1239
+ @meta(unsupported_engines=["postgres", "snowflake"])
1240
1240
  def bin(col: ColumnOrName) -> Column:
1241
1241
  return Column.invoke_anonymous_function(col, "BIN")
1242
1242
 
@@ -1288,7 +1288,7 @@ def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
1288
1288
  )
1289
1289
 
1290
1290
 
1291
- @meta(unsupported_engines=["bigquery", "postgres"])
1291
+ @meta(unsupported_engines=["bigquery", "postgres", "snowflake"])
1292
1292
  def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
1293
1293
  return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2)
1294
1294
 
@@ -1382,27 +1382,28 @@ def posexplode(col: ColumnOrName) -> Column:
1382
1382
  return Column.invoke_expression_over_column(col, expression.Posexplode)
1383
1383
 
1384
1384
 
1385
- @meta(unsupported_engines=["duckdb", "postgres"])
1385
+ @meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
1386
1386
  def explode_outer(col: ColumnOrName) -> Column:
1387
1387
  return Column.invoke_expression_over_column(col, expression.ExplodeOuter)
1388
1388
 
1389
1389
 
1390
- @meta(unsupported_engines=["duckdb", "postgres"])
1390
+ @meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
1391
1391
  def posexplode_outer(col: ColumnOrName) -> Column:
1392
1392
  return Column.invoke_expression_over_column(col, expression.PosexplodeOuter)
1393
1393
 
1394
1394
 
1395
- @meta()
1395
+ # Snowflake doesn't support JSONPath which is what this function uses
1396
+ @meta(unsupported_engines="snowflake")
1396
1397
  def get_json_object(col: ColumnOrName, path: str) -> Column:
1397
1398
  return Column.invoke_expression_over_column(col, expression.JSONExtract, expression=lit(path))
1398
1399
 
1399
1400
 
1400
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1401
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1401
1402
  def json_tuple(col: ColumnOrName, *fields: str) -> Column:
1402
1403
  return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields])
1403
1404
 
1404
1405
 
1405
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1406
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1406
1407
  def from_json(
1407
1408
  col: ColumnOrName,
1408
1409
  schema: t.Union[ArrayType, StructType, Column, str],
@@ -1419,7 +1420,7 @@ def from_json(
1419
1420
  return Column.invoke_anonymous_function(col, "FROM_JSON", schema)
1420
1421
 
1421
1422
 
1422
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1423
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1423
1424
  def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
1424
1425
  if options is not None:
1425
1426
  options_col = create_map([lit(x) for x in _flatten(options.items())])
@@ -1443,7 +1444,7 @@ def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = Non
1443
1444
  return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV")
1444
1445
 
1445
1446
 
1446
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1447
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1447
1448
  def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column:
1448
1449
  if options is not None:
1449
1450
  options_col = create_map([lit(x) for x in _flatten(options.items())])
@@ -1486,12 +1487,12 @@ def array_sort(
1486
1487
  return Column.invoke_expression_over_column(col, expression.ArraySort)
1487
1488
 
1488
1489
 
1489
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1490
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1490
1491
  def shuffle(col: ColumnOrName) -> Column:
1491
1492
  return Column.invoke_anonymous_function(col, "SHUFFLE")
1492
1493
 
1493
1494
 
1494
- @meta()
1495
+ @meta(unsupported_engines="snowflake")
1495
1496
  def reverse(col: ColumnOrName) -> Column:
1496
1497
  return Column.invoke_anonymous_function(col, "REVERSE")
1497
1498
 
@@ -1506,28 +1507,28 @@ def map_keys(col: ColumnOrName) -> Column:
1506
1507
  return Column.invoke_anonymous_function(col, "MAP_KEYS")
1507
1508
 
1508
1509
 
1509
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1510
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1510
1511
  def map_values(col: ColumnOrName) -> Column:
1511
1512
  return Column.invoke_anonymous_function(col, "MAP_VALUES")
1512
1513
 
1513
1514
 
1514
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1515
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1515
1516
  def map_entries(col: ColumnOrName) -> Column:
1516
1517
  return Column.invoke_anonymous_function(col, "MAP_ENTRIES")
1517
1518
 
1518
1519
 
1519
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1520
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1520
1521
  def map_from_entries(col: ColumnOrName) -> Column:
1521
1522
  return Column.invoke_expression_over_column(col, expression.MapFromEntries)
1522
1523
 
1523
1524
 
1524
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1525
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1525
1526
  def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column:
1526
1527
  count_col = count if isinstance(count, Column) else lit(count)
1527
1528
  return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col)
1528
1529
 
1529
1530
 
1530
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1531
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1531
1532
  def arrays_zip(*cols: ColumnOrName) -> Column:
1532
1533
  if len(cols) == 1:
1533
1534
  return Column.invoke_anonymous_function(cols[0], "ARRAYS_ZIP")
@@ -1551,7 +1552,7 @@ def sequence(
1551
1552
  return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
1552
1553
 
1553
1554
 
1554
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1555
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1555
1556
  def from_csv(
1556
1557
  col: ColumnOrName,
1557
1558
  schema: t.Union[Column, str],
@@ -1564,7 +1565,7 @@ def from_csv(
1564
1565
  return Column.invoke_anonymous_function(col, "FROM_CSV", schema)
1565
1566
 
1566
1567
 
1567
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1568
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1568
1569
  def aggregate(
1569
1570
  col: ColumnOrName,
1570
1571
  initialValue: ColumnOrName,
@@ -1586,7 +1587,7 @@ def aggregate(
1586
1587
  )
1587
1588
 
1588
1589
 
1589
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1590
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1590
1591
  def transform(
1591
1592
  col: ColumnOrName,
1592
1593
  f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
@@ -1597,19 +1598,19 @@ def transform(
1597
1598
  )
1598
1599
 
1599
1600
 
1600
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1601
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1601
1602
  def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
1602
1603
  f_expression = _get_lambda_from_func(f)
1603
1604
  return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression))
1604
1605
 
1605
1606
 
1606
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1607
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1607
1608
  def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
1608
1609
  f_expression = _get_lambda_from_func(f)
1609
1610
  return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
1610
1611
 
1611
1612
 
1612
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1613
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1613
1614
  def filter(
1614
1615
  col: ColumnOrName,
1615
1616
  f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
@@ -1620,7 +1621,7 @@ def filter(
1620
1621
  )
1621
1622
 
1622
1623
 
1623
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1624
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1624
1625
  def zip_with(
1625
1626
  left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
1626
1627
  ) -> Column:
@@ -1628,25 +1629,25 @@ def zip_with(
1628
1629
  return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
1629
1630
 
1630
1631
 
1631
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1632
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1632
1633
  def transform_keys(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
1633
1634
  f_expression = _get_lambda_from_func(f)
1634
1635
  return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression))
1635
1636
 
1636
1637
 
1637
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1638
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1638
1639
  def transform_values(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
1639
1640
  f_expression = _get_lambda_from_func(f)
1640
1641
  return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression))
1641
1642
 
1642
1643
 
1643
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1644
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1644
1645
  def map_filter(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column:
1645
1646
  f_expression = _get_lambda_from_func(f)
1646
1647
  return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression))
1647
1648
 
1648
1649
 
1649
- @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1650
+ @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
1650
1651
  def map_zip_with(
1651
1652
  col1: ColumnOrName,
1652
1653
  col2: ColumnOrName,
@@ -1656,11 +1657,16 @@ def map_zip_with(
1656
1657
  return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression))
1657
1658
 
1658
1659
 
1659
- @meta(unsupported_engines="postgres")
1660
+ @meta(unsupported_engines=["postgres", "snowflake"])
1660
1661
  def typeof(col: ColumnOrName) -> Column:
1661
1662
  return Column.invoke_anonymous_function(col, "TYPEOF")
1662
1663
 
1663
1664
 
1665
+ @meta()
1666
+ def nullif(col1: ColumnOrName, col2: ColumnOrName) -> Column:
1667
+ return Column.invoke_expression_over_column(col1, expression.Nullif, expression=col2)
1668
+
1669
+
1664
1670
  @meta()
1665
1671
  def _lambda_quoted(value: str) -> t.Optional[bool]:
1666
1672
  return False if value == "_" else None
@@ -65,6 +65,15 @@ def replace_branch_and_sequence_ids_with_cte_name(
65
65
  return
66
66
 
67
67
 
68
+ def normalize_dict(session: SESSION, data: t.Dict) -> t.Dict:
69
+ if isinstance(data, dict):
70
+ return {session._normalize_string(k): normalize_dict(session, v) for k, v in data.items()}
71
+ elif isinstance(data, list):
72
+ return [normalize_dict(session, v) for v in data]
73
+ else:
74
+ return data
75
+
76
+
68
77
  def _set_alias_name(id: exp.Identifier, name: str):
69
78
  id.set("this", name)
70
79
  id.set("quoted", False)
sqlframe/base/session.py CHANGED
@@ -23,6 +23,7 @@ from sqlglot.schema import MappingSchema
23
23
 
24
24
  from sqlframe.base.catalog import _BaseCatalog
25
25
  from sqlframe.base.dataframe import _BaseDataFrame
26
+ from sqlframe.base.normalize import normalize_dict
26
27
  from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
27
28
  from sqlframe.base.util import (
28
29
  get_column_mapping_from_schema_input,
@@ -257,6 +258,7 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN]):
257
258
  if isinstance(sample_row, Row):
258
259
  sample_row = sample_row.asDict()
259
260
  if isinstance(sample_row, dict):
261
+ sample_row = normalize_dict(self, sample_row)
260
262
  default_data_type = get_default_data_type(sample_row[name])
261
263
  updated_mapping[name] = (
262
264
  exp.DataType.build(default_data_type, dialect="spark")
@@ -15,4 +15,49 @@ globals().update(
15
15
  )
16
16
 
17
17
 
18
- from sqlframe.base.function_alternatives import e_literal as e # noqa
18
+ from sqlframe.base.function_alternatives import ( # noqa
19
+ e_literal as e,
20
+ expm1_from_exp as expm1,
21
+ log1p_from_log as log1p,
22
+ rint_from_round as rint,
23
+ bitwise_not_from_bitnot as bitwise_not,
24
+ skewness_from_skew as skewness,
25
+ isnan_using_equal as isnan,
26
+ isnull_using_equal as isnull,
27
+ nanvl_as_case as nanvl,
28
+ percentile_approx_without_accuracy_and_max_array as percentile_approx,
29
+ bround_using_half_even as bround,
30
+ shiftleft_from_bitshiftleft as shiftleft,
31
+ shiftright_from_bitshiftright as shiftright,
32
+ struct_with_eq as struct,
33
+ make_date_date_from_parts as make_date,
34
+ date_add_no_date_sub as date_add,
35
+ date_sub_by_date_add as date_sub,
36
+ add_months_using_func as add_months,
37
+ months_between_cast_as_date_cast_roundoff as months_between,
38
+ last_day_with_cast as last_day,
39
+ from_unixtime_from_timestamp as from_unixtime,
40
+ unix_timestamp_from_extract as unix_timestamp,
41
+ base64_from_base64_encode as base64,
42
+ unbase64_from_base64_decode_string as unbase64,
43
+ format_number_from_to_char as format_number,
44
+ overlay_from_substr as overlay,
45
+ levenshtein_edit_distance as levenshtein,
46
+ split_with_split as split,
47
+ regexp_extract_coalesce_empty_str as regexp_extract,
48
+ hex_using_encode as hex,
49
+ unhex_hex_decode_str as unhex,
50
+ create_map_with_cast as create_map,
51
+ array_contains_cast_variant as array_contains,
52
+ arrays_overlap_as_plural as arrays_overlap,
53
+ slice_as_array_slice as slice,
54
+ array_join_no_null_replacement as array_join,
55
+ array_position_cast_variant_and_flip as array_position,
56
+ element_at_using_brackets as element_at,
57
+ array_intersect_using_intersection as array_intersect,
58
+ array_union_using_array_concat as array_union,
59
+ sort_array_using_array_sort as sort_array,
60
+ flatten_using_array_flatten as flatten,
61
+ map_concat_using_map_cat as map_concat,
62
+ sequence_from_array_generate_range as sequence,
63
+ )
@@ -24,9 +24,17 @@ else:
24
24
 
25
25
 
26
26
  class JsonLoadsSnowflakeConverter(SnowflakeConverter):
27
+ # This might not be needed once proper arrow types are supported.
28
+ # Checkout this PR: https://github.com/snowflakedb/snowflake-connector-python/pull/1853/files
29
+ # Specifically see if `alter session set enable_structured_types_in_client_response=true` and
30
+ # `alter session set force_enable_structured_types_native_arrow_format=true` are supported then it might work.
31
+ # At the time of writing these were not supported parameters on my version on Snowflake.
27
32
  def _json_loads(self, ctx: dict[str, t.Any]) -> t.Callable:
28
33
  def conv(value: str) -> t.List:
29
- return json.loads(value)
34
+ # Snowflake returns "undefined" for null values when inside an array
35
+ # We check if we replaced "'undefined'" string and if so we switch it back
36
+ # this is a lazy approach compared to writing a proper regex replace
37
+ return json.loads(value.replace("undefined", "null").replace("'null'", "undefined"))
30
38
 
31
39
  return conv
32
40
 
@@ -49,10 +57,9 @@ class SnowflakeSession(
49
57
  _writer = SnowflakeDataFrameWriter
50
58
  _df = SnowflakeDataFrame
51
59
 
60
+ DEFAULT_TIME_FORMAT = "YYYY-MM-DD HH:MI:SS"
61
+
52
62
  def __init__(self, conn: t.Optional[SnowflakeConnection] = None):
53
- warnings.warn(
54
- "SnowflakeSession is still in active development. Functions may not work as expected."
55
- )
56
63
  import snowflake
57
64
 
58
65
  snowflake.connector.cursor.CAN_USE_ARROW_RESULT_FORMAT = False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sqlframe
3
- Version: 1.4.0
3
+ Version: 1.5.1
4
4
  Summary: Taking the Spark out of PySpark by converting to SQL
5
5
  Home-page: https://github.com/eakmanrq/sqlframe
6
6
  Author: Ryan Eakman
@@ -18,7 +18,7 @@ Requires-Python: >=3.8
18
18
  Description-Content-Type: text/markdown
19
19
  License-File: LICENSE
20
20
  Requires-Dist: prettytable (<3.11.0)
21
- Requires-Dist: sqlglot (<24.1,>=24.0.0)
21
+ Requires-Dist: sqlglot (<24.2,>=24.0.0)
22
22
  Provides-Extra: bigquery
23
23
  Requires-Dist: google-cloud-bigquery-storage (<3,>=2) ; extra == 'bigquery'
24
24
  Requires-Dist: google-cloud-bigquery[pandas] (<4,>=3) ; extra == 'bigquery'
@@ -72,6 +72,7 @@ SQLFrame currently supports the following engines (many more in development):
72
72
  * [BigQuery](https://sqlframe.readthedocs.io/en/stable/bigquery/)
73
73
  * [DuckDB](https://sqlframe.readthedocs.io/en/stable/duckdb)
74
74
  * [Postgres](https://sqlframe.readthedocs.io/en/stable/postgres)
75
+ * [Snowflake](https://sqlframe.readthedocs.io/en/stable/snowflake)
75
76
 
76
77
  SQLFrame also has a "Standalone" session that be used to generate SQL without any connection to a database engine.
77
78
  * [Standalone](https://sqlframe.readthedocs.io/en/stable/standalone)
@@ -91,6 +92,8 @@ pip install "sqlframe[bigquery]"
91
92
  pip install "sqlframe[duckdb]"
92
93
  # Postgres
93
94
  pip install "sqlframe[postgres]"
95
+ # Snowflake
96
+ pip install "sqlframe[snowflake]"
94
97
  # Standalone
95
98
  pip install sqlframe
96
99
  ```
@@ -1,19 +1,19 @@
1
1
  sqlframe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- sqlframe/_version.py,sha256=R8-T9fmURjcuoxYpHTAjyNAhgJPDtI2jogCjqYYkfCU,411
2
+ sqlframe/_version.py,sha256=W6YuN1JOd6M-rSt9HDXK91AutRDYXTjJT_LQg3rCsjk,411
3
3
  sqlframe/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  sqlframe/base/_typing.py,sha256=DuTay8-o9W-pw3RPZCgLunKNJLS9PkaV11G_pxXp9NY,1256
5
5
  sqlframe/base/catalog.py,sha256=ATDGirouUjal05P4ymL-wIi8rgjg_8w4PoACamiO64A,37245
6
- sqlframe/base/column.py,sha256=p3VrtATBmjAYHollFcsdps2UJTNC-Pvyg4Zt7y4CK9w,15358
6
+ sqlframe/base/column.py,sha256=kNp8Hfeozgs-OuZBky8-kvRq5uuRMAkkiLpb3fs4XUE,15575
7
7
  sqlframe/base/dataframe.py,sha256=9PuqC9dBficSE-Y1v_BHyk4gK-Hd43SaVBmxBeyNnD8,62939
8
8
  sqlframe/base/decorators.py,sha256=I5osMgx9BuCgbtp4jVM2DNwYJVLzCv-OtTedhQEik0g,1882
9
9
  sqlframe/base/exceptions.py,sha256=pCB9hXX4jxZWzNg3JN1i38cv3BmpUlee5NoLYx3YXIQ,208
10
- sqlframe/base/function_alternatives.py,sha256=to0kv3MTJmQFeVTMcitz0AxBIoUJC3cu5LkEY5aJpoo,31318
11
- sqlframe/base/functions.py,sha256=iVe8AbXGX_gXnkQ1N-clX6rihsonfzJ84_YvWzhB2FM,53540
10
+ sqlframe/base/function_alternatives.py,sha256=NDXs2igY7PBsStzTSRZvJcCshBOJkPQl2GbhpVFU6To,42931
11
+ sqlframe/base/functions.py,sha256=QgVMWnZFClxfbiOV4CpILtOtdo7-Ey5wWTehdGy0qkA,54393
12
12
  sqlframe/base/group.py,sha256=TES9CleVmH3x-0X-tqmuUKfCKSWjH5vg1aU3R6dDmFc,4059
13
- sqlframe/base/normalize.py,sha256=gRBn-PziFdE-CHtPJMkMl7y_YH0mauUcD4zfgyyvlpw,3565
13
+ sqlframe/base/normalize.py,sha256=nXAJ5CwxVf4DV0GsH-q1w0p8gmjSMlv96k_ez1eVul8,3880
14
14
  sqlframe/base/operations.py,sha256=-AhNuEzcV7ZExoP1oY3blaKip-joQyJeQVvfBTs_2g4,3456
15
15
  sqlframe/base/readerwriter.py,sha256=5NPQMiOrw6I54U243R_6-ynnWYsNksgqwRpPp4IFjIw,25288
16
- sqlframe/base/session.py,sha256=-h7qcOPRw9KBJPg_V6Tlr8Z2SmcsgAWruBo34o6zfrQ,21795
16
+ sqlframe/base/session.py,sha256=EhQb2oVRApU_xcbdVc5m0WrwKAbDjcUCRLAVzcg_JH8,21908
17
17
  sqlframe/base/transforms.py,sha256=y0j3SGDz3XCmNGrvassk1S-owllUWfkHyMgZlY6SFO4,467
18
18
  sqlframe/base/types.py,sha256=aJT5YXr-M_LAfUM0uK4asfbrQFab_xmsp1CP2zkG8p0,11924
19
19
  sqlframe/base/util.py,sha256=wdATi7STt-FfXrX9TPRkw4PFJP7uAsK_K9YkKSrd0qU,8824
@@ -66,10 +66,10 @@ sqlframe/snowflake/__init__.py,sha256=nuQ3cuHjDpW4ELZfbd2qOYmtXmcYl7MtsrdOrRdozo
66
66
  sqlframe/snowflake/catalog.py,sha256=uDjBgDdCyxaDkGNX_8tb-lol7MwwazcClUBAZsOSj70,5014
67
67
  sqlframe/snowflake/column.py,sha256=E1tUa62Y5HajkhgFuebU9zohrGyieudcHzTT8gfalio,40
68
68
  sqlframe/snowflake/dataframe.py,sha256=OJ27NudBUE3XX9mc8ywooGhYV4ijF9nX2K_nkHRcTx4,1393
69
- sqlframe/snowflake/functions.py,sha256=ZYX9gyPvmpKoLi_7uQdB0uPQNTREOAJD0aCcccX1iPc,456
69
+ sqlframe/snowflake/functions.py,sha256=HXxt-wM05vcbgmu06uGApGd-Z9bWOwWwjqPfg38fF0M,2330
70
70
  sqlframe/snowflake/group.py,sha256=pPP1l2RRo_LgkXrji8a87n2PKo-63ZRPT-WUtvVcBME,395
71
71
  sqlframe/snowflake/readwriter.py,sha256=yhRc2HcMq6PwV3ghZWC-q-qaE7LE4aEjZEXCip4OOlQ,884
72
- sqlframe/snowflake/session.py,sha256=QKdxXgK9_YgxoyxzEd73ot4t0M6Dz4em09JdVMYxVPI,2584
72
+ sqlframe/snowflake/session.py,sha256=bDOlnuIiQ9j_zfF7F5H1gTLmpHUjruIxr2CfXcS_7YU,3284
73
73
  sqlframe/snowflake/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
74
74
  sqlframe/snowflake/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
75
75
  sqlframe/spark/__init__.py,sha256=jamKYQtQaKjjXnQ01QGPHvatbrZSw9sWno_VOUGSz6I,712
@@ -92,8 +92,8 @@ sqlframe/standalone/readwriter.py,sha256=EZNyDJ4ID6sGNog3uP4-e9RvchX4biJJDNtc5hk
92
92
  sqlframe/standalone/session.py,sha256=wQmdu2sv6KMTAv0LRFk7TY7yzlh3xvmsyqilEtRecbY,1191
93
93
  sqlframe/standalone/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
94
94
  sqlframe/standalone/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
95
- sqlframe-1.4.0.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
96
- sqlframe-1.4.0.dist-info/METADATA,sha256=nnz73ML6w8WyctFzwiaKVVNr9RQwmpmfckrcKqEX_PE,7219
97
- sqlframe-1.4.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
98
- sqlframe-1.4.0.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
99
- sqlframe-1.4.0.dist-info/RECORD,,
95
+ sqlframe-1.5.1.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
96
+ sqlframe-1.5.1.dist-info/METADATA,sha256=dHIN2Vu-zGlN7rNFnbeJt3oy_ZajB9Y63pQgmi-NK_w,7332
97
+ sqlframe-1.5.1.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
98
+ sqlframe-1.5.1.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
99
+ sqlframe-1.5.1.dist-info/RECORD,,