sqlframe 3.13.4__py3-none-any.whl → 3.14.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '3.13.4'
16
- __version_tuple__ = version_tuple = (3, 13, 4)
15
+ __version__ = version = '3.14.1'
16
+ __version_tuple__ = version_tuple = (3, 14, 1)
@@ -79,6 +79,23 @@ JOIN_HINTS = {
79
79
  "SHUFFLE_REPLICATE_NL",
80
80
  }
81
81
 
82
+ JOIN_TYPE_MAPPING = {
83
+ "inner": "inner",
84
+ "cross": "cross",
85
+ "outer": "full_outer",
86
+ "full": "full_outer",
87
+ "fullouter": "full_outer",
88
+ "left": "left_outer",
89
+ "leftouter": "left_outer",
90
+ "right": "right_outer",
91
+ "rightouter": "right_outer",
92
+ "semi": "left_semi",
93
+ "leftsemi": "left_semi",
94
+ "left_semi": "left_semi",
95
+ "anti": "left_anti",
96
+ "leftanti": "left_anti",
97
+ "left_anti": "left_anti",
98
+ }
82
99
 
83
100
  DF = t.TypeVar("DF", bound="BaseDataFrame")
84
101
 
@@ -872,6 +889,68 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
872
889
  """
873
890
  return self.join.__wrapped__(self, other, how="cross") # type: ignore
874
891
 
892
+ def _handle_self_join(self, other_df: DF, join_columns: t.List[Column]):
893
+ # If the two dataframes being joined come from the same branch, we then check if they have any columns that
894
+ # were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
895
+ # the two columns since they would end up with the same table name. We do this by checking for the unique
896
+ # uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
897
+ # it comes from the other df and we change the table name to the other df's table name.
898
+ # See `test_self_join` for an example of this.
899
+ if self.branch_id == other_df.branch_id:
900
+ other_df_unique_uuids = other_df.known_uuids - self.known_uuids
901
+ for col in join_columns:
902
+ for col_expr in col.expression.find_all(exp.Column):
903
+ if (
904
+ "join_on_uuid" in col_expr.meta
905
+ and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
906
+ ):
907
+ col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
908
+
909
+ @staticmethod
910
+ def _handle_join_column_names_only(
911
+ join_columns: t.List[Column],
912
+ join_expression: exp.Select,
913
+ other_df: DF,
914
+ table_names: t.List[str],
915
+ ):
916
+ potential_ctes = [
917
+ cte
918
+ for cte in join_expression.ctes
919
+ if cte.alias_or_name in table_names and cte.alias_or_name != other_df.latest_cte_name
920
+ ]
921
+ # Determine the table to reference for the left side of the join by checking each of the left side
922
+ # tables and see if they have the column being referenced.
923
+ join_column_pairs = []
924
+ for join_column in join_columns:
925
+ num_matching_ctes = 0
926
+ for cte in potential_ctes:
927
+ if join_column.alias_or_name in cte.this.named_selects:
928
+ left_column = join_column.copy().set_table_name(cte.alias_or_name)
929
+ right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
930
+ join_column_pairs.append((left_column, right_column))
931
+ num_matching_ctes += 1
932
+ # We only want to match one table to the column and that should be matched left -> right
933
+ # so we break after the first match
934
+ break
935
+ if num_matching_ctes == 0:
936
+ raise ValueError(
937
+ f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
938
+ )
939
+ join_clause = functools.reduce(
940
+ lambda x, y: x & y,
941
+ [left_column == right_column for left_column, right_column in join_column_pairs],
942
+ )
943
+ return join_column_pairs, join_clause
944
+
945
+ def _normalize_join_clause(
946
+ self, join_columns: t.List[Column], join_expression: t.Optional[exp.Select]
947
+ ) -> Column:
948
+ join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
949
+ if len(join_columns) > 1:
950
+ join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
951
+ join_clause = join_columns[0]
952
+ return join_clause
953
+
875
954
  @operation(Operation.FROM)
876
955
  def join(
877
956
  self,
@@ -882,37 +961,33 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
882
961
  ) -> Self:
883
962
  from sqlframe.base.functions import coalesce
884
963
 
885
- if on is None:
964
+ if (on is None) and ("cross" not in how):
886
965
  logger.warning("Got no value for on. This appears to change the join to a cross join.")
887
966
  how = "cross"
967
+ if (on is not None) and ("cross" in how):
968
+ # Not a lot of doc, but Spark handles cross with predicate as an inner join
969
+ # https://learn.microsoft.com/en-us/dotnet/api/microsoft.spark.sql.dataframe.join
970
+ logger.warning("Got cross join with an 'on' value. This will result in an inner join.")
971
+ how = "inner"
888
972
 
889
973
  other_df = other_df._convert_leaf_to_cte()
890
974
  join_expression = self._add_ctes_to_expression(self.expression, other_df.expression.ctes)
891
975
  # We will determine actual "join on" expression later so we don't provide it at first
892
- join_expression = join_expression.join(
893
- join_expression.ctes[-1].alias, join_type=how.replace("_", " ")
894
- )
976
+ join_type = JOIN_TYPE_MAPPING.get(how, how).replace("_", " ")
977
+ join_expression = join_expression.join(join_expression.ctes[-1].alias, join_type=join_type)
895
978
  self_columns = self._get_outer_select_columns(join_expression)
896
979
  other_columns = self._get_outer_select_columns(other_df.expression)
897
980
  join_columns = self._ensure_and_normalize_cols(on)
898
- # If the two dataframes being joined come from the same branch, we then check if they have any columns that
899
- # were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
900
- # the two columns since they would end up with the same table name. We do this by checking for the unique
901
- # uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
902
- # it comes from the other df and we change the table name to the other df's table name.
903
- # See `test_self_join` for an example of this.
904
- if self.branch_id == other_df.branch_id:
905
- other_df_unique_uuids = other_df.known_uuids - self.known_uuids
906
- for col in join_columns:
907
- for col_expr in col.expression.find_all(exp.Column):
908
- if (
909
- "join_on_uuid" in col_expr.meta
910
- and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
911
- ):
912
- col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
981
+ self._handle_self_join(other_df, join_columns)
982
+
913
983
  # Determines the join clause and select columns to be used passed on what type of columns were provided for
914
984
  # the join. The columns returned changes based on how the on expression is provided.
915
- if how != "cross":
985
+ select_columns = (
986
+ self_columns
987
+ if join_type in ["left anti", "left semi"]
988
+ else self_columns + other_columns
989
+ )
990
+ if join_type != "cross":
916
991
  if isinstance(join_columns[0].expression, exp.Column):
917
992
  """
918
993
  Unique characteristics of join on column names only:
@@ -923,38 +998,9 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
923
998
  table.alias_or_name
924
999
  for table in get_tables_from_expression_with_join(join_expression)
925
1000
  ]
926
- potential_ctes = [
927
- cte
928
- for cte in join_expression.ctes
929
- if cte.alias_or_name in table_names
930
- and cte.alias_or_name != other_df.latest_cte_name
931
- ]
932
- # Determine the table to reference for the left side of the join by checking each of the left side
933
- # tables and see if they have the column being referenced.
934
- join_column_pairs = []
935
- for join_column in join_columns:
936
- num_matching_ctes = 0
937
- for cte in potential_ctes:
938
- if join_column.alias_or_name in cte.this.named_selects:
939
- left_column = join_column.copy().set_table_name(cte.alias_or_name)
940
- right_column = join_column.copy().set_table_name(
941
- other_df.latest_cte_name
942
- )
943
- join_column_pairs.append((left_column, right_column))
944
- num_matching_ctes += 1
945
- # We only want to match one table to the column and that should be matched left -> right
946
- # so we break after the first match
947
- break
948
- if num_matching_ctes == 0:
949
- raise ValueError(
950
- f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
951
- )
952
- join_clause = functools.reduce(
953
- lambda x, y: x & y,
954
- [
955
- left_column == right_column
956
- for left_column, right_column in join_column_pairs
957
- ],
1001
+
1002
+ join_column_pairs, join_clause = self._handle_join_column_names_only(
1003
+ join_columns, join_expression, other_df, table_names
958
1004
  )
959
1005
  join_column_names = [
960
1006
  coalesce(
@@ -972,7 +1018,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
972
1018
  if not isinstance(column.expression.this, exp.Star)
973
1019
  else column.sql()
974
1020
  )
975
- for column in self_columns + other_columns
1021
+ for column in select_columns
976
1022
  ]
977
1023
  select_column_names = [
978
1024
  column_name
@@ -989,17 +1035,12 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
989
1035
  * There is no deduplication of the results.
990
1036
  * The left join dataframe columns go first and right come after. No sort preference is given to join columns
991
1037
  """
992
- join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
993
- if len(join_columns) > 1:
994
- join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
995
- join_clause = join_columns[0]
996
- select_column_names = [
997
- column.alias_or_name for column in self_columns + other_columns
998
- ]
1038
+ join_clause = self._normalize_join_clause(join_columns, join_expression)
1039
+ select_column_names = [column.alias_or_name for column in select_columns]
999
1040
 
1000
1041
  # Update the on expression with the actual join clause to replace the dummy one from before
1001
1042
  else:
1002
- select_column_names = [column.alias_or_name for column in self_columns + other_columns]
1043
+ select_column_names = [column.alias_or_name for column in select_columns]
1003
1044
  join_clause = None
1004
1045
  join_expression.args["joins"][-1].set("on", join_clause.expression if join_clause else None)
1005
1046
  new_df = self.copy(expression=join_expression)
@@ -0,0 +1,335 @@
1
+ import functools
2
+ import logging
3
+ import typing as t
4
+
5
+ from sqlglot import exp
6
+
7
+ try:
8
+ from sqlglot.expressions import Whens
9
+ except ImportError:
10
+ Whens = None # type: ignore
11
+ from sqlglot.helper import object_to_dict
12
+
13
+ from sqlframe.base.column import Column
14
+ from sqlframe.base.table import (
15
+ DF,
16
+ Clause,
17
+ LazyExpression,
18
+ WhenMatched,
19
+ WhenNotMatched,
20
+ WhenNotMatchedBySource,
21
+ _BaseTable,
22
+ )
23
+
24
+ if t.TYPE_CHECKING:
25
+ from sqlframe.base._typing import ColumnOrLiteral
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def ensure_cte() -> t.Callable[[t.Callable], t.Callable]:
32
+ def decorator(func: t.Callable) -> t.Callable:
33
+ @functools.wraps(func)
34
+ def wrapper(self: _BaseTable, *args, **kwargs) -> t.Any:
35
+ if len(self.expression.ctes) > 0:
36
+ return func(self, *args, **kwargs) # type: ignore
37
+ self_class = self.__class__
38
+ self = self._convert_leaf_to_cte()
39
+ self = self_class(**object_to_dict(self))
40
+ return func(self, *args, **kwargs) # type: ignore
41
+
42
+ wrapper.__wrapped__ = func # type: ignore
43
+ return wrapper
44
+
45
+ return decorator
46
+
47
+
48
+ class _BaseTableMixins(_BaseTable, t.Generic[DF]):
49
+ def _ensure_where_condition(
50
+ self, where: t.Optional[t.Union[Column, str, bool]] = None
51
+ ) -> exp.Expression:
52
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
53
+
54
+ if where is None:
55
+ logger.warning("Empty value for `where`clause. Defaults to `True`.")
56
+ condition: exp.Expression = exp.Boolean(this=True)
57
+ else:
58
+ condition_list = self._ensure_and_normalize_cols(where, self.expression)
59
+ if len(condition_list) > 1:
60
+ condition_list = [functools.reduce(lambda x, y: x & y, condition_list)]
61
+ for col_expr in condition_list[0].expression.find_all(exp.Column):
62
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
63
+ col_expr.set("table", exp.to_identifier(self_name))
64
+ condition = condition_list[0].expression
65
+ if isinstance(condition, exp.Alias):
66
+ condition = condition.this
67
+ return condition
68
+
69
+
70
+ class UpdateSupportMixin(_BaseTableMixins, t.Generic[DF]):
71
+ @ensure_cte()
72
+ def update(
73
+ self,
74
+ set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
75
+ where: t.Optional[t.Union[Column, str, bool]] = None,
76
+ ) -> LazyExpression:
77
+ self_expr = self.expression.ctes[0].this.args["from"].this
78
+
79
+ condition = self._ensure_where_condition(where)
80
+ update_set = self._ensure_and_normalize_update_set(set_)
81
+ update_expr = exp.Update(
82
+ this=self_expr,
83
+ expressions=[
84
+ exp.EQ(
85
+ this=key,
86
+ expression=val,
87
+ )
88
+ for key, val in update_set.items()
89
+ ],
90
+ where=exp.Where(this=condition),
91
+ )
92
+
93
+ return LazyExpression(update_expr, self.session)
94
+
95
+ def _ensure_and_normalize_update_set(
96
+ self,
97
+ set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
98
+ ) -> t.Dict[str, exp.Expression]:
99
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
100
+ update_set = {}
101
+ for key, val in set_.items():
102
+ key_column: Column = self._ensure_and_normalize_col(key)
103
+ key_expr = list(key_column.expression.find_all(exp.Column))
104
+ if len(key_expr) > 1:
105
+ raise ValueError(f"Can only update one a single column at a time.")
106
+ key = key_expr[0].alias_or_name
107
+
108
+ val_column: Column = self._ensure_and_normalize_col(val)
109
+ for col_expr in val_column.expression.find_all(exp.Column):
110
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
111
+ col_expr.set("table", exp.to_identifier(self_name))
112
+ else:
113
+ raise ValueError(
114
+ f"Column `{col_expr.alias_or_name}` does not exist in the table."
115
+ )
116
+
117
+ update_set[key] = val_column.expression
118
+ return update_set
119
+
120
+
121
+ class DeleteSupportMixin(_BaseTableMixins, t.Generic[DF]):
122
+ @ensure_cte()
123
+ def delete(
124
+ self,
125
+ where: t.Optional[t.Union[Column, str, bool]] = None,
126
+ ) -> LazyExpression:
127
+ self_expr = self.expression.ctes[0].this.args["from"].this
128
+
129
+ condition = self._ensure_where_condition(where)
130
+ delete_expr = exp.Delete(
131
+ this=self_expr,
132
+ where=exp.Where(this=condition),
133
+ )
134
+
135
+ return LazyExpression(delete_expr, self.session)
136
+
137
+
138
+ class MergeSupportMixin(_BaseTable, t.Generic[DF]):
139
+ _merge_supported_clauses: t.Iterable[
140
+ t.Union[t.Type[WhenMatched], t.Type[WhenNotMatched], t.Type[WhenNotMatchedBySource]]
141
+ ]
142
+ _merge_support_star: bool
143
+
144
+ @ensure_cte()
145
+ def merge(
146
+ self,
147
+ other_df: DF,
148
+ condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
149
+ clauses: t.Iterable[t.Union[WhenMatched, WhenNotMatched, WhenNotMatchedBySource]],
150
+ ) -> LazyExpression:
151
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
152
+ self_expr = self.expression.ctes[0].this.args["from"].this
153
+
154
+ other_df = other_df._convert_leaf_to_cte()
155
+
156
+ if condition is None:
157
+ raise ValueError("condition cannot be None")
158
+
159
+ condition_columns: Column = self._ensure_and_normalize_condition(condition, other_df)
160
+ other_name = self._create_hash_from_expression(other_df.expression)
161
+ other_expr = exp.Subquery(
162
+ this=other_df.expression, alias=exp.TableAlias(this=exp.to_identifier(other_name))
163
+ )
164
+
165
+ for col_expr in condition_columns.expression.find_all(exp.Column):
166
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
167
+ col_expr.set("table", exp.to_identifier(self_name))
168
+ if col_expr.table == other_df.latest_cte_name:
169
+ col_expr.set("table", exp.to_identifier(other_name))
170
+
171
+ merge_expressions = []
172
+ for clause in clauses:
173
+ if not isinstance(clause, tuple(self._merge_supported_clauses)):
174
+ raise ValueError(
175
+ f"Unsupported clause type {type(clause.clause)} for merge operation"
176
+ )
177
+ expression = None
178
+
179
+ if clause.clause.condition is not None:
180
+ cond_clause = self._ensure_and_normalize_condition(
181
+ clause.clause.condition, other_df, True
182
+ )
183
+ for col_expr in cond_clause.expression.find_all(exp.Column):
184
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
185
+ col_expr.set("table", exp.to_identifier(self_name))
186
+ if col_expr.table == other_df.latest_cte_name:
187
+ col_expr.set("table", exp.to_identifier(other_name))
188
+ else:
189
+ cond_clause = None
190
+ if clause.clause.clause_type == Clause.UPDATE:
191
+ update_set = self._ensure_and_normalize_assignments(
192
+ clause.clause.assignments, other_df
193
+ )
194
+ expression = exp.When(
195
+ matched=clause.clause.matched,
196
+ source=clause.clause.by_source,
197
+ condition=cond_clause.expression if cond_clause else None,
198
+ then=exp.Update(
199
+ expressions=[
200
+ exp.EQ(
201
+ this=key,
202
+ expression=val,
203
+ )
204
+ for key, val in update_set.items()
205
+ ]
206
+ ),
207
+ )
208
+ if clause.clause.clause_type == Clause.UPDATE_ALL:
209
+ if not self._support_star:
210
+ raise ValueError("Merge operation does not support UPDATE_ALL")
211
+ expression = exp.When(
212
+ matched=clause.clause.matched,
213
+ source=clause.clause.by_source,
214
+ condition=cond_clause.expression if cond_clause else None,
215
+ then=exp.Update(expressions=[exp.Star()]),
216
+ )
217
+ elif clause.clause.clause_type == Clause.INSERT:
218
+ insert_values = self._ensure_and_normalize_assignments(
219
+ clause.clause.assignments, other_df
220
+ )
221
+ expression = exp.When(
222
+ matched=clause.clause.matched,
223
+ source=clause.clause.by_source,
224
+ condition=cond_clause.expression if cond_clause else None,
225
+ then=exp.Insert(
226
+ this=exp.Tuple(expressions=[key for key in insert_values.keys()]),
227
+ expression=exp.Tuple(expressions=[val for val in insert_values.values()]),
228
+ ),
229
+ )
230
+ elif clause.clause.clause_type == Clause.INSERT_ALL:
231
+ if not self._support_star:
232
+ raise ValueError("Merge operation does not support INSERT_ALL")
233
+ expression = exp.When(
234
+ matched=clause.clause.matched,
235
+ source=clause.clause.by_source,
236
+ condition=cond_clause.expression if cond_clause else None,
237
+ then=exp.Insert(expression=exp.Star()),
238
+ )
239
+ elif clause.clause.clause_type == Clause.DELETE:
240
+ expression = exp.When(
241
+ matched=clause.clause.matched,
242
+ source=clause.clause.by_source,
243
+ condition=cond_clause.expression if cond_clause else None,
244
+ then=exp.var("DELETE"),
245
+ )
246
+
247
+ if expression:
248
+ merge_expressions.append(expression)
249
+
250
+ if Whens is None:
251
+ merge_expr = exp.merge(
252
+ *merge_expressions,
253
+ into=self_expr,
254
+ using=other_expr,
255
+ on=condition_columns.expression,
256
+ )
257
+ else:
258
+ merge_expr = exp.merge(
259
+ Whens(expressions=merge_expressions),
260
+ into=self_expr,
261
+ using=other_expr,
262
+ on=condition_columns.expression,
263
+ )
264
+
265
+ return LazyExpression(merge_expr, self.session)
266
+
267
+ def _ensure_and_normalize_condition(
268
+ self,
269
+ condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
270
+ other_df: DF,
271
+ clause: t.Optional[bool] = False,
272
+ ):
273
+ join_expression = self._add_ctes_to_expression(
274
+ self.expression, other_df.expression.copy().ctes
275
+ )
276
+ condition = self._ensure_and_normalize_cols(condition, self.expression)
277
+ self._handle_self_join(other_df, condition)
278
+
279
+ if isinstance(condition[0].expression, exp.Column) and not clause:
280
+ table_names = [
281
+ table.alias_or_name
282
+ for table in [
283
+ self.expression.args["from"].this,
284
+ other_df.expression.args["from"].this,
285
+ ]
286
+ ]
287
+
288
+ join_column_pairs, join_clause = self._handle_join_column_names_only(
289
+ condition, join_expression, other_df, table_names
290
+ )
291
+ else:
292
+ join_clause = self._normalize_join_clause(condition, join_expression)
293
+ return join_clause
294
+
295
+ def _ensure_and_normalize_assignments(
296
+ self,
297
+ assignments: t.Dict[
298
+ t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]
299
+ ],
300
+ other_df,
301
+ ) -> t.Dict[exp.Column, exp.Expression]:
302
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
303
+ other_name = self._create_hash_from_expression(other_df.expression)
304
+ update_set = {}
305
+ for key, val in assignments.items():
306
+ key_column: Column = self._ensure_and_normalize_col(key)
307
+ key_expr = list(key_column.expression.find_all(exp.Column))
308
+ if len(key_expr) > 1:
309
+ raise ValueError(f"Target expression `{key_expr}` should be a single column.")
310
+ column_key = exp.column(key_expr[0].alias_or_name)
311
+
312
+ val = self._ensure_and_normalize_col(val)
313
+ val = self._ensure_and_normalize_cols(val, other_df.expression)[0]
314
+ if self.branch_id == other_df.branch_id:
315
+ other_df_unique_uuids = other_df.known_uuids - self.known_uuids
316
+ for col_expr in val.expression.find_all(exp.Column):
317
+ if (
318
+ "join_on_uuid" in col_expr.meta
319
+ and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
320
+ ):
321
+ col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
322
+
323
+ for col_expr in val.expression.find_all(exp.Column):
324
+ if not col_expr.table or col_expr.table == other_df.latest_cte_name:
325
+ col_expr.set("table", exp.to_identifier(other_name))
326
+ elif col_expr.table == self.expression.args["from"].this.alias_or_name:
327
+ col_expr.set("table", exp.to_identifier(self_name))
328
+ else:
329
+ raise ValueError(
330
+ f"Column `{col_expr.alias_or_name}` does not exist in any of the tables."
331
+ )
332
+ if isinstance(val.expression, exp.Alias):
333
+ val.expression = val.expression.this
334
+ update_set[column_key] = val.expression
335
+ return update_set
@@ -21,19 +21,20 @@ else:
21
21
  if t.TYPE_CHECKING:
22
22
  from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
23
23
  from sqlframe.base.column import Column
24
- from sqlframe.base.session import DF, _BaseSession
24
+ from sqlframe.base.session import DF, TABLE, _BaseSession
25
25
  from sqlframe.base.types import StructType
26
26
 
27
27
  SESSION = t.TypeVar("SESSION", bound=_BaseSession)
28
28
  else:
29
29
  SESSION = t.TypeVar("SESSION")
30
30
  DF = t.TypeVar("DF")
31
+ TABLE = t.TypeVar("TABLE")
31
32
 
32
33
 
33
34
  logger = logging.getLogger(__name__)
34
35
 
35
36
 
36
- class _BaseDataFrameReader(t.Generic[SESSION, DF]):
37
+ class _BaseDataFrameReader(t.Generic[SESSION, DF, TABLE]):
37
38
  def __init__(self, spark: SESSION):
38
39
  self._session = spark
39
40
  self.state_format_to_read: t.Optional[str] = None
@@ -42,7 +43,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
42
43
  def session(self) -> SESSION:
43
44
  return self._session
44
45
 
45
- def table(self, tableName: str) -> DF:
46
+ def table(self, tableName: str) -> TABLE:
46
47
  tableName = normalize_string(tableName, from_dialect="input", is_table=True)
47
48
  if df := self.session.temp_views.get(tableName):
48
49
  return df
@@ -50,7 +51,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
50
51
  self.session.catalog.add_table(table)
51
52
  columns = self.session.catalog.get_columns_from_schema(table)
52
53
 
53
- return self.session._create_df(
54
+ return self.session._create_table(
54
55
  exp.Select()
55
56
  .from_(tableName, dialect=self.session.input_dialect)
56
57
  .select(*columns, dialect=self.session.input_dialect)
sqlframe/base/session.py CHANGED
@@ -27,6 +27,7 @@ from sqlframe.base.catalog import _BaseCatalog
27
27
  from sqlframe.base.dataframe import BaseDataFrame
28
28
  from sqlframe.base.normalize import normalize_dict
29
29
  from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
30
+ from sqlframe.base.table import _BaseTable
30
31
  from sqlframe.base.udf import _BaseUDFRegistration
31
32
  from sqlframe.base.util import (
32
33
  get_column_mapping_from_schema_input,
@@ -65,17 +66,19 @@ CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
65
66
  READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
66
67
  WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
67
68
  DF = t.TypeVar("DF", bound=BaseDataFrame)
69
+ TABLE = t.TypeVar("TABLE", bound=_BaseTable)
68
70
  UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration)
69
71
 
70
72
  _MISSING = "MISSING"
71
73
 
72
74
 
73
- class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION]):
75
+ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, TABLE, CONN, UDF_REGISTRATION]):
74
76
  _instance = None
75
77
  _reader: t.Type[READER]
76
78
  _writer: t.Type[WRITER]
77
79
  _catalog: t.Type[CATALOG]
78
80
  _df: t.Type[DF]
81
+ _table: t.Type[TABLE]
79
82
  _udf_registration: t.Type[UDF_REGISTRATION]
80
83
 
81
84
  SANITIZE_COLUMN_NAMES = False
@@ -158,12 +161,15 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION
158
161
  return name.replace("(", "_").replace(")", "_")
159
162
  return name
160
163
 
161
- def table(self, tableName: str) -> DF:
164
+ def table(self, tableName: str) -> TABLE:
162
165
  return self.read.table(tableName)
163
166
 
164
167
  def _create_df(self, *args, **kwargs) -> DF:
165
168
  return self._df(self, *args, **kwargs)
166
169
 
170
+ def _create_table(self, *args, **kwargs) -> TABLE:
171
+ return self._table(self, *args, **kwargs)
172
+
167
173
  def __new__(cls, *args, **kwargs):
168
174
  if _BaseSession._instance is None:
169
175
  _BaseSession._instance = super().__new__(cls)