sqlframe 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sqlframe/__init__.py +0 -0
  2. sqlframe/_version.py +16 -0
  3. sqlframe/base/__init__.py +0 -0
  4. sqlframe/base/_typing.py +39 -0
  5. sqlframe/base/catalog.py +1163 -0
  6. sqlframe/base/column.py +388 -0
  7. sqlframe/base/dataframe.py +1519 -0
  8. sqlframe/base/decorators.py +51 -0
  9. sqlframe/base/exceptions.py +14 -0
  10. sqlframe/base/function_alternatives.py +1055 -0
  11. sqlframe/base/functions.py +1678 -0
  12. sqlframe/base/group.py +102 -0
  13. sqlframe/base/mixins/__init__.py +0 -0
  14. sqlframe/base/mixins/catalog_mixins.py +419 -0
  15. sqlframe/base/mixins/readwriter_mixins.py +118 -0
  16. sqlframe/base/normalize.py +84 -0
  17. sqlframe/base/operations.py +87 -0
  18. sqlframe/base/readerwriter.py +679 -0
  19. sqlframe/base/session.py +585 -0
  20. sqlframe/base/transforms.py +13 -0
  21. sqlframe/base/types.py +418 -0
  22. sqlframe/base/util.py +242 -0
  23. sqlframe/base/window.py +139 -0
  24. sqlframe/bigquery/__init__.py +23 -0
  25. sqlframe/bigquery/catalog.py +255 -0
  26. sqlframe/bigquery/column.py +1 -0
  27. sqlframe/bigquery/dataframe.py +54 -0
  28. sqlframe/bigquery/functions.py +378 -0
  29. sqlframe/bigquery/group.py +14 -0
  30. sqlframe/bigquery/readwriter.py +29 -0
  31. sqlframe/bigquery/session.py +89 -0
  32. sqlframe/bigquery/types.py +1 -0
  33. sqlframe/bigquery/window.py +1 -0
  34. sqlframe/duckdb/__init__.py +20 -0
  35. sqlframe/duckdb/catalog.py +108 -0
  36. sqlframe/duckdb/column.py +1 -0
  37. sqlframe/duckdb/dataframe.py +55 -0
  38. sqlframe/duckdb/functions.py +47 -0
  39. sqlframe/duckdb/group.py +14 -0
  40. sqlframe/duckdb/readwriter.py +111 -0
  41. sqlframe/duckdb/session.py +65 -0
  42. sqlframe/duckdb/types.py +1 -0
  43. sqlframe/duckdb/window.py +1 -0
  44. sqlframe/postgres/__init__.py +23 -0
  45. sqlframe/postgres/catalog.py +106 -0
  46. sqlframe/postgres/column.py +1 -0
  47. sqlframe/postgres/dataframe.py +54 -0
  48. sqlframe/postgres/functions.py +61 -0
  49. sqlframe/postgres/group.py +14 -0
  50. sqlframe/postgres/readwriter.py +29 -0
  51. sqlframe/postgres/session.py +68 -0
  52. sqlframe/postgres/types.py +1 -0
  53. sqlframe/postgres/window.py +1 -0
  54. sqlframe/redshift/__init__.py +23 -0
  55. sqlframe/redshift/catalog.py +127 -0
  56. sqlframe/redshift/column.py +1 -0
  57. sqlframe/redshift/dataframe.py +54 -0
  58. sqlframe/redshift/functions.py +18 -0
  59. sqlframe/redshift/group.py +14 -0
  60. sqlframe/redshift/readwriter.py +29 -0
  61. sqlframe/redshift/session.py +53 -0
  62. sqlframe/redshift/types.py +1 -0
  63. sqlframe/redshift/window.py +1 -0
  64. sqlframe/snowflake/__init__.py +26 -0
  65. sqlframe/snowflake/catalog.py +134 -0
  66. sqlframe/snowflake/column.py +1 -0
  67. sqlframe/snowflake/dataframe.py +54 -0
  68. sqlframe/snowflake/functions.py +18 -0
  69. sqlframe/snowflake/group.py +14 -0
  70. sqlframe/snowflake/readwriter.py +29 -0
  71. sqlframe/snowflake/session.py +53 -0
  72. sqlframe/snowflake/types.py +1 -0
  73. sqlframe/snowflake/window.py +1 -0
  74. sqlframe/spark/__init__.py +23 -0
  75. sqlframe/spark/catalog.py +1028 -0
  76. sqlframe/spark/column.py +1 -0
  77. sqlframe/spark/dataframe.py +54 -0
  78. sqlframe/spark/functions.py +22 -0
  79. sqlframe/spark/group.py +14 -0
  80. sqlframe/spark/readwriter.py +29 -0
  81. sqlframe/spark/session.py +90 -0
  82. sqlframe/spark/types.py +1 -0
  83. sqlframe/spark/window.py +1 -0
  84. sqlframe/standalone/__init__.py +26 -0
  85. sqlframe/standalone/catalog.py +13 -0
  86. sqlframe/standalone/column.py +1 -0
  87. sqlframe/standalone/dataframe.py +36 -0
  88. sqlframe/standalone/functions.py +1 -0
  89. sqlframe/standalone/group.py +14 -0
  90. sqlframe/standalone/readwriter.py +19 -0
  91. sqlframe/standalone/session.py +40 -0
  92. sqlframe/standalone/types.py +1 -0
  93. sqlframe/standalone/window.py +1 -0
  94. sqlframe-1.1.3.dist-info/LICENSE +21 -0
  95. sqlframe-1.1.3.dist-info/METADATA +172 -0
  96. sqlframe-1.1.3.dist-info/RECORD +98 -0
  97. sqlframe-1.1.3.dist-info/WHEEL +5 -0
  98. sqlframe-1.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1055 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ import re
6
+ import typing as t
7
+
8
+ from sqlglot import exp as expression
9
+ from sqlglot.helper import ensure_list
10
+
11
+ from sqlframe.base.column import Column
12
+ from sqlframe.base.util import get_func_from_session
13
+
14
+ if t.TYPE_CHECKING:
15
+ from sqlframe.base._typing import ColumnOrLiteral, ColumnOrName
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def e_literal() -> Column:
21
+ lit = get_func_from_session("lit")
22
+
23
+ return lit(math.e)
24
+
25
+
26
+ def expm1_from_exp(col: ColumnOrName) -> Column:
27
+ exp = get_func_from_session("exp")
28
+ lit = get_func_from_session("lit")
29
+ return exp(col) - lit(1)
30
+
31
+
32
+ def log1p_from_log(col: ColumnOrName) -> Column:
33
+ from sqlframe.base.session import _BaseSession
34
+
35
+ session: _BaseSession = _BaseSession()
36
+ log = get_func_from_session("log", session)
37
+ lit = get_func_from_session("lit", session)
38
+ return log(col + lit(1))
39
+
40
+
41
+ def rint_from_round(col: ColumnOrName) -> Column:
42
+ from sqlframe.base.session import _BaseSession
43
+
44
+ round = get_func_from_session("round", _BaseSession())
45
+ return round(col, 0)
46
+
47
+
48
+ def kurtosis_from_kurtosis_pop(col: ColumnOrName) -> Column:
49
+ return Column.invoke_anonymous_function(col, "KURTOSIS_POP")
50
+
51
+
52
+ def collect_set_from_list_distinct(col: ColumnOrName) -> Column:
53
+ collect_list = get_func_from_session("collect_list")
54
+ return collect_list(Column(expression.Distinct(expressions=[Column(col).expression])))
55
+
56
+
57
+ def first_always_ignore_nulls(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
58
+ from sqlframe.base.functions import first
59
+
60
+ return first(col)
61
+
62
+
63
+ def factorial_from_case_statement(col: ColumnOrName) -> Column:
64
+ from sqlframe.base.session import _BaseSession
65
+
66
+ session: _BaseSession = _BaseSession()
67
+ when = get_func_from_session("when", session)
68
+ col_func = get_func_from_session("col", session)
69
+ lit = get_func_from_session("lit", session)
70
+ return (
71
+ when(
72
+ col_func(col) == lit(1),
73
+ lit(1),
74
+ )
75
+ .when(
76
+ col_func(col) == lit(2),
77
+ lit(2),
78
+ )
79
+ .when(
80
+ col_func(col) == lit(3),
81
+ lit(6),
82
+ )
83
+ .when(
84
+ col_func(col) == lit(4),
85
+ lit(24),
86
+ )
87
+ .when(
88
+ col_func(col) == lit(5),
89
+ lit(120),
90
+ )
91
+ .when(
92
+ col_func(col) == lit(6),
93
+ lit(720),
94
+ )
95
+ .when(
96
+ col_func(col) == lit(7),
97
+ lit(5040),
98
+ )
99
+ .when(
100
+ col_func(col) == lit(8),
101
+ lit(40320),
102
+ )
103
+ .when(
104
+ col_func(col) == lit(9),
105
+ lit(362880),
106
+ )
107
+ .when(
108
+ col_func(col) == lit(10),
109
+ lit(3628800),
110
+ )
111
+ .when(
112
+ col_func(col) == lit(11),
113
+ lit(39916800),
114
+ )
115
+ .when(
116
+ col_func(col) == lit(12),
117
+ lit(479001600),
118
+ )
119
+ .when(
120
+ col_func(col) == lit(13),
121
+ lit(6227020800),
122
+ )
123
+ .when(
124
+ col_func(col) == lit(14),
125
+ lit(87178291200),
126
+ )
127
+ .when(
128
+ col_func(col) == lit(15),
129
+ lit(1307674368000),
130
+ )
131
+ .when(
132
+ col_func(col) == lit(16),
133
+ lit(20922789888000),
134
+ )
135
+ .when(
136
+ col_func(col) == lit(17),
137
+ lit(355687428096000),
138
+ )
139
+ .when(
140
+ col_func(col) == lit(18),
141
+ lit(6402373705728000),
142
+ )
143
+ .when(
144
+ col_func(col) == lit(19),
145
+ lit(121645100408832000),
146
+ )
147
+ .when(
148
+ col_func(col) == lit(20),
149
+ lit(2432902008176640000),
150
+ )
151
+ .otherwise(
152
+ lit(None),
153
+ )
154
+ )
155
+
156
+
157
+ def factorial_ensure_int(col: ColumnOrName) -> Column:
158
+ col_func = get_func_from_session("col")
159
+
160
+ return Column.invoke_anonymous_function(col_func(col).cast("integer"), "FACTORIAL")
161
+
162
+
163
+ def isnan_using_equal(col: ColumnOrName) -> Column:
164
+ lit = get_func_from_session("lit")
165
+ return Column(
166
+ expression.EQ(this=Column(col).expression, expression=lit(float("nan")).expression)
167
+ )
168
+
169
+
170
+ def isnull_using_equal(col: ColumnOrName) -> Column:
171
+ lit = get_func_from_session("lit")
172
+ col_func = get_func_from_session("col")
173
+ return Column(expression.Is(this=col_func(col).expression, expression=lit(None).expression))
174
+
175
+
176
+ def nanvl_as_case(col1: ColumnOrName, col2: ColumnOrName) -> Column:
177
+ when = get_func_from_session("when")
178
+ isnan = get_func_from_session("isnan")
179
+ col = get_func_from_session("col")
180
+ return when(~isnan(col1), col(col1)).otherwise(col(col2))
181
+
182
+
183
+ def percentile_approx_without_accuracy(
184
+ col: ColumnOrName,
185
+ percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
186
+ accuracy: t.Optional[float] = None,
187
+ ) -> Column:
188
+ from sqlframe.base.functions import percentile_approx
189
+
190
+ if accuracy:
191
+ logger.warning("Accuracy is ignored since it is not supported in this dialect")
192
+ return percentile_approx(col, percentage)
193
+
194
+
195
+ def percentile_approx_without_accuracy_and_plural(
196
+ col: ColumnOrName,
197
+ percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
198
+ accuracy: t.Optional[float] = None,
199
+ ) -> Column:
200
+ lit = get_func_from_session("lit")
201
+ array = get_func_from_session("array")
202
+ col_func = get_func_from_session("col")
203
+
204
+ def make_bracket_approx_percentile(percentage: float) -> expression.Bracket:
205
+ return expression.Bracket(
206
+ this=expression.Anonymous(
207
+ this="APPROX_QUANTILES",
208
+ expressions=[col_func(col).expression, lit(100).expression],
209
+ ),
210
+ expressions=[lit(int(percentage * 100)).cast("int").expression],
211
+ offset=0,
212
+ safe=False,
213
+ )
214
+
215
+ if accuracy:
216
+ logger.warning("Accuracy is ignored since it is not supported in this dialect")
217
+ if isinstance(percentage, (list, tuple)):
218
+ return array(*[make_bracket_approx_percentile(p) for p in percentage])
219
+ return Column(make_bracket_approx_percentile(percentage)) # type: ignore
220
+
221
+
222
+ def percentile_without_disc(
223
+ col: ColumnOrName,
224
+ percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]],
225
+ frequency: t.Optional[ColumnOrLiteral] = None,
226
+ ) -> Column:
227
+ lit = get_func_from_session("lit")
228
+ col_func = get_func_from_session("col")
229
+
230
+ percentage_col = percentage if isinstance(percentage, Column) else lit(percentage)
231
+ func_expressions = [
232
+ col_func(col).expression,
233
+ percentage_col.expression,
234
+ ]
235
+ if frequency:
236
+ func_expressions.append(frequency if isinstance(frequency, Column) else lit(frequency))
237
+ return Column(
238
+ expression.Anonymous(
239
+ this="PERCENTILE",
240
+ expressions=func_expressions,
241
+ )
242
+ )
243
+
244
+
245
+ def rand_no_seed(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
246
+ from sqlframe.base.functions import rand
247
+
248
+ if seed:
249
+ logger.warning("Seed is ignored since it is not supported in this dialect")
250
+ return rand()
251
+
252
+
253
+ def round_cast_as_numeric(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
254
+ from sqlframe.base.functions import round
255
+
256
+ col_func = get_func_from_session("col")
257
+
258
+ return round(col_func(col).cast("numeric"), scale)
259
+
260
+
261
+ def year_from_extract(col: ColumnOrName) -> Column:
262
+ col_func = get_func_from_session("col")
263
+
264
+ return Column(
265
+ expression.Extract(
266
+ this=expression.Var(this="year"), expression=col_func(col).cast("date").expression
267
+ )
268
+ )
269
+
270
+
271
+ def quarter_from_extract(col: ColumnOrName) -> Column:
272
+ col_func = get_func_from_session("col")
273
+
274
+ return Column(
275
+ expression.Extract(
276
+ this=expression.Var(this="quarter"), expression=col_func(col).cast("date").expression
277
+ )
278
+ )
279
+
280
+
281
+ def month_from_extract(col: ColumnOrName) -> Column:
282
+ col_func = get_func_from_session("col")
283
+
284
+ return Column(
285
+ expression.Extract(
286
+ this=expression.Var(this="month"), expression=col_func(col).cast("date").expression
287
+ )
288
+ )
289
+
290
+
291
+ def dayofweek_from_extract(col: ColumnOrName) -> Column:
292
+ col_func = get_func_from_session("col")
293
+
294
+ return Column(
295
+ expression.Extract(
296
+ this=expression.Var(this="dayofweek"), expression=col_func(col).cast("date").expression
297
+ )
298
+ )
299
+
300
+
301
+ def dayofweek_from_extract_with_isodow(col: ColumnOrName) -> Column:
302
+ col_func = get_func_from_session("col")
303
+
304
+ return Column(
305
+ expression.Extract(
306
+ this=expression.Var(this="isodow"), expression=col_func(col).cast("date").expression
307
+ )
308
+ )
309
+
310
+
311
+ def dayofmonth_from_extract(col: ColumnOrName) -> Column:
312
+ col_func = get_func_from_session("col")
313
+
314
+ return Column(
315
+ expression.Extract(
316
+ this=expression.Var(this="dayofmonth"), expression=col_func(col).cast("date").expression
317
+ )
318
+ )
319
+
320
+
321
+ def dayofmonth_from_extract_with_day(col: ColumnOrName) -> Column:
322
+ col_func = get_func_from_session("col")
323
+
324
+ return Column(
325
+ expression.Extract(
326
+ this=expression.Var(this="day"), expression=col_func(col).cast("date").expression
327
+ )
328
+ )
329
+
330
+
331
+ def dayofyear_from_extract(col: ColumnOrName) -> Column:
332
+ col_func = get_func_from_session("col")
333
+
334
+ return Column(
335
+ expression.Extract(
336
+ this=expression.Var(this="dayofyear"), expression=col_func(col).cast("date").expression
337
+ )
338
+ )
339
+
340
+
341
+ def dayofyear_from_extract_doy(col: ColumnOrName) -> Column:
342
+ col_func = get_func_from_session("col")
343
+
344
+ return Column(
345
+ expression.Extract(
346
+ this=expression.Var(this="doy"), expression=col_func(col).cast("date").expression
347
+ )
348
+ )
349
+
350
+
351
+ def hour_from_extract(col: ColumnOrName) -> Column:
352
+ col_func = get_func_from_session("col")
353
+
354
+ return Column(
355
+ expression.Extract(this=expression.Var(this="hour"), expression=col_func(col).expression)
356
+ )
357
+
358
+
359
+ def minute_from_extract(col: ColumnOrName) -> Column:
360
+ col_func = get_func_from_session("col")
361
+
362
+ return Column(
363
+ expression.Extract(this=expression.Var(this="minute"), expression=col_func(col).expression)
364
+ )
365
+
366
+
367
+ def second_from_extract(col: ColumnOrName) -> Column:
368
+ col_func = get_func_from_session("col")
369
+
370
+ return Column(
371
+ expression.Extract(this=expression.Var(this="second"), expression=col_func(col).expression)
372
+ )
373
+
374
+
375
+ def weekofyear_from_extract_as_week(col: ColumnOrName) -> Column:
376
+ col_func = get_func_from_session("col")
377
+
378
+ return Column(
379
+ expression.Extract(
380
+ this=expression.Var(this="week"), expression=col_func(col).cast("date").expression
381
+ )
382
+ )
383
+
384
+
385
+ def weekofyear_from_extract_as_isoweek(col: ColumnOrName) -> Column:
386
+ col_func = get_func_from_session("col")
387
+
388
+ return Column(
389
+ expression.Extract(
390
+ this=expression.Var(this="ISOWEEK"), expression=col_func(col).cast("date").expression
391
+ )
392
+ )
393
+
394
+
395
+ def make_date_casted_as_integer(
396
+ year: ColumnOrName, month: ColumnOrName, day: ColumnOrName
397
+ ) -> Column:
398
+ from sqlframe.base.functions import make_date
399
+
400
+ col_func = get_func_from_session("col")
401
+
402
+ return make_date(
403
+ col_func(year).cast("integer"),
404
+ col_func(month).cast("integer"),
405
+ col_func(day).cast("integer"),
406
+ )
407
+
408
+
409
+ def make_date_from_date_func(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column:
410
+ col_func = get_func_from_session("col")
411
+
412
+ return Column(
413
+ expression.Anonymous(
414
+ this="DATE",
415
+ expressions=[
416
+ col_func(year).cast("integer").expression,
417
+ col_func(month).cast("integer").expression,
418
+ col_func(day).cast("integer").expression,
419
+ ],
420
+ )
421
+ )
422
+
423
+
424
+ def to_date_from_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
425
+ from sqlframe.base.functions import to_date
426
+
427
+ to_timestamp = get_func_from_session("to_timestamp")
428
+
429
+ return to_date(to_timestamp(col, format))
430
+
431
+
432
+ def last_day_with_cast(col: ColumnOrName) -> Column:
433
+ from sqlframe.base.functions import last_day
434
+
435
+ col_func = get_func_from_session("col")
436
+
437
+ return last_day(col_func(col).cast("date"))
438
+
439
+
440
+ def sha1_force_sha1_and_to_hex(col: ColumnOrName) -> Column:
441
+ col_func = get_func_from_session("col")
442
+
443
+ return Column(
444
+ expression.Anonymous(
445
+ this="TO_HEX",
446
+ expressions=[
447
+ expression.Anonymous(
448
+ this="SHA1",
449
+ expressions=[col_func(col).expression],
450
+ )
451
+ ],
452
+ )
453
+ )
454
+
455
+
456
+ def hash_from_farm_fingerprint(*cols: ColumnOrName) -> Column:
457
+ if len(cols) > 1:
458
+ raise ValueError("This dialect only supports a single column for calculating hash")
459
+
460
+ col_func = get_func_from_session("col")
461
+
462
+ return Column(
463
+ expression.Anonymous(
464
+ this="FARM_FINGERPRINT",
465
+ expressions=[col_func(cols[0]).expression],
466
+ )
467
+ )
468
+
469
+
470
+ def date_add_by_multiplication(
471
+ col: ColumnOrName, days: t.Union[ColumnOrName, int], cast_as_date: bool = True
472
+ ) -> Column:
473
+ from sqlframe.base.functions import date_add
474
+
475
+ col_func = get_func_from_session("col")
476
+
477
+ if isinstance(days, int):
478
+ value = date_add(col, days)
479
+ else:
480
+ value = date_add(col, 1, cast_as_date=False) * col_func(days)
481
+ if cast_as_date:
482
+ return value.cast("date")
483
+ return value
484
+
485
+
486
+ def date_sub_by_multiplication(
487
+ col: ColumnOrName, days: t.Union[ColumnOrName, int], cast_as_date: bool = True
488
+ ) -> Column:
489
+ from sqlframe.base.functions import date_sub
490
+
491
+ col_func = get_func_from_session("col")
492
+
493
+ if isinstance(days, int):
494
+ value = date_sub(col, days)
495
+ else:
496
+ value = date_sub(col, 1, cast_as_date=False) * col_func(days)
497
+ if cast_as_date:
498
+ return value.cast("date")
499
+ return value
500
+
501
+
502
+ def date_diff_with_subtraction(end: ColumnOrName, start: ColumnOrName) -> Column:
503
+ col_func = get_func_from_session("col")
504
+
505
+ return col_func(end).cast("date") - col_func(start).cast("date")
506
+
507
+
508
+ def add_months_by_multiplication(
509
+ start: ColumnOrName, months: t.Union[ColumnOrName, int], cast_as_date: bool = True
510
+ ) -> Column:
511
+ from sqlframe.base.functions import add_months
512
+
513
+ col_func = get_func_from_session("col")
514
+ lit = get_func_from_session("lit")
515
+
516
+ multiple_value = lit(months) if isinstance(months, int) else col_func(months)
517
+ value = col_func(add_months(start, 1, cast_as_date=False).expression.unnest()) * multiple_value
518
+ if cast_as_date:
519
+ return value.cast("date")
520
+ return value
521
+
522
+
523
+ def months_between_from_age_and_extract(
524
+ date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
525
+ ) -> Column:
526
+ lit = get_func_from_session("lit")
527
+ col_func = get_func_from_session("col")
528
+
529
+ if roundOff:
530
+ logger.warning("Round off is ignored since it is not supported in this dialect")
531
+ age_expression = expression.Anonymous(
532
+ this="AGE",
533
+ expressions=[
534
+ col_func(date1).cast("date").expression,
535
+ col_func(date2).cast("date").expression,
536
+ ],
537
+ )
538
+ return (
539
+ Column(
540
+ expression.Extract(this=expression.Var(this="year"), expression=age_expression)
541
+ * expression.Literal.number(12)
542
+ )
543
+ + Column(expression.Extract(this=expression.Var(this="month"), expression=age_expression))
544
+ + lit(1)
545
+ ).cast("bigint")
546
+
547
+
548
+ def from_unixtime_from_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
549
+ from sqlframe.base.session import _BaseSession
550
+
551
+ session: _BaseSession = _BaseSession()
552
+ lit = get_func_from_session("lit")
553
+ col_func = get_func_from_session("col")
554
+
555
+ if format is None:
556
+ format = session.DEFAULT_TIME_FORMAT
557
+ return Column.invoke_expression_over_column(
558
+ Column(
559
+ expression.Anonymous(
560
+ this="TO_TIMESTAMP",
561
+ expressions=[col_func(col).expression],
562
+ )
563
+ ),
564
+ expression.TimeToStr,
565
+ format=lit(format),
566
+ )
567
+
568
+
569
+ def unix_timestamp_from_extract(
570
+ timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
571
+ ) -> Column:
572
+ to_timestamp = get_func_from_session("to_timestamp")
573
+
574
+ return Column(
575
+ expression.Extract(
576
+ this=expression.Var(this="epoch"), expression=to_timestamp(timestamp, format).expression
577
+ )
578
+ ).cast("bigint")
579
+
580
+
581
+ def base64_from_blob(col: ColumnOrLiteral) -> Column:
582
+ return Column.invoke_expression_over_column(Column(col).cast("blob"), expression.ToBase64)
583
+
584
+
585
+ def bas64_from_encode(col: ColumnOrLiteral) -> Column:
586
+ return Column(
587
+ expression.Encode(
588
+ this=Column(col).cast("bytea").expression, charset=expression.Literal.string("base64")
589
+ )
590
+ )
591
+
592
+
593
+ def unbase64_from_decode(col: ColumnOrLiteral) -> Column:
594
+ return Column(
595
+ expression.Decode(this=Column(col).expression, charset=expression.Literal.string("base64"))
596
+ )
597
+
598
+
599
+ def decode_from_blob(col: ColumnOrLiteral, charset: str) -> Column:
600
+ return Column(
601
+ expression.Decode(
602
+ this=Column(col).cast("blob").expression, charset=expression.Literal.string(charset)
603
+ )
604
+ )
605
+
606
+
607
+ def decode_from_convert_from(col: ColumnOrLiteral, charset: str) -> Column:
608
+ return Column(
609
+ expression.Anonymous(
610
+ this="CONVERT_FROM",
611
+ expressions=[Column(col).cast("bytea").expression, expression.Literal.string(charset)],
612
+ )
613
+ )
614
+
615
+
616
+ def encode_from_convert_to(col: ColumnOrName, charset: str) -> Column:
617
+ col_func = get_func_from_session("col")
618
+
619
+ return Column(
620
+ expression.Anonymous(
621
+ this="CONVERT_TO",
622
+ expressions=[col_func(col).expression, expression.Literal.string(charset)],
623
+ )
624
+ )
625
+
626
+
627
+ def concat_ws_from_array_to_string(sep: str, *cols: ColumnOrName) -> Column:
628
+ array = get_func_from_session("array")
629
+ lit = get_func_from_session("lit")
630
+
631
+ return Column(
632
+ expression.Anonymous(
633
+ this="ARRAY_TO_STRING",
634
+ expressions=[array(*cols).expression, lit(sep).expression],
635
+ )
636
+ )
637
+
638
+
639
+ def format_number_from_to_char(col: ColumnOrName, d: int) -> Column:
640
+ round = get_func_from_session("round")
641
+ format = "FM" + ("999," * 5) + "990" + "D" + ("0" * d)
642
+
643
+ return Column(
644
+ expression.ToChar(this=round(col, d).expression, format=expression.Literal.string(format))
645
+ )
646
+
647
+
648
+ def format_string_with_format(format: str, *cols: ColumnOrName) -> Column:
649
+ col_func = get_func_from_session("col")
650
+
651
+ return Column(
652
+ expression.Anonymous(
653
+ this="FORMAT",
654
+ expressions=[
655
+ expression.Literal.string(format.replace("%d", "%s")),
656
+ *[col_func(x).cast("string").expression for x in ensure_list(cols)],
657
+ ],
658
+ )
659
+ )
660
+
661
+
662
+ def format_string_with_pipes(format: str, *cols: ColumnOrName) -> Column:
663
+ lit = get_func_from_session("lit")
664
+ col_func = get_func_from_session("col")
665
+
666
+ values = format.replace("%d", "%s").split("%s")
667
+ if len(values) != len(cols) + 1:
668
+ raise ValueError("Number of values and columns do not match")
669
+ result = expression.DPipe(
670
+ this=lit(values[0]).expression, expression=col_func(cols[0]).expression
671
+ )
672
+ for i, value in enumerate(values[1:], start=1):
673
+ if i == len(cols):
674
+ result = expression.DPipe(this=result, expression=lit(value).expression)
675
+ else:
676
+ result = expression.DPipe(
677
+ this=expression.DPipe(this=result, expression=lit(value).expression),
678
+ expression=col_func(cols[i]).expression,
679
+ )
680
+ return Column(result)
681
+
682
+
683
+ def instr_using_strpos(col: ColumnOrName, substr: str) -> Column:
684
+ lit = get_func_from_session("lit")
685
+ col_func = get_func_from_session("col")
686
+
687
+ return Column(
688
+ expression.Anonymous(
689
+ this="STRPOS",
690
+ expressions=[col_func(col).expression, lit(substr).expression],
691
+ )
692
+ )
693
+
694
+
695
+ def overlay_from_substr(
696
+ src: ColumnOrName,
697
+ replace: ColumnOrName,
698
+ pos: t.Union[ColumnOrName, int],
699
+ len: t.Optional[t.Union[ColumnOrName, int]] = None,
700
+ ) -> Column:
701
+ col_func = get_func_from_session("col")
702
+ lit = get_func_from_session("lit")
703
+ substring = get_func_from_session("substring")
704
+ length_func = get_func_from_session("length")
705
+ length_value = len if len is not None else length_func(replace)
706
+ return Column(
707
+ expression.Concat(
708
+ expressions=[
709
+ substring(col_func(src), 1, col_func(pos) - lit(1)).expression,
710
+ col_func(replace).expression,
711
+ substring(
712
+ col_func(src), col_func(pos) + col_func(length_value), length_func(src)
713
+ ).expression,
714
+ ]
715
+ )
716
+ )
717
+
718
+
719
+ def split_no_limit(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
720
+ from sqlframe.base.functions import split
721
+
722
+ col_func = get_func_from_session("col")
723
+
724
+ if limit is not None:
725
+ logger.warning("Limit is ignored since it is not supported in this dialect")
726
+ return split(col_func(str), pattern)
727
+
728
+
729
+ def split_from_regex_split_to_array(
730
+ str: ColumnOrName, pattern: str, limit: t.Optional[int] = None
731
+ ) -> Column:
732
+ col_func = get_func_from_session("col")
733
+
734
+ if limit is not None:
735
+ logger.warning("Limit is ignored since it is not supported in this dialect")
736
+ return Column(
737
+ expression.Anonymous(
738
+ this="REGEXP_SPLIT_TO_ARRAY",
739
+ expressions=[
740
+ col_func(str).expression,
741
+ expression.Literal.string(pattern),
742
+ ],
743
+ )
744
+ )
745
+
746
+
747
+ def split_with_split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column:
748
+ lit = get_func_from_session("lit")
749
+ col_func = get_func_from_session("col")
750
+
751
+ if limit is not None:
752
+ logger.warning("Limit is ignored since it is not supported in this dialect")
753
+ return Column(
754
+ expression.Anonymous(
755
+ this="SPLIT",
756
+ expressions=[col_func(str).expression, lit(pattern).expression],
757
+ )
758
+ )
759
+
760
+
761
+ def array_contains_any(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
762
+ lit = get_func_from_session("lit")
763
+ value_col = value if isinstance(value, Column) else lit(value)
764
+ col_func = get_func_from_session("col")
765
+
766
+ return Column(
767
+ expression.EQ(
768
+ this=value_col.expression,
769
+ expression=expression.Anonymous(this="ANY", expressions=[col_func(col).expression]),
770
+ )
771
+ )
772
+
773
+
774
+ def arrays_overlap_using_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column:
775
+ col_func = get_func_from_session("col")
776
+
777
+ return Column(
778
+ expression.GT(
779
+ this=expression.ArraySize(
780
+ this=expression.Anonymous(
781
+ this="ARRAY_INTERSECT",
782
+ expressions=[col_func(col1).expression, col_func(col2).expression],
783
+ )
784
+ ),
785
+ expression=expression.Literal.number(0),
786
+ )
787
+ )
788
+
789
+
790
+ def arrays_overlap_renamed(col1: ColumnOrName, col2: ColumnOrName) -> Column:
791
+ col_func = get_func_from_session("col")
792
+
793
+ return Column(
794
+ expression.Anonymous(
795
+ this="ARRAYS_OVERLAP",
796
+ expressions=[col_func(col1).expression, col_func(col2).expression],
797
+ )
798
+ )
799
+
800
+
801
+ def slice_as_list_slice(
802
+ x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
803
+ ) -> Column:
804
+ lit = get_func_from_session("lit")
805
+
806
+ start_col = start if isinstance(start, Column) else lit(start)
807
+ length_col = length if isinstance(length, Column) else lit(length)
808
+ return Column.invoke_anonymous_function(x, "LIST_SLICE", start_col, start_col + length_col)
809
+
810
+
811
+ def slice_with_brackets(
812
+ x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
813
+ ) -> Column:
814
+ lit = get_func_from_session("lit")
815
+
816
+ start_col = start if isinstance(start, Column) else lit(start)
817
+ length_col = length if isinstance(length, Column) else lit(length)
818
+ col_func = get_func_from_session("col")
819
+
820
+ return Column(
821
+ expression.Bracket(
822
+ this=col_func(x).expression,
823
+ expressions=[
824
+ expression.Slice(
825
+ this=start_col.expression,
826
+ expression=(start_col + length_col).expression,
827
+ )
828
+ ],
829
+ )
830
+ )
831
+
832
+
833
+ def array_join_null_replacement_with_transform(
834
+ col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
835
+ ) -> Column:
836
+ from sqlframe.base.functions import array_join
837
+
838
+ col_func = get_func_from_session("col")
839
+
840
+ if null_replacement is None:
841
+ return array_join(col, delimiter, null_replacement)
842
+ col = Column(
843
+ expression.Anonymous(
844
+ this="LIST_TRANSFORM",
845
+ expressions=[
846
+ col_func(col).expression,
847
+ expression.Lambda(
848
+ this=expression.Coalesce(
849
+ this=expression.Cast(
850
+ this=expression.Identifier(this="x"),
851
+ to=expression.DataType.build("STRING"),
852
+ ),
853
+ expressions=[expression.Literal.string(null_replacement)],
854
+ ),
855
+ expressions=[expression.Identifier(this="x")],
856
+ ),
857
+ ],
858
+ )
859
+ )
860
+ return array_join(col, delimiter)
861
+
862
+
863
+ def element_at_using_brackets(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
864
+ col_func = get_func_from_session("col")
865
+ lit = get_func_from_session("lit")
866
+ # SQLGlot will auto add 1 to whatever we pass in for the brackets even though the value is already 1 based.
867
+ if not isinstance(value, int):
868
+ raise ValueError("This dialect requires the value must be an integer")
869
+ value_lit = lit(value - 1)
870
+ return Column(
871
+ expression.Bracket(this=col_func(col).expression, expressions=[value_lit.expression])
872
+ )
873
+
874
+
875
+ def array_remove_using_filter(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
876
+ lit = get_func_from_session("lit")
877
+ col_func = get_func_from_session("col")
878
+
879
+ value = value if isinstance(value, Column) else lit(value)
880
+ return Column(
881
+ expression.Anonymous(
882
+ this="LIST_FILTER",
883
+ expressions=[
884
+ col_func(col).expression,
885
+ expression.Lambda(
886
+ this=expression.NEQ(
887
+ this=expression.Identifier(this="x"), expression=value.expression
888
+ ),
889
+ expressions=[expression.Identifier(this="x")],
890
+ ),
891
+ ],
892
+ )
893
+ )
894
+
895
+
896
+ def array_union_using_list_concat(col1: ColumnOrName, col2: ColumnOrName) -> Column:
897
+ col_func = get_func_from_session("col")
898
+
899
+ return Column(
900
+ expression.Anonymous(
901
+ this="LIST_DISTINCT",
902
+ expressions=[
903
+ expression.Anonymous(
904
+ this="LIST_CONCAT",
905
+ expressions=[col_func(col1).expression, col_func(col2).expression],
906
+ )
907
+ ],
908
+ )
909
+ )
910
+
911
+
912
+ def array_union_using_array_concat(col1: ColumnOrName, col2: ColumnOrName) -> Column:
913
+ array_distinct = get_func_from_session("array_distinct")
914
+ col_func = get_func_from_session("col")
915
+
916
+ return array_distinct(
917
+ expression.ArrayConcat(
918
+ this=col_func(col1).expression, expressions=[col_func(col2).expression]
919
+ )
920
+ )
921
+
922
+
923
+ def get_json_object_using_arrow_op(col: ColumnOrName, path: str) -> Column:
924
+ col_func = get_func_from_session("col")
925
+ path = path.replace("$.", "")
926
+ return Column(
927
+ expression.JSONExtract(
928
+ this=expression.Cast(
929
+ this=col_func(col).expression, to=expression.DataType.build("JSON")
930
+ ),
931
+ expression=expression.JSONPath(
932
+ expressions=[expression.JSONPathRoot(), expression.JSONPathKey(this=path)]
933
+ ),
934
+ only_json_types=True,
935
+ )
936
+ )
937
+
938
+
939
+ def array_min_from_sort(col: ColumnOrName) -> Column:
940
+ element_at = get_func_from_session("element_at")
941
+ array_sort = get_func_from_session("array_sort")
942
+
943
+ return element_at(array_sort(col), 1)
944
+
945
+
946
+ def array_min_from_subquery(col: ColumnOrName) -> Column:
947
+ col_func = get_func_from_session("col")
948
+
949
+ explode = get_func_from_session("explode")
950
+ select = expression.Select(
951
+ expressions=[
952
+ expression.Min(
953
+ this=col_func("x").expression,
954
+ )
955
+ ],
956
+ )
957
+ select.set(
958
+ "from",
959
+ expression.From(
960
+ this=explode(col).alias("x").expression,
961
+ ),
962
+ )
963
+
964
+ return Column(expression.Subquery(this=select)).alias(col_func(col).alias_or_name)
965
+
966
+
967
+ def array_max_from_sort(col: ColumnOrName) -> Column:
968
+ element_at = get_func_from_session("element_at")
969
+ array_sort = get_func_from_session("array_sort")
970
+
971
+ return element_at(array_sort(col), -1)
972
+
973
+
974
+ def array_max_from_subquery(col: ColumnOrName) -> Column:
975
+ col_func = get_func_from_session("col")
976
+
977
+ explode = get_func_from_session("explode")
978
+ select = expression.Select(
979
+ expressions=[
980
+ expression.Max(
981
+ this=col_func("x").expression,
982
+ )
983
+ ],
984
+ )
985
+ select.set(
986
+ "from",
987
+ expression.From(
988
+ this=explode(col).alias("x").expression,
989
+ ),
990
+ )
991
+
992
+ return Column(expression.Subquery(this=select)).alias(col_func(col).alias_or_name)
993
+
994
+
995
+ def sequence_from_generate_series(
996
+ start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
997
+ ) -> Column:
998
+ col_func = get_func_from_session("col")
999
+
1000
+ return Column(
1001
+ expression.Anonymous(
1002
+ this="GENERATE_SERIES",
1003
+ expressions=[
1004
+ col_func(start).expression,
1005
+ col_func(stop).expression,
1006
+ col_func(step).expression if step else expression.Literal.number(1),
1007
+ ],
1008
+ )
1009
+ )
1010
+
1011
+
1012
+ def sequence_from_generate_array(
1013
+ start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
1014
+ ) -> Column:
1015
+ col_func = get_func_from_session("col")
1016
+
1017
+ return Column(
1018
+ expression.Anonymous(
1019
+ this="GENERATE_ARRAY",
1020
+ expressions=[
1021
+ col_func(start).expression,
1022
+ col_func(stop).expression,
1023
+ col_func(step).expression if step else expression.Literal.number(1),
1024
+ ],
1025
+ )
1026
+ )
1027
+
1028
+
1029
+ def regexp_extract_only_one_group(
1030
+ str: ColumnOrName, pattern: str, idx: t.Optional[int] = None
1031
+ ) -> Column:
1032
+ from sqlframe.base.functions import regexp_extract
1033
+
1034
+ if re.compile(pattern).groups > 1 or (idx is not None and idx > 1):
1035
+ raise ValueError("This dialect only supports regular expressions with a single group")
1036
+
1037
+ return regexp_extract(str, pattern, 1)
1038
+
1039
+
1040
+ def hex_casted_as_bytes(col: ColumnOrName) -> Column:
1041
+ col_func = get_func_from_session("col")
1042
+
1043
+ return Column(
1044
+ expression.Anonymous(
1045
+ this="TO_HEX",
1046
+ expressions=[col_func(col).cast("bytes").expression],
1047
+ )
1048
+ )
1049
+
1050
+
1051
+ def bit_length_from_length(col: ColumnOrName) -> Column:
1052
+ lit = get_func_from_session("lit")
1053
+ col_func = get_func_from_session("col")
1054
+
1055
+ return Column(expression.Length(this=col_func(col).expression)) * lit(8)