sqlframe 3.39.3__py3-none-any.whl → 3.40.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '3.39.3'
32
- __version_tuple__ = version_tuple = (3, 39, 3)
31
+ __version__ = version = '3.40.0'
32
+ __version_tuple__ = version_tuple = (3, 40, 0)
33
33
 
34
- __commit_id__ = commit_id = 'g9d915cb1e'
34
+ __commit_id__ = commit_id = 'g93abcd907'
@@ -16,7 +16,6 @@ from dataclasses import dataclass
16
16
  from uuid import uuid4
17
17
 
18
18
  import sqlglot
19
- from more_itertools import partition
20
19
  from prettytable import PrettyTable
21
20
  from sqlglot import Dialect, maybe_parse
22
21
  from sqlglot import expressions as exp
@@ -397,12 +396,21 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
397
396
  return Column.ensure_cols(ensure_list(cols)) # type: ignore
398
397
 
399
398
  def _ensure_and_normalize_cols(
400
- self, cols, expression: t.Optional[exp.Select] = None, skip_star_expansion: bool = False
399
+ self,
400
+ cols,
401
+ expression: t.Optional[exp.Select] = None,
402
+ skip_star_expansion: bool = False,
403
+ remove_identifier_if_possible: bool = True,
401
404
  ) -> t.List[Column]:
402
405
  from sqlframe.base.normalize import normalize
403
406
 
404
407
  cols = self._ensure_list_of_columns(cols)
405
- normalize(self.session, expression or self.expression, cols)
408
+ normalize(
409
+ self.session,
410
+ expression or self.expression,
411
+ cols,
412
+ remove_identifier_if_possible=remove_identifier_if_possible,
413
+ )
406
414
  if not skip_star_expansion:
407
415
  cols = list(flatten([self._expand_star(col) for col in cols]))
408
416
  self._resolve_ambiguous_columns(cols)
@@ -542,23 +550,16 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
542
550
  expression.set("with", exp.With(expressions=existing_ctes))
543
551
  return expression
544
552
 
545
- @classmethod
546
- def _get_outer_select_expressions(
547
- cls, item: exp.Expression
548
- ) -> t.List[t.Union[exp.Column, exp.Alias]]:
549
- outer_select = item.find(exp.Select)
550
- if outer_select:
551
- return outer_select.expressions
552
- return []
553
-
554
553
  @classmethod
555
554
  def _get_outer_select_columns(cls, item: exp.Expression) -> t.List[Column]:
556
555
  from sqlframe.base.session import _BaseSession
557
556
 
558
557
  col = get_func_from_session("col", _BaseSession())
559
558
 
560
- outer_expressions = cls._get_outer_select_expressions(item)
561
- return [col(quote_preserving_alias_or_name(x)) for x in outer_expressions]
559
+ outer_select = item.find(exp.Select)
560
+ if outer_select:
561
+ return [col(quote_preserving_alias_or_name(x)) for x in outer_select.expressions]
562
+ return []
562
563
 
563
564
  def _create_hash_from_expression(self, expression: exp.Expression) -> str:
564
565
  from sqlframe.base.session import _BaseSession
@@ -1025,9 +1026,17 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1025
1026
  return join_column_pairs, join_clause
1026
1027
 
1027
1028
  def _normalize_join_clause(
1028
- self, join_columns: t.List[Column], join_expression: t.Optional[exp.Select]
1029
+ self,
1030
+ join_columns: t.List[Column],
1031
+ join_expression: t.Optional[exp.Select],
1032
+ *,
1033
+ remove_identifier_if_possible: bool = True,
1029
1034
  ) -> Column:
1030
- join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
1035
+ join_columns = self._ensure_and_normalize_cols(
1036
+ join_columns,
1037
+ join_expression,
1038
+ remove_identifier_if_possible=remove_identifier_if_possible,
1039
+ )
1031
1040
  if len(join_columns) > 1:
1032
1041
  join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
1033
1042
  join_clause = join_columns[0]
@@ -1512,23 +1521,20 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1512
1521
  """
1513
1522
  return func(self, *args, **kwargs) # type: ignore
1514
1523
 
1515
- @operation(Operation.SELECT_CONSTRAINED)
1524
+ @operation(Operation.SELECT)
1516
1525
  def withColumn(self, colName: str, col: Column) -> Self:
1517
1526
  return self.withColumns.__wrapped__(self, {colName: col}) # type: ignore
1518
1527
 
1519
- @operation(Operation.SELECT_CONSTRAINED)
1528
+ @operation(Operation.SELECT)
1520
1529
  def withColumnRenamed(self, existing: str, new: str) -> Self:
1521
- col_func = get_func_from_session("col", self.session)
1522
1530
  expression = self.expression.copy()
1523
1531
  existing = self.session._normalize_string(existing)
1524
- outer_expressions = self._get_outer_select_expressions(expression)
1532
+ columns = self._get_outer_select_columns(expression)
1525
1533
  results = []
1526
1534
  found_match = False
1527
- for expr in outer_expressions:
1528
- column = col_func(expr.copy())
1529
- if existing == quote_preserving_alias_or_name(expr):
1530
- if isinstance(column.expression, exp.Alias):
1531
- column.expression.set("alias", exp.to_identifier(new))
1535
+ for column in columns:
1536
+ if column.alias_or_name == existing:
1537
+ column = column.alias(new)
1532
1538
  self._update_display_name_mapping([column], [new])
1533
1539
  found_match = True
1534
1540
  results.append(column)
@@ -1536,7 +1542,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1536
1542
  raise ValueError("Tried to rename a column that doesn't exist")
1537
1543
  return self.select.__wrapped__(self, *results, skip_update_display_name_mapping=True) # type: ignore
1538
1544
 
1539
- @operation(Operation.SELECT_CONSTRAINED)
1545
+ @operation(Operation.SELECT)
1540
1546
  def withColumnsRenamed(self, colsMap: t.Dict[str, str]) -> Self:
1541
1547
  """
1542
1548
  Returns a new :class:`DataFrame` by renaming multiple columns. If a non-existing column is
@@ -1582,7 +1588,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1582
1588
 
1583
1589
  return self.select.__wrapped__(self, *results, skip_update_display_name_mapping=True) # type: ignore
1584
1590
 
1585
- @operation(Operation.SELECT_CONSTRAINED)
1591
+ @operation(Operation.SELECT)
1586
1592
  def withColumns(self, *colsMap: t.Dict[str, Column]) -> Self:
1587
1593
  """
1588
1594
  Returns a new :class:`DataFrame` by adding multiple columns or replacing the
@@ -1620,14 +1626,13 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1620
1626
  """
1621
1627
  if len(colsMap) != 1:
1622
1628
  raise ValueError("Only a single map is supported")
1623
- col_func = get_func_from_session("col")
1624
1629
  col_map = {
1625
1630
  self._ensure_and_normalize_col(k): (self._ensure_and_normalize_col(v), k)
1626
1631
  for k, v in colsMap[0].items()
1627
1632
  }
1628
- existing_expr = self._get_outer_select_expressions(self.expression)
1629
- existing_col_names = [x.alias_or_name for x in existing_expr]
1630
- select_columns = [col_func(x) for x in existing_expr]
1633
+ existing_cols = self._get_outer_select_columns(self.expression)
1634
+ existing_col_names = [x.alias_or_name for x in existing_cols]
1635
+ select_columns = existing_cols
1631
1636
  for col, (col_value, display_name) in col_map.items():
1632
1637
  column_name = col.alias_or_name
1633
1638
  existing_col_index = (
@@ -1644,7 +1649,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1644
1649
  )
1645
1650
  return self.select.__wrapped__(self, *select_columns, skip_update_display_name_mapping=True) # type: ignore
1646
1651
 
1647
- @operation(Operation.SELECT_CONSTRAINED)
1652
+ @operation(Operation.SELECT)
1648
1653
  def drop(self, *cols: t.Union[str, Column]) -> Self:
1649
1654
  # Separate string column names from Column objects for different handling
1650
1655
  column_objs, column_names = partition_to(lambda x: isinstance(x, str), cols, list, set)
@@ -37,9 +37,7 @@ def _get_session() -> _BaseSession:
37
37
 
38
38
  @meta()
39
39
  def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column:
40
- from sqlframe.base.session import _BaseSession
41
-
42
- dialect = _BaseSession().input_dialect
40
+ dialect = _get_session().input_dialect
43
41
  if isinstance(column_name, str):
44
42
  col_expression = expression.to_column(column_name, dialect=dialect).transform(
45
43
  dialect.normalize_identifier
@@ -192,27 +190,27 @@ sum_distinct = sumDistinct
192
190
 
193
191
  @meta()
194
192
  def acos(col: ColumnOrName) -> Column:
195
- return Column.invoke_anonymous_function(col, "ACOS")
193
+ return Column.invoke_expression_over_column(col, expression.Acos)
196
194
 
197
195
 
198
196
  @meta(unsupported_engines="duckdb")
199
197
  def acosh(col: ColumnOrName) -> Column:
200
- return Column.invoke_anonymous_function(col, "ACOSH")
198
+ return Column.invoke_expression_over_column(col, expression.Acosh)
201
199
 
202
200
 
203
201
  @meta()
204
202
  def asin(col: ColumnOrName) -> Column:
205
- return Column.invoke_anonymous_function(col, "ASIN")
203
+ return Column.invoke_expression_over_column(col, expression.Asin)
206
204
 
207
205
 
208
206
  @meta(unsupported_engines="duckdb")
209
207
  def asinh(col: ColumnOrName) -> Column:
210
- return Column.invoke_anonymous_function(col, "ASINH")
208
+ return Column.invoke_expression_over_column(col, expression.Asinh)
211
209
 
212
210
 
213
211
  @meta()
214
212
  def atan(col: ColumnOrName) -> Column:
215
- return Column.invoke_anonymous_function(col, "ATAN")
213
+ return Column.invoke_expression_over_column(col, expression.Atan)
216
214
 
217
215
 
218
216
  @meta()
@@ -220,12 +218,12 @@ def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]
220
218
  col1_value = lit(col1) if isinstance(col1, (int, float)) else col1
221
219
  col2_value = lit(col2) if isinstance(col2, (int, float)) else col2
222
220
 
223
- return Column.invoke_anonymous_function(col1_value, "ATAN2", col2_value)
221
+ return Column.invoke_expression_over_column(col1_value, expression.Atan2, expression=col2_value)
224
222
 
225
223
 
226
224
  @meta(unsupported_engines="duckdb")
227
225
  def atanh(col: ColumnOrName) -> Column:
228
- return Column.invoke_anonymous_function(col, "ATANH")
226
+ return Column.invoke_expression_over_column(col, expression.Atanh)
229
227
 
230
228
 
231
229
  @meta()
@@ -253,12 +251,12 @@ def cosh(col: ColumnOrName) -> Column:
253
251
 
254
252
  @meta()
255
253
  def cot(col: ColumnOrName) -> Column:
256
- return Column.invoke_anonymous_function(col, "COT")
254
+ return Column.invoke_expression_over_column(col, expression.Cot)
257
255
 
258
256
 
259
257
  @meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
260
258
  def csc(col: ColumnOrName) -> Column:
261
- return Column.invoke_anonymous_function(col, "CSC")
259
+ return Column.invoke_expression_over_column(col, expression.Csc)
262
260
 
263
261
 
264
262
  @meta()
@@ -364,7 +362,7 @@ def rint(col: ColumnOrName) -> Column:
364
362
 
365
363
  @meta(unsupported_engines=["duckdb", "postgres", "snowflake"])
366
364
  def sec(col: ColumnOrName) -> Column:
367
- return Column.invoke_anonymous_function(col, "SEC")
365
+ return Column.invoke_expression_over_column(col, expression.Sec)
368
366
 
369
367
 
370
368
  @meta()
@@ -374,12 +372,12 @@ def signum(col: ColumnOrName) -> Column:
374
372
 
375
373
  @meta()
376
374
  def sin(col: ColumnOrName) -> Column:
377
- return Column.invoke_anonymous_function(col, "SIN")
375
+ return Column.invoke_expression_over_column(col, expression.Sin)
378
376
 
379
377
 
380
378
  @meta(unsupported_engines="duckdb")
381
379
  def sinh(col: ColumnOrName) -> Column:
382
- return Column.invoke_anonymous_function(col, "SINH")
380
+ return Column.invoke_expression_over_column(col, expression.Sinh)
383
381
 
384
382
 
385
383
  @meta()
@@ -662,9 +660,7 @@ def grouping_id(*cols: ColumnOrName) -> Column:
662
660
 
663
661
  @meta()
664
662
  def input_file_name() -> Column:
665
- from sqlframe.base.session import _BaseSession
666
-
667
- return Column(expression.Literal.string(_BaseSession()._last_loaded_file or ""))
663
+ return Column(expression.Literal.string(_get_session()._last_loaded_file or ""))
668
664
 
669
665
 
670
666
  @meta()
@@ -944,7 +940,7 @@ def nth_value(
944
940
 
945
941
  @meta()
946
942
  def ntile(n: int) -> Column:
947
- return Column.invoke_anonymous_function(None, "NTILE", lit(n))
943
+ return Column.invoke_expression_over_column(lit(n), expression.Ntile)
948
944
 
949
945
 
950
946
  @meta()
@@ -959,12 +955,10 @@ def current_timestamp() -> Column:
959
955
 
960
956
  @meta()
961
957
  def date_format(col: ColumnOrName, format: str) -> Column:
962
- from sqlframe.base.session import _BaseSession
963
-
964
958
  return Column.invoke_expression_over_column(
965
959
  Column(expression.TimeStrToTime(this=Column.ensure_col(col).column_expression)),
966
960
  expression.TimeToStr,
967
- format=_BaseSession().format_time(format),
961
+ format=_get_session().format_time(format),
968
962
  )
969
963
 
970
964
 
@@ -2832,7 +2826,7 @@ def make_interval(
2832
2826
 
2833
2827
  @meta(unsupported_engines="*")
2834
2828
  def try_add(left: ColumnOrName, right: ColumnOrName) -> Column:
2835
- return Column.invoke_anonymous_function(left, "TRY_ADD", right)
2829
+ return Column.invoke_expression_over_column(left, expression.SafeAdd, expression=right)
2836
2830
 
2837
2831
 
2838
2832
  @meta(unsupported_engines="*")
@@ -2849,12 +2843,12 @@ def try_divide(left: ColumnOrName, right: ColumnOrName) -> Column:
2849
2843
 
2850
2844
  @meta(unsupported_engines="*")
2851
2845
  def try_multiply(left: ColumnOrName, right: ColumnOrName) -> Column:
2852
- return Column.invoke_anonymous_function(left, "TRY_MULTIPLY", right)
2846
+ return Column.invoke_expression_over_column(left, expression.SafeMultiply, expression=right)
2853
2847
 
2854
2848
 
2855
2849
  @meta(unsupported_engines="*")
2856
2850
  def try_subtract(left: ColumnOrName, right: ColumnOrName) -> Column:
2857
- return Column.invoke_anonymous_function(left, "TRY_SUBTRACT", right)
2851
+ return Column.invoke_expression_over_column(left, expression.SafeSubtract, expression=right)
2858
2852
 
2859
2853
 
2860
2854
  @meta(unsupported_engines="*")
@@ -3378,10 +3372,9 @@ def get(col: ColumnOrName, index: t.Union[ColumnOrName, int]) -> Column:
3378
3372
  def get_active_spark_context() -> SparkContext:
3379
3373
  """Raise RuntimeError if SparkContext is not initialized,
3380
3374
  otherwise, returns the active SparkContext."""
3381
- from sqlframe.base.session import _BaseSession
3382
3375
  from sqlframe.spark.session import SparkSession
3383
3376
 
3384
- session: _BaseSession = _BaseSession()
3377
+ session = _get_session()
3385
3378
  if not isinstance(session, SparkSession):
3386
3379
  raise RuntimeError("This function is only available in SparkSession.")
3387
3380
  return session.spark_session.sparkContext
@@ -5263,7 +5256,7 @@ def regexp_extract_all(
5263
5256
  )
5264
5257
 
5265
5258
 
5266
- @meta(unsupported_engines="*")
5259
+ @meta(unsupported_engines=["duckdb", "bigquery", "postgres", "snowflake"])
5267
5260
  def regexp_instr(
5268
5261
  str: ColumnOrName, regexp: ColumnOrName, idx: t.Optional[t.Union[int, Column]] = None
5269
5262
  ) -> Column:
@@ -5298,11 +5291,9 @@ def regexp_instr(
5298
5291
  >>> df.select(regexp_instr('str', col("regexp")).alias('d')).collect()
5299
5292
  [Row(d=1)]
5300
5293
  """
5301
- if idx is None:
5302
- return Column.invoke_anonymous_function(str, "regexp_instr", regexp)
5303
- else:
5304
- idx = lit(idx) if isinstance(idx, int) else idx
5305
- return Column.invoke_anonymous_function(str, "regexp_instr", regexp, idx)
5294
+ return Column.invoke_expression_over_column(
5295
+ str, expression.RegexpInstr, expression=regexp, group=idx
5296
+ )
5306
5297
 
5307
5298
 
5308
5299
  @meta(unsupported_engines="snowflake")
@@ -6344,7 +6335,7 @@ def to_unix_timestamp(
6344
6335
  session = _get_session()
6345
6336
 
6346
6337
  if session._is_duckdb:
6347
- format = format or _BaseSession().default_time_format
6338
+ format = format or session.default_time_format
6348
6339
  timestamp = Column.ensure_col(timestamp).cast("string")
6349
6340
 
6350
6341
  if format is not None:
@@ -275,7 +275,9 @@ class MergeSupportMixin(_BaseTable, t.Generic[DF]):
275
275
  join_expression = self._add_ctes_to_expression(
276
276
  self.expression, other_df.expression.copy().ctes
277
277
  )
278
- condition = self._ensure_and_normalize_cols(condition, self.expression)
278
+ condition = self._ensure_and_normalize_cols(
279
+ condition, self.expression, remove_identifier_if_possible=False
280
+ )
279
281
  self._handle_self_join(other_df, condition)
280
282
 
281
283
  if isinstance(condition[0].expression, exp.Column) and not clause:
@@ -291,7 +293,9 @@ class MergeSupportMixin(_BaseTable, t.Generic[DF]):
291
293
  condition, join_expression, other_df, table_names
292
294
  )
293
295
  else:
294
- join_clause = self._normalize_join_clause(condition, join_expression)
296
+ join_clause = self._normalize_join_clause(
297
+ condition, join_expression, remove_identifier_if_possible=False
298
+ )
295
299
  return join_clause
296
300
 
297
301
  def _ensure_and_normalize_assignments(
@@ -16,33 +16,103 @@ if t.TYPE_CHECKING:
16
16
  NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column])
17
17
 
18
18
 
19
- def normalize(session: SESSION, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]):
19
+ def normalize(
20
+ session: SESSION,
21
+ expression_context: exp.Select,
22
+ expr: t.List[NORMALIZE_INPUT],
23
+ *,
24
+ remove_identifier_if_possible: bool = True,
25
+ ):
20
26
  expr = ensure_list(expr)
21
27
  expressions = _ensure_expressions(expr)
22
28
  for expression in expressions:
23
29
  identifiers = expression.find_all(exp.Identifier)
24
30
  for identifier in identifiers:
25
31
  identifier.transform(session.input_dialect.normalize_identifier)
26
- replace_alias_name_with_cte_name(session, expression_context, identifier)
27
- replace_branch_and_sequence_ids_with_cte_name(session, expression_context, identifier)
32
+ replace_alias_name_with_cte_name(
33
+ session,
34
+ expression_context,
35
+ identifier,
36
+ remove_identifier_if_possible=remove_identifier_if_possible,
37
+ )
38
+ replace_branch_and_sequence_ids_with_cte_name(
39
+ session,
40
+ expression_context,
41
+ identifier,
42
+ remove_identifier_if_possible=remove_identifier_if_possible,
43
+ )
28
44
 
29
45
 
30
46
  def replace_alias_name_with_cte_name(
31
- session: SESSION, expression_context: exp.Select, id: exp.Identifier
47
+ session: SESSION,
48
+ expression_context: exp.Select,
49
+ id: exp.Identifier,
50
+ *,
51
+ remove_identifier_if_possible: bool,
32
52
  ):
33
53
  normalized_id = session._normalize_string(id.alias_or_name)
34
54
  if normalized_id in session.name_to_sequence_id_mapping:
35
- for cte in reversed(expression_context.ctes):
55
+ # Get CTEs that are referenced in the FROM clause
56
+ referenced_cte_names = get_cte_names_from_from_clause(expression_context)
57
+
58
+ # Filter CTEs to only include those defined and referenced by the FROM clause
59
+ filtered_ctes = [
60
+ cte
61
+ for cte in reversed(expression_context.ctes)
62
+ if cte.alias_or_name in referenced_cte_names
63
+ ]
64
+
65
+ for cte in filtered_ctes:
36
66
  if cte.args["sequence_id"] in session.name_to_sequence_id_mapping[normalized_id]:
37
67
  _set_alias_name(id, cte.alias_or_name)
38
68
  break
69
+ else:
70
+ # Fallback: If not found in filtered CTEs, search through ALL CTEs unfiltered
71
+ for cte in reversed(expression_context.ctes):
72
+ if cte.args["sequence_id"] in session.name_to_sequence_id_mapping[normalized_id]:
73
+ _set_alias_name(id, cte.alias_or_name)
74
+ break
75
+ else:
76
+ # Final fallback: If this is a qualified column reference (table.column)
77
+ # and the table doesn't exist in FROM clause, remove the qualifier IF the column is unambiguously available
78
+ parent = id.parent
79
+ if parent and isinstance(parent, exp.Column) and remove_identifier_if_possible:
80
+ # Check if this table is not available in current context
81
+ current_tables = get_cte_names_from_from_clause(expression_context)
82
+ if normalized_id not in current_tables:
83
+ # Check if this table ID matches any CTE name directly (cross-context CTE reference)
84
+ cte_exists = any(
85
+ cte.alias_or_name == normalized_id for cte in expression_context.ctes
86
+ )
87
+
88
+ if cte_exists:
89
+ # This is a reference to a CTE that exists but is not in the current FROM clause
90
+ # Get the column name being referenced
91
+ column_name = (
92
+ _extract_column_name(parent.this)
93
+ if hasattr(parent, "this")
94
+ else None
95
+ )
96
+
97
+ # Only remove qualifier if the column is unambiguously available in current context
98
+ if column_name and is_column_unambiguously_available(
99
+ expression_context, column_name
100
+ ):
101
+ parent.set("table", None)
39
102
 
40
103
 
41
104
  def replace_branch_and_sequence_ids_with_cte_name(
42
- session: SESSION, expression_context: exp.Select, id: exp.Identifier
105
+ session: SESSION,
106
+ expression_context: exp.Select,
107
+ id: exp.Identifier,
108
+ *,
109
+ remove_identifier_if_possible: bool,
43
110
  ):
44
111
  normalized_id = session._normalize_string(id.alias_or_name)
45
112
  if normalized_id in session.known_ids:
113
+ # Get CTEs that are referenced in the FROM clause
114
+ referenced_cte_names = get_cte_names_from_from_clause(expression_context)
115
+
46
116
  # Check if we have a join and if both the tables in that join share a common branch id
47
117
  # If so we need to have this reference the left table by default unless the id is a sequence
48
118
  # id then it keeps that reference. This handles the weird edge case in spark that shouldn't
@@ -51,19 +121,138 @@ def replace_branch_and_sequence_ids_with_cte_name(
51
121
  join_table_aliases = [
52
122
  x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
53
123
  ]
124
+ # Filter CTEs to only include those referenced in the FROM clause
54
125
  ctes_in_join = [
55
- cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
126
+ cte
127
+ for cte in expression_context.ctes
128
+ if cte.alias_or_name in join_table_aliases
129
+ and cte.alias_or_name in referenced_cte_names
56
130
  ]
57
- if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
131
+ if (
132
+ len(ctes_in_join) >= 2
133
+ and ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]
134
+ ):
58
135
  assert len(ctes_in_join) == 2
59
136
  _set_alias_name(id, ctes_in_join[0].alias_or_name)
60
137
  return
61
138
 
139
+ # Filter CTEs to only include those defined and referenced by the FROM clause
140
+ filtered_ctes = [
141
+ cte
142
+ for cte in reversed(expression_context.ctes)
143
+ if cte.alias_or_name in referenced_cte_names
144
+ ]
145
+
146
+ for cte in filtered_ctes:
147
+ if normalized_id in (cte.args["branch_id"], cte.args["sequence_id"]):
148
+ _set_alias_name(id, cte.alias_or_name)
149
+ return
150
+
151
+ # Fallback: If not found in filtered CTEs, search through ALL CTEs unfiltered
62
152
  for cte in reversed(expression_context.ctes):
63
153
  if normalized_id in (cte.args["branch_id"], cte.args["sequence_id"]):
64
154
  _set_alias_name(id, cte.alias_or_name)
65
155
  return
66
156
 
157
+ # Final fallback: If this is a qualified column reference (table.column)
158
+ # and the table doesn't exist in FROM clause, remove the qualifier IF the column is unambiguously available
159
+ parent = id.parent
160
+ if parent and isinstance(parent, exp.Column) and remove_identifier_if_possible:
161
+ # Check if this table is not available in current context
162
+ current_tables = get_cte_names_from_from_clause(expression_context)
163
+ if normalized_id not in current_tables:
164
+ # Check if this table ID matches any CTE name directly (cross-context CTE reference)
165
+ cte_exists = any(cte.alias_or_name == normalized_id for cte in expression_context.ctes)
166
+
167
+ if cte_exists:
168
+ # This is a reference to a CTE that exists but is not in the current FROM clause
169
+ # Get the column name being referenced
170
+ column_name = _extract_column_name(parent.this) if hasattr(parent, "this") else None
171
+
172
+ # Only remove qualifier if the column is unambiguously available in current context
173
+ if column_name and is_column_unambiguously_available(
174
+ expression_context, column_name
175
+ ):
176
+ parent.set("table", None)
177
+
178
+
179
+ def is_column_unambiguously_available(expression_context: exp.Select, column_name: str) -> bool:
180
+ """
181
+ Check if a column name is unambiguously available in the current context.
182
+ Returns True if the column appears exactly once across all accessible CTEs.
183
+
184
+ Enhanced to handle more column expression types and edge cases.
185
+ """
186
+ current_tables = get_cte_names_from_from_clause(expression_context)
187
+ column_count_in_from = 0
188
+
189
+ # If no tables in FROM clause, be conservative
190
+ if not current_tables:
191
+ return False
192
+
193
+ # Count how many times this column appears in accessible CTEs
194
+ for cte in expression_context.ctes:
195
+ if cte.alias_or_name in current_tables:
196
+ if hasattr(cte, "this") and hasattr(cte.this, "expressions"):
197
+ for expr in cte.this.expressions:
198
+ expr_column_name = _extract_column_name(expr)
199
+
200
+ # Case-insensitive comparison for robustness
201
+ if expr_column_name and expr_column_name.lower() == column_name.lower():
202
+ column_count_in_from += 1
203
+
204
+ # Column is unambiguous if it appears exactly once in the FROM clause CTEs
205
+ return column_count_in_from == 1
206
+
207
+
208
+ def _extract_column_name(expr) -> str:
209
+ """
210
+ Extract column name from various expression types.
211
+ Enhanced to handle more SQLGlot expression types.
212
+ """
213
+ if hasattr(expr, "alias_or_name") and expr.alias_or_name:
214
+ return expr.alias_or_name
215
+ elif hasattr(expr, "this"):
216
+ if hasattr(expr.this, "this"):
217
+ return str(expr.this.this)
218
+ elif hasattr(expr.this, "name"):
219
+ return str(expr.this.name)
220
+ else:
221
+ return str(expr.this)
222
+ elif hasattr(expr, "name"):
223
+ return str(expr.name)
224
+ else:
225
+ return str(expr)
226
+
227
+
228
+ def get_cte_names_from_from_clause(expression_context: exp.Select) -> t.Set[str]:
229
+ """
230
+ Get the set of CTE names that are referenced in the FROM clause of the expression.
231
+
232
+ Args:
233
+ expression_context: The SELECT expression to analyze
234
+
235
+ Returns:
236
+ Set of CTE alias names referenced in the FROM clause (including joins)
237
+ """
238
+ referenced_cte_names = set()
239
+
240
+ # Get the main table from FROM clause
241
+ from_clause = expression_context.args.get("from")
242
+ if from_clause and from_clause.this:
243
+ main_table = from_clause.this
244
+ if hasattr(main_table, "alias_or_name") and main_table.alias_or_name:
245
+ referenced_cte_names.add(main_table.alias_or_name)
246
+
247
+ # Get tables from joins
248
+ if expression_context.args.get("joins"):
249
+ join_tables = get_tables_from_expression_with_join(expression_context)
250
+ for table in join_tables:
251
+ if hasattr(table, "alias_or_name") and table.alias_or_name:
252
+ referenced_cte_names.add(table.alias_or_name)
253
+
254
+ return referenced_cte_names
255
+
67
256
 
68
257
  def normalize_dict(session: SESSION, data: t.Dict) -> t.Dict:
69
258
  if isinstance(data, dict):
@@ -27,10 +27,9 @@ class Operation(IntEnum):
27
27
  WHERE = 2
28
28
  GROUP_BY = 3
29
29
  HAVING = 4
30
- SELECT_CONSTRAINED = 5
31
- SELECT = 6
32
- ORDER_BY = 7
33
- LIMIT = 8
30
+ SELECT = 5
31
+ ORDER_BY = 6
32
+ LIMIT = 7
34
33
 
35
34
 
36
35
  # We want to decorate a function (self: DF, *args, **kwargs) -> T
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sqlframe
3
- Version: 3.39.3
3
+ Version: 3.40.0
4
4
  Summary: Turning PySpark Into a Universal DataFrame API
5
5
  Home-page: https://github.com/eakmanrq/sqlframe
6
6
  Author: Ryan Eakman
@@ -18,7 +18,7 @@ Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
19
  Requires-Dist: more-itertools
20
20
  Requires-Dist: prettytable <4
21
- Requires-Dist: sqlglot <27.9,>=24.0.0
21
+ Requires-Dist: sqlglot <27.13,>=24.0.0
22
22
  Requires-Dist: typing-extensions
23
23
  Provides-Extra: bigquery
24
24
  Requires-Dist: google-cloud-bigquery-storage <3,>=2 ; extra == 'bigquery'
@@ -1,18 +1,18 @@
1
1
  sqlframe/__init__.py,sha256=SB80yLTITBXHI2GCDS6n6bN5ObHqgPjfpRPAUwxaots,3403
2
- sqlframe/_version.py,sha256=Vixv4hfZnHHXCXSmZD4wlHJUBkhCMzDLIyo5HqkJdes,714
2
+ sqlframe/_version.py,sha256=fOWY_ffL74_A_EHPCU75GCzXq1ZU-sD4WyU4XHtJjlI,714
3
3
  sqlframe/py.typed,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
4
4
  sqlframe/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  sqlframe/base/_typing.py,sha256=b2clI5HI1zEZKB_3Msx3FeAJQyft44ubUifJwQRVXyQ,1298
6
6
  sqlframe/base/catalog.py,sha256=-YulM2BMK8MoWbXi05AsJIPxd4AuiZDBCZuk4HoeMlE,38900
7
7
  sqlframe/base/column.py,sha256=f6rK6-hTiNx9WwJP7t6tqL3xEC2gwERPDlhWCS5iCBw,21417
8
- sqlframe/base/dataframe.py,sha256=HHjDaeap4_w4HRRj87lhQjFTczxLKhFD8b-9vhK2KsY,87592
8
+ sqlframe/base/dataframe.py,sha256=fveiwPH-JQyUJdyB9PxzjHTvwwBBzBY4pUWq2OraH9A,87328
9
9
  sqlframe/base/decorators.py,sha256=IhE5xNQDkwJHacCvulq5WpUKyKmXm7dL2A3o5WuKGP4,2131
10
10
  sqlframe/base/exceptions.py,sha256=9Uwvqn2eAkDpqm4BrRgbL61qM-GMCbJEMAW8otxO46s,370
11
11
  sqlframe/base/function_alternatives.py,sha256=aTu3nQhIAkZoxrI1IpjpaHEAMxBNms0AnhS0EMR-TwY,51727
12
- sqlframe/base/functions.py,sha256=RVNoRzM19BUwypdc0izYrrQe2Fe4_e9SbtpDkdD2bec,227981
12
+ sqlframe/base/functions.py,sha256=fc3jLuPAIJ3Hl4Bezm9Kgzsk4e5uFfgMgfajUCBKQG0,227919
13
13
  sqlframe/base/group.py,sha256=fBm8EUve7W7xz11nybTXr09ih-yZxL_vvEiZVE1eb_0,12025
14
- sqlframe/base/normalize.py,sha256=nXAJ5CwxVf4DV0GsH-q1w0p8gmjSMlv96k_ez1eVul8,3880
15
- sqlframe/base/operations.py,sha256=8dkMNqjG3xP1w_6euAj8FpwweD7t590HYjoeoCr5LqI,4465
14
+ sqlframe/base/normalize.py,sha256=YPeopWr8ZRjevArYfrM-DZBkQp4t4UfAEwynoj4VvcU,11773
15
+ sqlframe/base/operations.py,sha256=g-YNcbvNKTOBbYm23GKfB3fmydlR7ZZDAuZUtXIHtzw,4438
16
16
  sqlframe/base/readerwriter.py,sha256=Nb2VJ_HBmLQp5mK8JhnFooZh2ydAaboCAFVPb-4MNX4,31241
17
17
  sqlframe/base/session.py,sha256=99X-ShK9ohHCX6WdIJs0HhjfK23snaE3Gv6RYc5wqUI,27687
18
18
  sqlframe/base/table.py,sha256=rCeh1W5SWbtEVfkLAUiexzrZwNgmZeptLEmLcM1ABkE,6961
@@ -25,7 +25,7 @@ sqlframe/base/mixins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
25
25
  sqlframe/base/mixins/catalog_mixins.py,sha256=9fZGWToz9xMJSzUl1vsVtj6TH3TysP3fBCKJLnGUQzE,23353
26
26
  sqlframe/base/mixins/dataframe_mixins.py,sha256=8D2AFtfc0tj9Q5qzlNAXdXOYw9RuD8kpe8wixo8pY5o,1534
27
27
  sqlframe/base/mixins/readwriter_mixins.py,sha256=ItQ_0jZ5RljgmLjGDIzLMRP_NQdy3wAyKwJ6K5NjaqA,4954
28
- sqlframe/base/mixins/table_mixins.py,sha256=3MhsOARkplwED1GRD0wq1vR8GNuop34kt3Jg8MATIjc,13791
28
+ sqlframe/base/mixins/table_mixins.py,sha256=zoqrgaH1fOgnHkC6C4L8IUyspDa5SETP3OXVdKWxcUM,13917
29
29
  sqlframe/bigquery/__init__.py,sha256=kbaomhYAANPdxeDQhajv8IHfMg_ENKivtYK-rPwaV08,939
30
30
  sqlframe/bigquery/catalog.py,sha256=Dcpp1JKftc3ukdYpn6M1ujqixA-6_1k8aY21Y5Johyc,11899
31
31
  sqlframe/bigquery/column.py,sha256=E1tUa62Y5HajkhgFuebU9zohrGyieudcHzTT8gfalio,40
@@ -130,8 +130,8 @@ sqlframe/standalone/udf.py,sha256=azmgtUjHNIPs0WMVNId05SHwiYn41MKVBhKXsQJ5dmY,27
130
130
  sqlframe/standalone/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
131
131
  sqlframe/testing/__init__.py,sha256=VVCosQhitU74A3NnE52O4mNtGZONapuEXcc20QmSlnQ,132
132
132
  sqlframe/testing/utils.py,sha256=PFsGZpwNUE_4-g_f43_vstTqsK0AQ2lBneb5Eb6NkFo,13008
133
- sqlframe-3.39.3.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
134
- sqlframe-3.39.3.dist-info/METADATA,sha256=eyKm8nGawKAujUOiCBn4PEFpSh_UzsnEV7LpKQVecRM,9069
135
- sqlframe-3.39.3.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
136
- sqlframe-3.39.3.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
137
- sqlframe-3.39.3.dist-info/RECORD,,
133
+ sqlframe-3.40.0.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
134
+ sqlframe-3.40.0.dist-info/METADATA,sha256=43WXPdp-_riwus7pJqzCv6ct0oAEmZden39JcI-hKVU,9070
135
+ sqlframe-3.40.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
136
+ sqlframe-3.40.0.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
137
+ sqlframe-3.40.0.dist-info/RECORD,,