@altimateai/altimate-code 0.5.1 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. package/CHANGELOG.md +35 -0
  2. package/bin/altimate +6 -0
  3. package/bin/altimate-code +6 -0
  4. package/dbt-tools/bin/altimate-dbt +2 -0
  5. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/__init__.py +0 -0
  6. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/fetch_schema.py +35 -0
  7. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/utils.py +353 -0
  8. package/dbt-tools/dist/altimate_python_packages/altimate_packages/altimate/validate_sql.py +114 -0
  9. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__init__.py +178 -0
  10. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/__main__.py +96 -0
  11. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/_typing.py +17 -0
  12. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/__init__.py +3 -0
  13. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/__init__.py +18 -0
  14. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/_typing.py +18 -0
  15. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/column.py +332 -0
  16. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/dataframe.py +866 -0
  17. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/functions.py +1267 -0
  18. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/group.py +59 -0
  19. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/normalize.py +78 -0
  20. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/operations.py +53 -0
  21. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/readwriter.py +108 -0
  22. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/session.py +190 -0
  23. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/transforms.py +9 -0
  24. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/types.py +212 -0
  25. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/util.py +32 -0
  26. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dataframe/sql/window.py +134 -0
  27. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/__init__.py +118 -0
  28. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/athena.py +166 -0
  29. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/bigquery.py +1331 -0
  30. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/clickhouse.py +1393 -0
  31. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/databricks.py +131 -0
  32. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dialect.py +1915 -0
  33. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/doris.py +561 -0
  34. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/drill.py +157 -0
  35. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/druid.py +20 -0
  36. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/duckdb.py +1159 -0
  37. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/dune.py +16 -0
  38. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/hive.py +787 -0
  39. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/materialize.py +94 -0
  40. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/mysql.py +1324 -0
  41. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/oracle.py +378 -0
  42. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/postgres.py +778 -0
  43. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/presto.py +788 -0
  44. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/prql.py +203 -0
  45. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/redshift.py +448 -0
  46. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/risingwave.py +78 -0
  47. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/snowflake.py +1464 -0
  48. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark.py +202 -0
  49. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/spark2.py +349 -0
  50. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/sqlite.py +320 -0
  51. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/starrocks.py +343 -0
  52. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tableau.py +61 -0
  53. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/teradata.py +356 -0
  54. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/trino.py +115 -0
  55. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/dialects/tsql.py +1403 -0
  56. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/diff.py +456 -0
  57. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/errors.py +93 -0
  58. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/__init__.py +95 -0
  59. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/context.py +101 -0
  60. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/env.py +246 -0
  61. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/python.py +460 -0
  62. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/executor/table.py +155 -0
  63. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/expressions.py +8870 -0
  64. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/generator.py +4993 -0
  65. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/helper.py +582 -0
  66. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/jsonpath.py +227 -0
  67. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/lineage.py +423 -0
  68. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/__init__.py +11 -0
  69. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/annotate_types.py +589 -0
  70. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/canonicalize.py +222 -0
  71. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +43 -0
  72. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_joins.py +181 -0
  73. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +189 -0
  74. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +50 -0
  75. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/merge_subqueries.py +415 -0
  76. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize.py +200 -0
  77. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +64 -0
  78. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimize_joins.py +91 -0
  79. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/optimizer.py +94 -0
  80. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +222 -0
  81. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/pushdown_projections.py +172 -0
  82. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify.py +104 -0
  83. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_columns.py +1024 -0
  84. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/qualify_tables.py +155 -0
  85. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/scope.py +904 -0
  86. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/simplify.py +1587 -0
  87. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +302 -0
  88. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/parser.py +8501 -0
  89. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/planner.py +463 -0
  90. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/schema.py +588 -0
  91. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/serde.py +68 -0
  92. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/time.py +687 -0
  93. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/tokens.py +1520 -0
  94. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/transforms.py +1020 -0
  95. package/dbt-tools/dist/altimate_python_packages/altimate_packages/sqlglot/trie.py +81 -0
  96. package/dbt-tools/dist/altimate_python_packages/dbt_core_integration.py +825 -0
  97. package/dbt-tools/dist/altimate_python_packages/dbt_utils.py +157 -0
  98. package/dbt-tools/dist/index.js +23859 -0
  99. package/package.json +13 -13
  100. package/postinstall.mjs +42 -0
  101. package/skills/altimate-setup/SKILL.md +31 -0
@@ -0,0 +1,866 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import logging
5
+ import typing as t
6
+ import zlib
7
+ from copy import copy
8
+
9
+ import sqlglot as sqlglot
10
+ from sqlglot import Dialect, expressions as exp
11
+ from sqlglot.dataframe.sql import functions as F
12
+ from sqlglot.dataframe.sql.column import Column
13
+ from sqlglot.dataframe.sql.group import GroupedData
14
+ from sqlglot.dataframe.sql.normalize import normalize
15
+ from sqlglot.dataframe.sql.operations import Operation, operation
16
+ from sqlglot.dataframe.sql.readwriter import DataFrameWriter
17
+ from sqlglot.dataframe.sql.transforms import replace_id_value
18
+ from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join
19
+ from sqlglot.dataframe.sql.window import Window
20
+ from sqlglot.helper import ensure_list, object_to_dict, seq_get
21
+ from sqlglot.optimizer import optimize as optimize_func
22
+ from sqlglot.optimizer.qualify_columns import quote_identifiers
23
+
24
+ if t.TYPE_CHECKING:
25
+ from sqlglot.dataframe.sql._typing import (
26
+ ColumnLiterals,
27
+ ColumnOrLiteral,
28
+ ColumnOrName,
29
+ OutputExpressionContainer,
30
+ )
31
+ from sqlglot.dataframe.sql.session import SparkSession
32
+ from sqlglot.dialects.dialect import DialectType
33
+
34
+ logger = logging.getLogger("sqlglot")
35
+
36
+ JOIN_HINTS = {
37
+ "BROADCAST",
38
+ "BROADCASTJOIN",
39
+ "MAPJOIN",
40
+ "MERGE",
41
+ "SHUFFLEMERGE",
42
+ "MERGEJOIN",
43
+ "SHUFFLE_HASH",
44
+ "SHUFFLE_REPLICATE_NL",
45
+ }
46
+
47
+
48
+ class DataFrame:
49
+ def __init__(
50
+ self,
51
+ spark: SparkSession,
52
+ expression: exp.Select,
53
+ branch_id: t.Optional[str] = None,
54
+ sequence_id: t.Optional[str] = None,
55
+ last_op: Operation = Operation.INIT,
56
+ pending_hints: t.Optional[t.List[exp.Expression]] = None,
57
+ output_expression_container: t.Optional[OutputExpressionContainer] = None,
58
+ **kwargs,
59
+ ):
60
+ self.spark = spark
61
+ self.expression = expression
62
+ self.branch_id = branch_id or self.spark._random_branch_id
63
+ self.sequence_id = sequence_id or self.spark._random_sequence_id
64
+ self.last_op = last_op
65
+ self.pending_hints = pending_hints or []
66
+ self.output_expression_container = output_expression_container or exp.Select()
67
+
68
+ def __getattr__(self, column_name: str) -> Column:
69
+ return self[column_name]
70
+
71
+ def __getitem__(self, column_name: str) -> Column:
72
+ column_name = f"{self.branch_id}.{column_name}"
73
+ return Column(column_name)
74
+
75
+ def __copy__(self):
76
+ return self.copy()
77
+
78
+ @property
79
+ def sparkSession(self):
80
+ return self.spark
81
+
82
+ @property
83
+ def write(self):
84
+ return DataFrameWriter(self)
85
+
86
+ @property
87
+ def latest_cte_name(self) -> str:
88
+ if not self.expression.ctes:
89
+ from_exp = self.expression.args["from"]
90
+ if from_exp.alias_or_name:
91
+ return from_exp.alias_or_name
92
+ table_alias = from_exp.find(exp.TableAlias)
93
+ if not table_alias:
94
+ raise RuntimeError(
95
+ f"Could not find an alias name for this expression: {self.expression}"
96
+ )
97
+ return table_alias.alias_or_name
98
+ return self.expression.ctes[-1].alias
99
+
100
+ @property
101
+ def pending_join_hints(self):
102
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
103
+
104
+ @property
105
+ def pending_partition_hints(self):
106
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
107
+
108
+ @property
109
+ def columns(self) -> t.List[str]:
110
+ return self.expression.named_selects
111
+
112
+ @property
113
+ def na(self) -> DataFrameNaFunctions:
114
+ return DataFrameNaFunctions(self)
115
+
116
+ def _replace_cte_names_with_hashes(self, expression: exp.Select):
117
+ replacement_mapping = {}
118
+ for cte in expression.ctes:
119
+ old_name_id = cte.args["alias"].this
120
+ new_hashed_id = exp.to_identifier(
121
+ self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
122
+ )
123
+ replacement_mapping[old_name_id] = new_hashed_id
124
+ expression = expression.transform(replace_id_value, replacement_mapping)
125
+ return expression
126
+
127
+ def _create_cte_from_expression(
128
+ self,
129
+ expression: exp.Expression,
130
+ branch_id: t.Optional[str] = None,
131
+ sequence_id: t.Optional[str] = None,
132
+ **kwargs,
133
+ ) -> t.Tuple[exp.CTE, str]:
134
+ name = self._create_hash_from_expression(expression)
135
+ expression_to_cte = expression.copy()
136
+ expression_to_cte.set("with", None)
137
+ cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
138
+ cte.set("branch_id", branch_id or self.branch_id)
139
+ cte.set("sequence_id", sequence_id or self.sequence_id)
140
+ return cte, name
141
+
142
+ @t.overload
143
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
144
+ ...
145
+
146
+ @t.overload
147
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
148
+ ...
149
+
150
+ def _ensure_list_of_columns(self, cols):
151
+ return Column.ensure_cols(ensure_list(cols))
152
+
153
+ def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None):
154
+ cols = self._ensure_list_of_columns(cols)
155
+ normalize(self.spark, expression or self.expression, cols)
156
+ return cols
157
+
158
+ def _ensure_and_normalize_col(self, col):
159
+ col = Column.ensure_col(col)
160
+ normalize(self.spark, self.expression, col)
161
+ return col
162
+
163
+ def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame:
164
+ df = self._resolve_pending_hints()
165
+ sequence_id = sequence_id or df.sequence_id
166
+ expression = df.expression.copy()
167
+ cte_expression, cte_name = df._create_cte_from_expression(
168
+ expression=expression, sequence_id=sequence_id
169
+ )
170
+ new_expression = df._add_ctes_to_expression(
171
+ exp.Select(), expression.ctes + [cte_expression]
172
+ )
173
+ sel_columns = df._get_outer_select_columns(cte_expression)
174
+ new_expression = new_expression.from_(cte_name).select(
175
+ *[x.alias_or_name for x in sel_columns]
176
+ )
177
+ return df.copy(expression=new_expression, sequence_id=sequence_id)
178
+
179
+ def _resolve_pending_hints(self) -> DataFrame:
180
+ df = self.copy()
181
+ if not self.pending_hints:
182
+ return df
183
+ expression = df.expression
184
+ hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
185
+ for hint in df.pending_partition_hints:
186
+ hint_expression.append("expressions", hint)
187
+ df.pending_hints.remove(hint)
188
+
189
+ join_aliases = {
190
+ join_table.alias_or_name
191
+ for join_table in get_tables_from_expression_with_join(expression)
192
+ }
193
+ if join_aliases:
194
+ for hint in df.pending_join_hints:
195
+ for sequence_id_expression in hint.expressions:
196
+ sequence_id_or_name = sequence_id_expression.alias_or_name
197
+ sequence_ids_to_match = [sequence_id_or_name]
198
+ if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
199
+ sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
200
+ sequence_id_or_name
201
+ ]
202
+ matching_ctes = [
203
+ cte
204
+ for cte in reversed(expression.ctes)
205
+ if cte.args["sequence_id"] in sequence_ids_to_match
206
+ ]
207
+ for matching_cte in matching_ctes:
208
+ if matching_cte.alias_or_name in join_aliases:
209
+ sequence_id_expression.set("this", matching_cte.args["alias"].this)
210
+ df.pending_hints.remove(hint)
211
+ break
212
+ hint_expression.append("expressions", hint)
213
+ if hint_expression.expressions:
214
+ expression.set("hint", hint_expression)
215
+ return df
216
+
217
+ def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
218
+ hint_name = hint_name.upper()
219
+ hint_expression = (
220
+ exp.JoinHint(
221
+ this=hint_name,
222
+ expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
223
+ )
224
+ if hint_name in JOIN_HINTS
225
+ else exp.Anonymous(
226
+ this=hint_name, expressions=[parameter.expression for parameter in args]
227
+ )
228
+ )
229
+ new_df = self.copy()
230
+ new_df.pending_hints.append(hint_expression)
231
+ return new_df
232
+
233
+ def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool):
234
+ other_df = other._convert_leaf_to_cte()
235
+ base_expression = self.expression.copy()
236
+ base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
237
+ all_ctes = base_expression.ctes
238
+ other_df.expression.set("with", None)
239
+ base_expression.set("with", None)
240
+ operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
241
+ operation.set("with", exp.With(expressions=all_ctes))
242
+ return self.copy(expression=operation)._convert_leaf_to_cte()
243
+
244
+ def _cache(self, storage_level: str):
245
+ df = self._convert_leaf_to_cte()
246
+ df.expression.ctes[-1].set("cache_storage_level", storage_level)
247
+ return df
248
+
249
+ @classmethod
250
+ def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
251
+ expression = expression.copy()
252
+ with_expression = expression.args.get("with")
253
+ if with_expression:
254
+ existing_ctes = with_expression.expressions
255
+ existsing_cte_names = {x.alias_or_name for x in existing_ctes}
256
+ for cte in ctes:
257
+ if cte.alias_or_name not in existsing_cte_names:
258
+ existing_ctes.append(cte)
259
+ else:
260
+ existing_ctes = ctes
261
+ expression.set("with", exp.With(expressions=existing_ctes))
262
+ return expression
263
+
264
+ @classmethod
265
+ def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
266
+ expression = item.expression if isinstance(item, DataFrame) else item
267
+ return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
268
+
269
+ @classmethod
270
+ def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
271
+ from sqlglot.dataframe.sql.session import SparkSession
272
+
273
+ value = expression.sql(dialect=SparkSession().dialect).encode("utf-8")
274
+ return f"t{zlib.crc32(value)}"[:6]
275
+
276
+ def _get_select_expressions(
277
+ self,
278
+ ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
279
+ select_expressions: t.List[
280
+ t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
281
+ ] = []
282
+ main_select_ctes: t.List[exp.CTE] = []
283
+ for cte in self.expression.ctes:
284
+ cache_storage_level = cte.args.get("cache_storage_level")
285
+ if cache_storage_level:
286
+ select_expression = cte.this.copy()
287
+ select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
288
+ select_expression.set("cte_alias_name", cte.alias_or_name)
289
+ select_expression.set("cache_storage_level", cache_storage_level)
290
+ select_expressions.append((exp.Cache, select_expression))
291
+ else:
292
+ main_select_ctes.append(cte)
293
+ main_select = self.expression.copy()
294
+ if main_select_ctes:
295
+ main_select.set("with", exp.With(expressions=main_select_ctes))
296
+ expression_select_pair = (type(self.output_expression_container), main_select)
297
+ select_expressions.append(expression_select_pair) # type: ignore
298
+ return select_expressions
299
+
300
+ def sql(
301
+ self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs
302
+ ) -> t.List[str]:
303
+ from sqlglot.dataframe.sql.session import SparkSession
304
+
305
+ if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect:
306
+ logger.warning(
307
+ f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern."
308
+ )
309
+ df = self._resolve_pending_hints()
310
+ select_expressions = df._get_select_expressions()
311
+ output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
312
+ replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
313
+ for expression_type, select_expression in select_expressions:
314
+ select_expression = select_expression.transform(replace_id_value, replacement_mapping)
315
+ if optimize:
316
+ quote_identifiers(select_expression)
317
+ select_expression = t.cast(
318
+ exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect)
319
+ )
320
+ select_expression = df._replace_cte_names_with_hashes(select_expression)
321
+ expression: t.Union[exp.Select, exp.Cache, exp.Drop]
322
+ if expression_type == exp.Cache:
323
+ cache_table_name = df._create_hash_from_expression(select_expression)
324
+ cache_table = exp.to_table(cache_table_name)
325
+ original_alias_name = select_expression.args["cte_alias_name"]
326
+
327
+ replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
328
+ cache_table_name
329
+ )
330
+ sqlglot.schema.add_table(
331
+ cache_table_name,
332
+ {
333
+ expression.alias_or_name: expression.type.sql(
334
+ dialect=SparkSession().dialect
335
+ )
336
+ for expression in select_expression.expressions
337
+ },
338
+ dialect=SparkSession().dialect,
339
+ )
340
+ cache_storage_level = select_expression.args["cache_storage_level"]
341
+ options = [
342
+ exp.Literal.string("storageLevel"),
343
+ exp.Literal.string(cache_storage_level),
344
+ ]
345
+ expression = exp.Cache(
346
+ this=cache_table, expression=select_expression, lazy=True, options=options
347
+ )
348
+ # We will drop the "view" if it exists before running the cache table
349
+ output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
350
+ elif expression_type == exp.Create:
351
+ expression = df.output_expression_container.copy()
352
+ expression.set("expression", select_expression)
353
+ elif expression_type == exp.Insert:
354
+ expression = df.output_expression_container.copy()
355
+ select_without_ctes = select_expression.copy()
356
+ select_without_ctes.set("with", None)
357
+ expression.set("expression", select_without_ctes)
358
+ if select_expression.ctes:
359
+ expression.set("with", exp.With(expressions=select_expression.ctes))
360
+ elif expression_type == exp.Select:
361
+ expression = select_expression
362
+ else:
363
+ raise ValueError(f"Invalid expression type: {expression_type}")
364
+ output_expressions.append(expression)
365
+
366
+ return [
367
+ expression.sql(**{"dialect": SparkSession().dialect, **kwargs})
368
+ for expression in output_expressions
369
+ ]
370
+
371
+ def copy(self, **kwargs) -> DataFrame:
372
+ return DataFrame(**object_to_dict(self, **kwargs))
373
+
374
+ @operation(Operation.SELECT)
375
+ def select(self, *cols, **kwargs) -> DataFrame:
376
+ cols = self._ensure_and_normalize_cols(cols)
377
+ kwargs["append"] = kwargs.get("append", False)
378
+ if self.expression.args.get("joins"):
379
+ ambiguous_cols = [
380
+ col
381
+ for col in cols
382
+ if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
383
+ ]
384
+ if ambiguous_cols:
385
+ join_table_identifiers = [
386
+ x.this for x in get_tables_from_expression_with_join(self.expression)
387
+ ]
388
+ cte_names_in_join = [x.this for x in join_table_identifiers]
389
+ # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
390
+ # and therefore we allow multiple columns with the same name in the result. This matches the behavior
391
+ # of Spark.
392
+ resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
393
+ for ambiguous_col in ambiguous_cols:
394
+ ctes_with_column = [
395
+ cte
396
+ for cte in self.expression.ctes
397
+ if cte.alias_or_name in cte_names_in_join
398
+ and ambiguous_col.alias_or_name in cte.this.named_selects
399
+ ]
400
+ # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
401
+ # use the same CTE we used before
402
+ cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
403
+ if cte:
404
+ resolved_column_position[ambiguous_col] += 1
405
+ else:
406
+ cte = ctes_with_column[resolved_column_position[ambiguous_col]]
407
+ ambiguous_col.expression.set("table", cte.alias_or_name)
408
+ return self.copy(
409
+ expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs
410
+ )
411
+
412
+ @operation(Operation.NO_OP)
413
+ def alias(self, name: str, **kwargs) -> DataFrame:
414
+ new_sequence_id = self.spark._random_sequence_id
415
+ df = self.copy()
416
+ for join_hint in df.pending_join_hints:
417
+ for expression in join_hint.expressions:
418
+ if expression.alias_or_name == self.sequence_id:
419
+ expression.set("this", Column.ensure_col(new_sequence_id).expression)
420
+ df.spark._add_alias_to_mapping(name, new_sequence_id)
421
+ return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
422
+
423
+ @operation(Operation.WHERE)
424
+ def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame:
425
+ col = self._ensure_and_normalize_col(column)
426
+ return self.copy(expression=self.expression.where(col.expression))
427
+
428
+ filter = where
429
+
430
+ @operation(Operation.GROUP_BY)
431
+ def groupBy(self, *cols, **kwargs) -> GroupedData:
432
+ columns = self._ensure_and_normalize_cols(cols)
433
+ return GroupedData(self, columns, self.last_op)
434
+
435
+ @operation(Operation.SELECT)
436
+ def agg(self, *exprs, **kwargs) -> DataFrame:
437
+ cols = self._ensure_and_normalize_cols(exprs)
438
+ return self.groupBy().agg(*cols)
439
+
440
+ @operation(Operation.FROM)
441
+ def join(
442
+ self,
443
+ other_df: DataFrame,
444
+ on: t.Union[str, t.List[str], Column, t.List[Column]],
445
+ how: str = "inner",
446
+ **kwargs,
447
+ ) -> DataFrame:
448
+ other_df = other_df._convert_leaf_to_cte()
449
+ join_columns = self._ensure_list_of_columns(on)
450
+ # We will determine actual "join on" expression later so we don't provide it at first
451
+ join_expression = self.expression.join(
452
+ other_df.latest_cte_name, join_type=how.replace("_", " ")
453
+ )
454
+ join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
455
+ self_columns = self._get_outer_select_columns(join_expression)
456
+ other_columns = self._get_outer_select_columns(other_df)
457
+ # Determines the join clause and select columns to be used passed on what type of columns were provided for
458
+ # the join. The columns returned changes based on how the on expression is provided.
459
+ if isinstance(join_columns[0].expression, exp.Column):
460
+ """
461
+ Unique characteristics of join on column names only:
462
+ * The column names are put at the front of the select list
463
+ * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
464
+ """
465
+ table_names = [
466
+ table.alias_or_name
467
+ for table in get_tables_from_expression_with_join(join_expression)
468
+ ]
469
+ potential_ctes = [
470
+ cte
471
+ for cte in join_expression.ctes
472
+ if cte.alias_or_name in table_names
473
+ and cte.alias_or_name != other_df.latest_cte_name
474
+ ]
475
+ # Determine the table to reference for the left side of the join by checking each of the left side
476
+ # tables and see if they have the column being referenced.
477
+ join_column_pairs = []
478
+ for join_column in join_columns:
479
+ num_matching_ctes = 0
480
+ for cte in potential_ctes:
481
+ if join_column.alias_or_name in cte.this.named_selects:
482
+ left_column = join_column.copy().set_table_name(cte.alias_or_name)
483
+ right_column = join_column.copy().set_table_name(other_df.latest_cte_name)
484
+ join_column_pairs.append((left_column, right_column))
485
+ num_matching_ctes += 1
486
+ if num_matching_ctes > 1:
487
+ raise ValueError(
488
+ f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
489
+ )
490
+ elif num_matching_ctes == 0:
491
+ raise ValueError(
492
+ f"Column {join_column.alias_or_name} does not exist in any of the tables."
493
+ )
494
+ join_clause = functools.reduce(
495
+ lambda x, y: x & y,
496
+ [left_column == right_column for left_column, right_column in join_column_pairs],
497
+ )
498
+ join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
499
+ # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
500
+ select_column_names = [
501
+ column.alias_or_name
502
+ if not isinstance(column.expression.this, exp.Star)
503
+ else column.sql()
504
+ for column in self_columns + other_columns
505
+ ]
506
+ select_column_names = [
507
+ column_name
508
+ for column_name in select_column_names
509
+ if column_name not in join_column_names
510
+ ]
511
+ select_column_names = join_column_names + select_column_names
512
+ else:
513
+ """
514
+ Unique characteristics of join on expressions:
515
+ * There is no deduplication of the results.
516
+ * The left join dataframe columns go first and right come after. No sort preference is given to join columns
517
+ """
518
+ join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
519
+ if len(join_columns) > 1:
520
+ join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
521
+ join_clause = join_columns[0]
522
+ select_column_names = [column.alias_or_name for column in self_columns + other_columns]
523
+
524
+ # Update the on expression with the actual join clause to replace the dummy one from before
525
+ join_expression.args["joins"][-1].set("on", join_clause.expression)
526
+ new_df = self.copy(expression=join_expression)
527
+ new_df.pending_join_hints.extend(self.pending_join_hints)
528
+ new_df.pending_hints.extend(other_df.pending_hints)
529
+ new_df = new_df.select.__wrapped__(new_df, *select_column_names)
530
+ return new_df
531
+
532
+ @operation(Operation.ORDER_BY)
533
+ def orderBy(
534
+ self,
535
+ *cols: t.Union[str, Column],
536
+ ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
537
+ ) -> DataFrame:
538
+ """
539
+ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
540
+ has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
541
+ is unlikely to come up.
542
+ """
543
+ columns = self._ensure_and_normalize_cols(cols)
544
+ pre_ordered_col_indexes = [
545
+ x
546
+ for x in [
547
+ i if isinstance(col.expression, exp.Ordered) else None
548
+ for i, col in enumerate(columns)
549
+ ]
550
+ if x is not None
551
+ ]
552
+ if ascending is None:
553
+ ascending = [True] * len(columns)
554
+ elif not isinstance(ascending, list):
555
+ ascending = [ascending] * len(columns)
556
+ ascending = [bool(x) for i, x in enumerate(ascending)]
557
+ assert len(columns) == len(
558
+ ascending
559
+ ), "The length of items in ascending must equal the number of columns provided"
560
+ col_and_ascending = list(zip(columns, ascending))
561
+ order_by_columns = [
562
+ exp.Ordered(this=col.expression, desc=not asc)
563
+ if i not in pre_ordered_col_indexes
564
+ else columns[i].column_expression
565
+ for i, (col, asc) in enumerate(col_and_ascending)
566
+ ]
567
+ return self.copy(expression=self.expression.order_by(*order_by_columns))
568
+
569
+ sort = orderBy
570
+
571
+ @operation(Operation.FROM)
572
+ def union(self, other: DataFrame) -> DataFrame:
573
+ return self._set_operation(exp.Union, other, False)
574
+
575
+ unionAll = union
576
+
577
+ @operation(Operation.FROM)
578
+ def unionByName(self, other: DataFrame, allowMissingColumns: bool = False):
579
+ l_columns = self.columns
580
+ r_columns = other.columns
581
+ if not allowMissingColumns:
582
+ l_expressions = l_columns
583
+ r_expressions = l_columns
584
+ else:
585
+ l_expressions = []
586
+ r_expressions = []
587
+ r_columns_unused = copy(r_columns)
588
+ for l_column in l_columns:
589
+ l_expressions.append(l_column)
590
+ if l_column in r_columns:
591
+ r_expressions.append(l_column)
592
+ r_columns_unused.remove(l_column)
593
+ else:
594
+ r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
595
+ for r_column in r_columns_unused:
596
+ l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
597
+ r_expressions.append(r_column)
598
+ r_df = (
599
+ other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
600
+ )
601
+ l_df = self.copy()
602
+ if allowMissingColumns:
603
+ l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
604
+ return l_df._set_operation(exp.Union, r_df, False)
605
+
606
+ @operation(Operation.FROM)
607
+ def intersect(self, other: DataFrame) -> DataFrame:
608
+ return self._set_operation(exp.Intersect, other, True)
609
+
610
+ @operation(Operation.FROM)
611
+ def intersectAll(self, other: DataFrame) -> DataFrame:
612
+ return self._set_operation(exp.Intersect, other, False)
613
+
614
+ @operation(Operation.FROM)
615
+ def exceptAll(self, other: DataFrame) -> DataFrame:
616
+ return self._set_operation(exp.Except, other, False)
617
+
618
+ @operation(Operation.SELECT)
619
+ def distinct(self) -> DataFrame:
620
+ return self.copy(expression=self.expression.distinct())
621
+
622
+ @operation(Operation.SELECT)
623
+ def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
624
+ if not subset:
625
+ return self.distinct()
626
+ column_names = ensure_list(subset)
627
+ window = Window.partitionBy(*column_names).orderBy(*column_names)
628
+ return (
629
+ self.copy()
630
+ .withColumn("row_num", F.row_number().over(window))
631
+ .where(F.col("row_num") == F.lit(1))
632
+ .drop("row_num")
633
+ )
634
+
635
+ @operation(Operation.FROM)
636
+ def dropna(
637
+ self,
638
+ how: str = "any",
639
+ thresh: t.Optional[int] = None,
640
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
641
+ ) -> DataFrame:
642
+ minimum_non_null = thresh or 0 # will be determined later if thresh is null
643
+ new_df = self.copy()
644
+ all_columns = self._get_outer_select_columns(new_df.expression)
645
+ if subset:
646
+ null_check_columns = self._ensure_and_normalize_cols(subset)
647
+ else:
648
+ null_check_columns = all_columns
649
+ if thresh is None:
650
+ minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
651
+ else:
652
+ minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
653
+ if minimum_num_nulls > len(null_check_columns):
654
+ raise RuntimeError(
655
+ f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
656
+ f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
657
+ )
658
+ if_null_checks = [
659
+ F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
660
+ ]
661
+ nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
662
+ num_nulls = nulls_added_together.alias("num_nulls")
663
+ new_df = new_df.select(num_nulls, append=True)
664
+ filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
665
+ final_df = filtered_df.select(*all_columns)
666
+ return final_df
667
+
668
+ @operation(Operation.FROM)
669
+ def fillna(
670
+ self,
671
+ value: t.Union[ColumnLiterals],
672
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
673
+ ) -> DataFrame:
674
+ """
675
+ Functionality Difference: If you provide a value to replace a null and that type conflicts
676
+ with the type of the column then PySpark will just ignore your replacement.
677
+ This will try to cast them to be the same in some cases. So they won't always match.
678
+ Best to not mix types so make sure replacement is the same type as the column
679
+
680
+ Possibility for improvement: Use `typeof` function to get the type of the column
681
+ and check if it matches the type of the value provided. If not then make it null.
682
+ """
683
+ from sqlglot.dataframe.sql.functions import lit
684
+
685
+ values = None
686
+ columns = None
687
+ new_df = self.copy()
688
+ all_columns = self._get_outer_select_columns(new_df.expression)
689
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
690
+ if isinstance(value, dict):
691
+ values = list(value.values())
692
+ columns = self._ensure_and_normalize_cols(list(value))
693
+ if not columns:
694
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
695
+ if not values:
696
+ values = [value] * len(columns)
697
+ value_columns = [lit(value) for value in values]
698
+
699
+ null_replacement_mapping = {
700
+ column.alias_or_name: (
701
+ F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
702
+ )
703
+ for column, value in zip(columns, value_columns)
704
+ }
705
+ null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
706
+ null_replacement_columns = [
707
+ null_replacement_mapping[column.alias_or_name] for column in all_columns
708
+ ]
709
+ new_df = new_df.select(*null_replacement_columns)
710
+ return new_df
711
+
712
+ @operation(Operation.FROM)
713
+ def replace(
714
+ self,
715
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
716
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
717
+ subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
718
+ ) -> DataFrame:
719
+ from sqlglot.dataframe.sql.functions import lit
720
+
721
+ old_values = None
722
+ new_df = self.copy()
723
+ all_columns = self._get_outer_select_columns(new_df.expression)
724
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
725
+
726
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
727
+ if isinstance(to_replace, dict):
728
+ old_values = list(to_replace)
729
+ new_values = list(to_replace.values())
730
+ elif not old_values and isinstance(to_replace, list):
731
+ assert isinstance(value, list), "value must be a list since the replacements are a list"
732
+ assert len(to_replace) == len(
733
+ value
734
+ ), "the replacements and values must be the same length"
735
+ old_values = to_replace
736
+ new_values = value
737
+ else:
738
+ old_values = [to_replace] * len(columns)
739
+ new_values = [value] * len(columns)
740
+ old_values = [lit(value) for value in old_values]
741
+ new_values = [lit(value) for value in new_values]
742
+
743
+ replacement_mapping = {}
744
+ for column in columns:
745
+ expression = Column(None)
746
+ for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
747
+ if i == 0:
748
+ expression = F.when(column == old_value, new_value)
749
+ else:
750
+ expression = expression.when(column == old_value, new_value) # type: ignore
751
+ replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
752
+ column.expression.alias_or_name
753
+ )
754
+
755
+ replacement_mapping = {**all_column_mapping, **replacement_mapping}
756
+ replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
757
+ new_df = new_df.select(*replacement_columns)
758
+ return new_df
759
+
760
+ @operation(Operation.SELECT)
761
+ def withColumn(self, colName: str, col: Column) -> DataFrame:
762
+ col = self._ensure_and_normalize_col(col)
763
+ existing_col_names = self.expression.named_selects
764
+ existing_col_index = (
765
+ existing_col_names.index(colName) if colName in existing_col_names else None
766
+ )
767
+ if existing_col_index:
768
+ expression = self.expression.copy()
769
+ expression.expressions[existing_col_index] = col.expression
770
+ return self.copy(expression=expression)
771
+ return self.copy().select(col.alias(colName), append=True)
772
+
773
+ @operation(Operation.SELECT)
774
+ def withColumnRenamed(self, existing: str, new: str):
775
+ expression = self.expression.copy()
776
+ existing_columns = [
777
+ expression
778
+ for expression in expression.expressions
779
+ if expression.alias_or_name == existing
780
+ ]
781
+ if not existing_columns:
782
+ raise ValueError("Tried to rename a column that doesn't exist")
783
+ for existing_column in existing_columns:
784
+ if isinstance(existing_column, exp.Column):
785
+ existing_column.replace(exp.alias_(existing_column, new))
786
+ else:
787
+ existing_column.set("alias", exp.to_identifier(new))
788
+ return self.copy(expression=expression)
789
+
790
+ @operation(Operation.SELECT)
791
+ def drop(self, *cols: t.Union[str, Column]) -> DataFrame:
792
+ all_columns = self._get_outer_select_columns(self.expression)
793
+ drop_cols = self._ensure_and_normalize_cols(cols)
794
+ new_columns = [
795
+ col
796
+ for col in all_columns
797
+ if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
798
+ ]
799
+ return self.copy().select(*new_columns, append=False)
800
+
801
+ @operation(Operation.LIMIT)
802
+ def limit(self, num: int) -> DataFrame:
803
+ return self.copy(expression=self.expression.limit(num))
804
+
805
+ @operation(Operation.NO_OP)
806
+ def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
807
+ parameter_list = ensure_list(parameters)
808
+ parameter_columns = (
809
+ self._ensure_list_of_columns(parameter_list)
810
+ if parameters
811
+ else Column.ensure_cols([self.sequence_id])
812
+ )
813
+ return self._hint(name, parameter_columns)
814
+
815
+ @operation(Operation.NO_OP)
816
+ def repartition(
817
+ self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
818
+ ) -> DataFrame:
819
+ num_partition_cols = self._ensure_list_of_columns(numPartitions)
820
+ columns = self._ensure_and_normalize_cols(cols)
821
+ args = num_partition_cols + columns
822
+ return self._hint("repartition", args)
823
+
824
+ @operation(Operation.NO_OP)
825
+ def coalesce(self, numPartitions: int) -> DataFrame:
826
+ num_partitions = Column.ensure_cols([numPartitions])
827
+ return self._hint("coalesce", num_partitions)
828
+
829
+ @operation(Operation.NO_OP)
830
+ def cache(self) -> DataFrame:
831
+ return self._cache(storage_level="MEMORY_AND_DISK")
832
+
833
+ @operation(Operation.NO_OP)
834
+ def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame:
835
+ """
836
+ Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
837
+ """
838
+ return self._cache(storageLevel)
839
+
840
+
841
+ class DataFrameNaFunctions:
842
+ def __init__(self, df: DataFrame):
843
+ self.df = df
844
+
845
+ def drop(
846
+ self,
847
+ how: str = "any",
848
+ thresh: t.Optional[int] = None,
849
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
850
+ ) -> DataFrame:
851
+ return self.df.dropna(how=how, thresh=thresh, subset=subset)
852
+
853
+ def fill(
854
+ self,
855
+ value: t.Union[int, bool, float, str, t.Dict[str, t.Any]],
856
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
857
+ ) -> DataFrame:
858
+ return self.df.fillna(value=value, subset=subset)
859
+
860
+ def replace(
861
+ self,
862
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
863
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
864
+ subset: t.Optional[t.Union[str, t.List[str]]] = None,
865
+ ) -> DataFrame:
866
+ return self.df.replace(to_replace=to_replace, value=value, subset=subset)