sqlframe 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sqlframe/__init__.py +0 -0
  2. sqlframe/_version.py +16 -0
  3. sqlframe/base/__init__.py +0 -0
  4. sqlframe/base/_typing.py +39 -0
  5. sqlframe/base/catalog.py +1163 -0
  6. sqlframe/base/column.py +388 -0
  7. sqlframe/base/dataframe.py +1519 -0
  8. sqlframe/base/decorators.py +51 -0
  9. sqlframe/base/exceptions.py +14 -0
  10. sqlframe/base/function_alternatives.py +1055 -0
  11. sqlframe/base/functions.py +1678 -0
  12. sqlframe/base/group.py +102 -0
  13. sqlframe/base/mixins/__init__.py +0 -0
  14. sqlframe/base/mixins/catalog_mixins.py +419 -0
  15. sqlframe/base/mixins/readwriter_mixins.py +118 -0
  16. sqlframe/base/normalize.py +84 -0
  17. sqlframe/base/operations.py +87 -0
  18. sqlframe/base/readerwriter.py +679 -0
  19. sqlframe/base/session.py +585 -0
  20. sqlframe/base/transforms.py +13 -0
  21. sqlframe/base/types.py +418 -0
  22. sqlframe/base/util.py +242 -0
  23. sqlframe/base/window.py +139 -0
  24. sqlframe/bigquery/__init__.py +23 -0
  25. sqlframe/bigquery/catalog.py +255 -0
  26. sqlframe/bigquery/column.py +1 -0
  27. sqlframe/bigquery/dataframe.py +54 -0
  28. sqlframe/bigquery/functions.py +378 -0
  29. sqlframe/bigquery/group.py +14 -0
  30. sqlframe/bigquery/readwriter.py +29 -0
  31. sqlframe/bigquery/session.py +89 -0
  32. sqlframe/bigquery/types.py +1 -0
  33. sqlframe/bigquery/window.py +1 -0
  34. sqlframe/duckdb/__init__.py +20 -0
  35. sqlframe/duckdb/catalog.py +108 -0
  36. sqlframe/duckdb/column.py +1 -0
  37. sqlframe/duckdb/dataframe.py +55 -0
  38. sqlframe/duckdb/functions.py +47 -0
  39. sqlframe/duckdb/group.py +14 -0
  40. sqlframe/duckdb/readwriter.py +111 -0
  41. sqlframe/duckdb/session.py +65 -0
  42. sqlframe/duckdb/types.py +1 -0
  43. sqlframe/duckdb/window.py +1 -0
  44. sqlframe/postgres/__init__.py +23 -0
  45. sqlframe/postgres/catalog.py +106 -0
  46. sqlframe/postgres/column.py +1 -0
  47. sqlframe/postgres/dataframe.py +54 -0
  48. sqlframe/postgres/functions.py +61 -0
  49. sqlframe/postgres/group.py +14 -0
  50. sqlframe/postgres/readwriter.py +29 -0
  51. sqlframe/postgres/session.py +68 -0
  52. sqlframe/postgres/types.py +1 -0
  53. sqlframe/postgres/window.py +1 -0
  54. sqlframe/redshift/__init__.py +23 -0
  55. sqlframe/redshift/catalog.py +127 -0
  56. sqlframe/redshift/column.py +1 -0
  57. sqlframe/redshift/dataframe.py +54 -0
  58. sqlframe/redshift/functions.py +18 -0
  59. sqlframe/redshift/group.py +14 -0
  60. sqlframe/redshift/readwriter.py +29 -0
  61. sqlframe/redshift/session.py +53 -0
  62. sqlframe/redshift/types.py +1 -0
  63. sqlframe/redshift/window.py +1 -0
  64. sqlframe/snowflake/__init__.py +26 -0
  65. sqlframe/snowflake/catalog.py +134 -0
  66. sqlframe/snowflake/column.py +1 -0
  67. sqlframe/snowflake/dataframe.py +54 -0
  68. sqlframe/snowflake/functions.py +18 -0
  69. sqlframe/snowflake/group.py +14 -0
  70. sqlframe/snowflake/readwriter.py +29 -0
  71. sqlframe/snowflake/session.py +53 -0
  72. sqlframe/snowflake/types.py +1 -0
  73. sqlframe/snowflake/window.py +1 -0
  74. sqlframe/spark/__init__.py +23 -0
  75. sqlframe/spark/catalog.py +1028 -0
  76. sqlframe/spark/column.py +1 -0
  77. sqlframe/spark/dataframe.py +54 -0
  78. sqlframe/spark/functions.py +22 -0
  79. sqlframe/spark/group.py +14 -0
  80. sqlframe/spark/readwriter.py +29 -0
  81. sqlframe/spark/session.py +90 -0
  82. sqlframe/spark/types.py +1 -0
  83. sqlframe/spark/window.py +1 -0
  84. sqlframe/standalone/__init__.py +26 -0
  85. sqlframe/standalone/catalog.py +13 -0
  86. sqlframe/standalone/column.py +1 -0
  87. sqlframe/standalone/dataframe.py +36 -0
  88. sqlframe/standalone/functions.py +1 -0
  89. sqlframe/standalone/group.py +14 -0
  90. sqlframe/standalone/readwriter.py +19 -0
  91. sqlframe/standalone/session.py +40 -0
  92. sqlframe/standalone/types.py +1 -0
  93. sqlframe/standalone/window.py +1 -0
  94. sqlframe-1.1.3.dist-info/LICENSE +21 -0
  95. sqlframe-1.1.3.dist-info/METADATA +172 -0
  96. sqlframe-1.1.3.dist-info/RECORD +98 -0
  97. sqlframe-1.1.3.dist-info/WHEEL +5 -0
  98. sqlframe-1.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1519 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import itertools
7
+ import logging
8
+ import sys
9
+ import typing as t
10
+ import zlib
11
+ from copy import copy
12
+
13
+ import sqlglot
14
+ from prettytable import PrettyTable
15
+ from sqlglot import Dialect
16
+ from sqlglot import expressions as exp
17
+ from sqlglot.helper import ensure_list, object_to_dict, seq_get
18
+ from sqlglot.optimizer.qualify_columns import quote_identifiers
19
+
20
+ from sqlframe.base.operations import Operation, operation
21
+ from sqlframe.base.transforms import replace_id_value
22
+ from sqlframe.base.util import (
23
+ get_func_from_session,
24
+ get_tables_from_expression_with_join,
25
+ )
26
+
27
+ if sys.version_info >= (3, 11):
28
+ from typing import Self
29
+ else:
30
+ from typing_extensions import Self
31
+
32
+ if t.TYPE_CHECKING:
33
+ import pandas as pd
34
+ from sqlglot.dialects.dialect import DialectType
35
+
36
+ from sqlframe.base._typing import (
37
+ ColumnOrLiteral,
38
+ ColumnOrName,
39
+ OutputExpressionContainer,
40
+ PrimitiveType,
41
+ StorageLevel,
42
+ )
43
+ from sqlframe.base.column import Column
44
+ from sqlframe.base.group import _BaseGroupedData
45
+ from sqlframe.base.session import WRITER, _BaseSession
46
+ from sqlframe.base.types import Row, StructType
47
+
48
+ SESSION = t.TypeVar("SESSION", bound=_BaseSession)
49
+ GROUP_DATA = t.TypeVar("GROUP_DATA", bound=_BaseGroupedData)
50
+ else:
51
+ WRITER = t.TypeVar("WRITER")
52
+ SESSION = t.TypeVar("SESSION")
53
+ GROUP_DATA = t.TypeVar("GROUP_DATA")
54
+
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+ JOIN_HINTS = {
59
+ "BROADCAST",
60
+ "BROADCASTJOIN",
61
+ "MAPJOIN",
62
+ "MERGE",
63
+ "SHUFFLEMERGE",
64
+ "MERGEJOIN",
65
+ "SHUFFLE_HASH",
66
+ "SHUFFLE_REPLICATE_NL",
67
+ }
68
+
69
+
70
+ DF = t.TypeVar("DF", bound="_BaseDataFrame")
71
+
72
+
73
+ class _BaseDataFrameNaFunctions(t.Generic[DF]):
74
+ def __init__(self, df: DF):
75
+ self.df = df
76
+
77
+ def drop(
78
+ self,
79
+ how: str = "any",
80
+ thresh: t.Optional[int] = None,
81
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
82
+ ) -> DF:
83
+ return self.df.dropna(how=how, thresh=thresh, subset=subset)
84
+
85
+ @t.overload
86
+ def fill(self, value: PrimitiveType, subset: t.Optional[t.List[str]] = ...) -> DF: ...
87
+
88
+ @t.overload
89
+ def fill(self, value: t.Dict[str, PrimitiveType]) -> DF: ...
90
+
91
+ def fill(
92
+ self,
93
+ value: t.Union[PrimitiveType, t.Dict[str, PrimitiveType]],
94
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
95
+ ) -> DF:
96
+ return self.df.fillna(value=value, subset=subset)
97
+
98
+ def replace(
99
+ self,
100
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
101
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
102
+ subset: t.Optional[t.Union[str, t.List[str]]] = None,
103
+ ) -> DF:
104
+ return self.df.replace(to_replace=to_replace, value=value, subset=subset)
105
+
106
+
107
+ NA = t.TypeVar("NA", bound=_BaseDataFrameNaFunctions)
108
+
109
+
110
+ class _BaseDataFrameStatFunctions(t.Generic[DF]):
111
+ def __init__(self, df: DF):
112
+ self.df = df
113
+
114
+ @t.overload
115
+ def approxQuantile(
116
+ self,
117
+ col: str,
118
+ probabilities: t.Union[t.List[float], t.Tuple[float]],
119
+ relativeError: float,
120
+ ) -> t.List[float]: ...
121
+
122
+ @t.overload
123
+ def approxQuantile(
124
+ self,
125
+ col: t.Union[t.List[str], t.Tuple[str]],
126
+ probabilities: t.Union[t.List[float], t.Tuple[float]],
127
+ relativeError: float,
128
+ ) -> t.List[t.List[float]]: ...
129
+
130
+ def approxQuantile(
131
+ self,
132
+ col: t.Union[str, t.List[str], t.Tuple[str]],
133
+ probabilities: t.Union[t.List[float], t.Tuple[float]],
134
+ relativeError: float,
135
+ ) -> t.Union[t.List[float], t.List[t.List[float]]]:
136
+ return self.df.approxQuantile(col, probabilities, relativeError)
137
+
138
+ def corr(self, col1: str, col2: str, method: str = "pearson") -> float:
139
+ return self.df.corr(col1, col2, method)
140
+
141
+ def cov(self, col1: str, col2: str) -> float:
142
+ return self.df.cov(col1, col2)
143
+
144
+
145
+ STAT = t.TypeVar("STAT", bound=_BaseDataFrameStatFunctions)
146
+
147
+
148
+ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
149
+ _na: t.Type[NA]
150
+ _stat: t.Type[STAT]
151
+ _group_data: t.Type[GROUP_DATA]
152
+
153
+ def __init__(
154
+ self,
155
+ session: SESSION,
156
+ expression: exp.Select,
157
+ branch_id: t.Optional[str] = None,
158
+ sequence_id: t.Optional[str] = None,
159
+ last_op: Operation = Operation.INIT,
160
+ pending_hints: t.Optional[t.List[exp.Expression]] = None,
161
+ output_expression_container: t.Optional[OutputExpressionContainer] = None,
162
+ **kwargs,
163
+ ):
164
+ self.session = session
165
+ self.expression: exp.Select = expression
166
+ self.branch_id = branch_id or self.session._random_branch_id
167
+ self.sequence_id = sequence_id or self.session._random_sequence_id
168
+ self.last_op = last_op
169
+ self.pending_hints = pending_hints or []
170
+ self.output_expression_container = output_expression_container or exp.Select()
171
+ self.temp_views: t.List[exp.Select] = []
172
+
173
+ def __getattr__(self, column_name: str) -> Column:
174
+ return self[column_name]
175
+
176
+ def __getitem__(self, column_name: str) -> Column:
177
+ from sqlframe.base.util import get_func_from_session
178
+
179
+ col = get_func_from_session("col", self.session)
180
+
181
+ column_name = f"{self.branch_id}.{column_name}"
182
+ return col(column_name)
183
+
184
+ def __copy__(self):
185
+ return self.copy()
186
+
187
+ @property
188
+ def write(self) -> WRITER:
189
+ return self.session._writer(self)
190
+
191
+ @property
192
+ def latest_cte_name(self) -> str:
193
+ if not self.expression.ctes:
194
+ from_exp = self.expression.args["from"]
195
+ if from_exp.alias_or_name:
196
+ return from_exp.alias_or_name
197
+ table_alias = from_exp.find(exp.TableAlias)
198
+ if not table_alias:
199
+ raise RuntimeError(
200
+ f"Could not find an alias name for this expression: {self.expression}"
201
+ )
202
+ return table_alias.alias_or_name
203
+ return self.expression.ctes[-1].alias
204
+
205
+ @property
206
+ def pending_join_hints(self):
207
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)]
208
+
209
+ @property
210
+ def pending_partition_hints(self):
211
+ return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)]
212
+
213
+ @property
214
+ def columns(self) -> t.List[str]:
215
+ return self.expression.named_selects
216
+
217
+ @property
218
+ def na(self) -> NA:
219
+ return self._na(self)
220
+
221
+ @property
222
+ def stat(self) -> STAT:
223
+ return self._stat(self)
224
+
225
+ @property
226
+ def schema(self) -> StructType:
227
+ """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`.
228
+
229
+ .. versionadded:: 1.3.0
230
+
231
+ .. versionchanged:: 3.4.0
232
+ Supports Spark Connect.
233
+
234
+ Returns
235
+ -------
236
+ :class:`StructType`
237
+
238
+ Examples
239
+ --------
240
+ >>> df = spark.createDataFrame(
241
+ ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
242
+
243
+ Retrieve the schema of the current DataFrame.
244
+
245
+ >>> df.schema
246
+ StructType([StructField('age', LongType(), True),
247
+ StructField('name', StringType(), True)])
248
+ """
249
+ raise NotImplementedError
250
+
251
+ def _replace_cte_names_with_hashes(self, expression: exp.Select):
252
+ replacement_mapping = {}
253
+ for cte in expression.ctes:
254
+ old_name_id = cte.args["alias"].this
255
+ new_hashed_id = exp.to_identifier(
256
+ self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"]
257
+ )
258
+ replacement_mapping[old_name_id] = new_hashed_id
259
+ expression = expression.transform(replace_id_value, replacement_mapping).assert_is(
260
+ exp.Select
261
+ )
262
+ return expression
263
+
264
+ def _create_cte_from_expression(
265
+ self,
266
+ expression: exp.Expression,
267
+ branch_id: str,
268
+ sequence_id: str,
269
+ name: t.Optional[str] = None,
270
+ **kwargs,
271
+ ) -> t.Tuple[exp.CTE, str]:
272
+ name = name or self._create_hash_from_expression(expression)
273
+ expression_to_cte = expression.copy()
274
+ expression_to_cte.set("with", None)
275
+ cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
276
+ cte.set("branch_id", branch_id)
277
+ cte.set("sequence_id", sequence_id)
278
+ return cte, name
279
+
280
+ def _ensure_list_of_columns(
281
+ self, cols: t.Optional[t.Union[ColumnOrLiteral, t.Collection[ColumnOrLiteral]]]
282
+ ) -> t.List[Column]:
283
+ from sqlframe.base.column import Column
284
+
285
+ return Column.ensure_cols(ensure_list(cols)) # type: ignore
286
+
287
+ def _ensure_and_normalize_cols(
288
+ self, cols, expression: t.Optional[exp.Select] = None
289
+ ) -> t.List[Column]:
290
+ from sqlframe.base.normalize import normalize
291
+
292
+ cols = self._ensure_list_of_columns(cols)
293
+ normalize(self.session, expression or self.expression, cols)
294
+ return cols
295
+
296
+ def _ensure_and_normalize_col(self, col):
297
+ from sqlframe.base.column import Column
298
+ from sqlframe.base.normalize import normalize
299
+
300
+ col = Column.ensure_col(col)
301
+ normalize(self.session, self.expression, col)
302
+ return col
303
+
304
+ def _convert_leaf_to_cte(
305
+ self, sequence_id: t.Optional[str] = None, name: t.Optional[str] = None
306
+ ) -> Self:
307
+ df = self._resolve_pending_hints()
308
+ sequence_id = sequence_id or df.sequence_id
309
+ expression = df.expression.copy()
310
+ cte_expression, cte_name = df._create_cte_from_expression(
311
+ expression=expression, branch_id=self.branch_id, sequence_id=sequence_id, name=name
312
+ )
313
+ new_expression = df._add_ctes_to_expression(
314
+ exp.Select(), expression.ctes + [cte_expression]
315
+ )
316
+ sel_columns = df._get_outer_select_columns(cte_expression)
317
+ new_expression = new_expression.from_(cte_name).select(*[x.expression for x in sel_columns])
318
+ return df.copy(expression=new_expression, sequence_id=sequence_id)
319
+
320
+ def _resolve_pending_hints(self) -> Self:
321
+ df = self.copy()
322
+ if not self.pending_hints:
323
+ return df
324
+ expression = df.expression
325
+ hint_expression = expression.args.get("hint") or exp.Hint(expressions=[])
326
+ for hint in df.pending_partition_hints:
327
+ hint_expression.append("expressions", hint)
328
+ df.pending_hints.remove(hint)
329
+
330
+ join_aliases = {
331
+ join_table.alias_or_name
332
+ for join_table in get_tables_from_expression_with_join(expression)
333
+ }
334
+ if join_aliases:
335
+ for hint in df.pending_join_hints:
336
+ for sequence_id_expression in hint.expressions:
337
+ sequence_id_or_name = sequence_id_expression.alias_or_name
338
+ sequence_ids_to_match = [sequence_id_or_name]
339
+ if sequence_id_or_name in df.session.name_to_sequence_id_mapping:
340
+ sequence_ids_to_match = df.session.name_to_sequence_id_mapping[
341
+ sequence_id_or_name
342
+ ]
343
+ matching_ctes = [
344
+ cte
345
+ for cte in reversed(expression.ctes)
346
+ if cte.args["sequence_id"] in sequence_ids_to_match
347
+ ]
348
+ for matching_cte in matching_ctes:
349
+ if matching_cte.alias_or_name in join_aliases:
350
+ sequence_id_expression.set("this", matching_cte.args["alias"].this)
351
+ df.pending_hints.remove(hint)
352
+ break
353
+ hint_expression.append("expressions", hint)
354
+ if hint_expression.expressions:
355
+ expression.set("hint", hint_expression)
356
+ return df
357
+
358
+ def _hint(self, hint_name: str, args: t.List[Column]) -> Self:
359
+ hint_name = hint_name.upper()
360
+ hint_expression = (
361
+ exp.JoinHint(
362
+ this=hint_name,
363
+ expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
364
+ )
365
+ if hint_name in JOIN_HINTS
366
+ else exp.Anonymous(
367
+ this=hint_name, expressions=[parameter.expression for parameter in args]
368
+ )
369
+ )
370
+ new_df = self.copy()
371
+ new_df.pending_hints.append(hint_expression)
372
+ return new_df
373
+
374
+ def _set_operation(self, klass: t.Callable, other: Self, distinct: bool) -> Self:
375
+ other_df = other._convert_leaf_to_cte()
376
+ base_expression = self.expression.copy()
377
+ base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes)
378
+ all_ctes = base_expression.ctes
379
+ other_df.expression.set("with", None)
380
+ base_expression.set("with", None)
381
+ operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression)
382
+ operation.set("with", exp.With(expressions=all_ctes))
383
+ return self.copy(expression=operation)._convert_leaf_to_cte()
384
+
385
+ def _cache(self, storage_level: str) -> Self:
386
+ df = self._convert_leaf_to_cte()
387
+ df.expression.ctes[-1].set("cache_storage_level", storage_level)
388
+ return df
389
+
390
+ @classmethod
391
+ def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select:
392
+ expression = expression.copy()
393
+ with_expression = expression.args.get("with")
394
+ if with_expression:
395
+ existing_ctes = with_expression.expressions
396
+ existsing_cte_names = {x.alias_or_name for x in existing_ctes}
397
+ for cte in ctes:
398
+ if cte.alias_or_name not in existsing_cte_names:
399
+ existing_ctes.append(cte)
400
+ else:
401
+ existing_ctes = ctes
402
+ expression.set("with", exp.With(expressions=existing_ctes))
403
+ return expression
404
+
405
+ @classmethod
406
+ def _get_outer_select_columns(cls, item: exp.Expression) -> t.List[Column]:
407
+ from sqlframe.base.session import _BaseSession
408
+
409
+ col = get_func_from_session("col", _BaseSession())
410
+
411
+ outer_select = item.find(exp.Select)
412
+ if outer_select:
413
+ return [col(x.alias_or_name) for x in outer_select.expressions]
414
+ return []
415
+
416
+ def _create_hash_from_expression(self, expression: exp.Expression) -> str:
417
+ from sqlframe.base.session import _BaseSession
418
+
419
+ value = expression.sql(dialect=_BaseSession().input_dialect).encode("utf-8")
420
+ hash = f"t{zlib.crc32(value)}"[:9]
421
+ return self.session._normalize_string(hash)
422
+
423
+ def _get_select_expressions(
424
+ self,
425
+ ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
426
+ select_expressions: t.List[
427
+ t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
428
+ ] = []
429
+ main_select_ctes: t.List[exp.CTE] = []
430
+ for cte in self.expression.ctes:
431
+ cache_storage_level = cte.args.get("cache_storage_level")
432
+ if cache_storage_level:
433
+ select_expression = cte.this.copy()
434
+ select_expression.set("with", exp.With(expressions=copy(main_select_ctes)))
435
+ select_expression.set("cte_alias_name", cte.alias_or_name)
436
+ select_expression.set("cache_storage_level", cache_storage_level)
437
+ select_expressions.append((exp.Cache, select_expression))
438
+ else:
439
+ main_select_ctes.append(cte)
440
+ main_select = self.expression.copy()
441
+ if main_select_ctes:
442
+ main_select.set("with", exp.With(expressions=main_select_ctes))
443
+ expression_select_pair = (type(self.output_expression_container), main_select)
444
+ select_expressions.append(expression_select_pair) # type: ignore
445
+ return select_expressions
446
+
447
+ @t.overload
448
+ def sql(
449
+ self,
450
+ dialect: DialectType = ...,
451
+ optimize: bool = ...,
452
+ pretty: bool = ...,
453
+ *,
454
+ as_list: t.Literal[False],
455
+ **kwargs: t.Any,
456
+ ) -> str: ...
457
+
458
+ @t.overload
459
+ def sql(
460
+ self,
461
+ dialect: DialectType = ...,
462
+ optimize: bool = ...,
463
+ pretty: bool = ...,
464
+ *,
465
+ as_list: t.Literal[True],
466
+ **kwargs: t.Any,
467
+ ) -> t.List[str]: ...
468
+
469
+ def sql(
470
+ self,
471
+ dialect: DialectType = None,
472
+ optimize: bool = True,
473
+ pretty: bool = True,
474
+ as_list: bool = False,
475
+ **kwargs,
476
+ ) -> t.Union[str, t.List[str]]:
477
+ dialect = Dialect.get_or_raise(dialect or self.session.output_dialect)
478
+
479
+ df = self._resolve_pending_hints()
480
+ select_expressions = df._get_select_expressions()
481
+ output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
482
+ replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
483
+
484
+ for expression_type, select_expression in select_expressions:
485
+ select_expression = select_expression.transform(
486
+ replace_id_value, replacement_mapping
487
+ ).assert_is(exp.Select)
488
+ if optimize:
489
+ quote_identifiers(select_expression, dialect=dialect)
490
+ select_expression = t.cast(
491
+ exp.Select, self.session._optimize(select_expression, dialect=dialect)
492
+ )
493
+
494
+ select_expression = df._replace_cte_names_with_hashes(select_expression)
495
+
496
+ expression: t.Union[exp.Select, exp.Cache, exp.Drop]
497
+ if expression_type == exp.Cache:
498
+ cache_table_name = df._create_hash_from_expression(select_expression)
499
+ cache_table = exp.to_table(cache_table_name)
500
+ original_alias_name = select_expression.args["cte_alias_name"]
501
+
502
+ replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
503
+ cache_table_name
504
+ )
505
+ self.session.catalog.add_table(
506
+ cache_table_name,
507
+ {
508
+ expression.alias_or_name: expression.type.sql(dialect=dialect)
509
+ if expression.type
510
+ else "UNKNOWN"
511
+ for expression in select_expression.expressions
512
+ },
513
+ )
514
+
515
+ cache_storage_level = select_expression.args["cache_storage_level"]
516
+ options = [
517
+ exp.Literal.string("storageLevel"),
518
+ exp.Literal.string(cache_storage_level),
519
+ ]
520
+ expression = exp.Cache(
521
+ this=cache_table, expression=select_expression, lazy=True, options=options
522
+ )
523
+
524
+ # We will drop the "view" if it exists before running the cache table
525
+ output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
526
+ elif expression_type == exp.Create:
527
+ expression = df.output_expression_container.copy()
528
+ expression.set("expression", select_expression)
529
+ elif expression_type == exp.Insert:
530
+ expression = df.output_expression_container.copy()
531
+ select_without_ctes = select_expression.copy()
532
+ select_without_ctes.set("with", None)
533
+ expression.set("expression", select_without_ctes)
534
+
535
+ if select_expression.ctes:
536
+ expression.set("with", exp.With(expressions=select_expression.ctes))
537
+ elif expression_type == exp.Select:
538
+ expression = select_expression
539
+ else:
540
+ raise ValueError(f"Invalid expression type: {expression_type}")
541
+
542
+ output_expressions.append(expression)
543
+
544
+ results = [
545
+ expression.sql(dialect=dialect, pretty=pretty, **kwargs)
546
+ for expression in output_expressions
547
+ ]
548
+ if as_list:
549
+ return results
550
+ return ";\n".join(results)
551
+
552
+ def copy(self, **kwargs) -> Self:
553
+ return self.__class__(**object_to_dict(self, **kwargs))
554
+
555
+ @operation(Operation.SELECT)
556
+ def select(self, *cols, **kwargs) -> Self:
557
+ from sqlframe.base.column import Column
558
+
559
+ columns = self._ensure_and_normalize_cols(cols)
560
+ kwargs["append"] = kwargs.get("append", False)
561
+ if self.expression.args.get("joins"):
562
+ ambiguous_cols = [
563
+ col
564
+ for col in columns
565
+ if isinstance(col.column_expression, exp.Column) and not col.column_expression.table
566
+ ]
567
+ if ambiguous_cols:
568
+ join_table_identifiers = [
569
+ x.this for x in get_tables_from_expression_with_join(self.expression)
570
+ ]
571
+ cte_names_in_join = [x.this for x in join_table_identifiers]
572
+ # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right
573
+ # and therefore we allow multiple columns with the same name in the result. This matches the behavior
574
+ # of Spark.
575
+ resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols}
576
+ for ambiguous_col in ambiguous_cols:
577
+ ctes_with_column = [
578
+ cte
579
+ for cte in self.expression.ctes
580
+ if cte.alias_or_name in cte_names_in_join
581
+ and ambiguous_col.alias_or_name in cte.this.named_selects
582
+ ]
583
+ # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise,
584
+ # use the same CTE we used before
585
+ cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1)
586
+ if cte:
587
+ resolved_column_position[ambiguous_col] += 1
588
+ else:
589
+ cte = ctes_with_column[resolved_column_position[ambiguous_col]]
590
+ ambiguous_col.expression.set("table", exp.to_identifier(cte.alias_or_name))
591
+ return self.copy(
592
+ expression=self.expression.select(*[x.expression for x in columns], **kwargs), **kwargs
593
+ )
594
+
595
+ @operation(Operation.NO_OP)
596
+ def alias(self, name: str, **kwargs) -> Self:
597
+ from sqlframe.base.column import Column
598
+
599
+ new_sequence_id = self.session._random_sequence_id
600
+ df = self.copy()
601
+ for join_hint in df.pending_join_hints:
602
+ for expression in join_hint.expressions:
603
+ if expression.alias_or_name == self.sequence_id:
604
+ expression.set("this", Column.ensure_col(new_sequence_id).expression)
605
+ df.session._add_alias_to_mapping(name, new_sequence_id)
606
+ return df._convert_leaf_to_cte(sequence_id=new_sequence_id)
607
+
608
+ @operation(Operation.WHERE)
609
+ def where(self, column: t.Union[Column, str, bool], **kwargs) -> Self:
610
+ if isinstance(column, str):
611
+ col = self._ensure_and_normalize_col(
612
+ sqlglot.parse_one(column, dialect=self.session.input_dialect)
613
+ )
614
+ else:
615
+ col = self._ensure_and_normalize_col(column)
616
+ return self.copy(expression=self.expression.where(col.expression))
617
+
618
+ filter = where
619
+
620
+ @operation(Operation.GROUP_BY)
621
+ def groupBy(self, *cols, **kwargs) -> GROUP_DATA:
622
+ columns = self._ensure_and_normalize_cols(cols)
623
+ return self._group_data(self, columns, self.last_op)
624
+
625
+ groupby = groupBy
626
+
627
+ @operation(Operation.SELECT)
628
+ def agg(self, *exprs, **kwargs) -> Self:
629
+ cols = self._ensure_and_normalize_cols(exprs)
630
+ return self.groupBy().agg(*cols)
631
+
632
+ @operation(Operation.FROM)
633
+ def crossJoin(self, other: DF) -> Self:
634
+ """Returns the cartesian product with another :class:`DataFrame`.
635
+
636
+ .. versionadded:: 2.1.0
637
+
638
+ .. versionchanged:: 3.4.0
639
+ Supports Spark Connect.
640
+
641
+ Parameters
642
+ ----------
643
+ other : :class:`DataFrame`
644
+ Right side of the cartesian product.
645
+
646
+ Returns
647
+ -------
648
+ :class:`DataFrame`
649
+ Joined DataFrame.
650
+
651
+ Examples
652
+ --------
653
+ >>> from pyspark.sql import Row
654
+ >>> df = spark.createDataFrame(
655
+ ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
656
+ >>> df2 = spark.createDataFrame(
657
+ ... [Row(height=80, name="Tom"), Row(height=85, name="Bob")])
658
+ >>> df.crossJoin(df2.select("height")).select("age", "name", "height").show()
659
+ +---+-----+------+
660
+ |age| name|height|
661
+ +---+-----+------+
662
+ | 14| Tom| 80|
663
+ | 14| Tom| 85|
664
+ | 23|Alice| 80|
665
+ | 23|Alice| 85|
666
+ | 16| Bob| 80|
667
+ | 16| Bob| 85|
668
+ +---+-----+------+
669
+ """
670
+ return self.join.__wrapped__(self, other, how="cross") # type: ignore
671
+
672
+ @operation(Operation.FROM)
673
+ def join(
674
+ self,
675
+ other_df: Self,
676
+ on: t.Optional[t.Union[str, t.List[str], Column, t.List[Column]]] = None,
677
+ how: str = "inner",
678
+ **kwargs,
679
+ ) -> Self:
680
+ if on is None:
681
+ logger.warning("Got no value for on. This appears change the join to a cross join.")
682
+ how = "cross"
683
+ other_df = other_df._convert_leaf_to_cte()
684
+ # We will determine actual "join on" expression later so we don't provide it at first
685
+ join_expression = self.expression.join(
686
+ other_df.latest_cte_name, join_type=how.replace("_", " ")
687
+ )
688
+ join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
689
+ self_columns = self._get_outer_select_columns(join_expression)
690
+ other_columns = self._get_outer_select_columns(other_df.expression)
691
+ join_columns = self._ensure_list_of_columns(on)
692
+ # Determines the join clause and select columns to be used passed on what type of columns were provided for
693
+ # the join. The columns returned changes based on how the on expression is provided.
694
+ if how != "cross":
695
+ if isinstance(join_columns[0].expression, exp.Column):
696
+ """
697
+ Unique characteristics of join on column names only:
698
+ * The column names are put at the front of the select list
699
+ * The column names are deduplicated across the entire select list and only the column names (other dups are allowed)
700
+ """
701
+ table_names = [
702
+ table.alias_or_name
703
+ for table in get_tables_from_expression_with_join(join_expression)
704
+ ]
705
+ potential_ctes = [
706
+ cte
707
+ for cte in join_expression.ctes
708
+ if cte.alias_or_name in table_names
709
+ and cte.alias_or_name != other_df.latest_cte_name
710
+ ]
711
+ # Determine the table to reference for the left side of the join by checking each of the left side
712
+ # tables and see if they have the column being referenced.
713
+ join_column_pairs = []
714
+ for join_column in join_columns:
715
+ num_matching_ctes = 0
716
+ for cte in potential_ctes:
717
+ if join_column.alias_or_name in cte.this.named_selects:
718
+ left_column = join_column.copy().set_table_name(cte.alias_or_name)
719
+ right_column = join_column.copy().set_table_name(
720
+ other_df.latest_cte_name
721
+ )
722
+ join_column_pairs.append((left_column, right_column))
723
+ num_matching_ctes += 1
724
+ if num_matching_ctes > 1:
725
+ raise ValueError(
726
+ f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name."
727
+ )
728
+ elif num_matching_ctes == 0:
729
+ raise ValueError(
730
+ f"Column {join_column.alias_or_name} does not exist in any of the tables."
731
+ )
732
+ join_clause = functools.reduce(
733
+ lambda x, y: x & y,
734
+ [
735
+ left_column == right_column
736
+ for left_column, right_column in join_column_pairs
737
+ ],
738
+ )
739
+ join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs]
740
+ # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list
741
+ select_column_names = [
742
+ (
743
+ column.alias_or_name
744
+ if not isinstance(column.expression.this, exp.Star)
745
+ else column.sql()
746
+ )
747
+ for column in self_columns + other_columns
748
+ ]
749
+ select_column_names = [
750
+ column_name
751
+ for column_name in select_column_names
752
+ if column_name not in join_column_names
753
+ ]
754
+ select_column_names = join_column_names + select_column_names
755
+ else:
756
+ """
757
+ Unique characteristics of join on expressions:
758
+ * There is no deduplication of the results.
759
+ * The left join dataframe columns go first and right come after. No sort preference is given to join columns
760
+ """
761
+ join_columns = self._ensure_and_normalize_cols(join_columns, join_expression)
762
+ if len(join_columns) > 1:
763
+ join_columns = [functools.reduce(lambda x, y: x & y, join_columns)]
764
+ join_clause = join_columns[0]
765
+ select_column_names = [
766
+ column.alias_or_name for column in self_columns + other_columns
767
+ ]
768
+
769
+ # Update the on expression with the actual join clause to replace the dummy one from before
770
+ else:
771
+ select_column_names = [column.alias_or_name for column in self_columns + other_columns]
772
+ join_clause = None
773
+ join_expression.args["joins"][-1].set("on", join_clause.expression if join_clause else None)
774
+ new_df = self.copy(expression=join_expression)
775
+ new_df.pending_join_hints.extend(self.pending_join_hints)
776
+ new_df.pending_hints.extend(other_df.pending_hints)
777
+ new_df = new_df.select.__wrapped__(new_df, *select_column_names) # type: ignore
778
+ return new_df
779
+
780
+ @operation(Operation.ORDER_BY)
781
+ def orderBy(
782
+ self,
783
+ *cols: t.Union[str, Column],
784
+ ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
785
+ ) -> Self:
786
+ """
787
+ This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
788
+ has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this
789
+ is unlikely to come up.
790
+ """
791
+ columns = self._ensure_and_normalize_cols(cols)
792
+ pre_ordered_col_indexes = [
793
+ i for i, col in enumerate(columns) if isinstance(col.expression, exp.Ordered)
794
+ ]
795
+ if ascending is None:
796
+ ascending = [True] * len(columns)
797
+ elif not isinstance(ascending, list):
798
+ ascending = [ascending] * len(columns)
799
+ ascending = [bool(x) for i, x in enumerate(ascending)]
800
+ assert len(columns) == len(
801
+ ascending
802
+ ), "The length of items in ascending must equal the number of columns provided"
803
+ col_and_ascending = list(zip(columns, ascending))
804
+ order_by_columns = [
805
+ (
806
+ sqlglot.parse_one(
807
+ f"{col.expression.sql(dialect=self.session.input_dialect)} {'DESC' if not asc else ''}",
808
+ dialect=self.session.input_dialect,
809
+ into=exp.Ordered,
810
+ )
811
+ if i not in pre_ordered_col_indexes
812
+ else columns[i].column_expression
813
+ )
814
+ for i, (col, asc) in enumerate(col_and_ascending)
815
+ ]
816
+ return self.copy(expression=self.expression.order_by(*order_by_columns))
817
+
818
+ sort = orderBy
819
+
820
+ @operation(Operation.FROM)
821
+ def union(self, other: Self) -> Self:
822
+ return self._set_operation(exp.Union, other, False)
823
+
824
+ unionAll = union
825
+
826
+ @operation(Operation.FROM)
827
+ def unionByName(self, other: Self, allowMissingColumns: bool = False) -> Self:
828
+ l_columns = self.columns
829
+ r_columns = other.columns
830
+ if not allowMissingColumns:
831
+ l_expressions = l_columns
832
+ r_expressions = l_columns
833
+ else:
834
+ l_expressions = []
835
+ r_expressions = []
836
+ r_columns_unused = copy(r_columns)
837
+ for l_column in l_columns:
838
+ l_expressions.append(l_column)
839
+ if l_column in r_columns:
840
+ r_expressions.append(l_column)
841
+ r_columns_unused.remove(l_column)
842
+ else:
843
+ r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False))
844
+ for r_column in r_columns_unused:
845
+ l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False))
846
+ r_expressions.append(r_column)
847
+ r_df = (
848
+ other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
849
+ )
850
+ l_df = self.copy()
851
+ if allowMissingColumns:
852
+ l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
853
+ return l_df._set_operation(exp.Union, r_df, False)
854
+
855
+ @operation(Operation.FROM)
856
+ def intersect(self, other: Self) -> Self:
857
+ return self._set_operation(exp.Intersect, other, True)
858
+
859
+ @operation(Operation.FROM)
860
+ def intersectAll(self, other: Self) -> Self:
861
+ return self._set_operation(exp.Intersect, other, False)
862
+
863
+ @operation(Operation.FROM)
864
+ def exceptAll(self, other: Self) -> Self:
865
+ return self._set_operation(exp.Except, other, False)
866
+
867
+ @operation(Operation.SELECT)
868
+ def distinct(self) -> Self:
869
+ return self.copy(expression=self.expression.distinct())
870
+
871
+ @operation(Operation.SELECT)
872
+ def dropDuplicates(self, subset: t.Optional[t.List[str]] = None):
873
+ from sqlframe.base import functions as F
874
+ from sqlframe.base.window import Window
875
+
876
+ if not subset:
877
+ return self.distinct()
878
+ column_names = ensure_list(subset)
879
+ window = Window.partitionBy(*column_names).orderBy(*column_names)
880
+ return (
881
+ self.copy()
882
+ .withColumn("row_num", F.row_number().over(window))
883
+ .where(F.col("row_num") == F.lit(1))
884
+ .drop("row_num")
885
+ )
886
+
887
+ drop_duplicates = dropDuplicates
888
+
889
+ @operation(Operation.FROM)
890
+ def dropna(
891
+ self,
892
+ how: str = "any",
893
+ thresh: t.Optional[int] = None,
894
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
895
+ ) -> Self:
896
+ from sqlframe.base import functions as F
897
+
898
+ minimum_non_null = thresh or 0 # will be determined later if thresh is null
899
+ new_df = self.copy()
900
+ all_columns = self._get_outer_select_columns(new_df.expression)
901
+ if subset:
902
+ null_check_columns = self._ensure_and_normalize_cols(subset)
903
+ else:
904
+ null_check_columns = all_columns
905
+ if thresh is None:
906
+ minimum_num_nulls = 1 if how == "any" else len(null_check_columns)
907
+ else:
908
+ minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1
909
+ if minimum_num_nulls > len(null_check_columns):
910
+ raise RuntimeError(
911
+ f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
912
+ f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
913
+ )
914
+ if_null_checks = [
915
+ F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
916
+ ]
917
+ nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
918
+ num_nulls = nulls_added_together.alias("num_nulls")
919
+ new_df = new_df.select(num_nulls, append=True)
920
+ filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls))
921
+ final_df = filtered_df.select(*all_columns)
922
+ return final_df
923
+
924
+ def explain(
925
+ self, extended: t.Optional[t.Union[bool, str]] = None, mode: t.Optional[str] = None
926
+ ) -> None:
927
+ """Prints the (logical and physical) plans to the console for debugging purposes.
928
+
929
+ .. versionadded:: 1.3.0
930
+
931
+ .. versionchanged:: 3.4.0
932
+ Supports Spark Connect.
933
+
934
+ Parameters
935
+ ----------
936
+ extended : bool, optional
937
+ default ``False``. If ``False``, prints only the physical plan.
938
+ When this is a string without specifying the ``mode``, it works as the mode is
939
+ specified.
940
+ mode : str, optional
941
+ specifies the expected output format of plans.
942
+
943
+ * ``simple``: Print only a physical plan.
944
+ * ``extended``: Print both logical and physical plans.
945
+ * ``codegen``: Print a physical plan and generated codes if they are available.
946
+ * ``cost``: Print a logical plan and statistics if they are available.
947
+ * ``formatted``: Split explain output into two sections: a physical plan outline \
948
+ and node details.
949
+
950
+ .. versionchanged:: 3.0.0
951
+ Added optional argument `mode` to specify the expected output format of plans.
952
+
953
+ Examples
954
+ --------
955
+ >>> df = spark.createDataFrame(
956
+ ... [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
957
+
958
+ Print out the physical plan only (default).
959
+
960
+ >>> df.explain() # doctest: +SKIP
961
+ == Physical Plan ==
962
+ *(1) Scan ExistingRDD[age...,name...]
963
+
964
+ Print out all of the parsed, analyzed, optimized and physical plans.
965
+
966
+ >>> df.explain(True)
967
+ == Parsed Logical Plan ==
968
+ ...
969
+ == Analyzed Logical Plan ==
970
+ ...
971
+ == Optimized Logical Plan ==
972
+ ...
973
+ == Physical Plan ==
974
+ ...
975
+
976
+ Print out the plans with two sections: a physical plan outline and node details
977
+
978
+ >>> df.explain(mode="formatted") # doctest: +SKIP
979
+ == Physical Plan ==
980
+ * Scan ExistingRDD (...)
981
+ (1) Scan ExistingRDD [codegen id : ...]
982
+ Output [2]: [age..., name...]
983
+ ...
984
+
985
+ Print a logical plan and statistics if they are available.
986
+
987
+ >>> df.explain("cost")
988
+ == Optimized Logical Plan ==
989
+ ...Statistics...
990
+ ...
991
+ """
992
+ sql_queries = self.sql(pretty=False, optimize=False, as_list=True)
993
+ if len(sql_queries) > 1:
994
+ raise ValueError("Cannot explain a DataFrame with multiple queries")
995
+ sql_query = "EXPLAIN " + sql_queries[0]
996
+ self.session._execute(sql_query, quote_identifiers=False)
997
+
998
+ @operation(Operation.FROM)
999
+ def fillna(
1000
+ self,
1001
+ value: t.Union[PrimitiveType, t.Dict[str, PrimitiveType]],
1002
+ subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None,
1003
+ ) -> Self:
1004
+ """
1005
+ Functionality Difference: If you provide a value to replace a null and that type conflicts
1006
+ with the type of the column then PySpark will just ignore your replacement.
1007
+ This will try to cast them to be the same in some cases. So they won't always match.
1008
+ Best to not mix types so make sure replacement is the same type as the column
1009
+
1010
+ Possibility for improvement: Use `typeof` function to get the type of the column
1011
+ and check if it matches the type of the value provided. If not then make it null.
1012
+ """
1013
+ from sqlframe.base import functions as F
1014
+
1015
+ values = None
1016
+ columns = None
1017
+ new_df = self.copy()
1018
+ all_columns = self._get_outer_select_columns(new_df.expression)
1019
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
1020
+ if isinstance(value, dict):
1021
+ values = list(value.values())
1022
+ columns = self._ensure_and_normalize_cols(list(value))
1023
+ if not columns:
1024
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
1025
+ if not values:
1026
+ assert not isinstance(value, dict)
1027
+ values = [value] * len(columns)
1028
+ value_columns = [F.lit(value) for value in values]
1029
+
1030
+ null_replacement_mapping = {
1031
+ column.alias_or_name: (
1032
+ F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
1033
+ )
1034
+ for column, value in zip(columns, value_columns)
1035
+ }
1036
+ null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
1037
+ null_replacement_columns = [
1038
+ null_replacement_mapping[column.alias_or_name] for column in all_columns
1039
+ ]
1040
+ new_df = new_df.select(*null_replacement_columns)
1041
+ return new_df
1042
+
1043
+ @operation(Operation.FROM)
1044
+ def replace(
1045
+ self,
1046
+ to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
1047
+ value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
1048
+ subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
1049
+ ) -> Self:
1050
+ from sqlframe.base import functions as F
1051
+ from sqlframe.base.column import Column
1052
+
1053
+ old_values = None
1054
+ new_df = self.copy()
1055
+ all_columns = self._get_outer_select_columns(new_df.expression)
1056
+ all_column_mapping = {column.alias_or_name: column for column in all_columns}
1057
+
1058
+ columns = self._ensure_and_normalize_cols(subset) if subset else all_columns
1059
+ if isinstance(to_replace, dict):
1060
+ old_values = list(to_replace)
1061
+ new_values = list(to_replace.values())
1062
+ elif not old_values and isinstance(to_replace, list):
1063
+ assert isinstance(value, list), "value must be a list since the replacements are a list"
1064
+ assert len(to_replace) == len(
1065
+ value
1066
+ ), "the replacements and values must be the same length"
1067
+ old_values = to_replace
1068
+ new_values = value
1069
+ else:
1070
+ old_values = [to_replace] * len(columns)
1071
+ new_values = [value] * len(columns)
1072
+ old_values = [F.lit(value) for value in old_values]
1073
+ new_values = [F.lit(value) for value in new_values]
1074
+
1075
+ replacement_mapping = {}
1076
+ for column in columns:
1077
+ # expression = Column(None)
1078
+ expression = F.lit(None)
1079
+ for i, (old_value, new_value) in enumerate(zip(old_values, new_values)):
1080
+ if i == 0:
1081
+ expression = F.when(column == old_value, new_value)
1082
+ else:
1083
+ expression = expression.when(column == old_value, new_value) # type: ignore
1084
+ replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias(
1085
+ column.expression.alias_or_name
1086
+ )
1087
+
1088
+ replacement_mapping = {**all_column_mapping, **replacement_mapping}
1089
+ replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns]
1090
+ new_df = new_df.select(*replacement_columns)
1091
+ return new_df
1092
+
1093
+ @operation(Operation.SELECT)
1094
+ def withColumn(self, colName: str, col: Column) -> Self:
1095
+ col = self._ensure_and_normalize_col(col)
1096
+ col_name = self._ensure_and_normalize_col(colName).alias_or_name
1097
+ existing_col_names = self.expression.named_selects
1098
+ existing_col_index = (
1099
+ existing_col_names.index(col_name) if col_name in existing_col_names else None
1100
+ )
1101
+ if existing_col_index:
1102
+ expression = self.expression.copy()
1103
+ expression.expressions[existing_col_index] = col.alias(col_name).expression
1104
+ return self.copy(expression=expression)
1105
+ return self.select.__wrapped__(self, col.alias(col_name), append=True) # type: ignore
1106
+
1107
+ @operation(Operation.SELECT)
1108
+ def withColumnRenamed(self, existing: str, new: str) -> Self:
1109
+ expression = self.expression.copy()
1110
+ existing = self.session._normalize_string(existing)
1111
+ new = self.session._normalize_string(new)
1112
+ existing_columns = [
1113
+ expression
1114
+ for expression in expression.expressions
1115
+ if expression.alias_or_name == existing
1116
+ ]
1117
+ if not existing_columns:
1118
+ raise ValueError("Tried to rename a column that doesn't exist")
1119
+ for existing_column in existing_columns:
1120
+ if isinstance(existing_column, exp.Column):
1121
+ existing_column.replace(exp.alias_(existing_column, new))
1122
+ else:
1123
+ existing_column.set("alias", exp.to_identifier(new))
1124
+ return self.copy(expression=expression)
1125
+
1126
+ @operation(Operation.SELECT)
1127
+ def drop(self, *cols: t.Union[str, Column]) -> Self:
1128
+ all_columns = self._get_outer_select_columns(self.expression)
1129
+ drop_cols = self._ensure_and_normalize_cols(cols)
1130
+ new_columns = [
1131
+ col
1132
+ for col in all_columns
1133
+ if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols]
1134
+ ]
1135
+ return self.copy().select(*new_columns, append=False)
1136
+
1137
+ @operation(Operation.LIMIT)
1138
+ def limit(self, num: int) -> Self:
1139
+ return self.copy(expression=self.expression.limit(num))
1140
+
1141
+ def toDF(self, *cols: str) -> Self:
1142
+ """Returns a new :class:`DataFrame` that with new specified column names
1143
+
1144
+ .. versionadded:: 1.6.0
1145
+
1146
+ .. versionchanged:: 3.4.0
1147
+ Supports Spark Connect.
1148
+
1149
+ Parameters
1150
+ ----------
1151
+ *cols : tuple
1152
+ a tuple of string new column name. The length of the
1153
+ list needs to be the same as the number of columns in the initial
1154
+ :class:`DataFrame`
1155
+
1156
+ Returns
1157
+ -------
1158
+ :class:`DataFrame`
1159
+ DataFrame with new column names.
1160
+
1161
+ Examples
1162
+ --------
1163
+ >>> df = spark.createDataFrame([(14, "Tom"), (23, "Alice"),
1164
+ ... (16, "Bob")], ["age", "name"])
1165
+ >>> df.toDF('f1', 'f2').show()
1166
+ +---+-----+
1167
+ | f1| f2|
1168
+ +---+-----+
1169
+ | 14| Tom|
1170
+ | 23|Alice|
1171
+ | 16| Bob|
1172
+ +---+-----+
1173
+ """
1174
+ if len(cols) != len(self.columns):
1175
+ raise ValueError(
1176
+ f"Number of column names does not match number of columns: {len(cols)} != {len(self.columns)}"
1177
+ )
1178
+ expression = self.expression.copy()
1179
+ expression = expression.select(
1180
+ *[exp.alias_(col, new_col) for col, new_col in zip(expression.expressions, cols)],
1181
+ append=False,
1182
+ )
1183
+ return self.copy(expression=expression)
1184
+
1185
+ @operation(Operation.NO_OP)
1186
+ def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> Self:
1187
+ from sqlframe.base.column import Column
1188
+
1189
+ parameter_list = ensure_list(parameters)
1190
+ parameter_columns = (
1191
+ self._ensure_list_of_columns(parameter_list)
1192
+ if parameters
1193
+ else Column.ensure_cols([self.sequence_id])
1194
+ )
1195
+ return self._hint(name, parameter_columns)
1196
+
1197
+ @operation(Operation.NO_OP)
1198
+ def repartition(self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName) -> Self:
1199
+ num_partition_cols = self._ensure_list_of_columns(numPartitions)
1200
+ columns = self._ensure_and_normalize_cols(cols)
1201
+ args = num_partition_cols + columns
1202
+ return self._hint("repartition", args)
1203
+
1204
+ @operation(Operation.NO_OP)
1205
+ def coalesce(self, numPartitions: int) -> Self:
1206
+ lit = get_func_from_session("lit")
1207
+
1208
+ num_partitions = lit(numPartitions)
1209
+ return self._hint("coalesce", [num_partitions])
1210
+
1211
+ @operation(Operation.NO_OP)
1212
+ def cache(self) -> Self:
1213
+ return self._cache(storage_level="MEMORY_AND_DISK")
1214
+
1215
+ @operation(Operation.NO_OP)
1216
+ def persist(self, storageLevel: StorageLevel = "MEMORY_AND_DISK_SER") -> Self:
1217
+ """
1218
+ Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html
1219
+ """
1220
+ return self._cache(storageLevel)
1221
+
1222
+ @t.overload
1223
+ def cube(self, *cols: ColumnOrName) -> GROUP_DATA: ...
1224
+
1225
+ @t.overload
1226
+ def cube(self, __cols: t.Union[t.List[Column], t.List[str]]) -> GROUP_DATA: ...
1227
+
1228
+ def cube(self, *cols: ColumnOrName) -> GROUP_DATA: # type: ignore[misc]
1229
+ """
1230
+ Create a multi-dimensional cube for the current :class:`DataFrame` using
1231
+ the specified columns, so we can run aggregations on them.
1232
+
1233
+ .. versionadded:: 1.4.0
1234
+
1235
+ .. versionchanged:: 3.4.0
1236
+ Supports Spark Connect.
1237
+
1238
+ Parameters
1239
+ ----------
1240
+ cols : list, str or :class:`Column`
1241
+ columns to create cube by.
1242
+ Each element should be a column name (string) or an expression (:class:`Column`)
1243
+ or list of them.
1244
+
1245
+ Returns
1246
+ -------
1247
+ :class:`GroupedData`
1248
+ Cube of the data by given columns.
1249
+
1250
+ Examples
1251
+ --------
1252
+ >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"])
1253
+ >>> df.cube("name", df.age).count().orderBy("name", "age").show()
1254
+ +-----+----+-----+
1255
+ | name| age|count|
1256
+ +-----+----+-----+
1257
+ | NULL|NULL| 2|
1258
+ | NULL| 2| 1|
1259
+ | NULL| 5| 1|
1260
+ |Alice|NULL| 1|
1261
+ |Alice| 2| 1|
1262
+ | Bob|NULL| 1|
1263
+ | Bob| 5| 1|
1264
+ +-----+----+-----+
1265
+ """
1266
+
1267
+ columns = self._ensure_and_normalize_cols(cols)
1268
+ grouping_columns: t.List[t.List[Column]] = []
1269
+ for i in reversed(range(len(columns) + 1)):
1270
+ grouping_columns.extend([list(x) for x in itertools.combinations(columns, i)])
1271
+ return self._group_data(self, grouping_columns, self.last_op)
1272
+
1273
+ def collect(self) -> t.List[Row]:
1274
+ result = []
1275
+ for sql in self.sql(pretty=False, optimize=False, as_list=True):
1276
+ result = self.session._fetch_rows(sql)
1277
+ return result
1278
+
1279
+ @t.overload
1280
+ def head(self) -> t.Optional[Row]: ...
1281
+
1282
+ @t.overload
1283
+ def head(self, n: int) -> t.List[Row]: ...
1284
+
1285
+ def head(self, n: t.Optional[int] = None) -> t.Union[t.Optional[Row], t.List[Row]]:
1286
+ n = n or 1
1287
+ df = self.limit(n)
1288
+ if n == 1:
1289
+ return df.collect()[0]
1290
+ return df.collect()
1291
+
1292
+ def first(self) -> t.Optional[Row]:
1293
+ return self.head()
1294
+
1295
+ def show(
1296
+ self, n: int = 20, truncate: t.Optional[t.Union[bool, int]] = None, vertical: bool = False
1297
+ ):
1298
+ if vertical:
1299
+ raise NotImplementedError("Vertical show is not yet supported")
1300
+ if truncate:
1301
+ logger.warning("Truncate is ignored so full results will be displayed")
1302
+ # Make sure that the limit we add doesn't affect the results
1303
+ df = self._convert_leaf_to_cte()
1304
+ sql = df.limit(n).sql(
1305
+ pretty=False, optimize=False, dialect=self.session.output_dialect, as_list=True
1306
+ )
1307
+ for sql in ensure_list(sql):
1308
+ result = self.session._fetch_rows(sql)
1309
+ table = PrettyTable()
1310
+ if row := seq_get(result, 0):
1311
+ table.field_names = list(row.asDict().keys())
1312
+ for row in result:
1313
+ table.add_row(list(row))
1314
+ print(table)
1315
+
1316
+ def toPandas(self) -> pd.DataFrame:
1317
+ sql_kwargs = dict(
1318
+ pretty=False, optimize=False, dialect=self.session.output_dialect, as_list=True
1319
+ )
1320
+ sqls = [None] + self.sql(**sql_kwargs) # type: ignore
1321
+ for sql in self.sql(**sql_kwargs)[:-1]: # type: ignore
1322
+ if sql:
1323
+ self.session._execute(sql)
1324
+ assert sqls[-1] is not None
1325
+ return self.session._fetchdf(sqls[-1])
1326
+
1327
+ def createOrReplaceTempView(self, name: str) -> None:
1328
+ self.session.temp_views[name] = self.copy()._convert_leaf_to_cte()
1329
+
1330
+ def count(self) -> int:
1331
+ if not self.session._has_connection:
1332
+ raise RuntimeError("Cannot count without a connection")
1333
+
1334
+ df = self._convert_leaf_to_cte()
1335
+ df = self.copy(expression=df.expression.select("count(*)", append=False))
1336
+ for sql in df.sql(
1337
+ dialect=self.session.output_dialect, pretty=False, optimize=False, as_list=True
1338
+ ):
1339
+ result = self.session._fetch_rows(sql)
1340
+ return result[0][0]
1341
+
1342
+ def createGlobalTempView(self, name: str) -> None:
1343
+ raise NotImplementedError("Global temp views are not yet supported")
1344
+
1345
+ """
1346
+ Stat Functions
1347
+ """
1348
+
1349
+ @t.overload
1350
+ def approxQuantile(
1351
+ self,
1352
+ col: str,
1353
+ probabilities: t.Union[t.List[float], t.Tuple[float]],
1354
+ relativeError: float,
1355
+ ) -> t.List[float]: ...
1356
+
1357
+ @t.overload
1358
+ def approxQuantile(
1359
+ self,
1360
+ col: t.Union[t.List[str], t.Tuple[str]],
1361
+ probabilities: t.Union[t.List[float], t.Tuple[float]],
1362
+ relativeError: float,
1363
+ ) -> t.List[t.List[float]]: ...
1364
+
1365
+ def approxQuantile(
1366
+ self,
1367
+ col: t.Union[str, t.List[str], t.Tuple[str]],
1368
+ probabilities: t.Union[t.List[float], t.Tuple[float]],
1369
+ relativeError: float,
1370
+ ) -> t.Union[t.List[float], t.List[t.List[float]]]:
1371
+ """
1372
+ Calculates the approximate quantiles of numerical columns of a
1373
+ :class:`DataFrame`.
1374
+
1375
+ The result of this algorithm has the following deterministic bound:
1376
+ If the :class:`DataFrame` has N elements and if we request the quantile at
1377
+ probability `p` up to error `err`, then the algorithm will return
1378
+ a sample `x` from the :class:`DataFrame` so that the *exact* rank of `x` is
1379
+ close to (p * N). More precisely,
1380
+
1381
+ floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
1382
+
1383
+ This method implements a variation of the Greenwald-Khanna
1384
+ algorithm (with some speed optimizations). The algorithm was first
1385
+ present in [[https://doi.org/10.1145/375663.375670
1386
+ Space-efficient Online Computation of Quantile Summaries]]
1387
+ by Greenwald and Khanna.
1388
+
1389
+ .. versionadded:: 2.0.0
1390
+
1391
+ .. versionchanged:: 3.4.0
1392
+ Supports Spark Connect.
1393
+
1394
+ Parameters
1395
+ ----------
1396
+ col: str, tuple or list
1397
+ Can be a single column name, or a list of names for multiple columns.
1398
+
1399
+ .. versionchanged:: 2.2.0
1400
+ Added support for multiple columns.
1401
+ probabilities : list or tuple
1402
+ a list of quantile probabilities
1403
+ Each number must belong to [0, 1].
1404
+ For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
1405
+ relativeError : float
1406
+ The relative target precision to achieve
1407
+ (>= 0). If set to zero, the exact quantiles are computed, which
1408
+ could be very expensive. Note that values greater than 1 are
1409
+ accepted but gives the same result as 1.
1410
+
1411
+ Returns
1412
+ -------
1413
+ list
1414
+ the approximate quantiles at the given probabilities.
1415
+
1416
+ * If the input `col` is a string, the output is a list of floats.
1417
+
1418
+ * If the input `col` is a list or tuple of strings, the output is also a
1419
+ list, but each element in it is a list of floats, i.e., the output
1420
+ is a list of list of floats.
1421
+
1422
+ Notes
1423
+ -----
1424
+ Null values will be ignored in numerical columns before calculation.
1425
+ For columns only containing null values, an empty list is returned.
1426
+ """
1427
+
1428
+ percentile_approx = get_func_from_session("percentile_approx")
1429
+ col_func = get_func_from_session("col")
1430
+
1431
+ accuracy = 1.0 / relativeError if relativeError > 0.0 else 10000
1432
+
1433
+ df = self.select(
1434
+ *[
1435
+ percentile_approx(col_func(x), probabilities, accuracy).alias(f"val_{i}")
1436
+ for i, x in enumerate(ensure_list(col))
1437
+ ]
1438
+ )
1439
+ rows = df.collect()
1440
+ return [[float(y) for y in x] for row in rows for x in row.asDict().values()]
1441
+
1442
+ def corr(self, col1: str, col2: str, method: t.Optional[str] = None) -> float:
1443
+ """
1444
+ Calculates the correlation of two columns of a :class:`DataFrame` as a double value.
1445
+ Currently only supports the Pearson Correlation Coefficient.
1446
+ :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases of each other.
1447
+
1448
+ .. versionadded:: 1.4.0
1449
+
1450
+ .. versionchanged:: 3.4.0
1451
+ Supports Spark Connect.
1452
+
1453
+ Parameters
1454
+ ----------
1455
+ col1 : str
1456
+ The name of the first column
1457
+ col2 : str
1458
+ The name of the second column
1459
+ method : str, optional
1460
+ The correlation method. Currently only supports "pearson"
1461
+
1462
+ Returns
1463
+ -------
1464
+ float
1465
+ Pearson Correlation Coefficient of two columns.
1466
+
1467
+ Examples
1468
+ --------
1469
+ >>> df = spark.createDataFrame([(1, 12), (10, 1), (19, 8)], ["c1", "c2"])
1470
+ >>> df.corr("c1", "c2")
1471
+ -0.3592106040535498
1472
+ >>> df = spark.createDataFrame([(11, 12), (10, 11), (9, 10)], ["small", "bigger"])
1473
+ >>> df.corr("small", "bigger")
1474
+ 1.0
1475
+ """
1476
+ if method != "pearson":
1477
+ raise ValueError(f"Currently only the Pearson Correlation Coefficient is supported")
1478
+
1479
+ corr = get_func_from_session("corr")
1480
+ col_func = get_func_from_session("col")
1481
+
1482
+ return self.select(corr(col_func(col1), col_func(col2))).collect()[0][0]
1483
+
1484
+ def cov(self, col1: str, col2: str) -> float:
1485
+ """
1486
+ Calculate the sample covariance for the given columns, specified by their names, as a
1487
+ double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases.
1488
+
1489
+ .. versionadded:: 1.4.0
1490
+
1491
+ .. versionchanged:: 3.4.0
1492
+ Supports Spark Connect.
1493
+
1494
+ Parameters
1495
+ ----------
1496
+ col1 : str
1497
+ The name of the first column
1498
+ col2 : str
1499
+ The name of the second column
1500
+
1501
+ Returns
1502
+ -------
1503
+ float
1504
+ Covariance of two columns.
1505
+
1506
+ Examples
1507
+ --------
1508
+ >>> df = spark.createDataFrame([(1, 12), (10, 1), (19, 8)], ["c1", "c2"])
1509
+ >>> df.cov("c1", "c2")
1510
+ -18.0
1511
+ >>> df = spark.createDataFrame([(11, 12), (10, 11), (9, 10)], ["small", "bigger"])
1512
+ >>> df.cov("small", "bigger")
1513
+ 1.0
1514
+
1515
+ """
1516
+ covar_samp = get_func_from_session("covar_samp")
1517
+ col_func = get_func_from_session("col")
1518
+
1519
+ return self.select(covar_samp(col_func(col1), col_func(col2))).collect()[0][0]