sqlframe 3.13.4__py3-none-any.whl → 3.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/dataframe.py +68 -51
- sqlframe/base/mixins/table_mixins.py +335 -0
- sqlframe/base/readerwriter.py +5 -4
- sqlframe/base/session.py +8 -2
- sqlframe/base/table.py +238 -0
- sqlframe/bigquery/catalog.py +1 -0
- sqlframe/bigquery/readwriter.py +2 -1
- sqlframe/bigquery/session.py +3 -0
- sqlframe/bigquery/table.py +24 -0
- sqlframe/databricks/readwriter.py +2 -1
- sqlframe/databricks/session.py +3 -0
- sqlframe/databricks/table.py +24 -0
- sqlframe/duckdb/readwriter.py +4 -1
- sqlframe/duckdb/session.py +3 -0
- sqlframe/duckdb/table.py +16 -0
- sqlframe/postgres/readwriter.py +2 -1
- sqlframe/postgres/session.py +3 -0
- sqlframe/postgres/table.py +24 -0
- sqlframe/redshift/readwriter.py +2 -1
- sqlframe/redshift/session.py +3 -0
- sqlframe/redshift/table.py +15 -0
- sqlframe/snowflake/readwriter.py +2 -1
- sqlframe/snowflake/session.py +3 -0
- sqlframe/snowflake/table.py +23 -0
- sqlframe/spark/readwriter.py +2 -1
- sqlframe/spark/session.py +3 -0
- sqlframe/spark/table.py +6 -0
- sqlframe/standalone/readwriter.py +4 -1
- sqlframe/standalone/session.py +3 -0
- sqlframe/standalone/table.py +6 -0
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.0.dist-info}/METADATA +1 -1
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.0.dist-info}/RECORD +36 -26
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.0.dist-info}/LICENSE +0 -0
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.0.dist-info}/WHEEL +0 -0
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.0.dist-info}/top_level.txt +0 -0
sqlframe/_version.py
CHANGED
sqlframe/base/dataframe.py
CHANGED
|
@@ -872,6 +872,68 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
872
872
|
"""
|
|
873
873
|
return self.join.__wrapped__(self, other, how="cross") # type: ignore
|
|
874
874
|
|
|
875
|
+
def _handle_self_join(self, other_df: DF, join_columns: t.List[Column]):
|
|
876
|
+
# If the two dataframes being joined come from the same branch, we then check if they have any columns that
|
|
877
|
+
# were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
|
|
878
|
+
# the two columns since they would end up with the same table name. We do this by checking for the unique
|
|
879
|
+
# uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
|
|
880
|
+
# it comes from the other df and we change the table name to the other df's table name.
|
|
881
|
+
# See `test_self_join` for an example of this.
|
|
882
|
+
if self.branch_id == other_df.branch_id:
|
|
883
|
+
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
884
|
+
for col in join_columns:
|
|
885
|
+
for col_expr in col.expression.find_all(exp.Column):
|
|
886
|
+
if (
|
|
887
|
+
"join_on_uuid" in col_expr.meta
|
|
888
|
+
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
889
|
+
):
|
|
890
|
+
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
891
|
+
|
|
892
|
+
@staticmethod
|
|
893
|
+
def _handle_join_column_names_only(
|
|
894
|
+
join_columns: t.List[Column],
|
|
895
|
+
join_expression: exp.Select,
|
|
896
|
+
other_df: DF,
|
|
897
|
+
table_names: t.List[str],
|
|
898
|
+
):
|
|
899
|
+
potential_ctes = [
|
|
900
|
+
cte
|
|
901
|
+
for cte in join_expression.ctes
|
|
902
|
+
if cte.alias_or_name in table_names and cte.alias_or_name != other_df.latest_cte_name
|
|
903
|
+
]
|
|
904
|
+
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
905
|
+
# tables and see if they have the column being referenced.
|
|
906
|
+
join_column_pairs = []
|
|
907
|
+
for join_column in join_columns:
|
|
908
|
+
num_matching_ctes = 0
|
|
909
|
+
for cte in potential_ctes:
|
|
910
|
+
if join_column.alias_or_name in cte.this.named_selects:
|
|
911
|
+
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
912
|
+
right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
|
|
913
|
+
join_column_pairs.append((left_column, right_column))
|
|
914
|
+
num_matching_ctes += 1
|
|
915
|
+
# We only want to match one table to the column and that should be matched left -> right
|
|
916
|
+
# so we break after the first match
|
|
917
|
+
break
|
|
918
|
+
if num_matching_ctes == 0:
|
|
919
|
+
raise ValueError(
|
|
920
|
+
f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
|
|
921
|
+
)
|
|
922
|
+
join_clause = functools.reduce(
|
|
923
|
+
lambda x, y: x & y,
|
|
924
|
+
[left_column == right_column for left_column, right_column in join_column_pairs],
|
|
925
|
+
)
|
|
926
|
+
return join_column_pairs, join_clause
|
|
927
|
+
|
|
928
|
+
def _normalize_join_clause(
|
|
929
|
+
self, join_columns: t.List[Column], join_expression: t.Optional[exp.Select]
|
|
930
|
+
) -> Column:
|
|
931
|
+
join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
|
|
932
|
+
if len(join_columns) > 1:
|
|
933
|
+
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
934
|
+
join_clause = join_columns[0]
|
|
935
|
+
return join_clause
|
|
936
|
+
|
|
875
937
|
@operation(Operation.FROM)
|
|
876
938
|
def join(
|
|
877
939
|
self,
|
|
@@ -895,21 +957,8 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
895
957
|
self_columns = self._get_outer_select_columns(join_expression)
|
|
896
958
|
other_columns = self._get_outer_select_columns(other_df.expression)
|
|
897
959
|
join_columns = self._ensure_and_normalize_cols(on)
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
# the two columns since they would end up with the same table name. We do this by checking for the unique
|
|
901
|
-
# uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
|
|
902
|
-
# it comes from the other df and we change the table name to the other df's table name.
|
|
903
|
-
# See `test_self_join` for an example of this.
|
|
904
|
-
if self.branch_id == other_df.branch_id:
|
|
905
|
-
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
906
|
-
for col in join_columns:
|
|
907
|
-
for col_expr in col.expression.find_all(exp.Column):
|
|
908
|
-
if (
|
|
909
|
-
"join_on_uuid" in col_expr.meta
|
|
910
|
-
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
911
|
-
):
|
|
912
|
-
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
960
|
+
self._handle_self_join(other_df, join_columns)
|
|
961
|
+
|
|
913
962
|
# Determines the join clause and select columns to be used passed on what type of columns were provided for
|
|
914
963
|
# the join. The columns returned changes based on how the on expression is provided.
|
|
915
964
|
if how != "cross":
|
|
@@ -923,38 +972,9 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
923
972
|
table.alias_or_name
|
|
924
973
|
for table in get_tables_from_expression_with_join(join_expression)
|
|
925
974
|
]
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
if cte.alias_or_name in table_names
|
|
930
|
-
and cte.alias_or_name != other_df.latest_cte_name
|
|
931
|
-
]
|
|
932
|
-
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
933
|
-
# tables and see if they have the column being referenced.
|
|
934
|
-
join_column_pairs = []
|
|
935
|
-
for join_column in join_columns:
|
|
936
|
-
num_matching_ctes = 0
|
|
937
|
-
for cte in potential_ctes:
|
|
938
|
-
if join_column.alias_or_name in cte.this.named_selects:
|
|
939
|
-
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
940
|
-
right_column = join_column.copy().set_table_name(
|
|
941
|
-
other_df.latest_cte_name
|
|
942
|
-
)
|
|
943
|
-
join_column_pairs.append((left_column, right_column))
|
|
944
|
-
num_matching_ctes += 1
|
|
945
|
-
# We only want to match one table to the column and that should be matched left -> right
|
|
946
|
-
# so we break after the first match
|
|
947
|
-
break
|
|
948
|
-
if num_matching_ctes == 0:
|
|
949
|
-
raise ValueError(
|
|
950
|
-
f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
|
|
951
|
-
)
|
|
952
|
-
join_clause = functools.reduce(
|
|
953
|
-
lambda x, y: x & y,
|
|
954
|
-
[
|
|
955
|
-
left_column == right_column
|
|
956
|
-
for left_column, right_column in join_column_pairs
|
|
957
|
-
],
|
|
975
|
+
|
|
976
|
+
join_column_pairs, join_clause = self._handle_join_column_names_only(
|
|
977
|
+
join_columns, join_expression, other_df, table_names
|
|
958
978
|
)
|
|
959
979
|
join_column_names = [
|
|
960
980
|
coalesce(
|
|
@@ -989,10 +1009,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
989
1009
|
* There is no deduplication of the results.
|
|
990
1010
|
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
|
|
991
1011
|
"""
|
|
992
|
-
|
|
993
|
-
if len(join_columns) > 1:
|
|
994
|
-
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
995
|
-
join_clause = join_columns[0]
|
|
1012
|
+
join_clause = self._normalize_join_clause(join_columns, join_expression)
|
|
996
1013
|
select_column_names = [
|
|
997
1014
|
column.alias_or_name for column in self_columns + other_columns
|
|
998
1015
|
]
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import typing as t
|
|
4
|
+
|
|
5
|
+
from sqlglot import exp
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from sqlglot.expressions import Whens
|
|
9
|
+
except ImportError:
|
|
10
|
+
Whens = None # type: ignore
|
|
11
|
+
from sqlglot.helper import object_to_dict
|
|
12
|
+
|
|
13
|
+
from sqlframe.base.column import Column
|
|
14
|
+
from sqlframe.base.table import (
|
|
15
|
+
DF,
|
|
16
|
+
Clause,
|
|
17
|
+
LazyExpression,
|
|
18
|
+
WhenMatched,
|
|
19
|
+
WhenNotMatched,
|
|
20
|
+
WhenNotMatchedBySource,
|
|
21
|
+
_BaseTable,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
if t.TYPE_CHECKING:
|
|
25
|
+
from sqlframe.base._typing import ColumnOrLiteral
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def ensure_cte() -> t.Callable[[t.Callable], t.Callable]:
|
|
32
|
+
def decorator(func: t.Callable) -> t.Callable:
|
|
33
|
+
@functools.wraps(func)
|
|
34
|
+
def wrapper(self: _BaseTable, *args, **kwargs) -> t.Any:
|
|
35
|
+
if len(self.expression.ctes) > 0:
|
|
36
|
+
return func(self, *args, **kwargs) # type: ignore
|
|
37
|
+
self_class = self.__class__
|
|
38
|
+
self = self._convert_leaf_to_cte()
|
|
39
|
+
self = self_class(**object_to_dict(self))
|
|
40
|
+
return func(self, *args, **kwargs) # type: ignore
|
|
41
|
+
|
|
42
|
+
wrapper.__wrapped__ = func # type: ignore
|
|
43
|
+
return wrapper
|
|
44
|
+
|
|
45
|
+
return decorator
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class _BaseTableMixins(_BaseTable, t.Generic[DF]):
|
|
49
|
+
def _ensure_where_condition(
|
|
50
|
+
self, where: t.Optional[t.Union[Column, str, bool]] = None
|
|
51
|
+
) -> exp.Expression:
|
|
52
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
53
|
+
|
|
54
|
+
if where is None:
|
|
55
|
+
logger.warning("Empty value for `where`clause. Defaults to `True`.")
|
|
56
|
+
condition: exp.Expression = exp.Boolean(this=True)
|
|
57
|
+
else:
|
|
58
|
+
condition_list = self._ensure_and_normalize_cols(where, self.expression)
|
|
59
|
+
if len(condition_list) > 1:
|
|
60
|
+
condition_list = [functools.reduce(lambda x, y: x & y, condition_list)]
|
|
61
|
+
for col_expr in condition_list[0].expression.find_all(exp.Column):
|
|
62
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
63
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
64
|
+
condition = condition_list[0].expression
|
|
65
|
+
if isinstance(condition, exp.Alias):
|
|
66
|
+
condition = condition.this
|
|
67
|
+
return condition
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class UpdateSupportMixin(_BaseTableMixins, t.Generic[DF]):
|
|
71
|
+
@ensure_cte()
|
|
72
|
+
def update(
|
|
73
|
+
self,
|
|
74
|
+
set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
|
|
75
|
+
where: t.Optional[t.Union[Column, str, bool]] = None,
|
|
76
|
+
) -> LazyExpression:
|
|
77
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
78
|
+
|
|
79
|
+
condition = self._ensure_where_condition(where)
|
|
80
|
+
update_set = self._ensure_and_normalize_update_set(set_)
|
|
81
|
+
update_expr = exp.Update(
|
|
82
|
+
this=self_expr,
|
|
83
|
+
expressions=[
|
|
84
|
+
exp.EQ(
|
|
85
|
+
this=key,
|
|
86
|
+
expression=val,
|
|
87
|
+
)
|
|
88
|
+
for key, val in update_set.items()
|
|
89
|
+
],
|
|
90
|
+
where=exp.Where(this=condition),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return LazyExpression(update_expr, self.session)
|
|
94
|
+
|
|
95
|
+
def _ensure_and_normalize_update_set(
|
|
96
|
+
self,
|
|
97
|
+
set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
|
|
98
|
+
) -> t.Dict[str, exp.Expression]:
|
|
99
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
100
|
+
update_set = {}
|
|
101
|
+
for key, val in set_.items():
|
|
102
|
+
key_column: Column = self._ensure_and_normalize_col(key)
|
|
103
|
+
key_expr = list(key_column.expression.find_all(exp.Column))
|
|
104
|
+
if len(key_expr) > 1:
|
|
105
|
+
raise ValueError(f"Can only update one a single column at a time.")
|
|
106
|
+
key = key_expr[0].alias_or_name
|
|
107
|
+
|
|
108
|
+
val_column: Column = self._ensure_and_normalize_col(val)
|
|
109
|
+
for col_expr in val_column.expression.find_all(exp.Column):
|
|
110
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
111
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Column `{col_expr.alias_or_name}` does not exist in the table."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
update_set[key] = val_column.expression
|
|
118
|
+
return update_set
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class DeleteSupportMixin(_BaseTableMixins, t.Generic[DF]):
|
|
122
|
+
@ensure_cte()
|
|
123
|
+
def delete(
|
|
124
|
+
self,
|
|
125
|
+
where: t.Optional[t.Union[Column, str, bool]] = None,
|
|
126
|
+
) -> LazyExpression:
|
|
127
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
128
|
+
|
|
129
|
+
condition = self._ensure_where_condition(where)
|
|
130
|
+
delete_expr = exp.Delete(
|
|
131
|
+
this=self_expr,
|
|
132
|
+
where=exp.Where(this=condition),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return LazyExpression(delete_expr, self.session)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class MergeSupportMixin(_BaseTable, t.Generic[DF]):
|
|
139
|
+
_merge_supported_clauses: t.Iterable[
|
|
140
|
+
t.Union[t.Type[WhenMatched], t.Type[WhenNotMatched], t.Type[WhenNotMatchedBySource]]
|
|
141
|
+
]
|
|
142
|
+
_merge_support_star: bool
|
|
143
|
+
|
|
144
|
+
@ensure_cte()
|
|
145
|
+
def merge(
|
|
146
|
+
self,
|
|
147
|
+
other_df: DF,
|
|
148
|
+
condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
|
|
149
|
+
clauses: t.Iterable[t.Union[WhenMatched, WhenNotMatched, WhenNotMatchedBySource]],
|
|
150
|
+
) -> LazyExpression:
|
|
151
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
152
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
153
|
+
|
|
154
|
+
other_df = other_df._convert_leaf_to_cte()
|
|
155
|
+
|
|
156
|
+
if condition is None:
|
|
157
|
+
raise ValueError("condition cannot be None")
|
|
158
|
+
|
|
159
|
+
condition_columns: Column = self._ensure_and_normalize_condition(condition, other_df)
|
|
160
|
+
other_name = self._create_hash_from_expression(other_df.expression)
|
|
161
|
+
other_expr = exp.Subquery(
|
|
162
|
+
this=other_df.expression, alias=exp.TableAlias(this=exp.to_identifier(other_name))
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
for col_expr in condition_columns.expression.find_all(exp.Column):
|
|
166
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
167
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
168
|
+
if col_expr.table == other_df.latest_cte_name:
|
|
169
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
170
|
+
|
|
171
|
+
merge_expressions = []
|
|
172
|
+
for clause in clauses:
|
|
173
|
+
if not isinstance(clause, tuple(self._merge_supported_clauses)):
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"Unsupported clause type {type(clause.clause)} for merge operation"
|
|
176
|
+
)
|
|
177
|
+
expression = None
|
|
178
|
+
|
|
179
|
+
if clause.clause.condition is not None:
|
|
180
|
+
cond_clause = self._ensure_and_normalize_condition(
|
|
181
|
+
clause.clause.condition, other_df, True
|
|
182
|
+
)
|
|
183
|
+
for col_expr in cond_clause.expression.find_all(exp.Column):
|
|
184
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
185
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
186
|
+
if col_expr.table == other_df.latest_cte_name:
|
|
187
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
188
|
+
else:
|
|
189
|
+
cond_clause = None
|
|
190
|
+
if clause.clause.clause_type == Clause.UPDATE:
|
|
191
|
+
update_set = self._ensure_and_normalize_assignments(
|
|
192
|
+
clause.clause.assignments, other_df
|
|
193
|
+
)
|
|
194
|
+
expression = exp.When(
|
|
195
|
+
matched=clause.clause.matched,
|
|
196
|
+
source=clause.clause.by_source,
|
|
197
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
198
|
+
then=exp.Update(
|
|
199
|
+
expressions=[
|
|
200
|
+
exp.EQ(
|
|
201
|
+
this=key,
|
|
202
|
+
expression=val,
|
|
203
|
+
)
|
|
204
|
+
for key, val in update_set.items()
|
|
205
|
+
]
|
|
206
|
+
),
|
|
207
|
+
)
|
|
208
|
+
if clause.clause.clause_type == Clause.UPDATE_ALL:
|
|
209
|
+
if not self._support_star:
|
|
210
|
+
raise ValueError("Merge operation does not support UPDATE_ALL")
|
|
211
|
+
expression = exp.When(
|
|
212
|
+
matched=clause.clause.matched,
|
|
213
|
+
source=clause.clause.by_source,
|
|
214
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
215
|
+
then=exp.Update(expressions=[exp.Star()]),
|
|
216
|
+
)
|
|
217
|
+
elif clause.clause.clause_type == Clause.INSERT:
|
|
218
|
+
insert_values = self._ensure_and_normalize_assignments(
|
|
219
|
+
clause.clause.assignments, other_df
|
|
220
|
+
)
|
|
221
|
+
expression = exp.When(
|
|
222
|
+
matched=clause.clause.matched,
|
|
223
|
+
source=clause.clause.by_source,
|
|
224
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
225
|
+
then=exp.Insert(
|
|
226
|
+
this=exp.Tuple(expressions=[key for key in insert_values.keys()]),
|
|
227
|
+
expression=exp.Tuple(expressions=[val for val in insert_values.values()]),
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
elif clause.clause.clause_type == Clause.INSERT_ALL:
|
|
231
|
+
if not self._support_star:
|
|
232
|
+
raise ValueError("Merge operation does not support INSERT_ALL")
|
|
233
|
+
expression = exp.When(
|
|
234
|
+
matched=clause.clause.matched,
|
|
235
|
+
source=clause.clause.by_source,
|
|
236
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
237
|
+
then=exp.Insert(expression=exp.Star()),
|
|
238
|
+
)
|
|
239
|
+
elif clause.clause.clause_type == Clause.DELETE:
|
|
240
|
+
expression = exp.When(
|
|
241
|
+
matched=clause.clause.matched,
|
|
242
|
+
source=clause.clause.by_source,
|
|
243
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
244
|
+
then=exp.var("DELETE"),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if expression:
|
|
248
|
+
merge_expressions.append(expression)
|
|
249
|
+
|
|
250
|
+
if Whens is None:
|
|
251
|
+
merge_expr = exp.merge(
|
|
252
|
+
*merge_expressions,
|
|
253
|
+
into=self_expr,
|
|
254
|
+
using=other_expr,
|
|
255
|
+
on=condition_columns.expression,
|
|
256
|
+
)
|
|
257
|
+
else:
|
|
258
|
+
merge_expr = exp.merge(
|
|
259
|
+
Whens(expressions=merge_expressions),
|
|
260
|
+
into=self_expr,
|
|
261
|
+
using=other_expr,
|
|
262
|
+
on=condition_columns.expression,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return LazyExpression(merge_expr, self.session)
|
|
266
|
+
|
|
267
|
+
def _ensure_and_normalize_condition(
|
|
268
|
+
self,
|
|
269
|
+
condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
|
|
270
|
+
other_df: DF,
|
|
271
|
+
clause: t.Optional[bool] = False,
|
|
272
|
+
):
|
|
273
|
+
join_expression = self._add_ctes_to_expression(
|
|
274
|
+
self.expression, other_df.expression.copy().ctes
|
|
275
|
+
)
|
|
276
|
+
condition = self._ensure_and_normalize_cols(condition, self.expression)
|
|
277
|
+
self._handle_self_join(other_df, condition)
|
|
278
|
+
|
|
279
|
+
if isinstance(condition[0].expression, exp.Column) and not clause:
|
|
280
|
+
table_names = [
|
|
281
|
+
table.alias_or_name
|
|
282
|
+
for table in [
|
|
283
|
+
self.expression.args["from"].this,
|
|
284
|
+
other_df.expression.args["from"].this,
|
|
285
|
+
]
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
join_column_pairs, join_clause = self._handle_join_column_names_only(
|
|
289
|
+
condition, join_expression, other_df, table_names
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
join_clause = self._normalize_join_clause(condition, join_expression)
|
|
293
|
+
return join_clause
|
|
294
|
+
|
|
295
|
+
def _ensure_and_normalize_assignments(
|
|
296
|
+
self,
|
|
297
|
+
assignments: t.Dict[
|
|
298
|
+
t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]
|
|
299
|
+
],
|
|
300
|
+
other_df,
|
|
301
|
+
) -> t.Dict[exp.Column, exp.Expression]:
|
|
302
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
303
|
+
other_name = self._create_hash_from_expression(other_df.expression)
|
|
304
|
+
update_set = {}
|
|
305
|
+
for key, val in assignments.items():
|
|
306
|
+
key_column: Column = self._ensure_and_normalize_col(key)
|
|
307
|
+
key_expr = list(key_column.expression.find_all(exp.Column))
|
|
308
|
+
if len(key_expr) > 1:
|
|
309
|
+
raise ValueError(f"Target expression `{key_expr}` should be a single column.")
|
|
310
|
+
column_key = exp.column(key_expr[0].alias_or_name)
|
|
311
|
+
|
|
312
|
+
val = self._ensure_and_normalize_col(val)
|
|
313
|
+
val = self._ensure_and_normalize_cols(val, other_df.expression)[0]
|
|
314
|
+
if self.branch_id == other_df.branch_id:
|
|
315
|
+
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
316
|
+
for col_expr in val.expression.find_all(exp.Column):
|
|
317
|
+
if (
|
|
318
|
+
"join_on_uuid" in col_expr.meta
|
|
319
|
+
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
320
|
+
):
|
|
321
|
+
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
322
|
+
|
|
323
|
+
for col_expr in val.expression.find_all(exp.Column):
|
|
324
|
+
if not col_expr.table or col_expr.table == other_df.latest_cte_name:
|
|
325
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
326
|
+
elif col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
327
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
328
|
+
else:
|
|
329
|
+
raise ValueError(
|
|
330
|
+
f"Column `{col_expr.alias_or_name}` does not exist in any of the tables."
|
|
331
|
+
)
|
|
332
|
+
if isinstance(val.expression, exp.Alias):
|
|
333
|
+
val.expression = val.expression.this
|
|
334
|
+
update_set[column_key] = val.expression
|
|
335
|
+
return update_set
|
sqlframe/base/readerwriter.py
CHANGED
|
@@ -21,19 +21,20 @@ else:
|
|
|
21
21
|
if t.TYPE_CHECKING:
|
|
22
22
|
from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
|
|
23
23
|
from sqlframe.base.column import Column
|
|
24
|
-
from sqlframe.base.session import DF, _BaseSession
|
|
24
|
+
from sqlframe.base.session import DF, TABLE, _BaseSession
|
|
25
25
|
from sqlframe.base.types import StructType
|
|
26
26
|
|
|
27
27
|
SESSION = t.TypeVar("SESSION", bound=_BaseSession)
|
|
28
28
|
else:
|
|
29
29
|
SESSION = t.TypeVar("SESSION")
|
|
30
30
|
DF = t.TypeVar("DF")
|
|
31
|
+
TABLE = t.TypeVar("TABLE")
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
37
|
+
class _BaseDataFrameReader(t.Generic[SESSION, DF, TABLE]):
|
|
37
38
|
def __init__(self, spark: SESSION):
|
|
38
39
|
self._session = spark
|
|
39
40
|
self.state_format_to_read: t.Optional[str] = None
|
|
@@ -42,7 +43,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
|
42
43
|
def session(self) -> SESSION:
|
|
43
44
|
return self._session
|
|
44
45
|
|
|
45
|
-
def table(self, tableName: str) ->
|
|
46
|
+
def table(self, tableName: str) -> TABLE:
|
|
46
47
|
tableName = normalize_string(tableName, from_dialect="input", is_table=True)
|
|
47
48
|
if df := self.session.temp_views.get(tableName):
|
|
48
49
|
return df
|
|
@@ -50,7 +51,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
|
50
51
|
self.session.catalog.add_table(table)
|
|
51
52
|
columns = self.session.catalog.get_columns_from_schema(table)
|
|
52
53
|
|
|
53
|
-
return self.session.
|
|
54
|
+
return self.session._create_table(
|
|
54
55
|
exp.Select()
|
|
55
56
|
.from_(tableName, dialect=self.session.input_dialect)
|
|
56
57
|
.select(*columns, dialect=self.session.input_dialect)
|
sqlframe/base/session.py
CHANGED
|
@@ -27,6 +27,7 @@ from sqlframe.base.catalog import _BaseCatalog
|
|
|
27
27
|
from sqlframe.base.dataframe import BaseDataFrame
|
|
28
28
|
from sqlframe.base.normalize import normalize_dict
|
|
29
29
|
from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
|
|
30
|
+
from sqlframe.base.table import _BaseTable
|
|
30
31
|
from sqlframe.base.udf import _BaseUDFRegistration
|
|
31
32
|
from sqlframe.base.util import (
|
|
32
33
|
get_column_mapping_from_schema_input,
|
|
@@ -65,17 +66,19 @@ CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
|
|
|
65
66
|
READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
|
|
66
67
|
WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
|
|
67
68
|
DF = t.TypeVar("DF", bound=BaseDataFrame)
|
|
69
|
+
TABLE = t.TypeVar("TABLE", bound=_BaseTable)
|
|
68
70
|
UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration)
|
|
69
71
|
|
|
70
72
|
_MISSING = "MISSING"
|
|
71
73
|
|
|
72
74
|
|
|
73
|
-
class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION]):
|
|
75
|
+
class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, TABLE, CONN, UDF_REGISTRATION]):
|
|
74
76
|
_instance = None
|
|
75
77
|
_reader: t.Type[READER]
|
|
76
78
|
_writer: t.Type[WRITER]
|
|
77
79
|
_catalog: t.Type[CATALOG]
|
|
78
80
|
_df: t.Type[DF]
|
|
81
|
+
_table: t.Type[TABLE]
|
|
79
82
|
_udf_registration: t.Type[UDF_REGISTRATION]
|
|
80
83
|
|
|
81
84
|
SANITIZE_COLUMN_NAMES = False
|
|
@@ -158,12 +161,15 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION
|
|
|
158
161
|
return name.replace("(", "_").replace(")", "_")
|
|
159
162
|
return name
|
|
160
163
|
|
|
161
|
-
def table(self, tableName: str) ->
|
|
164
|
+
def table(self, tableName: str) -> TABLE:
|
|
162
165
|
return self.read.table(tableName)
|
|
163
166
|
|
|
164
167
|
def _create_df(self, *args, **kwargs) -> DF:
|
|
165
168
|
return self._df(self, *args, **kwargs)
|
|
166
169
|
|
|
170
|
+
def _create_table(self, *args, **kwargs) -> TABLE:
|
|
171
|
+
return self._table(self, *args, **kwargs)
|
|
172
|
+
|
|
167
173
|
def __new__(cls, *args, **kwargs):
|
|
168
174
|
if _BaseSession._instance is None:
|
|
169
175
|
_BaseSession._instance = super().__new__(cls)
|