sqlframe 3.10.0__py3-none-any.whl → 3.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,12 +21,18 @@ if t.TYPE_CHECKING:
21
21
  from pyspark.sql.session import SparkContext
22
22
 
23
23
  from sqlframe.base._typing import ColumnOrLiteral, ColumnOrName
24
- from sqlframe.base.session import DF
24
+ from sqlframe.base.session import DF, _BaseSession
25
25
  from sqlframe.base.types import ArrayType, StructType
26
26
 
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
29
 
30
+ def _get_session() -> _BaseSession:
31
+ from sqlframe.base.session import _BaseSession
32
+
33
+ return _BaseSession()
34
+
35
+
30
36
  @meta()
31
37
  def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
32
38
  from sqlframe.base.session import _BaseSession
@@ -68,7 +74,9 @@ def least(*cols: ColumnOrName) -> Column:
68
74
  def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
69
75
  columns = [Column.ensure_col(x) for x in [col] + list(cols)]
70
76
  return Column(
71
- expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns]))
77
+ expression.Count(
78
+ this=expression.Distinct(expressions=[x.column_expression for x in columns])
79
+ )
72
80
  )
73
81
 
74
82
 
@@ -151,7 +159,9 @@ mean = avg
151
159
  @meta()
152
160
  def sumDistinct(col: ColumnOrName) -> Column:
153
161
  return Column(
154
- expression.Sum(this=expression.Distinct(expressions=[Column.ensure_col(col).expression]))
162
+ expression.Sum(
163
+ this=expression.Distinct(expressions=[Column.ensure_col(col).column_expression])
164
+ )
155
165
  )
156
166
 
157
167
 
@@ -237,6 +247,19 @@ def csc(col: ColumnOrName) -> Column:
237
247
 
238
248
  @meta()
239
249
  def e() -> Column:
250
+ from sqlframe.base.function_alternatives import e_literal
251
+
252
+ session = _get_session()
253
+
254
+ if (
255
+ session._is_bigquery
256
+ or session._is_duckdb
257
+ or session._is_postgres
258
+ or session._is_redshift
259
+ or session._is_snowflake
260
+ ):
261
+ return e_literal()
262
+
240
263
  return Column(expression.Anonymous(this="e"))
241
264
 
242
265
 
@@ -247,11 +270,31 @@ def exp(col: ColumnOrName) -> Column:
247
270
 
248
271
  @meta()
249
272
  def expm1(col: ColumnOrName) -> Column:
273
+ from sqlframe.base.function_alternatives import expm1_from_exp
274
+
275
+ session = _get_session()
276
+
277
+ if session._is_bigquery or session._is_duckdb or session._is_postgres or session._is_snowflake:
278
+ return expm1_from_exp(col)
279
+
250
280
  return Column.invoke_anonymous_function(col, "EXPM1")
251
281
 
252
282
 
253
283
  @meta()
254
284
  def factorial(col: ColumnOrName) -> Column:
285
+ from sqlframe.base.function_alternatives import (
286
+ factorial_ensure_int,
287
+ factorial_from_case_statement,
288
+ )
289
+
290
+ session = _get_session()
291
+
292
+ if session._is_duckdb:
293
+ return factorial_ensure_int(col)
294
+
295
+ if session._is_bigquery:
296
+ return factorial_from_case_statement(col)
297
+
255
298
  return Column.invoke_anonymous_function(col, "FACTORIAL")
256
299
 
257
300
 
@@ -267,6 +310,13 @@ def log10(col: ColumnOrName) -> Column:
267
310
 
268
311
  @meta()
269
312
  def log1p(col: ColumnOrName) -> Column:
313
+ from sqlframe.base.function_alternatives import log1p_from_log
314
+
315
+ session = _get_session()
316
+
317
+ if session._is_bigquery or session._is_duckdb or session._is_postgres or session._is_snowflake:
318
+ return log1p_from_log(col)
319
+
270
320
  return Column.invoke_anonymous_function(col, "LOG1P")
271
321
 
272
322
 
@@ -286,6 +336,13 @@ def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = Non
286
336
 
287
337
  @meta()
288
338
  def rint(col: ColumnOrName) -> Column:
339
+ from sqlframe.base.function_alternatives import rint_from_round
340
+
341
+ session = _get_session()
342
+
343
+ if session._is_bigquery or session._is_duckdb or session._is_postgres or session._is_snowflake:
344
+ return rint_from_round(col)
345
+
289
346
  return Column.invoke_anonymous_function(col, "RINT")
290
347
 
291
348
 
@@ -321,6 +378,13 @@ def tanh(col: ColumnOrName) -> Column:
321
378
 
322
379
  @meta()
323
380
  def degrees(col: ColumnOrName) -> Column:
381
+ from sqlframe.base.function_alternatives import degrees_bgutil
382
+
383
+ session = _get_session()
384
+
385
+ if session._is_bigquery:
386
+ return degrees_bgutil(col)
387
+
324
388
  return Column.invoke_anonymous_function(col, "DEGREES")
325
389
 
326
390
 
@@ -329,6 +393,13 @@ toDegrees = degrees
329
393
 
330
394
  @meta()
331
395
  def radians(col: ColumnOrName) -> Column:
396
+ from sqlframe.base.function_alternatives import radians_bgutil
397
+
398
+ session = _get_session()
399
+
400
+ if session._is_bigquery:
401
+ return radians_bgutil(col)
402
+
332
403
  return Column.invoke_anonymous_function(col, "RADIANS")
333
404
 
334
405
 
@@ -342,6 +413,13 @@ def bitwiseNOT(col: ColumnOrName) -> Column:
342
413
 
343
414
  @meta()
344
415
  def bitwise_not(col: ColumnOrName) -> Column:
416
+ from sqlframe.base.function_alternatives import bitwise_not_from_bitnot
417
+
418
+ session = _get_session()
419
+
420
+ if session._is_snowflake:
421
+ return bitwise_not_from_bitnot(col)
422
+
345
423
  return Column.invoke_expression_over_column(col, expression.BitwiseNot)
346
424
 
347
425
 
@@ -397,11 +475,25 @@ def var_pop(col: ColumnOrName) -> Column:
397
475
 
398
476
  @meta(unsupported_engines=["bigquery", "postgres"])
399
477
  def skewness(col: ColumnOrName) -> Column:
478
+ from sqlframe.base.function_alternatives import skewness_from_skew
479
+
480
+ session = _get_session()
481
+
482
+ if session._is_snowflake:
483
+ return skewness_from_skew(col)
484
+
400
485
  return Column.invoke_anonymous_function(col, "SKEWNESS")
401
486
 
402
487
 
403
488
  @meta(unsupported_engines=["bigquery", "postgres"])
404
489
  def kurtosis(col: ColumnOrName) -> Column:
490
+ from sqlframe.base.function_alternatives import kurtosis_from_kurtosis_pop
491
+
492
+ session = _get_session()
493
+
494
+ if session._is_duckdb:
495
+ return kurtosis_from_kurtosis_pop(col)
496
+
405
497
  return Column.invoke_anonymous_function(col, "KURTOSIS")
406
498
 
407
499
 
@@ -412,6 +504,13 @@ def collect_list(col: ColumnOrName) -> Column:
412
504
 
413
505
  @meta()
414
506
  def collect_set(col: ColumnOrName) -> Column:
507
+ from sqlframe.base.function_alternatives import collect_set_from_list_distinct
508
+
509
+ session = _get_session()
510
+
511
+ if session._is_bigquery or session._is_duckdb or session._is_postgres:
512
+ return collect_set_from_list_distinct(col)
513
+
415
514
  return Column.invoke_expression_over_column(col, expression.ArrayUniqueAgg)
416
515
 
417
516
 
@@ -495,6 +594,11 @@ def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
495
594
 
496
595
  @meta(unsupported_engines=["bigquery", "postgres", "snowflake"])
497
596
  def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
597
+ session = _get_session()
598
+
599
+ if session._is_duckdb:
600
+ ignorenulls = None
601
+
498
602
  this = Column.invoke_expression_over_column(col, expression.First)
499
603
  if ignorenulls:
500
604
  return Column.invoke_expression_over_column(this, expression.IgnoreNulls)
@@ -519,11 +623,25 @@ def input_file_name() -> Column:
519
623
 
520
624
  @meta()
521
625
  def isnan(col: ColumnOrName) -> Column:
626
+ from sqlframe.base.function_alternatives import isnan_using_equal
627
+
628
+ session = _get_session()
629
+
630
+ if session._is_postgres or session._is_snowflake:
631
+ return isnan_using_equal(col)
632
+
522
633
  return Column.invoke_expression_over_column(col, expression.IsNan)
523
634
 
524
635
 
525
636
  @meta()
526
637
  def isnull(col: ColumnOrName) -> Column:
638
+ from sqlframe.base.function_alternatives import isnull_using_equal
639
+
640
+ session = _get_session()
641
+
642
+ if session._is_bigquery or session._is_duckdb or session._is_postgres or session._is_snowflake:
643
+ return isnull_using_equal(col)
644
+
527
645
  return Column.invoke_anonymous_function(col, "ISNULL")
528
646
 
529
647
 
@@ -542,6 +660,13 @@ def monotonically_increasing_id() -> Column:
542
660
 
543
661
  @meta()
544
662
  def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column:
663
+ from sqlframe.base.function_alternatives import nanvl_as_case
664
+
665
+ session = _get_session()
666
+
667
+ if session._is_bigquery or session._is_duckdb or session._is_postgres or session._is_snowflake:
668
+ return nanvl_as_case(col1, col2)
669
+
545
670
  return Column.invoke_anonymous_function(col1, "NANVL", col2)
546
671
 
547
672
 
@@ -551,6 +676,24 @@ def percentile_approx(
551
676
  percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
552
677
  accuracy: t.Optional[t.Union[ColumnOrLiteral, int]] = None,
553
678
  ) -> Column:
679
+ from sqlframe.base.function_alternatives import (
680
+ percentile_approx_without_accuracy_and_max_array,
681
+ percentile_approx_without_accuracy_and_plural,
682
+ )
683
+
684
+ session = _get_session()
685
+
686
+ if session._is_bigquery:
687
+ return percentile_approx_without_accuracy_and_plural(col, percentage, accuracy) # type: ignore
688
+
689
+ if session._is_duckdb:
690
+ if accuracy:
691
+ logger.warning("Accuracy is ignored since it is not supported in this dialect")
692
+ accuracy = None
693
+
694
+ if session._is_snowflake and isinstance(percentage, (list, tuple)):
695
+ return percentile_approx_without_accuracy_and_max_array(col, percentage, accuracy) # type: ignore
696
+
554
697
  if accuracy:
555
698
  return Column.invoke_expression_over_column(
556
699
  col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
@@ -566,6 +709,13 @@ def percentile(
566
709
  percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
567
710
  frequency: t.Optional[ColumnOrLiteral] = None,
568
711
  ) -> Column:
712
+ from sqlframe.base.function_alternatives import percentile_without_disc
713
+
714
+ session = _get_session()
715
+
716
+ if session._is_databricks or session._is_spark:
717
+ return percentile_without_disc(col, percentage, frequency)
718
+
569
719
  if frequency:
570
720
  logger.warning("Frequency is not supported in all engines")
571
721
  return Column.invoke_expression_over_column(
@@ -575,6 +725,13 @@ def percentile(
575
725
 
576
726
  @meta()
577
727
  def rand(seed: t.Optional[int] = None) -> Column:
728
+ session = _get_session()
729
+
730
+ if session._is_bigquery or session._is_duckdb or session._is_postgres:
731
+ if seed:
732
+ logger.warning("Seed is ignored since it is not supported in this dialect")
733
+ seed = None
734
+
578
735
  if seed is not None:
579
736
  return Column.invoke_expression_over_column(None, expression.Rand, this=lit(seed))
580
737
  return Column.invoke_expression_over_column(None, expression.Rand)
@@ -589,6 +746,11 @@ def randn(seed: t.Optional[int] = None) -> Column:
589
746
 
590
747
  @meta()
591
748
  def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
749
+ session = _get_session()
750
+
751
+ if session._is_postgres:
752
+ col = Column.ensure_col(col).cast("numeric")
753
+
592
754
  if scale is not None:
593
755
  return Column.invoke_expression_over_column(col, expression.Round, decimals=scale)
594
756
  return Column.invoke_expression_over_column(col, expression.Round)
@@ -596,6 +758,19 @@ def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
596
758
 
597
759
  @meta(unsupported_engines=["duckdb", "postgres"])
598
760
  def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
761
+ from sqlframe.base.function_alternatives import (
762
+ bround_bgutil,
763
+ bround_using_half_even,
764
+ )
765
+
766
+ session = _get_session()
767
+
768
+ if session._is_bigquery:
769
+ return bround_bgutil(col, scale)
770
+
771
+ if session._is_snowflake:
772
+ return bround_using_half_even(col, scale)
773
+
599
774
  if scale is not None:
600
775
  return Column.invoke_anonymous_function(col, "BROUND", lit(scale))
601
776
  return Column.invoke_anonymous_function(col, "BROUND")
@@ -603,6 +778,13 @@ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
603
778
 
604
779
  @meta()
605
780
  def shiftleft(col: ColumnOrName, numBits: int) -> Column:
781
+ from sqlframe.base.function_alternatives import shiftleft_from_bitshiftleft
782
+
783
+ session = _get_session()
784
+
785
+ if session._is_snowflake:
786
+ return shiftleft_from_bitshiftleft(col, numBits)
787
+
606
788
  return Column.invoke_expression_over_column(
607
789
  col, expression.BitwiseLeftShift, expression=lit(numBits)
608
790
  )
@@ -613,6 +795,13 @@ shiftLeft = shiftleft
613
795
 
614
796
  @meta()
615
797
  def shiftright(col: ColumnOrName, numBits: int) -> Column:
798
+ from sqlframe.base.function_alternatives import shiftright_from_bitshiftright
799
+
800
+ session = _get_session()
801
+
802
+ if session._is_snowflake:
803
+ return shiftright_from_bitshiftright(col, numBits)
804
+
616
805
  return Column.invoke_expression_over_column(
617
806
  col, expression.BitwiseRightShift, expression=lit(numBits)
618
807
  )
@@ -638,6 +827,13 @@ def expr(str: str) -> Column:
638
827
 
639
828
  @meta(unsupported_engines=["postgres"])
640
829
  def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
830
+ from sqlframe.base.function_alternatives import struct_with_eq
831
+
832
+ session = _get_session()
833
+
834
+ if session._is_snowflake:
835
+ return struct_with_eq(col, *cols)
836
+
641
837
  columns = ensure_list(col) + list(cols)
642
838
  return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns)
643
839
 
@@ -699,7 +895,7 @@ def date_format(col: ColumnOrName, format: str) -> Column:
699
895
  from sqlframe.base.session import _BaseSession
700
896
 
701
897
  return Column.invoke_expression_over_column(
702
- Column(expression.TimeStrToTime(this=Column.ensure_col(col).expression)),
898
+ Column(expression.TimeStrToTime(this=Column.ensure_col(col).column_expression)),
703
899
  expression.TimeToStr,
704
900
  format=_BaseSession().format_time(format),
705
901
  )
@@ -707,77 +903,185 @@ def date_format(col: ColumnOrName, format: str) -> Column:
707
903
 
708
904
  @meta()
709
905
  def year(col: ColumnOrName) -> Column:
906
+ from sqlframe.base.function_alternatives import year_from_extract
907
+
908
+ session = _get_session()
909
+
910
+ if session._is_bigquery or session._is_postgres:
911
+ return year_from_extract(col)
912
+
710
913
  return Column.invoke_expression_over_column(
711
- Column(expression.TsOrDsToDate(this=Column.ensure_col(col).expression)), expression.Year
914
+ Column(expression.TsOrDsToDate(this=Column.ensure_col(col).column_expression)),
915
+ expression.Year,
712
916
  )
713
917
 
714
918
 
715
919
  @meta()
716
920
  def quarter(col: ColumnOrName) -> Column:
921
+ from sqlframe.base.function_alternatives import quarter_from_extract
922
+
923
+ session = _get_session()
924
+
925
+ if session._is_bigquery or session._is_postgres:
926
+ return quarter_from_extract(col)
927
+
717
928
  return Column(
718
929
  expression.Anonymous(
719
930
  this="QUARTER",
720
- expressions=[expression.TsOrDsToDate(this=Column.ensure_col(col).expression)],
931
+ expressions=[expression.TsOrDsToDate(this=Column.ensure_col(col).column_expression)],
721
932
  )
722
933
  )
723
934
 
724
935
 
725
936
  @meta()
726
937
  def month(col: ColumnOrName) -> Column:
938
+ from sqlframe.base.function_alternatives import month_from_extract
939
+
940
+ session = _get_session()
941
+
942
+ if session._is_bigquery or session._is_postgres:
943
+ return month_from_extract(col)
944
+
727
945
  return Column.invoke_expression_over_column(
728
- Column(expression.TsOrDsToDate(this=Column.ensure_col(col).expression)), expression.Month
946
+ Column(expression.TsOrDsToDate(this=Column.ensure_col(col).column_expression)),
947
+ expression.Month,
729
948
  )
730
949
 
731
950
 
732
951
  @meta()
733
952
  def dayofweek(col: ColumnOrName) -> Column:
953
+ from sqlframe.base.function_alternatives import (
954
+ dayofweek_from_extract,
955
+ dayofweek_from_extract_with_isodow,
956
+ )
957
+
958
+ session = _get_session()
959
+
960
+ if session._is_bigquery:
961
+ return dayofweek_from_extract(col)
962
+
963
+ if session._is_postgres:
964
+ return dayofweek_from_extract_with_isodow(col)
965
+
734
966
  return Column.invoke_expression_over_column(
735
- Column(expression.TsOrDsToDate(this=Column.ensure_col(col).expression)),
967
+ Column(expression.TsOrDsToDate(this=Column.ensure_col(col).column_expression)),
736
968
  expression.DayOfWeek,
737
969
  )
738
970
 
739
971
 
740
972
  @meta()
741
973
  def dayofmonth(col: ColumnOrName) -> Column:
974
+ from sqlframe.base.function_alternatives import dayofmonth_from_extract_with_day
975
+
976
+ session = _get_session()
977
+
978
+ if session._is_bigquery or session._is_postgres:
979
+ return dayofmonth_from_extract_with_day(col)
980
+
742
981
  return Column.invoke_expression_over_column(
743
- Column(expression.TsOrDsToDate(this=Column.ensure_col(col).expression)),
982
+ Column(expression.TsOrDsToDate(this=Column.ensure_col(col).column_expression)),
744
983
  expression.DayOfMonth,
745
984
  )
746
985
 
747
986
 
748
987
  @meta()
749
988
  def dayofyear(col: ColumnOrName) -> Column:
989
+ from sqlframe.base.function_alternatives import (
990
+ dayofyear_from_extract,
991
+ dayofyear_from_extract_doy,
992
+ )
993
+
994
+ session = _get_session()
995
+
996
+ if session._is_bigquery:
997
+ return dayofyear_from_extract(col)
998
+
999
+ if session._is_postgres:
1000
+ return dayofyear_from_extract_doy(col)
1001
+
750
1002
  return Column.invoke_expression_over_column(
751
- Column(expression.TsOrDsToDate(this=Column.ensure_col(col).expression)),
1003
+ Column(expression.TsOrDsToDate(this=Column.ensure_col(col).column_expression)),
752
1004
  expression.DayOfYear,
753
1005
  )
754
1006
 
755
1007
 
756
1008
  @meta()
757
1009
  def hour(col: ColumnOrName) -> Column:
1010
+ from sqlframe.base.function_alternatives import hour_from_extract
1011
+
1012
+ session = _get_session()
1013
+
1014
+ if session._is_bigquery or session._is_postgres:
1015
+ return hour_from_extract(col)
1016
+
758
1017
  return Column.invoke_anonymous_function(col, "HOUR")
759
1018
 
760
1019
 
761
1020
  @meta()
762
1021
  def minute(col: ColumnOrName) -> Column:
1022
+ from sqlframe.base.function_alternatives import minute_from_extract
1023
+
1024
+ session = _get_session()
1025
+
1026
+ if session._is_bigquery or session._is_postgres:
1027
+ return minute_from_extract(col)
1028
+
763
1029
  return Column.invoke_anonymous_function(col, "MINUTE")
764
1030
 
765
1031
 
766
1032
  @meta()
767
1033
  def second(col: ColumnOrName) -> Column:
1034
+ from sqlframe.base.function_alternatives import second_from_extract
1035
+
1036
+ session = _get_session()
1037
+
1038
+ if session._is_bigquery or session._is_postgres:
1039
+ return second_from_extract(col)
1040
+
768
1041
  return Column.invoke_anonymous_function(col, "SECOND")
769
1042
 
770
1043
 
771
1044
  @meta()
772
1045
  def weekofyear(col: ColumnOrName) -> Column:
1046
+ from sqlframe.base.function_alternatives import (
1047
+ weekofyear_from_extract_as_isoweek,
1048
+ weekofyear_from_extract_as_week,
1049
+ )
1050
+
1051
+ session = _get_session()
1052
+
1053
+ if session._is_bigquery:
1054
+ return weekofyear_from_extract_as_isoweek(col)
1055
+
1056
+ if session._is_postgres:
1057
+ return weekofyear_from_extract_as_week(col)
1058
+
773
1059
  return Column.invoke_expression_over_column(
774
- Column(expression.TsOrDsToDate(this=Column.ensure_col(col).expression)),
1060
+ Column(expression.TsOrDsToDate(this=Column.ensure_col(col).column_expression)),
775
1061
  expression.WeekOfYear,
776
1062
  )
777
1063
 
778
1064
 
779
1065
  @meta()
780
1066
  def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
1067
+ from sqlframe.base.function_alternatives import (
1068
+ make_date_date_from_parts,
1069
+ make_date_from_date_func,
1070
+ )
1071
+
1072
+ session = _get_session()
1073
+
1074
+ if session._is_bigquery:
1075
+ return make_date_from_date_func(year, month, day)
1076
+
1077
+ if session._is_postgres:
1078
+ year = Column.ensure_col(year).cast("integer")
1079
+ month = Column.ensure_col(month).cast("integer")
1080
+ day = Column.ensure_col(day).cast("integer")
1081
+
1082
+ if session._is_snowflake:
1083
+ return make_date_date_from_parts(year, month, day)
1084
+
781
1085
  return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day)
782
1086
 
783
1087
 
@@ -785,9 +1089,22 @@ def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Col
785
1089
  def date_add(
786
1090
  col: ColumnOrName, days: t.Union[ColumnOrName, int], cast_as_date: bool = True
787
1091
  ) -> Column:
1092
+ from sqlframe.base.function_alternatives import date_add_no_date_sub
1093
+
1094
+ session = _get_session()
1095
+ date_sub_func = get_func_from_session("date_sub")
1096
+ original_days = None
1097
+
1098
+ if session._is_postgres and not isinstance(days, int):
1099
+ original_days = days
1100
+ days = 1
1101
+
1102
+ if session._is_snowflake:
1103
+ return date_add_no_date_sub(col, days, cast_as_date)
1104
+
788
1105
  if isinstance(days, int):
789
1106
  if days < 0:
790
- return date_sub(col, days * -1)
1107
+ return date_sub_func(col, days * -1)
791
1108
  days = lit(days)
792
1109
  result = Column.invoke_expression_over_column(
793
1110
  Column.ensure_col(col).cast("date"),
@@ -795,6 +1112,10 @@ def date_add(
795
1112
  expression=days,
796
1113
  unit=expression.Var(this="DAY"),
797
1114
  )
1115
+
1116
+ if session._is_postgres and original_days is not None:
1117
+ result = result * Column.ensure_col(original_days)
1118
+
798
1119
  if cast_as_date:
799
1120
  return result.cast("date")
800
1121
  return result
@@ -807,9 +1128,22 @@ def date_sub(
807
1128
  """
808
1129
  Non-standard argument: cast_as_date
809
1130
  """
1131
+ from sqlframe.base.function_alternatives import date_sub_by_date_add
1132
+
1133
+ session = _get_session()
1134
+ date_add_func = get_func_from_session("date_add")
1135
+ original_days = None
1136
+
1137
+ if session._is_postgres and not isinstance(days, int):
1138
+ original_days = days
1139
+ days = 1
1140
+
1141
+ if session._is_snowflake:
1142
+ return date_sub_by_date_add(col, days, cast_as_date)
1143
+
810
1144
  if isinstance(days, int):
811
1145
  if days < 0:
812
- return date_add(col, days * -1)
1146
+ return date_add_func(col, days * -1)
813
1147
  days = lit(days)
814
1148
  result = Column.invoke_expression_over_column(
815
1149
  Column.ensure_col(col).cast("date"),
@@ -817,6 +1151,10 @@ def date_sub(
817
1151
  expression=days,
818
1152
  unit=expression.Var(this="DAY"),
819
1153
  )
1154
+
1155
+ if session._is_postgres and original_days is not None:
1156
+ result = result * Column.ensure_col(original_days)
1157
+
820
1158
  if cast_as_date:
821
1159
  return result.cast("date")
822
1160
  return result
@@ -824,6 +1162,13 @@ def date_sub(
824
1162
 
825
1163
  @meta()
826
1164
  def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:
1165
+ from sqlframe.base.function_alternatives import date_diff_with_subtraction
1166
+
1167
+ session = _get_session()
1168
+
1169
+ if session._is_postgres:
1170
+ return date_diff_with_subtraction(end, start)
1171
+
827
1172
  return Column.invoke_expression_over_column(
828
1173
  Column.ensure_col(end).cast("date"),
829
1174
  expression.DateDiff,
@@ -838,28 +1183,51 @@ def add_months(
838
1183
  """
839
1184
  Non-standard argument: cast_as_date
840
1185
  """
1186
+ from sqlframe.base.function_alternatives import add_months_using_func
1187
+
1188
+ lit = get_func_from_session("lit")
1189
+ session = _get_session()
1190
+ original_months = months
1191
+
1192
+ if session._is_databricks or session._is_postgres or session._is_spark:
1193
+ months = 1
1194
+
1195
+ if session._is_snowflake:
1196
+ return add_months_using_func(start, months, cast_as_date)
1197
+
841
1198
  start_col = Column(start).cast("date")
842
1199
 
843
1200
  if isinstance(months, int):
844
1201
  if months < 0:
845
1202
  end_col = Column(
846
1203
  expression.Interval(
847
- this=lit(months * -1).expression, unit=expression.Var(this="MONTH")
1204
+ this=lit(months * -1).column_expression, unit=expression.Var(this="MONTH")
848
1205
  )
849
1206
  )
850
1207
  result = start_col - end_col
851
1208
  else:
852
1209
  end_col = Column(
853
- expression.Interval(this=lit(months).expression, unit=expression.Var(this="MONTH"))
1210
+ expression.Interval(
1211
+ this=lit(months).column_expression, unit=expression.Var(this="MONTH")
1212
+ )
854
1213
  )
855
1214
  result = start_col + end_col
856
1215
  else:
857
1216
  end_col = Column(
858
1217
  expression.Interval(
859
- this=Column.ensure_col(months).expression, unit=expression.Var(this="MONTH")
1218
+ this=Column.ensure_col(months).column_expression, unit=expression.Var(this="MONTH")
860
1219
  )
861
1220
  )
862
1221
  result = start_col + end_col
1222
+
1223
+ if session._is_databricks or session._is_postgres or session._is_spark:
1224
+ multiple_value = (
1225
+ lit(original_months)
1226
+ if isinstance(original_months, int)
1227
+ else Column.ensure_col(original_months)
1228
+ )
1229
+ result = Column.ensure_col(result.column_expression.unnest()) * multiple_value
1230
+
863
1231
  if cast_as_date:
864
1232
  return result.cast("date")
865
1233
  return result
@@ -869,35 +1237,79 @@ def add_months(
869
1237
  def months_between(
870
1238
  date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
871
1239
  ) -> Column:
1240
+ from sqlframe.base.function_alternatives import (
1241
+ months_between_bgutils,
1242
+ months_between_from_age_and_extract,
1243
+ )
1244
+
1245
+ session = _get_session()
1246
+ original_roundoff = roundOff
1247
+
1248
+ if session._is_bigquery:
1249
+ return months_between_bgutils(date1, date2, roundOff)
1250
+
1251
+ if session._is_postgres:
1252
+ return months_between_from_age_and_extract(date1, date2, roundOff)
1253
+
1254
+ if session._is_snowflake:
1255
+ date1 = Column.ensure_col(date1).cast("date")
1256
+ date2 = Column.ensure_col(date2).cast("date")
1257
+ roundOff = None
1258
+
872
1259
  if roundOff is None:
873
- return Column.invoke_expression_over_column(
1260
+ result = Column.invoke_expression_over_column(
874
1261
  date1, expression.MonthsBetween, expression=date2
875
1262
  )
1263
+ else:
1264
+ result = Column.invoke_expression_over_column(
1265
+ date1, expression.MonthsBetween, expression=date2, roundoff=lit(roundOff)
1266
+ )
876
1267
 
877
- return Column.invoke_expression_over_column(
878
- date1, expression.MonthsBetween, expression=date2, roundoff=lit(roundOff)
879
- )
1268
+ if session._is_snowflake and original_roundoff:
1269
+ return result.cast("bigint")
1270
+ return result
880
1271
 
881
1272
 
882
1273
  @meta()
883
1274
  def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
884
- from sqlframe.base.session import _BaseSession
1275
+ session = _get_session()
1276
+
1277
+ if session._is_bigquery:
1278
+ to_timestamp_func = get_func_from_session("to_timestamp")
1279
+ col = to_timestamp_func(col, format)
1280
+
1281
+ if session._is_snowflake:
1282
+ format = format or session.default_time_format
885
1283
 
886
- # format = lit(format or spark_default_date_format())
887
1284
  if format is not None:
888
1285
  return Column.invoke_expression_over_column(
889
- col, expression.TsOrDsToDate, format=_BaseSession().format_time(format)
1286
+ col, expression.TsOrDsToDate, format=session.format_time(format)
890
1287
  )
891
1288
  return Column.invoke_expression_over_column(col, expression.TsOrDsToDate)
892
1289
 
893
1290
 
894
1291
  @meta()
895
1292
  def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
896
- from sqlframe.base.session import _BaseSession
1293
+ from sqlframe.base.function_alternatives import (
1294
+ to_timestamp_just_timestamp,
1295
+ to_timestamp_tz,
1296
+ to_timestamp_with_time_zone,
1297
+ )
1298
+
1299
+ session = _get_session()
1300
+
1301
+ if session._is_duckdb:
1302
+ return to_timestamp_tz(col, format)
1303
+
1304
+ if session._is_bigquery:
1305
+ return to_timestamp_just_timestamp(col, format)
1306
+
1307
+ if session._is_postgres:
1308
+ return to_timestamp_with_time_zone(col, format)
897
1309
 
898
1310
  if format is not None:
899
1311
  return Column.invoke_expression_over_column(
900
- col, expression.StrToTime, format=_BaseSession().format_time(format)
1312
+ col, expression.StrToTime, format=session.format_time(format)
901
1313
  )
902
1314
 
903
1315
  return Column.ensure_col(col).cast("timestampltz")
@@ -919,22 +1331,45 @@ def date_trunc(format: str, timestamp: ColumnOrName) -> Column:
919
1331
 
920
1332
  @meta(unsupported_engines=["duckdb", "postgres"])
921
1333
  def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
1334
+ from sqlframe.base.function_alternatives import next_day_bgutil
1335
+
1336
+ session = _get_session()
1337
+
1338
+ if session._is_bigquery:
1339
+ return next_day_bgutil(col, dayOfWeek)
1340
+
922
1341
  return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek))
923
1342
 
924
1343
 
925
1344
  @meta()
926
1345
  def last_day(col: ColumnOrName) -> Column:
1346
+ session = _get_session()
1347
+
1348
+ if session._is_bigquery or session._is_duckdb or session._is_postgres or session._is_snowflake:
1349
+ col = Column.ensure_col(col).cast("date")
1350
+
927
1351
  return Column.invoke_expression_over_column(col, expression.LastDay)
928
1352
 
929
1353
 
930
1354
  @meta()
931
1355
  def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
932
- from sqlframe.base.session import _BaseSession
1356
+ from sqlframe.base.function_alternatives import (
1357
+ from_unixtime_bigutil,
1358
+ from_unixtime_from_timestamp,
1359
+ )
1360
+
1361
+ session = _get_session()
1362
+
1363
+ if session._is_bigquery:
1364
+ return from_unixtime_bigutil(col, format)
1365
+
1366
+ if session._is_postgres or session._is_snowflake:
1367
+ return from_unixtime_from_timestamp(col, format)
933
1368
 
934
1369
  return Column.invoke_expression_over_column(
935
1370
  col,
936
1371
  expression.UnixToStr,
937
- format=_BaseSession().format_time(format),
1372
+ format=session.format_time(format),
938
1373
  )
939
1374
 
940
1375
 
@@ -942,12 +1377,23 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
942
1377
  def unix_timestamp(
943
1378
  timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
944
1379
  ) -> Column:
945
- from sqlframe.base.session import _BaseSession
1380
+ from sqlframe.base.function_alternatives import (
1381
+ unix_timestamp_bgutil,
1382
+ unix_timestamp_from_extract,
1383
+ )
1384
+
1385
+ session = _get_session()
1386
+
1387
+ if session._is_bigquery:
1388
+ return unix_timestamp_bgutil(timestamp, format)
1389
+
1390
+ if session._is_postgres or session._is_snowflake:
1391
+ return unix_timestamp_from_extract(timestamp, format)
946
1392
 
947
1393
  return Column.invoke_expression_over_column(
948
1394
  timestamp,
949
1395
  expression.StrToUnix,
950
- format=_BaseSession().format_time(format),
1396
+ format=session.format_time(format),
951
1397
  ).cast("bigint")
952
1398
 
953
1399
 
@@ -1010,6 +1456,13 @@ def md5(col: ColumnOrName) -> Column:
1010
1456
 
1011
1457
  @meta(unsupported_engines=["duckdb", "postgres"])
1012
1458
  def sha1(col: ColumnOrName) -> Column:
1459
+ from sqlframe.base.function_alternatives import sha1_force_sha1_and_to_hex
1460
+
1461
+ session = _get_session()
1462
+
1463
+ if session._is_bigquery:
1464
+ return sha1_force_sha1_and_to_hex(col)
1465
+
1013
1466
  return Column.invoke_expression_over_column(col, expression.SHA)
1014
1467
 
1015
1468
 
@@ -1020,6 +1473,13 @@ def sha2(col: ColumnOrName, numBits: int) -> Column:
1020
1473
 
1021
1474
  @meta(unsupported_engines=["postgres"])
1022
1475
  def hash(*cols: ColumnOrName) -> Column:
1476
+ from sqlframe.base.function_alternatives import hash_from_farm_fingerprint
1477
+
1478
+ session = _get_session()
1479
+
1480
+ if session._is_bigquery:
1481
+ return hash_from_farm_fingerprint(*cols)
1482
+
1023
1483
  args = cols[1:] if len(cols) > 1 else []
1024
1484
  return Column.invoke_anonymous_function(cols[0], "HASH", *args)
1025
1485
 
@@ -1061,11 +1521,41 @@ def ascii(col: ColumnOrName) -> Column:
1061
1521
 
1062
1522
  @meta()
1063
1523
  def base64(col: ColumnOrName) -> Column:
1524
+ from sqlframe.base.function_alternatives import (
1525
+ base64_from_base64_encode,
1526
+ base64_from_blob,
1527
+ base64_from_encode,
1528
+ )
1529
+
1530
+ session = _get_session()
1531
+
1532
+ if session._is_bigquery or session._is_duckdb:
1533
+ return base64_from_blob(col)
1534
+
1535
+ if session._is_postgres:
1536
+ return base64_from_encode(col)
1537
+
1538
+ if session._is_snowflake:
1539
+ return base64_from_base64_encode(col)
1540
+
1064
1541
  return Column.invoke_expression_over_column(col, expression.ToBase64)
1065
1542
 
1066
1543
 
1067
1544
  @meta()
1068
1545
  def unbase64(col: ColumnOrName) -> Column:
1546
+ from sqlframe.base.function_alternatives import (
1547
+ unbase64_from_base64_decode_string,
1548
+ unbase64_from_decode,
1549
+ )
1550
+
1551
+ session = _get_session()
1552
+
1553
+ if session._is_postgres:
1554
+ return unbase64_from_decode(col)
1555
+
1556
+ if session._is_snowflake:
1557
+ return unbase64_from_base64_decode_string(col)
1558
+
1069
1559
  return Column.invoke_expression_over_column(col, expression.FromBase64)
1070
1560
 
1071
1561
 
@@ -1086,6 +1576,13 @@ def trim(col: ColumnOrName) -> Column:
1086
1576
 
1087
1577
  @meta()
1088
1578
  def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
1579
+ from sqlframe.base.function_alternatives import concat_ws_from_array_to_string
1580
+
1581
+ session = _get_session()
1582
+
1583
+ if session._is_bigquery:
1584
+ return concat_ws_from_array_to_string(sep, *cols)
1585
+
1089
1586
  return Column.invoke_expression_over_column(
1090
1587
  None, expression.ConcatWs, expressions=[lit(sep)] + list(cols)
1091
1588
  )
@@ -1093,6 +1590,19 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
1093
1590
 
1094
1591
  @meta(unsupported_engines=["bigquery", "snowflake"])
1095
1592
  def decode(col: ColumnOrName, charset: str) -> Column:
1593
+ from sqlframe.base.function_alternatives import (
1594
+ decode_from_blob,
1595
+ decode_from_convert_from,
1596
+ )
1597
+
1598
+ session = _get_session()
1599
+
1600
+ if session._is_duckdb:
1601
+ return decode_from_blob(col, charset)
1602
+
1603
+ if session._is_postgres:
1604
+ return decode_from_convert_from(col, charset)
1605
+
1096
1606
  return Column.invoke_expression_over_column(
1097
1607
  col, expression.Decode, charset=expression.Literal.string(charset)
1098
1608
  )
@@ -1100,6 +1610,13 @@ def decode(col: ColumnOrName, charset: str) -> Column:
1100
1610
 
1101
1611
  @meta(unsupported_engines=["bigquery", "snowflake"])
1102
1612
  def encode(col: ColumnOrName, charset: str) -> Column:
1613
+ from sqlframe.base.function_alternatives import encode_from_convert_to
1614
+
1615
+ session = _get_session()
1616
+
1617
+ if session._is_postgres:
1618
+ return encode_from_convert_to(col, charset)
1619
+
1103
1620
  return Column.invoke_expression_over_column(
1104
1621
  col, expression.Encode, charset=expression.Literal.string(charset)
1105
1622
  )
@@ -1107,11 +1624,37 @@ def encode(col: ColumnOrName, charset: str) -> Column:
1107
1624
 
1108
1625
  @meta(unsupported_engines="duckdb")
1109
1626
  def format_number(col: ColumnOrName, d: int) -> Column:
1627
+ from sqlframe.base.function_alternatives import (
1628
+ format_number_bgutil,
1629
+ format_number_from_to_char,
1630
+ )
1631
+
1632
+ session = _get_session()
1633
+
1634
+ if session._is_bigquery:
1635
+ return format_number_bgutil(col, d)
1636
+
1637
+ if session._is_postgres or session._is_snowflake:
1638
+ return format_number_from_to_char(col, d)
1639
+
1110
1640
  return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d))
1111
1641
 
1112
1642
 
1113
1643
  @meta(unsupported_engines="snowflake")
1114
1644
  def format_string(format: str, *cols: ColumnOrName) -> Column:
1645
+ from sqlframe.base.function_alternatives import (
1646
+ format_string_with_format,
1647
+ format_string_with_pipes,
1648
+ )
1649
+
1650
+ session = _get_session()
1651
+
1652
+ if session._is_duckdb:
1653
+ return format_string_with_pipes(format, *cols)
1654
+
1655
+ if session._is_bigquery or session._is_postgres:
1656
+ return format_string_with_format(format, *cols)
1657
+
1115
1658
  format_col = lit(format)
1116
1659
  columns = [Column.ensure_col(x) for x in cols]
1117
1660
  return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns)
@@ -1119,6 +1662,13 @@ def format_string(format: str, *cols: ColumnOrName) -> Column:
1119
1662
 
1120
1663
  @meta()
1121
1664
  def instr(col: ColumnOrName, substr: str) -> Column:
1665
+ from sqlframe.base.function_alternatives import instr_using_strpos
1666
+
1667
+ session = _get_session()
1668
+
1669
+ if session._is_bigquery:
1670
+ return instr_using_strpos(col, substr)
1671
+
1122
1672
  return Column.invoke_expression_over_column(col, expression.StrPosition, substr=lit(substr))
1123
1673
 
1124
1674
 
@@ -1129,13 +1679,19 @@ def overlay(
1129
1679
  pos: t.Union[ColumnOrName, int],
1130
1680
  len: t.Optional[t.Union[ColumnOrName, int]] = None,
1131
1681
  ) -> Column:
1682
+ from sqlframe.base.function_alternatives import overlay_from_substr
1683
+
1684
+ session = _get_session()
1685
+
1686
+ if session._is_bigquery or session._is_duckdb or session._is_snowflake:
1687
+ return overlay_from_substr(src, replace, pos, len)
1132
1688
  return Column.invoke_expression_over_column(
1133
1689
  src,
1134
1690
  expression.Overlay,
1135
1691
  **{
1136
- "expression": Column(replace).expression,
1137
- "from": lit(pos).expression,
1138
- "for": lit(len).expression if len is not None else None,
1692
+ "expression": Column(replace).column_expression,
1693
+ "from": lit(pos).column_expression,
1694
+ "for": lit(len).column_expression if len is not None else None,
1139
1695
  },
1140
1696
  )
1141
1697
 
@@ -1162,6 +1718,13 @@ def substring(str: ColumnOrName, pos: int, len: int) -> Column:
1162
1718
 
1163
1719
  @meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
1164
1720
  def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
1721
+ from sqlframe.base.function_alternatives import substring_index_bgutil
1722
+
1723
+ session = _get_session()
1724
+
1725
+ if session._is_bigquery:
1726
+ return substring_index_bgutil(str, delim, count)
1727
+
1165
1728
  return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count))
1166
1729
 
1167
1730
 
@@ -1169,15 +1732,22 @@ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
1169
1732
  def levenshtein(
1170
1733
  left: ColumnOrName, right: ColumnOrName, threshold: t.Optional[int] = None
1171
1734
  ) -> Column:
1735
+ from sqlframe.base.function_alternatives import levenshtein_edit_distance
1736
+
1737
+ session = _get_session()
1738
+
1739
+ if session._is_snowflake:
1740
+ return levenshtein_edit_distance(left, right, threshold)
1741
+
1172
1742
  value: t.Union[expression.Case, expression.Levenshtein] = expression.Levenshtein(
1173
- this=Column.ensure_col(left).expression,
1174
- expression=Column.ensure_col(right).expression,
1743
+ this=Column.ensure_col(left).column_expression,
1744
+ expression=Column.ensure_col(right).column_expression,
1175
1745
  )
1176
1746
  if threshold is not None:
1177
1747
  value = (
1178
1748
  expression.case()
1179
- .when(expression.LTE(this=value, expression=lit(threshold).expression), value)
1180
- .else_(lit(-1).expression)
1749
+ .when(expression.LTE(this=value, expression=lit(threshold).column_expression), value)
1750
+ .else_(lit(-1).column_expression)
1181
1751
  )
1182
1752
  return Column(value)
1183
1753
 
@@ -1196,9 +1766,9 @@ def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Colum
1196
1766
  def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
1197
1767
  return Column(
1198
1768
  expression.Pad(
1199
- this=Column.ensure_col(col).expression,
1200
- expression=lit(len).expression,
1201
- fill_pattern=lit(pad).expression,
1769
+ this=Column.ensure_col(col).column_expression,
1770
+ expression=lit(len).column_expression,
1771
+ fill_pattern=lit(pad).column_expression,
1202
1772
  # We can use `invoke_expression_over_column` because this is an actual bool instead of literal bool
1203
1773
  is_left=True,
1204
1774
  )
@@ -1209,9 +1779,9 @@ def lpad(col: ColumnOrName, len: int, pad: str) -> Column:
1209
1779
  def rpad(col: ColumnOrName, len: int, pad: str) -> Column:
1210
1780
  return Column(
1211
1781
  expression.Pad(
1212
- this=Column.ensure_col(col).expression,
1213
- expression=lit(len).expression,
1214
- fill_pattern=lit(pad).expression,
1782
+ this=Column.ensure_col(col).column_expression,
1783
+ expression=lit(len).column_expression,
1784
+ fill_pattern=lit(pad).column_expression,
1215
1785
  # We can use `invoke_expression_over_column` because this is an actual bool instead of literal bool
1216
1786
  is_left=False,
1217
1787
  )
@@ -1225,6 +1795,24 @@ def repeat(col: ColumnOrName, n: int) -> Column:
1225
1795
 
1226
1796
  @meta()
1227
1797
  def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
1798
+ from sqlframe.base.function_alternatives import (
1799
+ split_from_regex_split_to_array,
1800
+ split_with_split,
1801
+ )
1802
+
1803
+ session = _get_session()
1804
+
1805
+ if session._is_duckdb:
1806
+ if limit is not None:
1807
+ logger.warning("Limit is ignored since it is not supported in this dialect")
1808
+ limit = None
1809
+
1810
+ if session._is_bigquery or session._is_snowflake:
1811
+ return split_with_split(str, pattern, limit)
1812
+
1813
+ if session._is_postgres:
1814
+ return split_from_regex_split_to_array(str, pattern, limit)
1815
+
1228
1816
  if limit is not None:
1229
1817
  return Column.invoke_expression_over_column(
1230
1818
  str, expression.RegexpSplit, expression=lit(pattern), limit=lit(limit)
@@ -1236,22 +1824,39 @@ def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Col
1236
1824
 
1237
1825
  @meta(unsupported_engines="postgres")
1238
1826
  def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column:
1827
+ session = _get_session()
1828
+
1239
1829
  if idx is not None:
1240
- return Column.invoke_expression_over_column(
1830
+ result = Column.invoke_expression_over_column(
1241
1831
  str,
1242
1832
  expression.RegexpExtract,
1243
1833
  expression=lit(pattern),
1244
1834
  group=lit(idx),
1245
1835
  )
1246
- return Column.invoke_expression_over_column(
1247
- str, expression.RegexpExtract, expression=lit(pattern)
1248
- )
1836
+ else:
1837
+ result = Column.invoke_expression_over_column(
1838
+ str, expression.RegexpExtract, expression=lit(pattern)
1839
+ )
1840
+
1841
+ if session._is_snowflake:
1842
+ coalesce_func = get_func_from_session("coalesce")
1843
+
1844
+ result = coalesce_func(result, lit(""))
1845
+
1846
+ return result
1249
1847
 
1250
1848
 
1251
1849
  @meta()
1252
1850
  def regexp_replace(
1253
1851
  str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None
1254
1852
  ) -> Column:
1853
+ from sqlframe.base.function_alternatives import regexp_replace_global_option
1854
+
1855
+ session = _get_session()
1856
+
1857
+ if session._is_duckdb or session._is_postgres:
1858
+ return regexp_replace_global_option(str, pattern, replacement, position)
1859
+
1255
1860
  if position is not None:
1256
1861
  return Column.invoke_expression_over_column(
1257
1862
  str,
@@ -1280,16 +1885,43 @@ def soundex(col: ColumnOrName) -> Column:
1280
1885
 
1281
1886
  @meta(unsupported_engines=["postgres", "snowflake"])
1282
1887
  def bin(col: ColumnOrName) -> Column:
1888
+ from sqlframe.base.function_alternatives import bin_bgutil
1889
+
1890
+ session = _get_session()
1891
+
1892
+ if session._is_bigquery:
1893
+ return bin_bgutil(col)
1894
+
1283
1895
  return Column.invoke_anonymous_function(col, "BIN")
1284
1896
 
1285
1897
 
1286
1898
  @meta(unsupported_engines="postgres")
1287
1899
  def hex(col: ColumnOrName) -> Column:
1900
+ from sqlframe.base.function_alternatives import (
1901
+ hex_casted_as_bytes,
1902
+ hex_using_encode,
1903
+ )
1904
+
1905
+ session = _get_session()
1906
+
1907
+ if session._is_bigquery:
1908
+ return hex_casted_as_bytes(col)
1909
+
1910
+ if session._is_snowflake:
1911
+ return hex_using_encode(col)
1912
+
1288
1913
  return Column.invoke_expression_over_column(col, expression.Hex)
1289
1914
 
1290
1915
 
1291
1916
  @meta(unsupported_engines="postgres")
1292
1917
  def unhex(col: ColumnOrName) -> Column:
1918
+ from sqlframe.base.function_alternatives import unhex_hex_decode_str
1919
+
1920
+ session = _get_session()
1921
+
1922
+ if session._is_snowflake:
1923
+ return unhex_hex_decode_str(col)
1924
+
1293
1925
  return Column.invoke_expression_over_column(col, expression.Unhex)
1294
1926
 
1295
1927
 
@@ -1305,6 +1937,13 @@ def octet_length(col: ColumnOrName) -> Column:
1305
1937
 
1306
1938
  @meta()
1307
1939
  def bit_length(col: ColumnOrName) -> Column:
1940
+ from sqlframe.base.function_alternatives import bit_length_from_length
1941
+
1942
+ session = _get_session()
1943
+
1944
+ if session._is_bigquery:
1945
+ return bit_length_from_length(col)
1946
+
1308
1947
  return Column.invoke_anonymous_function(col, "BIT_LENGTH")
1309
1948
 
1310
1949
 
@@ -1326,6 +1965,19 @@ def array_agg(col: ColumnOrName) -> Column:
1326
1965
 
1327
1966
  @meta()
1328
1967
  def array_append(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
1968
+ from sqlframe.base.function_alternatives import (
1969
+ array_append_list_append,
1970
+ array_append_using_array_cat,
1971
+ )
1972
+
1973
+ session = _get_session()
1974
+
1975
+ if session._is_bigquery:
1976
+ return array_append_using_array_cat(col, value)
1977
+
1978
+ if session._is_duckdb:
1979
+ return array_append_list_append(col, value)
1980
+
1329
1981
  value = value if isinstance(value, Column) else lit(value)
1330
1982
  return Column.invoke_anonymous_function(col, "ARRAY_APPEND", value)
1331
1983
 
@@ -1388,13 +2040,21 @@ def getbit(col: ColumnOrName, pos: ColumnOrName) -> Column:
1388
2040
 
1389
2041
  @meta(unsupported_engines=["bigquery", "postgres"])
1390
2042
  def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
2043
+ session = _get_session()
2044
+
1391
2045
  cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
1392
- return Column.invoke_expression_over_column(
2046
+ result = Column.invoke_expression_over_column(
1393
2047
  None,
1394
2048
  expression.VarMap,
1395
- keys=array(*cols[::2]).expression,
1396
- values=array(*cols[1::2]).expression,
2049
+ keys=array(*cols[::2]).column_expression,
2050
+ values=array(*cols[1::2]).column_expression,
1397
2051
  )
2052
+ if not session._is_snowflake:
2053
+ return result
2054
+
2055
+ col1_dtype = col(cols[0]).dtype or "VARCHAR"
2056
+ col2_dtype = col(cols[1]).dtype or "VARCHAR"
2057
+ return result.cast(f"MAP({col1_dtype}, {col2_dtype})")
1398
2058
 
1399
2059
 
1400
2060
  @meta(unsupported_engines=["bigquery", "postgres", "snowflake"])
@@ -1404,14 +2064,43 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
1404
2064
 
1405
2065
  @meta()
1406
2066
  def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
1407
- value_col = value if isinstance(value, Column) else lit(value)
2067
+ from sqlframe.base.function_alternatives import array_contains_any
2068
+
2069
+ session = _get_session()
2070
+ lit_func = get_func_from_session("lit")
2071
+
2072
+ if session._is_postgres:
2073
+ return array_contains_any(col, value)
2074
+
2075
+ value = value if isinstance(value, Column) else lit_func(value)
2076
+
2077
+ if session._is_snowflake:
2078
+ value = value.cast("variant")
2079
+
1408
2080
  return Column.invoke_expression_over_column(
1409
- col, expression.ArrayContains, expression=value_col.expression
2081
+ col, expression.ArrayContains, expression=value.column_expression
1410
2082
  )
1411
2083
 
1412
2084
 
1413
2085
  @meta(unsupported_engines="bigquery")
1414
2086
  def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
2087
+ from sqlframe.base.function_alternatives import (
2088
+ arrays_overlap_as_plural,
2089
+ arrays_overlap_renamed,
2090
+ arrays_overlap_using_intersect,
2091
+ )
2092
+
2093
+ session = _get_session()
2094
+
2095
+ if session._is_duckdb:
2096
+ return arrays_overlap_using_intersect(col1, col2)
2097
+
2098
+ if session._is_databricks or session._is_spark:
2099
+ return arrays_overlap_renamed(col1, col2)
2100
+
2101
+ if session._is_snowflake:
2102
+ return arrays_overlap_as_plural(col1, col2)
2103
+
1415
2104
  return Column.invoke_expression_over_column(col1, expression.ArrayOverlaps, expression=col2)
1416
2105
 
1417
2106
 
@@ -1419,6 +2108,27 @@ def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
1419
2108
  def slice(
1420
2109
  x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
1421
2110
  ) -> Column:
2111
+ from sqlframe.base.function_alternatives import (
2112
+ slice_as_array_slice,
2113
+ slice_as_list_slice,
2114
+ slice_bgutil,
2115
+ slice_with_brackets,
2116
+ )
2117
+
2118
+ session = _get_session()
2119
+
2120
+ if session._is_bigquery:
2121
+ return slice_bgutil(x, start, length)
2122
+
2123
+ if session._is_duckdb:
2124
+ return slice_as_list_slice(x, start, length)
2125
+
2126
+ if session._is_postgres:
2127
+ return slice_with_brackets(x, start, length)
2128
+
2129
+ if session._is_snowflake:
2130
+ return slice_as_array_slice(x, start, length)
2131
+
1422
2132
  start_col = lit(start) if isinstance(start, int) else start
1423
2133
  length_col = lit(length) if isinstance(length, int) else length
1424
2134
  return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
@@ -1428,6 +2138,13 @@ def slice(
1428
2138
  def array_join(
1429
2139
  col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
1430
2140
  ) -> Column:
2141
+ session = _get_session()
2142
+
2143
+ if session._is_snowflake:
2144
+ if null_replacement is not None:
2145
+ logger.warning("Null replacement is ignored since it is not supported in this dialect")
2146
+ null_replacement = None
2147
+
1431
2148
  if null_replacement is not None:
1432
2149
  return Column.invoke_expression_over_column(
1433
2150
  col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement)
@@ -1444,6 +2161,19 @@ def concat(*cols: ColumnOrName) -> Column:
1444
2161
 
1445
2162
  @meta()
1446
2163
  def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
2164
+ from sqlframe.base.function_alternatives import (
2165
+ array_position_bgutil,
2166
+ array_position_cast_variant_and_flip,
2167
+ )
2168
+
2169
+ session = _get_session()
2170
+
2171
+ if session._is_bigquery:
2172
+ return array_position_bgutil(col, value)
2173
+
2174
+ if session._is_snowflake:
2175
+ return array_position_cast_variant_and_flip(col, value)
2176
+
1447
2177
  value_col = value if isinstance(value, Column) else lit(value)
1448
2178
  # Some engines return NULL if item is not found but Spark expects 0 so we coalesce to 0
1449
2179
  return coalesce(Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col), lit(0))
@@ -1451,28 +2181,75 @@ def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
1451
2181
 
1452
2182
  @meta()
1453
2183
  def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
2184
+ from sqlframe.base.function_alternatives import element_at_using_brackets
2185
+
2186
+ session = _get_session()
2187
+
2188
+ if session._is_bigquery or session._is_duckdb or session._is_postgres or session._is_snowflake:
2189
+ return element_at_using_brackets(col, value)
2190
+
1454
2191
  value_col = value if isinstance(value, Column) else lit(value)
1455
2192
  return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col)
1456
2193
 
1457
2194
 
1458
2195
  @meta()
1459
2196
  def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
2197
+ from sqlframe.base.function_alternatives import (
2198
+ array_remove_bgutil,
2199
+ array_remove_using_filter,
2200
+ )
2201
+
2202
+ session = _get_session()
2203
+
2204
+ if session._is_bigquery:
2205
+ return array_remove_bgutil(col, value)
2206
+
2207
+ if session._is_duckdb:
2208
+ return array_remove_using_filter(col, value)
2209
+
1460
2210
  value_col = value if isinstance(value, Column) else lit(value)
1461
2211
  return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col)
1462
2212
 
1463
2213
 
1464
2214
  @meta(unsupported_engines="postgres")
1465
2215
  def array_distinct(col: ColumnOrName) -> Column:
2216
+ from sqlframe.base.function_alternatives import array_distinct_bgutil
2217
+
2218
+ session = _get_session()
2219
+
2220
+ if session._is_bigquery:
2221
+ return array_distinct_bgutil(col)
2222
+
1466
2223
  return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT")
1467
2224
 
1468
2225
 
1469
2226
  @meta(unsupported_engines=["bigquery", "postgres"])
1470
2227
  def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column:
2228
+ from sqlframe.base.function_alternatives import array_intersect_using_intersection
2229
+
2230
+ session = _get_session()
2231
+
2232
+ if session._is_snowflake:
2233
+ return array_intersect_using_intersection(col1, col2)
2234
+
1471
2235
  return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2))
1472
2236
 
1473
2237
 
1474
2238
  @meta(unsupported_engines=["postgres"])
1475
2239
  def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column:
2240
+ from sqlframe.base.function_alternatives import (
2241
+ array_union_using_array_concat,
2242
+ array_union_using_list_concat,
2243
+ )
2244
+
2245
+ session = _get_session()
2246
+
2247
+ if session._is_duckdb:
2248
+ return array_union_using_list_concat(col1, col2)
2249
+
2250
+ if session._is_bigquery or session._is_snowflake:
2251
+ return array_union_using_array_concat(col1, col2)
2252
+
1476
2253
  return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2))
1477
2254
 
1478
2255
 
@@ -1504,6 +2281,19 @@ def posexplode_outer(col: ColumnOrName) -> Column:
1504
2281
  # Snowflake doesn't support JSONPath which is what this function uses
1505
2282
  @meta(unsupported_engines="snowflake")
1506
2283
  def get_json_object(col: ColumnOrName, path: str) -> Column:
2284
+ from sqlframe.base.function_alternatives import (
2285
+ get_json_object_using_arrow_op,
2286
+ get_json_object_using_function,
2287
+ )
2288
+
2289
+ session = _get_session()
2290
+
2291
+ if session._is_databricks:
2292
+ return get_json_object_using_function(col, path)
2293
+
2294
+ if session._is_postgres:
2295
+ return get_json_object_using_arrow_op(col, path)
2296
+
1507
2297
  return Column.invoke_expression_over_column(col, expression.JSONExtract, expression=lit(path))
1508
2298
 
1509
2299
 
@@ -1572,16 +2362,63 @@ def size(col: ColumnOrName) -> Column:
1572
2362
 
1573
2363
  @meta()
1574
2364
  def array_min(col: ColumnOrName) -> Column:
2365
+ from sqlframe.base.function_alternatives import (
2366
+ array_min_bgutil,
2367
+ array_min_from_sort,
2368
+ array_min_from_subquery,
2369
+ )
2370
+
2371
+ session = _get_session()
2372
+
2373
+ if session._is_bigquery:
2374
+ return array_min_bgutil(col)
2375
+
2376
+ if session._is_duckdb:
2377
+ return array_min_from_sort(col)
2378
+
2379
+ if session._is_postgres:
2380
+ return array_min_from_subquery(col)
2381
+
1575
2382
  return Column.invoke_anonymous_function(col, "ARRAY_MIN")
1576
2383
 
1577
2384
 
1578
2385
  @meta()
1579
2386
  def array_max(col: ColumnOrName) -> Column:
2387
+ from sqlframe.base.function_alternatives import (
2388
+ array_max_bgutil,
2389
+ array_max_from_sort,
2390
+ array_max_from_subquery,
2391
+ )
2392
+
2393
+ session = _get_session()
2394
+
2395
+ if session._is_bigquery:
2396
+ return array_max_bgutil(col)
2397
+
2398
+ if session._is_duckdb:
2399
+ return array_max_from_sort(col)
2400
+
2401
+ if session._is_postgres:
2402
+ return array_max_from_subquery(col)
2403
+
1580
2404
  return Column.invoke_anonymous_function(col, "ARRAY_MAX")
1581
2405
 
1582
2406
 
1583
2407
  @meta(unsupported_engines="postgres")
1584
2408
  def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
2409
+ from sqlframe.base.function_alternatives import (
2410
+ sort_array_bgutil,
2411
+ sort_array_using_array_sort,
2412
+ )
2413
+
2414
+ session = _get_session()
2415
+
2416
+ if session._is_bigquery:
2417
+ return sort_array_bgutil(col, asc)
2418
+
2419
+ if session._is_snowflake:
2420
+ return sort_array_using_array_sort(col, asc)
2421
+
1585
2422
  if asc is not None:
1586
2423
  return Column.invoke_expression_over_column(col, expression.SortArray, asc=lit(asc))
1587
2424
  return Column.invoke_expression_over_column(col, expression.SortArray)
@@ -1592,6 +2429,12 @@ def array_sort(
1592
2429
  col: ColumnOrName,
1593
2430
  comparator: t.Optional[t.Union[t.Callable[[Column, Column], Column]]] = None,
1594
2431
  ) -> Column:
2432
+ session = _get_session()
2433
+ sort_array_func = get_func_from_session("sort_array")
2434
+
2435
+ if session._is_bigquery:
2436
+ return sort_array_func(col, comparator)
2437
+
1595
2438
  if comparator is not None:
1596
2439
  f_expression = _get_lambda_from_func(comparator)
1597
2440
  return Column.invoke_expression_over_column(
@@ -1612,6 +2455,13 @@ def reverse(col: ColumnOrName) -> Column:
1612
2455
 
1613
2456
  @meta(unsupported_engines=["bigquery", "postgres"])
1614
2457
  def flatten(col: ColumnOrName) -> Column:
2458
+ from sqlframe.base.function_alternatives import flatten_using_array_flatten
2459
+
2460
+ session = _get_session()
2461
+
2462
+ if session._is_snowflake:
2463
+ return flatten_using_array_flatten(col)
2464
+
1615
2465
  return Column.invoke_expression_over_column(col, expression.Flatten)
1616
2466
 
1617
2467
 
@@ -1650,6 +2500,13 @@ def arrays_zip(*cols: ColumnOrName) -> Column:
1650
2500
 
1651
2501
  @meta(unsupported_engines=["bigquery", "duckdb", "postgres"])
1652
2502
  def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
2503
+ from sqlframe.base.function_alternatives import map_concat_using_map_cat
2504
+
2505
+ session = _get_session()
2506
+
2507
+ if session._is_snowflake:
2508
+ return map_concat_using_map_cat(*cols)
2509
+
1653
2510
  columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
1654
2511
  if len(columns) == 1:
1655
2512
  return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT") # type: ignore
@@ -1660,11 +2517,28 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
1660
2517
  def sequence(
1661
2518
  start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
1662
2519
  ) -> Column:
2520
+ from sqlframe.base.function_alternatives import (
2521
+ sequence_from_array_generate_range,
2522
+ sequence_from_generate_array,
2523
+ sequence_from_generate_series,
2524
+ )
2525
+
2526
+ session = _get_session()
2527
+
2528
+ if session._is_bigquery:
2529
+ return sequence_from_generate_array(start, stop, step)
2530
+
2531
+ if session._is_duckdb:
2532
+ return sequence_from_generate_series(start, stop, step)
2533
+
2534
+ if session._is_snowflake:
2535
+ return sequence_from_array_generate_range(start, stop, step)
2536
+
1663
2537
  return Column(
1664
2538
  expression.GenerateSeries(
1665
- start=Column.ensure_col(start).expression,
1666
- end=Column.ensure_col(stop).expression,
1667
- step=Column.ensure_col(step).expression if step is not None else None,
2539
+ start=Column.ensure_col(start).column_expression,
2540
+ end=Column.ensure_col(stop).column_expression,
2541
+ step=Column.ensure_col(step).column_expression if step is not None else None,
1668
2542
  )
1669
2543
  )
1670
2544
 
@@ -1778,6 +2652,23 @@ def map_zip_with(
1778
2652
 
1779
2653
  @meta()
1780
2654
  def typeof(col: ColumnOrName) -> Column:
2655
+ from sqlframe.base.function_alternatives import (
2656
+ typeof_bgutil,
2657
+ typeof_from_variant,
2658
+ typeof_pg_typeof,
2659
+ )
2660
+
2661
+ session = _get_session()
2662
+
2663
+ if session._is_bigquery:
2664
+ return typeof_bgutil(col)
2665
+
2666
+ if session._is_postgres:
2667
+ return typeof_pg_typeof(col)
2668
+
2669
+ if session._is_snowflake:
2670
+ return typeof_from_variant(col)
2671
+
1781
2672
  return Column.invoke_anonymous_function(col, "TYPEOF")
1782
2673
 
1783
2674
 
@@ -1913,9 +2804,21 @@ def to_binary(col: ColumnOrName, format: t.Optional[ColumnOrName] = None) -> Col
1913
2804
 
1914
2805
  @meta()
1915
2806
  def any_value(col: ColumnOrName, ignoreNulls: t.Optional[t.Union[bool, Column]] = None) -> Column:
2807
+ session = _get_session()
2808
+
2809
+ if session._is_duckdb:
2810
+ if not ignoreNulls:
2811
+ logger.warning("Nulls are always ignored when using `ANY_VALUE` on this engine")
2812
+ ignoreNulls = None
2813
+
2814
+ if session._is_bigquery or session._is_postgres or session._is_snowflake:
2815
+ if ignoreNulls:
2816
+ logger.warning("Ignoring nulls is not supported in this dialect")
2817
+ ignoreNulls = None
2818
+
1916
2819
  column = Column.invoke_expression_over_column(col, expression.AnyValue)
1917
2820
  if ignoreNulls:
1918
- return Column(expression.IgnoreNulls(this=column.expression))
2821
+ return Column(expression.IgnoreNulls(this=column.column_expression))
1919
2822
  return column
1920
2823
 
1921
2824
 
@@ -2056,7 +2959,7 @@ def cardinality(col: ColumnOrName) -> Column:
2056
2959
 
2057
2960
  @meta()
2058
2961
  def char(col: ColumnOrName) -> Column:
2059
- return Column(expression.Chr(expressions=Column.ensure_col(col).expression))
2962
+ return Column(expression.Chr(expressions=Column.ensure_col(col).column_expression))
2060
2963
 
2061
2964
 
2062
2965
  @meta(unsupported_engines="*")
@@ -2072,7 +2975,7 @@ def character_length(str: ColumnOrName) -> Column:
2072
2975
  @meta(unsupported_engines=["bigquery", "postgres"])
2073
2976
  def contains(left: ColumnOrName, right: ColumnOrName) -> Column:
2074
2977
  return Column.invoke_expression_over_column(
2075
- left, expression.Contains, expression=Column.ensure_col(right).expression
2978
+ left, expression.Contains, expression=Column.ensure_col(right).column_expression
2076
2979
  )
2077
2980
 
2078
2981
 
@@ -2084,9 +2987,9 @@ def convert_timezone(
2084
2987
 
2085
2988
  return Column(
2086
2989
  expression.ConvertTimezone(
2087
- timestamp=to_timestamp(Column.ensure_col(sourceTs)).expression,
2088
- source_tz=sourceTz.expression if sourceTz else None,
2089
- target_tz=Column.ensure_col(targetTz).expression,
2990
+ timestamp=to_timestamp(Column.ensure_col(sourceTs)).column_expression,
2991
+ source_tz=sourceTz.column_expression if sourceTz else None,
2992
+ target_tz=Column.ensure_col(targetTz).column_expression,
2090
2993
  )
2091
2994
  )
2092
2995
 
@@ -2180,6 +3083,13 @@ def current_timezone() -> Column:
2180
3083
 
2181
3084
  @meta()
2182
3085
  def current_user() -> Column:
3086
+ from sqlframe.base.function_alternatives import current_user_from_session_user
3087
+
3088
+ session = _get_session()
3089
+
3090
+ if session._is_bigquery:
3091
+ return current_user_from_session_user()
3092
+
2183
3093
  return Column.invoke_expression_over_column(None, expression.CurrentUser)
2184
3094
 
2185
3095
 
@@ -2204,6 +3114,21 @@ def datepart(field: ColumnOrName, source: ColumnOrName) -> Column:
2204
3114
 
2205
3115
  @meta(unsupported_engines=["bigquery", "postgres", "snowflake"])
2206
3116
  def day(col: ColumnOrName) -> Column:
3117
+ from sqlframe.base.function_alternatives import day_with_try_to_timestamp
3118
+
3119
+ session = _get_session()
3120
+
3121
+ if session._is_duckdb:
3122
+ try_to_timestamp = get_func_from_session("try_to_timestamp")
3123
+ to_date = get_func_from_session("to_date")
3124
+ when = get_func_from_session("when")
3125
+ _is_string = get_func_from_session("_is_string")
3126
+ coalesce = get_func_from_session("coalesce")
3127
+ col = when(
3128
+ _is_string(col),
3129
+ coalesce(try_to_timestamp(col), to_date(col)),
3130
+ ).otherwise(col)
3131
+
2207
3132
  return Column.invoke_expression_over_column(col, expression.Day)
2208
3133
 
2209
3134
 
@@ -2222,6 +3147,19 @@ def elt(*inputs: ColumnOrName) -> Column:
2222
3147
 
2223
3148
  @meta()
2224
3149
  def endswith(str: ColumnOrName, suffix: ColumnOrName) -> Column:
3150
+ from sqlframe.base.function_alternatives import (
3151
+ endswith_using_like,
3152
+ endswith_with_underscore,
3153
+ )
3154
+
3155
+ session = _get_session()
3156
+
3157
+ if session._is_bigquery or session._is_duckdb:
3158
+ return endswith_with_underscore(str, suffix)
3159
+
3160
+ if session._is_postgres:
3161
+ return endswith_using_like(str, suffix)
3162
+
2225
3163
  return Column.invoke_anonymous_function(str, "endswith", suffix)
2226
3164
 
2227
3165
 
@@ -2237,6 +3175,11 @@ def every(col: ColumnOrName) -> Column:
2237
3175
 
2238
3176
  @meta()
2239
3177
  def extract(field: ColumnOrName, source: ColumnOrName) -> Column:
3178
+ session = _get_session()
3179
+
3180
+ if session._is_bigquery:
3181
+ field = expression.Var(this=Column.ensure_col(field).alias_or_name) # type: ignore
3182
+
2240
3183
  return Column.invoke_expression_over_column(field, expression.Extract, expression=source)
2241
3184
 
2242
3185
 
@@ -2250,7 +3193,7 @@ def first_value(col: ColumnOrName, ignoreNulls: t.Optional[t.Union[bool, Column]
2250
3193
  column = Column.invoke_expression_over_column(col, expression.FirstValue)
2251
3194
 
2252
3195
  if ignoreNulls:
2253
- return Column(expression.IgnoreNulls(this=column.expression))
3196
+ return Column(expression.IgnoreNulls(this=column.column_expression))
2254
3197
  return column
2255
3198
 
2256
3199
 
@@ -2665,8 +3608,8 @@ def ilike(
2665
3608
  if escapeChar is not None:
2666
3609
  return Column(
2667
3610
  expression.Escape(
2668
- this=column.expression,
2669
- expression=Column.ensure_col(escapeChar).expression,
3611
+ this=column.column_expression,
3612
+ expression=Column.ensure_col(escapeChar).column_expression,
2670
3613
  )
2671
3614
  )
2672
3615
  return column
@@ -2913,7 +3856,7 @@ def last_value(col: ColumnOrName, ignoreNulls: t.Optional[t.Union[bool, Column]]
2913
3856
  column = Column.invoke_expression_over_column(col, expression.LastValue)
2914
3857
 
2915
3858
  if ignoreNulls:
2916
- return Column(expression.IgnoreNulls(this=column.expression))
3859
+ return Column(expression.IgnoreNulls(this=column.column_expression))
2917
3860
  return column
2918
3861
 
2919
3862
 
@@ -2963,6 +3906,11 @@ def left(str: ColumnOrName, len: ColumnOrName) -> Column:
2963
3906
  >>> df.select(left(df.a, df.b).alias('r')).collect()
2964
3907
  [Row(r='Spa')]
2965
3908
  """
3909
+ session = _get_session()
3910
+
3911
+ if session._is_postgres:
3912
+ len = Column.ensure_col(len).cast("integer")
3913
+
2966
3914
  return Column.invoke_expression_over_column(str, expression.Left, expression=len)
2967
3915
 
2968
3916
 
@@ -3014,8 +3962,8 @@ def like(
3014
3962
  if escapeChar is not None:
3015
3963
  return Column(
3016
3964
  expression.Escape(
3017
- this=column.expression,
3018
- expression=Column.ensure_col(escapeChar).expression,
3965
+ this=column.column_expression,
3966
+ expression=Column.ensure_col(escapeChar).column_expression,
3019
3967
  )
3020
3968
  )
3021
3969
  return column
@@ -3853,6 +4801,16 @@ def position(
3853
4801
  | 4|
3854
4802
  +-----------------+
3855
4803
  """
4804
+ from sqlframe.base.function_alternatives import position_as_strpos
4805
+
4806
+ session = _get_session()
4807
+
4808
+ if session._is_bigquery:
4809
+ return position_as_strpos(substr, str, start)
4810
+
4811
+ if session._is_postgres:
4812
+ start = Column.ensure_col(start).cast("integer") if start else None
4813
+
3856
4814
  if start is not None:
3857
4815
  return Column.invoke_expression_over_column(
3858
4816
  str, expression.StrPosition, substr=substr, position=start
@@ -4038,6 +4996,13 @@ def regexp(str: ColumnOrName, regexp: ColumnOrName) -> Column:
4038
4996
  | true|
4039
4997
  +-------------------+
4040
4998
  """
4999
+ from sqlframe.base.function_alternatives import regexp_extract_only_one_group
5000
+
5001
+ session = _get_session()
5002
+
5003
+ if session._is_bigquery:
5004
+ return regexp_extract_only_one_group(str, regexp) # type: ignore
5005
+
4041
5006
  return Column.invoke_anonymous_function(str, "regexp", regexp)
4042
5007
 
4043
5008
 
@@ -4575,6 +5540,11 @@ def right(str: ColumnOrName, len: ColumnOrName) -> Column:
4575
5540
  >>> df.select(right(df.a, df.b).alias('r')).collect()
4576
5541
  [Row(r='SQL')]
4577
5542
  """
5543
+ session = _get_session()
5544
+
5545
+ if session._is_postgres:
5546
+ len = Column.ensure_col(len).cast("integer")
5547
+
4578
5548
  return Column.invoke_expression_over_column(str, expression.Right, expression=len)
4579
5549
 
4580
5550
 
@@ -5030,6 +6000,13 @@ def to_number(col: ColumnOrName, format: ColumnOrName) -> Column:
5030
6000
  >>> df.select(to_number(df.e, lit("$99.99")).alias('r')).collect()
5031
6001
  [Row(r=Decimal('78.12'))]
5032
6002
  """
6003
+ from sqlframe.base.function_alternatives import to_number_using_to_double
6004
+
6005
+ session = _get_session()
6006
+
6007
+ if session._is_snowflake:
6008
+ return to_number_using_to_double(col, format)
6009
+
5033
6010
  return Column.invoke_expression_over_column(col, expression.ToNumber, format=format)
5034
6011
 
5035
6012
 
@@ -5149,11 +6126,14 @@ def to_unix_timestamp(
5149
6126
  [Row(r=None)]
5150
6127
  >>> spark.conf.unset("spark.sql.session.timeZone")
5151
6128
  """
5152
- from sqlframe.base.session import _BaseSession
6129
+ session = _get_session()
6130
+
6131
+ if session._is_duckdb:
6132
+ format = format or _BaseSession().default_time_format
5153
6133
 
5154
6134
  if format is not None:
5155
6135
  return Column.invoke_expression_over_column(
5156
- timestamp, expression.StrToUnix, format=_BaseSession().format_time(format)
6136
+ timestamp, expression.StrToUnix, format=session.format_time(format)
5157
6137
  )
5158
6138
  else:
5159
6139
  return Column.invoke_expression_over_column(timestamp, expression.StrToUnix)
@@ -5306,10 +6286,21 @@ def try_element_at(col: ColumnOrName, extraction: ColumnOrName) -> Column:
5306
6286
  >>> df.select(try_element_at(df.data, lit("a")).alias('r')).collect()
5307
6287
  [Row(r=1.0)]
5308
6288
  """
6289
+ session = _get_session()
6290
+
6291
+ if session._is_databricks or session._is_duckdb or session._is_postgres or session._is_spark:
6292
+ lit = get_func_from_session("lit")
6293
+ extraction = Column.ensure_col(extraction)
6294
+ if (
6295
+ isinstance(extraction.column_expression, expression.Literal)
6296
+ and extraction.column_expression.is_number
6297
+ ):
6298
+ extraction = extraction - lit(1)
6299
+
5309
6300
  return Column(
5310
6301
  expression.Bracket(
5311
- this=Column.ensure_col(col).expression,
5312
- expressions=[Column.ensure_col(extraction).expression],
6302
+ this=Column.ensure_col(col).column_expression,
6303
+ expressions=[Column.ensure_col(extraction).column_expression],
5313
6304
  safe=True,
5314
6305
  )
5315
6306
  )
@@ -5340,12 +6331,27 @@ def try_to_timestamp(col: ColumnOrName, format: t.Optional[ColumnOrName] = None)
5340
6331
  >>> df.select(try_to_timestamp(df.t, lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).collect()
5341
6332
  [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
5342
6333
  """
5343
- from sqlframe.base.session import _BaseSession
6334
+ from sqlframe.base.function_alternatives import (
6335
+ try_to_timestamp_pgtemp,
6336
+ try_to_timestamp_safe,
6337
+ try_to_timestamp_strptime,
6338
+ )
6339
+
6340
+ session = _get_session()
6341
+
6342
+ if session._is_bigquery:
6343
+ return try_to_timestamp_safe(col, format)
6344
+
6345
+ if session._is_duckdb:
6346
+ return try_to_timestamp_strptime(col, format)
6347
+
6348
+ if session._is_postgres:
6349
+ return try_to_timestamp_pgtemp(col, format)
5344
6350
 
5345
6351
  return Column.invoke_anonymous_function(
5346
6352
  col,
5347
6353
  "try_to_timestamp",
5348
- _BaseSession().format_execution_time(format), # type: ignore
6354
+ session.format_execution_time(format), # type: ignore
5349
6355
  )
5350
6356
 
5351
6357
 
@@ -5841,6 +6847,27 @@ def years(col: ColumnOrName) -> Column:
5841
6847
  # SQLFrame specific
5842
6848
  @meta()
5843
6849
  def _is_string(col: ColumnOrName) -> Column:
6850
+ from sqlframe.base.function_alternatives import (
6851
+ _is_string_using_typeof_char_varying,
6852
+ _is_string_using_typeof_string,
6853
+ _is_string_using_typeof_string_lcase,
6854
+ _is_string_using_typeof_varchar,
6855
+ )
6856
+
6857
+ session = _get_session()
6858
+
6859
+ if session._is_bigquery:
6860
+ return _is_string_using_typeof_string(col)
6861
+
6862
+ if session._is_duckdb:
6863
+ return _is_string_using_typeof_varchar(col)
6864
+
6865
+ if session._is_postgres:
6866
+ return _is_string_using_typeof_char_varying(col)
6867
+
6868
+ if session._is_databricks or session._is_spark:
6869
+ return _is_string_using_typeof_string_lcase(col)
6870
+
5844
6871
  col = Column.invoke_anonymous_function(col, "TO_VARIANT")
5845
6872
  return Column.invoke_anonymous_function(col, "IS_VARCHAR")
5846
6873
 
@@ -5889,7 +6916,7 @@ def _get_lambda_from_func(lambda_expression: t.Callable):
5889
6916
  for x in lambda_expression.__code__.co_varnames
5890
6917
  ]
5891
6918
  return expression.Lambda(
5892
- this=lambda_expression(*[Column(x) for x in variables]).expression,
6919
+ this=lambda_expression(*[Column(x) for x in variables]).column_expression,
5893
6920
  expressions=variables,
5894
6921
  )
5895
6922