sqlframe 3.13.4__py3-none-any.whl → 3.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '3.13.4'
16
- __version_tuple__ = version_tuple = (3, 13, 4)
15
+ __version__ = version = '3.14.0'
16
+ __version_tuple__ = version_tuple = (3, 14, 0)
@@ -872,6 +872,68 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
872
872
  """
873
873
  return self.join.__wrapped__(self, other, how="cross") # type: ignore
874
874
 
875
+ def _handle_self_join(self, other_df: DF, join_columns: t.List[Column]):
876
+ # If the two dataframes being joined come from the same branch, we then check if they have any columns that
877
+ # were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
878
+ # the two columns since they would end up with the same table name. We do this by checking for the unique
879
+ # uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
880
+ # it comes from the other df and we change the table name to the other df's table name.
881
+ # See `test_self_join` for an example of this.
882
+ if self.branch_id == other_df.branch_id:
883
+ other_df_unique_uuids = other_df.known_uuids - self.known_uuids
884
+ for col in join_columns:
885
+ for col_expr in col.expression.find_all(exp.Column):
886
+ if (
887
+ "join_on_uuid" in col_expr.meta
888
+ and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
889
+ ):
890
+ col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
891
+
892
+ @staticmethod
893
+ def _handle_join_column_names_only(
894
+ join_columns: t.List[Column],
895
+ join_expression: exp.Select,
896
+ other_df: DF,
897
+ table_names: t.List[str],
898
+ ):
899
+ potential_ctes = [
900
+ cte
901
+ for cte in join_expression.ctes
902
+ if cte.alias_or_name in table_names and cte.alias_or_name != other_df.latest_cte_name
903
+ ]
904
+ # Determine the table to reference for the left side of the join by checking each of the left side
905
+ # tables and see if they have the column being referenced.
906
+ join_column_pairs = []
907
+ for join_column in join_columns:
908
+ num_matching_ctes = 0
909
+ for cte in potential_ctes:
910
+ if join_column.alias_or_name in cte.this.named_selects:
911
+ left_column = join_column.copy().set_table_name(cte.alias_or_name)
912
+ right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
913
+ join_column_pairs.append((left_column, right_column))
914
+ num_matching_ctes += 1
915
+ # We only want to match one table to the column and that should be matched left -> right
916
+ # so we break after the first match
917
+ break
918
+ if num_matching_ctes == 0:
919
+ raise ValueError(
920
+ f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
921
+ )
922
+ join_clause = functools.reduce(
923
+ lambda x, y: x & y,
924
+ [left_column == right_column for left_column, right_column in join_column_pairs],
925
+ )
926
+ return join_column_pairs, join_clause
927
+
928
+ def _normalize_join_clause(
929
+ self, join_columns: t.List[Column], join_expression: t.Optional[exp.Select]
930
+ ) -> Column:
931
+ join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
932
+ if len(join_columns) > 1:
933
+ join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
934
+ join_clause = join_columns[0]
935
+ return join_clause
936
+
875
937
  @operation(Operation.FROM)
876
938
  def join(
877
939
  self,
@@ -895,21 +957,8 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
895
957
  self_columns = self._get_outer_select_columns(join_expression)
896
958
  other_columns = self._get_outer_select_columns(other_df.expression)
897
959
  join_columns = self._ensure_and_normalize_cols(on)
898
- # If the two dataframes being joined come from the same branch, we then check if they have any columns that
899
- # were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
900
- # the two columns since they would end up with the same table name. We do this by checking for the unique
901
- # uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
902
- # it comes from the other df and we change the table name to the other df's table name.
903
- # See `test_self_join` for an example of this.
904
- if self.branch_id == other_df.branch_id:
905
- other_df_unique_uuids = other_df.known_uuids - self.known_uuids
906
- for col in join_columns:
907
- for col_expr in col.expression.find_all(exp.Column):
908
- if (
909
- "join_on_uuid" in col_expr.meta
910
- and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
911
- ):
912
- col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
960
+ self._handle_self_join(other_df, join_columns)
961
+
913
962
  # Determines the join clause and select columns to be used passed on what type of columns were provided for
914
963
  # the join. The columns returned changes based on how the on expression is provided.
915
964
  if how != "cross":
@@ -923,38 +972,9 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
923
972
  table.alias_or_name
924
973
  for table in get_tables_from_expression_with_join(join_expression)
925
974
  ]
926
- potential_ctes = [
927
- cte
928
- for cte in join_expression.ctes
929
- if cte.alias_or_name in table_names
930
- and cte.alias_or_name != other_df.latest_cte_name
931
- ]
932
- # Determine the table to reference for the left side of the join by checking each of the left side
933
- # tables and see if they have the column being referenced.
934
- join_column_pairs = []
935
- for join_column in join_columns:
936
- num_matching_ctes = 0
937
- for cte in potential_ctes:
938
- if join_column.alias_or_name in cte.this.named_selects:
939
- left_column = join_column.copy().set_table_name(cte.alias_or_name)
940
- right_column = join_column.copy().set_table_name(
941
- other_df.latest_cte_name
942
- )
943
- join_column_pairs.append((left_column, right_column))
944
- num_matching_ctes += 1
945
- # We only want to match one table to the column and that should be matched left -> right
946
- # so we break after the first match
947
- break
948
- if num_matching_ctes == 0:
949
- raise ValueError(
950
- f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
951
- )
952
- join_clause = functools.reduce(
953
- lambda x, y: x & y,
954
- [
955
- left_column == right_column
956
- for left_column, right_column in join_column_pairs
957
- ],
975
+
976
+ join_column_pairs, join_clause = self._handle_join_column_names_only(
977
+ join_columns, join_expression, other_df, table_names
958
978
  )
959
979
  join_column_names = [
960
980
  coalesce(
@@ -989,10 +1009,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
989
1009
  * There is no deduplication of the results.
990
1010
  * The left join dataframe columns go first and right come after. No sort preference is given to join columns
991
1011
  """
992
- join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
993
- if len(join_columns) > 1:
994
- join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
995
- join_clause = join_columns[0]
1012
+ join_clause = self._normalize_join_clause(join_columns, join_expression)
996
1013
  select_column_names = [
997
1014
  column.alias_or_name for column in self_columns + other_columns
998
1015
  ]
@@ -0,0 +1,335 @@
1
+ import functools
2
+ import logging
3
+ import typing as t
4
+
5
+ from sqlglot import exp
6
+
7
+ try:
8
+ from sqlglot.expressions import Whens
9
+ except ImportError:
10
+ Whens = None # type: ignore
11
+ from sqlglot.helper import object_to_dict
12
+
13
+ from sqlframe.base.column import Column
14
+ from sqlframe.base.table import (
15
+ DF,
16
+ Clause,
17
+ LazyExpression,
18
+ WhenMatched,
19
+ WhenNotMatched,
20
+ WhenNotMatchedBySource,
21
+ _BaseTable,
22
+ )
23
+
24
+ if t.TYPE_CHECKING:
25
+ from sqlframe.base._typing import ColumnOrLiteral
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def ensure_cte() -> t.Callable[[t.Callable], t.Callable]:
32
+ def decorator(func: t.Callable) -> t.Callable:
33
+ @functools.wraps(func)
34
+ def wrapper(self: _BaseTable, *args, **kwargs) -> t.Any:
35
+ if len(self.expression.ctes) > 0:
36
+ return func(self, *args, **kwargs) # type: ignore
37
+ self_class = self.__class__
38
+ self = self._convert_leaf_to_cte()
39
+ self = self_class(**object_to_dict(self))
40
+ return func(self, *args, **kwargs) # type: ignore
41
+
42
+ wrapper.__wrapped__ = func # type: ignore
43
+ return wrapper
44
+
45
+ return decorator
46
+
47
+
48
+ class _BaseTableMixins(_BaseTable, t.Generic[DF]):
49
+ def _ensure_where_condition(
50
+ self, where: t.Optional[t.Union[Column, str, bool]] = None
51
+ ) -> exp.Expression:
52
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
53
+
54
+ if where is None:
55
+ logger.warning("Empty value for `where`clause. Defaults to `True`.")
56
+ condition: exp.Expression = exp.Boolean(this=True)
57
+ else:
58
+ condition_list = self._ensure_and_normalize_cols(where, self.expression)
59
+ if len(condition_list) > 1:
60
+ condition_list = [functools.reduce(lambda x, y: x & y, condition_list)]
61
+ for col_expr in condition_list[0].expression.find_all(exp.Column):
62
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
63
+ col_expr.set("table", exp.to_identifier(self_name))
64
+ condition = condition_list[0].expression
65
+ if isinstance(condition, exp.Alias):
66
+ condition = condition.this
67
+ return condition
68
+
69
+
70
+ class UpdateSupportMixin(_BaseTableMixins, t.Generic[DF]):
71
+ @ensure_cte()
72
+ def update(
73
+ self,
74
+ set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
75
+ where: t.Optional[t.Union[Column, str, bool]] = None,
76
+ ) -> LazyExpression:
77
+ self_expr = self.expression.ctes[0].this.args["from"].this
78
+
79
+ condition = self._ensure_where_condition(where)
80
+ update_set = self._ensure_and_normalize_update_set(set_)
81
+ update_expr = exp.Update(
82
+ this=self_expr,
83
+ expressions=[
84
+ exp.EQ(
85
+ this=key,
86
+ expression=val,
87
+ )
88
+ for key, val in update_set.items()
89
+ ],
90
+ where=exp.Where(this=condition),
91
+ )
92
+
93
+ return LazyExpression(update_expr, self.session)
94
+
95
+ def _ensure_and_normalize_update_set(
96
+ self,
97
+ set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
98
+ ) -> t.Dict[str, exp.Expression]:
99
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
100
+ update_set = {}
101
+ for key, val in set_.items():
102
+ key_column: Column = self._ensure_and_normalize_col(key)
103
+ key_expr = list(key_column.expression.find_all(exp.Column))
104
+ if len(key_expr) > 1:
105
+ raise ValueError(f"Can only update one a single column at a time.")
106
+ key = key_expr[0].alias_or_name
107
+
108
+ val_column: Column = self._ensure_and_normalize_col(val)
109
+ for col_expr in val_column.expression.find_all(exp.Column):
110
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
111
+ col_expr.set("table", exp.to_identifier(self_name))
112
+ else:
113
+ raise ValueError(
114
+ f"Column `{col_expr.alias_or_name}` does not exist in the table."
115
+ )
116
+
117
+ update_set[key] = val_column.expression
118
+ return update_set
119
+
120
+
121
+ class DeleteSupportMixin(_BaseTableMixins, t.Generic[DF]):
122
+ @ensure_cte()
123
+ def delete(
124
+ self,
125
+ where: t.Optional[t.Union[Column, str, bool]] = None,
126
+ ) -> LazyExpression:
127
+ self_expr = self.expression.ctes[0].this.args["from"].this
128
+
129
+ condition = self._ensure_where_condition(where)
130
+ delete_expr = exp.Delete(
131
+ this=self_expr,
132
+ where=exp.Where(this=condition),
133
+ )
134
+
135
+ return LazyExpression(delete_expr, self.session)
136
+
137
+
138
+ class MergeSupportMixin(_BaseTable, t.Generic[DF]):
139
+ _merge_supported_clauses: t.Iterable[
140
+ t.Union[t.Type[WhenMatched], t.Type[WhenNotMatched], t.Type[WhenNotMatchedBySource]]
141
+ ]
142
+ _merge_support_star: bool
143
+
144
+ @ensure_cte()
145
+ def merge(
146
+ self,
147
+ other_df: DF,
148
+ condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
149
+ clauses: t.Iterable[t.Union[WhenMatched, WhenNotMatched, WhenNotMatchedBySource]],
150
+ ) -> LazyExpression:
151
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
152
+ self_expr = self.expression.ctes[0].this.args["from"].this
153
+
154
+ other_df = other_df._convert_leaf_to_cte()
155
+
156
+ if condition is None:
157
+ raise ValueError("condition cannot be None")
158
+
159
+ condition_columns: Column = self._ensure_and_normalize_condition(condition, other_df)
160
+ other_name = self._create_hash_from_expression(other_df.expression)
161
+ other_expr = exp.Subquery(
162
+ this=other_df.expression, alias=exp.TableAlias(this=exp.to_identifier(other_name))
163
+ )
164
+
165
+ for col_expr in condition_columns.expression.find_all(exp.Column):
166
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
167
+ col_expr.set("table", exp.to_identifier(self_name))
168
+ if col_expr.table == other_df.latest_cte_name:
169
+ col_expr.set("table", exp.to_identifier(other_name))
170
+
171
+ merge_expressions = []
172
+ for clause in clauses:
173
+ if not isinstance(clause, tuple(self._merge_supported_clauses)):
174
+ raise ValueError(
175
+ f"Unsupported clause type {type(clause.clause)} for merge operation"
176
+ )
177
+ expression = None
178
+
179
+ if clause.clause.condition is not None:
180
+ cond_clause = self._ensure_and_normalize_condition(
181
+ clause.clause.condition, other_df, True
182
+ )
183
+ for col_expr in cond_clause.expression.find_all(exp.Column):
184
+ if col_expr.table == self.expression.args["from"].this.alias_or_name:
185
+ col_expr.set("table", exp.to_identifier(self_name))
186
+ if col_expr.table == other_df.latest_cte_name:
187
+ col_expr.set("table", exp.to_identifier(other_name))
188
+ else:
189
+ cond_clause = None
190
+ if clause.clause.clause_type == Clause.UPDATE:
191
+ update_set = self._ensure_and_normalize_assignments(
192
+ clause.clause.assignments, other_df
193
+ )
194
+ expression = exp.When(
195
+ matched=clause.clause.matched,
196
+ source=clause.clause.by_source,
197
+ condition=cond_clause.expression if cond_clause else None,
198
+ then=exp.Update(
199
+ expressions=[
200
+ exp.EQ(
201
+ this=key,
202
+ expression=val,
203
+ )
204
+ for key, val in update_set.items()
205
+ ]
206
+ ),
207
+ )
208
+ if clause.clause.clause_type == Clause.UPDATE_ALL:
209
+ if not self._support_star:
210
+ raise ValueError("Merge operation does not support UPDATE_ALL")
211
+ expression = exp.When(
212
+ matched=clause.clause.matched,
213
+ source=clause.clause.by_source,
214
+ condition=cond_clause.expression if cond_clause else None,
215
+ then=exp.Update(expressions=[exp.Star()]),
216
+ )
217
+ elif clause.clause.clause_type == Clause.INSERT:
218
+ insert_values = self._ensure_and_normalize_assignments(
219
+ clause.clause.assignments, other_df
220
+ )
221
+ expression = exp.When(
222
+ matched=clause.clause.matched,
223
+ source=clause.clause.by_source,
224
+ condition=cond_clause.expression if cond_clause else None,
225
+ then=exp.Insert(
226
+ this=exp.Tuple(expressions=[key for key in insert_values.keys()]),
227
+ expression=exp.Tuple(expressions=[val for val in insert_values.values()]),
228
+ ),
229
+ )
230
+ elif clause.clause.clause_type == Clause.INSERT_ALL:
231
+ if not self._support_star:
232
+ raise ValueError("Merge operation does not support INSERT_ALL")
233
+ expression = exp.When(
234
+ matched=clause.clause.matched,
235
+ source=clause.clause.by_source,
236
+ condition=cond_clause.expression if cond_clause else None,
237
+ then=exp.Insert(expression=exp.Star()),
238
+ )
239
+ elif clause.clause.clause_type == Clause.DELETE:
240
+ expression = exp.When(
241
+ matched=clause.clause.matched,
242
+ source=clause.clause.by_source,
243
+ condition=cond_clause.expression if cond_clause else None,
244
+ then=exp.var("DELETE"),
245
+ )
246
+
247
+ if expression:
248
+ merge_expressions.append(expression)
249
+
250
+ if Whens is None:
251
+ merge_expr = exp.merge(
252
+ *merge_expressions,
253
+ into=self_expr,
254
+ using=other_expr,
255
+ on=condition_columns.expression,
256
+ )
257
+ else:
258
+ merge_expr = exp.merge(
259
+ Whens(expressions=merge_expressions),
260
+ into=self_expr,
261
+ using=other_expr,
262
+ on=condition_columns.expression,
263
+ )
264
+
265
+ return LazyExpression(merge_expr, self.session)
266
+
267
+ def _ensure_and_normalize_condition(
268
+ self,
269
+ condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
270
+ other_df: DF,
271
+ clause: t.Optional[bool] = False,
272
+ ):
273
+ join_expression = self._add_ctes_to_expression(
274
+ self.expression, other_df.expression.copy().ctes
275
+ )
276
+ condition = self._ensure_and_normalize_cols(condition, self.expression)
277
+ self._handle_self_join(other_df, condition)
278
+
279
+ if isinstance(condition[0].expression, exp.Column) and not clause:
280
+ table_names = [
281
+ table.alias_or_name
282
+ for table in [
283
+ self.expression.args["from"].this,
284
+ other_df.expression.args["from"].this,
285
+ ]
286
+ ]
287
+
288
+ join_column_pairs, join_clause = self._handle_join_column_names_only(
289
+ condition, join_expression, other_df, table_names
290
+ )
291
+ else:
292
+ join_clause = self._normalize_join_clause(condition, join_expression)
293
+ return join_clause
294
+
295
+ def _ensure_and_normalize_assignments(
296
+ self,
297
+ assignments: t.Dict[
298
+ t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]
299
+ ],
300
+ other_df,
301
+ ) -> t.Dict[exp.Column, exp.Expression]:
302
+ self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
303
+ other_name = self._create_hash_from_expression(other_df.expression)
304
+ update_set = {}
305
+ for key, val in assignments.items():
306
+ key_column: Column = self._ensure_and_normalize_col(key)
307
+ key_expr = list(key_column.expression.find_all(exp.Column))
308
+ if len(key_expr) > 1:
309
+ raise ValueError(f"Target expression `{key_expr}` should be a single column.")
310
+ column_key = exp.column(key_expr[0].alias_or_name)
311
+
312
+ val = self._ensure_and_normalize_col(val)
313
+ val = self._ensure_and_normalize_cols(val, other_df.expression)[0]
314
+ if self.branch_id == other_df.branch_id:
315
+ other_df_unique_uuids = other_df.known_uuids - self.known_uuids
316
+ for col_expr in val.expression.find_all(exp.Column):
317
+ if (
318
+ "join_on_uuid" in col_expr.meta
319
+ and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
320
+ ):
321
+ col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
322
+
323
+ for col_expr in val.expression.find_all(exp.Column):
324
+ if not col_expr.table or col_expr.table == other_df.latest_cte_name:
325
+ col_expr.set("table", exp.to_identifier(other_name))
326
+ elif col_expr.table == self.expression.args["from"].this.alias_or_name:
327
+ col_expr.set("table", exp.to_identifier(self_name))
328
+ else:
329
+ raise ValueError(
330
+ f"Column `{col_expr.alias_or_name}` does not exist in any of the tables."
331
+ )
332
+ if isinstance(val.expression, exp.Alias):
333
+ val.expression = val.expression.this
334
+ update_set[column_key] = val.expression
335
+ return update_set
@@ -21,19 +21,20 @@ else:
21
21
  if t.TYPE_CHECKING:
22
22
  from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
23
23
  from sqlframe.base.column import Column
24
- from sqlframe.base.session import DF, _BaseSession
24
+ from sqlframe.base.session import DF, TABLE, _BaseSession
25
25
  from sqlframe.base.types import StructType
26
26
 
27
27
  SESSION = t.TypeVar("SESSION", bound=_BaseSession)
28
28
  else:
29
29
  SESSION = t.TypeVar("SESSION")
30
30
  DF = t.TypeVar("DF")
31
+ TABLE = t.TypeVar("TABLE")
31
32
 
32
33
 
33
34
  logger = logging.getLogger(__name__)
34
35
 
35
36
 
36
- class _BaseDataFrameReader(t.Generic[SESSION, DF]):
37
+ class _BaseDataFrameReader(t.Generic[SESSION, DF, TABLE]):
37
38
  def __init__(self, spark: SESSION):
38
39
  self._session = spark
39
40
  self.state_format_to_read: t.Optional[str] = None
@@ -42,7 +43,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
42
43
  def session(self) -> SESSION:
43
44
  return self._session
44
45
 
45
- def table(self, tableName: str) -> DF:
46
+ def table(self, tableName: str) -> TABLE:
46
47
  tableName = normalize_string(tableName, from_dialect="input", is_table=True)
47
48
  if df := self.session.temp_views.get(tableName):
48
49
  return df
@@ -50,7 +51,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
50
51
  self.session.catalog.add_table(table)
51
52
  columns = self.session.catalog.get_columns_from_schema(table)
52
53
 
53
- return self.session._create_df(
54
+ return self.session._create_table(
54
55
  exp.Select()
55
56
  .from_(tableName, dialect=self.session.input_dialect)
56
57
  .select(*columns, dialect=self.session.input_dialect)
sqlframe/base/session.py CHANGED
@@ -27,6 +27,7 @@ from sqlframe.base.catalog import _BaseCatalog
27
27
  from sqlframe.base.dataframe import BaseDataFrame
28
28
  from sqlframe.base.normalize import normalize_dict
29
29
  from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
30
+ from sqlframe.base.table import _BaseTable
30
31
  from sqlframe.base.udf import _BaseUDFRegistration
31
32
  from sqlframe.base.util import (
32
33
  get_column_mapping_from_schema_input,
@@ -65,17 +66,19 @@ CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
65
66
  READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
66
67
  WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
67
68
  DF = t.TypeVar("DF", bound=BaseDataFrame)
69
+ TABLE = t.TypeVar("TABLE", bound=_BaseTable)
68
70
  UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration)
69
71
 
70
72
  _MISSING = "MISSING"
71
73
 
72
74
 
73
- class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION]):
75
+ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, TABLE, CONN, UDF_REGISTRATION]):
74
76
  _instance = None
75
77
  _reader: t.Type[READER]
76
78
  _writer: t.Type[WRITER]
77
79
  _catalog: t.Type[CATALOG]
78
80
  _df: t.Type[DF]
81
+ _table: t.Type[TABLE]
79
82
  _udf_registration: t.Type[UDF_REGISTRATION]
80
83
 
81
84
  SANITIZE_COLUMN_NAMES = False
@@ -158,12 +161,15 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION
158
161
  return name.replace("(", "_").replace(")", "_")
159
162
  return name
160
163
 
161
- def table(self, tableName: str) -> DF:
164
+ def table(self, tableName: str) -> TABLE:
162
165
  return self.read.table(tableName)
163
166
 
164
167
  def _create_df(self, *args, **kwargs) -> DF:
165
168
  return self._df(self, *args, **kwargs)
166
169
 
170
+ def _create_table(self, *args, **kwargs) -> TABLE:
171
+ return self._table(self, *args, **kwargs)
172
+
167
173
  def __new__(cls, *args, **kwargs):
168
174
  if _BaseSession._instance is None:
169
175
  _BaseSession._instance = super().__new__(cls)