sqlframe 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sqlframe/__init__.py +0 -0
  2. sqlframe/_version.py +16 -0
  3. sqlframe/base/__init__.py +0 -0
  4. sqlframe/base/_typing.py +39 -0
  5. sqlframe/base/catalog.py +1163 -0
  6. sqlframe/base/column.py +388 -0
  7. sqlframe/base/dataframe.py +1519 -0
  8. sqlframe/base/decorators.py +51 -0
  9. sqlframe/base/exceptions.py +14 -0
  10. sqlframe/base/function_alternatives.py +1055 -0
  11. sqlframe/base/functions.py +1678 -0
  12. sqlframe/base/group.py +102 -0
  13. sqlframe/base/mixins/__init__.py +0 -0
  14. sqlframe/base/mixins/catalog_mixins.py +419 -0
  15. sqlframe/base/mixins/readwriter_mixins.py +118 -0
  16. sqlframe/base/normalize.py +84 -0
  17. sqlframe/base/operations.py +87 -0
  18. sqlframe/base/readerwriter.py +679 -0
  19. sqlframe/base/session.py +585 -0
  20. sqlframe/base/transforms.py +13 -0
  21. sqlframe/base/types.py +418 -0
  22. sqlframe/base/util.py +242 -0
  23. sqlframe/base/window.py +139 -0
  24. sqlframe/bigquery/__init__.py +23 -0
  25. sqlframe/bigquery/catalog.py +255 -0
  26. sqlframe/bigquery/column.py +1 -0
  27. sqlframe/bigquery/dataframe.py +54 -0
  28. sqlframe/bigquery/functions.py +378 -0
  29. sqlframe/bigquery/group.py +14 -0
  30. sqlframe/bigquery/readwriter.py +29 -0
  31. sqlframe/bigquery/session.py +89 -0
  32. sqlframe/bigquery/types.py +1 -0
  33. sqlframe/bigquery/window.py +1 -0
  34. sqlframe/duckdb/__init__.py +20 -0
  35. sqlframe/duckdb/catalog.py +108 -0
  36. sqlframe/duckdb/column.py +1 -0
  37. sqlframe/duckdb/dataframe.py +55 -0
  38. sqlframe/duckdb/functions.py +47 -0
  39. sqlframe/duckdb/group.py +14 -0
  40. sqlframe/duckdb/readwriter.py +111 -0
  41. sqlframe/duckdb/session.py +65 -0
  42. sqlframe/duckdb/types.py +1 -0
  43. sqlframe/duckdb/window.py +1 -0
  44. sqlframe/postgres/__init__.py +23 -0
  45. sqlframe/postgres/catalog.py +106 -0
  46. sqlframe/postgres/column.py +1 -0
  47. sqlframe/postgres/dataframe.py +54 -0
  48. sqlframe/postgres/functions.py +61 -0
  49. sqlframe/postgres/group.py +14 -0
  50. sqlframe/postgres/readwriter.py +29 -0
  51. sqlframe/postgres/session.py +68 -0
  52. sqlframe/postgres/types.py +1 -0
  53. sqlframe/postgres/window.py +1 -0
  54. sqlframe/redshift/__init__.py +23 -0
  55. sqlframe/redshift/catalog.py +127 -0
  56. sqlframe/redshift/column.py +1 -0
  57. sqlframe/redshift/dataframe.py +54 -0
  58. sqlframe/redshift/functions.py +18 -0
  59. sqlframe/redshift/group.py +14 -0
  60. sqlframe/redshift/readwriter.py +29 -0
  61. sqlframe/redshift/session.py +53 -0
  62. sqlframe/redshift/types.py +1 -0
  63. sqlframe/redshift/window.py +1 -0
  64. sqlframe/snowflake/__init__.py +26 -0
  65. sqlframe/snowflake/catalog.py +134 -0
  66. sqlframe/snowflake/column.py +1 -0
  67. sqlframe/snowflake/dataframe.py +54 -0
  68. sqlframe/snowflake/functions.py +18 -0
  69. sqlframe/snowflake/group.py +14 -0
  70. sqlframe/snowflake/readwriter.py +29 -0
  71. sqlframe/snowflake/session.py +53 -0
  72. sqlframe/snowflake/types.py +1 -0
  73. sqlframe/snowflake/window.py +1 -0
  74. sqlframe/spark/__init__.py +23 -0
  75. sqlframe/spark/catalog.py +1028 -0
  76. sqlframe/spark/column.py +1 -0
  77. sqlframe/spark/dataframe.py +54 -0
  78. sqlframe/spark/functions.py +22 -0
  79. sqlframe/spark/group.py +14 -0
  80. sqlframe/spark/readwriter.py +29 -0
  81. sqlframe/spark/session.py +90 -0
  82. sqlframe/spark/types.py +1 -0
  83. sqlframe/spark/window.py +1 -0
  84. sqlframe/standalone/__init__.py +26 -0
  85. sqlframe/standalone/catalog.py +13 -0
  86. sqlframe/standalone/column.py +1 -0
  87. sqlframe/standalone/dataframe.py +36 -0
  88. sqlframe/standalone/functions.py +1 -0
  89. sqlframe/standalone/group.py +14 -0
  90. sqlframe/standalone/readwriter.py +19 -0
  91. sqlframe/standalone/session.py +40 -0
  92. sqlframe/standalone/types.py +1 -0
  93. sqlframe/standalone/window.py +1 -0
  94. sqlframe-1.1.3.dist-info/LICENSE +21 -0
  95. sqlframe-1.1.3.dist-info/METADATA +172 -0
  96. sqlframe-1.1.3.dist-info/RECORD +98 -0
  97. sqlframe-1.1.3.dist-info/WHEEL +5 -0
  98. sqlframe-1.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,679 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import pathlib
7
+ import sys
8
+ import typing as t
9
+ from functools import reduce
10
+
11
+ from sqlglot import exp
12
+ from sqlglot.helper import object_to_dict
13
+
14
+ if sys.version_info >= (3, 11):
15
+ from typing import Self
16
+ else:
17
+ from typing_extensions import Self
18
+
19
+ if t.TYPE_CHECKING:
20
+ from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
21
+ from sqlframe.base.column import Column
22
+ from sqlframe.base.session import DF, _BaseSession
23
+ from sqlframe.base.types import StructType
24
+
25
+ SESSION = t.TypeVar("SESSION", bound=_BaseSession)
26
+ else:
27
+ SESSION = t.TypeVar("SESSION")
28
+ DF = t.TypeVar("DF")
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
35
+ def __init__(self, spark: SESSION):
36
+ self._session = spark
37
+
38
+ @property
39
+ def session(self) -> SESSION:
40
+ return self._session
41
+
42
+ def table(self, tableName: str) -> DF:
43
+ if df := self.session.temp_views.get(tableName):
44
+ return df
45
+ table = (
46
+ exp.to_table(tableName, dialect=self.session.input_dialect)
47
+ .transform(self.session.input_dialect.normalize_identifier)
48
+ .assert_is(exp.Table)
49
+ )
50
+ self.session.catalog.add_table(table)
51
+ columns = self.session.catalog.get_columns_from_schema(table)
52
+
53
+ return self.session._create_df(
54
+ exp.Select().from_(table).select(*columns, dialect=self.session.input_dialect)
55
+ )
56
+
57
+ def _to_casted_columns(self, column_mapping: t.Dict) -> t.List[Column]:
58
+ from sqlframe.base.column import Column
59
+
60
+ return [
61
+ Column(
62
+ exp.cast(
63
+ exp.to_column(k), to=exp.DataType.build(v, dialect=self.session.input_dialect)
64
+ ).as_(k)
65
+ )
66
+ for k, v in column_mapping.items()
67
+ ]
68
+
69
+ def load(
70
+ self,
71
+ path: t.Optional[PathOrPaths] = None,
72
+ format: t.Optional[str] = None,
73
+ schema: t.Optional[t.Union[StructType, str]] = None,
74
+ **options: OptionalPrimitiveType,
75
+ ) -> DF:
76
+ raise NotImplementedError()
77
+
78
+ def json(
79
+ self,
80
+ path: t.Union[str, t.List[str]],
81
+ schema: t.Optional[t.Union[StructType, str]] = None,
82
+ primitivesAsString: t.Optional[t.Union[bool, str]] = None,
83
+ prefersDecimal: t.Optional[t.Union[bool, str]] = None,
84
+ allowComments: t.Optional[t.Union[bool, str]] = None,
85
+ allowUnquotedFieldNames: t.Optional[t.Union[bool, str]] = None,
86
+ allowSingleQuotes: t.Optional[t.Union[bool, str]] = None,
87
+ allowNumericLeadingZero: t.Optional[t.Union[bool, str]] = None,
88
+ allowBackslashEscapingAnyCharacter: t.Optional[t.Union[bool, str]] = None,
89
+ mode: t.Optional[str] = None,
90
+ columnNameOfCorruptRecord: t.Optional[str] = None,
91
+ dateFormat: t.Optional[str] = None,
92
+ timestampFormat: t.Optional[str] = None,
93
+ multiLine: t.Optional[t.Union[bool, str]] = None,
94
+ allowUnquotedControlChars: t.Optional[t.Union[bool, str]] = None,
95
+ lineSep: t.Optional[str] = None,
96
+ samplingRatio: t.Optional[t.Union[float, str]] = None,
97
+ dropFieldIfAllNull: t.Optional[t.Union[bool, str]] = None,
98
+ encoding: t.Optional[str] = None,
99
+ locale: t.Optional[str] = None,
100
+ pathGlobFilter: t.Optional[t.Union[bool, str]] = None,
101
+ recursiveFileLookup: t.Optional[t.Union[bool, str]] = None,
102
+ modifiedBefore: t.Optional[t.Union[bool, str]] = None,
103
+ modifiedAfter: t.Optional[t.Union[bool, str]] = None,
104
+ allowNonNumericNumbers: t.Optional[t.Union[bool, str]] = None,
105
+ ) -> DF:
106
+ """
107
+ Loads JSON files and returns the results as a :class:`DataFrame`.
108
+
109
+ `JSON Lines <http://jsonlines.org/>`_ (newline-delimited JSON) is supported by default.
110
+ For JSON (one record per file), set the ``multiLine`` parameter to ``true``.
111
+
112
+ If the ``schema`` parameter is not specified, this function goes
113
+ through the input once to determine the input schema.
114
+
115
+ .. versionadded:: 1.4.0
116
+
117
+ .. versionchanged:: 3.4.0
118
+ Supports Spark Connect.
119
+
120
+ Parameters
121
+ ----------
122
+ path : str, list or :class:`RDD`
123
+ string represents path to the JSON dataset, or a list of paths,
124
+ or RDD of Strings storing JSON objects.
125
+ schema : :class:`pyspark.sql.types.StructType` or str, t.Optional
126
+ an t.Optional :class:`pyspark.sql.types.StructType` for the input schema or
127
+ a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
128
+
129
+ Other Parameters
130
+ ----------------
131
+ Extra options
132
+ For the extra options, refer to
133
+ `Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option>`_
134
+ for the version you use.
135
+
136
+ .. # noqa
137
+
138
+ Examples
139
+ --------
140
+ Write a DataFrame into a JSON file and read it back.
141
+
142
+ >>> import tempfile
143
+ >>> with tempfile.TemporaryDirectory() as d:
144
+ ... # Write a DataFrame into a JSON file
145
+ ... spark.createDataFrame(
146
+ ... [{"age": 100, "name": "Hyukjin Kwon"}]
147
+ ... ).write.mode("overwrite").format("json").save(d)
148
+ ...
149
+ ... # Read the JSON file as a DataFrame.
150
+ ... spark.read.json(d).show()
151
+ +---+------------+
152
+ |age| name|
153
+ +---+------------+
154
+ |100|Hyukjin Kwon|
155
+ +---+------------+
156
+ """
157
+ options = dict(
158
+ primitivesAsString=primitivesAsString,
159
+ prefersDecimal=prefersDecimal,
160
+ allowComments=allowComments,
161
+ allowUnquotedFieldNames=allowUnquotedFieldNames,
162
+ allowSingleQuotes=allowSingleQuotes,
163
+ allowNumericLeadingZero=allowNumericLeadingZero,
164
+ allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
165
+ mode=mode,
166
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord,
167
+ dateFormat=dateFormat,
168
+ timestampFormat=timestampFormat,
169
+ multiLine=multiLine,
170
+ allowUnquotedControlChars=allowUnquotedControlChars,
171
+ lineSep=lineSep,
172
+ samplingRatio=samplingRatio,
173
+ dropFieldIfAllNull=dropFieldIfAllNull,
174
+ encoding=encoding,
175
+ locale=locale,
176
+ pathGlobFilter=pathGlobFilter,
177
+ recursiveFileLookup=recursiveFileLookup,
178
+ modifiedBefore=modifiedBefore,
179
+ modifiedAfter=modifiedAfter,
180
+ allowNonNumericNumbers=allowNonNumericNumbers,
181
+ )
182
+ return self.load(path=path, format="json", schema=schema, **options)
183
+
184
+ def parquet(self, *paths: str, **options: OptionalPrimitiveType) -> DF:
185
+ """
186
+ Loads Parquet files, returning the result as a :class:`DataFrame`.
187
+
188
+ .. versionadded:: 1.4.0
189
+
190
+ .. versionchanged:: 3.4.0
191
+ Supports Spark Connect.
192
+
193
+ Parameters
194
+ ----------
195
+ paths : str
196
+
197
+ Other Parameters
198
+ ----------------
199
+ **options
200
+ For the extra options, refer to
201
+ `Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-parquet.html#data-source-option>`_
202
+ for the version you use.
203
+
204
+ .. # noqa
205
+
206
+ Examples
207
+ --------
208
+ Write a DataFrame into a Parquet file and read it back.
209
+
210
+ >>> import tempfile
211
+ >>> with tempfile.TemporaryDirectory() as d:
212
+ ... # Write a DataFrame into a Parquet file
213
+ ... spark.createDataFrame(
214
+ ... [{"age": 100, "name": "Hyukjin Kwon"}]
215
+ ... ).write.mode("overwrite").format("parquet").save(d)
216
+ ...
217
+ ... # Read the Parquet file as a DataFrame.
218
+ ... spark.read.parquet(d).show()
219
+ +---+------------+
220
+ |age| name|
221
+ +---+------------+
222
+ |100|Hyukjin Kwon|
223
+ +---+------------+
224
+ """
225
+ dfs = [self.load(path=path, format="parquet", **options) for path in paths] # type: ignore
226
+ return reduce(lambda a, b: a.union(b), dfs)
227
+
228
+ def csv(
229
+ self,
230
+ path: PathOrPaths,
231
+ schema: t.Optional[t.Union[StructType, str]] = None,
232
+ sep: t.Optional[str] = None,
233
+ encoding: t.Optional[str] = None,
234
+ quote: t.Optional[str] = None,
235
+ escape: t.Optional[str] = None,
236
+ comment: t.Optional[str] = None,
237
+ header: t.Optional[t.Union[bool, str]] = None,
238
+ inferSchema: t.Optional[t.Union[bool, str]] = None,
239
+ ignoreLeadingWhiteSpace: t.Optional[t.Union[bool, str]] = None,
240
+ ignoreTrailingWhiteSpace: t.Optional[t.Union[bool, str]] = None,
241
+ nullValue: t.Optional[str] = None,
242
+ nanValue: t.Optional[str] = None,
243
+ positiveInf: t.Optional[str] = None,
244
+ negativeInf: t.Optional[str] = None,
245
+ dateFormat: t.Optional[str] = None,
246
+ timestampFormat: t.Optional[str] = None,
247
+ maxColumns: t.Optional[t.Union[int, str]] = None,
248
+ maxCharsPerColumn: t.Optional[t.Union[int, str]] = None,
249
+ maxMalformedLogPerPartition: t.Optional[t.Union[int, str]] = None,
250
+ mode: t.Optional[str] = None,
251
+ columnNameOfCorruptRecord: t.Optional[str] = None,
252
+ multiLine: t.Optional[t.Union[bool, str]] = None,
253
+ charToEscapeQuoteEscaping: t.Optional[str] = None,
254
+ samplingRatio: t.Optional[t.Union[float, str]] = None,
255
+ enforceSchema: t.Optional[t.Union[bool, str]] = None,
256
+ emptyValue: t.Optional[str] = None,
257
+ locale: t.Optional[str] = None,
258
+ lineSep: t.Optional[str] = None,
259
+ pathGlobFilter: t.Optional[t.Union[bool, str]] = None,
260
+ recursiveFileLookup: t.Optional[t.Union[bool, str]] = None,
261
+ modifiedBefore: t.Optional[t.Union[bool, str]] = None,
262
+ modifiedAfter: t.Optional[t.Union[bool, str]] = None,
263
+ unescapedQuoteHandling: t.Optional[str] = None,
264
+ ) -> DF:
265
+ r"""Loads a CSV file and returns the result as a :class:`DataFrame`.
266
+
267
+ This function will go through the input once to determine the input schema if
268
+ ``inferSchema`` is enabled. To avoid going through the entire data once, disable
269
+ ``inferSchema`` option or specify the schema explicitly using ``schema``.
270
+
271
+ .. versionadded:: 2.0.0
272
+
273
+ .. versionchanged:: 3.4.0
274
+ Supports Spark Connect.
275
+
276
+ Parameters
277
+ ----------
278
+ path : str or list
279
+ string, or list of strings, for input path(s),
280
+ or RDD of Strings storing CSV rows.
281
+ schema : :class:`pyspark.sql.types.StructType` or str, t.Optional
282
+ an t.Optional :class:`pyspark.sql.types.StructType` for the input schema
283
+ or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
284
+
285
+ Other Parameters
286
+ ----------------
287
+ Extra options
288
+ For the extra options, refer to
289
+ `Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-csv.html#data-source-option>`_
290
+ for the version you use.
291
+
292
+ .. # noqa
293
+
294
+ Examples
295
+ --------
296
+ Write a DataFrame into a CSV file and read it back.
297
+
298
+ >>> import tempfile
299
+ >>> with tempfile.TemporaryDirectory() as d:
300
+ ... # Write a DataFrame into a CSV file
301
+ ... df = spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}])
302
+ ... df.write.mode("overwrite").format("csv").save(d)
303
+ ...
304
+ ... # Read the CSV file as a DataFrame with 'nullValue' option set to 'Hyukjin Kwon'.
305
+ ... spark.read.csv(d, schema=df.schema, nullValue="Hyukjin Kwon").show()
306
+ +---+----+
307
+ |age|name|
308
+ +---+----+
309
+ |100|NULL|
310
+ +---+----+
311
+ """
312
+ options = dict(
313
+ sep=sep,
314
+ encoding=encoding,
315
+ quote=quote,
316
+ escape=escape,
317
+ comment=comment,
318
+ header=header,
319
+ inferSchema=inferSchema,
320
+ ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
321
+ ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
322
+ nullValue=nullValue,
323
+ nanValue=nanValue,
324
+ positiveInf=positiveInf,
325
+ negativeInf=negativeInf,
326
+ dateFormat=dateFormat,
327
+ timestampFormat=timestampFormat,
328
+ maxColumns=maxColumns,
329
+ maxCharsPerColumn=maxCharsPerColumn,
330
+ maxMalformedLogPerPartition=maxMalformedLogPerPartition,
331
+ mode=mode,
332
+ columnNameOfCorruptRecord=columnNameOfCorruptRecord,
333
+ multiLine=multiLine,
334
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
335
+ samplingRatio=samplingRatio,
336
+ enforceSchema=enforceSchema,
337
+ emptyValue=emptyValue,
338
+ locale=locale,
339
+ lineSep=lineSep,
340
+ pathGlobFilter=pathGlobFilter,
341
+ recursiveFileLookup=recursiveFileLookup,
342
+ modifiedBefore=modifiedBefore,
343
+ modifiedAfter=modifiedAfter,
344
+ unescapedQuoteHandling=unescapedQuoteHandling,
345
+ )
346
+ return self.load(path=path, format="csv", schema=schema, **options)
347
+
348
+
349
+ class _BaseDataFrameWriter(t.Generic[SESSION, DF]):
350
+ def __init__(
351
+ self,
352
+ df: DF,
353
+ mode: t.Optional[str] = None,
354
+ by_name: bool = False,
355
+ ):
356
+ self._df = df
357
+ self._mode = mode
358
+ self._by_name = by_name
359
+
360
+ @property
361
+ def _session(self) -> SESSION:
362
+ return self._df.session
363
+
364
+ def copy(self, **kwargs) -> Self:
365
+ return self.__class__(
366
+ **{
367
+ k[1:] if k.startswith("_") else k: v
368
+ for k, v in object_to_dict(self, **kwargs).items()
369
+ }
370
+ )
371
+
372
+ def sql(self, **kwargs) -> t.List[str]:
373
+ return self._df.sql(**{**dict(pretty=False, optimize=False, as_list=True), **kwargs})
374
+
375
+ def mode(self, saveMode: t.Optional[str]) -> Self:
376
+ return self.copy(_mode=saveMode)
377
+
378
+ @property
379
+ def byName(self) -> Self:
380
+ return self.copy(by_name=True)
381
+
382
+ def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> Self:
383
+ from sqlframe.base.session import _BaseSession
384
+
385
+ output_expression_container = exp.Insert(
386
+ **{
387
+ "this": exp.to_table(tableName),
388
+ "overwrite": overwrite,
389
+ }
390
+ )
391
+ df = self._df.copy(output_expression_container=output_expression_container)
392
+ if self._by_name:
393
+ columns = self._session.catalog._schema.column_names(
394
+ tableName, only_visible=True, dialect=_BaseSession().input_dialect
395
+ )
396
+ df = df._convert_leaf_to_cte().select(*columns)
397
+
398
+ if self._session._has_connection:
399
+ for sql in df.sql(pretty=False, optimize=False, as_list=True):
400
+ self._session._execute(sql)
401
+ return self.copy(_df=df)
402
+
403
+ def saveAsTable(
404
+ self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None
405
+ ) -> Self:
406
+ if format is not None:
407
+ raise NotImplementedError("Providing Format in the save as table is not supported")
408
+ exists, replace, mode = None, None, mode or str(self._mode)
409
+ if mode == "append":
410
+ return self.insertInto(name)
411
+ if mode == "ignore":
412
+ exists = True
413
+ if mode == "overwrite":
414
+ replace = True
415
+ output_expression_container = exp.Create(
416
+ this=exp.to_table(name),
417
+ kind="TABLE",
418
+ exists=exists,
419
+ replace=replace,
420
+ )
421
+ df = self._df.copy(output_expression_container=output_expression_container)
422
+ if self._session._has_connection:
423
+ for sql in df.sql(pretty=False, optimize=False, as_list=True):
424
+ self._session._execute(sql)
425
+ return self.copy(_df=df)
426
+
427
+ @staticmethod
428
+ def _mode_to_pandas_mode(mode: t.Optional[str]) -> str:
429
+ if mode is None or mode in {"ignore", "error", "errorifexists", "overwrite"}:
430
+ return "w"
431
+ if mode == "append":
432
+ return "a"
433
+ raise ValueError(f"Unsupported mode: {mode}")
434
+
435
+ def _validate_mode(self, path: str, mode: t.Optional[str]) -> t.Tuple[str, bool]:
436
+ mode = mode or "error"
437
+ if mode in {"error", "errorifexists"} and pathlib.Path(path).exists():
438
+ raise FileExistsError(f"Path already exists: {path}")
439
+ if mode == "ignore" and pathlib.Path(path).exists():
440
+ return mode, True
441
+ return mode, False
442
+
443
+ def _write(self, path: str, mode: t.Optional[str], format: str, **options) -> None:
444
+ raise NotImplementedError
445
+
446
+ def json(
447
+ self,
448
+ path: str,
449
+ mode: t.Optional[str] = None,
450
+ compression: t.Optional[str] = None,
451
+ dateFormat: t.Optional[str] = None,
452
+ timestampFormat: t.Optional[str] = None,
453
+ lineSep: t.Optional[str] = None,
454
+ encoding: t.Optional[str] = None,
455
+ ignoreNullFields: t.Optional[t.Union[bool, str]] = None,
456
+ ) -> None:
457
+ """Saves the content of the :class:`DataFrame` in JSON format
458
+ (`JSON Lines text format or newline-delimited JSON <http://jsonlines.org/>`_) at the
459
+ specified path.
460
+
461
+ .. versionadded:: 1.4.0
462
+
463
+ .. versionchanged:: 3.4.0
464
+ Supports Spark Connect.
465
+
466
+ Parameters
467
+ ----------
468
+ path : str
469
+ the path in any Hadoop supported file system
470
+ mode : str, t.Optional
471
+ specifies the behavior of the save operation when data already exists.
472
+
473
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
474
+ * ``overwrite``: Overwrite existing data.
475
+ * ``ignore``: Silently ignore this operation if data already exists.
476
+ * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \
477
+ exists.
478
+
479
+ Other Parameters
480
+ ----------------
481
+ Extra options
482
+ For the extra options, refer to
483
+ `Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-json.html#data-source-option>`_
484
+ for the version you use.
485
+
486
+ .. # noqa
487
+
488
+ Examples
489
+ --------
490
+ Write a DataFrame into a JSON file and read it back.
491
+
492
+ >>> import tempfile
493
+ >>> with tempfile.TemporaryDirectory() as d:
494
+ ... # Write a DataFrame into a JSON file
495
+ ... spark.createDataFrame(
496
+ ... [{"age": 100, "name": "Hyukjin Kwon"}]
497
+ ... ).write.json(d, mode="overwrite")
498
+ ...
499
+ ... # Read the JSON file as a DataFrame.
500
+ ... spark.read.format("json").load(d).show()
501
+ +---+------------+
502
+ |age| name|
503
+ +---+------------+
504
+ |100|Hyukjin Kwon|
505
+ +---+------------+
506
+ """
507
+ self._write(
508
+ path=path,
509
+ mode=mode,
510
+ format="json",
511
+ compression=compression,
512
+ dateFormat=dateFormat,
513
+ timestampFormat=timestampFormat,
514
+ lineSep=lineSep,
515
+ encoding=encoding,
516
+ ignoreNullFields=ignoreNullFields,
517
+ )
518
+
519
+ def parquet(
520
+ self,
521
+ path: str,
522
+ mode: t.Optional[str] = None,
523
+ partitionBy: t.Optional[t.Union[str, t.List[str]]] = None,
524
+ compression: t.Optional[str] = None,
525
+ ) -> None:
526
+ """Saves the content of the :class:`DataFrame` in Parquet format at the specified path.
527
+
528
+ .. versionadded:: 1.4.0
529
+
530
+ .. versionchanged:: 3.4.0
531
+ Supports Spark Connect.
532
+
533
+ Parameters
534
+ ----------
535
+ path : str
536
+ the path in any Hadoop supported file system
537
+ mode : str, t.Optional
538
+ specifies the behavior of the save operation when data already exists.
539
+
540
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
541
+ * ``overwrite``: Overwrite existing data.
542
+ * ``ignore``: Silently ignore this operation if data already exists.
543
+ * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \
544
+ exists.
545
+ partitionBy : str or list, t.Optional
546
+ names of partitioning columns
547
+
548
+ Other Parameters
549
+ ----------------
550
+ Extra options
551
+ For the extra options, refer to
552
+ `Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-parquet.html#data-source-option>`_
553
+ for the version you use.
554
+
555
+ .. # noqa
556
+
557
+ Examples
558
+ --------
559
+ Write a DataFrame into a Parquet file and read it back.
560
+
561
+ >>> import tempfile
562
+ >>> with tempfile.TemporaryDirectory() as d:
563
+ ... # Write a DataFrame into a Parquet file
564
+ ... spark.createDataFrame(
565
+ ... [{"age": 100, "name": "Hyukjin Kwon"}]
566
+ ... ).write.parquet(d, mode="overwrite")
567
+ ...
568
+ ... # Read the Parquet file as a DataFrame.
569
+ ... spark.read.format("parquet").load(d).show()
570
+ +---+------------+
571
+ |age| name|
572
+ +---+------------+
573
+ |100|Hyukjin Kwon|
574
+ +---+------------+
575
+ """
576
+ self._write(
577
+ path=path, mode=mode, format="parquet", compression=compression, partitionBy=partitionBy
578
+ )
579
+
580
+ def csv(
581
+ self,
582
+ path: str,
583
+ mode: t.Optional[str] = None,
584
+ compression: t.Optional[str] = None,
585
+ sep: t.Optional[str] = None,
586
+ quote: t.Optional[str] = None,
587
+ escape: t.Optional[str] = None,
588
+ header: t.Optional[t.Union[bool, str]] = None,
589
+ nullValue: t.Optional[str] = None,
590
+ escapeQuotes: t.Optional[t.Union[bool, str]] = None,
591
+ quoteAll: t.Optional[t.Union[bool, str]] = None,
592
+ dateFormat: t.Optional[str] = None,
593
+ timestampFormat: t.Optional[str] = None,
594
+ ignoreLeadingWhiteSpace: t.Optional[t.Union[bool, str]] = None,
595
+ ignoreTrailingWhiteSpace: t.Optional[t.Union[bool, str]] = None,
596
+ charToEscapeQuoteEscaping: t.Optional[str] = None,
597
+ encoding: t.Optional[str] = None,
598
+ emptyValue: t.Optional[str] = None,
599
+ lineSep: t.Optional[str] = None,
600
+ ) -> None:
601
+ r"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
602
+
603
+ .. versionadded:: 2.0.0
604
+
605
+ .. versionchanged:: 3.4.0
606
+ Supports Spark Connect.
607
+
608
+ Parameters
609
+ ----------
610
+ path : str
611
+ the path in any Hadoop supported file system
612
+ mode : str, t.Optional
613
+ specifies the behavior of the save operation when data already exists.
614
+
615
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
616
+ * ``overwrite``: Overwrite existing data.
617
+ * ``ignore``: Silently ignore this operation if data already exists.
618
+ * ``error`` or ``errorifexists`` (default case): Throw an exception if data already \
619
+ exists.
620
+
621
+ Other Parameters
622
+ ----------------
623
+ Extra options
624
+ For the extra options, refer to
625
+ `Data Source Option <https://spark.apache.org/docs/latest/sql-data-sources-csv.html#data-source-option>`_
626
+ for the version you use.
627
+
628
+ .. # noqa
629
+
630
+ Examples
631
+ --------
632
+ Write a DataFrame into a CSV file and read it back.
633
+
634
+ >>> import tempfile
635
+ >>> with tempfile.TemporaryDirectory() as d:
636
+ ... # Write a DataFrame into a CSV file
637
+ ... df = spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}])
638
+ ... df.write.csv(d, mode="overwrite")
639
+ ...
640
+ ... # Read the CSV file as a DataFrame with 'nullValue' option set to 'Hyukjin Kwon'.
641
+ ... spark.read.schema(df.schema).format("csv").option(
642
+ ... "nullValue", "Hyukjin Kwon").load(d).show()
643
+ +---+----+
644
+ |age|name|
645
+ +---+----+
646
+ |100|NULL|
647
+ +---+----+
648
+ """
649
+ self._write(
650
+ path=path,
651
+ mode=mode,
652
+ format="csv",
653
+ compression=compression,
654
+ sep=sep,
655
+ quote=quote,
656
+ escape=escape,
657
+ header=header,
658
+ nullValue=nullValue,
659
+ escapeQuotes=escapeQuotes,
660
+ quoteAll=quoteAll,
661
+ dateFormat=dateFormat,
662
+ timestampFormat=timestampFormat,
663
+ ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
664
+ ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
665
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
666
+ encoding=encoding,
667
+ emptyValue=emptyValue,
668
+ lineSep=lineSep,
669
+ )
670
+
671
+
672
+ def _infer_format(path: str) -> str:
673
+ if path.endswith(".json"):
674
+ return "json"
675
+ if path.endswith(".parquet"):
676
+ return "parquet"
677
+ if path.endswith(".csv"):
678
+ return "csv"
679
+ raise ValueError(f"Cannot infer format from path: {path}")