sqlframe 3.13.4__py3-none-any.whl → 3.14.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/dataframe.py +102 -61
- sqlframe/base/mixins/table_mixins.py +335 -0
- sqlframe/base/readerwriter.py +5 -4
- sqlframe/base/session.py +8 -2
- sqlframe/base/table.py +238 -0
- sqlframe/bigquery/catalog.py +1 -0
- sqlframe/bigquery/readwriter.py +2 -1
- sqlframe/bigquery/session.py +3 -0
- sqlframe/bigquery/table.py +24 -0
- sqlframe/databricks/readwriter.py +2 -1
- sqlframe/databricks/session.py +3 -0
- sqlframe/databricks/table.py +24 -0
- sqlframe/duckdb/readwriter.py +4 -1
- sqlframe/duckdb/session.py +3 -0
- sqlframe/duckdb/table.py +16 -0
- sqlframe/postgres/readwriter.py +2 -1
- sqlframe/postgres/session.py +3 -0
- sqlframe/postgres/table.py +24 -0
- sqlframe/redshift/readwriter.py +2 -1
- sqlframe/redshift/session.py +3 -0
- sqlframe/redshift/table.py +15 -0
- sqlframe/snowflake/readwriter.py +2 -1
- sqlframe/snowflake/session.py +3 -0
- sqlframe/snowflake/table.py +23 -0
- sqlframe/spark/readwriter.py +2 -1
- sqlframe/spark/session.py +3 -0
- sqlframe/spark/table.py +6 -0
- sqlframe/standalone/readwriter.py +4 -1
- sqlframe/standalone/session.py +3 -0
- sqlframe/standalone/table.py +6 -0
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.1.dist-info}/METADATA +4 -4
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.1.dist-info}/RECORD +36 -26
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.1.dist-info}/LICENSE +0 -0
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.1.dist-info}/WHEEL +0 -0
- {sqlframe-3.13.4.dist-info → sqlframe-3.14.1.dist-info}/top_level.txt +0 -0
sqlframe/_version.py
CHANGED
sqlframe/base/dataframe.py
CHANGED
|
@@ -79,6 +79,23 @@ JOIN_HINTS = {
|
|
|
79
79
|
"SHUFFLE_REPLICATE_NL",
|
|
80
80
|
}
|
|
81
81
|
|
|
82
|
+
JOIN_TYPE_MAPPING = {
|
|
83
|
+
"inner": "inner",
|
|
84
|
+
"cross": "cross",
|
|
85
|
+
"outer": "full_outer",
|
|
86
|
+
"full": "full_outer",
|
|
87
|
+
"fullouter": "full_outer",
|
|
88
|
+
"left": "left_outer",
|
|
89
|
+
"leftouter": "left_outer",
|
|
90
|
+
"right": "right_outer",
|
|
91
|
+
"rightouter": "right_outer",
|
|
92
|
+
"semi": "left_semi",
|
|
93
|
+
"leftsemi": "left_semi",
|
|
94
|
+
"left_semi": "left_semi",
|
|
95
|
+
"anti": "left_anti",
|
|
96
|
+
"leftanti": "left_anti",
|
|
97
|
+
"left_anti": "left_anti",
|
|
98
|
+
}
|
|
82
99
|
|
|
83
100
|
DF = t.TypeVar("DF", bound="BaseDataFrame")
|
|
84
101
|
|
|
@@ -872,6 +889,68 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
872
889
|
"""
|
|
873
890
|
return self.join.__wrapped__(self, other, how="cross") # type: ignore
|
|
874
891
|
|
|
892
|
+
def _handle_self_join(self, other_df: DF, join_columns: t.List[Column]):
|
|
893
|
+
# If the two dataframes being joined come from the same branch, we then check if they have any columns that
|
|
894
|
+
# were created using the "branch_id" (df["column_name"]). If so, we know that we need to differentiate
|
|
895
|
+
# the two columns since they would end up with the same table name. We do this by checking for the unique
|
|
896
|
+
# uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
|
|
897
|
+
# it comes from the other df and we change the table name to the other df's table name.
|
|
898
|
+
# See `test_self_join` for an example of this.
|
|
899
|
+
if self.branch_id == other_df.branch_id:
|
|
900
|
+
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
901
|
+
for col in join_columns:
|
|
902
|
+
for col_expr in col.expression.find_all(exp.Column):
|
|
903
|
+
if (
|
|
904
|
+
"join_on_uuid" in col_expr.meta
|
|
905
|
+
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
906
|
+
):
|
|
907
|
+
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
908
|
+
|
|
909
|
+
@staticmethod
|
|
910
|
+
def _handle_join_column_names_only(
|
|
911
|
+
join_columns: t.List[Column],
|
|
912
|
+
join_expression: exp.Select,
|
|
913
|
+
other_df: DF,
|
|
914
|
+
table_names: t.List[str],
|
|
915
|
+
):
|
|
916
|
+
potential_ctes = [
|
|
917
|
+
cte
|
|
918
|
+
for cte in join_expression.ctes
|
|
919
|
+
if cte.alias_or_name in table_names and cte.alias_or_name != other_df.latest_cte_name
|
|
920
|
+
]
|
|
921
|
+
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
922
|
+
# tables and see if they have the column being referenced.
|
|
923
|
+
join_column_pairs = []
|
|
924
|
+
for join_column in join_columns:
|
|
925
|
+
num_matching_ctes = 0
|
|
926
|
+
for cte in potential_ctes:
|
|
927
|
+
if join_column.alias_or_name in cte.this.named_selects:
|
|
928
|
+
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
929
|
+
right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
|
|
930
|
+
join_column_pairs.append((left_column, right_column))
|
|
931
|
+
num_matching_ctes += 1
|
|
932
|
+
# We only want to match one table to the column and that should be matched left -> right
|
|
933
|
+
# so we break after the first match
|
|
934
|
+
break
|
|
935
|
+
if num_matching_ctes == 0:
|
|
936
|
+
raise ValueError(
|
|
937
|
+
f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
|
|
938
|
+
)
|
|
939
|
+
join_clause = functools.reduce(
|
|
940
|
+
lambda x, y: x & y,
|
|
941
|
+
[left_column == right_column for left_column, right_column in join_column_pairs],
|
|
942
|
+
)
|
|
943
|
+
return join_column_pairs, join_clause
|
|
944
|
+
|
|
945
|
+
def _normalize_join_clause(
|
|
946
|
+
self, join_columns: t.List[Column], join_expression: t.Optional[exp.Select]
|
|
947
|
+
) -> Column:
|
|
948
|
+
join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
|
|
949
|
+
if len(join_columns) > 1:
|
|
950
|
+
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
951
|
+
join_clause = join_columns[0]
|
|
952
|
+
return join_clause
|
|
953
|
+
|
|
875
954
|
@operation(Operation.FROM)
|
|
876
955
|
def join(
|
|
877
956
|
self,
|
|
@@ -882,37 +961,33 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
882
961
|
) -> Self:
|
|
883
962
|
from sqlframe.base.functions import coalesce
|
|
884
963
|
|
|
885
|
-
if on is None:
|
|
964
|
+
if (on is None) and ("cross" not in how):
|
|
886
965
|
logger.warning("Got no value for on. This appears to change the join to a cross join.")
|
|
887
966
|
how = "cross"
|
|
967
|
+
if (on is not None) and ("cross" in how):
|
|
968
|
+
# Not a lot of doc, but Spark handles cross with predicate as an inner join
|
|
969
|
+
# https://learn.microsoft.com/en-us/dotnet/api/microsoft.spark.sql.dataframe.join
|
|
970
|
+
logger.warning("Got cross join with an 'on' value. This will result in an inner join.")
|
|
971
|
+
how = "inner"
|
|
888
972
|
|
|
889
973
|
other_df = other_df._convert_leaf_to_cte()
|
|
890
974
|
join_expression = self._add_ctes_to_expression(self.expression, other_df.expression.ctes)
|
|
891
975
|
# We will determine actual "join on" expression later so we don't provide it at first
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
)
|
|
976
|
+
join_type = JOIN_TYPE_MAPPING.get(how, how).replace("_", " ")
|
|
977
|
+
join_expression = join_expression.join(join_expression.ctes[-1].alias, join_type=join_type)
|
|
895
978
|
self_columns = self._get_outer_select_columns(join_expression)
|
|
896
979
|
other_columns = self._get_outer_select_columns(other_df.expression)
|
|
897
980
|
join_columns = self._ensure_and_normalize_cols(on)
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
# the two columns since they would end up with the same table name. We do this by checking for the unique
|
|
901
|
-
# uuids in the other df and finding columns that have metadata on them that match the uuids. If so, we know
|
|
902
|
-
# it comes from the other df and we change the table name to the other df's table name.
|
|
903
|
-
# See `test_self_join` for an example of this.
|
|
904
|
-
if self.branch_id == other_df.branch_id:
|
|
905
|
-
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
906
|
-
for col in join_columns:
|
|
907
|
-
for col_expr in col.expression.find_all(exp.Column):
|
|
908
|
-
if (
|
|
909
|
-
"join_on_uuid" in col_expr.meta
|
|
910
|
-
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
911
|
-
):
|
|
912
|
-
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
981
|
+
self._handle_self_join(other_df, join_columns)
|
|
982
|
+
|
|
913
983
|
# Determines the join clause and select columns to be used passed on what type of columns were provided for
|
|
914
984
|
# the join. The columns returned changes based on how the on expression is provided.
|
|
915
|
-
|
|
985
|
+
select_columns = (
|
|
986
|
+
self_columns
|
|
987
|
+
if join_type in ["left anti", "left semi"]
|
|
988
|
+
else self_columns + other_columns
|
|
989
|
+
)
|
|
990
|
+
if join_type != "cross":
|
|
916
991
|
if isinstance(join_columns[0].expression, exp.Column):
|
|
917
992
|
"""
|
|
918
993
|
Unique characteristics of join on column names only:
|
|
@@ -923,38 +998,9 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
923
998
|
table.alias_or_name
|
|
924
999
|
for table in get_tables_from_expression_with_join(join_expression)
|
|
925
1000
|
]
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
if cte.alias_or_name in table_names
|
|
930
|
-
and cte.alias_or_name != other_df.latest_cte_name
|
|
931
|
-
]
|
|
932
|
-
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
933
|
-
# tables and see if they have the column being referenced.
|
|
934
|
-
join_column_pairs = []
|
|
935
|
-
for join_column in join_columns:
|
|
936
|
-
num_matching_ctes = 0
|
|
937
|
-
for cte in potential_ctes:
|
|
938
|
-
if join_column.alias_or_name in cte.this.named_selects:
|
|
939
|
-
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
940
|
-
right_column = join_column.copy().set_table_name(
|
|
941
|
-
other_df.latest_cte_name
|
|
942
|
-
)
|
|
943
|
-
join_column_pairs.append((left_column, right_column))
|
|
944
|
-
num_matching_ctes += 1
|
|
945
|
-
# We only want to match one table to the column and that should be matched left -> right
|
|
946
|
-
# so we break after the first match
|
|
947
|
-
break
|
|
948
|
-
if num_matching_ctes == 0:
|
|
949
|
-
raise ValueError(
|
|
950
|
-
f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
|
|
951
|
-
)
|
|
952
|
-
join_clause = functools.reduce(
|
|
953
|
-
lambda x, y: x & y,
|
|
954
|
-
[
|
|
955
|
-
left_column == right_column
|
|
956
|
-
for left_column, right_column in join_column_pairs
|
|
957
|
-
],
|
|
1001
|
+
|
|
1002
|
+
join_column_pairs, join_clause = self._handle_join_column_names_only(
|
|
1003
|
+
join_columns, join_expression, other_df, table_names
|
|
958
1004
|
)
|
|
959
1005
|
join_column_names = [
|
|
960
1006
|
coalesce(
|
|
@@ -972,7 +1018,7 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
972
1018
|
if not isinstance(column.expression.this, exp.Star)
|
|
973
1019
|
else column.sql()
|
|
974
1020
|
)
|
|
975
|
-
for column in
|
|
1021
|
+
for column in select_columns
|
|
976
1022
|
]
|
|
977
1023
|
select_column_names = [
|
|
978
1024
|
column_name
|
|
@@ -989,17 +1035,12 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
989
1035
|
* There is no deduplication of the results.
|
|
990
1036
|
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
|
|
991
1037
|
"""
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
995
|
-
join_clause = join_columns[0]
|
|
996
|
-
select_column_names = [
|
|
997
|
-
column.alias_or_name for column in self_columns + other_columns
|
|
998
|
-
]
|
|
1038
|
+
join_clause = self._normalize_join_clause(join_columns, join_expression)
|
|
1039
|
+
select_column_names = [column.alias_or_name for column in select_columns]
|
|
999
1040
|
|
|
1000
1041
|
# Update the on expression with the actual join clause to replace the dummy one from before
|
|
1001
1042
|
else:
|
|
1002
|
-
select_column_names = [column.alias_or_name for column in
|
|
1043
|
+
select_column_names = [column.alias_or_name for column in select_columns]
|
|
1003
1044
|
join_clause = None
|
|
1004
1045
|
join_expression.args["joins"][-1].set("on", join_clause.expression if join_clause else None)
|
|
1005
1046
|
new_df = self.copy(expression=join_expression)
|
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import logging
|
|
3
|
+
import typing as t
|
|
4
|
+
|
|
5
|
+
from sqlglot import exp
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from sqlglot.expressions import Whens
|
|
9
|
+
except ImportError:
|
|
10
|
+
Whens = None # type: ignore
|
|
11
|
+
from sqlglot.helper import object_to_dict
|
|
12
|
+
|
|
13
|
+
from sqlframe.base.column import Column
|
|
14
|
+
from sqlframe.base.table import (
|
|
15
|
+
DF,
|
|
16
|
+
Clause,
|
|
17
|
+
LazyExpression,
|
|
18
|
+
WhenMatched,
|
|
19
|
+
WhenNotMatched,
|
|
20
|
+
WhenNotMatchedBySource,
|
|
21
|
+
_BaseTable,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
if t.TYPE_CHECKING:
|
|
25
|
+
from sqlframe.base._typing import ColumnOrLiteral
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def ensure_cte() -> t.Callable[[t.Callable], t.Callable]:
|
|
32
|
+
def decorator(func: t.Callable) -> t.Callable:
|
|
33
|
+
@functools.wraps(func)
|
|
34
|
+
def wrapper(self: _BaseTable, *args, **kwargs) -> t.Any:
|
|
35
|
+
if len(self.expression.ctes) > 0:
|
|
36
|
+
return func(self, *args, **kwargs) # type: ignore
|
|
37
|
+
self_class = self.__class__
|
|
38
|
+
self = self._convert_leaf_to_cte()
|
|
39
|
+
self = self_class(**object_to_dict(self))
|
|
40
|
+
return func(self, *args, **kwargs) # type: ignore
|
|
41
|
+
|
|
42
|
+
wrapper.__wrapped__ = func # type: ignore
|
|
43
|
+
return wrapper
|
|
44
|
+
|
|
45
|
+
return decorator
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class _BaseTableMixins(_BaseTable, t.Generic[DF]):
|
|
49
|
+
def _ensure_where_condition(
|
|
50
|
+
self, where: t.Optional[t.Union[Column, str, bool]] = None
|
|
51
|
+
) -> exp.Expression:
|
|
52
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
53
|
+
|
|
54
|
+
if where is None:
|
|
55
|
+
logger.warning("Empty value for `where`clause. Defaults to `True`.")
|
|
56
|
+
condition: exp.Expression = exp.Boolean(this=True)
|
|
57
|
+
else:
|
|
58
|
+
condition_list = self._ensure_and_normalize_cols(where, self.expression)
|
|
59
|
+
if len(condition_list) > 1:
|
|
60
|
+
condition_list = [functools.reduce(lambda x, y: x & y, condition_list)]
|
|
61
|
+
for col_expr in condition_list[0].expression.find_all(exp.Column):
|
|
62
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
63
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
64
|
+
condition = condition_list[0].expression
|
|
65
|
+
if isinstance(condition, exp.Alias):
|
|
66
|
+
condition = condition.this
|
|
67
|
+
return condition
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class UpdateSupportMixin(_BaseTableMixins, t.Generic[DF]):
|
|
71
|
+
@ensure_cte()
|
|
72
|
+
def update(
|
|
73
|
+
self,
|
|
74
|
+
set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
|
|
75
|
+
where: t.Optional[t.Union[Column, str, bool]] = None,
|
|
76
|
+
) -> LazyExpression:
|
|
77
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
78
|
+
|
|
79
|
+
condition = self._ensure_where_condition(where)
|
|
80
|
+
update_set = self._ensure_and_normalize_update_set(set_)
|
|
81
|
+
update_expr = exp.Update(
|
|
82
|
+
this=self_expr,
|
|
83
|
+
expressions=[
|
|
84
|
+
exp.EQ(
|
|
85
|
+
this=key,
|
|
86
|
+
expression=val,
|
|
87
|
+
)
|
|
88
|
+
for key, val in update_set.items()
|
|
89
|
+
],
|
|
90
|
+
where=exp.Where(this=condition),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return LazyExpression(update_expr, self.session)
|
|
94
|
+
|
|
95
|
+
def _ensure_and_normalize_update_set(
|
|
96
|
+
self,
|
|
97
|
+
set_: t.Dict[t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]],
|
|
98
|
+
) -> t.Dict[str, exp.Expression]:
|
|
99
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
100
|
+
update_set = {}
|
|
101
|
+
for key, val in set_.items():
|
|
102
|
+
key_column: Column = self._ensure_and_normalize_col(key)
|
|
103
|
+
key_expr = list(key_column.expression.find_all(exp.Column))
|
|
104
|
+
if len(key_expr) > 1:
|
|
105
|
+
raise ValueError(f"Can only update one a single column at a time.")
|
|
106
|
+
key = key_expr[0].alias_or_name
|
|
107
|
+
|
|
108
|
+
val_column: Column = self._ensure_and_normalize_col(val)
|
|
109
|
+
for col_expr in val_column.expression.find_all(exp.Column):
|
|
110
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
111
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
112
|
+
else:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Column `{col_expr.alias_or_name}` does not exist in the table."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
update_set[key] = val_column.expression
|
|
118
|
+
return update_set
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class DeleteSupportMixin(_BaseTableMixins, t.Generic[DF]):
|
|
122
|
+
@ensure_cte()
|
|
123
|
+
def delete(
|
|
124
|
+
self,
|
|
125
|
+
where: t.Optional[t.Union[Column, str, bool]] = None,
|
|
126
|
+
) -> LazyExpression:
|
|
127
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
128
|
+
|
|
129
|
+
condition = self._ensure_where_condition(where)
|
|
130
|
+
delete_expr = exp.Delete(
|
|
131
|
+
this=self_expr,
|
|
132
|
+
where=exp.Where(this=condition),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return LazyExpression(delete_expr, self.session)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class MergeSupportMixin(_BaseTable, t.Generic[DF]):
|
|
139
|
+
_merge_supported_clauses: t.Iterable[
|
|
140
|
+
t.Union[t.Type[WhenMatched], t.Type[WhenNotMatched], t.Type[WhenNotMatchedBySource]]
|
|
141
|
+
]
|
|
142
|
+
_merge_support_star: bool
|
|
143
|
+
|
|
144
|
+
@ensure_cte()
|
|
145
|
+
def merge(
|
|
146
|
+
self,
|
|
147
|
+
other_df: DF,
|
|
148
|
+
condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
|
|
149
|
+
clauses: t.Iterable[t.Union[WhenMatched, WhenNotMatched, WhenNotMatchedBySource]],
|
|
150
|
+
) -> LazyExpression:
|
|
151
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
152
|
+
self_expr = self.expression.ctes[0].this.args["from"].this
|
|
153
|
+
|
|
154
|
+
other_df = other_df._convert_leaf_to_cte()
|
|
155
|
+
|
|
156
|
+
if condition is None:
|
|
157
|
+
raise ValueError("condition cannot be None")
|
|
158
|
+
|
|
159
|
+
condition_columns: Column = self._ensure_and_normalize_condition(condition, other_df)
|
|
160
|
+
other_name = self._create_hash_from_expression(other_df.expression)
|
|
161
|
+
other_expr = exp.Subquery(
|
|
162
|
+
this=other_df.expression, alias=exp.TableAlias(this=exp.to_identifier(other_name))
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
for col_expr in condition_columns.expression.find_all(exp.Column):
|
|
166
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
167
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
168
|
+
if col_expr.table == other_df.latest_cte_name:
|
|
169
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
170
|
+
|
|
171
|
+
merge_expressions = []
|
|
172
|
+
for clause in clauses:
|
|
173
|
+
if not isinstance(clause, tuple(self._merge_supported_clauses)):
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"Unsupported clause type {type(clause.clause)} for merge operation"
|
|
176
|
+
)
|
|
177
|
+
expression = None
|
|
178
|
+
|
|
179
|
+
if clause.clause.condition is not None:
|
|
180
|
+
cond_clause = self._ensure_and_normalize_condition(
|
|
181
|
+
clause.clause.condition, other_df, True
|
|
182
|
+
)
|
|
183
|
+
for col_expr in cond_clause.expression.find_all(exp.Column):
|
|
184
|
+
if col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
185
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
186
|
+
if col_expr.table == other_df.latest_cte_name:
|
|
187
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
188
|
+
else:
|
|
189
|
+
cond_clause = None
|
|
190
|
+
if clause.clause.clause_type == Clause.UPDATE:
|
|
191
|
+
update_set = self._ensure_and_normalize_assignments(
|
|
192
|
+
clause.clause.assignments, other_df
|
|
193
|
+
)
|
|
194
|
+
expression = exp.When(
|
|
195
|
+
matched=clause.clause.matched,
|
|
196
|
+
source=clause.clause.by_source,
|
|
197
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
198
|
+
then=exp.Update(
|
|
199
|
+
expressions=[
|
|
200
|
+
exp.EQ(
|
|
201
|
+
this=key,
|
|
202
|
+
expression=val,
|
|
203
|
+
)
|
|
204
|
+
for key, val in update_set.items()
|
|
205
|
+
]
|
|
206
|
+
),
|
|
207
|
+
)
|
|
208
|
+
if clause.clause.clause_type == Clause.UPDATE_ALL:
|
|
209
|
+
if not self._support_star:
|
|
210
|
+
raise ValueError("Merge operation does not support UPDATE_ALL")
|
|
211
|
+
expression = exp.When(
|
|
212
|
+
matched=clause.clause.matched,
|
|
213
|
+
source=clause.clause.by_source,
|
|
214
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
215
|
+
then=exp.Update(expressions=[exp.Star()]),
|
|
216
|
+
)
|
|
217
|
+
elif clause.clause.clause_type == Clause.INSERT:
|
|
218
|
+
insert_values = self._ensure_and_normalize_assignments(
|
|
219
|
+
clause.clause.assignments, other_df
|
|
220
|
+
)
|
|
221
|
+
expression = exp.When(
|
|
222
|
+
matched=clause.clause.matched,
|
|
223
|
+
source=clause.clause.by_source,
|
|
224
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
225
|
+
then=exp.Insert(
|
|
226
|
+
this=exp.Tuple(expressions=[key for key in insert_values.keys()]),
|
|
227
|
+
expression=exp.Tuple(expressions=[val for val in insert_values.values()]),
|
|
228
|
+
),
|
|
229
|
+
)
|
|
230
|
+
elif clause.clause.clause_type == Clause.INSERT_ALL:
|
|
231
|
+
if not self._support_star:
|
|
232
|
+
raise ValueError("Merge operation does not support INSERT_ALL")
|
|
233
|
+
expression = exp.When(
|
|
234
|
+
matched=clause.clause.matched,
|
|
235
|
+
source=clause.clause.by_source,
|
|
236
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
237
|
+
then=exp.Insert(expression=exp.Star()),
|
|
238
|
+
)
|
|
239
|
+
elif clause.clause.clause_type == Clause.DELETE:
|
|
240
|
+
expression = exp.When(
|
|
241
|
+
matched=clause.clause.matched,
|
|
242
|
+
source=clause.clause.by_source,
|
|
243
|
+
condition=cond_clause.expression if cond_clause else None,
|
|
244
|
+
then=exp.var("DELETE"),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
if expression:
|
|
248
|
+
merge_expressions.append(expression)
|
|
249
|
+
|
|
250
|
+
if Whens is None:
|
|
251
|
+
merge_expr = exp.merge(
|
|
252
|
+
*merge_expressions,
|
|
253
|
+
into=self_expr,
|
|
254
|
+
using=other_expr,
|
|
255
|
+
on=condition_columns.expression,
|
|
256
|
+
)
|
|
257
|
+
else:
|
|
258
|
+
merge_expr = exp.merge(
|
|
259
|
+
Whens(expressions=merge_expressions),
|
|
260
|
+
into=self_expr,
|
|
261
|
+
using=other_expr,
|
|
262
|
+
on=condition_columns.expression,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return LazyExpression(merge_expr, self.session)
|
|
266
|
+
|
|
267
|
+
def _ensure_and_normalize_condition(
|
|
268
|
+
self,
|
|
269
|
+
condition: t.Union[str, t.List[str], Column, t.List[Column], bool],
|
|
270
|
+
other_df: DF,
|
|
271
|
+
clause: t.Optional[bool] = False,
|
|
272
|
+
):
|
|
273
|
+
join_expression = self._add_ctes_to_expression(
|
|
274
|
+
self.expression, other_df.expression.copy().ctes
|
|
275
|
+
)
|
|
276
|
+
condition = self._ensure_and_normalize_cols(condition, self.expression)
|
|
277
|
+
self._handle_self_join(other_df, condition)
|
|
278
|
+
|
|
279
|
+
if isinstance(condition[0].expression, exp.Column) and not clause:
|
|
280
|
+
table_names = [
|
|
281
|
+
table.alias_or_name
|
|
282
|
+
for table in [
|
|
283
|
+
self.expression.args["from"].this,
|
|
284
|
+
other_df.expression.args["from"].this,
|
|
285
|
+
]
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
join_column_pairs, join_clause = self._handle_join_column_names_only(
|
|
289
|
+
condition, join_expression, other_df, table_names
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
join_clause = self._normalize_join_clause(condition, join_expression)
|
|
293
|
+
return join_clause
|
|
294
|
+
|
|
295
|
+
def _ensure_and_normalize_assignments(
|
|
296
|
+
self,
|
|
297
|
+
assignments: t.Dict[
|
|
298
|
+
t.Union[Column, str], t.Union[Column, "ColumnOrLiteral", exp.Expression]
|
|
299
|
+
],
|
|
300
|
+
other_df,
|
|
301
|
+
) -> t.Dict[exp.Column, exp.Expression]:
|
|
302
|
+
self_name = self.expression.ctes[0].this.args["from"].this.alias_or_name
|
|
303
|
+
other_name = self._create_hash_from_expression(other_df.expression)
|
|
304
|
+
update_set = {}
|
|
305
|
+
for key, val in assignments.items():
|
|
306
|
+
key_column: Column = self._ensure_and_normalize_col(key)
|
|
307
|
+
key_expr = list(key_column.expression.find_all(exp.Column))
|
|
308
|
+
if len(key_expr) > 1:
|
|
309
|
+
raise ValueError(f"Target expression `{key_expr}` should be a single column.")
|
|
310
|
+
column_key = exp.column(key_expr[0].alias_or_name)
|
|
311
|
+
|
|
312
|
+
val = self._ensure_and_normalize_col(val)
|
|
313
|
+
val = self._ensure_and_normalize_cols(val, other_df.expression)[0]
|
|
314
|
+
if self.branch_id == other_df.branch_id:
|
|
315
|
+
other_df_unique_uuids = other_df.known_uuids - self.known_uuids
|
|
316
|
+
for col_expr in val.expression.find_all(exp.Column):
|
|
317
|
+
if (
|
|
318
|
+
"join_on_uuid" in col_expr.meta
|
|
319
|
+
and col_expr.meta["join_on_uuid"] in other_df_unique_uuids
|
|
320
|
+
):
|
|
321
|
+
col_expr.set("table", exp.to_identifier(other_df.latest_cte_name))
|
|
322
|
+
|
|
323
|
+
for col_expr in val.expression.find_all(exp.Column):
|
|
324
|
+
if not col_expr.table or col_expr.table == other_df.latest_cte_name:
|
|
325
|
+
col_expr.set("table", exp.to_identifier(other_name))
|
|
326
|
+
elif col_expr.table == self.expression.args["from"].this.alias_or_name:
|
|
327
|
+
col_expr.set("table", exp.to_identifier(self_name))
|
|
328
|
+
else:
|
|
329
|
+
raise ValueError(
|
|
330
|
+
f"Column `{col_expr.alias_or_name}` does not exist in any of the tables."
|
|
331
|
+
)
|
|
332
|
+
if isinstance(val.expression, exp.Alias):
|
|
333
|
+
val.expression = val.expression.this
|
|
334
|
+
update_set[column_key] = val.expression
|
|
335
|
+
return update_set
|
sqlframe/base/readerwriter.py
CHANGED
|
@@ -21,19 +21,20 @@ else:
|
|
|
21
21
|
if t.TYPE_CHECKING:
|
|
22
22
|
from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
|
|
23
23
|
from sqlframe.base.column import Column
|
|
24
|
-
from sqlframe.base.session import DF, _BaseSession
|
|
24
|
+
from sqlframe.base.session import DF, TABLE, _BaseSession
|
|
25
25
|
from sqlframe.base.types import StructType
|
|
26
26
|
|
|
27
27
|
SESSION = t.TypeVar("SESSION", bound=_BaseSession)
|
|
28
28
|
else:
|
|
29
29
|
SESSION = t.TypeVar("SESSION")
|
|
30
30
|
DF = t.TypeVar("DF")
|
|
31
|
+
TABLE = t.TypeVar("TABLE")
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
logger = logging.getLogger(__name__)
|
|
34
35
|
|
|
35
36
|
|
|
36
|
-
class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
37
|
+
class _BaseDataFrameReader(t.Generic[SESSION, DF, TABLE]):
|
|
37
38
|
def __init__(self, spark: SESSION):
|
|
38
39
|
self._session = spark
|
|
39
40
|
self.state_format_to_read: t.Optional[str] = None
|
|
@@ -42,7 +43,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
|
42
43
|
def session(self) -> SESSION:
|
|
43
44
|
return self._session
|
|
44
45
|
|
|
45
|
-
def table(self, tableName: str) ->
|
|
46
|
+
def table(self, tableName: str) -> TABLE:
|
|
46
47
|
tableName = normalize_string(tableName, from_dialect="input", is_table=True)
|
|
47
48
|
if df := self.session.temp_views.get(tableName):
|
|
48
49
|
return df
|
|
@@ -50,7 +51,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
|
50
51
|
self.session.catalog.add_table(table)
|
|
51
52
|
columns = self.session.catalog.get_columns_from_schema(table)
|
|
52
53
|
|
|
53
|
-
return self.session.
|
|
54
|
+
return self.session._create_table(
|
|
54
55
|
exp.Select()
|
|
55
56
|
.from_(tableName, dialect=self.session.input_dialect)
|
|
56
57
|
.select(*columns, dialect=self.session.input_dialect)
|
sqlframe/base/session.py
CHANGED
|
@@ -27,6 +27,7 @@ from sqlframe.base.catalog import _BaseCatalog
|
|
|
27
27
|
from sqlframe.base.dataframe import BaseDataFrame
|
|
28
28
|
from sqlframe.base.normalize import normalize_dict
|
|
29
29
|
from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
|
|
30
|
+
from sqlframe.base.table import _BaseTable
|
|
30
31
|
from sqlframe.base.udf import _BaseUDFRegistration
|
|
31
32
|
from sqlframe.base.util import (
|
|
32
33
|
get_column_mapping_from_schema_input,
|
|
@@ -65,17 +66,19 @@ CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
|
|
|
65
66
|
READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
|
|
66
67
|
WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
|
|
67
68
|
DF = t.TypeVar("DF", bound=BaseDataFrame)
|
|
69
|
+
TABLE = t.TypeVar("TABLE", bound=_BaseTable)
|
|
68
70
|
UDF_REGISTRATION = t.TypeVar("UDF_REGISTRATION", bound=_BaseUDFRegistration)
|
|
69
71
|
|
|
70
72
|
_MISSING = "MISSING"
|
|
71
73
|
|
|
72
74
|
|
|
73
|
-
class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION]):
|
|
75
|
+
class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, TABLE, CONN, UDF_REGISTRATION]):
|
|
74
76
|
_instance = None
|
|
75
77
|
_reader: t.Type[READER]
|
|
76
78
|
_writer: t.Type[WRITER]
|
|
77
79
|
_catalog: t.Type[CATALOG]
|
|
78
80
|
_df: t.Type[DF]
|
|
81
|
+
_table: t.Type[TABLE]
|
|
79
82
|
_udf_registration: t.Type[UDF_REGISTRATION]
|
|
80
83
|
|
|
81
84
|
SANITIZE_COLUMN_NAMES = False
|
|
@@ -158,12 +161,15 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN, UDF_REGISTRATION
|
|
|
158
161
|
return name.replace("(", "_").replace(")", "_")
|
|
159
162
|
return name
|
|
160
163
|
|
|
161
|
-
def table(self, tableName: str) ->
|
|
164
|
+
def table(self, tableName: str) -> TABLE:
|
|
162
165
|
return self.read.table(tableName)
|
|
163
166
|
|
|
164
167
|
def _create_df(self, *args, **kwargs) -> DF:
|
|
165
168
|
return self._df(self, *args, **kwargs)
|
|
166
169
|
|
|
170
|
+
def _create_table(self, *args, **kwargs) -> TABLE:
|
|
171
|
+
return self._table(self, *args, **kwargs)
|
|
172
|
+
|
|
167
173
|
def __new__(cls, *args, **kwargs):
|
|
168
174
|
if _BaseSession._instance is None:
|
|
169
175
|
_BaseSession._instance = super().__new__(cls)
|