sqlframe 3.13.3__py3-none-any.whl → 3.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/dataframe.py +78 -54
- sqlframe/base/mixins/table_mixins.py +335 -0
- sqlframe/base/readerwriter.py +5 -4
- sqlframe/base/session.py +8 -2
- sqlframe/base/table.py +238 -0
- sqlframe/bigquery/catalog.py +1 -0
- sqlframe/bigquery/readwriter.py +2 -1
- sqlframe/bigquery/session.py +3 -0
- sqlframe/bigquery/table.py +24 -0
- sqlframe/databricks/readwriter.py +2 -1
- sqlframe/databricks/session.py +3 -0
- sqlframe/databricks/table.py +24 -0
- sqlframe/duckdb/readwriter.py +4 -1
- sqlframe/duckdb/session.py +3 -0
- sqlframe/duckdb/table.py +16 -0
- sqlframe/postgres/readwriter.py +2 -1
- sqlframe/postgres/session.py +3 -0
- sqlframe/postgres/table.py +24 -0
- sqlframe/redshift/readwriter.py +2 -1
- sqlframe/redshift/session.py +3 -0
- sqlframe/redshift/table.py +15 -0
- sqlframe/snowflake/readwriter.py +2 -1
- sqlframe/snowflake/session.py +3 -0
- sqlframe/snowflake/table.py +23 -0
- sqlframe/spark/readwriter.py +2 -1
- sqlframe/spark/session.py +3 -0
- sqlframe/spark/table.py +6 -0
- sqlframe/standalone/readwriter.py +4 -1
- sqlframe/standalone/session.py +3 -0
- sqlframe/standalone/table.py +6 -0
- {sqlframe-3.13.3.dist-info → sqlframe-3.14.0.dist-info}/METADATA +1 -1
- {sqlframe-3.13.3.dist-info → sqlframe-3.14.0.dist-info}/RECORD +36 -26
- {sqlframe-3.13.3.dist-info → sqlframe-3.14.0.dist-info}/LICENSE +0 -0
- {sqlframe-3.13.3.dist-info → sqlframe-3.14.0.dist-info}/WHEEL +0 -0
- {sqlframe-3.13.3.dist-info → sqlframe-3.14.0.dist-info}/top_level.txt +0 -0
sqlframe/_version.py
CHANGED
sqlframe/base/dataframe.py
CHANGED
|
@@ -481,6 +481,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
481
481
|
cte = cte.transform(replace_id_value, replaced_cte_names) # type: ignore
|
|
482
482
|
if cte.alias_or_name in existing_cte_counts:
|
|
483
483
|
existing_cte_counts[cte.alias_or_name] += 10
|
|
484
|
+
# Add unique where filter to ensure that the hash of the CTE is unique
|
|
484
485
|
cte.set(
|
|
485
486
|
"this",
|
|
486
487
|
cte.this.where(
|
|
@@ -502,6 +503,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
502
503
|
new_cte_alias, dialect=self.session.input_dialect, into=exp.TableAlias
|
|
503
504
|
),
|
|
504
505
|
)
|
|
506
|
+
existing_cte_counts[new_cte_alias] = 0
|
|
505
507
|
existing_ctes.append(cte)
|
|
506
508
|
else:
|
|
507
509
|
existing_ctes = ctes
|
|
@@ -755,15 +757,20 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
755
757
|
]
|
|
756
758
|
cte_names_in_join = [x.this for x in join_table_identifiers]
|
|
757
759
|
# If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
|
|
758
|
-
# and therefore we allow multiple columns with the same
|
|
759
|
-
# of Spark.
|
|
760
|
+
# (or right to left if a right join) and therefore we allow multiple columns with the same
|
|
761
|
+
# name in the result. This matches the behavior of Spark.
|
|
760
762
|
resolved_column_position: t.Dict[exp.Column, int] = {
|
|
761
763
|
col.copy(): -1 for col in ambiguous_cols
|
|
762
764
|
}
|
|
763
765
|
for ambiguous_col in ambiguous_cols:
|
|
766
|
+
ctes = (
|
|
767
|
+
list(reversed(self.expression.ctes))
|
|
768
|
+
if self.expression.args["joins"][0].args.get("side", "") == "right"
|
|
769
|
+
else self.expression.ctes
|
|
770
|
+
)
|
|
764
771
|
ctes_with_column = [
|
|
765
772
|
cte
|
|
766
|
-
for cte in
|
|
773
|
+
for cte in ctes
|
|
767
774
|
if cte.alias_or_name in cte_names_in_join
|
|
768
775
|
and ambiguous_col.alias_or_name in cte.this.named_selects
|
|
769
776
|
]
|
|
@@ -865,6 +872,68 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
865
872
|
"""
|
|
866
873
|
return self.join.__wrapped__(self, other, how="cross") # type: ignore
|
|
867
874
|
|
|
875
|
+
def _handle_self_join(self, other_df: DF, join_columns: t.List[Column]):
|
|
876
|
+
# If the two dataframes being joined come from the same branch, we then check if they have any columns that
|
|
877
|
+
# were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
|
|
878
|
+
# the two columns since they would end up with the same table name. We do this by checking for the unique
|
|
879
|
+
# uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
|
|
880
|
+
# it comes from the other df and we change the table name to the other df's table name.
|
|
881
|
+
# See `test_self_join` for an example of this.
|
|
882
|
+
if self.branch_id == other_df.branch_id:
|
|
883
|
+
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
884
|
+
for col in join_columns:
|
|
885
|
+
for col_expr in col.expression.find_all(exp.Column):
|
|
886
|
+
if (
|
|
887
|
+
"join_on_uuid" in col_expr.meta
|
|
888
|
+
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
889
|
+
):
|
|
890
|
+
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
891
|
+
|
|
892
|
+
@staticmethod
|
|
893
|
+
def _handle_join_column_names_only(
|
|
894
|
+
join_columns: t.List[Column],
|
|
895
|
+
join_expression: exp.Select,
|
|
896
|
+
other_df: DF,
|
|
897
|
+
table_names: t.List[str],
|
|
898
|
+
):
|
|
899
|
+
potential_ctes = [
|
|
900
|
+
cte
|
|
901
|
+
for cte in join_expression.ctes
|
|
902
|
+
if cte.alias_or_name in table_names and cte.alias_or_name != other_df.latest_cte_name
|
|
903
|
+
]
|
|
904
|
+
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
905
|
+
# tables and see if they have the column being referenced.
|
|
906
|
+
join_column_pairs = []
|
|
907
|
+
for join_column in join_columns:
|
|
908
|
+
num_matching_ctes = 0
|
|
909
|
+
for cte in potential_ctes:
|
|
910
|
+
if join_column.alias_or_name in cte.this.named_selects:
|
|
911
|
+
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
912
|
+
right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
|
|
913
|
+
join_column_pairs.append((left_column, right_column))
|
|
914
|
+
num_matching_ctes += 1
|
|
915
|
+
# We only want to match one table to the column and that should be matched left -> right
|
|
916
|
+
# so we break after the first match
|
|
917
|
+
break
|
|
918
|
+
if num_matching_ctes == 0:
|
|
919
|
+
raise ValueError(
|
|
920
|
+
f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
|
|
921
|
+
)
|
|
922
|
+
join_clause = functools.reduce(
|
|
923
|
+
lambda x, y: x & y,
|
|
924
|
+
[left_column == right_column for left_column, right_column in join_column_pairs],
|
|
925
|
+
)
|
|
926
|
+
return join_column_pairs, join_clause
|
|
927
|
+
|
|
928
|
+
def _normalize_join_clause(
|
|
929
|
+
self, join_columns: t.List[Column], join_expression: t.Optional[exp.Select]
|
|
930
|
+
) -> Column:
|
|
931
|
+
join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
|
|
932
|
+
if len(join_columns) > 1:
|
|
933
|
+
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
934
|
+
join_clause = join_columns[0]
|
|
935
|
+
return join_clause
|
|
936
|
+
|
|
868
937
|
@operation(Operation.FROM)
|
|
869
938
|
def join(
|
|
870
939
|
self,
|
|
@@ -888,21 +957,8 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
888
957
|
self_columns = self._get_outer_select_columns(join_expression)
|
|
889
958
|
other_columns = self._get_outer_select_columns(other_df.expression)
|
|
890
959
|
join_columns = self._ensure_and_normalize_cols(on)
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
# the two columns since they would end up with the same table name. We do this by checking for the unique
|
|
894
|
-
# uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
|
|
895
|
-
# it comes from the other df and we change the table name to the other df's table name.
|
|
896
|
-
# See `test_self_join` for an example of this.
|
|
897
|
-
if self.branch_id == other_df.branch_id:
|
|
898
|
-
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
899
|
-
for col in join_columns:
|
|
900
|
-
for col_expr in col.expression.find_all(exp.Column):
|
|
901
|
-
if (
|
|
902
|
-
"join_on_uuid" in col_expr.meta
|
|
903
|
-
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
904
|
-
):
|
|
905
|
-
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
960
|
+
self._handle_self_join(other_df, join_columns)
|
|
961
|
+
|
|
906
962
|
# Determines the join clause and select columns to be used passed on what type of columns were provided for
|
|
907
963
|
# the join. The columns returned changes based on how the on expression is provided.
|
|
908
964
|
if how != "cross":
|
|
@@ -916,38 +972,9 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
916
972
|
table.alias_or_name
|
|
917
973
|
for table in get_tables_from_expression_with_join(join_expression)
|
|
918
974
|
]
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
if cte.alias_or_name in table_names
|
|
923
|
-
and cte.alias_or_name != other_df.latest_cte_name
|
|
924
|
-
]
|
|
925
|
-
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
926
|
-
# tables and see if they have the column being referenced.
|
|
927
|
-
join_column_pairs = []
|
|
928
|
-
for join_column in join_columns:
|
|
929
|
-
num_matching_ctes = 0
|
|
930
|
-
for cte in potential_ctes:
|
|
931
|
-
if join_column.alias_or_name in cte.this.named_selects:
|
|
932
|
-
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
933
|
-
right_column = join_column.copy().set_table_name(
|
|
934
|
-
other_df.latest_cte_name
|
|
935
|
-
)
|
|
936
|
-
join_column_pairs.append((left_column, right_column))
|
|
937
|
-
num_matching_ctes += 1
|
|
938
|
-
# We only want to match one table to the column and that should be matched left -> right
|
|
939
|
-
# so we break after the first match
|
|
940
|
-
break
|
|
941
|
-
if num_matching_ctes == 0:
|
|
942
|
-
raise ValueError(
|
|
943
|
-
f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
|
|
944
|
-
)
|
|
945
|
-
join_clause = functools.reduce(
|
|
946
|
-
lambda x, y: x & y,
|
|
947
|
-
[
|
|
948
|
-
left_column == right_column
|
|
949
|
-
for left_column, right_column in join_column_pairs
|
|
950
|
-
],
|
|
975
|
+
|
|
976
|
+
join_column_pairs, join_clause = self._handle_join_column_names_only(
|
|
977
|
+
join_columns, join_expression, other_df, table_names
|
|
951
978
|
)
|
|
952
979
|
join_column_names = [
|
|
953
980
|
coalesce(
|
|
@@ -982,10 +1009,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
982
1009
|
* There is no deduplication of the results.
|
|
983
1010
|
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
|
|
984
1011
|
"""
|
|
985
|
-
|
|
986
|
-
if len(join_columns) > 1:
|
|
987
|
-
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
988
|
-
join_clause = join_columns[0]
|
|
1012
|
+
join_clause = self._normalize_join_clause(join_columns, join_expression)
|
|
989
1013
|
select_column_names = [
|
|
990
1014
|
column.alias_or_name for column in self_columns + other_columns
|
|
991
1015
|
]
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import typing as t
|
|
4
|
+
|
|
5
|
+
from sqlglot import exp
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from sqlglot.expressions import Whens
|
|
9
|
+
except ImportError:
|
|
10
|
+
Whens = None # type: ignore
|
|
11
|
+
from sqlglot.helper import object_to_dict
|
|
12
|
+
|
|
13
|
+
from sqlframe.base.column import Column
|
|
14
|
+
from sqlframe.base.table import (
|
|
15
|
+
DF,
|
|
16
|
+
Clause,
|
|
17
|
+
LazyExpression,
|
|
18
|
+
WhenMatched,
|
|
19
|
+
WhenNotMatched,
|
|
20
|
+
WhenNotMatchedBySource,
|
|
21
|
+
_BaseTable,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
if t.TYPE_CHECKING:
|
|
25
|
+
from sqlframe.base._typing import ColumnOrLiteral
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def ensure_cte() -> t.Callable[[t.Callable], t.Callable]:
|
|
32
|
+
def decorator(func: t.Callable) -> t.Callable:
|
|
33
|
+
@functools.wraps(func)
|
|
34
|
+
def wrapper(self: _BaseTable, *args, **kwargs) -> t.Any:
|
|
35
|
+
if len(self.expression.ctes) > 0:
|
|
36
|
+
return func(self, *args, **kwargs) # type: ignore
|
|
37
|
+
self_class = self.__class__
|
|
38
|
+
self = self._convert_leaf_to_cte()
|
|
39
|
+
self = self_class(**object_to_dict(self))
|
|
40
|
+
return func(self, *args, **kwargs) # type: ignore
|
|
41
|
+
|
|
42
|
+
wrapper.__wrapped__ = func # type: ignore
|
|
43
|
+
return wrapper
|
|
44
|
+
|
|
45
|
+
return decorator
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class _BaseTableMixins(_BaseTable, t.Generic[DF]):
|
|
49
|
+
def _ensure_where_condition(
|
|
50
|
+
self, where: t.Optional[t.Union[Column, str, bool]] = None
|
|
51
|
+
) -> exp.Expression:
|
|
52
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
53
|
+
|
|
54
|
+
if where is None:
|
|
55
|
+
logger.warning("Empty value for `where`clause. Defaults to `True`.")
|
|
56
|
+
condition: exp.Expression = exp.Boolean(this=True)
|
|
57
|
+
else:
|
|
58
|
+
condition_list = self._ensure_and_normalize_cols(where, self.expression)
|
|
59
|
+
if len(condition_list) > 1:
|
|
60
|
+
condition_list = [functools.reduce(lambda x, y: x & y, condition_list)]
|
|
61
|
+
for col_expr in condition_list[0].expression.find_all(exp.Column):
|
|
62
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
63
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
64
|
+
condition = condition_list[0].expression
|
|
65
|
+
if isinstance(condition, exp.Alias):
|
|
66
|
+
condition = condition.this
|
|
67
|
+
return condition
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class UpdateSupportMixin(_BaseTableMixins, t.Generic[DF]):
|
|
71
|
+
@ensure_cte()
|
|
72
|
+
def update(
|
|
73
|
+
self,
|
|
74
|
+
set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
|
|
75
|
+
where: t.Optional[t.Union[Column, str, bool]] = None,
|
|
76
|
+
) -> LazyExpression:
|
|
77
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
78
|
+
|
|
79
|
+
condition = self._ensure_where_condition(where)
|
|
80
|
+
update_set = self._ensure_and_normalize_update_set(set_)
|
|
81
|
+
update_expr = exp.Update(
|
|
82
|
+
this=self_expr,
|
|
83
|
+
expressions=[
|
|
84
|
+
exp.EQ(
|
|
85
|
+
this=key,
|
|
86
|
+
expression=val,
|
|
87
|
+
)
|
|
88
|
+
for key, val in update_set.items()
|
|
89
|
+
],
|
|
90
|
+
where=exp.Where(this=condition),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return LazyExpression(update_expr, self.session)
|
|
94
|
+
|
|
95
|
+
def _ensure_and_normalize_update_set(
|
|
96
|
+
self,
|
|
97
|
+
set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
|
|
98
|
+
) -> t.Dict[str, exp.Expression]:
|
|
99
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
100
|
+
update_set = {}
|
|
101
|
+
for key, val in set_.items():
|
|
102
|
+
key_column: Column = self._ensure_and_normalize_col(key)
|
|
103
|
+
key_expr = list(key_column.expression.find_all(exp.Column))
|
|
104
|
+
if len(key_expr) > 1:
|
|
105
|
+
raise ValueError(f"Can only update one a single column at a time.")
|
|
106
|
+
key = key_expr[0].alias_or_name
|
|
107
|
+
|
|
108
|
+
val_column: Column = self._ensure_and_normalize_col(val)
|
|
109
|
+
for col_expr in val_column.expression.find_all(exp.Column):
|
|
110
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
111
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Column `{col_expr.alias_or_name}` does not exist in the table."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
update_set[key] = val_column.expression
|
|
118
|
+
return update_set
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class DeleteSupportMixin(_BaseTableMixins, t.Generic[DF]):
|
|
122
|
+
@ensure_cte()
|
|
123
|
+
def delete(
|
|
124
|
+
self,
|
|
125
|
+
where: t.Optional[t.Union[Column, str, bool]] = None,
|
|
126
|
+
) -> LazyExpression:
|
|
127
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
128
|
+
|
|
129
|
+
condition = self._ensure_where_condition(where)
|
|
130
|
+
delete_expr = exp.Delete(
|
|
131
|
+
this=self_expr,
|
|
132
|
+
where=exp.Where(this=condition),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return LazyExpression(delete_expr, self.session)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class MergeSupportMixin(_BaseTable, t.Generic[DF]):
|
|
139
|
+
_merge_supported_clauses: t.Iterable[
|
|
140
|
+
t.Union[t.Type[WhenMatched], t.Type[WhenNotMatched], t.Type[WhenNotMatchedBySource]]
|
|
141
|
+
]
|
|
142
|
+
_merge_support_star: bool
|
|
143
|
+
|
|
144
|
+
@ensure_cte()
|
|
145
|
+
def merge(
|
|
146
|
+
self,
|
|
147
|
+
other_df: DF,
|
|
148
|
+
condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
|
|
149
|
+
clauses: t.Iterable[t.Union[WhenMatched, WhenNotMatched, WhenNotMatchedBySource]],
|
|
150
|
+
) -> LazyExpression:
|
|
151
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
152
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
153
|
+
|
|
154
|
+
other_df = other_df._convert_leaf_to_cte()
|
|
155
|
+
|
|
156
|
+
if condition is None:
|
|
157
|
+
raise ValueError("condition cannot be None")
|
|
158
|
+
|
|
159
|
+
condition_columns: Column = self._ensure_and_normalize_condition(condition, other_df)
|
|
160
|
+
other_name = self._create_hash_from_expression(other_df.expression)
|
|
161
|
+
other_expr = exp.Subquery(
|
|
162
|
+
this=other_df.expression, alias=exp.TableAlias(this=exp.to_identifier(other_name))
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
for col_expr in condition_columns.expression.find_all(exp.Column):
|
|
166
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
167
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
168
|
+
if col_expr.table == other_df.latest_cte_name:
|
|
169
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
170
|
+
|
|
171
|
+
merge_expressions = []
|
|
172
|
+
for clause in clauses:
|
|
173
|
+
if not isinstance(clause, tuple(self._merge_supported_clauses)):
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"Unsupported clause type {type(clause.clause)} for merge operation"
|
|
176
|
+
)
|
|
177
|
+
expression = None
|
|
178
|
+
|
|
179
|
+
if clause.clause.condition is not None:
|
|
180
|
+
cond_clause = self._ensure_and_normalize_condition(
|
|
181
|
+
clause.clause.condition, other_df, True
|
|
182
|
+
)
|
|
183
|
+
for col_expr in cond_clause.expression.find_all(exp.Column):
|
|
184
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
185
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
186
|
+
if col_expr.table == other_df.latest_cte_name:
|
|
187
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
188
|
+
else:
|
|
189
|
+
cond_clause = None
|
|
190
|
+
if clause.clause.clause_type == Clause.UPDATE:
|
|
191
|
+
update_set = self._ensure_and_normalize_assignments(
|
|
192
|
+
clause.clause.assignments, other_df
|
|
193
|
+
)
|
|
194
|
+
expression = exp.When(
|
|
195
|
+
matched=clause.clause.matched,
|
|
196
|
+
source=clause.clause.by_source,
|
|
197
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
198
|
+
then=exp.Update(
|
|
199
|
+
expressions=[
|
|
200
|
+
exp.EQ(
|
|
201
|
+
this=key,
|
|
202
|
+
expression=val,
|
|
203
|
+
)
|
|
204
|
+
for key, val in update_set.items()
|
|
205
|
+
]
|
|
206
|
+
),
|
|
207
|
+
)
|
|
208
|
+
if clause.clause.clause_type == Clause.UPDATE_ALL:
|
|
209
|
+
if not self._support_star:
|
|
210
|
+
raise ValueError("Merge operation does not support UPDATE_ALL")
|
|
211
|
+
expression = exp.When(
|
|
212
|
+
matched=clause.clause.matched,
|
|
213
|
+
source=clause.clause.by_source,
|
|
214
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
215
|
+
then=exp.Update(expressions=[exp.Star()]),
|
|
216
|
+
)
|
|
217
|
+
elif clause.clause.clause_type == Clause.INSERT:
|
|
218
|
+
insert_values = self._ensure_and_normalize_assignments(
|
|
219
|
+
clause.clause.assignments, other_df
|
|
220
|
+
)
|
|
221
|
+
expression = exp.When(
|
|
222
|
+
matched=clause.clause.matched,
|
|
223
|
+
source=clause.clause.by_source,
|
|
224
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
225
|
+
then=exp.Insert(
|
|
226
|
+
this=exp.Tuple(expressions=[key for key in insert_values.keys()]),
|
|
227
|
+
expression=exp.Tuple(expressions=[val for val in insert_values.values()]),
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
elif clause.clause.clause_type == Clause.INSERT_ALL:
|
|
231
|
+
if not self._support_star:
|
|
232
|
+
raise ValueError("Merge operation does not support INSERT_ALL")
|
|
233
|
+
expression = exp.When(
|
|
234
|
+
matched=clause.clause.matched,
|
|
235
|
+
source=clause.clause.by_source,
|
|
236
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
237
|
+
then=exp.Insert(expression=exp.Star()),
|
|
238
|
+
)
|
|
239
|
+
elif clause.clause.clause_type == Clause.DELETE:
|
|
240
|
+
expression = exp.When(
|
|
241
|
+
matched=clause.clause.matched,
|
|
242
|
+
source=clause.clause.by_source,
|
|
243
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
244
|
+
then=exp.var("DELETE"),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if expression:
|
|
248
|
+
merge_expressions.append(expression)
|
|
249
|
+
|
|
250
|
+
if Whens is None:
|
|
251
|
+
merge_expr = exp.merge(
|
|
252
|
+
*merge_expressions,
|
|
253
|
+
into=self_expr,
|
|
254
|
+
using=other_expr,
|
|
255
|
+
on=condition_columns.expression,
|
|
256
|
+
)
|
|
257
|
+
else:
|
|
258
|
+
merge_expr = exp.merge(
|
|
259
|
+
Whens(expressions=merge_expressions),
|
|
260
|
+
into=self_expr,
|
|
261
|
+
using=other_expr,
|
|
262
|
+
on=condition_columns.expression,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return LazyExpression(merge_expr, self.session)
|
|
266
|
+
|
|
267
|
+
def _ensure_and_normalize_condition(
|
|
268
|
+
self,
|
|
269
|
+
condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
|
|
270
|
+
other_df: DF,
|
|
271
|
+
clause: t.Optional[bool] = False,
|
|
272
|
+
):
|
|
273
|
+
join_expression = self._add_ctes_to_expression(
|
|
274
|
+
self.expression, other_df.expression.copy().ctes
|
|
275
|
+
)
|
|
276
|
+
condition = self._ensure_and_normalize_cols(condition, self.expression)
|
|
277
|
+
self._handle_self_join(other_df, condition)
|
|
278
|
+
|
|
279
|
+
if isinstance(condition[0].expression, exp.Column) and not clause:
|
|
280
|
+
table_names = [
|
|
281
|
+
table.alias_or_name
|
|
282
|
+
for table in [
|
|
283
|
+
self.expression.args["from"].this,
|
|
284
|
+
other_df.expression.args["from"].this,
|
|
285
|
+
]
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
join_column_pairs, join_clause = self._handle_join_column_names_only(
|
|
289
|
+
condition, join_expression, other_df, table_names
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
join_clause = self._normalize_join_clause(condition, join_expression)
|
|
293
|
+
return join_clause
|
|
294
|
+
|
|
295
|
+
def _ensure_and_normalize_assignments(
|
|
296
|
+
self,
|
|
297
|
+
assignments: t.Dict[
|
|
298
|
+
t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]
|
|
299
|
+
],
|
|
300
|
+
other_df,
|
|
301
|
+
) -> t.Dict[exp.Column, exp.Expression]:
|
|
302
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
303
|
+
other_name = self._create_hash_from_expression(other_df.expression)
|
|
304
|
+
update_set = {}
|
|
305
|
+
for key, val in assignments.items():
|
|
306
|
+
key_column: Column = self._ensure_and_normalize_col(key)
|
|
307
|
+
key_expr = list(key_column.expression.find_all(exp.Column))
|
|
308
|
+
if len(key_expr) > 1:
|
|
309
|
+
raise ValueError(f"Target expression `{key_expr}` should be a single column.")
|
|
310
|
+
column_key = exp.column(key_expr[0].alias_or_name)
|
|
311
|
+
|
|
312
|
+
val = self._ensure_and_normalize_col(val)
|
|
313
|
+
val = self._ensure_and_normalize_cols(val, other_df.expression)[0]
|
|
314
|
+
if self.branch_id == other_df.branch_id:
|
|
315
|
+
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
316
|
+
for col_expr in val.expression.find_all(exp.Column):
|
|
317
|
+
if (
|
|
318
|
+
"join_on_uuid" in col_expr.meta
|
|
319
|
+
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
320
|
+
):
|
|
321
|
+
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
322
|
+
|
|
323
|
+
for col_expr in val.expression.find_all(exp.Column):
|
|
324
|
+
if not col_expr.table or col_expr.table == other_df.latest_cte_name:
|
|
325
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
326
|
+
elif col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
327
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
328
|
+
else:
|
|
329
|
+
raise ValueError(
|
|
330
|
+
f"Column `{col_expr.alias_or_name}` does not exist in any of the tables."
|
|
331
|
+
)
|
|
332
|
+
if isinstance(val.expression, exp.Alias):
|
|
333
|
+
val.expression = val.expression.this
|
|
334
|
+
update_set[column_key] = val.expression
|
|
335
|
+
return update_set
|
sqlframe/base/readerwriter.py
CHANGED
|
@@ -21,19 +21,20 @@ else:
|
|
|
21
21
|
if t.TYPE_CHECKING:
|
|
22
22
|
from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
|
|
23
23
|
from sqlframe.base.column import Column
|
|
24
|
-
from sqlframe.base.session import DF, _BaseSession
|
|
24
|
+
from sqlframe.base.session import DF, TABLE, _BaseSession
|
|
25
25
|
from sqlframe.base.types import StructType
|
|
26
26
|
|
|
27
27
|
SESSION = t.TypeVar("SESSION", bound=_BaseSession)
|
|
28
28
|
else:
|
|
29
29
|
SESSION = t.TypeVar("SESSION")
|
|
30
30
|
DF = t.TypeVar("DF")
|
|
31
|
+
TABLE = t.TypeVar("TABLE")
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
37
|
+
class _BaseDataFrameReader(t.Generic[SESSION, DF, TABLE]):
|
|
37
38
|
def __init__(self, spark: SESSION):
|
|
38
39
|
self._session = spark
|
|
39
40
|
self.state_format_to_read: t.Optional[str] = None
|
|
@@ -42,7 +43,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
|
42
43
|
def session(self) -> SESSION:
|
|
43
44
|
return self._session
|
|
44
45
|
|
|
45
|
-
def table(self, tableName: str) ->
|
|
46
|
+
def table(self, tableName: str) -> TABLE:
|
|
46
47
|
tableName = normalize_string(tableName, from_dialect="input", is_table=True)
|
|
47
48
|
if df := self.session.temp_views.get(tableName):
|
|
48
49
|
return df
|
|
@@ -50,7 +51,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
|
50
51
|
self.session.catalog.add_table(table)
|
|
51
52
|
columns = self.session.catalog.get_columns_from_schema(table)
|
|
52
53
|
|
|
53
|
-
return self.session.
|
|
54
|
+
return self.session._create_table(
|
|
54
55
|
exp.Select()
|
|
55
56
|
.from_(tableName, dialect=self.session.input_dialect)
|
|
56
57
|
.select(*columns, dialect=self.session.input_dialect)
|
sqlframe/base/session.py
CHANGED
|
@@ -27,6 +27,7 @@ from sqlframe.base.catalog import _BaseCatalog
|
|
|
27
27
|
from sqlframe.base.dataframe import BaseDataFrame
|
|
28
28
|
from sqlframe.base.normalize import normalize_dict
|
|
29
29
|
from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
|
|
30
|
+
from sqlframe.base.table import _BaseTable
|
|
30
31
|
from sqlframe.base.udf import _BaseUDFRegistration
|
|
31
32
|
from sqlframe.base.util import (
|
|
32
33
|
get_column_mapping_from_schema_input,
|
|
@@ -65,17 +66,19 @@ CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
|
|
|
65
66
|
READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
|
|
66
67
|
WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
|
|
67
68
|
DF = t.TypeVar("DF", bound=BaseDataFrame)
|
|
69
|
+
TABLE = t.TypeVar("TABLE", bound=_BaseTable)
|
|
68
70
|
UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration)
|
|
69
71
|
|
|
70
72
|
_MISSING = "MISSING"
|
|
71
73
|
|
|
72
74
|
|
|
73
|
-
class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION]):
|
|
75
|
+
class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, TABLE, CONN, UDF_REGISTRATION]):
|
|
74
76
|
_instance = None
|
|
75
77
|
_reader: t.Type[READER]
|
|
76
78
|
_writer: t.Type[WRITER]
|
|
77
79
|
_catalog: t.Type[CATALOG]
|
|
78
80
|
_df: t.Type[DF]
|
|
81
|
+
_table: t.Type[TABLE]
|
|
79
82
|
_udf_registration: t.Type[UDF_REGISTRATION]
|
|
80
83
|
|
|
81
84
|
SANITIZE_COLUMN_NAMES = False
|
|
@@ -158,12 +161,15 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION
|
|
|
158
161
|
return name.replace("(", "_").replace(")", "_")
|
|
159
162
|
return name
|
|
160
163
|
|
|
161
|
-
def table(self, tableName: str) ->
|
|
164
|
+
def table(self, tableName: str) -> TABLE:
|
|
162
165
|
return self.read.table(tableName)
|
|
163
166
|
|
|
164
167
|
def _create_df(self, *args, **kwargs) -> DF:
|
|
165
168
|
return self._df(self, *args, **kwargs)
|
|
166
169
|
|
|
170
|
+
def _create_table(self, *args, **kwargs) -> TABLE:
|
|
171
|
+
return self._table(self, *args, **kwargs)
|
|
172
|
+
|
|
167
173
|
def __new__(cls, *args, **kwargs):
|
|
168
174
|
if _BaseSession._instance is None:
|
|
169
175
|
_BaseSession._instance = super().__new__(cls)
|