sqlframe 3.13.3__py3-none-any.whl → 3.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '3.13.3'
16
- __version_tuple__ = version_tuple = (3, 13, 3)
15
+ __version__ = version = '3.14.0'
16
+ __version_tuple__ = version_tuple = (3, 14, 0)
@@ -481,6 +481,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
481
481
  cte = cte.transform(replace_id_value, replaced_cte_names) # type: ignore
482
482
  if cte.alias_or_name in existing_cte_counts:
483
483
  existing_cte_counts[cte.alias_or_name] += 10
484
+ # Add unique where filter to ensure that the hash of the CTE is unique
484
485
  cte.set(
485
486
  "this",
486
487
  cte.this.where(
@@ -502,6 +503,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
502
503
  new_cte_alias, dialect=self.session.input_dialect, into=exp.TableAlias
503
504
  ),
504
505
  )
506
+ existing_cte_counts[new_cte_alias] = 0
505
507
  existing_ctes.append(cte)
506
508
  else:
507
509
  existing_ctes = ctes
@@ -755,15 +757,20 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
755
757
  ]
756
758
  cte_names_in_join = [x.this for x in join_table_identifiers]
757
759
  # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
758
- # and therefore we allow multiple columns with the same name in the result. This matches the behavior
759
- # of Spark.
760
+ # (or right to left if a right join) and therefore we allow multiple columns with the same
761
+ # name in the result. This matches the behavior of Spark.
760
762
  resolved_column_position: t.Dict[exp.Column, int] = {
761
763
  col.copy(): -1 for col in ambiguous_cols
762
764
  }
763
765
  for ambiguous_col in ambiguous_cols:
766
+ ctes = (
767
+ list(reversed(self.expression.ctes))
768
+ if self.expression.args["joins"][0].args.get("side", "") == "right"
769
+ else self.expression.ctes
770
+ )
764
771
  ctes_with_column = [
765
772
  cte
766
- for cte in self.expression.ctes
773
+ for cte in ctes
767
774
  if cte.alias_or_name in cte_names_in_join
768
775
  and ambiguous_col.alias_or_name in cte.this.named_selects
769
776
  ]
@@ -865,6 +872,68 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
865
872
  """
866
873
  return self.join.__wrapped__(self, other, how="cross") # type: ignore
867
874
 
875
+ def _handle_self_join(self, other_df: DF, join_columns: t.List[Column]):
876
+ # If the two dataframes being joined come from the same branch, we then check if they have any columns that
877
+ # were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
878
+ # the two columns since they would end up with the same table name. We do this by checking for the unique
879
+ # uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
880
+ # it comes from the other df and we change the table name to the other df's table name.
881
+ # See `test_self_join` for an example of this.
882
+ if self.branch_id == other_df.branch_id:
883
+ other_df_unique_uuids = other_df.known_uuids - self.known_uuids
884
+ for col in join_columns:
885
+ for col_expr in col.expression.find_all(exp.Column):
886
+ if (
887
+ "join_on_uuid" in col_expr.meta
888
+ and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
889
+ ):
890
+ col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
891
+
892
+ @staticmethod
893
+ def _handle_join_column_names_only(
894
+ join_columns: t.List[Column],
895
+ join_expression: exp.Select,
896
+ other_df: DF,
897
+ table_names: t.List[str],
898
+ ):
899
+ potential_ctes = [
900
+ cte
901
+ for cte in join_expression.ctes
902
+ if cte.alias_or_name in table_names and cte.alias_or_name != other_df.latest_cte_name
903
+ ]
904
+ # Determine the table to reference for the left side of the join by checking each of the left side
905
+ # tables and see if they have the column being referenced.
906
+ join_column_pairs = []
907
+ for join_column in join_columns:
908
+ num_matching_ctes = 0
909
+ for cte in potential_ctes:
910
+ if join_column.alias_or_name in cte.this.named_selects:
911
+ left_column = join_column.copy().set_table_name(cte.alias_or_name)
912
+ right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
913
+ join_column_pairs.append((left_column, right_column))
914
+ num_matching_ctes += 1
915
+ # We only want to match one table to the column and that should be matched left -> right
916
+ # so we break after the first match
917
+ break
918
+ if num_matching_ctes == 0:
919
+ raise ValueError(
920
+ f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
921
+ )
922
+ join_clause = functools.reduce(
923
+ lambda x, y: x & y,
924
+ [left_column == right_column for left_column, right_column in join_column_pairs],
925
+ )
926
+ return join_column_pairs, join_clause
927
+
928
+ def _normalize_join_clause(
929
+ self, join_columns: t.List[Column], join_expression: t.Optional[exp.Select]
930
+ ) -> Column:
931
+ join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
932
+ if len(join_columns) > 1:
933
+ join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
934
+ join_clause = join_columns[0]
935
+ return join_clause
936
+
868
937
  @operation(Operation.FROM)
869
938
  def join(
870
939
  self,
@@ -888,21 +957,8 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
888
957
  self_columns = self._get_outer_select_columns(join_expression)
889
958
  other_columns = self._get_outer_select_columns(other_df.expression)
890
959
  join_columns = self._ensure_and_normalize_cols(on)
891
- # If the two dataframes being joined come from the same branch, we then check if they have any columns that
892
- # were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
893
- # the two columns since they would end up with the same table name. We do this by checking for the unique
894
- # uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
895
- # it comes from the other df and we change the table name to the other df's table name.
896
- # See `test_self_join` for an example of this.
897
- if self.branch_id == other_df.branch_id:
898
- other_df_unique_uuids = other_df.known_uuids - self.known_uuids
899
- for col in join_columns:
900
- for col_expr in col.expression.find_all(exp.Column):
901
- if (
902
- "join_on_uuid" in col_expr.meta
903
- and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
904
- ):
905
- col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
960
+ self._handle_self_join(other_df, join_columns)
961
+
906
962
  # Determines the join clause and select columns to be used passed on what type of columns were provided for
907
963
  # the join. The columns returned changes based on how the on expression is provided.
908
964
  if how != "cross":
@@ -916,38 +972,9 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
916
972
  table.alias_or_name
917
973
  for table in get_tables_from_expression_with_join(join_expression)
918
974
  ]
919
- potential_ctes = [
920
- cte
921
- for cte in join_expression.ctes
922
- if cte.alias_or_name in table_names
923
- and cte.alias_or_name != other_df.latest_cte_name
924
- ]
925
- # Determine the table to reference for the left side of the join by checking each of the left side
926
- # tables and see if they have the column being referenced.
927
- join_column_pairs = []
928
- for join_column in join_columns:
929
- num_matching_ctes = 0
930
- for cte in potential_ctes:
931
- if join_column.alias_or_name in cte.this.named_selects:
932
- left_column = join_column.copy().set_table_name(cte.alias_or_name)
933
- right_column = join_column.copy().set_table_name(
934
- other_df.latest_cte_name
935
- )
936
- join_column_pairs.append((left_column, right_column))
937
- num_matching_ctes += 1
938
- # We only want to match one table to the column and that should be matched left -> right
939
- # so we break after the first match
940
- break
941
- if num_matching_ctes == 0:
942
- raise ValueError(
943
- f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
944
- )
945
- join_clause = functools.reduce(
946
- lambda x, y: x & y,
947
- [
948
- left_column == right_column
949
- for left_column, right_column in join_column_pairs
950
- ],
975
+
976
+ join_column_pairs, join_clause = self._handle_join_column_names_only(
977
+ join_columns, join_expression, other_df, table_names
951
978
  )
952
979
  join_column_names = [
953
980
  coalesce(
@@ -982,10 +1009,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
982
1009
  * There is no deduplication of the results.
983
1010
  * The left join dataframe columns go first and right come after. No sort preference is given to join columns
984
1011
  """
985
- join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
986
- if len(join_columns) > 1:
987
- join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
988
- join_clause = join_columns[0]
1012
+ join_clause = self._normalize_join_clause(join_columns, join_expression)
989
1013
  select_column_names = [
990
1014
  column.alias_or_name for column in self_columns + other_columns
991
1015
  ]
@@ -0,0 +1,335 @@
1
+ import functools
2
+ import logging
3
+ import typing as t
4
+
5
+ from sqlglot import exp
6
+
7
+ try:
8
+ from sqlglot.expressions import Whens
9
+ except ImportError:
10
+ Whens = None # type: ignore
11
+ from sqlglot.helper import object_to_dict
12
+
13
+ from sqlframe.base.column import Column
14
+ from sqlframe.base.table import (
15
+ DF,
16
+ Clause,
17
+ LazyExpression,
18
+ WhenMatched,
19
+ WhenNotMatched,
20
+ WhenNotMatchedBySource,
21
+ _BaseTable,
22
+ )
23
+
24
+ if t.TYPE_CHECKING:
25
+ from sqlframe.base._typing import ColumnOrLiteral
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def ensure_cte() -> t.Callable[[t.Callable], t.Callable]:
32
+ def decorator(func: t.Callable) -> t.Callable:
33
+ @functools.wraps(func)
34
+ def wrapper(self: _BaseTable, *args, **kwargs) -> t.Any:
35
+ if len(self.expression.ctes) > 0:
36
+ return func(self, *args, **kwargs) # type: ignore
37
+ self_class = self.__class__
38
+ self = self._convert_leaf_to_cte()
39
+ self = self_class(**object_to_dict(self))
40
+ return func(self, *args, **kwargs) # type: ignore
41
+
42
+ wrapper.__wrapped__ = func # type: ignore
43
+ return wrapper
44
+
45
+ return decorator
46
+
47
+
48
+ class _BaseTableMixins(_BaseTable, t.Generic[DF]):
49
+ def _ensure_where_condition(
50
+ self, where: t.Optional[t.Union[Column, str, bool]] = None
51
+ ) -> exp.Expression:
52
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
53
+
54
+ if where is None:
55
+ logger.warning("Empty value for `where`clause. Defaults to `True`.")
56
+ condition: exp.Expression = exp.Boolean(this=True)
57
+ else:
58
+ condition_list = self._ensure_and_normalize_cols(where, self.expression)
59
+ if len(condition_list) > 1:
60
+ condition_list = [functools.reduce(lambda x, y: x & y, condition_list)]
61
+ for col_expr in condition_list[0].expression.find_all(exp.Column):
62
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
63
+ col_expr.set("table", exp.to_identifier(self_name))
64
+ condition = condition_list[0].expression
65
+ if isinstance(condition, exp.Alias):
66
+ condition = condition.this
67
+ return condition
68
+
69
+
70
+ class UpdateSupportMixin(_BaseTableMixins, t.Generic[DF]):
71
+ @ensure_cte()
72
+ def update(
73
+ self,
74
+ set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
75
+ where: t.Optional[t.Union[Column, str, bool]] = None,
76
+ ) -> LazyExpression:
77
+ self_expr = self.expression.ctes[0].this.args["from"].this
78
+
79
+ condition = self._ensure_where_condition(where)
80
+ update_set = self._ensure_and_normalize_update_set(set_)
81
+ update_expr = exp.Update(
82
+ this=self_expr,
83
+ expressions=[
84
+ exp.EQ(
85
+ this=key,
86
+ expression=val,
87
+ )
88
+ for key, val in update_set.items()
89
+ ],
90
+ where=exp.Where(this=condition),
91
+ )
92
+
93
+ return LazyExpression(update_expr, self.session)
94
+
95
+ def _ensure_and_normalize_update_set(
96
+ self,
97
+ set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
98
+ ) -> t.Dict[str, exp.Expression]:
99
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
100
+ update_set = {}
101
+ for key, val in set_.items():
102
+ key_column: Column = self._ensure_and_normalize_col(key)
103
+ key_expr = list(key_column.expression.find_all(exp.Column))
104
+ if len(key_expr) > 1:
105
+ raise ValueError(f"Can only update one a single column at a time.")
106
+ key = key_expr[0].alias_or_name
107
+
108
+ val_column: Column = self._ensure_and_normalize_col(val)
109
+ for col_expr in val_column.expression.find_all(exp.Column):
110
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
111
+ col_expr.set("table", exp.to_identifier(self_name))
112
+ else:
113
+ raise ValueError(
114
+ f"Column `{col_expr.alias_or_name}` does not exist in the table."
115
+ )
116
+
117
+ update_set[key] = val_column.expression
118
+ return update_set
119
+
120
+
121
+ class DeleteSupportMixin(_BaseTableMixins, t.Generic[DF]):
122
+ @ensure_cte()
123
+ def delete(
124
+ self,
125
+ where: t.Optional[t.Union[Column, str, bool]] = None,
126
+ ) -> LazyExpression:
127
+ self_expr = self.expression.ctes[0].this.args["from"].this
128
+
129
+ condition = self._ensure_where_condition(where)
130
+ delete_expr = exp.Delete(
131
+ this=self_expr,
132
+ where=exp.Where(this=condition),
133
+ )
134
+
135
+ return LazyExpression(delete_expr, self.session)
136
+
137
+
138
+ class MergeSupportMixin(_BaseTable, t.Generic[DF]):
139
+ _merge_supported_clauses: t.Iterable[
140
+ t.Union[t.Type[WhenMatched], t.Type[WhenNotMatched], t.Type[WhenNotMatchedBySource]]
141
+ ]
142
+ _merge_support_star: bool
143
+
144
+ @ensure_cte()
145
+ def merge(
146
+ self,
147
+ other_df: DF,
148
+ condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
149
+ clauses: t.Iterable[t.Union[WhenMatched, WhenNotMatched, WhenNotMatchedBySource]],
150
+ ) -> LazyExpression:
151
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
152
+ self_expr = self.expression.ctes[0].this.args["from"].this
153
+
154
+ other_df = other_df._convert_leaf_to_cte()
155
+
156
+ if condition is None:
157
+ raise ValueError("condition cannot be None")
158
+
159
+ condition_columns: Column = self._ensure_and_normalize_condition(condition, other_df)
160
+ other_name = self._create_hash_from_expression(other_df.expression)
161
+ other_expr = exp.Subquery(
162
+ this=other_df.expression, alias=exp.TableAlias(this=exp.to_identifier(other_name))
163
+ )
164
+
165
+ for col_expr in condition_columns.expression.find_all(exp.Column):
166
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
167
+ col_expr.set("table", exp.to_identifier(self_name))
168
+ if col_expr.table == other_df.latest_cte_name:
169
+ col_expr.set("table", exp.to_identifier(other_name))
170
+
171
+ merge_expressions = []
172
+ for clause in clauses:
173
+ if not isinstance(clause, tuple(self._merge_supported_clauses)):
174
+ raise ValueError(
175
+ f"Unsupported clause type {type(clause.clause)} for merge operation"
176
+ )
177
+ expression = None
178
+
179
+ if clause.clause.condition is not None:
180
+ cond_clause = self._ensure_and_normalize_condition(
181
+ clause.clause.condition, other_df, True
182
+ )
183
+ for col_expr in cond_clause.expression.find_all(exp.Column):
184
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
185
+ col_expr.set("table", exp.to_identifier(self_name))
186
+ if col_expr.table == other_df.latest_cte_name:
187
+ col_expr.set("table", exp.to_identifier(other_name))
188
+ else:
189
+ cond_clause = None
190
+ if clause.clause.clause_type == Clause.UPDATE:
191
+ update_set = self._ensure_and_normalize_assignments(
192
+ clause.clause.assignments, other_df
193
+ )
194
+ expression = exp.When(
195
+ matched=clause.clause.matched,
196
+ source=clause.clause.by_source,
197
+ condition=cond_clause.expression if cond_clause else None,
198
+ then=exp.Update(
199
+ expressions=[
200
+ exp.EQ(
201
+ this=key,
202
+ expression=val,
203
+ )
204
+ for key, val in update_set.items()
205
+ ]
206
+ ),
207
+ )
208
+ if clause.clause.clause_type == Clause.UPDATE_ALL:
209
+ if not self._support_star:
210
+ raise ValueError("Merge operation does not support UPDATE_ALL")
211
+ expression = exp.When(
212
+ matched=clause.clause.matched,
213
+ source=clause.clause.by_source,
214
+ condition=cond_clause.expression if cond_clause else None,
215
+ then=exp.Update(expressions=[exp.Star()]),
216
+ )
217
+ elif clause.clause.clause_type == Clause.INSERT:
218
+ insert_values = self._ensure_and_normalize_assignments(
219
+ clause.clause.assignments, other_df
220
+ )
221
+ expression = exp.When(
222
+ matched=clause.clause.matched,
223
+ source=clause.clause.by_source,
224
+ condition=cond_clause.expression if cond_clause else None,
225
+ then=exp.Insert(
226
+ this=exp.Tuple(expressions=[key for key in insert_values.keys()]),
227
+ expression=exp.Tuple(expressions=[val for val in insert_values.values()]),
228
+ ),
229
+ )
230
+ elif clause.clause.clause_type == Clause.INSERT_ALL:
231
+ if not self._support_star:
232
+ raise ValueError("Merge operation does not support INSERT_ALL")
233
+ expression = exp.When(
234
+ matched=clause.clause.matched,
235
+ source=clause.clause.by_source,
236
+ condition=cond_clause.expression if cond_clause else None,
237
+ then=exp.Insert(expression=exp.Star()),
238
+ )
239
+ elif clause.clause.clause_type == Clause.DELETE:
240
+ expression = exp.When(
241
+ matched=clause.clause.matched,
242
+ source=clause.clause.by_source,
243
+ condition=cond_clause.expression if cond_clause else None,
244
+ then=exp.var("DELETE"),
245
+ )
246
+
247
+ if expression:
248
+ merge_expressions.append(expression)
249
+
250
+ if Whens is None:
251
+ merge_expr = exp.merge(
252
+ *merge_expressions,
253
+ into=self_expr,
254
+ using=other_expr,
255
+ on=condition_columns.expression,
256
+ )
257
+ else:
258
+ merge_expr = exp.merge(
259
+ Whens(expressions=merge_expressions),
260
+ into=self_expr,
261
+ using=other_expr,
262
+ on=condition_columns.expression,
263
+ )
264
+
265
+ return LazyExpression(merge_expr, self.session)
266
+
267
+ def _ensure_and_normalize_condition(
268
+ self,
269
+ condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
270
+ other_df: DF,
271
+ clause: t.Optional[bool] = False,
272
+ ):
273
+ join_expression = self._add_ctes_to_expression(
274
+ self.expression, other_df.expression.copy().ctes
275
+ )
276
+ condition = self._ensure_and_normalize_cols(condition, self.expression)
277
+ self._handle_self_join(other_df, condition)
278
+
279
+ if isinstance(condition[0].expression, exp.Column) and not clause:
280
+ table_names = [
281
+ table.alias_or_name
282
+ for table in [
283
+ self.expression.args["from"].this,
284
+ other_df.expression.args["from"].this,
285
+ ]
286
+ ]
287
+
288
+ join_column_pairs, join_clause = self._handle_join_column_names_only(
289
+ condition, join_expression, other_df, table_names
290
+ )
291
+ else:
292
+ join_clause = self._normalize_join_clause(condition, join_expression)
293
+ return join_clause
294
+
295
+ def _ensure_and_normalize_assignments(
296
+ self,
297
+ assignments: t.Dict[
298
+ t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]
299
+ ],
300
+ other_df,
301
+ ) -> t.Dict[exp.Column, exp.Expression]:
302
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
303
+ other_name = self._create_hash_from_expression(other_df.expression)
304
+ update_set = {}
305
+ for key, val in assignments.items():
306
+ key_column: Column = self._ensure_and_normalize_col(key)
307
+ key_expr = list(key_column.expression.find_all(exp.Column))
308
+ if len(key_expr) > 1:
309
+ raise ValueError(f"Target expression `{key_expr}` should be a single column.")
310
+ column_key = exp.column(key_expr[0].alias_or_name)
311
+
312
+ val = self._ensure_and_normalize_col(val)
313
+ val = self._ensure_and_normalize_cols(val, other_df.expression)[0]
314
+ if self.branch_id == other_df.branch_id:
315
+ other_df_unique_uuids = other_df.known_uuids - self.known_uuids
316
+ for col_expr in val.expression.find_all(exp.Column):
317
+ if (
318
+ "join_on_uuid" in col_expr.meta
319
+ and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
320
+ ):
321
+ col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
322
+
323
+ for col_expr in val.expression.find_all(exp.Column):
324
+ if not col_expr.table or col_expr.table == other_df.latest_cte_name:
325
+ col_expr.set("table", exp.to_identifier(other_name))
326
+ elif col_expr.table == self.expression.args["from"].this.alias_or_name:
327
+ col_expr.set("table", exp.to_identifier(self_name))
328
+ else:
329
+ raise ValueError(
330
+ f"Column `{col_expr.alias_or_name}` does not exist in any of the tables."
331
+ )
332
+ if isinstance(val.expression, exp.Alias):
333
+ val.expression = val.expression.this
334
+ update_set[column_key] = val.expression
335
+ return update_set
@@ -21,19 +21,20 @@ else:
21
21
  if t.TYPE_CHECKING:
22
22
  from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
23
23
  from sqlframe.base.column import Column
24
- from sqlframe.base.session import DF, _BaseSession
24
+ from sqlframe.base.session import DF, TABLE, _BaseSession
25
25
  from sqlframe.base.types import StructType
26
26
 
27
27
  SESSION = t.TypeVar("SESSION", bound=_BaseSession)
28
28
  else:
29
29
  SESSION = t.TypeVar("SESSION")
30
30
  DF = t.TypeVar("DF")
31
+ TABLE = t.TypeVar("TABLE")
31
32
 
32
33
 
33
34
  logger = logging.getLogger(__name__)
34
35
 
35
36
 
36
- class _BaseDataFrameReader(t.Generic[SESSION, DF]):
37
+ class _BaseDataFrameReader(t.Generic[SESSION, DF, TABLE]):
37
38
  def __init__(self, spark: SESSION):
38
39
  self._session = spark
39
40
  self.state_format_to_read: t.Optional[str] = None
@@ -42,7 +43,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
42
43
  def session(self) -> SESSION:
43
44
  return self._session
44
45
 
45
- def table(self, tableName: str) -> DF:
46
+ def table(self, tableName: str) -> TABLE:
46
47
  tableName = normalize_string(tableName, from_dialect="input", is_table=True)
47
48
  if df := self.session.temp_views.get(tableName):
48
49
  return df
@@ -50,7 +51,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
50
51
  self.session.catalog.add_table(table)
51
52
  columns = self.session.catalog.get_columns_from_schema(table)
52
53
 
53
- return self.session._create_df(
54
+ return self.session._create_table(
54
55
  exp.Select()
55
56
  .from_(tableName, dialect=self.session.input_dialect)
56
57
  .select(*columns, dialect=self.session.input_dialect)
sqlframe/base/session.py CHANGED
@@ -27,6 +27,7 @@ from sqlframe.base.catalog import _BaseCatalog
27
27
  from sqlframe.base.dataframe import BaseDataFrame
28
28
  from sqlframe.base.normalize import normalize_dict
29
29
  from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
30
+ from sqlframe.base.table import _BaseTable
30
31
  from sqlframe.base.udf import _BaseUDFRegistration
31
32
  from sqlframe.base.util import (
32
33
  get_column_mapping_from_schema_input,
@@ -65,17 +66,19 @@ CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
65
66
  READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
66
67
  WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
67
68
  DF = t.TypeVar("DF", bound=BaseDataFrame)
69
+ TABLE = t.TypeVar("TABLE", bound=_BaseTable)
68
70
  UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration)
69
71
 
70
72
  _MISSING = "MISSING"
71
73
 
72
74
 
73
- class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION]):
75
+ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, TABLE, CONN, UDF_REGISTRATION]):
74
76
  _instance = None
75
77
  _reader: t.Type[READER]
76
78
  _writer: t.Type[WRITER]
77
79
  _catalog: t.Type[CATALOG]
78
80
  _df: t.Type[DF]
81
+ _table: t.Type[TABLE]
79
82
  _udf_registration: t.Type[UDF_REGISTRATION]
80
83
 
81
84
  SANITIZE_COLUMN_NAMES = False
@@ -158,12 +161,15 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION
158
161
  return name.replace("(", "_").replace(")", "_")
159
162
  return name
160
163
 
161
- def table(self, tableName: str) -> DF:
164
+ def table(self, tableName: str) -> TABLE:
162
165
  return self.read.table(tableName)
163
166
 
164
167
  def _create_df(self, *args, **kwargs) -> DF:
165
168
  return self._df(self, *args, **kwargs)
166
169
 
170
+ def _create_table(self, *args, **kwargs) -> TABLE:
171
+ return self._table(self, *args, **kwargs)
172
+
167
173
  def __new__(cls, *args, **kwargs):
168
174
  if _BaseSession._instance is None:
169
175
  _BaseSession._instance = super().__new__(cls)