sqlframe 3.23.0__py3-none-any.whl → 3.24.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/__init__.py +12 -0
- sqlframe/_version.py +2 -2
- sqlframe/base/catalog.py +5 -4
- sqlframe/base/column.py +57 -0
- sqlframe/base/dataframe.py +6 -2
- sqlframe/base/group.py +2 -0
- sqlframe/base/mixins/catalog_mixins.py +147 -11
- sqlframe/base/mixins/dataframe_mixins.py +4 -1
- sqlframe/base/operations.py +42 -14
- sqlframe/base/readerwriter.py +4 -1
- sqlframe/base/window.py +6 -6
- sqlframe/bigquery/catalog.py +6 -3
- sqlframe/databricks/catalog.py +185 -11
- sqlframe/databricks/readwriter.py +293 -13
- sqlframe/duckdb/catalog.py +12 -9
- sqlframe/postgres/catalog.py +10 -7
- sqlframe/py.typed +1 -0
- sqlframe/redshift/catalog.py +11 -8
- sqlframe/snowflake/catalog.py +12 -9
- sqlframe/spark/catalog.py +21 -5
- sqlframe/standalone/catalog.py +4 -1
- {sqlframe-3.23.0.dist-info → sqlframe-3.24.1.dist-info}/METADATA +2 -2
- {sqlframe-3.23.0.dist-info → sqlframe-3.24.1.dist-info}/RECORD +26 -25
- {sqlframe-3.23.0.dist-info → sqlframe-3.24.1.dist-info}/LICENSE +0 -0
- {sqlframe-3.23.0.dist-info → sqlframe-3.24.1.dist-info}/WHEEL +0 -0
- {sqlframe-3.23.0.dist-info → sqlframe-3.24.1.dist-info}/top_level.txt +0 -0
sqlframe/databricks/catalog.py
CHANGED
|
@@ -3,36 +3,49 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import fnmatch
|
|
6
|
-
import json
|
|
7
6
|
import typing as t
|
|
8
7
|
|
|
8
|
+
import sqlglot as sg
|
|
9
9
|
from sqlglot import exp, parse_one
|
|
10
10
|
|
|
11
|
-
from sqlframe.base.catalog import Column, Function, _BaseCatalog
|
|
11
|
+
from sqlframe.base.catalog import TABLE, Column, Function, _BaseCatalog
|
|
12
12
|
from sqlframe.base.mixins.catalog_mixins import (
|
|
13
|
+
CreateTableFromFunctionMixin,
|
|
13
14
|
GetCurrentCatalogFromFunctionMixin,
|
|
14
15
|
GetCurrentDatabaseFromFunctionMixin,
|
|
15
16
|
ListCatalogsFromInfoSchemaMixin,
|
|
16
17
|
ListDatabasesFromInfoSchemaMixin,
|
|
17
18
|
ListTablesFromInfoSchemaMixin,
|
|
18
|
-
SetCurrentCatalogFromUseMixin,
|
|
19
19
|
SetCurrentDatabaseFromUseMixin,
|
|
20
20
|
)
|
|
21
|
-
from sqlframe.base.
|
|
21
|
+
from sqlframe.base.types import StructType
|
|
22
|
+
from sqlframe.base.util import (
|
|
23
|
+
get_column_mapping_from_schema_input,
|
|
24
|
+
normalize_string,
|
|
25
|
+
schema_,
|
|
26
|
+
to_csv,
|
|
27
|
+
to_schema,
|
|
28
|
+
)
|
|
22
29
|
|
|
23
30
|
if t.TYPE_CHECKING:
|
|
24
31
|
from sqlframe.databricks.session import DatabricksSession # noqa
|
|
25
32
|
from sqlframe.databricks.dataframe import DatabricksDataFrame # noqa
|
|
33
|
+
from sqlframe.databricks.table import DatabricksTable # noqa
|
|
26
34
|
|
|
27
35
|
|
|
28
36
|
class DatabricksCatalog(
|
|
29
|
-
GetCurrentCatalogFromFunctionMixin[
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
37
|
+
GetCurrentCatalogFromFunctionMixin[
|
|
38
|
+
"DatabricksSession", "DatabricksDataFrame", "DatabricksTable"
|
|
39
|
+
],
|
|
40
|
+
GetCurrentDatabaseFromFunctionMixin[
|
|
41
|
+
"DatabricksSession", "DatabricksDataFrame", "DatabricksTable"
|
|
42
|
+
],
|
|
43
|
+
CreateTableFromFunctionMixin["DatabricksSession", "DatabricksDataFrame", "DatabricksTable"],
|
|
44
|
+
ListDatabasesFromInfoSchemaMixin["DatabricksSession", "DatabricksDataFrame", "DatabricksTable"],
|
|
45
|
+
ListCatalogsFromInfoSchemaMixin["DatabricksSession", "DatabricksDataFrame", "DatabricksTable"],
|
|
46
|
+
SetCurrentDatabaseFromUseMixin["DatabricksSession", "DatabricksDataFrame", "DatabricksTable"],
|
|
47
|
+
ListTablesFromInfoSchemaMixin["DatabricksSession", "DatabricksDataFrame", "DatabricksTable"],
|
|
48
|
+
_BaseCatalog["DatabricksSession", "DatabricksDataFrame", "DatabricksTable"],
|
|
36
49
|
):
|
|
37
50
|
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog")
|
|
38
51
|
UPPERCASE_INFO_SCHEMA = True
|
|
@@ -310,3 +323,164 @@ class DatabricksCatalog(
|
|
|
310
323
|
)
|
|
311
324
|
|
|
312
325
|
return columns
|
|
326
|
+
|
|
327
|
+
def createTable(
|
|
328
|
+
self,
|
|
329
|
+
tableName: str,
|
|
330
|
+
path: t.Optional[str] = None,
|
|
331
|
+
source: t.Optional[str] = None,
|
|
332
|
+
schema: t.Optional[StructType] = None,
|
|
333
|
+
description: t.Optional[str] = None,
|
|
334
|
+
**options: str,
|
|
335
|
+
) -> DatabricksTable:
|
|
336
|
+
"""Creates a table based on the dataset in a data source.
|
|
337
|
+
|
|
338
|
+
.. versionadded:: 2.2.0
|
|
339
|
+
|
|
340
|
+
Parameters
|
|
341
|
+
----------
|
|
342
|
+
tableName : str
|
|
343
|
+
name of the table to create.
|
|
344
|
+
|
|
345
|
+
.. versionchanged:: 3.4.0
|
|
346
|
+
Allow ``tableName`` to be qualified with catalog name.
|
|
347
|
+
|
|
348
|
+
path : str, t.Optional
|
|
349
|
+
the path in which the data for this table exists.
|
|
350
|
+
When ``path`` is specified, an external table is
|
|
351
|
+
created from the data at the given path. Otherwise a managed table is created.
|
|
352
|
+
source : str, t.Optional
|
|
353
|
+
the source of this table such as 'parquet, 'orc', etc.
|
|
354
|
+
If ``source`` is not specified, the default data source configured by
|
|
355
|
+
``spark.sql.sources.default`` will be used.
|
|
356
|
+
schema : class:`StructType`, t.Optional
|
|
357
|
+
the schema for this table.
|
|
358
|
+
description : str, t.Optional
|
|
359
|
+
the description of this table.
|
|
360
|
+
|
|
361
|
+
.. versionchanged:: 3.1.0
|
|
362
|
+
Added the ``description`` parameter.
|
|
363
|
+
|
|
364
|
+
**options : dict, t.Optional
|
|
365
|
+
extra options to specify in the table.
|
|
366
|
+
|
|
367
|
+
Returns
|
|
368
|
+
-------
|
|
369
|
+
:class:`DataFrame`
|
|
370
|
+
The DataFrame associated with the table.
|
|
371
|
+
|
|
372
|
+
Examples
|
|
373
|
+
--------
|
|
374
|
+
Creating a managed table.
|
|
375
|
+
|
|
376
|
+
>>> _ = spark.catalog.createTable("tbl1", schema=spark.range(1).schema, source='parquet')
|
|
377
|
+
>>> _ = spark.sql("DROP TABLE tbl1")
|
|
378
|
+
|
|
379
|
+
Creating an external table
|
|
380
|
+
|
|
381
|
+
>>> import tempfile
|
|
382
|
+
>>> with tempfile.TemporaryDirectory() as d:
|
|
383
|
+
... _ = spark.catalog.createTable(
|
|
384
|
+
... "tbl2", schema=spark.range(1).schema, path=d, source='parquet')
|
|
385
|
+
>>> _ = spark.sql("DROP TABLE tbl2")
|
|
386
|
+
"""
|
|
387
|
+
if not isinstance(tableName, str):
|
|
388
|
+
raise TypeError("tableName must be a string")
|
|
389
|
+
if path is not None and not isinstance(path, str):
|
|
390
|
+
raise TypeError("path must be a string")
|
|
391
|
+
if source is not None and not isinstance(source, str):
|
|
392
|
+
raise TypeError("source must be a string")
|
|
393
|
+
if schema is not None and not isinstance(schema, StructType):
|
|
394
|
+
raise TypeError("schema must be a StructType")
|
|
395
|
+
if description is not None and not isinstance(description, str):
|
|
396
|
+
raise TypeError("description must be a string")
|
|
397
|
+
|
|
398
|
+
source = (source or "delta").lower()
|
|
399
|
+
replace: t.Union[str, bool, None] = options.pop("replace", None)
|
|
400
|
+
exists: t.Union[str, bool, None] = options.pop("exists", None)
|
|
401
|
+
table_properties: t.Union[str, t.Dict[str, str]] = options.pop("properties", {})
|
|
402
|
+
partitionBy: t.Union[t.List[str], str, None] = options.pop("partitionBy", None)
|
|
403
|
+
clusterBy: t.Union[t.List[str], str, None] = options.pop("clusterBy", None)
|
|
404
|
+
|
|
405
|
+
if isinstance(replace, str) and replace.lower() == "true":
|
|
406
|
+
replace = True
|
|
407
|
+
if isinstance(exists, str) and exists.lower() == "true":
|
|
408
|
+
exists = True
|
|
409
|
+
|
|
410
|
+
schema_expressions: t.List[exp.Expression] = []
|
|
411
|
+
if schema is not None and isinstance(schema, StructType):
|
|
412
|
+
column_mapping = get_column_mapping_from_schema_input(
|
|
413
|
+
schema, dialect=self.session.input_dialect
|
|
414
|
+
)
|
|
415
|
+
schema_expressions = [
|
|
416
|
+
exp.ColumnDef(
|
|
417
|
+
this=exp.parse_identifier(k, dialect=self.session.input_dialect), kind=v
|
|
418
|
+
)
|
|
419
|
+
for k, v in column_mapping.items()
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
name = normalize_string(tableName, from_dialect="input", is_table=True)
|
|
423
|
+
properties: t.List[exp.Expression] = []
|
|
424
|
+
if source is not None:
|
|
425
|
+
properties.append(exp.FileFormatProperty(this=exp.Var(this=source.upper())))
|
|
426
|
+
if path is not None:
|
|
427
|
+
properties.append(exp.LocationProperty(this=exp.convert(path)))
|
|
428
|
+
if replace and source != "delta":
|
|
429
|
+
replace = None
|
|
430
|
+
drop_expression = exp.Drop(
|
|
431
|
+
this=exp.to_table(name, dialect=self.session.input_dialect),
|
|
432
|
+
kind="TABLE",
|
|
433
|
+
exists=True,
|
|
434
|
+
)
|
|
435
|
+
if self.session._has_connection:
|
|
436
|
+
self.session._collect(drop_expression)
|
|
437
|
+
if description is not None:
|
|
438
|
+
properties.append(exp.SchemaCommentProperty(this=exp.convert(description)))
|
|
439
|
+
if partitionBy is not None:
|
|
440
|
+
if isinstance(partitionBy, str):
|
|
441
|
+
partition_by = [partitionBy]
|
|
442
|
+
else:
|
|
443
|
+
partition_by = partitionBy
|
|
444
|
+
properties.append(
|
|
445
|
+
exp.PartitionedByProperty(
|
|
446
|
+
this=exp.Tuple(expressions=list(map(sg.to_identifier, partition_by)))
|
|
447
|
+
)
|
|
448
|
+
)
|
|
449
|
+
if clusterBy is not None:
|
|
450
|
+
if isinstance(clusterBy, str):
|
|
451
|
+
cluster_by = [clusterBy]
|
|
452
|
+
else:
|
|
453
|
+
cluster_by = clusterBy
|
|
454
|
+
properties.append(
|
|
455
|
+
exp.Cluster(
|
|
456
|
+
expressions=[exp.Tuple(expressions=list(map(sg.to_identifier, cluster_by)))]
|
|
457
|
+
)
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
properties.extend(
|
|
461
|
+
exp.Property(this=sg.to_identifier(name=k), value=exp.convert(value=v))
|
|
462
|
+
for k, v in (table_properties if isinstance(table_properties, dict) else {}).items()
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
format_options: dict[str, t.Union[bool, float, int, str, None]] = {
|
|
466
|
+
key: f"'{val}'" for key, val in options.items() if val is not None
|
|
467
|
+
}
|
|
468
|
+
format_options_str = to_csv(format_options, " ")
|
|
469
|
+
|
|
470
|
+
output_expression_container = exp.Create(
|
|
471
|
+
this=exp.Schema(
|
|
472
|
+
this=exp.to_table(name, dialect=self.session.input_dialect),
|
|
473
|
+
expressions=schema_expressions,
|
|
474
|
+
),
|
|
475
|
+
kind="TABLE",
|
|
476
|
+
exists=exists,
|
|
477
|
+
replace=replace,
|
|
478
|
+
properties=exp.Properties(expressions=properties),
|
|
479
|
+
)
|
|
480
|
+
if self.session._has_connection:
|
|
481
|
+
sql = self.session._to_sql(output_expression_container, quote_identifiers=True)
|
|
482
|
+
sql += f" OPTIONS ({format_options_str})" if format_options_str else ""
|
|
483
|
+
self.session._collect(sql)
|
|
484
|
+
|
|
485
|
+
df = self.session.table(name)
|
|
486
|
+
return df
|
|
@@ -6,7 +6,9 @@ import sys
|
|
|
6
6
|
import typing as t
|
|
7
7
|
|
|
8
8
|
import sqlglot as sg
|
|
9
|
+
from databricks.sql import ServerOperationError
|
|
9
10
|
from sqlglot import exp
|
|
11
|
+
from sqlglot.helper import ensure_list
|
|
10
12
|
|
|
11
13
|
if sys.version_info >= (3, 11):
|
|
12
14
|
from typing import Self
|
|
@@ -17,12 +19,21 @@ from sqlframe.base.mixins.readwriter_mixins import PandasLoaderMixin, PandasWrit
|
|
|
17
19
|
from sqlframe.base.readerwriter import (
|
|
18
20
|
_BaseDataFrameReader,
|
|
19
21
|
_BaseDataFrameWriter,
|
|
22
|
+
_infer_format,
|
|
23
|
+
)
|
|
24
|
+
from sqlframe.base.util import (
|
|
25
|
+
ensure_column_mapping,
|
|
26
|
+
generate_random_identifier,
|
|
27
|
+
normalize_string,
|
|
28
|
+
split_filepath,
|
|
29
|
+
to_csv,
|
|
20
30
|
)
|
|
21
|
-
from sqlframe.base.util import normalize_string
|
|
22
31
|
|
|
23
32
|
if t.TYPE_CHECKING:
|
|
24
|
-
from sqlframe.
|
|
33
|
+
from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
|
|
34
|
+
from sqlframe.base.types import StructType
|
|
25
35
|
from sqlframe.databricks.dataframe import DatabricksDataFrame # noqa
|
|
36
|
+
from sqlframe.databricks.session import DatabricksSession # noqa
|
|
26
37
|
from sqlframe.databricks.table import DatabricksTable # noqa
|
|
27
38
|
|
|
28
39
|
|
|
@@ -30,13 +41,232 @@ class DatabricksDataFrameReader(
|
|
|
30
41
|
PandasLoaderMixin["DatabricksSession", "DatabricksDataFrame"],
|
|
31
42
|
_BaseDataFrameReader["DatabricksSession", "DatabricksDataFrame", "DatabricksTable"],
|
|
32
43
|
):
|
|
33
|
-
|
|
44
|
+
def load(
|
|
45
|
+
self,
|
|
46
|
+
path: t.Optional[PathOrPaths] = None,
|
|
47
|
+
format: t.Optional[str] = None,
|
|
48
|
+
schema: t.Optional[t.Union[StructType, str]] = None,
|
|
49
|
+
**options: OptionalPrimitiveType,
|
|
50
|
+
) -> DatabricksDataFrame:
|
|
51
|
+
"""Loads data from a data source and returns it as a :class:`DataFrame`.
|
|
52
|
+
|
|
53
|
+
.. versionadded:: 1.4.0
|
|
54
|
+
|
|
55
|
+
.. versionchanged:: 3.4.0
|
|
56
|
+
Supports Spark Connect.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
path : str or list, t.Optional
|
|
61
|
+
t.Optional string or a list of string for file-system backed data sources.
|
|
62
|
+
format : str, t.Optional
|
|
63
|
+
t.Optional string for format of the data source. Default to 'parquet'.
|
|
64
|
+
schema : :class:`pyspark.sql.types.StructType` or str, t.Optional
|
|
65
|
+
t.Optional :class:`pyspark.sql.types.StructType` for the input schema
|
|
66
|
+
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
|
|
67
|
+
**options : dict
|
|
68
|
+
all other string options
|
|
69
|
+
|
|
70
|
+
Examples
|
|
71
|
+
--------
|
|
72
|
+
Load a CSV file with format, schema and options specified.
|
|
73
|
+
|
|
74
|
+
>>> import tempfile
|
|
75
|
+
>>> with tempfile.TemporaryDirectory() as d:
|
|
76
|
+
... # Write a DataFrame into a CSV file with a header
|
|
77
|
+
... df = spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}])
|
|
78
|
+
... df.write.option("header", True).mode("overwrite").format("csv").save(d)
|
|
79
|
+
...
|
|
80
|
+
... # Read the CSV file as a DataFrame with 'nullValue' option set to 'Hyukjin Kwon',
|
|
81
|
+
... # and 'header' option set to `True`.
|
|
82
|
+
... df = spark.read.load(
|
|
83
|
+
... d, schema=df.schema, format="csv", nullValue="Hyukjin Kwon", header=True)
|
|
84
|
+
... df.printSchema()
|
|
85
|
+
... df.show()
|
|
86
|
+
root
|
|
87
|
+
|-- age: long (nullable = true)
|
|
88
|
+
|-- name: string (nullable = true)
|
|
89
|
+
+---+----+
|
|
90
|
+
|age|name|
|
|
91
|
+
+---+----+
|
|
92
|
+
|100|NULL|
|
|
93
|
+
+---+----+
|
|
94
|
+
"""
|
|
95
|
+
assert path is not None, "path is required"
|
|
96
|
+
assert isinstance(path, str), "path must be a string"
|
|
97
|
+
format = format or self.state_format_to_read or _infer_format(path)
|
|
98
|
+
fs_prefix, filepath = split_filepath(path)
|
|
99
|
+
|
|
100
|
+
if fs_prefix == "":
|
|
101
|
+
return super().load(path, format, schema, **options)
|
|
102
|
+
|
|
103
|
+
if schema:
|
|
104
|
+
column_mapping = ensure_column_mapping(schema)
|
|
105
|
+
select_column_mapping = column_mapping.copy()
|
|
106
|
+
select_columns = [x.expression for x in self._to_casted_columns(select_column_mapping)]
|
|
107
|
+
|
|
108
|
+
if hasattr(schema, "simpleString"):
|
|
109
|
+
schema = schema.simpleString()
|
|
110
|
+
else:
|
|
111
|
+
select_columns = [exp.Star()]
|
|
112
|
+
|
|
113
|
+
if format == "delta":
|
|
114
|
+
from_clause = f"delta.`{fs_prefix + filepath}`"
|
|
115
|
+
elif format:
|
|
116
|
+
paths = ",".join([f"{path}" for path in ensure_list(path)])
|
|
117
|
+
|
|
118
|
+
format_options: dict[str, OptionalPrimitiveType] = {
|
|
119
|
+
k: v for k, v in options.items() if v is not None
|
|
120
|
+
}
|
|
121
|
+
format_options["format"] = format
|
|
122
|
+
format_options["schemaEvolutionMode"] = "none"
|
|
123
|
+
if schema:
|
|
124
|
+
format_options["schema"] = f"{schema}"
|
|
125
|
+
if "inferSchema" in format_options:
|
|
126
|
+
format_options["inferColumnTypes"] = format_options.pop("inferSchema")
|
|
127
|
+
|
|
128
|
+
format_options = {key: f"'{val}'" for key, val in format_options.items()}
|
|
129
|
+
format_options_str = to_csv(format_options, " => ")
|
|
130
|
+
|
|
131
|
+
from_clause = f"read_files('{paths}', {format_options_str})"
|
|
132
|
+
else:
|
|
133
|
+
from_clause = f"'{path}'"
|
|
134
|
+
|
|
135
|
+
df = self.session.sql(
|
|
136
|
+
exp.select(*select_columns).from_(from_clause, dialect=self.session.input_dialect),
|
|
137
|
+
qualify=False,
|
|
138
|
+
)
|
|
139
|
+
if select_columns == [exp.Star()] and df.schema:
|
|
140
|
+
return self.load(path=path, format=format, schema=df.schema, **options)
|
|
141
|
+
self.session._last_loaded_file = path # type: ignore
|
|
142
|
+
return df
|
|
34
143
|
|
|
35
144
|
|
|
36
145
|
class DatabricksDataFrameWriter(
|
|
37
146
|
PandasWriterMixin["DatabricksSession", "DatabricksDataFrame"],
|
|
38
147
|
_BaseDataFrameWriter["DatabricksSession", "DatabricksDataFrame"],
|
|
39
148
|
):
|
|
149
|
+
def save(
|
|
150
|
+
self,
|
|
151
|
+
path: str,
|
|
152
|
+
mode: t.Optional[str] = None,
|
|
153
|
+
format: t.Optional[str] = None,
|
|
154
|
+
partitionBy: t.Optional[t.Union[str, t.List[str]]] = None,
|
|
155
|
+
**options,
|
|
156
|
+
):
|
|
157
|
+
format = str(format or self._state_format_to_write)
|
|
158
|
+
self._write(path, mode, format, partitionBy=partitionBy, **options)
|
|
159
|
+
|
|
160
|
+
def _write(self, path: str, mode: t.Optional[str], format: str, **options): # type: ignore
|
|
161
|
+
fs_prefix, filepath = split_filepath(path)
|
|
162
|
+
if fs_prefix == "":
|
|
163
|
+
super()._write(filepath, mode, format, **options)
|
|
164
|
+
elif format == "delta":
|
|
165
|
+
self.saveAsTable(f"delta.`{fs_prefix + filepath}`", format, mode, **options)
|
|
166
|
+
else:
|
|
167
|
+
mode = str(mode or self._mode or "error")
|
|
168
|
+
partition_by = options.pop("partitionBy", None)
|
|
169
|
+
tmp_table = f"_{generate_random_identifier()}_tmp"
|
|
170
|
+
drop_expr = exp.Drop(
|
|
171
|
+
this=exp.to_table(tmp_table, dialect=self._session.input_dialect),
|
|
172
|
+
kind="TABLE",
|
|
173
|
+
exists=True,
|
|
174
|
+
)
|
|
175
|
+
if mode == "append" or mode == "default":
|
|
176
|
+
try:
|
|
177
|
+
self._session.catalog.createTable(
|
|
178
|
+
tmp_table,
|
|
179
|
+
path=fs_prefix + filepath,
|
|
180
|
+
source=format,
|
|
181
|
+
**options,
|
|
182
|
+
)
|
|
183
|
+
self.byName.insertInto(tmp_table)
|
|
184
|
+
except ServerOperationError as e:
|
|
185
|
+
if "UNABLE_TO_INFER_SCHEMA" in str(e):
|
|
186
|
+
self.saveAsTable(
|
|
187
|
+
tmp_table,
|
|
188
|
+
format=format,
|
|
189
|
+
mode="error",
|
|
190
|
+
path=fs_prefix + filepath,
|
|
191
|
+
**options,
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
raise e
|
|
195
|
+
finally:
|
|
196
|
+
self._df.session._collect(drop_expr)
|
|
197
|
+
elif mode == "error" or mode == "errorifexists":
|
|
198
|
+
try:
|
|
199
|
+
self._session.catalog.createTable(
|
|
200
|
+
tmp_table,
|
|
201
|
+
path=fs_prefix + filepath,
|
|
202
|
+
source=format,
|
|
203
|
+
**options,
|
|
204
|
+
)
|
|
205
|
+
raise FileExistsError(f"Path already exists: {fs_prefix + filepath}")
|
|
206
|
+
except ServerOperationError as e:
|
|
207
|
+
if "UNABLE_TO_INFER_SCHEMA" in str(e):
|
|
208
|
+
self.saveAsTable(
|
|
209
|
+
tmp_table,
|
|
210
|
+
format=format,
|
|
211
|
+
mode=mode,
|
|
212
|
+
path=fs_prefix + filepath,
|
|
213
|
+
**options,
|
|
214
|
+
)
|
|
215
|
+
finally:
|
|
216
|
+
self._df.session._collect(drop_expr)
|
|
217
|
+
elif mode == "overwrite":
|
|
218
|
+
try:
|
|
219
|
+
self.saveAsTable(
|
|
220
|
+
tmp_table,
|
|
221
|
+
format=format,
|
|
222
|
+
mode=mode,
|
|
223
|
+
path=fs_prefix + filepath,
|
|
224
|
+
partitionBy=partition_by,
|
|
225
|
+
**options,
|
|
226
|
+
)
|
|
227
|
+
finally:
|
|
228
|
+
self._df.session._collect(drop_expr)
|
|
229
|
+
elif mode == "ignore":
|
|
230
|
+
pass
|
|
231
|
+
else:
|
|
232
|
+
raise RuntimeError(f"Unssuported mode: {mode}")
|
|
233
|
+
|
|
234
|
+
def insertInto(
|
|
235
|
+
self,
|
|
236
|
+
tableName: str,
|
|
237
|
+
overwrite: t.Optional[bool] = None,
|
|
238
|
+
replaceWhere: t.Optional[str] = None,
|
|
239
|
+
) -> Self:
|
|
240
|
+
from sqlframe.base.session import _BaseSession
|
|
241
|
+
|
|
242
|
+
tableName = normalize_string(tableName, from_dialect="input", is_table=True)
|
|
243
|
+
output_expression_container = exp.Insert(
|
|
244
|
+
**{
|
|
245
|
+
**{
|
|
246
|
+
"this": exp.to_table(tableName, dialect=_BaseSession().input_dialect),
|
|
247
|
+
"overwrite": overwrite,
|
|
248
|
+
},
|
|
249
|
+
**(
|
|
250
|
+
{
|
|
251
|
+
"by_name": self._by_name,
|
|
252
|
+
}
|
|
253
|
+
if self._by_name
|
|
254
|
+
else {}
|
|
255
|
+
),
|
|
256
|
+
**({"where": sg.parse_one(replaceWhere)} if replaceWhere else {}),
|
|
257
|
+
}
|
|
258
|
+
)
|
|
259
|
+
df = self._df.copy(output_expression_container=output_expression_container)
|
|
260
|
+
if self._by_name:
|
|
261
|
+
columns = self._session.catalog._schema.column_names(
|
|
262
|
+
tableName, only_visible=True, dialect=_BaseSession().input_dialect
|
|
263
|
+
)
|
|
264
|
+
df = df._convert_leaf_to_cte().select(*columns)
|
|
265
|
+
|
|
266
|
+
if self._session._has_connection:
|
|
267
|
+
df.collect()
|
|
268
|
+
return self.copy(_df=df)
|
|
269
|
+
|
|
40
270
|
def saveAsTable(
|
|
41
271
|
self,
|
|
42
272
|
name: str,
|
|
@@ -44,17 +274,39 @@ class DatabricksDataFrameWriter(
|
|
|
44
274
|
mode: t.Optional[str] = None,
|
|
45
275
|
partitionBy: t.Optional[t.Union[str, t.List[str]]] = None,
|
|
46
276
|
clusterBy: t.Optional[t.Union[str, t.List[str]]] = None,
|
|
47
|
-
**options,
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
277
|
+
**options: OptionalPrimitiveType,
|
|
278
|
+
):
|
|
279
|
+
format = (format or self._state_format_to_write or "delta").lower()
|
|
280
|
+
table_properties: t.Union[OptionalPrimitiveType, t.Dict[str, OptionalPrimitiveType]] = (
|
|
281
|
+
options.pop("properties", {})
|
|
282
|
+
)
|
|
283
|
+
path: OptionalPrimitiveType = options.pop("path", None)
|
|
284
|
+
if path is not None and not isinstance(path, str):
|
|
285
|
+
raise ValueError("path must be a string")
|
|
286
|
+
|
|
287
|
+
replace_where: OptionalPrimitiveType = options.pop("replaceWhere", None)
|
|
288
|
+
if replace_where is not None and not isinstance(replace_where, str):
|
|
289
|
+
raise ValueError("replaceWhere must be a string")
|
|
290
|
+
|
|
291
|
+
exists, replace, mode = None, None, str(mode or self._mode or "error")
|
|
52
292
|
if mode == "append":
|
|
53
|
-
|
|
293
|
+
self._session.catalog.createTable(
|
|
294
|
+
name,
|
|
295
|
+
path=path,
|
|
296
|
+
source=format,
|
|
297
|
+
schema=self._df.schema,
|
|
298
|
+
partitionBy=partitionBy,
|
|
299
|
+
clusterBy=clusterBy,
|
|
300
|
+
exists="true",
|
|
301
|
+
**options,
|
|
302
|
+
)
|
|
303
|
+
self.insertInto(name, replaceWhere=replace_where)
|
|
304
|
+
return
|
|
54
305
|
if mode == "ignore":
|
|
55
306
|
exists = True
|
|
56
307
|
if mode == "overwrite":
|
|
57
308
|
replace = True
|
|
309
|
+
|
|
58
310
|
name = normalize_string(name, from_dialect="input", is_table=True)
|
|
59
311
|
|
|
60
312
|
properties: t.List[exp.Expression] = []
|
|
@@ -79,9 +331,31 @@ class DatabricksDataFrameWriter(
|
|
|
79
331
|
)
|
|
80
332
|
)
|
|
81
333
|
|
|
334
|
+
format_options_str = ""
|
|
335
|
+
if format is not None:
|
|
336
|
+
properties.append(exp.FileFormatProperty(this=exp.Var(this=format.upper())))
|
|
337
|
+
format_options: dict[str, OptionalPrimitiveType] = {
|
|
338
|
+
key: f"'{val}'" for key, val in options.items() if val is not None
|
|
339
|
+
}
|
|
340
|
+
format_options_str = to_csv(format_options, " ")
|
|
341
|
+
|
|
342
|
+
if path is not None and isinstance(path, str):
|
|
343
|
+
properties.append(exp.LocationProperty(this=exp.convert(path)))
|
|
344
|
+
if replace and format != "delta":
|
|
345
|
+
replace = None
|
|
346
|
+
drop_expression = exp.Drop(
|
|
347
|
+
this=exp.to_table(name, dialect=self._session.input_dialect),
|
|
348
|
+
kind="TABLE",
|
|
349
|
+
exists=True,
|
|
350
|
+
)
|
|
351
|
+
if self._session._has_connection:
|
|
352
|
+
self._session._collect(drop_expression)
|
|
353
|
+
|
|
82
354
|
properties.extend(
|
|
83
355
|
exp.Property(this=sg.to_identifier(name), value=exp.convert(value))
|
|
84
|
-
for name, value in (
|
|
356
|
+
for name, value in (
|
|
357
|
+
(table_properties if isinstance(table_properties, dict) else {}).items()
|
|
358
|
+
)
|
|
85
359
|
)
|
|
86
360
|
|
|
87
361
|
output_expression_container = exp.Create(
|
|
@@ -91,7 +365,13 @@ class DatabricksDataFrameWriter(
|
|
|
91
365
|
replace=replace,
|
|
92
366
|
properties=exp.Properties(expressions=properties),
|
|
93
367
|
)
|
|
94
|
-
df = self._df.copy(output_expression_container=output_expression_container)
|
|
95
368
|
if self._session._has_connection:
|
|
96
|
-
|
|
97
|
-
|
|
369
|
+
create_sql = self._session._to_sql(output_expression_container, quote_identifiers=True)
|
|
370
|
+
df_sql = self._df.sql(self._session.execution_dialect, False, False)
|
|
371
|
+
sql = (
|
|
372
|
+
create_sql
|
|
373
|
+
+ (f" OPTIONS ({format_options_str})" if format_options_str else "")
|
|
374
|
+
+ " AS "
|
|
375
|
+
+ df_sql
|
|
376
|
+
)
|
|
377
|
+
self._session._collect(sql)
|
sqlframe/duckdb/catalog.py
CHANGED
|
@@ -9,6 +9,7 @@ from sqlglot import exp
|
|
|
9
9
|
|
|
10
10
|
from sqlframe.base.catalog import Function, _BaseCatalog
|
|
11
11
|
from sqlframe.base.mixins.catalog_mixins import (
|
|
12
|
+
CreateTableFromFunctionMixin,
|
|
12
13
|
GetCurrentCatalogFromFunctionMixin,
|
|
13
14
|
GetCurrentDatabaseFromFunctionMixin,
|
|
14
15
|
ListCatalogsFromInfoSchemaMixin,
|
|
@@ -23,18 +24,20 @@ from sqlframe.base.util import normalize_string, schema_, to_schema
|
|
|
23
24
|
if t.TYPE_CHECKING:
|
|
24
25
|
from sqlframe.duckdb.session import DuckDBSession # noqa
|
|
25
26
|
from sqlframe.duckdb.dataframe import DuckDBDataFrame # noqa
|
|
27
|
+
from sqlframe.duckdb.table import DuckDBTable # noqa
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
class DuckDBCatalog(
|
|
29
|
-
GetCurrentCatalogFromFunctionMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
30
|
-
SetCurrentCatalogFromUseMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
31
|
-
GetCurrentDatabaseFromFunctionMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
31
|
+
GetCurrentCatalogFromFunctionMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
32
|
+
SetCurrentCatalogFromUseMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
33
|
+
GetCurrentDatabaseFromFunctionMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
34
|
+
CreateTableFromFunctionMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
35
|
+
ListDatabasesFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
36
|
+
ListCatalogsFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
37
|
+
SetCurrentDatabaseFromUseMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
38
|
+
ListTablesFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
39
|
+
ListColumnsFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
40
|
+
_BaseCatalog["DuckDBSession", "DuckDBDataFrame", "DuckDBTable"],
|
|
38
41
|
):
|
|
39
42
|
TEMP_CATALOG_FILTER = exp.column("table_catalog").eq("temp")
|
|
40
43
|
|
sqlframe/postgres/catalog.py
CHANGED
|
@@ -9,6 +9,7 @@ from sqlglot import exp, parse_one
|
|
|
9
9
|
|
|
10
10
|
from sqlframe.base.catalog import Column, Function, _BaseCatalog
|
|
11
11
|
from sqlframe.base.mixins.catalog_mixins import (
|
|
12
|
+
CreateTableFromFunctionMixin,
|
|
12
13
|
GetCurrentCatalogFromFunctionMixin,
|
|
13
14
|
GetCurrentDatabaseFromFunctionMixin,
|
|
14
15
|
ListCatalogsFromInfoSchemaMixin,
|
|
@@ -21,16 +22,18 @@ from sqlframe.base.util import normalize_string, to_schema
|
|
|
21
22
|
if t.TYPE_CHECKING:
|
|
22
23
|
from sqlframe.postgres.session import PostgresSession # noqa
|
|
23
24
|
from sqlframe.postgres.dataframe import PostgresDataFrame # noqa
|
|
25
|
+
from sqlframe.postgres.table import PostgresTable # noqa
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
class PostgresCatalog(
|
|
27
|
-
GetCurrentCatalogFromFunctionMixin["PostgresSession", "PostgresDataFrame"],
|
|
28
|
-
GetCurrentDatabaseFromFunctionMixin["PostgresSession", "PostgresDataFrame"],
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
29
|
+
GetCurrentCatalogFromFunctionMixin["PostgresSession", "PostgresDataFrame", "PostgresTable"],
|
|
30
|
+
GetCurrentDatabaseFromFunctionMixin["PostgresSession", "PostgresDataFrame", "PostgresTable"],
|
|
31
|
+
CreateTableFromFunctionMixin["PostgresSession", "PostgresDataFrame", "PostgresTable"],
|
|
32
|
+
ListDatabasesFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame", "PostgresTable"],
|
|
33
|
+
ListCatalogsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame", "PostgresTable"],
|
|
34
|
+
SetCurrentDatabaseFromSearchPathMixin["PostgresSession", "PostgresDataFrame", "PostgresTable"],
|
|
35
|
+
ListTablesFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame", "PostgresTable"],
|
|
36
|
+
_BaseCatalog["PostgresSession", "PostgresDataFrame", "PostgresTable"],
|
|
34
37
|
):
|
|
35
38
|
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.column("current_catalog")
|
|
36
39
|
TEMP_SCHEMA_FILTER = exp.column("table_schema").like("pg_temp_%")
|
sqlframe/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|