sqlframe 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/__init__.py +0 -0
- sqlframe/_version.py +16 -0
- sqlframe/base/__init__.py +0 -0
- sqlframe/base/_typing.py +39 -0
- sqlframe/base/catalog.py +1163 -0
- sqlframe/base/column.py +388 -0
- sqlframe/base/dataframe.py +1519 -0
- sqlframe/base/decorators.py +51 -0
- sqlframe/base/exceptions.py +14 -0
- sqlframe/base/function_alternatives.py +1055 -0
- sqlframe/base/functions.py +1678 -0
- sqlframe/base/group.py +102 -0
- sqlframe/base/mixins/__init__.py +0 -0
- sqlframe/base/mixins/catalog_mixins.py +419 -0
- sqlframe/base/mixins/readwriter_mixins.py +118 -0
- sqlframe/base/normalize.py +84 -0
- sqlframe/base/operations.py +87 -0
- sqlframe/base/readerwriter.py +679 -0
- sqlframe/base/session.py +585 -0
- sqlframe/base/transforms.py +13 -0
- sqlframe/base/types.py +418 -0
- sqlframe/base/util.py +242 -0
- sqlframe/base/window.py +139 -0
- sqlframe/bigquery/__init__.py +23 -0
- sqlframe/bigquery/catalog.py +255 -0
- sqlframe/bigquery/column.py +1 -0
- sqlframe/bigquery/dataframe.py +54 -0
- sqlframe/bigquery/functions.py +378 -0
- sqlframe/bigquery/group.py +14 -0
- sqlframe/bigquery/readwriter.py +29 -0
- sqlframe/bigquery/session.py +89 -0
- sqlframe/bigquery/types.py +1 -0
- sqlframe/bigquery/window.py +1 -0
- sqlframe/duckdb/__init__.py +20 -0
- sqlframe/duckdb/catalog.py +108 -0
- sqlframe/duckdb/column.py +1 -0
- sqlframe/duckdb/dataframe.py +55 -0
- sqlframe/duckdb/functions.py +47 -0
- sqlframe/duckdb/group.py +14 -0
- sqlframe/duckdb/readwriter.py +111 -0
- sqlframe/duckdb/session.py +65 -0
- sqlframe/duckdb/types.py +1 -0
- sqlframe/duckdb/window.py +1 -0
- sqlframe/postgres/__init__.py +23 -0
- sqlframe/postgres/catalog.py +106 -0
- sqlframe/postgres/column.py +1 -0
- sqlframe/postgres/dataframe.py +54 -0
- sqlframe/postgres/functions.py +61 -0
- sqlframe/postgres/group.py +14 -0
- sqlframe/postgres/readwriter.py +29 -0
- sqlframe/postgres/session.py +68 -0
- sqlframe/postgres/types.py +1 -0
- sqlframe/postgres/window.py +1 -0
- sqlframe/redshift/__init__.py +23 -0
- sqlframe/redshift/catalog.py +127 -0
- sqlframe/redshift/column.py +1 -0
- sqlframe/redshift/dataframe.py +54 -0
- sqlframe/redshift/functions.py +18 -0
- sqlframe/redshift/group.py +14 -0
- sqlframe/redshift/readwriter.py +29 -0
- sqlframe/redshift/session.py +53 -0
- sqlframe/redshift/types.py +1 -0
- sqlframe/redshift/window.py +1 -0
- sqlframe/snowflake/__init__.py +26 -0
- sqlframe/snowflake/catalog.py +134 -0
- sqlframe/snowflake/column.py +1 -0
- sqlframe/snowflake/dataframe.py +54 -0
- sqlframe/snowflake/functions.py +18 -0
- sqlframe/snowflake/group.py +14 -0
- sqlframe/snowflake/readwriter.py +29 -0
- sqlframe/snowflake/session.py +53 -0
- sqlframe/snowflake/types.py +1 -0
- sqlframe/snowflake/window.py +1 -0
- sqlframe/spark/__init__.py +23 -0
- sqlframe/spark/catalog.py +1028 -0
- sqlframe/spark/column.py +1 -0
- sqlframe/spark/dataframe.py +54 -0
- sqlframe/spark/functions.py +22 -0
- sqlframe/spark/group.py +14 -0
- sqlframe/spark/readwriter.py +29 -0
- sqlframe/spark/session.py +90 -0
- sqlframe/spark/types.py +1 -0
- sqlframe/spark/window.py +1 -0
- sqlframe/standalone/__init__.py +26 -0
- sqlframe/standalone/catalog.py +13 -0
- sqlframe/standalone/column.py +1 -0
- sqlframe/standalone/dataframe.py +36 -0
- sqlframe/standalone/functions.py +1 -0
- sqlframe/standalone/group.py +14 -0
- sqlframe/standalone/readwriter.py +19 -0
- sqlframe/standalone/session.py +40 -0
- sqlframe/standalone/types.py +1 -0
- sqlframe/standalone/window.py +1 -0
- sqlframe-1.1.3.dist-info/LICENSE +21 -0
- sqlframe-1.1.3.dist-info/METADATA +172 -0
- sqlframe-1.1.3.dist-info/RECORD +98 -0
- sqlframe-1.1.3.dist-info/WHEEL +5 -0
- sqlframe-1.1.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1519 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import itertools
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
import typing as t
|
|
10
|
+
import zlib
|
|
11
|
+
from copy import copy
|
|
12
|
+
|
|
13
|
+
import sqlglot
|
|
14
|
+
from prettytable import PrettyTable
|
|
15
|
+
from sqlglot import Dialect
|
|
16
|
+
from sqlglot import expressions as exp
|
|
17
|
+
from sqlglot.helper import ensure_list, object_to_dict, seq_get
|
|
18
|
+
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
|
19
|
+
|
|
20
|
+
from sqlframe.base.operations import Operation, operation
|
|
21
|
+
from sqlframe.base.transforms import replace_id_value
|
|
22
|
+
from sqlframe.base.util import (
|
|
23
|
+
get_func_from_session,
|
|
24
|
+
get_tables_from_expression_with_join,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
if sys.version_info >= (3, 11):
|
|
28
|
+
from typing import Self
|
|
29
|
+
else:
|
|
30
|
+
from typing_extensions import Self
|
|
31
|
+
|
|
32
|
+
if t.TYPE_CHECKING:
|
|
33
|
+
import pandas as pd
|
|
34
|
+
from sqlglot.dialects.dialect import DialectType
|
|
35
|
+
|
|
36
|
+
from sqlframe.base._typing import (
|
|
37
|
+
ColumnOrLiteral,
|
|
38
|
+
ColumnOrName,
|
|
39
|
+
OutputExpressionContainer,
|
|
40
|
+
PrimitiveType,
|
|
41
|
+
StorageLevel,
|
|
42
|
+
)
|
|
43
|
+
from sqlframe.base.column import Column
|
|
44
|
+
from sqlframe.base.group import _BaseGroupedData
|
|
45
|
+
from sqlframe.base.session import WRITER, _BaseSession
|
|
46
|
+
from sqlframe.base.types import Row, StructType
|
|
47
|
+
|
|
48
|
+
SESSION = t.TypeVar("SESSION", bound=_BaseSession)
|
|
49
|
+
GROUP_DATA = t.TypeVar("GROUP_DATA", bound=_BaseGroupedData)
|
|
50
|
+
else:
|
|
51
|
+
WRITER = t.TypeVar("WRITER")
|
|
52
|
+
SESSION = t.TypeVar("SESSION")
|
|
53
|
+
GROUP_DATA = t.TypeVar("GROUP_DATA")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
logger = logging.getLogger(__name__)
|
|
57
|
+
|
|
58
|
+
JOIN_HINTS = {
|
|
59
|
+
"BROADCAST",
|
|
60
|
+
"BROADCASTJOIN",
|
|
61
|
+
"MAPJOIN",
|
|
62
|
+
"MERGE",
|
|
63
|
+
"SHUFFLEMERGE",
|
|
64
|
+
"MERGEJOIN",
|
|
65
|
+
"SHUFFLE_HASH",
|
|
66
|
+
"SHUFFLE_REPLICATE_NL",
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
DF = t.TypeVar("DF", bound="_BaseDataFrame")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class _BaseDataFrameNaFunctions(t.Generic[DF]):
|
|
74
|
+
def __init__(self, df: DF):
|
|
75
|
+
self.df = df
|
|
76
|
+
|
|
77
|
+
def drop(
|
|
78
|
+
self,
|
|
79
|
+
how: str = "any",
|
|
80
|
+
thresh: t.Optional[int] = None,
|
|
81
|
+
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
82
|
+
) -> DF:
|
|
83
|
+
return self.df.dropna(how=how, thresh=thresh, subset=subset)
|
|
84
|
+
|
|
85
|
+
@t.overload
|
|
86
|
+
def fill(self, value: PrimitiveType, subset: t.Optional[t.List[str]] = ...) -> DF: ...
|
|
87
|
+
|
|
88
|
+
@t.overload
|
|
89
|
+
def fill(self, value: t.Dict[str, PrimitiveType]) -> DF: ...
|
|
90
|
+
|
|
91
|
+
def fill(
|
|
92
|
+
self,
|
|
93
|
+
value: t.Union[PrimitiveType, t.Dict[str, PrimitiveType]],
|
|
94
|
+
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
95
|
+
) -> DF:
|
|
96
|
+
return self.df.fillna(value=value, subset=subset)
|
|
97
|
+
|
|
98
|
+
def replace(
|
|
99
|
+
self,
|
|
100
|
+
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
|
101
|
+
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
|
102
|
+
subset: t.Optional[t.Union[str, t.List[str]]] = None,
|
|
103
|
+
) -> DF:
|
|
104
|
+
return self.df.replace(to_replace=to_replace, value=value, subset=subset)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
NA = t.TypeVar("NA", bound=_BaseDataFrameNaFunctions)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class _BaseDataFrameStatFunctions(t.Generic[DF]):
|
|
111
|
+
def __init__(self, df: DF):
|
|
112
|
+
self.df = df
|
|
113
|
+
|
|
114
|
+
@t.overload
|
|
115
|
+
def approxQuantile(
|
|
116
|
+
self,
|
|
117
|
+
col: str,
|
|
118
|
+
probabilities: t.Union[t.List[float], t.Tuple[float]],
|
|
119
|
+
relativeError: float,
|
|
120
|
+
) -> t.List[float]: ...
|
|
121
|
+
|
|
122
|
+
@t.overload
|
|
123
|
+
def approxQuantile(
|
|
124
|
+
self,
|
|
125
|
+
col: t.Union[t.List[str], t.Tuple[str]],
|
|
126
|
+
probabilities: t.Union[t.List[float], t.Tuple[float]],
|
|
127
|
+
relativeError: float,
|
|
128
|
+
) -> t.List[t.List[float]]: ...
|
|
129
|
+
|
|
130
|
+
def approxQuantile(
|
|
131
|
+
self,
|
|
132
|
+
col: t.Union[str, t.List[str], t.Tuple[str]],
|
|
133
|
+
probabilities: t.Union[t.List[float], t.Tuple[float]],
|
|
134
|
+
relativeError: float,
|
|
135
|
+
) -> t.Union[t.List[float], t.List[t.List[float]]]:
|
|
136
|
+
return self.df.approxQuantile(col, probabilities, relativeError)
|
|
137
|
+
|
|
138
|
+
def corr(self, col1: str, col2: str, method: str = "pearson") -> float:
|
|
139
|
+
return self.df.corr(col1, col2, method)
|
|
140
|
+
|
|
141
|
+
def cov(self, col1: str, col2: str) -> float:
|
|
142
|
+
return self.df.cov(col1, col2)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
STAT = t.TypeVar("STAT", bound=_BaseDataFrameStatFunctions)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
149
|
+
_na: t.Type[NA]
|
|
150
|
+
_stat: t.Type[STAT]
|
|
151
|
+
_group_data: t.Type[GROUP_DATA]
|
|
152
|
+
|
|
153
|
+
def __init__(
|
|
154
|
+
self,
|
|
155
|
+
session: SESSION,
|
|
156
|
+
expression: exp.Select,
|
|
157
|
+
branch_id: t.Optional[str] = None,
|
|
158
|
+
sequence_id: t.Optional[str] = None,
|
|
159
|
+
last_op: Operation = Operation.INIT,
|
|
160
|
+
pending_hints: t.Optional[t.List[exp.Expression]] = None,
|
|
161
|
+
output_expression_container: t.Optional[OutputExpressionContainer] = None,
|
|
162
|
+
**kwargs,
|
|
163
|
+
):
|
|
164
|
+
self.session = session
|
|
165
|
+
self.expression: exp.Select = expression
|
|
166
|
+
self.branch_id = branch_id or self.session._random_branch_id
|
|
167
|
+
self.sequence_id = sequence_id or self.session._random_sequence_id
|
|
168
|
+
self.last_op = last_op
|
|
169
|
+
self.pending_hints = pending_hints or []
|
|
170
|
+
self.output_expression_container = output_expression_container or exp.Select()
|
|
171
|
+
self.temp_views: t.List[exp.Select] = []
|
|
172
|
+
|
|
173
|
+
def __getattr__(self, column_name: str) -> Column:
|
|
174
|
+
return self[column_name]
|
|
175
|
+
|
|
176
|
+
def __getitem__(self, column_name: str) -> Column:
|
|
177
|
+
from sqlframe.base.util import get_func_from_session
|
|
178
|
+
|
|
179
|
+
col = get_func_from_session("col", self.session)
|
|
180
|
+
|
|
181
|
+
column_name = f"{self.branch_id}.{column_name}"
|
|
182
|
+
return col(column_name)
|
|
183
|
+
|
|
184
|
+
def __copy__(self):
|
|
185
|
+
return self.copy()
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def write(self) -> WRITER:
|
|
189
|
+
return self.session._writer(self)
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def latest_cte_name(self) -> str:
|
|
193
|
+
if not self.expression.ctes:
|
|
194
|
+
from_exp = self.expression.args["from"]
|
|
195
|
+
if from_exp.alias_or_name:
|
|
196
|
+
return from_exp.alias_or_name
|
|
197
|
+
table_alias = from_exp.find(exp.TableAlias)
|
|
198
|
+
if not table_alias:
|
|
199
|
+
raise RuntimeError(
|
|
200
|
+
f"Could not find an alias name for this expression: {self.expression}"
|
|
201
|
+
)
|
|
202
|
+
return table_alias.alias_or_name
|
|
203
|
+
return self.expression.ctes[-1].alias
|
|
204
|
+
|
|
205
|
+
@property
|
|
206
|
+
def pending_join_hints(self):
|
|
207
|
+
return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def pending_partition_hints(self):
|
|
211
|
+
return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def columns(self) -> t.List[str]:
|
|
215
|
+
return self.expression.named_selects
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def na(self) -> NA:
|
|
219
|
+
return self._na(self)
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def stat(self) -> STAT:
|
|
223
|
+
return self._stat(self)
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def schema(self) -> StructType:
|
|
227
|
+
"""Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`.
|
|
228
|
+
|
|
229
|
+
.. versionadded:: 1.3.0
|
|
230
|
+
|
|
231
|
+
.. versionchanged:: 3.4.0
|
|
232
|
+
Supports Spark Connect.
|
|
233
|
+
|
|
234
|
+
Returns
|
|
235
|
+
-------
|
|
236
|
+
:class:`StructType`
|
|
237
|
+
|
|
238
|
+
Examples
|
|
239
|
+
--------
|
|
240
|
+
>>> df = spark.createDataFrame(
|
|
241
|
+
... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
|
|
242
|
+
|
|
243
|
+
Retrieve the schema of the current DataFrame.
|
|
244
|
+
|
|
245
|
+
>>> df.schema
|
|
246
|
+
StructType([StructField('age', LongType(), True),
|
|
247
|
+
StructField('name', StringType(), True)])
|
|
248
|
+
"""
|
|
249
|
+
raise NotImplementedError
|
|
250
|
+
|
|
251
|
+
def _replace_cte_names_with_hashes(self, expression: exp.Select):
|
|
252
|
+
replacement_mapping = {}
|
|
253
|
+
for cte in expression.ctes:
|
|
254
|
+
old_name_id = cte.args["alias"].this
|
|
255
|
+
new_hashed_id = exp.to_identifier(
|
|
256
|
+
self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
|
|
257
|
+
)
|
|
258
|
+
replacement_mapping[old_name_id] = new_hashed_id
|
|
259
|
+
expression = expression.transform(replace_id_value, replacement_mapping).assert_is(
|
|
260
|
+
exp.Select
|
|
261
|
+
)
|
|
262
|
+
return expression
|
|
263
|
+
|
|
264
|
+
def _create_cte_from_expression(
|
|
265
|
+
self,
|
|
266
|
+
expression: exp.Expression,
|
|
267
|
+
branch_id: str,
|
|
268
|
+
sequence_id: str,
|
|
269
|
+
name: t.Optional[str] = None,
|
|
270
|
+
**kwargs,
|
|
271
|
+
) -> t.Tuple[exp.CTE, str]:
|
|
272
|
+
name = name or self._create_hash_from_expression(expression)
|
|
273
|
+
expression_to_cte = expression.copy()
|
|
274
|
+
expression_to_cte.set("with", None)
|
|
275
|
+
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
|
|
276
|
+
cte.set("branch_id", branch_id)
|
|
277
|
+
cte.set("sequence_id", sequence_id)
|
|
278
|
+
return cte, name
|
|
279
|
+
|
|
280
|
+
def _ensure_list_of_columns(
|
|
281
|
+
self, cols: t.Optional[t.Union[ColumnOrLiteral, t.Collection[ColumnOrLiteral]]]
|
|
282
|
+
) -> t.List[Column]:
|
|
283
|
+
from sqlframe.base.column import Column
|
|
284
|
+
|
|
285
|
+
return Column.ensure_cols(ensure_list(cols)) # type: ignore
|
|
286
|
+
|
|
287
|
+
def _ensure_and_normalize_cols(
|
|
288
|
+
self, cols, expression: t.Optional[exp.Select] = None
|
|
289
|
+
) -> t.List[Column]:
|
|
290
|
+
from sqlframe.base.normalize import normalize
|
|
291
|
+
|
|
292
|
+
cols = self._ensure_list_of_columns(cols)
|
|
293
|
+
normalize(self.session, expression or self.expression, cols)
|
|
294
|
+
return cols
|
|
295
|
+
|
|
296
|
+
def _ensure_and_normalize_col(self, col):
|
|
297
|
+
from sqlframe.base.column import Column
|
|
298
|
+
from sqlframe.base.normalize import normalize
|
|
299
|
+
|
|
300
|
+
col = Column.ensure_col(col)
|
|
301
|
+
normalize(self.session, self.expression, col)
|
|
302
|
+
return col
|
|
303
|
+
|
|
304
|
+
def _convert_leaf_to_cte(
|
|
305
|
+
self, sequence_id: t.Optional[str] = None, name: t.Optional[str] = None
|
|
306
|
+
) -> Self:
|
|
307
|
+
df = self._resolve_pending_hints()
|
|
308
|
+
sequence_id = sequence_id or df.sequence_id
|
|
309
|
+
expression = df.expression.copy()
|
|
310
|
+
cte_expression, cte_name = df._create_cte_from_expression(
|
|
311
|
+
expression=expression, branch_id=self.branch_id, sequence_id=sequence_id, name=name
|
|
312
|
+
)
|
|
313
|
+
new_expression = df._add_ctes_to_expression(
|
|
314
|
+
exp.Select(), expression.ctes + [cte_expression]
|
|
315
|
+
)
|
|
316
|
+
sel_columns = df._get_outer_select_columns(cte_expression)
|
|
317
|
+
new_expression = new_expression.from_(cte_name).select(*[x.expression for x in sel_columns])
|
|
318
|
+
return df.copy(expression=new_expression, sequence_id=sequence_id)
|
|
319
|
+
|
|
320
|
+
def _resolve_pending_hints(self) -> Self:
|
|
321
|
+
df = self.copy()
|
|
322
|
+
if not self.pending_hints:
|
|
323
|
+
return df
|
|
324
|
+
expression = df.expression
|
|
325
|
+
hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
|
|
326
|
+
for hint in df.pending_partition_hints:
|
|
327
|
+
hint_expression.append("expressions", hint)
|
|
328
|
+
df.pending_hints.remove(hint)
|
|
329
|
+
|
|
330
|
+
join_aliases = {
|
|
331
|
+
join_table.alias_or_name
|
|
332
|
+
for join_table in get_tables_from_expression_with_join(expression)
|
|
333
|
+
}
|
|
334
|
+
if join_aliases:
|
|
335
|
+
for hint in df.pending_join_hints:
|
|
336
|
+
for sequence_id_expression in hint.expressions:
|
|
337
|
+
sequence_id_or_name = sequence_id_expression.alias_or_name
|
|
338
|
+
sequence_ids_to_match = [sequence_id_or_name]
|
|
339
|
+
if sequence_id_or_name in df.session.name_to_sequence_id_mapping:
|
|
340
|
+
sequence_ids_to_match = df.session.name_to_sequence_id_mapping[
|
|
341
|
+
sequence_id_or_name
|
|
342
|
+
]
|
|
343
|
+
matching_ctes = [
|
|
344
|
+
cte
|
|
345
|
+
for cte in reversed(expression.ctes)
|
|
346
|
+
if cte.args["sequence_id"] in sequence_ids_to_match
|
|
347
|
+
]
|
|
348
|
+
for matching_cte in matching_ctes:
|
|
349
|
+
if matching_cte.alias_or_name in join_aliases:
|
|
350
|
+
sequence_id_expression.set("this", matching_cte.args["alias"].this)
|
|
351
|
+
df.pending_hints.remove(hint)
|
|
352
|
+
break
|
|
353
|
+
hint_expression.append("expressions", hint)
|
|
354
|
+
if hint_expression.expressions:
|
|
355
|
+
expression.set("hint", hint_expression)
|
|
356
|
+
return df
|
|
357
|
+
|
|
358
|
+
def _hint(self, hint_name: str, args: t.List[Column]) -> Self:
|
|
359
|
+
hint_name = hint_name.upper()
|
|
360
|
+
hint_expression = (
|
|
361
|
+
exp.JoinHint(
|
|
362
|
+
this=hint_name,
|
|
363
|
+
expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
|
|
364
|
+
)
|
|
365
|
+
if hint_name in JOIN_HINTS
|
|
366
|
+
else exp.Anonymous(
|
|
367
|
+
this=hint_name, expressions=[parameter.expression for parameter in args]
|
|
368
|
+
)
|
|
369
|
+
)
|
|
370
|
+
new_df = self.copy()
|
|
371
|
+
new_df.pending_hints.append(hint_expression)
|
|
372
|
+
return new_df
|
|
373
|
+
|
|
374
|
+
def _set_operation(self, klass: t.Callable, other: Self, distinct: bool) -> Self:
|
|
375
|
+
other_df = other._convert_leaf_to_cte()
|
|
376
|
+
base_expression = self.expression.copy()
|
|
377
|
+
base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
|
|
378
|
+
all_ctes = base_expression.ctes
|
|
379
|
+
other_df.expression.set("with", None)
|
|
380
|
+
base_expression.set("with", None)
|
|
381
|
+
operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
|
|
382
|
+
operation.set("with", exp.With(expressions=all_ctes))
|
|
383
|
+
return self.copy(expression=operation)._convert_leaf_to_cte()
|
|
384
|
+
|
|
385
|
+
def _cache(self, storage_level: str) -> Self:
|
|
386
|
+
df = self._convert_leaf_to_cte()
|
|
387
|
+
df.expression.ctes[-1].set("cache_storage_level", storage_level)
|
|
388
|
+
return df
|
|
389
|
+
|
|
390
|
+
@classmethod
|
|
391
|
+
def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
|
|
392
|
+
expression = expression.copy()
|
|
393
|
+
with_expression = expression.args.get("with")
|
|
394
|
+
if with_expression:
|
|
395
|
+
existing_ctes = with_expression.expressions
|
|
396
|
+
existsing_cte_names = {x.alias_or_name for x in existing_ctes}
|
|
397
|
+
for cte in ctes:
|
|
398
|
+
if cte.alias_or_name not in existsing_cte_names:
|
|
399
|
+
existing_ctes.append(cte)
|
|
400
|
+
else:
|
|
401
|
+
existing_ctes = ctes
|
|
402
|
+
expression.set("with", exp.With(expressions=existing_ctes))
|
|
403
|
+
return expression
|
|
404
|
+
|
|
405
|
+
@classmethod
|
|
406
|
+
def _get_outer_select_columns(cls, item: exp.Expression) -> t.List[Column]:
|
|
407
|
+
from sqlframe.base.session import _BaseSession
|
|
408
|
+
|
|
409
|
+
col = get_func_from_session("col", _BaseSession())
|
|
410
|
+
|
|
411
|
+
outer_select = item.find(exp.Select)
|
|
412
|
+
if outer_select:
|
|
413
|
+
return [col(x.alias_or_name) for x in outer_select.expressions]
|
|
414
|
+
return []
|
|
415
|
+
|
|
416
|
+
def _create_hash_from_expression(self, expression: exp.Expression) -> str:
|
|
417
|
+
from sqlframe.base.session import _BaseSession
|
|
418
|
+
|
|
419
|
+
value = expression.sql(dialect=_BaseSession().input_dialect).encode("utf-8")
|
|
420
|
+
hash = f"t{zlib.crc32(value)}"[:9]
|
|
421
|
+
return self.session._normalize_string(hash)
|
|
422
|
+
|
|
423
|
+
def _get_select_expressions(
|
|
424
|
+
self,
|
|
425
|
+
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
|
|
426
|
+
select_expressions: t.List[
|
|
427
|
+
t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
|
|
428
|
+
] = []
|
|
429
|
+
main_select_ctes: t.List[exp.CTE] = []
|
|
430
|
+
for cte in self.expression.ctes:
|
|
431
|
+
cache_storage_level = cte.args.get("cache_storage_level")
|
|
432
|
+
if cache_storage_level:
|
|
433
|
+
select_expression = cte.this.copy()
|
|
434
|
+
select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
|
|
435
|
+
select_expression.set("cte_alias_name", cte.alias_or_name)
|
|
436
|
+
select_expression.set("cache_storage_level", cache_storage_level)
|
|
437
|
+
select_expressions.append((exp.Cache, select_expression))
|
|
438
|
+
else:
|
|
439
|
+
main_select_ctes.append(cte)
|
|
440
|
+
main_select = self.expression.copy()
|
|
441
|
+
if main_select_ctes:
|
|
442
|
+
main_select.set("with", exp.With(expressions=main_select_ctes))
|
|
443
|
+
expression_select_pair = (type(self.output_expression_container), main_select)
|
|
444
|
+
select_expressions.append(expression_select_pair) # type: ignore
|
|
445
|
+
return select_expressions
|
|
446
|
+
|
|
447
|
+
@t.overload
|
|
448
|
+
def sql(
|
|
449
|
+
self,
|
|
450
|
+
dialect: DialectType = ...,
|
|
451
|
+
optimize: bool = ...,
|
|
452
|
+
pretty: bool = ...,
|
|
453
|
+
*,
|
|
454
|
+
as_list: t.Literal[False],
|
|
455
|
+
**kwargs: t.Any,
|
|
456
|
+
) -> str: ...
|
|
457
|
+
|
|
458
|
+
@t.overload
|
|
459
|
+
def sql(
|
|
460
|
+
self,
|
|
461
|
+
dialect: DialectType = ...,
|
|
462
|
+
optimize: bool = ...,
|
|
463
|
+
pretty: bool = ...,
|
|
464
|
+
*,
|
|
465
|
+
as_list: t.Literal[True],
|
|
466
|
+
**kwargs: t.Any,
|
|
467
|
+
) -> t.List[str]: ...
|
|
468
|
+
|
|
469
|
+
def sql(
|
|
470
|
+
self,
|
|
471
|
+
dialect: DialectType = None,
|
|
472
|
+
optimize: bool = True,
|
|
473
|
+
pretty: bool = True,
|
|
474
|
+
as_list: bool = False,
|
|
475
|
+
**kwargs,
|
|
476
|
+
) -> t.Union[str, t.List[str]]:
|
|
477
|
+
dialect = Dialect.get_or_raise(dialect or self.session.output_dialect)
|
|
478
|
+
|
|
479
|
+
df = self._resolve_pending_hints()
|
|
480
|
+
select_expressions = df._get_select_expressions()
|
|
481
|
+
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
|
|
482
|
+
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
|
|
483
|
+
|
|
484
|
+
for expression_type, select_expression in select_expressions:
|
|
485
|
+
select_expression = select_expression.transform(
|
|
486
|
+
replace_id_value, replacement_mapping
|
|
487
|
+
).assert_is(exp.Select)
|
|
488
|
+
if optimize:
|
|
489
|
+
quote_identifiers(select_expression, dialect=dialect)
|
|
490
|
+
select_expression = t.cast(
|
|
491
|
+
exp.Select, self.session._optimize(select_expression, dialect=dialect)
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
|
495
|
+
|
|
496
|
+
expression: t.Union[exp.Select, exp.Cache, exp.Drop]
|
|
497
|
+
if expression_type == exp.Cache:
|
|
498
|
+
cache_table_name = df._create_hash_from_expression(select_expression)
|
|
499
|
+
cache_table = exp.to_table(cache_table_name)
|
|
500
|
+
original_alias_name = select_expression.args["cte_alias_name"]
|
|
501
|
+
|
|
502
|
+
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
|
|
503
|
+
cache_table_name
|
|
504
|
+
)
|
|
505
|
+
self.session.catalog.add_table(
|
|
506
|
+
cache_table_name,
|
|
507
|
+
{
|
|
508
|
+
expression.alias_or_name: expression.type.sql(dialect=dialect)
|
|
509
|
+
if expression.type
|
|
510
|
+
else "UNKNOWN"
|
|
511
|
+
for expression in select_expression.expressions
|
|
512
|
+
},
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
cache_storage_level = select_expression.args["cache_storage_level"]
|
|
516
|
+
options = [
|
|
517
|
+
exp.Literal.string("storageLevel"),
|
|
518
|
+
exp.Literal.string(cache_storage_level),
|
|
519
|
+
]
|
|
520
|
+
expression = exp.Cache(
|
|
521
|
+
this=cache_table, expression=select_expression, lazy=True, options=options
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# We will drop the "view" if it exists before running the cache table
|
|
525
|
+
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
|
|
526
|
+
elif expression_type == exp.Create:
|
|
527
|
+
expression = df.output_expression_container.copy()
|
|
528
|
+
expression.set("expression", select_expression)
|
|
529
|
+
elif expression_type == exp.Insert:
|
|
530
|
+
expression = df.output_expression_container.copy()
|
|
531
|
+
select_without_ctes = select_expression.copy()
|
|
532
|
+
select_without_ctes.set("with", None)
|
|
533
|
+
expression.set("expression", select_without_ctes)
|
|
534
|
+
|
|
535
|
+
if select_expression.ctes:
|
|
536
|
+
expression.set("with", exp.With(expressions=select_expression.ctes))
|
|
537
|
+
elif expression_type == exp.Select:
|
|
538
|
+
expression = select_expression
|
|
539
|
+
else:
|
|
540
|
+
raise ValueError(f"Invalid expression type: {expression_type}")
|
|
541
|
+
|
|
542
|
+
output_expressions.append(expression)
|
|
543
|
+
|
|
544
|
+
results = [
|
|
545
|
+
expression.sql(dialect=dialect, pretty=pretty, **kwargs)
|
|
546
|
+
for expression in output_expressions
|
|
547
|
+
]
|
|
548
|
+
if as_list:
|
|
549
|
+
return results
|
|
550
|
+
return ";\n".join(results)
|
|
551
|
+
|
|
552
|
+
def copy(self, **kwargs) -> Self:
|
|
553
|
+
return self.__class__(**object_to_dict(self, **kwargs))
|
|
554
|
+
|
|
555
|
+
@operation(Operation.SELECT)
|
|
556
|
+
def select(self, *cols, **kwargs) -> Self:
|
|
557
|
+
from sqlframe.base.column import Column
|
|
558
|
+
|
|
559
|
+
columns = self._ensure_and_normalize_cols(cols)
|
|
560
|
+
kwargs["append"] = kwargs.get("append", False)
|
|
561
|
+
if self.expression.args.get("joins"):
|
|
562
|
+
ambiguous_cols = [
|
|
563
|
+
col
|
|
564
|
+
for col in columns
|
|
565
|
+
if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
|
|
566
|
+
]
|
|
567
|
+
if ambiguous_cols:
|
|
568
|
+
join_table_identifiers = [
|
|
569
|
+
x.this for x in get_tables_from_expression_with_join(self.expression)
|
|
570
|
+
]
|
|
571
|
+
cte_names_in_join = [x.this for x in join_table_identifiers]
|
|
572
|
+
# If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
|
|
573
|
+
# and therefore we allow multiple columns with the same name in the result. This matches the behavior
|
|
574
|
+
# of Spark.
|
|
575
|
+
resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
|
|
576
|
+
for ambiguous_col in ambiguous_cols:
|
|
577
|
+
ctes_with_column = [
|
|
578
|
+
cte
|
|
579
|
+
for cte in self.expression.ctes
|
|
580
|
+
if cte.alias_or_name in cte_names_in_join
|
|
581
|
+
and ambiguous_col.alias_or_name in cte.this.named_selects
|
|
582
|
+
]
|
|
583
|
+
# Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
|
|
584
|
+
# use the same CTE we used before
|
|
585
|
+
cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
|
|
586
|
+
if cte:
|
|
587
|
+
resolved_column_position[ambiguous_col] += 1
|
|
588
|
+
else:
|
|
589
|
+
cte = ctes_with_column[resolved_column_position[ambiguous_col]]
|
|
590
|
+
ambiguous_col.expression.set("table", exp.to_identifier(cte.alias_or_name))
|
|
591
|
+
return self.copy(
|
|
592
|
+
expression=self.expression.select(*[x.expression for x in columns], **kwargs), **kwargs
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
@operation(Operation.NO_OP)
|
|
596
|
+
def alias(self, name: str, **kwargs) -> Self:
|
|
597
|
+
from sqlframe.base.column import Column
|
|
598
|
+
|
|
599
|
+
new_sequence_id = self.session._random_sequence_id
|
|
600
|
+
df = self.copy()
|
|
601
|
+
for join_hint in df.pending_join_hints:
|
|
602
|
+
for expression in join_hint.expressions:
|
|
603
|
+
if expression.alias_or_name == self.sequence_id:
|
|
604
|
+
expression.set("this", Column.ensure_col(new_sequence_id).expression)
|
|
605
|
+
df.session._add_alias_to_mapping(name, new_sequence_id)
|
|
606
|
+
return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
|
|
607
|
+
|
|
608
|
+
@operation(Operation.WHERE)
|
|
609
|
+
def where(self, column: t.Union[Column, str, bool], **kwargs) -> Self:
|
|
610
|
+
if isinstance(column, str):
|
|
611
|
+
col = self._ensure_and_normalize_col(
|
|
612
|
+
sqlglot.parse_one(column, dialect=self.session.input_dialect)
|
|
613
|
+
)
|
|
614
|
+
else:
|
|
615
|
+
col = self._ensure_and_normalize_col(column)
|
|
616
|
+
return self.copy(expression=self.expression.where(col.expression))
|
|
617
|
+
|
|
618
|
+
filter = where
|
|
619
|
+
|
|
620
|
+
@operation(Operation.GROUP_BY)
|
|
621
|
+
def groupBy(self, *cols, **kwargs) -> GROUP_DATA:
|
|
622
|
+
columns = self._ensure_and_normalize_cols(cols)
|
|
623
|
+
return self._group_data(self, columns, self.last_op)
|
|
624
|
+
|
|
625
|
+
groupby = groupBy
|
|
626
|
+
|
|
627
|
+
@operation(Operation.SELECT)
|
|
628
|
+
def agg(self, *exprs, **kwargs) -> Self:
|
|
629
|
+
cols = self._ensure_and_normalize_cols(exprs)
|
|
630
|
+
return self.groupBy().agg(*cols)
|
|
631
|
+
|
|
632
|
+
@operation(Operation.FROM)
|
|
633
|
+
def crossJoin(self, other: DF) -> Self:
|
|
634
|
+
"""Returns the cartesian product with another :class:`DataFrame`.
|
|
635
|
+
|
|
636
|
+
.. versionadded:: 2.1.0
|
|
637
|
+
|
|
638
|
+
.. versionchanged:: 3.4.0
|
|
639
|
+
Supports Spark Connect.
|
|
640
|
+
|
|
641
|
+
Parameters
|
|
642
|
+
----------
|
|
643
|
+
other : :class:`DataFrame`
|
|
644
|
+
Right side of the cartesian product.
|
|
645
|
+
|
|
646
|
+
Returns
|
|
647
|
+
-------
|
|
648
|
+
:class:`DataFrame`
|
|
649
|
+
Joined DataFrame.
|
|
650
|
+
|
|
651
|
+
Examples
|
|
652
|
+
--------
|
|
653
|
+
>>> from pyspark.sql import Row
|
|
654
|
+
>>> df = spark.createDataFrame(
|
|
655
|
+
... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
|
|
656
|
+
>>> df2 = spark.createDataFrame(
|
|
657
|
+
... [Row(height=80, name="Tom"), Row(height=85, name="Bob")])
|
|
658
|
+
>>> df.crossJoin(df2.select("height")).select("age", "name", "height").show()
|
|
659
|
+
+---+-----+------+
|
|
660
|
+
|age| name|height|
|
|
661
|
+
+---+-----+------+
|
|
662
|
+
| 14| Tom| 80|
|
|
663
|
+
| 14| Tom| 85|
|
|
664
|
+
| 23|Alice| 80|
|
|
665
|
+
| 23|Alice| 85|
|
|
666
|
+
| 16| Bob| 80|
|
|
667
|
+
| 16| Bob| 85|
|
|
668
|
+
+---+-----+------+
|
|
669
|
+
"""
|
|
670
|
+
return self.join.__wrapped__(self, other, how="cross") # type: ignore
|
|
671
|
+
|
|
672
|
+
@operation(Operation.FROM)
|
|
673
|
+
def join(
|
|
674
|
+
self,
|
|
675
|
+
other_df: Self,
|
|
676
|
+
on: t.Optional[t.Union[str, t.List[str], Column, t.List[Column]]] = None,
|
|
677
|
+
how: str = "inner",
|
|
678
|
+
**kwargs,
|
|
679
|
+
) -> Self:
|
|
680
|
+
if on is None:
|
|
681
|
+
logger.warning("Got no value for on. This appears change the join to a cross join.")
|
|
682
|
+
how = "cross"
|
|
683
|
+
other_df = other_df._convert_leaf_to_cte()
|
|
684
|
+
# We will determine actual "join on" expression later so we don't provide it at first
|
|
685
|
+
join_expression = self.expression.join(
|
|
686
|
+
other_df.latest_cte_name, join_type=how.replace("_", " ")
|
|
687
|
+
)
|
|
688
|
+
join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
|
|
689
|
+
self_columns = self._get_outer_select_columns(join_expression)
|
|
690
|
+
other_columns = self._get_outer_select_columns(other_df.expression)
|
|
691
|
+
join_columns = self._ensure_list_of_columns(on)
|
|
692
|
+
# Determines the join clause and select columns to be used passed on what type of columns were provided for
|
|
693
|
+
# the join. The columns returned changes based on how the on expression is provided.
|
|
694
|
+
if how != "cross":
|
|
695
|
+
if isinstance(join_columns[0].expression, exp.Column):
|
|
696
|
+
"""
|
|
697
|
+
Unique characteristics of join on column names only:
|
|
698
|
+
* The column names are put at the front of the select list
|
|
699
|
+
* The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
|
|
700
|
+
"""
|
|
701
|
+
table_names = [
|
|
702
|
+
table.alias_or_name
|
|
703
|
+
for table in get_tables_from_expression_with_join(join_expression)
|
|
704
|
+
]
|
|
705
|
+
potential_ctes = [
|
|
706
|
+
cte
|
|
707
|
+
for cte in join_expression.ctes
|
|
708
|
+
if cte.alias_or_name in table_names
|
|
709
|
+
and cte.alias_or_name != other_df.latest_cte_name
|
|
710
|
+
]
|
|
711
|
+
# Determine the table to reference for the left side of the join by checking each of the left side
|
|
712
|
+
# tables and see if they have the column being referenced.
|
|
713
|
+
join_column_pairs = []
|
|
714
|
+
for join_column in join_columns:
|
|
715
|
+
num_matching_ctes = 0
|
|
716
|
+
for cte in potential_ctes:
|
|
717
|
+
if join_column.alias_or_name in cte.this.named_selects:
|
|
718
|
+
left_column = join_column.copy().set_table_name(cte.alias_or_name)
|
|
719
|
+
right_column = join_column.copy().set_table_name(
|
|
720
|
+
other_df.latest_cte_name
|
|
721
|
+
)
|
|
722
|
+
join_column_pairs.append((left_column, right_column))
|
|
723
|
+
num_matching_ctes += 1
|
|
724
|
+
if num_matching_ctes > 1:
|
|
725
|
+
raise ValueError(
|
|
726
|
+
f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
|
|
727
|
+
)
|
|
728
|
+
elif num_matching_ctes == 0:
|
|
729
|
+
raise ValueError(
|
|
730
|
+
f"Column {join_column.alias_or_name} does not exist in any of the tables."
|
|
731
|
+
)
|
|
732
|
+
join_clause = functools.reduce(
|
|
733
|
+
lambda x, y: x & y,
|
|
734
|
+
[
|
|
735
|
+
left_column == right_column
|
|
736
|
+
for left_column, right_column in join_column_pairs
|
|
737
|
+
],
|
|
738
|
+
)
|
|
739
|
+
join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
|
|
740
|
+
# To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
|
|
741
|
+
select_column_names = [
|
|
742
|
+
(
|
|
743
|
+
column.alias_or_name
|
|
744
|
+
if not isinstance(column.expression.this, exp.Star)
|
|
745
|
+
else column.sql()
|
|
746
|
+
)
|
|
747
|
+
for column in self_columns + other_columns
|
|
748
|
+
]
|
|
749
|
+
select_column_names = [
|
|
750
|
+
column_name
|
|
751
|
+
for column_name in select_column_names
|
|
752
|
+
if column_name not in join_column_names
|
|
753
|
+
]
|
|
754
|
+
select_column_names = join_column_names + select_column_names
|
|
755
|
+
else:
|
|
756
|
+
"""
|
|
757
|
+
Unique characteristics of join on expressions:
|
|
758
|
+
* There is no deduplication of the results.
|
|
759
|
+
* The left join dataframe columns go first and right come after. No sort preference is given to join columns
|
|
760
|
+
"""
|
|
761
|
+
join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
|
|
762
|
+
if len(join_columns) > 1:
|
|
763
|
+
join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
|
|
764
|
+
join_clause = join_columns[0]
|
|
765
|
+
select_column_names = [
|
|
766
|
+
column.alias_or_name for column in self_columns + other_columns
|
|
767
|
+
]
|
|
768
|
+
|
|
769
|
+
# Update the on expression with the actual join clause to replace the dummy one from before
|
|
770
|
+
else:
|
|
771
|
+
select_column_names = [column.alias_or_name for column in self_columns + other_columns]
|
|
772
|
+
join_clause = None
|
|
773
|
+
join_expression.args["joins"][-1].set("on", join_clause.expression if join_clause else None)
|
|
774
|
+
new_df = self.copy(expression=join_expression)
|
|
775
|
+
new_df.pending_join_hints.extend(self.pending_join_hints)
|
|
776
|
+
new_df.pending_hints.extend(other_df.pending_hints)
|
|
777
|
+
new_df = new_df.select.__wrapped__(new_df, *select_column_names) # type: ignore
|
|
778
|
+
return new_df
|
|
779
|
+
|
|
780
|
+
@operation(Operation.ORDER_BY)
|
|
781
|
+
def orderBy(
|
|
782
|
+
self,
|
|
783
|
+
*cols: t.Union[str, Column],
|
|
784
|
+
ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
|
|
785
|
+
) -> Self:
|
|
786
|
+
"""
|
|
787
|
+
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
|
|
788
|
+
has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
|
|
789
|
+
is unlikely to come up.
|
|
790
|
+
"""
|
|
791
|
+
columns = self._ensure_and_normalize_cols(cols)
|
|
792
|
+
pre_ordered_col_indexes = [
|
|
793
|
+
i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
|
|
794
|
+
]
|
|
795
|
+
if ascending is None:
|
|
796
|
+
ascending = [True] * len(columns)
|
|
797
|
+
elif not isinstance(ascending, list):
|
|
798
|
+
ascending = [ascending] * len(columns)
|
|
799
|
+
ascending = [bool(x) for i, x in enumerate(ascending)]
|
|
800
|
+
assert len(columns) == len(
|
|
801
|
+
ascending
|
|
802
|
+
), "The length of items in ascending must equal the number of columns provided"
|
|
803
|
+
col_and_ascending = list(zip(columns, ascending))
|
|
804
|
+
order_by_columns = [
|
|
805
|
+
(
|
|
806
|
+
sqlglot.parse_one(
|
|
807
|
+
f"{col.expression.sql(dialect=self.session.input_dialect)} {'DESC' if not asc else ''}",
|
|
808
|
+
dialect=self.session.input_dialect,
|
|
809
|
+
into=exp.Ordered,
|
|
810
|
+
)
|
|
811
|
+
if i not in pre_ordered_col_indexes
|
|
812
|
+
else columns[i].column_expression
|
|
813
|
+
)
|
|
814
|
+
for i, (col, asc) in enumerate(col_and_ascending)
|
|
815
|
+
]
|
|
816
|
+
return self.copy(expression=self.expression.order_by(*order_by_columns))
|
|
817
|
+
|
|
818
|
+
sort = orderBy
|
|
819
|
+
|
|
820
|
+
@operation(Operation.FROM)
|
|
821
|
+
def union(self, other: Self) -> Self:
|
|
822
|
+
return self._set_operation(exp.Union, other, False)
|
|
823
|
+
|
|
824
|
+
unionAll = union
|
|
825
|
+
|
|
826
|
+
@operation(Operation.FROM)
|
|
827
|
+
def unionByName(self, other: Self, allowMissingColumns: bool = False) -> Self:
|
|
828
|
+
l_columns = self.columns
|
|
829
|
+
r_columns = other.columns
|
|
830
|
+
if not allowMissingColumns:
|
|
831
|
+
l_expressions = l_columns
|
|
832
|
+
r_expressions = l_columns
|
|
833
|
+
else:
|
|
834
|
+
l_expressions = []
|
|
835
|
+
r_expressions = []
|
|
836
|
+
r_columns_unused = copy(r_columns)
|
|
837
|
+
for l_column in l_columns:
|
|
838
|
+
l_expressions.append(l_column)
|
|
839
|
+
if l_column in r_columns:
|
|
840
|
+
r_expressions.append(l_column)
|
|
841
|
+
r_columns_unused.remove(l_column)
|
|
842
|
+
else:
|
|
843
|
+
r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
|
|
844
|
+
for r_column in r_columns_unused:
|
|
845
|
+
l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
|
|
846
|
+
r_expressions.append(r_column)
|
|
847
|
+
r_df = (
|
|
848
|
+
other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
|
|
849
|
+
)
|
|
850
|
+
l_df = self.copy()
|
|
851
|
+
if allowMissingColumns:
|
|
852
|
+
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
|
|
853
|
+
return l_df._set_operation(exp.Union, r_df, False)
|
|
854
|
+
|
|
855
|
+
@operation(Operation.FROM)
|
|
856
|
+
def intersect(self, other: Self) -> Self:
|
|
857
|
+
return self._set_operation(exp.Intersect, other, True)
|
|
858
|
+
|
|
859
|
+
@operation(Operation.FROM)
|
|
860
|
+
def intersectAll(self, other: Self) -> Self:
|
|
861
|
+
return self._set_operation(exp.Intersect, other, False)
|
|
862
|
+
|
|
863
|
+
@operation(Operation.FROM)
|
|
864
|
+
def exceptAll(self, other: Self) -> Self:
|
|
865
|
+
return self._set_operation(exp.Except, other, False)
|
|
866
|
+
|
|
867
|
+
@operation(Operation.SELECT)
|
|
868
|
+
def distinct(self) -> Self:
|
|
869
|
+
return self.copy(expression=self.expression.distinct())
|
|
870
|
+
|
|
871
|
+
@operation(Operation.SELECT)
|
|
872
|
+
def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
|
|
873
|
+
from sqlframe.base import functions as F
|
|
874
|
+
from sqlframe.base.window import Window
|
|
875
|
+
|
|
876
|
+
if not subset:
|
|
877
|
+
return self.distinct()
|
|
878
|
+
column_names = ensure_list(subset)
|
|
879
|
+
window = Window.partitionBy(*column_names).orderBy(*column_names)
|
|
880
|
+
return (
|
|
881
|
+
self.copy()
|
|
882
|
+
.withColumn("row_num", F.row_number().over(window))
|
|
883
|
+
.where(F.col("row_num") == F.lit(1))
|
|
884
|
+
.drop("row_num")
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
drop_duplicates = dropDuplicates
|
|
888
|
+
|
|
889
|
+
@operation(Operation.FROM)
|
|
890
|
+
def dropna(
|
|
891
|
+
self,
|
|
892
|
+
how: str = "any",
|
|
893
|
+
thresh: t.Optional[int] = None,
|
|
894
|
+
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
895
|
+
) -> Self:
|
|
896
|
+
from sqlframe.base import functions as F
|
|
897
|
+
|
|
898
|
+
minimum_non_null = thresh or 0 # will be determined later if thresh is null
|
|
899
|
+
new_df = self.copy()
|
|
900
|
+
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
901
|
+
if subset:
|
|
902
|
+
null_check_columns = self._ensure_and_normalize_cols(subset)
|
|
903
|
+
else:
|
|
904
|
+
null_check_columns = all_columns
|
|
905
|
+
if thresh is None:
|
|
906
|
+
minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
|
|
907
|
+
else:
|
|
908
|
+
minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
|
|
909
|
+
if minimum_num_nulls > len(null_check_columns):
|
|
910
|
+
raise RuntimeError(
|
|
911
|
+
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
|
|
912
|
+
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
|
|
913
|
+
)
|
|
914
|
+
if_null_checks = [
|
|
915
|
+
F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
|
|
916
|
+
]
|
|
917
|
+
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
|
|
918
|
+
num_nulls = nulls_added_together.alias("num_nulls")
|
|
919
|
+
new_df = new_df.select(num_nulls, append=True)
|
|
920
|
+
filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
|
|
921
|
+
final_df = filtered_df.select(*all_columns)
|
|
922
|
+
return final_df
|
|
923
|
+
|
|
924
|
+
def explain(
|
|
925
|
+
self, extended: t.Optional[t.Union[bool, str]] = None, mode: t.Optional[str] = None
|
|
926
|
+
) -> None:
|
|
927
|
+
"""Prints the (logical and physical) plans to the console for debugging purposes.
|
|
928
|
+
|
|
929
|
+
.. versionadded:: 1.3.0
|
|
930
|
+
|
|
931
|
+
.. versionchanged:: 3.4.0
|
|
932
|
+
Supports Spark Connect.
|
|
933
|
+
|
|
934
|
+
Parameters
|
|
935
|
+
----------
|
|
936
|
+
extended : bool, optional
|
|
937
|
+
default ``False``. If ``False``, prints only the physical plan.
|
|
938
|
+
When this is a string without specifying the ``mode``, it works as the mode is
|
|
939
|
+
specified.
|
|
940
|
+
mode : str, optional
|
|
941
|
+
specifies the expected output format of plans.
|
|
942
|
+
|
|
943
|
+
* ``simple``: Print only a physical plan.
|
|
944
|
+
* ``extended``: Print both logical and physical plans.
|
|
945
|
+
* ``codegen``: Print a physical plan and generated codes if they are available.
|
|
946
|
+
* ``cost``: Print a logical plan and statistics if they are available.
|
|
947
|
+
* ``formatted``: Split explain output into two sections: a physical plan outline \
|
|
948
|
+
and node details.
|
|
949
|
+
|
|
950
|
+
.. versionchanged:: 3.0.0
|
|
951
|
+
Added optional argument `mode` to specify the expected output format of plans.
|
|
952
|
+
|
|
953
|
+
Examples
|
|
954
|
+
--------
|
|
955
|
+
>>> df = spark.createDataFrame(
|
|
956
|
+
... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
|
|
957
|
+
|
|
958
|
+
Print out the physical plan only (default).
|
|
959
|
+
|
|
960
|
+
>>> df.explain() # doctest: +SKIP
|
|
961
|
+
== Physical Plan ==
|
|
962
|
+
*(1) Scan ExistingRDD[age...,name...]
|
|
963
|
+
|
|
964
|
+
Print out all of the parsed, analyzed, optimized and physical plans.
|
|
965
|
+
|
|
966
|
+
>>> df.explain(True)
|
|
967
|
+
== Parsed Logical Plan ==
|
|
968
|
+
...
|
|
969
|
+
== Analyzed Logical Plan ==
|
|
970
|
+
...
|
|
971
|
+
== Optimized Logical Plan ==
|
|
972
|
+
...
|
|
973
|
+
== Physical Plan ==
|
|
974
|
+
...
|
|
975
|
+
|
|
976
|
+
Print out the plans with two sections: a physical plan outline and node details
|
|
977
|
+
|
|
978
|
+
>>> df.explain(mode="formatted") # doctest: +SKIP
|
|
979
|
+
== Physical Plan ==
|
|
980
|
+
* Scan ExistingRDD (...)
|
|
981
|
+
(1) Scan ExistingRDD [codegen id : ...]
|
|
982
|
+
Output [2]: [age..., name...]
|
|
983
|
+
...
|
|
984
|
+
|
|
985
|
+
Print a logical plan and statistics if they are available.
|
|
986
|
+
|
|
987
|
+
>>> df.explain("cost")
|
|
988
|
+
== Optimized Logical Plan ==
|
|
989
|
+
...Statistics...
|
|
990
|
+
...
|
|
991
|
+
"""
|
|
992
|
+
sql_queries = self.sql(pretty=False, optimize=False, as_list=True)
|
|
993
|
+
if len(sql_queries) > 1:
|
|
994
|
+
raise ValueError("Cannot explain a DataFrame with multiple queries")
|
|
995
|
+
sql_query = "EXPLAIN " + sql_queries[0]
|
|
996
|
+
self.session._execute(sql_query, quote_identifiers=False)
|
|
997
|
+
|
|
998
|
+
@operation(Operation.FROM)
|
|
999
|
+
def fillna(
|
|
1000
|
+
self,
|
|
1001
|
+
value: t.Union[PrimitiveType, t.Dict[str, PrimitiveType]],
|
|
1002
|
+
subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
|
|
1003
|
+
) -> Self:
|
|
1004
|
+
"""
|
|
1005
|
+
Functionality Difference: If you provide a value to replace a null and that type conflicts
|
|
1006
|
+
with the type of the column then PySpark will just ignore your replacement.
|
|
1007
|
+
This will try to cast them to be the same in some cases. So they won't always match.
|
|
1008
|
+
Best to not mix types so make sure replacement is the same type as the column
|
|
1009
|
+
|
|
1010
|
+
Possibility for improvement: Use `typeof` function to get the type of the column
|
|
1011
|
+
and check if it matches the type of the value provided. If not then make it null.
|
|
1012
|
+
"""
|
|
1013
|
+
from sqlframe.base import functions as F
|
|
1014
|
+
|
|
1015
|
+
values = None
|
|
1016
|
+
columns = None
|
|
1017
|
+
new_df = self.copy()
|
|
1018
|
+
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
1019
|
+
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
|
1020
|
+
if isinstance(value, dict):
|
|
1021
|
+
values = list(value.values())
|
|
1022
|
+
columns = self._ensure_and_normalize_cols(list(value))
|
|
1023
|
+
if not columns:
|
|
1024
|
+
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
|
1025
|
+
if not values:
|
|
1026
|
+
assert not isinstance(value, dict)
|
|
1027
|
+
values = [value] * len(columns)
|
|
1028
|
+
value_columns = [F.lit(value) for value in values]
|
|
1029
|
+
|
|
1030
|
+
null_replacement_mapping = {
|
|
1031
|
+
column.alias_or_name: (
|
|
1032
|
+
F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
|
|
1033
|
+
)
|
|
1034
|
+
for column, value in zip(columns, value_columns)
|
|
1035
|
+
}
|
|
1036
|
+
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
|
|
1037
|
+
null_replacement_columns = [
|
|
1038
|
+
null_replacement_mapping[column.alias_or_name] for column in all_columns
|
|
1039
|
+
]
|
|
1040
|
+
new_df = new_df.select(*null_replacement_columns)
|
|
1041
|
+
return new_df
|
|
1042
|
+
|
|
1043
|
+
@operation(Operation.FROM)
|
|
1044
|
+
def replace(
|
|
1045
|
+
self,
|
|
1046
|
+
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
|
|
1047
|
+
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
|
|
1048
|
+
subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
|
|
1049
|
+
) -> Self:
|
|
1050
|
+
from sqlframe.base import functions as F
|
|
1051
|
+
from sqlframe.base.column import Column
|
|
1052
|
+
|
|
1053
|
+
old_values = None
|
|
1054
|
+
new_df = self.copy()
|
|
1055
|
+
all_columns = self._get_outer_select_columns(new_df.expression)
|
|
1056
|
+
all_column_mapping = {column.alias_or_name: column for column in all_columns}
|
|
1057
|
+
|
|
1058
|
+
columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
|
|
1059
|
+
if isinstance(to_replace, dict):
|
|
1060
|
+
old_values = list(to_replace)
|
|
1061
|
+
new_values = list(to_replace.values())
|
|
1062
|
+
elif not old_values and isinstance(to_replace, list):
|
|
1063
|
+
assert isinstance(value, list), "value must be a list since the replacements are a list"
|
|
1064
|
+
assert len(to_replace) == len(
|
|
1065
|
+
value
|
|
1066
|
+
), "the replacements and values must be the same length"
|
|
1067
|
+
old_values = to_replace
|
|
1068
|
+
new_values = value
|
|
1069
|
+
else:
|
|
1070
|
+
old_values = [to_replace] * len(columns)
|
|
1071
|
+
new_values = [value] * len(columns)
|
|
1072
|
+
old_values = [F.lit(value) for value in old_values]
|
|
1073
|
+
new_values = [F.lit(value) for value in new_values]
|
|
1074
|
+
|
|
1075
|
+
replacement_mapping = {}
|
|
1076
|
+
for column in columns:
|
|
1077
|
+
# expression = Column(None)
|
|
1078
|
+
expression = F.lit(None)
|
|
1079
|
+
for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
|
|
1080
|
+
if i == 0:
|
|
1081
|
+
expression = F.when(column == old_value, new_value)
|
|
1082
|
+
else:
|
|
1083
|
+
expression = expression.when(column == old_value, new_value) # type: ignore
|
|
1084
|
+
replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
|
|
1085
|
+
column.expression.alias_or_name
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
replacement_mapping = {**all_column_mapping, **replacement_mapping}
|
|
1089
|
+
replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
|
|
1090
|
+
new_df = new_df.select(*replacement_columns)
|
|
1091
|
+
return new_df
|
|
1092
|
+
|
|
1093
|
+
@operation(Operation.SELECT)
|
|
1094
|
+
def withColumn(self, colName: str, col: Column) -> Self:
|
|
1095
|
+
col = self._ensure_and_normalize_col(col)
|
|
1096
|
+
col_name = self._ensure_and_normalize_col(colName).alias_or_name
|
|
1097
|
+
existing_col_names = self.expression.named_selects
|
|
1098
|
+
existing_col_index = (
|
|
1099
|
+
existing_col_names.index(col_name) if col_name in existing_col_names else None
|
|
1100
|
+
)
|
|
1101
|
+
if existing_col_index:
|
|
1102
|
+
expression = self.expression.copy()
|
|
1103
|
+
expression.expressions[existing_col_index] = col.alias(col_name).expression
|
|
1104
|
+
return self.copy(expression=expression)
|
|
1105
|
+
return self.select.__wrapped__(self, col.alias(col_name), append=True) # type: ignore
|
|
1106
|
+
|
|
1107
|
+
@operation(Operation.SELECT)
|
|
1108
|
+
def withColumnRenamed(self, existing: str, new: str) -> Self:
|
|
1109
|
+
expression = self.expression.copy()
|
|
1110
|
+
existing = self.session._normalize_string(existing)
|
|
1111
|
+
new = self.session._normalize_string(new)
|
|
1112
|
+
existing_columns = [
|
|
1113
|
+
expression
|
|
1114
|
+
for expression in expression.expressions
|
|
1115
|
+
if expression.alias_or_name == existing
|
|
1116
|
+
]
|
|
1117
|
+
if not existing_columns:
|
|
1118
|
+
raise ValueError("Tried to rename a column that doesn't exist")
|
|
1119
|
+
for existing_column in existing_columns:
|
|
1120
|
+
if isinstance(existing_column, exp.Column):
|
|
1121
|
+
existing_column.replace(exp.alias_(existing_column, new))
|
|
1122
|
+
else:
|
|
1123
|
+
existing_column.set("alias", exp.to_identifier(new))
|
|
1124
|
+
return self.copy(expression=expression)
|
|
1125
|
+
|
|
1126
|
+
@operation(Operation.SELECT)
|
|
1127
|
+
def drop(self, *cols: t.Union[str, Column]) -> Self:
|
|
1128
|
+
all_columns = self._get_outer_select_columns(self.expression)
|
|
1129
|
+
drop_cols = self._ensure_and_normalize_cols(cols)
|
|
1130
|
+
new_columns = [
|
|
1131
|
+
col
|
|
1132
|
+
for col in all_columns
|
|
1133
|
+
if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
|
|
1134
|
+
]
|
|
1135
|
+
return self.copy().select(*new_columns, append=False)
|
|
1136
|
+
|
|
1137
|
+
@operation(Operation.LIMIT)
|
|
1138
|
+
def limit(self, num: int) -> Self:
|
|
1139
|
+
return self.copy(expression=self.expression.limit(num))
|
|
1140
|
+
|
|
1141
|
+
def toDF(self, *cols: str) -> Self:
|
|
1142
|
+
"""Returns a new :class:`DataFrame` that with new specified column names
|
|
1143
|
+
|
|
1144
|
+
.. versionadded:: 1.6.0
|
|
1145
|
+
|
|
1146
|
+
.. versionchanged:: 3.4.0
|
|
1147
|
+
Supports Spark Connect.
|
|
1148
|
+
|
|
1149
|
+
Parameters
|
|
1150
|
+
----------
|
|
1151
|
+
*cols : tuple
|
|
1152
|
+
a tuple of string new column name. The length of the
|
|
1153
|
+
list needs to be the same as the number of columns in the initial
|
|
1154
|
+
:class:`DataFrame`
|
|
1155
|
+
|
|
1156
|
+
Returns
|
|
1157
|
+
-------
|
|
1158
|
+
:class:`DataFrame`
|
|
1159
|
+
DataFrame with new column names.
|
|
1160
|
+
|
|
1161
|
+
Examples
|
|
1162
|
+
--------
|
|
1163
|
+
>>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"),
|
|
1164
|
+
... (16, "Bob")], ["age", "name"])
|
|
1165
|
+
>>> df.toDF('f1', 'f2').show()
|
|
1166
|
+
+---+-----+
|
|
1167
|
+
| f1| f2|
|
|
1168
|
+
+---+-----+
|
|
1169
|
+
| 14| Tom|
|
|
1170
|
+
| 23|Alice|
|
|
1171
|
+
| 16| Bob|
|
|
1172
|
+
+---+-----+
|
|
1173
|
+
"""
|
|
1174
|
+
if len(cols) != len(self.columns):
|
|
1175
|
+
raise ValueError(
|
|
1176
|
+
f"Number of column names does not match number of columns: {len(cols)} != {len(self.columns)}"
|
|
1177
|
+
)
|
|
1178
|
+
expression = self.expression.copy()
|
|
1179
|
+
expression = expression.select(
|
|
1180
|
+
*[exp.alias_(col, new_col) for col, new_col in zip(expression.expressions, cols)],
|
|
1181
|
+
append=False,
|
|
1182
|
+
)
|
|
1183
|
+
return self.copy(expression=expression)
|
|
1184
|
+
|
|
1185
|
+
@operation(Operation.NO_OP)
|
|
1186
|
+
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> Self:
|
|
1187
|
+
from sqlframe.base.column import Column
|
|
1188
|
+
|
|
1189
|
+
parameter_list = ensure_list(parameters)
|
|
1190
|
+
parameter_columns = (
|
|
1191
|
+
self._ensure_list_of_columns(parameter_list)
|
|
1192
|
+
if parameters
|
|
1193
|
+
else Column.ensure_cols([self.sequence_id])
|
|
1194
|
+
)
|
|
1195
|
+
return self._hint(name, parameter_columns)
|
|
1196
|
+
|
|
1197
|
+
@operation(Operation.NO_OP)
|
|
1198
|
+
def repartition(self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName) -> Self:
|
|
1199
|
+
num_partition_cols = self._ensure_list_of_columns(numPartitions)
|
|
1200
|
+
columns = self._ensure_and_normalize_cols(cols)
|
|
1201
|
+
args = num_partition_cols + columns
|
|
1202
|
+
return self._hint("repartition", args)
|
|
1203
|
+
|
|
1204
|
+
@operation(Operation.NO_OP)
|
|
1205
|
+
def coalesce(self, numPartitions: int) -> Self:
|
|
1206
|
+
lit = get_func_from_session("lit")
|
|
1207
|
+
|
|
1208
|
+
num_partitions = lit(numPartitions)
|
|
1209
|
+
return self._hint("coalesce", [num_partitions])
|
|
1210
|
+
|
|
1211
|
+
@operation(Operation.NO_OP)
|
|
1212
|
+
def cache(self) -> Self:
|
|
1213
|
+
return self._cache(storage_level="MEMORY_AND_DISK")
|
|
1214
|
+
|
|
1215
|
+
@operation(Operation.NO_OP)
|
|
1216
|
+
def persist(self, storageLevel: StorageLevel = "MEMORY_AND_DISK_SER") -> Self:
|
|
1217
|
+
"""
|
|
1218
|
+
Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
|
|
1219
|
+
"""
|
|
1220
|
+
return self._cache(storageLevel)
|
|
1221
|
+
|
|
1222
|
+
@t.overload
|
|
1223
|
+
def cube(self, *cols: ColumnOrName) -> GROUP_DATA: ...
|
|
1224
|
+
|
|
1225
|
+
@t.overload
|
|
1226
|
+
def cube(self, __cols: t.Union[t.List[Column], t.List[str]]) -> GROUP_DATA: ...
|
|
1227
|
+
|
|
1228
|
+
def cube(self, *cols: ColumnOrName) -> GROUP_DATA: # type: ignore[misc]
|
|
1229
|
+
"""
|
|
1230
|
+
Create a multi-dimensional cube for the current :class:`DataFrame` using
|
|
1231
|
+
the specified columns, so we can run aggregations on them.
|
|
1232
|
+
|
|
1233
|
+
.. versionadded:: 1.4.0
|
|
1234
|
+
|
|
1235
|
+
.. versionchanged:: 3.4.0
|
|
1236
|
+
Supports Spark Connect.
|
|
1237
|
+
|
|
1238
|
+
Parameters
|
|
1239
|
+
----------
|
|
1240
|
+
cols : list, str or :class:`Column`
|
|
1241
|
+
columns to create cube by.
|
|
1242
|
+
Each element should be a column name (string) or an expression (:class:`Column`)
|
|
1243
|
+
or list of them.
|
|
1244
|
+
|
|
1245
|
+
Returns
|
|
1246
|
+
-------
|
|
1247
|
+
:class:`GroupedData`
|
|
1248
|
+
Cube of the data by given columns.
|
|
1249
|
+
|
|
1250
|
+
Examples
|
|
1251
|
+
--------
|
|
1252
|
+
>>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
|
|
1253
|
+
>>> df.cube("name", df.age).count().orderBy("name", "age").show()
|
|
1254
|
+
+-----+----+-----+
|
|
1255
|
+
| name| age|count|
|
|
1256
|
+
+-----+----+-----+
|
|
1257
|
+
| NULL|NULL| 2|
|
|
1258
|
+
| NULL| 2| 1|
|
|
1259
|
+
| NULL| 5| 1|
|
|
1260
|
+
|Alice|NULL| 1|
|
|
1261
|
+
|Alice| 2| 1|
|
|
1262
|
+
| Bob|NULL| 1|
|
|
1263
|
+
| Bob| 5| 1|
|
|
1264
|
+
+-----+----+-----+
|
|
1265
|
+
"""
|
|
1266
|
+
|
|
1267
|
+
columns = self._ensure_and_normalize_cols(cols)
|
|
1268
|
+
grouping_columns: t.List[t.List[Column]] = []
|
|
1269
|
+
for i in reversed(range(len(columns) + 1)):
|
|
1270
|
+
grouping_columns.extend([list(x) for x in itertools.combinations(columns, i)])
|
|
1271
|
+
return self._group_data(self, grouping_columns, self.last_op)
|
|
1272
|
+
|
|
1273
|
+
def collect(self) -> t.List[Row]:
|
|
1274
|
+
result = []
|
|
1275
|
+
for sql in self.sql(pretty=False, optimize=False, as_list=True):
|
|
1276
|
+
result = self.session._fetch_rows(sql)
|
|
1277
|
+
return result
|
|
1278
|
+
|
|
1279
|
+
@t.overload
|
|
1280
|
+
def head(self) -> t.Optional[Row]: ...
|
|
1281
|
+
|
|
1282
|
+
@t.overload
|
|
1283
|
+
def head(self, n: int) -> t.List[Row]: ...
|
|
1284
|
+
|
|
1285
|
+
def head(self, n: t.Optional[int] = None) -> t.Union[t.Optional[Row], t.List[Row]]:
|
|
1286
|
+
n = n or 1
|
|
1287
|
+
df = self.limit(n)
|
|
1288
|
+
if n == 1:
|
|
1289
|
+
return df.collect()[0]
|
|
1290
|
+
return df.collect()
|
|
1291
|
+
|
|
1292
|
+
def first(self) -> t.Optional[Row]:
|
|
1293
|
+
return self.head()
|
|
1294
|
+
|
|
1295
|
+
def show(
|
|
1296
|
+
self, n: int = 20, truncate: t.Optional[t.Union[bool, int]] = None, vertical: bool = False
|
|
1297
|
+
):
|
|
1298
|
+
if vertical:
|
|
1299
|
+
raise NotImplementedError("Vertical show is not yet supported")
|
|
1300
|
+
if truncate:
|
|
1301
|
+
logger.warning("Truncate is ignored so full results will be displayed")
|
|
1302
|
+
# Make sure that the limit we add doesn't affect the results
|
|
1303
|
+
df = self._convert_leaf_to_cte()
|
|
1304
|
+
sql = df.limit(n).sql(
|
|
1305
|
+
pretty=False, optimize=False, dialect=self.session.output_dialect, as_list=True
|
|
1306
|
+
)
|
|
1307
|
+
for sql in ensure_list(sql):
|
|
1308
|
+
result = self.session._fetch_rows(sql)
|
|
1309
|
+
table = PrettyTable()
|
|
1310
|
+
if row := seq_get(result, 0):
|
|
1311
|
+
table.field_names = list(row.asDict().keys())
|
|
1312
|
+
for row in result:
|
|
1313
|
+
table.add_row(list(row))
|
|
1314
|
+
print(table)
|
|
1315
|
+
|
|
1316
|
+
def toPandas(self) -> pd.DataFrame:
|
|
1317
|
+
sql_kwargs = dict(
|
|
1318
|
+
pretty=False, optimize=False, dialect=self.session.output_dialect, as_list=True
|
|
1319
|
+
)
|
|
1320
|
+
sqls = [None] + self.sql(**sql_kwargs) # type: ignore
|
|
1321
|
+
for sql in self.sql(**sql_kwargs)[:-1]: # type: ignore
|
|
1322
|
+
if sql:
|
|
1323
|
+
self.session._execute(sql)
|
|
1324
|
+
assert sqls[-1] is not None
|
|
1325
|
+
return self.session._fetchdf(sqls[-1])
|
|
1326
|
+
|
|
1327
|
+
def createOrReplaceTempView(self, name: str) -> None:
|
|
1328
|
+
self.session.temp_views[name] = self.copy()._convert_leaf_to_cte()
|
|
1329
|
+
|
|
1330
|
+
def count(self) -> int:
|
|
1331
|
+
if not self.session._has_connection:
|
|
1332
|
+
raise RuntimeError("Cannot count without a connection")
|
|
1333
|
+
|
|
1334
|
+
df = self._convert_leaf_to_cte()
|
|
1335
|
+
df = self.copy(expression=df.expression.select("count(*)", append=False))
|
|
1336
|
+
for sql in df.sql(
|
|
1337
|
+
dialect=self.session.output_dialect, pretty=False, optimize=False, as_list=True
|
|
1338
|
+
):
|
|
1339
|
+
result = self.session._fetch_rows(sql)
|
|
1340
|
+
return result[0][0]
|
|
1341
|
+
|
|
1342
|
+
def createGlobalTempView(self, name: str) -> None:
|
|
1343
|
+
raise NotImplementedError("Global temp views are not yet supported")
|
|
1344
|
+
|
|
1345
|
+
"""
|
|
1346
|
+
Stat Functions
|
|
1347
|
+
"""
|
|
1348
|
+
|
|
1349
|
+
@t.overload
|
|
1350
|
+
def approxQuantile(
|
|
1351
|
+
self,
|
|
1352
|
+
col: str,
|
|
1353
|
+
probabilities: t.Union[t.List[float], t.Tuple[float]],
|
|
1354
|
+
relativeError: float,
|
|
1355
|
+
) -> t.List[float]: ...
|
|
1356
|
+
|
|
1357
|
+
@t.overload
|
|
1358
|
+
def approxQuantile(
|
|
1359
|
+
self,
|
|
1360
|
+
col: t.Union[t.List[str], t.Tuple[str]],
|
|
1361
|
+
probabilities: t.Union[t.List[float], t.Tuple[float]],
|
|
1362
|
+
relativeError: float,
|
|
1363
|
+
) -> t.List[t.List[float]]: ...
|
|
1364
|
+
|
|
1365
|
+
def approxQuantile(
|
|
1366
|
+
self,
|
|
1367
|
+
col: t.Union[str, t.List[str], t.Tuple[str]],
|
|
1368
|
+
probabilities: t.Union[t.List[float], t.Tuple[float]],
|
|
1369
|
+
relativeError: float,
|
|
1370
|
+
) -> t.Union[t.List[float], t.List[t.List[float]]]:
|
|
1371
|
+
"""
|
|
1372
|
+
Calculates the approximate quantiles of numerical columns of a
|
|
1373
|
+
:class:`DataFrame`.
|
|
1374
|
+
|
|
1375
|
+
The result of this algorithm has the following deterministic bound:
|
|
1376
|
+
If the :class:`DataFrame` has N elements and if we request the quantile at
|
|
1377
|
+
probability `p` up to error `err`, then the algorithm will return
|
|
1378
|
+
a sample `x` from the :class:`DataFrame` so that the *exact* rank of `x` is
|
|
1379
|
+
close to (p * N). More precisely,
|
|
1380
|
+
|
|
1381
|
+
floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
|
|
1382
|
+
|
|
1383
|
+
This method implements a variation of the Greenwald-Khanna
|
|
1384
|
+
algorithm (with some speed optimizations). The algorithm was first
|
|
1385
|
+
present in [[https://doi.org/10.1145/375663.375670
|
|
1386
|
+
Space-efficient Online Computation of Quantile Summaries]]
|
|
1387
|
+
by Greenwald and Khanna.
|
|
1388
|
+
|
|
1389
|
+
.. versionadded:: 2.0.0
|
|
1390
|
+
|
|
1391
|
+
.. versionchanged:: 3.4.0
|
|
1392
|
+
Supports Spark Connect.
|
|
1393
|
+
|
|
1394
|
+
Parameters
|
|
1395
|
+
----------
|
|
1396
|
+
col: str, tuple or list
|
|
1397
|
+
Can be a single column name, or a list of names for multiple columns.
|
|
1398
|
+
|
|
1399
|
+
.. versionchanged:: 2.2.0
|
|
1400
|
+
Added support for multiple columns.
|
|
1401
|
+
probabilities : list or tuple
|
|
1402
|
+
a list of quantile probabilities
|
|
1403
|
+
Each number must belong to [0, 1].
|
|
1404
|
+
For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
|
|
1405
|
+
relativeError : float
|
|
1406
|
+
The relative target precision to achieve
|
|
1407
|
+
(>= 0). If set to zero, the exact quantiles are computed, which
|
|
1408
|
+
could be very expensive. Note that values greater than 1 are
|
|
1409
|
+
accepted but gives the same result as 1.
|
|
1410
|
+
|
|
1411
|
+
Returns
|
|
1412
|
+
-------
|
|
1413
|
+
list
|
|
1414
|
+
the approximate quantiles at the given probabilities.
|
|
1415
|
+
|
|
1416
|
+
* If the input `col` is a string, the output is a list of floats.
|
|
1417
|
+
|
|
1418
|
+
* If the input `col` is a list or tuple of strings, the output is also a
|
|
1419
|
+
list, but each element in it is a list of floats, i.e., the output
|
|
1420
|
+
is a list of list of floats.
|
|
1421
|
+
|
|
1422
|
+
Notes
|
|
1423
|
+
-----
|
|
1424
|
+
Null values will be ignored in numerical columns before calculation.
|
|
1425
|
+
For columns only containing null values, an empty list is returned.
|
|
1426
|
+
"""
|
|
1427
|
+
|
|
1428
|
+
percentile_approx = get_func_from_session("percentile_approx")
|
|
1429
|
+
col_func = get_func_from_session("col")
|
|
1430
|
+
|
|
1431
|
+
accuracy = 1.0 / relativeError if relativeError > 0.0 else 10000
|
|
1432
|
+
|
|
1433
|
+
df = self.select(
|
|
1434
|
+
*[
|
|
1435
|
+
percentile_approx(col_func(x), probabilities, accuracy).alias(f"val_{i}")
|
|
1436
|
+
for i, x in enumerate(ensure_list(col))
|
|
1437
|
+
]
|
|
1438
|
+
)
|
|
1439
|
+
rows = df.collect()
|
|
1440
|
+
return [[float(y) for y in x] for row in rows for x in row.asDict().values()]
|
|
1441
|
+
|
|
1442
|
+
def corr(self, col1: str, col2: str, method: t.Optional[str] = None) -> float:
|
|
1443
|
+
"""
|
|
1444
|
+
Calculates the correlation of two columns of a :class:`DataFrame` as a double value.
|
|
1445
|
+
Currently only supports the Pearson Correlation Coefficient.
|
|
1446
|
+
:func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other.
|
|
1447
|
+
|
|
1448
|
+
.. versionadded:: 1.4.0
|
|
1449
|
+
|
|
1450
|
+
.. versionchanged:: 3.4.0
|
|
1451
|
+
Supports Spark Connect.
|
|
1452
|
+
|
|
1453
|
+
Parameters
|
|
1454
|
+
----------
|
|
1455
|
+
col1 : str
|
|
1456
|
+
The name of the first column
|
|
1457
|
+
col2 : str
|
|
1458
|
+
The name of the second column
|
|
1459
|
+
method : str, optional
|
|
1460
|
+
The correlation method. Currently only supports "pearson"
|
|
1461
|
+
|
|
1462
|
+
Returns
|
|
1463
|
+
-------
|
|
1464
|
+
float
|
|
1465
|
+
Pearson Correlation Coefficient of two columns.
|
|
1466
|
+
|
|
1467
|
+
Examples
|
|
1468
|
+
--------
|
|
1469
|
+
>>> df = spark.createDataFrame([(1, 12), (10, 1), (19, 8)], ["c1", "c2"])
|
|
1470
|
+
>>> df.corr("c1", "c2")
|
|
1471
|
+
-0.3592106040535498
|
|
1472
|
+
>>> df = spark.createDataFrame([(11, 12), (10, 11), (9, 10)], ["small", "bigger"])
|
|
1473
|
+
>>> df.corr("small", "bigger")
|
|
1474
|
+
1.0
|
|
1475
|
+
"""
|
|
1476
|
+
if method != "pearson":
|
|
1477
|
+
raise ValueError(f"Currently only the Pearson Correlation Coefficient is supported")
|
|
1478
|
+
|
|
1479
|
+
corr = get_func_from_session("corr")
|
|
1480
|
+
col_func = get_func_from_session("col")
|
|
1481
|
+
|
|
1482
|
+
return self.select(corr(col_func(col1), col_func(col2))).collect()[0][0]
|
|
1483
|
+
|
|
1484
|
+
def cov(self, col1: str, col2: str) -> float:
|
|
1485
|
+
"""
|
|
1486
|
+
Calculate the sample covariance for the given columns, specified by their names, as a
|
|
1487
|
+
double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases.
|
|
1488
|
+
|
|
1489
|
+
.. versionadded:: 1.4.0
|
|
1490
|
+
|
|
1491
|
+
.. versionchanged:: 3.4.0
|
|
1492
|
+
Supports Spark Connect.
|
|
1493
|
+
|
|
1494
|
+
Parameters
|
|
1495
|
+
----------
|
|
1496
|
+
col1 : str
|
|
1497
|
+
The name of the first column
|
|
1498
|
+
col2 : str
|
|
1499
|
+
The name of the second column
|
|
1500
|
+
|
|
1501
|
+
Returns
|
|
1502
|
+
-------
|
|
1503
|
+
float
|
|
1504
|
+
Covariance of two columns.
|
|
1505
|
+
|
|
1506
|
+
Examples
|
|
1507
|
+
--------
|
|
1508
|
+
>>> df = spark.createDataFrame([(1, 12), (10, 1), (19, 8)], ["c1", "c2"])
|
|
1509
|
+
>>> df.cov("c1", "c2")
|
|
1510
|
+
-18.0
|
|
1511
|
+
>>> df = spark.createDataFrame([(11, 12), (10, 11), (9, 10)], ["small", "bigger"])
|
|
1512
|
+
>>> df.cov("small", "bigger")
|
|
1513
|
+
1.0
|
|
1514
|
+
|
|
1515
|
+
"""
|
|
1516
|
+
covar_samp = get_func_from_session("covar_samp")
|
|
1517
|
+
col_func = get_func_from_session("col")
|
|
1518
|
+
|
|
1519
|
+
return self.select(covar_samp(col_func(col1), col_func(col2))).collect()[0][0]
|