sqlframe 3.39.3__py3-none-any.whl → 3.40.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +3 -3
- sqlframe/base/dataframe.py +37 -32
- sqlframe/base/functions.py +25 -34
- sqlframe/base/mixins/table_mixins.py +6 -2
- sqlframe/base/normalize.py +197 -8
- sqlframe/base/operations.py +3 -4
- {sqlframe-3.39.3.dist-info → sqlframe-3.40.0.dist-info}/METADATA +2 -2
- {sqlframe-3.39.3.dist-info → sqlframe-3.40.0.dist-info}/RECORD +11 -11
- {sqlframe-3.39.3.dist-info → sqlframe-3.40.0.dist-info}/LICENSE +0 -0
- {sqlframe-3.39.3.dist-info → sqlframe-3.40.0.dist-info}/WHEEL +0 -0
- {sqlframe-3.39.3.dist-info → sqlframe-3.40.0.dist-info}/top_level.txt +0 -0
sqlframe/_version.py
CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
28
28
|
commit_id: COMMIT_ID
|
29
29
|
__commit_id__: COMMIT_ID
|
30
30
|
|
31
|
-
__version__ = version = '3.
|
32
|
-
__version_tuple__ = version_tuple = (3,
|
31
|
+
__version__ = version = '3.40.0'
|
32
|
+
__version_tuple__ = version_tuple = (3, 40, 0)
|
33
33
|
|
34
|
-
__commit_id__ = commit_id = '
|
34
|
+
__commit_id__ = commit_id = 'g93abcd907'
|
sqlframe/base/dataframe.py
CHANGED
@@ -16,7 +16,6 @@ from dataclasses import dataclass
|
|
16
16
|
from uuid import uuid4
|
17
17
|
|
18
18
|
import sqlglot
|
19
|
-
from more_itertools import partition
|
20
19
|
from prettytable import PrettyTable
|
21
20
|
from sqlglot import Dialect, maybe_parse
|
22
21
|
from sqlglot import expressions as exp
|
@@ -397,12 +396,21 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
397
396
|
return Column.ensure_cols(ensure_list(cols)) # type: ignore
|
398
397
|
|
399
398
|
def _ensure_and_normalize_cols(
|
400
|
-
self,
|
399
|
+
self,
|
400
|
+
cols,
|
401
|
+
expression: t.Optional[exp.Select] = None,
|
402
|
+
skip_star_expansion: bool = False,
|
403
|
+
remove_identifier_if_possible: bool = True,
|
401
404
|
) -> t.List[Column]:
|
402
405
|
from sqlframe.base.normalize import normalize
|
403
406
|
|
404
407
|
cols = self._ensure_list_of_columns(cols)
|
405
|
-
normalize(
|
408
|
+
normalize(
|
409
|
+
self.session,
|
410
|
+
expression or self.expression,
|
411
|
+
cols,
|
412
|
+
remove_identifier_if_possible=remove_identifier_if_possible,
|
413
|
+
)
|
406
414
|
if not skip_star_expansion:
|
407
415
|
cols = list(flatten([self._expand_star(col) for col in cols]))
|
408
416
|
self._resolve_ambiguous_columns(cols)
|
@@ -542,23 +550,16 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
542
550
|
expression.set("with", exp.With(expressions=existing_ctes))
|
543
551
|
return expression
|
544
552
|
|
545
|
-
@classmethod
|
546
|
-
def _get_outer_select_expressions(
|
547
|
-
cls, item: exp.Expression
|
548
|
-
) -> t.List[t.Union[exp.Column, exp.Alias]]:
|
549
|
-
outer_select = item.find(exp.Select)
|
550
|
-
if outer_select:
|
551
|
-
return outer_select.expressions
|
552
|
-
return []
|
553
|
-
|
554
553
|
@classmethod
|
555
554
|
def _get_outer_select_columns(cls, item: exp.Expression) -> t.List[Column]:
|
556
555
|
from sqlframe.base.session import _BaseSession
|
557
556
|
|
558
557
|
col = get_func_from_session("col", _BaseSession())
|
559
558
|
|
560
|
-
|
561
|
-
|
559
|
+
outer_select = item.find(exp.Select)
|
560
|
+
if outer_select:
|
561
|
+
return [col(quote_preserving_alias_or_name(x)) for x in outer_select.expressions]
|
562
|
+
return []
|
562
563
|
|
563
564
|
def _create_hash_from_expression(self, expression: exp.Expression) -> str:
|
564
565
|
from sqlframe.base.session import _BaseSession
|
@@ -1025,9 +1026,17 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
1025
1026
|
return join_column_pairs, join_clause
|
1026
1027
|
|
1027
1028
|
def _normalize_join_clause(
|
1028
|
-
self,
|
1029
|
+
self,
|
1030
|
+
join_columns: t.List[Column],
|
1031
|
+
join_expression: t.Optional[exp.Select],
|
1032
|
+
*,
|
1033
|
+
remove_identifier_if_possible: bool = True,
|
1029
1034
|
) -> Column:
|
1030
|
-
join_columns = self._ensure_and_normalize_cols(
|
1035
|
+
join_columns = self._ensure_and_normalize_cols(
|
1036
|
+
join_columns,
|
1037
|
+
join_expression,
|
1038
|
+
remove_identifier_if_possible=remove_identifier_if_possible,
|
1039
|
+
)
|
1031
1040
|
if len(join_columns) > 1:
|
1032
1041
|
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
1033
1042
|
join_clause = join_columns[0]
|
@@ -1512,23 +1521,20 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
1512
1521
|
"""
|
1513
1522
|
return func(self, *args, **kwargs) # type: ignore
|
1514
1523
|
|
1515
|
-
@operation(Operation.
|
1524
|
+
@operation(Operation.SELECT)
|
1516
1525
|
def withColumn(self, colName: str, col: Column) -> Self:
|
1517
1526
|
return self.withColumns.__wrapped__(self, {colName: col}) # type: ignore
|
1518
1527
|
|
1519
|
-
@operation(Operation.
|
1528
|
+
@operation(Operation.SELECT)
|
1520
1529
|
def withColumnRenamed(self, existing: str, new: str) -> Self:
|
1521
|
-
col_func = get_func_from_session("col", self.session)
|
1522
1530
|
expression = self.expression.copy()
|
1523
1531
|
existing = self.session._normalize_string(existing)
|
1524
|
-
|
1532
|
+
columns = self._get_outer_select_columns(expression)
|
1525
1533
|
results = []
|
1526
1534
|
found_match = False
|
1527
|
-
for
|
1528
|
-
column
|
1529
|
-
|
1530
|
-
if isinstance(column.expression, exp.Alias):
|
1531
|
-
column.expression.set("alias", exp.to_identifier(new))
|
1535
|
+
for column in columns:
|
1536
|
+
if column.alias_or_name == existing:
|
1537
|
+
column = column.alias(new)
|
1532
1538
|
self._update_display_name_mapping([column], [new])
|
1533
1539
|
found_match = True
|
1534
1540
|
results.append(column)
|
@@ -1536,7 +1542,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
1536
1542
|
raise ValueError("Tried to rename a column that doesn't exist")
|
1537
1543
|
return self.select.__wrapped__(self, *results, skip_update_display_name_mapping=True) # type: ignore
|
1538
1544
|
|
1539
|
-
@operation(Operation.
|
1545
|
+
@operation(Operation.SELECT)
|
1540
1546
|
def withColumnsRenamed(self, colsMap: t.Dict[str, str]) -> Self:
|
1541
1547
|
"""
|
1542
1548
|
Returns a new :class:`DataFrame` by renaming multiple columns. If a non-existing column is
|
@@ -1582,7 +1588,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
1582
1588
|
|
1583
1589
|
return self.select.__wrapped__(self, *results, skip_update_display_name_mapping=True) # type: ignore
|
1584
1590
|
|
1585
|
-
@operation(Operation.
|
1591
|
+
@operation(Operation.SELECT)
|
1586
1592
|
def withColumns(self, *colsMap: t.Dict[str, Column]) -> Self:
|
1587
1593
|
"""
|
1588
1594
|
Returns a new :class:`DataFrame` by adding multiple columns or replacing the
|
@@ -1620,14 +1626,13 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
1620
1626
|
"""
|
1621
1627
|
if len(colsMap) != 1:
|
1622
1628
|
raise ValueError("Only a single map is supported")
|
1623
|
-
col_func = get_func_from_session("col")
|
1624
1629
|
col_map = {
|
1625
1630
|
self._ensure_and_normalize_col(k): (self._ensure_and_normalize_col(v), k)
|
1626
1631
|
for k, v in colsMap[0].items()
|
1627
1632
|
}
|
1628
|
-
|
1629
|
-
existing_col_names = [x.alias_or_name for x in
|
1630
|
-
select_columns =
|
1633
|
+
existing_cols = self._get_outer_select_columns(self.expression)
|
1634
|
+
existing_col_names = [x.alias_or_name for x in existing_cols]
|
1635
|
+
select_columns = existing_cols
|
1631
1636
|
for col, (col_value, display_name) in col_map.items():
|
1632
1637
|
column_name = col.alias_or_name
|
1633
1638
|
existing_col_index = (
|
@@ -1644,7 +1649,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
1644
1649
|
)
|
1645
1650
|
return self.select.__wrapped__(self, *select_columns, skip_update_display_name_mapping=True) # type: ignore
|
1646
1651
|
|
1647
|
-
@operation(Operation.
|
1652
|
+
@operation(Operation.SELECT)
|
1648
1653
|
def drop(self, *cols: t.Union[str, Column]) -> Self:
|
1649
1654
|
# Separate string column names from Column objects for different handling
|
1650
1655
|
column_objs, column_names = partition_to(lambda x: isinstance(x, str), cols, list, set)
|
sqlframe/base/functions.py
CHANGED
@@ -37,9 +37,7 @@ def _get_session() -> _BaseSession:
|
|
37
37
|
|
38
38
|
@meta()
|
39
39
|
def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
|
40
|
-
|
41
|
-
|
42
|
-
dialect = _BaseSession().input_dialect
|
40
|
+
dialect = _get_session().input_dialect
|
43
41
|
if isinstance(column_name, str):
|
44
42
|
col_expression = expression.to_column(column_name, dialect=dialect).transform(
|
45
43
|
dialect.normalize_identifier
|
@@ -192,27 +190,27 @@ sum_distinct = sumDistinct
|
|
192
190
|
|
193
191
|
@meta()
|
194
192
|
def acos(col: ColumnOrName) -> Column:
|
195
|
-
return Column.
|
193
|
+
return Column.invoke_expression_over_column(col, expression.Acos)
|
196
194
|
|
197
195
|
|
198
196
|
@meta(unsupported_engines="duckdb")
|
199
197
|
def acosh(col: ColumnOrName) -> Column:
|
200
|
-
return Column.
|
198
|
+
return Column.invoke_expression_over_column(col, expression.Acosh)
|
201
199
|
|
202
200
|
|
203
201
|
@meta()
|
204
202
|
def asin(col: ColumnOrName) -> Column:
|
205
|
-
return Column.
|
203
|
+
return Column.invoke_expression_over_column(col, expression.Asin)
|
206
204
|
|
207
205
|
|
208
206
|
@meta(unsupported_engines="duckdb")
|
209
207
|
def asinh(col: ColumnOrName) -> Column:
|
210
|
-
return Column.
|
208
|
+
return Column.invoke_expression_over_column(col, expression.Asinh)
|
211
209
|
|
212
210
|
|
213
211
|
@meta()
|
214
212
|
def atan(col: ColumnOrName) -> Column:
|
215
|
-
return Column.
|
213
|
+
return Column.invoke_expression_over_column(col, expression.Atan)
|
216
214
|
|
217
215
|
|
218
216
|
@meta()
|
@@ -220,12 +218,12 @@ def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
|
|
220
218
|
col1_value = lit(col1) if isinstance(col1, (int, float)) else col1
|
221
219
|
col2_value = lit(col2) if isinstance(col2, (int, float)) else col2
|
222
220
|
|
223
|
-
return Column.
|
221
|
+
return Column.invoke_expression_over_column(col1_value, expression.Atan2, expression=col2_value)
|
224
222
|
|
225
223
|
|
226
224
|
@meta(unsupported_engines="duckdb")
|
227
225
|
def atanh(col: ColumnOrName) -> Column:
|
228
|
-
return Column.
|
226
|
+
return Column.invoke_expression_over_column(col, expression.Atanh)
|
229
227
|
|
230
228
|
|
231
229
|
@meta()
|
@@ -253,12 +251,12 @@ def cosh(col: ColumnOrName) -> Column:
|
|
253
251
|
|
254
252
|
@meta()
|
255
253
|
def cot(col: ColumnOrName) -> Column:
|
256
|
-
return Column.
|
254
|
+
return Column.invoke_expression_over_column(col, expression.Cot)
|
257
255
|
|
258
256
|
|
259
257
|
@meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
|
260
258
|
def csc(col: ColumnOrName) -> Column:
|
261
|
-
return Column.
|
259
|
+
return Column.invoke_expression_over_column(col, expression.Csc)
|
262
260
|
|
263
261
|
|
264
262
|
@meta()
|
@@ -364,7 +362,7 @@ def rint(col: ColumnOrName) -> Column:
|
|
364
362
|
|
365
363
|
@meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
|
366
364
|
def sec(col: ColumnOrName) -> Column:
|
367
|
-
return Column.
|
365
|
+
return Column.invoke_expression_over_column(col, expression.Sec)
|
368
366
|
|
369
367
|
|
370
368
|
@meta()
|
@@ -374,12 +372,12 @@ def signum(col: ColumnOrName) -> Column:
|
|
374
372
|
|
375
373
|
@meta()
|
376
374
|
def sin(col: ColumnOrName) -> Column:
|
377
|
-
return Column.
|
375
|
+
return Column.invoke_expression_over_column(col, expression.Sin)
|
378
376
|
|
379
377
|
|
380
378
|
@meta(unsupported_engines="duckdb")
|
381
379
|
def sinh(col: ColumnOrName) -> Column:
|
382
|
-
return Column.
|
380
|
+
return Column.invoke_expression_over_column(col, expression.Sinh)
|
383
381
|
|
384
382
|
|
385
383
|
@meta()
|
@@ -662,9 +660,7 @@ def grouping_id(*cols: ColumnOrName) -> Column:
|
|
662
660
|
|
663
661
|
@meta()
|
664
662
|
def input_file_name() -> Column:
|
665
|
-
|
666
|
-
|
667
|
-
return Column(expression.Literal.string(_BaseSession()._last_loaded_file or ""))
|
663
|
+
return Column(expression.Literal.string(_get_session()._last_loaded_file or ""))
|
668
664
|
|
669
665
|
|
670
666
|
@meta()
|
@@ -944,7 +940,7 @@ def nth_value(
|
|
944
940
|
|
945
941
|
@meta()
|
946
942
|
def ntile(n: int) -> Column:
|
947
|
-
return Column.
|
943
|
+
return Column.invoke_expression_over_column(lit(n), expression.Ntile)
|
948
944
|
|
949
945
|
|
950
946
|
@meta()
|
@@ -959,12 +955,10 @@ def current_timestamp() -> Column:
|
|
959
955
|
|
960
956
|
@meta()
|
961
957
|
def date_format(col: ColumnOrName, format: str) -> Column:
|
962
|
-
from sqlframe.base.session import _BaseSession
|
963
|
-
|
964
958
|
return Column.invoke_expression_over_column(
|
965
959
|
Column(expression.TimeStrToTime(this=Column.ensure_col(col).column_expression)),
|
966
960
|
expression.TimeToStr,
|
967
|
-
format=
|
961
|
+
format=_get_session().format_time(format),
|
968
962
|
)
|
969
963
|
|
970
964
|
|
@@ -2832,7 +2826,7 @@ def make_interval(
|
|
2832
2826
|
|
2833
2827
|
@meta(unsupported_engines="*")
|
2834
2828
|
def try_add(left: ColumnOrName, right: ColumnOrName) -> Column:
|
2835
|
-
return Column.
|
2829
|
+
return Column.invoke_expression_over_column(left, expression.SafeAdd, expression=right)
|
2836
2830
|
|
2837
2831
|
|
2838
2832
|
@meta(unsupported_engines="*")
|
@@ -2849,12 +2843,12 @@ def try_divide(left: ColumnOrName, right: ColumnOrName) -> Column:
|
|
2849
2843
|
|
2850
2844
|
@meta(unsupported_engines="*")
|
2851
2845
|
def try_multiply(left: ColumnOrName, right: ColumnOrName) -> Column:
|
2852
|
-
return Column.
|
2846
|
+
return Column.invoke_expression_over_column(left, expression.SafeMultiply, expression=right)
|
2853
2847
|
|
2854
2848
|
|
2855
2849
|
@meta(unsupported_engines="*")
|
2856
2850
|
def try_subtract(left: ColumnOrName, right: ColumnOrName) -> Column:
|
2857
|
-
return Column.
|
2851
|
+
return Column.invoke_expression_over_column(left, expression.SafeSubtract, expression=right)
|
2858
2852
|
|
2859
2853
|
|
2860
2854
|
@meta(unsupported_engines="*")
|
@@ -3378,10 +3372,9 @@ def get(col: ColumnOrName, index: t.Union[ColumnOrName, int]) -> Column:
|
|
3378
3372
|
def get_active_spark_context() -> SparkContext:
|
3379
3373
|
"""Raise RuntimeError if SparkContext is not initialized,
|
3380
3374
|
otherwise, returns the active SparkContext."""
|
3381
|
-
from sqlframe.base.session import _BaseSession
|
3382
3375
|
from sqlframe.spark.session import SparkSession
|
3383
3376
|
|
3384
|
-
session
|
3377
|
+
session = _get_session()
|
3385
3378
|
if not isinstance(session, SparkSession):
|
3386
3379
|
raise RuntimeError("This function is only available in SparkSession.")
|
3387
3380
|
return session.spark_session.sparkContext
|
@@ -5263,7 +5256,7 @@ def regexp_extract_all(
|
|
5263
5256
|
)
|
5264
5257
|
|
5265
5258
|
|
5266
|
-
@meta(unsupported_engines="
|
5259
|
+
@meta(unsupported_engines=["duckdb", "bigquery", "postgres", "snowflake"])
|
5267
5260
|
def regexp_instr(
|
5268
5261
|
str: ColumnOrName, regexp: ColumnOrName, idx: t.Optional[t.Union[int, Column]] = None
|
5269
5262
|
) -> Column:
|
@@ -5298,11 +5291,9 @@ def regexp_instr(
|
|
5298
5291
|
>>> df.select(regexp_instr('str', col("regexp")).alias('d')).collect()
|
5299
5292
|
[Row(d=1)]
|
5300
5293
|
"""
|
5301
|
-
|
5302
|
-
|
5303
|
-
|
5304
|
-
idx = lit(idx) if isinstance(idx, int) else idx
|
5305
|
-
return Column.invoke_anonymous_function(str, "regexp_instr", regexp, idx)
|
5294
|
+
return Column.invoke_expression_over_column(
|
5295
|
+
str, expression.RegexpInstr, expression=regexp, group=idx
|
5296
|
+
)
|
5306
5297
|
|
5307
5298
|
|
5308
5299
|
@meta(unsupported_engines="snowflake")
|
@@ -6344,7 +6335,7 @@ def to_unix_timestamp(
|
|
6344
6335
|
session = _get_session()
|
6345
6336
|
|
6346
6337
|
if session._is_duckdb:
|
6347
|
-
format = format or
|
6338
|
+
format = format or session.default_time_format
|
6348
6339
|
timestamp = Column.ensure_col(timestamp).cast("string")
|
6349
6340
|
|
6350
6341
|
if format is not None:
|
@@ -275,7 +275,9 @@ class MergeSupportMixin(_BaseTable, t.Generic[DF]):
|
|
275
275
|
join_expression = self._add_ctes_to_expression(
|
276
276
|
self.expression, other_df.expression.copy().ctes
|
277
277
|
)
|
278
|
-
condition = self._ensure_and_normalize_cols(
|
278
|
+
condition = self._ensure_and_normalize_cols(
|
279
|
+
condition, self.expression, remove_identifier_if_possible=False
|
280
|
+
)
|
279
281
|
self._handle_self_join(other_df, condition)
|
280
282
|
|
281
283
|
if isinstance(condition[0].expression, exp.Column) and not clause:
|
@@ -291,7 +293,9 @@ class MergeSupportMixin(_BaseTable, t.Generic[DF]):
|
|
291
293
|
condition, join_expression, other_df, table_names
|
292
294
|
)
|
293
295
|
else:
|
294
|
-
join_clause = self._normalize_join_clause(
|
296
|
+
join_clause = self._normalize_join_clause(
|
297
|
+
condition, join_expression, remove_identifier_if_possible=False
|
298
|
+
)
|
295
299
|
return join_clause
|
296
300
|
|
297
301
|
def _ensure_and_normalize_assignments(
|
sqlframe/base/normalize.py
CHANGED
@@ -16,33 +16,103 @@ if t.TYPE_CHECKING:
|
|
16
16
|
NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
|
17
17
|
|
18
18
|
|
19
|
-
def normalize(
|
19
|
+
def normalize(
|
20
|
+
session: SESSION,
|
21
|
+
expression_context: exp.Select,
|
22
|
+
expr: t.List[NORMALIZE_INPUT],
|
23
|
+
*,
|
24
|
+
remove_identifier_if_possible: bool = True,
|
25
|
+
):
|
20
26
|
expr = ensure_list(expr)
|
21
27
|
expressions = _ensure_expressions(expr)
|
22
28
|
for expression in expressions:
|
23
29
|
identifiers = expression.find_all(exp.Identifier)
|
24
30
|
for identifier in identifiers:
|
25
31
|
identifier.transform(session.input_dialect.normalize_identifier)
|
26
|
-
replace_alias_name_with_cte_name(
|
27
|
-
|
32
|
+
replace_alias_name_with_cte_name(
|
33
|
+
session,
|
34
|
+
expression_context,
|
35
|
+
identifier,
|
36
|
+
remove_identifier_if_possible=remove_identifier_if_possible,
|
37
|
+
)
|
38
|
+
replace_branch_and_sequence_ids_with_cte_name(
|
39
|
+
session,
|
40
|
+
expression_context,
|
41
|
+
identifier,
|
42
|
+
remove_identifier_if_possible=remove_identifier_if_possible,
|
43
|
+
)
|
28
44
|
|
29
45
|
|
30
46
|
def replace_alias_name_with_cte_name(
|
31
|
-
session: SESSION,
|
47
|
+
session: SESSION,
|
48
|
+
expression_context: exp.Select,
|
49
|
+
id: exp.Identifier,
|
50
|
+
*,
|
51
|
+
remove_identifier_if_possible: bool,
|
32
52
|
):
|
33
53
|
normalized_id = session._normalize_string(id.alias_or_name)
|
34
54
|
if normalized_id in session.name_to_sequence_id_mapping:
|
35
|
-
|
55
|
+
# Get CTEs that are referenced in the FROM clause
|
56
|
+
referenced_cte_names = get_cte_names_from_from_clause(expression_context)
|
57
|
+
|
58
|
+
# Filter CTEs to only include those defined and referenced by the FROM clause
|
59
|
+
filtered_ctes = [
|
60
|
+
cte
|
61
|
+
for cte in reversed(expression_context.ctes)
|
62
|
+
if cte.alias_or_name in referenced_cte_names
|
63
|
+
]
|
64
|
+
|
65
|
+
for cte in filtered_ctes:
|
36
66
|
if cte.args["sequence_id"] in session.name_to_sequence_id_mapping[normalized_id]:
|
37
67
|
_set_alias_name(id, cte.alias_or_name)
|
38
68
|
break
|
69
|
+
else:
|
70
|
+
# Fallback: If not found in filtered CTEs, search through ALL CTEs unfiltered
|
71
|
+
for cte in reversed(expression_context.ctes):
|
72
|
+
if cte.args["sequence_id"] in session.name_to_sequence_id_mapping[normalized_id]:
|
73
|
+
_set_alias_name(id, cte.alias_or_name)
|
74
|
+
break
|
75
|
+
else:
|
76
|
+
# Final fallback: If this is a qualified column reference (table.column)
|
77
|
+
# and the table doesn't exist in FROM clause, remove the qualifier IF the column is unambiguously available
|
78
|
+
parent = id.parent
|
79
|
+
if parent and isinstance(parent, exp.Column) and remove_identifier_if_possible:
|
80
|
+
# Check if this table is not available in current context
|
81
|
+
current_tables = get_cte_names_from_from_clause(expression_context)
|
82
|
+
if normalized_id not in current_tables:
|
83
|
+
# Check if this table ID matches any CTE name directly (cross-context CTE reference)
|
84
|
+
cte_exists = any(
|
85
|
+
cte.alias_or_name == normalized_id for cte in expression_context.ctes
|
86
|
+
)
|
87
|
+
|
88
|
+
if cte_exists:
|
89
|
+
# This is a reference to a CTE that exists but is not in the current FROM clause
|
90
|
+
# Get the column name being referenced
|
91
|
+
column_name = (
|
92
|
+
_extract_column_name(parent.this)
|
93
|
+
if hasattr(parent, "this")
|
94
|
+
else None
|
95
|
+
)
|
96
|
+
|
97
|
+
# Only remove qualifier if the column is unambiguously available in current context
|
98
|
+
if column_name and is_column_unambiguously_available(
|
99
|
+
expression_context, column_name
|
100
|
+
):
|
101
|
+
parent.set("table", None)
|
39
102
|
|
40
103
|
|
41
104
|
def replace_branch_and_sequence_ids_with_cte_name(
|
42
|
-
session: SESSION,
|
105
|
+
session: SESSION,
|
106
|
+
expression_context: exp.Select,
|
107
|
+
id: exp.Identifier,
|
108
|
+
*,
|
109
|
+
remove_identifier_if_possible: bool,
|
43
110
|
):
|
44
111
|
normalized_id = session._normalize_string(id.alias_or_name)
|
45
112
|
if normalized_id in session.known_ids:
|
113
|
+
# Get CTEs that are referenced in the FROM clause
|
114
|
+
referenced_cte_names = get_cte_names_from_from_clause(expression_context)
|
115
|
+
|
46
116
|
# Check if we have a join and if both the tables in that join share a common branch id
|
47
117
|
# If so we need to have this reference the left table by default unless the id is a sequence
|
48
118
|
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
|
@@ -51,19 +121,138 @@ def replace_branch_and_sequence_ids_with_cte_name(
|
|
51
121
|
join_table_aliases = [
|
52
122
|
x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
|
53
123
|
]
|
124
|
+
# Filter CTEs to only include those referenced in the FROM clause
|
54
125
|
ctes_in_join = [
|
55
|
-
cte
|
126
|
+
cte
|
127
|
+
for cte in expression_context.ctes
|
128
|
+
if cte.alias_or_name in join_table_aliases
|
129
|
+
and cte.alias_or_name in referenced_cte_names
|
56
130
|
]
|
57
|
-
if
|
131
|
+
if (
|
132
|
+
len(ctes_in_join) >= 2
|
133
|
+
and ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]
|
134
|
+
):
|
58
135
|
assert len(ctes_in_join) == 2
|
59
136
|
_set_alias_name(id, ctes_in_join[0].alias_or_name)
|
60
137
|
return
|
61
138
|
|
139
|
+
# Filter CTEs to only include those defined and referenced by the FROM clause
|
140
|
+
filtered_ctes = [
|
141
|
+
cte
|
142
|
+
for cte in reversed(expression_context.ctes)
|
143
|
+
if cte.alias_or_name in referenced_cte_names
|
144
|
+
]
|
145
|
+
|
146
|
+
for cte in filtered_ctes:
|
147
|
+
if normalized_id in (cte.args["branch_id"], cte.args["sequence_id"]):
|
148
|
+
_set_alias_name(id, cte.alias_or_name)
|
149
|
+
return
|
150
|
+
|
151
|
+
# Fallback: If not found in filtered CTEs, search through ALL CTEs unfiltered
|
62
152
|
for cte in reversed(expression_context.ctes):
|
63
153
|
if normalized_id in (cte.args["branch_id"], cte.args["sequence_id"]):
|
64
154
|
_set_alias_name(id, cte.alias_or_name)
|
65
155
|
return
|
66
156
|
|
157
|
+
# Final fallback: If this is a qualified column reference (table.column)
|
158
|
+
# and the table doesn't exist in FROM clause, remove the qualifier IF the column is unambiguously available
|
159
|
+
parent = id.parent
|
160
|
+
if parent and isinstance(parent, exp.Column) and remove_identifier_if_possible:
|
161
|
+
# Check if this table is not available in current context
|
162
|
+
current_tables = get_cte_names_from_from_clause(expression_context)
|
163
|
+
if normalized_id not in current_tables:
|
164
|
+
# Check if this table ID matches any CTE name directly (cross-context CTE reference)
|
165
|
+
cte_exists = any(cte.alias_or_name == normalized_id for cte in expression_context.ctes)
|
166
|
+
|
167
|
+
if cte_exists:
|
168
|
+
# This is a reference to a CTE that exists but is not in the current FROM clause
|
169
|
+
# Get the column name being referenced
|
170
|
+
column_name = _extract_column_name(parent.this) if hasattr(parent, "this") else None
|
171
|
+
|
172
|
+
# Only remove qualifier if the column is unambiguously available in current context
|
173
|
+
if column_name and is_column_unambiguously_available(
|
174
|
+
expression_context, column_name
|
175
|
+
):
|
176
|
+
parent.set("table", None)
|
177
|
+
|
178
|
+
|
179
|
+
def is_column_unambiguously_available(expression_context: exp.Select, column_name: str) -> bool:
|
180
|
+
"""
|
181
|
+
Check if a column name is unambiguously available in the current context.
|
182
|
+
Returns True if the column appears exactly once across all accessible CTEs.
|
183
|
+
|
184
|
+
Enhanced to handle more column expression types and edge cases.
|
185
|
+
"""
|
186
|
+
current_tables = get_cte_names_from_from_clause(expression_context)
|
187
|
+
column_count_in_from = 0
|
188
|
+
|
189
|
+
# If no tables in FROM clause, be conservative
|
190
|
+
if not current_tables:
|
191
|
+
return False
|
192
|
+
|
193
|
+
# Count how many times this column appears in accessible CTEs
|
194
|
+
for cte in expression_context.ctes:
|
195
|
+
if cte.alias_or_name in current_tables:
|
196
|
+
if hasattr(cte, "this") and hasattr(cte.this, "expressions"):
|
197
|
+
for expr in cte.this.expressions:
|
198
|
+
expr_column_name = _extract_column_name(expr)
|
199
|
+
|
200
|
+
# Case-insensitive comparison for robustness
|
201
|
+
if expr_column_name and expr_column_name.lower() == column_name.lower():
|
202
|
+
column_count_in_from += 1
|
203
|
+
|
204
|
+
# Column is unambiguous if it appears exactly once in the FROM clause CTEs
|
205
|
+
return column_count_in_from == 1
|
206
|
+
|
207
|
+
|
208
|
+
def _extract_column_name(expr) -> str:
|
209
|
+
"""
|
210
|
+
Extract column name from various expression types.
|
211
|
+
Enhanced to handle more SQLGlot expression types.
|
212
|
+
"""
|
213
|
+
if hasattr(expr, "alias_or_name") and expr.alias_or_name:
|
214
|
+
return expr.alias_or_name
|
215
|
+
elif hasattr(expr, "this"):
|
216
|
+
if hasattr(expr.this, "this"):
|
217
|
+
return str(expr.this.this)
|
218
|
+
elif hasattr(expr.this, "name"):
|
219
|
+
return str(expr.this.name)
|
220
|
+
else:
|
221
|
+
return str(expr.this)
|
222
|
+
elif hasattr(expr, "name"):
|
223
|
+
return str(expr.name)
|
224
|
+
else:
|
225
|
+
return str(expr)
|
226
|
+
|
227
|
+
|
228
|
+
def get_cte_names_from_from_clause(expression_context: exp.Select) -> t.Set[str]:
|
229
|
+
"""
|
230
|
+
Get the set of CTE names that are referenced in the FROM clause of the expression.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
expression_context: The SELECT expression to analyze
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
Set of CTE alias names referenced in the FROM clause (including joins)
|
237
|
+
"""
|
238
|
+
referenced_cte_names = set()
|
239
|
+
|
240
|
+
# Get the main table from FROM clause
|
241
|
+
from_clause = expression_context.args.get("from")
|
242
|
+
if from_clause and from_clause.this:
|
243
|
+
main_table = from_clause.this
|
244
|
+
if hasattr(main_table, "alias_or_name") and main_table.alias_or_name:
|
245
|
+
referenced_cte_names.add(main_table.alias_or_name)
|
246
|
+
|
247
|
+
# Get tables from joins
|
248
|
+
if expression_context.args.get("joins"):
|
249
|
+
join_tables = get_tables_from_expression_with_join(expression_context)
|
250
|
+
for table in join_tables:
|
251
|
+
if hasattr(table, "alias_or_name") and table.alias_or_name:
|
252
|
+
referenced_cte_names.add(table.alias_or_name)
|
253
|
+
|
254
|
+
return referenced_cte_names
|
255
|
+
|
67
256
|
|
68
257
|
def normalize_dict(session: SESSION, data: t.Dict) -> t.Dict:
|
69
258
|
if isinstance(data, dict):
|
sqlframe/base/operations.py
CHANGED
@@ -27,10 +27,9 @@ class Operation(IntEnum):
|
|
27
27
|
WHERE = 2
|
28
28
|
GROUP_BY = 3
|
29
29
|
HAVING = 4
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
LIMIT = 8
|
30
|
+
SELECT = 5
|
31
|
+
ORDER_BY = 6
|
32
|
+
LIMIT = 7
|
34
33
|
|
35
34
|
|
36
35
|
# We want to decorate a function (self: DF, *args, **kwargs) -> T
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sqlframe
|
3
|
-
Version: 3.
|
3
|
+
Version: 3.40.0
|
4
4
|
Summary: Turning PySpark Into a Universal DataFrame API
|
5
5
|
Home-page: https://github.com/eakmanrq/sqlframe
|
6
6
|
Author: Ryan Eakman
|
@@ -18,7 +18,7 @@ Description-Content-Type: text/markdown
|
|
18
18
|
License-File: LICENSE
|
19
19
|
Requires-Dist: more-itertools
|
20
20
|
Requires-Dist: prettytable <4
|
21
|
-
Requires-Dist: sqlglot <27.
|
21
|
+
Requires-Dist: sqlglot <27.13,>=24.0.0
|
22
22
|
Requires-Dist: typing-extensions
|
23
23
|
Provides-Extra: bigquery
|
24
24
|
Requires-Dist: google-cloud-bigquery-storage <3,>=2 ; extra == 'bigquery'
|
@@ -1,18 +1,18 @@
|
|
1
1
|
sqlframe/__init__.py,sha256=SB80yLTITBXHI2GCDS6n6bN5ObHqgPjfpRPAUwxaots,3403
|
2
|
-
sqlframe/_version.py,sha256=
|
2
|
+
sqlframe/_version.py,sha256=fOWY_ffL74_A_EHPCU75GCzXq1ZU-sD4WyU4XHtJjlI,714
|
3
3
|
sqlframe/py.typed,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
|
4
4
|
sqlframe/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
5
|
sqlframe/base/_typing.py,sha256=b2clI5HI1zEZKB_3Msx3FeAJQyft44ubUifJwQRVXyQ,1298
|
6
6
|
sqlframe/base/catalog.py,sha256=-YulM2BMK8MoWbXi05AsJIPxd4AuiZDBCZuk4HoeMlE,38900
|
7
7
|
sqlframe/base/column.py,sha256=f6rK6-hTiNx9WwJP7t6tqL3xEC2gwERPDlhWCS5iCBw,21417
|
8
|
-
sqlframe/base/dataframe.py,sha256=
|
8
|
+
sqlframe/base/dataframe.py,sha256=fveiwPH-JQyUJdyB9PxzjHTvwwBBzBY4pUWq2OraH9A,87328
|
9
9
|
sqlframe/base/decorators.py,sha256=IhE5xNQDkwJHacCvulq5WpUKyKmXm7dL2A3o5WuKGP4,2131
|
10
10
|
sqlframe/base/exceptions.py,sha256=9Uwvqn2eAkDpqm4BrRgbL61qM-GMCbJEMAW8otxO46s,370
|
11
11
|
sqlframe/base/function_alternatives.py,sha256=aTu3nQhIAkZoxrI1IpjpaHEAMxBNms0AnhS0EMR-TwY,51727
|
12
|
-
sqlframe/base/functions.py,sha256=
|
12
|
+
sqlframe/base/functions.py,sha256=fc3jLuPAIJ3Hl4Bezm9Kgzsk4e5uFfgMgfajUCBKQG0,227919
|
13
13
|
sqlframe/base/group.py,sha256=fBm8EUve7W7xz11nybTXr09ih-yZxL_vvEiZVE1eb_0,12025
|
14
|
-
sqlframe/base/normalize.py,sha256=
|
15
|
-
sqlframe/base/operations.py,sha256=
|
14
|
+
sqlframe/base/normalize.py,sha256=YPeopWr8ZRjevArYfrM-DZBkQp4t4UfAEwynoj4VvcU,11773
|
15
|
+
sqlframe/base/operations.py,sha256=g-YNcbvNKTOBbYm23GKfB3fmydlR7ZZDAuZUtXIHtzw,4438
|
16
16
|
sqlframe/base/readerwriter.py,sha256=Nb2VJ_HBmLQp5mK8JhnFooZh2ydAaboCAFVPb-4MNX4,31241
|
17
17
|
sqlframe/base/session.py,sha256=99X-ShK9ohHCX6WdIJs0HhjfK23snaE3Gv6RYc5wqUI,27687
|
18
18
|
sqlframe/base/table.py,sha256=rCeh1W5SWbtEVfkLAUiexzrZwNgmZeptLEmLcM1ABkE,6961
|
@@ -25,7 +25,7 @@ sqlframe/base/mixins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
|
|
25
25
|
sqlframe/base/mixins/catalog_mixins.py,sha256=9fZGWToz9xMJSzUl1vsVtj6TH3TysP3fBCKJLnGUQzE,23353
|
26
26
|
sqlframe/base/mixins/dataframe_mixins.py,sha256=8D2AFtfc0tj9Q5qzlNAXdXOYw9RuD8kpe8wixo8pY5o,1534
|
27
27
|
sqlframe/base/mixins/readwriter_mixins.py,sha256=ItQ_0jZ5RljgmLjGDIzLMRP_NQdy3wAyKwJ6K5NjaqA,4954
|
28
|
-
sqlframe/base/mixins/table_mixins.py,sha256=
|
28
|
+
sqlframe/base/mixins/table_mixins.py,sha256=zoqrgaH1fOgnHkC6C4L8IUyspDa5SETP3OXVdKWxcUM,13917
|
29
29
|
sqlframe/bigquery/__init__.py,sha256=kbaomhYAANPdxeDQhajv8IHfMg_ENKivtYK-rPwaV08,939
|
30
30
|
sqlframe/bigquery/catalog.py,sha256=Dcpp1JKftc3ukdYpn6M1ujqixA-6_1k8aY21Y5Johyc,11899
|
31
31
|
sqlframe/bigquery/column.py,sha256=E1tUa62Y5HajkhgFuebU9zohrGyieudcHzTT8gfalio,40
|
@@ -130,8 +130,8 @@ sqlframe/standalone/udf.py,sha256=azmgtUjHNIPs0WMVNId05SHwiYn41MKVBhKXsQJ5dmY,27
|
|
130
130
|
sqlframe/standalone/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
|
131
131
|
sqlframe/testing/__init__.py,sha256=VVCosQhitU74A3NnE52O4mNtGZONapuEXcc20QmSlnQ,132
|
132
132
|
sqlframe/testing/utils.py,sha256=PFsGZpwNUE_4-g_f43_vstTqsK0AQ2lBneb5Eb6NkFo,13008
|
133
|
-
sqlframe-3.
|
134
|
-
sqlframe-3.
|
135
|
-
sqlframe-3.
|
136
|
-
sqlframe-3.
|
137
|
-
sqlframe-3.
|
133
|
+
sqlframe-3.40.0.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
|
134
|
+
sqlframe-3.40.0.dist-info/METADATA,sha256=43WXPdp-_riwus7pJqzCv6ct0oAEmZden39JcI-hKVU,9070
|
135
|
+
sqlframe-3.40.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
136
|
+
sqlframe-3.40.0.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
|
137
|
+
sqlframe-3.40.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|