altimate-code 0.5.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +35 -0
- package/README.md +1 -5
- package/bin/altimate +6 -0
- package/bin/altimate-code +6 -0
- package/dbt-tools/bin/altimate-dbt +2 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
- package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
- package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
- package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
- package/dbt-tools/dist/index.js +23859 -0
- package/package.json +13 -13
- package/postinstall.mjs +42 -0
- package/skills/altimate-setup/SKILL.md +31 -0
package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py
ADDED
|
@@ -0,0 +1,866 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import logging
|
|
5
|
+
import typing as t
|
|
6
|
+
import zlib
|
|
7
|
+
from copy import copy
|
|
8
|
+
|
|
9
|
+
import sqlglot as sqlglot
|
|
10
|
+
from sqlglot import Dialect, expressions as exp
|
|
11
|
+
from sqlglot.dataframe.sql import functions as F
|
|
12
|
+
from sqlglot.dataframe.sql.column import Column
|
|
13
|
+
from sqlglot.dataframe.sql.group import GroupedData
|
|
14
|
+
from sqlglot.dataframe.sql.normalize import normalize
|
|
15
|
+
from sqlglot.dataframe.sql.operations import Operation, operation
|
|
16
|
+
from sqlglot.dataframe.sql.readwriter import DataFrameWriter
|
|
17
|
+
from sqlglot.dataframe.sql.transforms import replace_id_value
|
|
18
|
+
from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
|
|
19
|
+
from sqlglot.dataframe.sql.window import Window
|
|
20
|
+
from sqlglot.helper import ensure_list, object_to_dict, seq_get
|
|
21
|
+
from sqlglot.optimizer import optimize as optimize_func
|
|
22
|
+
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
|
23
|
+
|
|
24
|
+
if t.TYPE_CHECKING:
|
|
25
|
+
from sqlglot.dataframe.sql._typing import (
|
|
26
|
+
ColumnLiterals,
|
|
27
|
+
ColumnOrLiteral,
|
|
28
|
+
ColumnOrName,
|
|
29
|
+
OutputExpressionContainer,
|
|
30
|
+
)
|
|
31
|
+
from sqlglot.dataframe.sql.session import SparkSession
|
|
32
|
+
from sqlglot.dialects.dialect import DialectType
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger("sqlglot")
|
|
35
|
+
|
|
36
|
+
JOIN_HINTS = {
|
|
37
|
+
"BROADCAST",
|
|
38
|
+
"BROADCASTJOIN",
|
|
39
|
+
"MAPJOIN",
|
|
40
|
+
"MERGE",
|
|
41
|
+
"SHUFFLEMERGE",
|
|
42
|
+
"MERGEJOIN",
|
|
43
|
+
"SHUFFLE_HASH",
|
|
44
|
+
"SHUFFLE_REPLICATE_NL",
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class DataFrame:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
spark: SparkSession,
|
|
52
|
+
expression: exp.Select,
|
|
53
|
+
branch_id: t.Optional[str] = None,
|
|
54
|
+
sequence_id: t.Optional[str] = None,
|
|
55
|
+
last_op: Operation = Operation.INIT,
|
|
56
|
+
pending_hints: t.Optional[t.List[exp.Expression]] = None,
|
|
57
|
+
output_expression_container: t.Optional[OutputExpressionContainer] = None,
|
|
58
|
+
**kwargs,
|
|
59
|
+
):
|
|
60
|
+
self.spark = spark
|
|
61
|
+
self.expression = expression
|
|
62
|
+
self.branch_id = branch_id or self.spark._random_branch_id
|
|
63
|
+
self.sequence_id = sequence_id or self.spark._random_sequence_id
|
|
64
|
+
self.last_op = last_op
|
|
65
|
+
self.pending_hints = pending_hints or []
|
|
66
|
+
self.output_expression_container = output_expression_container or exp.Select()
|
|
67
|
+
|
|
68
|
+
def __getattr__(self, column_name: str) -> Column:
|
|
69
|
+
return self[column_name]
|
|
70
|
+
|
|
71
|
+
def __getitem__(self, column_name: str) -> Column:
|
|
72
|
+
column_name = f"{self.branch_id}.{column_name}"
|
|
73
|
+
return Column(column_name)
|
|
74
|
+
|
|
75
|
+
def __copy__(self):
|
|
76
|
+
return self.copy()
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def sparkSession(self):
|
|
80
|
+
return self.spark
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def write(self):
|
|
84
|
+
return DataFrameWriter(self)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def latest_cte_name(self) -> str:
|
|
88
|
+
if not self.expression.ctes:
|
|
89
|
+
from_exp = self.expression.args["from"]
|
|
90
|
+
if from_exp.alias_or_name:
|
|
91
|
+
return from_exp.alias_or_name
|
|
92
|
+
table_alias = from_exp.find(exp.TableAlias)
|
|
93
|
+
if not table_alias:
|
|
94
|
+
raise RuntimeError(
|
|
95
|
+
f"Could not find an alias name for this expression: {self.expression}"
|
|
96
|
+
)
|
|
97
|
+
return table_alias.alias_or_name
|
|
98
|
+
return self.expression.ctes[-1].alias
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def pending_join_hints(self):
|
|
102
|
+
return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def pending_partition_hints(self):
|
|
106
|
+
return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def columns(self) -> t.List[str]:
|
|
110
|
+
return self.expression.named_selects
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def na(self) -> DataFrameNaFunctions:
|
|
114
|
+
return DataFrameNaFunctions(self)
|
|
115
|
+
|
|
116
|
+
def _replace_cte_names_with_hashes(self, expression: exp.Select):
|
|
117
|
+
replacement_mapping = {}
|
|
118
|
+
for cte in expression.ctes:
|
|
119
|
+
old_name_id = cte.args["alias"].this
|
|
120
|
+
new_hashed_id = exp.to_identifier(
|
|
121
|
+
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
|
|
122
|
+
)
|
|
123
|
+
replacement_mapping[old_name_id] = new_hashed_id
|
|
124
|
+
expression = expression.transform(replace_id_value, replacement_mapping)
|
|
125
|
+
return expression
|
|
126
|
+
|
|
127
|
+
def _create_cte_from_expression(
|
|
128
|
+
self,
|
|
129
|
+
expression: exp.Expression,
|
|
130
|
+
branch_id: t.Optional[str] = None,
|
|
131
|
+
sequence_id: t.Optional[str] = None,
|
|
132
|
+
**kwargs,
|
|
133
|
+
) -> t.Tuple[exp.CTE, str]:
|
|
134
|
+
name = self._create_hash_from_expression(expression)
|
|
135
|
+
expression_to_cte = expression.copy()
|
|
136
|
+
expression_to_cte.set("with", None)
|
|
137
|
+
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
|
|
138
|
+
cte.set("branch_id", branch_id or self.branch_id)
|
|
139
|
+
cte.set("sequence_id", sequence_id or self.sequence_id)
|
|
140
|
+
return cte, name
|
|
141
|
+
|
|
142
|
+
@t.overload
|
|
143
|
+
def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
|
|
144
|
+
...
|
|
145
|
+
|
|
146
|
+
@t.overload
|
|
147
|
+
def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
|
|
148
|
+
...
|
|
149
|
+
|
|
150
|
+
def _ensure_list_of_columns(self, cols):
|
|
151
|
+
return Column.ensure_cols(ensure_list(cols))
|
|
152
|
+
|
|
153
|
+
def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
|
|
154
|
+
cols = self._ensure_list_of_columns(cols)
|
|
155
|
+
normalize(self.spark, expression or self.expression, cols)
|
|
156
|
+
return cols
|
|
157
|
+
|
|
158
|
+
def _ensure_and_normalize_col(self, col):
|
|
159
|
+
col = Column.ensure_col(col)
|
|
160
|
+
normalize(self.spark, self.expression, col)
|
|
161
|
+
return col
|
|
162
|
+
|
|
163
|
+
def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
|
|
164
|
+
df = self._resolve_pending_hints()
|
|
165
|
+
sequence_id = sequence_id or df.sequence_id
|
|
166
|
+
expression = df.expression.copy()
|
|
167
|
+
cte_expression, cte_name = df._create_cte_from_expression(
|
|
168
|
+
expression=expression, sequence_id=sequence_id
|
|
169
|
+
)
|
|
170
|
+
new_expression = df._add_ctes_to_expression(
|
|
171
|
+
exp.Select(), expression.ctes + [cte_expression]
|
|
172
|
+
)
|
|
173
|
+
sel_columns = df._get_outer_select_columns(cte_expression)
|
|
174
|
+
new_expression = new_expression.from_(cte_name).select(
|
|
175
|
+
*[x.alias_or_name for x in sel_columns]
|
|
176
|
+
)
|
|
177
|
+
return df.copy(expression=new_expression, sequence_id=sequence_id)
|
|
178
|
+
|
|
179
|
+
def _resolve_pending_hints(self) -> DataFrame:
|
|
180
|
+
df = self.copy()
|
|
181
|
+
if not self.pending_hints:
|
|
182
|
+
return df
|
|
183
|
+
expression = df.expression
|
|
184
|
+
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
|
|
185
|
+
for hint in df.pending_partition_hints:
|
|
186
|
+
hint_expression.append("expressions", hint)
|
|
187
|
+
df.pending_hints.remove(hint)
|
|
188
|
+
|
|
189
|
+
join_aliases = {
|
|
190
|
+
join_table.alias_or_name
|
|
191
|
+
for join_table in get_tables_from_expression_with_join(expression)
|
|
192
|
+
}
|
|
193
|
+
if join_aliases:
|
|
194
|
+
for hint in df.pending_join_hints:
|
|
195
|
+
for sequence_id_expression in hint.expressions:
|
|
196
|
+
sequence_id_or_name = sequence_id_expression.alias_or_name
|
|
197
|
+
sequence_ids_to_match = [sequence_id_or_name]
|
|
198
|
+
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
|
|
199
|
+
sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
|
|
200
|
+
sequence_id_or_name
|
|
201
|
+
]
|
|
202
|
+
matching_ctes = [
|
|
203
|
+
cte
|
|
204
|
+
for cte in reversed(expression.ctes)
|
|
205
|
+
if cte.args["sequence_id"] in sequence_ids_to_match
|
|
206
|
+
]
|
|
207
|
+
for matching_cte in matching_ctes:
|
|
208
|
+
if matching_cte.alias_or_name in join_aliases:
|
|
209
|
+
sequence_id_expression.set("this", matching_cte.args["alias"].this)
|
|
210
|
+
df.pending_hints.remove(hint)
|
|
211
|
+
break
|
|
212
|
+
hint_expression.append("expressions", hint)
|
|
213
|
+
if hint_expression.expressions:
|
|
214
|
+
expression.set("hint", hint_expression)
|
|
215
|
+
return df
|
|
216
|
+
|
|
217
|
+
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
|
|
218
|
+
hint_name = hint_name.upper()
|
|
219
|
+
hint_expression = (
|
|
220
|
+
exp.JoinHint(
|
|
221
|
+
this=hint_name,
|
|
222
|
+
expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
|
|
223
|
+
)
|
|
224
|
+
if hint_name in JOIN_HINTS
|
|
225
|
+
else exp.Anonymous(
|
|
226
|
+
this=hint_name, expressions=[parameter.expression for parameter in args]
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
new_df = self.copy()
|
|
230
|
+
new_df.pending_hints.append(hint_expression)
|
|
231
|
+
return new_df
|
|
232
|
+
|
|
233
|
+
def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
|
|
234
|
+
other_df = other._convert_leaf_to_cte()
|
|
235
|
+
base_expression = self.expression.copy()
|
|
236
|
+
base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
|
|
237
|
+
all_ctes = base_expression.ctes
|
|
238
|
+
other_df.expression.set("with", None)
|
|
239
|
+
base_expression.set("with", None)
|
|
240
|
+
operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
|
|
241
|
+
operation.set("with", exp.With(expressions=all_ctes))
|
|
242
|
+
return self.copy(expression=operation)._convert_leaf_to_cte()
|
|
243
|
+
|
|
244
|
+
def _cache(self, storage_level: str):
|
|
245
|
+
df = self._convert_leaf_to_cte()
|
|
246
|
+
df.expression.ctes[-1].set("cache_storage_level", storage_level)
|
|
247
|
+
return df
|
|
248
|
+
|
|
249
|
+
@classmethod
|
|
250
|
+
def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
|
|
251
|
+
expression = expression.copy()
|
|
252
|
+
with_expression = expression.args.get("with")
|
|
253
|
+
if with_expression:
|
|
254
|
+
existing_ctes = with_expression.expressions
|
|
255
|
+
existsing_cte_names = {x.alias_or_name for x in existing_ctes}
|
|
256
|
+
for cte in ctes:
|
|
257
|
+
if cte.alias_or_name not in existsing_cte_names:
|
|
258
|
+
existing_ctes.append(cte)
|
|
259
|
+
else:
|
|
260
|
+
existing_ctes = ctes
|
|
261
|
+
expression.set("with", exp.With(expressions=existing_ctes))
|
|
262
|
+
return expression
|
|
263
|
+
|
|
264
|
+
@classmethod
|
|
265
|
+
def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
|
|
266
|
+
expression = item.expression if isinstance(item, DataFrame) else item
|
|
267
|
+
return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
|
|
268
|
+
|
|
269
|
+
@classmethod
|
|
270
|
+
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
|
|
271
|
+
from sqlglot.dataframe.sql.session import SparkSession
|
|
272
|
+
|
|
273
|
+
value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
|
|
274
|
+
return f"t{zlib.crc32(value)}"[:6]
|
|
275
|
+
|
|
276
|
+
def _get_select_expressions(
|
|
277
|
+
self,
|
|
278
|
+
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
|
|
279
|
+
select_expressions: t.List[
|
|
280
|
+
t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
|
|
281
|
+
] = []
|
|
282
|
+
main_select_ctes: t.List[exp.CTE] = []
|
|
283
|
+
for cte in self.expression.ctes:
|
|
284
|
+
cache_storage_level = cte.args.get("cache_storage_level")
|
|
285
|
+
if cache_storage_level:
|
|
286
|
+
select_expression = cte.this.copy()
|
|
287
|
+
select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
|
|
288
|
+
select_expression.set("cte_alias_name", cte.alias_or_name)
|
|
289
|
+
select_expression.set("cache_storage_level", cache_storage_level)
|
|
290
|
+
select_expressions.append((exp.Cache, select_expression))
|
|
291
|
+
else:
|
|
292
|
+
main_select_ctes.append(cte)
|
|
293
|
+
main_select = self.expression.copy()
|
|
294
|
+
if main_select_ctes:
|
|
295
|
+
main_select.set("with", exp.With(expressions=main_select_ctes))
|
|
296
|
+
expression_select_pair = (type(self.output_expression_container), main_select)
|
|
297
|
+
select_expressions.append(expression_select_pair) # type: ignore
|
|
298
|
+
return select_expressions
|
|
299
|
+
|
|
300
|
+
def sql(
|
|
301
|
+
self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
|
|
302
|
+
) -> t.List[str]:
|
|
303
|
+
from sqlglot.dataframe.sql.session import SparkSession
|
|
304
|
+
|
|
305
|
+
if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
|
|
306
|
+
logger.warning(
|
|
307
|
+
f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
|
|
308
|
+
)
|
|
309
|
+
df = self._resolve_pending_hints()
|
|
310
|
+
select_expressions = df._get_select_expressions()
|
|
311
|
+
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
|
|
312
|
+
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
|
|
313
|
+
for expression_type, select_expression in select_expressions:
|
|
314
|
+
select_expression = select_expression.transform(replace_id_value, replacement_mapping)
|
|
315
|
+
if optimize:
|
|
316
|
+
quote_identifiers(select_expression)
|
|
317
|
+
select_expression = t.cast(
|
|
318
|
+
exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
|
|
319
|
+
)
|
|
320
|
+
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
|
321
|
+
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
|
|
322
|
+
if expression_type == exp.Cache:
|
|
323
|
+
cache_table_name = df._create_hash_from_expression(select_expression)
|
|
324
|
+
cache_table = exp.to_table(cache_table_name)
|
|
325
|
+
original_alias_name = select_expression.args["cte_alias_name"]
|
|
326
|
+
|
|
327
|
+
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
|
|
328
|
+
cache_table_name
|
|
329
|
+
)
|
|
330
|
+
sqlglot.schema.add_table(
|
|
331
|
+
cache_table_name,
|
|
332
|
+
{
|
|
333
|
+
expression.alias_or_name: expression.type.sql(
|
|
334
|
+
dialect=SparkSession().dialect
|
|
335
|
+
)
|
|
336
|
+
for expression in select_expression.expressions
|
|
337
|
+
},
|
|
338
|
+
dialect=SparkSession().dialect,
|
|
339
|
+
)
|
|
340
|
+
cache_storage_level = select_expression.args["cache_storage_level"]
|
|
341
|
+
options = [
|
|
342
|
+
exp.Literal.string("storageLevel"),
|
|
343
|
+
exp.Literal.string(cache_storage_level),
|
|
344
|
+
]
|
|
345
|
+
expression = exp.Cache(
|
|
346
|
+
this=cache_table, expression=select_expression, lazy=True, options=options
|
|
347
|
+
)
|
|
348
|
+
# We will drop the "view" if it exists before running the cache table
|
|
349
|
+
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
|
|
350
|
+
elif expression_type == exp.Create:
|
|
351
|
+
expression = df.output_expression_container.copy()
|
|
352
|
+
expression.set("expression", select_expression)
|
|
353
|
+
elif expression_type == exp.Insert:
|
|
354
|
+
expression = df.output_expression_container.copy()
|
|
355
|
+
select_without_ctes = select_expression.copy()
|
|
356
|
+
select_without_ctes.set("with", None)
|
|
357
|
+
expression.set("expression", select_without_ctes)
|
|
358
|
+
if select_expression.ctes:
|
|
359
|
+
expression.set("with", exp.With(expressions=select_expression.ctes))
|
|
360
|
+
elif expression_type == exp.Select:
|
|
361
|
+
expression = select_expression
|
|
362
|
+
else:
|
|
363
|
+
raise ValueError(f"Invalid expression type: {expression_type}")
|
|
364
|
+
output_expressions.append(expression)
|
|
365
|
+
|
|
366
|
+
return [
|
|
367
|
+
expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
|
|
368
|
+
for expression in output_expressions
|
|
369
|
+
]
|
|
370
|
+
|
|
371
|
+
def copy(self, **kwargs) -> DataFrame:
|
|
372
|
+
return DataFrame(**object_to_dict(self, **kwargs))
|
|
373
|
+
|
|
374
|
+
@operation(Operation.SELECT)
|
|
375
|
+
def select(self, *cols, **kwargs) -> DataFrame:
|
|
376
|
+
cols = self._ensure_and_normalize_cols(cols)
|
|
377
|
+
kwargs["append"] = kwargs.get("append", False)
|
|
378
|
+
if self.expression.args.get("joins"):
|
|
379
|
+
ambiguous_cols = [
|
|
380
|
+
col
|
|
381
|
+
for col in cols
|
|
382
|
+
if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
|
|
383
|
+
]
|
|
384
|
+
if ambiguous_cols:
|
|
385
|
+
join_table_identifiers = [
|
|
386
|
+
x.this for x in get_tables_from_expression_with_join(self.expression)
|
|
387
|
+
]
|
|
388
|
+
cte_names_in_join = [x.this for x in join_table_identifiers]
|
|
389
|
+
# If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
|
|
390
|
+
# and therefore we allow multiple columns with the same name in the result. This matches the behavior
|
|
391
|
+
# of Spark.
|
|
392
|
+
resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
|
|
393
|
+
for ambiguous_col in ambiguous_cols:
|
|
394
|
+
ctes_with_column = [
|
|
395
|
+
cte
|
|
396
|
+
for cte in self.expression.ctes
|
|
397
|
+
if cte.alias_or_name in cte_names_in_join
|
|
398
|
+
and ambiguous_col.alias_or_name in cte.this.named_selects
|
|
399
|
+
]
|
|
400
|
+
# Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
|
|
401
|
+
# use the same CTE we used before
|
|
402
|
+
cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
|
|
403
|
+
if cte:
|
|
404
|
+
resolved_column_position[ambiguous_col] += 1
|
|
405
|
+
else:
|
|
406
|
+
cte = ctes_with_column[resolved_column_position[ambiguous_col]]
|
|
407
|
+
ambiguous_col.expression.set("table", cte.alias_or_name)
|
|
408
|
+
return self.copy(
|
|
409
|
+
expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
@operation(Operation.NO_OP)
|
|
413
|
+
def alias(self, name: str, **kwargs) -> DataFrame:
|
|
414
|
+
new_sequence_id = self.spark._random_sequence_id
|
|
415
|
+
df = self.copy()
|
|
416
|
+
for join_hint in df.pending_join_hints:
|
|
417
|
+
for expression in join_hint.expressions:
|
|
418
|
+
if expression.alias_or_name == self.sequence_id:
|
|
419
|
+
expression.set("this", Column.ensure_col(new_sequence_id).expression)
|
|
420
|
+
df.spark._add_alias_to_mapping(name, new_sequence_id)
|
|
421
|
+
return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
|
|
422
|
+
|
|
423
|
+
@operation(Operation.WHERE)
|
|
424
|
+
def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
|
|
425
|
+
col = self._ensure_and_normalize_col(column)
|
|
426
|
+
return self.copy(expression=self.expression.where(col.expression))
|
|
427
|
+
|
|
428
|
+
filter = where
|
|
429
|
+
|
|
430
|
+
@operation(Operation.GROUP_BY)
|
|
431
|
+
def groupBy(self, *cols, **kwargs) -> GroupedData:
|
|
432
|
+
columns = self._ensure_and_normalize_cols(cols)
|
|
433
|
+
return GroupedData(self, columns, self.last_op)
|
|
434
|
+
|
|
435
|
+
@operation(Operation.SELECT)
|
|
436
|
+
def agg(self, *exprs, **kwargs) -> DataFrame:
|
|
437
|
+
cols = self._ensure_and_normalize_cols(exprs)
|
|
438
|
+
return self.groupBy().agg(*cols)
|
|
439
|
+
|
|
440
|
+
@operation(Operation.FROM)
|
|
441
|
+
def join(
|
|
442
|
+
self,
|
|
443
|
+
other_df: DataFrame,
|
|
444
|
+
on: t.Union[str, t.List[str], Column, t.List[Column]],
|
|
445
|
+
how: str = "inner",
|
|
446
|
+
**kwargs,
|
|
447
|
+
) -> DataFrame:
|
|
448
|
+
other_df = other_df._convert_leaf_to_cte()
|
|
449
|
+
join_columns = self._ensure_list_of_columns(on)
|
|
450
|
+
# We will determine actual "join on" expression later so we don't provide it at first
|
|
451
|
+
join_expression = self.expression.join(
|
|
452
|
+
other_df.latest_cte_name, join_type=how.replace("_", " ")
|
|
453
|
+
)
|
|
454
|
+
join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
|
|
455
|
+
self_columns = self._get_outer_select_columns(join_expression)
|
|
456
|
+
other_columns = self._get_outer_select_columns(other_df)
|
|
457
|
+
# Determines the join clause and select columns to be used passed on what type of columns were provided for
|
|
458
|
+
# the join. The columns returned changes based on how the on expression is provided.
|
|
459
|
+
if isinstance(join_columns[0].expression, exp.Column):
|
|
460
|
+
"""
|
|
461
|
+
Unique characteristics of join on column names only:
|
|
462
|
+
* The column names are put at the front of the select list
|
|
463
|
+
* The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
|
|
464
|
+
"""
|
|
465
|
+
table_names = [
|
|
466
|
+
table.alias_or_name
|
|
467
|
+
for table in get_tables_from_expression_with_join(join_expression)
|
|
468
|
+
]
|
|
469
|
+
potential_ctes = [
|
|
470
|
+
cte
|
|
471
|
+
for cte in join_expression.ctes
|
|
472
|
+
if cte.alias_or_name in table_names
|
|
473
|
+
and cte.alias_or_name != other_df.latest_cte_name
|
|
474
|
+
]
|
|
475
|
+
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
476
|
+
# tables and see if they have the column being referenced.
|
|
477
|
+
join_column_pairs = []
|
|
478
|
+
for join_column in join_columns:
|
|
479
|
+
num_matching_ctes = 0
|
|
480
|
+
for cte in potential_ctes:
|
|
481
|
+
if join_column.alias_or_name in cte.this.named_selects:
|
|
482
|
+
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
483
|
+
right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
|
|
484
|
+
join_column_pairs.append((left_column, right_column))
|
|
485
|
+
num_matching_ctes += 1
|
|
486
|
+
if num_matching_ctes > 1:
|
|
487
|
+
raise ValueError(
|
|
488
|
+
f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
|
|
489
|
+
)
|
|
490
|
+
elif num_matching_ctes == 0:
|
|
491
|
+
raise ValueError(
|
|
492
|
+
f"Column {join_column.alias_or_name} does not exist in any of the tables."
|
|
493
|
+
)
|
|
494
|
+
join_clause = functools.reduce(
|
|
495
|
+
lambda x, y: x & y,
|
|
496
|
+
[left_column == right_column for left_column, right_column in join_column_pairs],
|
|
497
|
+
)
|
|
498
|
+
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
|
|
499
|
+
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
|
|
500
|
+
select_column_names = [
|
|
501
|
+
column.alias_or_name
|
|
502
|
+
if not isinstance(column.expression.this, exp.Star)
|
|
503
|
+
else column.sql()
|
|
504
|
+
for column in self_columns + other_columns
|
|
505
|
+
]
|
|
506
|
+
select_column_names = [
|
|
507
|
+
column_name
|
|
508
|
+
for column_name in select_column_names
|
|
509
|
+
if column_name not in join_column_names
|
|
510
|
+
]
|
|
511
|
+
select_column_names = join_column_names + select_column_names
|
|
512
|
+
else:
|
|
513
|
+
"""
|
|
514
|
+
Unique characteristics of join on expressions:
|
|
515
|
+
* There is no deduplication of the results.
|
|
516
|
+
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
|
|
517
|
+
"""
|
|
518
|
+
join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
|
|
519
|
+
if len(join_columns) > 1:
|
|
520
|
+
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
521
|
+
join_clause = join_columns[0]
|
|
522
|
+
select_column_names = [column.alias_or_name for column in self_columns + other_columns]
|
|
523
|
+
|
|
524
|
+
# Update the on expression with the actual join clause to replace the dummy one from before
|
|
525
|
+
join_expression.args["joins"][-1].set("on", join_clause.expression)
|
|
526
|
+
new_df = self.copy(expression=join_expression)
|
|
527
|
+
new_df.pending_join_hints.extend(self.pending_join_hints)
|
|
528
|
+
new_df.pending_hints.extend(other_df.pending_hints)
|
|
529
|
+
new_df = new_df.select.__wrapped__(new_df, *select_column_names)
|
|
530
|
+
return new_df
|
|
531
|
+
|
|
532
|
+
@operation(Operation.ORDER_BY)
|
|
533
|
+
def orderBy(
|
|
534
|
+
self,
|
|
535
|
+
*cols: t.Union[str, Column],
|
|
536
|
+
ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
|
|
537
|
+
) -> DataFrame:
|
|
538
|
+
"""
|
|
539
|
+
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
|
|
540
|
+
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
|
|
541
|
+
is unlikely to come up.
|
|
542
|
+
"""
|
|
543
|
+
columns = self._ensure_and_normalize_cols(cols)
|
|
544
|
+
pre_ordered_col_indexes = [
|
|
545
|
+
x
|
|
546
|
+
for x in [
|
|
547
|
+
i if isinstance(col.expression, exp.Ordered) else None
|
|
548
|
+
for i, col in enumerate(columns)
|
|
549
|
+
]
|
|
550
|
+
if x is not None
|
|
551
|
+
]
|
|
552
|
+
if ascending is None:
|
|
553
|
+
ascending = [True] * len(columns)
|
|
554
|
+
elif not isinstance(ascending, list):
|
|
555
|
+
ascending = [ascending] * len(columns)
|
|
556
|
+
ascending = [bool(x) for i, x in enumerate(ascending)]
|
|
557
|
+
assert len(columns) == len(
|
|
558
|
+
ascending
|
|
559
|
+
), "The length of items in ascending must equal the number of columns provided"
|
|
560
|
+
col_and_ascending = list(zip(columns, ascending))
|
|
561
|
+
order_by_columns = [
|
|
562
|
+
exp.Ordered(this=col.expression, desc=not asc)
|
|
563
|
+
if i not in pre_ordered_col_indexes
|
|
564
|
+
else columns[i].column_expression
|
|
565
|
+
for i, (col, asc) in enumerate(col_and_ascending)
|
|
566
|
+
]
|
|
567
|
+
return self.copy(expression=self.expression.order_by(*order_by_columns))
|
|
568
|
+
|
|
569
|
+
sort = orderBy
|
|
570
|
+
|
|
571
|
+
@operation(Operation.FROM)
|
|
572
|
+
def union(self, other: DataFrame) -> DataFrame:
|
|
573
|
+
return self._set_operation(exp.Union, other, False)
|
|
574
|
+
|
|
575
|
+
unionAll = union
|
|
576
|
+
|
|
577
|
+
@operation(Operation.FROM)
|
|
578
|
+
def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
|
|
579
|
+
l_columns = self.columns
|
|
580
|
+
r_columns = other.columns
|
|
581
|
+
if not allowMissingColumns:
|
|
582
|
+
l_expressions = l_columns
|
|
583
|
+
r_expressions = l_columns
|
|
584
|
+
else:
|
|
585
|
+
l_expressions = []
|
|
586
|
+
r_expressions = []
|
|
587
|
+
r_columns_unused = copy(r_columns)
|
|
588
|
+
for l_column in l_columns:
|
|
589
|
+
l_expressions.append(l_column)
|
|
590
|
+
if l_column in r_columns:
|
|
591
|
+
r_expressions.append(l_column)
|
|
592
|
+
r_columns_unused.remove(l_column)
|
|
593
|
+
else:
|
|
594
|
+
r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
|
|
595
|
+
for r_column in r_columns_unused:
|
|
596
|
+
l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
|
|
597
|
+
r_expressions.append(r_column)
|
|
598
|
+
r_df = (
|
|
599
|
+
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
|
600
|
+
)
|
|
601
|
+
l_df = self.copy()
|
|
602
|
+
if allowMissingColumns:
|
|
603
|
+
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
|
|
604
|
+
return l_df._set_operation(exp.Union, r_df, False)
|
|
605
|
+
|
|
606
|
+
@operation(Operation.FROM)
|
|
607
|
+
def intersect(self, other: DataFrame) -> DataFrame:
|
|
608
|
+
return self._set_operation(exp.Intersect, other, True)
|
|
609
|
+
|
|
610
|
+
@operation(Operation.FROM)
|
|
611
|
+
def intersectAll(self, other: DataFrame) -> DataFrame:
|
|
612
|
+
return self._set_operation(exp.Intersect, other, False)
|
|
613
|
+
|
|
614
|
+
@operation(Operation.FROM)
|
|
615
|
+
def exceptAll(self, other: DataFrame) -> DataFrame:
|
|
616
|
+
return self._set_operation(exp.Except, other, False)
|
|
617
|
+
|
|
618
|
+
@operation(Operation.SELECT)
|
|
619
|
+
def distinct(self) -> DataFrame:
|
|
620
|
+
return self.copy(expression=self.expression.distinct())
|
|
621
|
+
|
|
622
|
+
@operation(Operation.SELECT)
|
|
623
|
+
def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
|
|
624
|
+
if not subset:
|
|
625
|
+
return self.distinct()
|
|
626
|
+
column_names = ensure_list(subset)
|
|
627
|
+
window = Window.partitionBy(*column_names).orderBy(*column_names)
|
|
628
|
+
return (
|
|
629
|
+
self.copy()
|
|
630
|
+
.withColumn("row_num", F.row_number().over(window))
|
|
631
|
+
.where(F.col("row_num") == F.lit(1))
|
|
632
|
+
.drop("row_num")
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
@operation(Operation.FROM)
|
|
636
|
+
def dropna(
|
|
637
|
+
self,
|
|
638
|
+
how: str = "any",
|
|
639
|
+
thresh: t.Optional[int] = None,
|
|
640
|
+
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
641
|
+
) -> DataFrame:
|
|
642
|
+
minimum_non_null = thresh or 0 # will be determined later if thresh is null
|
|
643
|
+
new_df = self.copy()
|
|
644
|
+
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
645
|
+
if subset:
|
|
646
|
+
null_check_columns = self._ensure_and_normalize_cols(subset)
|
|
647
|
+
else:
|
|
648
|
+
null_check_columns = all_columns
|
|
649
|
+
if thresh is None:
|
|
650
|
+
minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
|
|
651
|
+
else:
|
|
652
|
+
minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
|
|
653
|
+
if minimum_num_nulls > len(null_check_columns):
|
|
654
|
+
raise RuntimeError(
|
|
655
|
+
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
|
|
656
|
+
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
|
|
657
|
+
)
|
|
658
|
+
if_null_checks = [
|
|
659
|
+
F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
|
|
660
|
+
]
|
|
661
|
+
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
|
|
662
|
+
num_nulls = nulls_added_together.alias("num_nulls")
|
|
663
|
+
new_df = new_df.select(num_nulls, append=True)
|
|
664
|
+
filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
|
|
665
|
+
final_df = filtered_df.select(*all_columns)
|
|
666
|
+
return final_df
|
|
667
|
+
|
|
668
|
+
@operation(Operation.FROM)
|
|
669
|
+
def fillna(
|
|
670
|
+
self,
|
|
671
|
+
value: t.Union[ColumnLiterals],
|
|
672
|
+
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
673
|
+
) -> DataFrame:
|
|
674
|
+
"""
|
|
675
|
+
Functionality Difference: If you provide a value to replace a null and that type conflicts
|
|
676
|
+
with the type of the column then PySpark will just ignore your replacement.
|
|
677
|
+
This will try to cast them to be the same in some cases. So they won't always match.
|
|
678
|
+
Best to not mix types so make sure replacement is the same type as the column
|
|
679
|
+
|
|
680
|
+
Possibility for improvement: Use `typeof` function to get the type of the column
|
|
681
|
+
and check if it matches the type of the value provided. If not then make it null.
|
|
682
|
+
"""
|
|
683
|
+
from sqlglot.dataframe.sql.functions import lit
|
|
684
|
+
|
|
685
|
+
values = None
|
|
686
|
+
columns = None
|
|
687
|
+
new_df = self.copy()
|
|
688
|
+
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
689
|
+
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
|
690
|
+
if isinstance(value, dict):
|
|
691
|
+
values = list(value.values())
|
|
692
|
+
columns = self._ensure_and_normalize_cols(list(value))
|
|
693
|
+
if not columns:
|
|
694
|
+
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
|
695
|
+
if not values:
|
|
696
|
+
values = [value] * len(columns)
|
|
697
|
+
value_columns = [lit(value) for value in values]
|
|
698
|
+
|
|
699
|
+
null_replacement_mapping = {
|
|
700
|
+
column.alias_or_name: (
|
|
701
|
+
F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
|
|
702
|
+
)
|
|
703
|
+
for column, value in zip(columns, value_columns)
|
|
704
|
+
}
|
|
705
|
+
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
|
|
706
|
+
null_replacement_columns = [
|
|
707
|
+
null_replacement_mapping[column.alias_or_name] for column in all_columns
|
|
708
|
+
]
|
|
709
|
+
new_df = new_df.select(*null_replacement_columns)
|
|
710
|
+
return new_df
|
|
711
|
+
|
|
712
|
+
@operation(Operation.FROM)
|
|
713
|
+
def replace(
|
|
714
|
+
self,
|
|
715
|
+
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
|
716
|
+
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
|
717
|
+
subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
|
|
718
|
+
) -> DataFrame:
|
|
719
|
+
from sqlglot.dataframe.sql.functions import lit
|
|
720
|
+
|
|
721
|
+
old_values = None
|
|
722
|
+
new_df = self.copy()
|
|
723
|
+
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
724
|
+
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
|
725
|
+
|
|
726
|
+
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
|
727
|
+
if isinstance(to_replace, dict):
|
|
728
|
+
old_values = list(to_replace)
|
|
729
|
+
new_values = list(to_replace.values())
|
|
730
|
+
elif not old_values and isinstance(to_replace, list):
|
|
731
|
+
assert isinstance(value, list), "value must be a list since the replacements are a list"
|
|
732
|
+
assert len(to_replace) == len(
|
|
733
|
+
value
|
|
734
|
+
), "the replacements and values must be the same length"
|
|
735
|
+
old_values = to_replace
|
|
736
|
+
new_values = value
|
|
737
|
+
else:
|
|
738
|
+
old_values = [to_replace] * len(columns)
|
|
739
|
+
new_values = [value] * len(columns)
|
|
740
|
+
old_values = [lit(value) for value in old_values]
|
|
741
|
+
new_values = [lit(value) for value in new_values]
|
|
742
|
+
|
|
743
|
+
replacement_mapping = {}
|
|
744
|
+
for column in columns:
|
|
745
|
+
expression = Column(None)
|
|
746
|
+
for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
|
|
747
|
+
if i == 0:
|
|
748
|
+
expression = F.when(column == old_value, new_value)
|
|
749
|
+
else:
|
|
750
|
+
expression = expression.when(column == old_value, new_value) # type: ignore
|
|
751
|
+
replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
|
|
752
|
+
column.expression.alias_or_name
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
replacement_mapping = {**all_column_mapping, **replacement_mapping}
|
|
756
|
+
replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
|
|
757
|
+
new_df = new_df.select(*replacement_columns)
|
|
758
|
+
return new_df
|
|
759
|
+
|
|
760
|
+
@operation(Operation.SELECT)
|
|
761
|
+
def withColumn(self, colName: str, col: Column) -> DataFrame:
|
|
762
|
+
col = self._ensure_and_normalize_col(col)
|
|
763
|
+
existing_col_names = self.expression.named_selects
|
|
764
|
+
existing_col_index = (
|
|
765
|
+
existing_col_names.index(colName) if colName in existing_col_names else None
|
|
766
|
+
)
|
|
767
|
+
if existing_col_index:
|
|
768
|
+
expression = self.expression.copy()
|
|
769
|
+
expression.expressions[existing_col_index] = col.expression
|
|
770
|
+
return self.copy(expression=expression)
|
|
771
|
+
return self.copy().select(col.alias(colName), append=True)
|
|
772
|
+
|
|
773
|
+
@operation(Operation.SELECT)
|
|
774
|
+
def withColumnRenamed(self, existing: str, new: str):
|
|
775
|
+
expression = self.expression.copy()
|
|
776
|
+
existing_columns = [
|
|
777
|
+
expression
|
|
778
|
+
for expression in expression.expressions
|
|
779
|
+
if expression.alias_or_name == existing
|
|
780
|
+
]
|
|
781
|
+
if not existing_columns:
|
|
782
|
+
raise ValueError("Tried to rename a column that doesn't exist")
|
|
783
|
+
for existing_column in existing_columns:
|
|
784
|
+
if isinstance(existing_column, exp.Column):
|
|
785
|
+
existing_column.replace(exp.alias_(existing_column, new))
|
|
786
|
+
else:
|
|
787
|
+
existing_column.set("alias", exp.to_identifier(new))
|
|
788
|
+
return self.copy(expression=expression)
|
|
789
|
+
|
|
790
|
+
@operation(Operation.SELECT)
|
|
791
|
+
def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
|
|
792
|
+
all_columns = self._get_outer_select_columns(self.expression)
|
|
793
|
+
drop_cols = self._ensure_and_normalize_cols(cols)
|
|
794
|
+
new_columns = [
|
|
795
|
+
col
|
|
796
|
+
for col in all_columns
|
|
797
|
+
if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
|
|
798
|
+
]
|
|
799
|
+
return self.copy().select(*new_columns, append=False)
|
|
800
|
+
|
|
801
|
+
@operation(Operation.LIMIT)
|
|
802
|
+
def limit(self, num: int) -> DataFrame:
|
|
803
|
+
return self.copy(expression=self.expression.limit(num))
|
|
804
|
+
|
|
805
|
+
@operation(Operation.NO_OP)
|
|
806
|
+
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
|
|
807
|
+
parameter_list = ensure_list(parameters)
|
|
808
|
+
parameter_columns = (
|
|
809
|
+
self._ensure_list_of_columns(parameter_list)
|
|
810
|
+
if parameters
|
|
811
|
+
else Column.ensure_cols([self.sequence_id])
|
|
812
|
+
)
|
|
813
|
+
return self._hint(name, parameter_columns)
|
|
814
|
+
|
|
815
|
+
@operation(Operation.NO_OP)
|
|
816
|
+
def repartition(
|
|
817
|
+
self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
|
|
818
|
+
) -> DataFrame:
|
|
819
|
+
num_partition_cols = self._ensure_list_of_columns(numPartitions)
|
|
820
|
+
columns = self._ensure_and_normalize_cols(cols)
|
|
821
|
+
args = num_partition_cols + columns
|
|
822
|
+
return self._hint("repartition", args)
|
|
823
|
+
|
|
824
|
+
@operation(Operation.NO_OP)
|
|
825
|
+
def coalesce(self, numPartitions: int) -> DataFrame:
|
|
826
|
+
num_partitions = Column.ensure_cols([numPartitions])
|
|
827
|
+
return self._hint("coalesce", num_partitions)
|
|
828
|
+
|
|
829
|
+
@operation(Operation.NO_OP)
|
|
830
|
+
def cache(self) -> DataFrame:
|
|
831
|
+
return self._cache(storage_level="MEMORY_AND_DISK")
|
|
832
|
+
|
|
833
|
+
@operation(Operation.NO_OP)
|
|
834
|
+
def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
|
|
835
|
+
"""
|
|
836
|
+
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
|
|
837
|
+
"""
|
|
838
|
+
return self._cache(storageLevel)
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
class DataFrameNaFunctions:
|
|
842
|
+
def __init__(self, df: DataFrame):
|
|
843
|
+
self.df = df
|
|
844
|
+
|
|
845
|
+
def drop(
|
|
846
|
+
self,
|
|
847
|
+
how: str = "any",
|
|
848
|
+
thresh: t.Optional[int] = None,
|
|
849
|
+
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
850
|
+
) -> DataFrame:
|
|
851
|
+
return self.df.dropna(how=how, thresh=thresh, subset=subset)
|
|
852
|
+
|
|
853
|
+
def fill(
|
|
854
|
+
self,
|
|
855
|
+
value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
|
|
856
|
+
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
857
|
+
) -> DataFrame:
|
|
858
|
+
return self.df.fillna(value=value, subset=subset)
|
|
859
|
+
|
|
860
|
+
def replace(
|
|
861
|
+
self,
|
|
862
|
+
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
|
863
|
+
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
|
864
|
+
subset: t.Optional[t.Union[str, t.List[str]]] = None,
|
|
865
|
+
) -> DataFrame:
|
|
866
|
+
return self.df.replace(to_replace=to_replace, value=value, subset=subset)
|