sqlframe 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/__init__.py +0 -0
- sqlframe/_version.py +16 -0
- sqlframe/base/__init__.py +0 -0
- sqlframe/base/_typing.py +39 -0
- sqlframe/base/catalog.py +1163 -0
- sqlframe/base/column.py +388 -0
- sqlframe/base/dataframe.py +1519 -0
- sqlframe/base/decorators.py +51 -0
- sqlframe/base/exceptions.py +14 -0
- sqlframe/base/function_alternatives.py +1055 -0
- sqlframe/base/functions.py +1678 -0
- sqlframe/base/group.py +102 -0
- sqlframe/base/mixins/__init__.py +0 -0
- sqlframe/base/mixins/catalog_mixins.py +419 -0
- sqlframe/base/mixins/readwriter_mixins.py +118 -0
- sqlframe/base/normalize.py +84 -0
- sqlframe/base/operations.py +87 -0
- sqlframe/base/readerwriter.py +679 -0
- sqlframe/base/session.py +585 -0
- sqlframe/base/transforms.py +13 -0
- sqlframe/base/types.py +418 -0
- sqlframe/base/util.py +242 -0
- sqlframe/base/window.py +139 -0
- sqlframe/bigquery/__init__.py +23 -0
- sqlframe/bigquery/catalog.py +255 -0
- sqlframe/bigquery/column.py +1 -0
- sqlframe/bigquery/dataframe.py +54 -0
- sqlframe/bigquery/functions.py +378 -0
- sqlframe/bigquery/group.py +14 -0
- sqlframe/bigquery/readwriter.py +29 -0
- sqlframe/bigquery/session.py +89 -0
- sqlframe/bigquery/types.py +1 -0
- sqlframe/bigquery/window.py +1 -0
- sqlframe/duckdb/__init__.py +20 -0
- sqlframe/duckdb/catalog.py +108 -0
- sqlframe/duckdb/column.py +1 -0
- sqlframe/duckdb/dataframe.py +55 -0
- sqlframe/duckdb/functions.py +47 -0
- sqlframe/duckdb/group.py +14 -0
- sqlframe/duckdb/readwriter.py +111 -0
- sqlframe/duckdb/session.py +65 -0
- sqlframe/duckdb/types.py +1 -0
- sqlframe/duckdb/window.py +1 -0
- sqlframe/postgres/__init__.py +23 -0
- sqlframe/postgres/catalog.py +106 -0
- sqlframe/postgres/column.py +1 -0
- sqlframe/postgres/dataframe.py +54 -0
- sqlframe/postgres/functions.py +61 -0
- sqlframe/postgres/group.py +14 -0
- sqlframe/postgres/readwriter.py +29 -0
- sqlframe/postgres/session.py +68 -0
- sqlframe/postgres/types.py +1 -0
- sqlframe/postgres/window.py +1 -0
- sqlframe/redshift/__init__.py +23 -0
- sqlframe/redshift/catalog.py +127 -0
- sqlframe/redshift/column.py +1 -0
- sqlframe/redshift/dataframe.py +54 -0
- sqlframe/redshift/functions.py +18 -0
- sqlframe/redshift/group.py +14 -0
- sqlframe/redshift/readwriter.py +29 -0
- sqlframe/redshift/session.py +53 -0
- sqlframe/redshift/types.py +1 -0
- sqlframe/redshift/window.py +1 -0
- sqlframe/snowflake/__init__.py +26 -0
- sqlframe/snowflake/catalog.py +134 -0
- sqlframe/snowflake/column.py +1 -0
- sqlframe/snowflake/dataframe.py +54 -0
- sqlframe/snowflake/functions.py +18 -0
- sqlframe/snowflake/group.py +14 -0
- sqlframe/snowflake/readwriter.py +29 -0
- sqlframe/snowflake/session.py +53 -0
- sqlframe/snowflake/types.py +1 -0
- sqlframe/snowflake/window.py +1 -0
- sqlframe/spark/__init__.py +23 -0
- sqlframe/spark/catalog.py +1028 -0
- sqlframe/spark/column.py +1 -0
- sqlframe/spark/dataframe.py +54 -0
- sqlframe/spark/functions.py +22 -0
- sqlframe/spark/group.py +14 -0
- sqlframe/spark/readwriter.py +29 -0
- sqlframe/spark/session.py +90 -0
- sqlframe/spark/types.py +1 -0
- sqlframe/spark/window.py +1 -0
- sqlframe/standalone/__init__.py +26 -0
- sqlframe/standalone/catalog.py +13 -0
- sqlframe/standalone/column.py +1 -0
- sqlframe/standalone/dataframe.py +36 -0
- sqlframe/standalone/functions.py +1 -0
- sqlframe/standalone/group.py +14 -0
- sqlframe/standalone/readwriter.py +19 -0
- sqlframe/standalone/session.py +40 -0
- sqlframe/standalone/types.py +1 -0
- sqlframe/standalone/window.py +1 -0
- sqlframe-1.1.3.dist-info/LICENSE +21 -0
- sqlframe-1.1.3.dist-info/METADATA +172 -0
- sqlframe-1.1.3.dist-info/RECORD +98 -0
- sqlframe-1.1.3.dist-info/WHEEL +5 -0
- sqlframe-1.1.3.dist-info/top_level.txt +1 -0
sqlframe/base/session.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import datetime
|
|
6
|
+
import logging
|
|
7
|
+
import sys
|
|
8
|
+
import typing as t
|
|
9
|
+
import uuid
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
from functools import cached_property
|
|
12
|
+
|
|
13
|
+
import sqlglot
|
|
14
|
+
from sqlglot import Dialect, exp
|
|
15
|
+
from sqlglot.expressions import parse_identifier
|
|
16
|
+
from sqlglot.helper import seq_get
|
|
17
|
+
from sqlglot.optimizer import optimize
|
|
18
|
+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
|
19
|
+
from sqlglot.optimizer.qualify_columns import (
|
|
20
|
+
quote_identifiers as quote_identifiers_func,
|
|
21
|
+
)
|
|
22
|
+
from sqlglot.schema import MappingSchema
|
|
23
|
+
|
|
24
|
+
from sqlframe.base.catalog import _BaseCatalog
|
|
25
|
+
from sqlframe.base.dataframe import _BaseDataFrame
|
|
26
|
+
from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
|
|
27
|
+
from sqlframe.base.util import get_column_mapping_from_schema_input
|
|
28
|
+
|
|
29
|
+
if sys.version_info >= (3, 11):
|
|
30
|
+
from typing import Self
|
|
31
|
+
else:
|
|
32
|
+
from typing_extensions import Self
|
|
33
|
+
|
|
34
|
+
if t.TYPE_CHECKING:
|
|
35
|
+
import pandas as pd
|
|
36
|
+
from _typeshed.dbapi import DBAPIConnection, DBAPICursor
|
|
37
|
+
|
|
38
|
+
from sqlframe.base._typing import ColumnLiterals, SchemaInput
|
|
39
|
+
from sqlframe.base.types import Row, StructType
|
|
40
|
+
|
|
41
|
+
class DBAPIConnectionWithPandas(DBAPIConnection):
|
|
42
|
+
def cursor(self) -> DBAPICursorWithPandas: ...
|
|
43
|
+
|
|
44
|
+
class DBAPICursorWithPandas(DBAPICursor):
|
|
45
|
+
def fetchdf(self) -> pd.DataFrame: ...
|
|
46
|
+
|
|
47
|
+
CONN = t.TypeVar("CONN", bound=DBAPIConnectionWithPandas)
|
|
48
|
+
else:
|
|
49
|
+
CONN = t.TypeVar("CONN")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
|
|
56
|
+
READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
|
|
57
|
+
WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
|
|
58
|
+
DF = t.TypeVar("DF", bound=_BaseDataFrame)
|
|
59
|
+
|
|
60
|
+
_MISSING = "MISSING"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN]):
|
|
64
|
+
_instance = None
|
|
65
|
+
_reader: t.Type[READER]
|
|
66
|
+
_writer: t.Type[WRITER]
|
|
67
|
+
_catalog: t.Type[CATALOG]
|
|
68
|
+
_df: t.Type[DF]
|
|
69
|
+
|
|
70
|
+
SANITIZE_COLUMN_NAMES = False
|
|
71
|
+
DEFAULT_TIME_FORMAT = "yyyy-MM-dd HH:mm:ss"
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
conn: t.Optional[CONN] = None,
|
|
76
|
+
schema: t.Optional[MappingSchema] = None,
|
|
77
|
+
*args,
|
|
78
|
+
**kwargs,
|
|
79
|
+
):
|
|
80
|
+
if not hasattr(self, "input_dialect"):
|
|
81
|
+
self.input_dialect: Dialect = Dialect.get_or_raise(self.builder.DEFAULT_INPUT_DIALECT)
|
|
82
|
+
self.output_dialect: Dialect = Dialect.get_or_raise(self.builder.DEFAULT_OUTPUT_DIALECT)
|
|
83
|
+
self.known_ids: t.Set[str] = set()
|
|
84
|
+
self.known_branch_ids: t.Set[str] = set()
|
|
85
|
+
self.known_sequence_ids: t.Set[str] = set()
|
|
86
|
+
self.name_to_sequence_id_mapping: t.Dict[str, t.List[str]] = defaultdict(list)
|
|
87
|
+
self.incrementing_id: int = 1
|
|
88
|
+
self._last_loaded_file: t.Optional[str] = None
|
|
89
|
+
self.temp_views: t.Dict[str, DF] = {}
|
|
90
|
+
if not self._has_connection or conn:
|
|
91
|
+
self._connection = conn
|
|
92
|
+
if not getattr(self, "schema", None) or schema:
|
|
93
|
+
self._schema = schema
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def read(self) -> READER:
|
|
97
|
+
return self._reader(self)
|
|
98
|
+
|
|
99
|
+
@cached_property
|
|
100
|
+
def catalog(self) -> CATALOG:
|
|
101
|
+
return self._catalog(self, self._schema)
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def _conn(self) -> CONN:
|
|
105
|
+
if self._connection is None:
|
|
106
|
+
raise ValueError("Connection not set")
|
|
107
|
+
return self._connection
|
|
108
|
+
|
|
109
|
+
@cached_property
|
|
110
|
+
def _cur(self) -> DBAPICursorWithPandas:
|
|
111
|
+
return self._conn.cursor()
|
|
112
|
+
|
|
113
|
+
def _sanitize_column_name(self, name: str) -> str:
|
|
114
|
+
if self.SANITIZE_COLUMN_NAMES:
|
|
115
|
+
return name.replace("(", "_").replace(")", "_")
|
|
116
|
+
return name
|
|
117
|
+
|
|
118
|
+
def table(self, tableName: str) -> DF:
|
|
119
|
+
return self.read.table(tableName)
|
|
120
|
+
|
|
121
|
+
def _create_df(self, *args, **kwargs) -> DF:
|
|
122
|
+
return self._df(self, *args, **kwargs)
|
|
123
|
+
|
|
124
|
+
def __new__(cls, *args, **kwargs):
|
|
125
|
+
if _BaseSession._instance is None:
|
|
126
|
+
_BaseSession._instance = super().__new__(cls)
|
|
127
|
+
return _BaseSession._instance
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def _has_connection(self) -> bool:
|
|
131
|
+
return hasattr(self, "_connection") and bool(self._connection)
|
|
132
|
+
|
|
133
|
+
def range(self, *args):
|
|
134
|
+
start = 0
|
|
135
|
+
step = 1
|
|
136
|
+
numPartitions = None
|
|
137
|
+
if len(args) == 1:
|
|
138
|
+
end = args[0]
|
|
139
|
+
elif len(args) == 2:
|
|
140
|
+
start, end = args
|
|
141
|
+
elif len(args) == 3:
|
|
142
|
+
start, end, step = args
|
|
143
|
+
elif len(args) == 4:
|
|
144
|
+
start, end, step, numPartitions = args
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
"range() takes 1 to 4 positional arguments but {} were given".format(len(args))
|
|
148
|
+
)
|
|
149
|
+
if numPartitions is not None:
|
|
150
|
+
logger.warning("numPartitions is not supported")
|
|
151
|
+
return self.createDataFrame([[x] for x in range(start, end, step)], schema={"id": "long"})
|
|
152
|
+
|
|
153
|
+
def createDataFrame(
|
|
154
|
+
self,
|
|
155
|
+
data: t.Sequence[
|
|
156
|
+
t.Union[
|
|
157
|
+
t.Dict[str, ColumnLiterals],
|
|
158
|
+
t.List[ColumnLiterals],
|
|
159
|
+
t.Tuple[ColumnLiterals, ...],
|
|
160
|
+
ColumnLiterals,
|
|
161
|
+
]
|
|
162
|
+
],
|
|
163
|
+
schema: t.Optional[SchemaInput] = None,
|
|
164
|
+
samplingRatio: t.Optional[float] = None,
|
|
165
|
+
verifySchema: bool = False,
|
|
166
|
+
) -> DF:
|
|
167
|
+
from sqlframe.base import functions as F
|
|
168
|
+
from sqlframe.base.types import Row, StructType
|
|
169
|
+
|
|
170
|
+
if samplingRatio is not None or verifySchema:
|
|
171
|
+
raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
|
|
172
|
+
if (
|
|
173
|
+
schema is not None
|
|
174
|
+
and not isinstance(schema, dict)
|
|
175
|
+
and (
|
|
176
|
+
not isinstance(schema, (StructType, str, list, tuple))
|
|
177
|
+
or (isinstance(schema, (list, tuple)) and not isinstance(schema[0], str))
|
|
178
|
+
)
|
|
179
|
+
):
|
|
180
|
+
raise NotImplementedError("Only schema of either list or string of list supported")
|
|
181
|
+
|
|
182
|
+
column_mapping: t.Mapping[str, t.Optional[exp.DataType]]
|
|
183
|
+
if schema is not None:
|
|
184
|
+
column_mapping = get_column_mapping_from_schema_input(
|
|
185
|
+
schema, dialect=self.input_dialect
|
|
186
|
+
)
|
|
187
|
+
elif data:
|
|
188
|
+
if isinstance(data[0], Row):
|
|
189
|
+
column_mapping = {col_name.strip(): None for col_name in data[0].__fields__}
|
|
190
|
+
elif isinstance(data[0], dict):
|
|
191
|
+
column_mapping = {col_name.strip(): None for col_name in data[0]}
|
|
192
|
+
else:
|
|
193
|
+
column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} # type: ignore
|
|
194
|
+
else:
|
|
195
|
+
column_mapping = {}
|
|
196
|
+
|
|
197
|
+
column_mapping = {
|
|
198
|
+
normalize_identifiers(k, self.input_dialect).sql(dialect=self.input_dialect): v
|
|
199
|
+
for k, v in column_mapping.items()
|
|
200
|
+
}
|
|
201
|
+
empty_df = not data
|
|
202
|
+
rows = [[None] * len(column_mapping)] if empty_df else list(data) # type: ignore
|
|
203
|
+
|
|
204
|
+
def get_default_data_type(value: t.Any) -> t.Optional[str]:
|
|
205
|
+
if isinstance(value, Row):
|
|
206
|
+
row_types = []
|
|
207
|
+
for row_name, row_dtype in zip(value.__fields__, value):
|
|
208
|
+
default_type = get_default_data_type(row_dtype)
|
|
209
|
+
if not default_type:
|
|
210
|
+
continue
|
|
211
|
+
row_types.append((row_name, default_type))
|
|
212
|
+
return "struct<" + ", ".join(f"{k}: {v}" for (k, v) in row_types) + ">"
|
|
213
|
+
elif isinstance(value, dict):
|
|
214
|
+
sample_row = seq_get(list(value.items()), 0)
|
|
215
|
+
if not sample_row:
|
|
216
|
+
return None
|
|
217
|
+
key, value = sample_row
|
|
218
|
+
default_key = get_default_data_type(key)
|
|
219
|
+
default_value = get_default_data_type(value)
|
|
220
|
+
if not default_key or not default_value:
|
|
221
|
+
return None
|
|
222
|
+
return f"map<{default_key}, {default_value}>"
|
|
223
|
+
elif isinstance(value, (list, set, tuple)):
|
|
224
|
+
if not value:
|
|
225
|
+
return None
|
|
226
|
+
default_type = get_default_data_type(next(iter(value)))
|
|
227
|
+
if not default_type:
|
|
228
|
+
return None
|
|
229
|
+
return f"array<{default_type}>"
|
|
230
|
+
elif isinstance(value, bool):
|
|
231
|
+
return "boolean"
|
|
232
|
+
elif isinstance(value, bytes):
|
|
233
|
+
return "binary"
|
|
234
|
+
elif isinstance(value, int):
|
|
235
|
+
return "bigint"
|
|
236
|
+
elif isinstance(value, float):
|
|
237
|
+
return "double"
|
|
238
|
+
elif isinstance(value, datetime.datetime):
|
|
239
|
+
if value.tzinfo:
|
|
240
|
+
return "timestamptz"
|
|
241
|
+
return "timestamp"
|
|
242
|
+
elif isinstance(value, datetime.date):
|
|
243
|
+
return "date"
|
|
244
|
+
elif isinstance(value, str):
|
|
245
|
+
return "string"
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
updated_mapping: t.Dict[str, t.Optional[exp.DataType]] = {}
|
|
249
|
+
sample_row = rows[0]
|
|
250
|
+
for i, (name, dtype) in enumerate(column_mapping.items()):
|
|
251
|
+
if dtype is not None:
|
|
252
|
+
updated_mapping[name] = dtype
|
|
253
|
+
continue
|
|
254
|
+
if isinstance(sample_row, Row):
|
|
255
|
+
sample_row = sample_row.asDict()
|
|
256
|
+
if isinstance(sample_row, dict):
|
|
257
|
+
default_data_type = get_default_data_type(sample_row[name])
|
|
258
|
+
updated_mapping[name] = (
|
|
259
|
+
exp.DataType.build(default_data_type, dialect="spark")
|
|
260
|
+
if default_data_type
|
|
261
|
+
else None
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
default_data_type = get_default_data_type(sample_row[i])
|
|
265
|
+
updated_mapping[name] = (
|
|
266
|
+
exp.DataType.build(default_data_type, dialect="spark")
|
|
267
|
+
if default_data_type
|
|
268
|
+
else None
|
|
269
|
+
)
|
|
270
|
+
column_mapping = updated_mapping
|
|
271
|
+
data_expressions = []
|
|
272
|
+
for row in rows:
|
|
273
|
+
if isinstance(row, (list, tuple, dict)):
|
|
274
|
+
if not row:
|
|
275
|
+
data_expressions.append(exp.tuple_(exp.Null()))
|
|
276
|
+
continue
|
|
277
|
+
if isinstance(row, Row):
|
|
278
|
+
row = row.asDict()
|
|
279
|
+
if isinstance(row, dict):
|
|
280
|
+
row = row.values() # type: ignore
|
|
281
|
+
data_expressions.append(exp.tuple_(*[F.lit(x).expression for x in row]))
|
|
282
|
+
else:
|
|
283
|
+
data_expressions.append(exp.tuple_(*[F.lit(row).expression]))
|
|
284
|
+
|
|
285
|
+
if column_mapping:
|
|
286
|
+
sel_columns = [
|
|
287
|
+
(
|
|
288
|
+
F.col(name).cast(data_type).alias(name).expression
|
|
289
|
+
if data_type is not None
|
|
290
|
+
else F.col(name).expression
|
|
291
|
+
)
|
|
292
|
+
for name, data_type in column_mapping.items()
|
|
293
|
+
]
|
|
294
|
+
else:
|
|
295
|
+
sel_columns = [F.lit(None).expression]
|
|
296
|
+
|
|
297
|
+
select_kwargs = {
|
|
298
|
+
"expressions": sel_columns,
|
|
299
|
+
"from": exp.From(
|
|
300
|
+
this=exp.Values(
|
|
301
|
+
expressions=data_expressions,
|
|
302
|
+
alias=exp.TableAlias(
|
|
303
|
+
this=exp.to_identifier(self._auto_incrementing_name),
|
|
304
|
+
columns=[
|
|
305
|
+
exp.parse_identifier(col_name, dialect=self.input_dialect)
|
|
306
|
+
for col_name in column_mapping
|
|
307
|
+
],
|
|
308
|
+
),
|
|
309
|
+
),
|
|
310
|
+
),
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
sel_expression = exp.Select(**select_kwargs)
|
|
314
|
+
if empty_df:
|
|
315
|
+
sel_expression = sel_expression.where(exp.false())
|
|
316
|
+
return self._create_df(sel_expression)
|
|
317
|
+
|
|
318
|
+
def sql(self, sqlQuery: t.Union[str, exp.Expression], optimize: bool = True) -> DF:
|
|
319
|
+
expression = (
|
|
320
|
+
sqlglot.parse_one(sqlQuery, read=self.input_dialect)
|
|
321
|
+
if isinstance(sqlQuery, str)
|
|
322
|
+
else sqlQuery
|
|
323
|
+
)
|
|
324
|
+
if optimize:
|
|
325
|
+
expression = self._optimize(expression)
|
|
326
|
+
if self.temp_views:
|
|
327
|
+
replacement_mapping = {}
|
|
328
|
+
for table in expression.find_all(exp.Table):
|
|
329
|
+
if not (df := self.temp_views.get(table.name)):
|
|
330
|
+
continue
|
|
331
|
+
expression_ctes = {cte.alias_or_name: cte for cte in expression.ctes} # type: ignore
|
|
332
|
+
replacement_mapping[table] = df.expression.ctes[-1].alias_or_name
|
|
333
|
+
ctes_to_add = []
|
|
334
|
+
for cte in df.expression.ctes:
|
|
335
|
+
if cte.alias_or_name not in expression_ctes:
|
|
336
|
+
ctes_to_add.append(cte)
|
|
337
|
+
expression.set("with", exp.With(expressions=expression.ctes + ctes_to_add)) # type: ignore
|
|
338
|
+
|
|
339
|
+
def replace_temp_view_name_with_cte(node: exp.Expression) -> exp.Expression:
|
|
340
|
+
if isinstance(node, exp.Table):
|
|
341
|
+
if node in replacement_mapping:
|
|
342
|
+
node.set("this", exp.to_identifier(replacement_mapping[node]))
|
|
343
|
+
return node
|
|
344
|
+
|
|
345
|
+
if replacement_mapping:
|
|
346
|
+
expression = expression.transform(replace_temp_view_name_with_cte)
|
|
347
|
+
|
|
348
|
+
if isinstance(expression, exp.Select):
|
|
349
|
+
df = self._create_df(expression)
|
|
350
|
+
df = df._convert_leaf_to_cte()
|
|
351
|
+
elif isinstance(expression, (exp.Create, exp.Insert)):
|
|
352
|
+
select_expression = expression.expression.copy()
|
|
353
|
+
if isinstance(expression, exp.Insert):
|
|
354
|
+
select_expression.set("with", expression.args.get("with"))
|
|
355
|
+
expression.set("with", None)
|
|
356
|
+
del expression.args["expression"]
|
|
357
|
+
df = self._create_df(select_expression, output_expression_container=expression) # type: ignore
|
|
358
|
+
df = df._convert_leaf_to_cte()
|
|
359
|
+
else:
|
|
360
|
+
raise ValueError(
|
|
361
|
+
"Unknown expression type provided in the SQL. Please create an issue with the SQL."
|
|
362
|
+
)
|
|
363
|
+
return df
|
|
364
|
+
|
|
365
|
+
@property
|
|
366
|
+
def _auto_incrementing_name(self) -> str:
|
|
367
|
+
name = f"a{self.incrementing_id}"
|
|
368
|
+
self.incrementing_id += 1
|
|
369
|
+
return name
|
|
370
|
+
|
|
371
|
+
@property
|
|
372
|
+
def _random_branch_id(self) -> str:
|
|
373
|
+
id = self._random_id
|
|
374
|
+
self.known_branch_ids.add(id)
|
|
375
|
+
return id
|
|
376
|
+
|
|
377
|
+
@property
|
|
378
|
+
def _random_sequence_id(self):
|
|
379
|
+
id = self._random_id
|
|
380
|
+
self.known_sequence_ids.add(id)
|
|
381
|
+
return id
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def _random_id(self) -> str:
|
|
385
|
+
id = "r" + uuid.uuid4().hex
|
|
386
|
+
normalized_id = self._normalize_string(id)
|
|
387
|
+
self.known_ids.add(normalized_id)
|
|
388
|
+
return normalized_id
|
|
389
|
+
|
|
390
|
+
@property
|
|
391
|
+
def _join_hint_names(self) -> t.Set[str]:
|
|
392
|
+
return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
|
|
393
|
+
|
|
394
|
+
def _normalize_string(self, value: str) -> str:
|
|
395
|
+
expression = parse_identifier(value, dialect=self.input_dialect)
|
|
396
|
+
normalize_identifiers(expression, dialect=self.input_dialect)
|
|
397
|
+
return expression.sql(dialect=self.input_dialect)
|
|
398
|
+
|
|
399
|
+
def _add_alias_to_mapping(self, name: str, sequence_id: str):
|
|
400
|
+
self.name_to_sequence_id_mapping[self._normalize_string(name)].append(sequence_id)
|
|
401
|
+
|
|
402
|
+
def _to_sql(self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True) -> str:
|
|
403
|
+
if isinstance(sql, exp.Expression):
|
|
404
|
+
expression = sql.copy()
|
|
405
|
+
if quote_identifiers:
|
|
406
|
+
normalize_identifiers(expression, dialect=self.input_dialect)
|
|
407
|
+
quote_identifiers_func(expression, dialect=self.input_dialect)
|
|
408
|
+
sql = expression.sql(dialect=self.output_dialect)
|
|
409
|
+
return t.cast(str, sql)
|
|
410
|
+
|
|
411
|
+
def _optimize(
|
|
412
|
+
self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
|
|
413
|
+
) -> exp.Expression:
|
|
414
|
+
dialect = dialect or self.output_dialect
|
|
415
|
+
quote_identifiers_func(expression, dialect=dialect)
|
|
416
|
+
return optimize(expression, dialect=dialect, schema=self.catalog._schema)
|
|
417
|
+
|
|
418
|
+
def _execute(
|
|
419
|
+
self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
|
|
420
|
+
) -> None:
|
|
421
|
+
self._cur.execute(self._to_sql(sql, quote_identifiers=quote_identifiers))
|
|
422
|
+
|
|
423
|
+
@classmethod
|
|
424
|
+
def _try_get_map(cls, value: t.Any) -> t.Optional[t.Dict[str, t.Any]]:
|
|
425
|
+
return None if not isinstance(value, dict) else value
|
|
426
|
+
|
|
427
|
+
@classmethod
|
|
428
|
+
def _to_value(cls, value: t.Any) -> t.Any:
|
|
429
|
+
if (map_value := cls._try_get_map(value)) is not None:
|
|
430
|
+
return map_value
|
|
431
|
+
elif isinstance(value, dict):
|
|
432
|
+
return cls._to_row(list(value.keys()), list(value.values()))
|
|
433
|
+
elif isinstance(value, (list, set, tuple)) and value:
|
|
434
|
+
return [cls._to_value(x) for x in value]
|
|
435
|
+
return value
|
|
436
|
+
|
|
437
|
+
@classmethod
|
|
438
|
+
def _to_row(cls, columns: t.List[str], values: t.Iterable[t.Any]) -> Row:
|
|
439
|
+
from sqlframe.base.types import Row
|
|
440
|
+
|
|
441
|
+
converted_values = []
|
|
442
|
+
for value in values:
|
|
443
|
+
converted_values.append(cls._to_value(value))
|
|
444
|
+
return Row(**dict(zip(columns, converted_values)))
|
|
445
|
+
|
|
446
|
+
def _fetch_rows(
|
|
447
|
+
self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
|
|
448
|
+
) -> t.List[Row]:
|
|
449
|
+
from sqlframe.base.types import Row
|
|
450
|
+
|
|
451
|
+
def _dict_to_row(row: t.Dict[str, t.Any]) -> Row:
|
|
452
|
+
for key, value in row.items():
|
|
453
|
+
if isinstance(value, dict):
|
|
454
|
+
row[key] = _dict_to_row(value)
|
|
455
|
+
return Row(**row)
|
|
456
|
+
|
|
457
|
+
self._execute(sql, quote_identifiers=quote_identifiers)
|
|
458
|
+
result = self._cur.fetchall()
|
|
459
|
+
if not self._cur.description:
|
|
460
|
+
return []
|
|
461
|
+
columns = [x[0] for x in self._cur.description]
|
|
462
|
+
return [self._to_row(columns, row) for row in result]
|
|
463
|
+
|
|
464
|
+
def _fetchdf(
|
|
465
|
+
self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
|
|
466
|
+
) -> pd.DataFrame:
|
|
467
|
+
from pandas.io.sql import read_sql_query
|
|
468
|
+
|
|
469
|
+
return read_sql_query(self._to_sql(sql, quote_identifiers=quote_identifiers), self._conn)
|
|
470
|
+
|
|
471
|
+
@property
|
|
472
|
+
def _is_standalone(self) -> bool:
|
|
473
|
+
from sqlframe.standalone.session import StandaloneSession
|
|
474
|
+
|
|
475
|
+
return isinstance(self, StandaloneSession)
|
|
476
|
+
|
|
477
|
+
@property
|
|
478
|
+
def _is_duckdb(self) -> bool:
|
|
479
|
+
from sqlframe.duckdb.session import DuckDBSession
|
|
480
|
+
|
|
481
|
+
return isinstance(self, DuckDBSession)
|
|
482
|
+
|
|
483
|
+
@property
|
|
484
|
+
def _is_postgres(self) -> bool:
|
|
485
|
+
from sqlframe.postgres.session import PostgresSession
|
|
486
|
+
|
|
487
|
+
return isinstance(self, PostgresSession)
|
|
488
|
+
|
|
489
|
+
@property
|
|
490
|
+
def _is_spark(self) -> bool:
|
|
491
|
+
from sqlframe.spark.session import SparkSession
|
|
492
|
+
|
|
493
|
+
return isinstance(self, SparkSession)
|
|
494
|
+
|
|
495
|
+
@property
|
|
496
|
+
def _is_bigquery(self) -> bool:
|
|
497
|
+
from sqlframe.bigquery.session import BigQuerySession
|
|
498
|
+
|
|
499
|
+
return isinstance(self, BigQuerySession)
|
|
500
|
+
|
|
501
|
+
@property
|
|
502
|
+
def _is_redshift(self) -> bool:
|
|
503
|
+
from sqlframe.redshift.session import RedshiftSession
|
|
504
|
+
|
|
505
|
+
return isinstance(self, RedshiftSession)
|
|
506
|
+
|
|
507
|
+
@property
|
|
508
|
+
def _is_snowflake(self) -> bool:
|
|
509
|
+
from sqlframe.snowflake.session import SnowflakeSession
|
|
510
|
+
|
|
511
|
+
return isinstance(self, SnowflakeSession)
|
|
512
|
+
|
|
513
|
+
class Builder:
|
|
514
|
+
SQLFRAME_INPUT_DIALECT_KEY = "sqlframe.input.dialect"
|
|
515
|
+
SQLFRAME_OUTPUT_DIALECT_KEY = "sqlframe.output.dialect"
|
|
516
|
+
SQLFRAME_CONN_KEY = "sqlframe.conn"
|
|
517
|
+
SQLFRAME_SCHEMA_KEY = "sqlframe.schema"
|
|
518
|
+
DEFAULT_INPUT_DIALECT = "spark"
|
|
519
|
+
DEFAULT_OUTPUT_DIALECT = "spark"
|
|
520
|
+
|
|
521
|
+
def __init__(self):
|
|
522
|
+
self.input_dialect = self.DEFAULT_INPUT_DIALECT
|
|
523
|
+
self.output_dialect = self.DEFAULT_OUTPUT_DIALECT
|
|
524
|
+
self._conn = None
|
|
525
|
+
self._session_kwargs = {}
|
|
526
|
+
|
|
527
|
+
def __getattr__(self, item) -> Self:
|
|
528
|
+
return self
|
|
529
|
+
|
|
530
|
+
def __call__(self, *args, **kwargs):
|
|
531
|
+
return self
|
|
532
|
+
|
|
533
|
+
@property
|
|
534
|
+
def session(self) -> _BaseSession:
|
|
535
|
+
return _BaseSession(**self._session_kwargs)
|
|
536
|
+
|
|
537
|
+
def getOrCreate(self) -> _BaseSession:
|
|
538
|
+
self._set_session_properties()
|
|
539
|
+
return self.session
|
|
540
|
+
|
|
541
|
+
def _set_config(
|
|
542
|
+
self,
|
|
543
|
+
key: t.Optional[str] = None,
|
|
544
|
+
value: t.Optional[t.Any] = None,
|
|
545
|
+
*,
|
|
546
|
+
map: t.Optional[t.Dict[str, t.Any]] = None,
|
|
547
|
+
) -> None:
|
|
548
|
+
if value is not None:
|
|
549
|
+
if key == self.SQLFRAME_INPUT_DIALECT_KEY:
|
|
550
|
+
self.input_dialect = value
|
|
551
|
+
elif key == self.SQLFRAME_OUTPUT_DIALECT_KEY:
|
|
552
|
+
self.output_dialect = value
|
|
553
|
+
elif key == self.SQLFRAME_CONN_KEY:
|
|
554
|
+
self._session_kwargs["conn"] = value
|
|
555
|
+
elif key == self.SQLFRAME_SCHEMA_KEY:
|
|
556
|
+
self._session_kwargs["schema"] = value
|
|
557
|
+
else:
|
|
558
|
+
self._session_kwargs[key] = value
|
|
559
|
+
if map:
|
|
560
|
+
if self.SQLFRAME_INPUT_DIALECT_KEY in map:
|
|
561
|
+
self.input_dialect = map[self.SQLFRAME_INPUT_DIALECT_KEY]
|
|
562
|
+
if self.SQLFRAME_OUTPUT_DIALECT_KEY in map:
|
|
563
|
+
self.output_dialect = map[self.SQLFRAME_OUTPUT_DIALECT_KEY]
|
|
564
|
+
if self.SQLFRAME_CONN_KEY in map:
|
|
565
|
+
self._session_kwargs["conn"] = map[self.SQLFRAME_CONN_KEY]
|
|
566
|
+
if self.SQLFRAME_SCHEMA_KEY in map:
|
|
567
|
+
self._session_kwargs["schema"] = map[self.SQLFRAME_SCHEMA_KEY]
|
|
568
|
+
|
|
569
|
+
def config(
|
|
570
|
+
self,
|
|
571
|
+
key: t.Optional[str] = None,
|
|
572
|
+
value: t.Optional[t.Any] = None,
|
|
573
|
+
*,
|
|
574
|
+
map: t.Optional[t.Dict[str, t.Any]] = None,
|
|
575
|
+
) -> Self:
|
|
576
|
+
self._set_config(key, value, map=map)
|
|
577
|
+
return self
|
|
578
|
+
|
|
579
|
+
def _set_session_properties(self) -> None:
|
|
580
|
+
self.session.input_dialect = Dialect.get_or_raise(self.input_dialect)
|
|
581
|
+
self.session.output_dialect = Dialect.get_or_raise(self.output_dialect)
|
|
582
|
+
if not self.session._connection:
|
|
583
|
+
self.session._connection = self._conn
|
|
584
|
+
|
|
585
|
+
builder = Builder()
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
import typing as t
|
|
4
|
+
|
|
5
|
+
from sqlglot import expressions as exp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def replace_id_value(
|
|
9
|
+
node: exp.Expression, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]
|
|
10
|
+
) -> exp.Expression:
|
|
11
|
+
if isinstance(node, exp.Identifier) and node in replacement_mapping:
|
|
12
|
+
node = node.replace(replacement_mapping[node].copy())
|
|
13
|
+
return node
|