sqlframe 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sqlframe/__init__.py +0 -0
  2. sqlframe/_version.py +16 -0
  3. sqlframe/base/__init__.py +0 -0
  4. sqlframe/base/_typing.py +39 -0
  5. sqlframe/base/catalog.py +1163 -0
  6. sqlframe/base/column.py +388 -0
  7. sqlframe/base/dataframe.py +1519 -0
  8. sqlframe/base/decorators.py +51 -0
  9. sqlframe/base/exceptions.py +14 -0
  10. sqlframe/base/function_alternatives.py +1055 -0
  11. sqlframe/base/functions.py +1678 -0
  12. sqlframe/base/group.py +102 -0
  13. sqlframe/base/mixins/__init__.py +0 -0
  14. sqlframe/base/mixins/catalog_mixins.py +419 -0
  15. sqlframe/base/mixins/readwriter_mixins.py +118 -0
  16. sqlframe/base/normalize.py +84 -0
  17. sqlframe/base/operations.py +87 -0
  18. sqlframe/base/readerwriter.py +679 -0
  19. sqlframe/base/session.py +585 -0
  20. sqlframe/base/transforms.py +13 -0
  21. sqlframe/base/types.py +418 -0
  22. sqlframe/base/util.py +242 -0
  23. sqlframe/base/window.py +139 -0
  24. sqlframe/bigquery/__init__.py +23 -0
  25. sqlframe/bigquery/catalog.py +255 -0
  26. sqlframe/bigquery/column.py +1 -0
  27. sqlframe/bigquery/dataframe.py +54 -0
  28. sqlframe/bigquery/functions.py +378 -0
  29. sqlframe/bigquery/group.py +14 -0
  30. sqlframe/bigquery/readwriter.py +29 -0
  31. sqlframe/bigquery/session.py +89 -0
  32. sqlframe/bigquery/types.py +1 -0
  33. sqlframe/bigquery/window.py +1 -0
  34. sqlframe/duckdb/__init__.py +20 -0
  35. sqlframe/duckdb/catalog.py +108 -0
  36. sqlframe/duckdb/column.py +1 -0
  37. sqlframe/duckdb/dataframe.py +55 -0
  38. sqlframe/duckdb/functions.py +47 -0
  39. sqlframe/duckdb/group.py +14 -0
  40. sqlframe/duckdb/readwriter.py +111 -0
  41. sqlframe/duckdb/session.py +65 -0
  42. sqlframe/duckdb/types.py +1 -0
  43. sqlframe/duckdb/window.py +1 -0
  44. sqlframe/postgres/__init__.py +23 -0
  45. sqlframe/postgres/catalog.py +106 -0
  46. sqlframe/postgres/column.py +1 -0
  47. sqlframe/postgres/dataframe.py +54 -0
  48. sqlframe/postgres/functions.py +61 -0
  49. sqlframe/postgres/group.py +14 -0
  50. sqlframe/postgres/readwriter.py +29 -0
  51. sqlframe/postgres/session.py +68 -0
  52. sqlframe/postgres/types.py +1 -0
  53. sqlframe/postgres/window.py +1 -0
  54. sqlframe/redshift/__init__.py +23 -0
  55. sqlframe/redshift/catalog.py +127 -0
  56. sqlframe/redshift/column.py +1 -0
  57. sqlframe/redshift/dataframe.py +54 -0
  58. sqlframe/redshift/functions.py +18 -0
  59. sqlframe/redshift/group.py +14 -0
  60. sqlframe/redshift/readwriter.py +29 -0
  61. sqlframe/redshift/session.py +53 -0
  62. sqlframe/redshift/types.py +1 -0
  63. sqlframe/redshift/window.py +1 -0
  64. sqlframe/snowflake/__init__.py +26 -0
  65. sqlframe/snowflake/catalog.py +134 -0
  66. sqlframe/snowflake/column.py +1 -0
  67. sqlframe/snowflake/dataframe.py +54 -0
  68. sqlframe/snowflake/functions.py +18 -0
  69. sqlframe/snowflake/group.py +14 -0
  70. sqlframe/snowflake/readwriter.py +29 -0
  71. sqlframe/snowflake/session.py +53 -0
  72. sqlframe/snowflake/types.py +1 -0
  73. sqlframe/snowflake/window.py +1 -0
  74. sqlframe/spark/__init__.py +23 -0
  75. sqlframe/spark/catalog.py +1028 -0
  76. sqlframe/spark/column.py +1 -0
  77. sqlframe/spark/dataframe.py +54 -0
  78. sqlframe/spark/functions.py +22 -0
  79. sqlframe/spark/group.py +14 -0
  80. sqlframe/spark/readwriter.py +29 -0
  81. sqlframe/spark/session.py +90 -0
  82. sqlframe/spark/types.py +1 -0
  83. sqlframe/spark/window.py +1 -0
  84. sqlframe/standalone/__init__.py +26 -0
  85. sqlframe/standalone/catalog.py +13 -0
  86. sqlframe/standalone/column.py +1 -0
  87. sqlframe/standalone/dataframe.py +36 -0
  88. sqlframe/standalone/functions.py +1 -0
  89. sqlframe/standalone/group.py +14 -0
  90. sqlframe/standalone/readwriter.py +19 -0
  91. sqlframe/standalone/session.py +40 -0
  92. sqlframe/standalone/types.py +1 -0
  93. sqlframe/standalone/window.py +1 -0
  94. sqlframe-1.1.3.dist-info/LICENSE +21 -0
  95. sqlframe-1.1.3.dist-info/METADATA +172 -0
  96. sqlframe-1.1.3.dist-info/RECORD +98 -0
  97. sqlframe-1.1.3.dist-info/WHEEL +5 -0
  98. sqlframe-1.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,585 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import datetime
6
+ import logging
7
+ import sys
8
+ import typing as t
9
+ import uuid
10
+ from collections import defaultdict
11
+ from functools import cached_property
12
+
13
+ import sqlglot
14
+ from sqlglot import Dialect, exp
15
+ from sqlglot.expressions import parse_identifier
16
+ from sqlglot.helper import seq_get
17
+ from sqlglot.optimizer import optimize
18
+ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
19
+ from sqlglot.optimizer.qualify_columns import (
20
+ quote_identifiers as quote_identifiers_func,
21
+ )
22
+ from sqlglot.schema import MappingSchema
23
+
24
+ from sqlframe.base.catalog import _BaseCatalog
25
+ from sqlframe.base.dataframe import _BaseDataFrame
26
+ from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
27
+ from sqlframe.base.util import get_column_mapping_from_schema_input
28
+
29
+ if sys.version_info >= (3, 11):
30
+ from typing import Self
31
+ else:
32
+ from typing_extensions import Self
33
+
34
+ if t.TYPE_CHECKING:
35
+ import pandas as pd
36
+ from _typeshed.dbapi import DBAPIConnection, DBAPICursor
37
+
38
+ from sqlframe.base._typing import ColumnLiterals, SchemaInput
39
+ from sqlframe.base.types import Row, StructType
40
+
41
+ class DBAPIConnectionWithPandas(DBAPIConnection):
42
+ def cursor(self) -> DBAPICursorWithPandas: ...
43
+
44
+ class DBAPICursorWithPandas(DBAPICursor):
45
+ def fetchdf(self) -> pd.DataFrame: ...
46
+
47
+ CONN = t.TypeVar("CONN", bound=DBAPIConnectionWithPandas)
48
+ else:
49
+ CONN = t.TypeVar("CONN")
50
+
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ CATALOG = t.TypeVar("CATALOG", bound=_BaseCatalog)
56
+ READER = t.TypeVar("READER", bound=_BaseDataFrameReader)
57
+ WRITER = t.TypeVar("WRITER", bound=_BaseDataFrameWriter)
58
+ DF = t.TypeVar("DF", bound=_BaseDataFrame)
59
+
60
+ _MISSING = "MISSING"
61
+
62
+
63
+ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN]):
64
+ _instance = None
65
+ _reader: t.Type[READER]
66
+ _writer: t.Type[WRITER]
67
+ _catalog: t.Type[CATALOG]
68
+ _df: t.Type[DF]
69
+
70
+ SANITIZE_COLUMN_NAMES = False
71
+ DEFAULT_TIME_FORMAT = "yyyy-MM-dd HH:mm:ss"
72
+
73
+ def __init__(
74
+ self,
75
+ conn: t.Optional[CONN] = None,
76
+ schema: t.Optional[MappingSchema] = None,
77
+ *args,
78
+ **kwargs,
79
+ ):
80
+ if not hasattr(self, "input_dialect"):
81
+ self.input_dialect: Dialect = Dialect.get_or_raise(self.builder.DEFAULT_INPUT_DIALECT)
82
+ self.output_dialect: Dialect = Dialect.get_or_raise(self.builder.DEFAULT_OUTPUT_DIALECT)
83
+ self.known_ids: t.Set[str] = set()
84
+ self.known_branch_ids: t.Set[str] = set()
85
+ self.known_sequence_ids: t.Set[str] = set()
86
+ self.name_to_sequence_id_mapping: t.Dict[str, t.List[str]] = defaultdict(list)
87
+ self.incrementing_id: int = 1
88
+ self._last_loaded_file: t.Optional[str] = None
89
+ self.temp_views: t.Dict[str, DF] = {}
90
+ if not self._has_connection or conn:
91
+ self._connection = conn
92
+ if not getattr(self, "schema", None) or schema:
93
+ self._schema = schema
94
+
95
+ @property
96
+ def read(self) -> READER:
97
+ return self._reader(self)
98
+
99
+ @cached_property
100
+ def catalog(self) -> CATALOG:
101
+ return self._catalog(self, self._schema)
102
+
103
+ @property
104
+ def _conn(self) -> CONN:
105
+ if self._connection is None:
106
+ raise ValueError("Connection not set")
107
+ return self._connection
108
+
109
+ @cached_property
110
+ def _cur(self) -> DBAPICursorWithPandas:
111
+ return self._conn.cursor()
112
+
113
+ def _sanitize_column_name(self, name: str) -> str:
114
+ if self.SANITIZE_COLUMN_NAMES:
115
+ return name.replace("(", "_").replace(")", "_")
116
+ return name
117
+
118
+ def table(self, tableName: str) -> DF:
119
+ return self.read.table(tableName)
120
+
121
+ def _create_df(self, *args, **kwargs) -> DF:
122
+ return self._df(self, *args, **kwargs)
123
+
124
+ def __new__(cls, *args, **kwargs):
125
+ if _BaseSession._instance is None:
126
+ _BaseSession._instance = super().__new__(cls)
127
+ return _BaseSession._instance
128
+
129
+ @property
130
+ def _has_connection(self) -> bool:
131
+ return hasattr(self, "_connection") and bool(self._connection)
132
+
133
+ def range(self, *args):
134
+ start = 0
135
+ step = 1
136
+ numPartitions = None
137
+ if len(args) == 1:
138
+ end = args[0]
139
+ elif len(args) == 2:
140
+ start, end = args
141
+ elif len(args) == 3:
142
+ start, end, step = args
143
+ elif len(args) == 4:
144
+ start, end, step, numPartitions = args
145
+ else:
146
+ raise ValueError(
147
+ "range() takes 1 to 4 positional arguments but {} were given".format(len(args))
148
+ )
149
+ if numPartitions is not None:
150
+ logger.warning("numPartitions is not supported")
151
+ return self.createDataFrame([[x] for x in range(start, end, step)], schema={"id": "long"})
152
+
153
+ def createDataFrame(
154
+ self,
155
+ data: t.Sequence[
156
+ t.Union[
157
+ t.Dict[str, ColumnLiterals],
158
+ t.List[ColumnLiterals],
159
+ t.Tuple[ColumnLiterals, ...],
160
+ ColumnLiterals,
161
+ ]
162
+ ],
163
+ schema: t.Optional[SchemaInput] = None,
164
+ samplingRatio: t.Optional[float] = None,
165
+ verifySchema: bool = False,
166
+ ) -> DF:
167
+ from sqlframe.base import functions as F
168
+ from sqlframe.base.types import Row, StructType
169
+
170
+ if samplingRatio is not None or verifySchema:
171
+ raise NotImplementedError("Sampling Ratio and Verify Schema are not supported")
172
+ if (
173
+ schema is not None
174
+ and not isinstance(schema, dict)
175
+ and (
176
+ not isinstance(schema, (StructType, str, list, tuple))
177
+ or (isinstance(schema, (list, tuple)) and not isinstance(schema[0], str))
178
+ )
179
+ ):
180
+ raise NotImplementedError("Only schema of either list or string of list supported")
181
+
182
+ column_mapping: t.Mapping[str, t.Optional[exp.DataType]]
183
+ if schema is not None:
184
+ column_mapping = get_column_mapping_from_schema_input(
185
+ schema, dialect=self.input_dialect
186
+ )
187
+ elif data:
188
+ if isinstance(data[0], Row):
189
+ column_mapping = {col_name.strip(): None for col_name in data[0].__fields__}
190
+ elif isinstance(data[0], dict):
191
+ column_mapping = {col_name.strip(): None for col_name in data[0]}
192
+ else:
193
+ column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} # type: ignore
194
+ else:
195
+ column_mapping = {}
196
+
197
+ column_mapping = {
198
+ normalize_identifiers(k, self.input_dialect).sql(dialect=self.input_dialect): v
199
+ for k, v in column_mapping.items()
200
+ }
201
+ empty_df = not data
202
+ rows = [[None] * len(column_mapping)] if empty_df else list(data) # type: ignore
203
+
204
+ def get_default_data_type(value: t.Any) -> t.Optional[str]:
205
+ if isinstance(value, Row):
206
+ row_types = []
207
+ for row_name, row_dtype in zip(value.__fields__, value):
208
+ default_type = get_default_data_type(row_dtype)
209
+ if not default_type:
210
+ continue
211
+ row_types.append((row_name, default_type))
212
+ return "struct<" + ", ".join(f"{k}: {v}" for (k, v) in row_types) + ">"
213
+ elif isinstance(value, dict):
214
+ sample_row = seq_get(list(value.items()), 0)
215
+ if not sample_row:
216
+ return None
217
+ key, value = sample_row
218
+ default_key = get_default_data_type(key)
219
+ default_value = get_default_data_type(value)
220
+ if not default_key or not default_value:
221
+ return None
222
+ return f"map<{default_key}, {default_value}>"
223
+ elif isinstance(value, (list, set, tuple)):
224
+ if not value:
225
+ return None
226
+ default_type = get_default_data_type(next(iter(value)))
227
+ if not default_type:
228
+ return None
229
+ return f"array<{default_type}>"
230
+ elif isinstance(value, bool):
231
+ return "boolean"
232
+ elif isinstance(value, bytes):
233
+ return "binary"
234
+ elif isinstance(value, int):
235
+ return "bigint"
236
+ elif isinstance(value, float):
237
+ return "double"
238
+ elif isinstance(value, datetime.datetime):
239
+ if value.tzinfo:
240
+ return "timestamptz"
241
+ return "timestamp"
242
+ elif isinstance(value, datetime.date):
243
+ return "date"
244
+ elif isinstance(value, str):
245
+ return "string"
246
+ return None
247
+
248
+ updated_mapping: t.Dict[str, t.Optional[exp.DataType]] = {}
249
+ sample_row = rows[0]
250
+ for i, (name, dtype) in enumerate(column_mapping.items()):
251
+ if dtype is not None:
252
+ updated_mapping[name] = dtype
253
+ continue
254
+ if isinstance(sample_row, Row):
255
+ sample_row = sample_row.asDict()
256
+ if isinstance(sample_row, dict):
257
+ default_data_type = get_default_data_type(sample_row[name])
258
+ updated_mapping[name] = (
259
+ exp.DataType.build(default_data_type, dialect="spark")
260
+ if default_data_type
261
+ else None
262
+ )
263
+ else:
264
+ default_data_type = get_default_data_type(sample_row[i])
265
+ updated_mapping[name] = (
266
+ exp.DataType.build(default_data_type, dialect="spark")
267
+ if default_data_type
268
+ else None
269
+ )
270
+ column_mapping = updated_mapping
271
+ data_expressions = []
272
+ for row in rows:
273
+ if isinstance(row, (list, tuple, dict)):
274
+ if not row:
275
+ data_expressions.append(exp.tuple_(exp.Null()))
276
+ continue
277
+ if isinstance(row, Row):
278
+ row = row.asDict()
279
+ if isinstance(row, dict):
280
+ row = row.values() # type: ignore
281
+ data_expressions.append(exp.tuple_(*[F.lit(x).expression for x in row]))
282
+ else:
283
+ data_expressions.append(exp.tuple_(*[F.lit(row).expression]))
284
+
285
+ if column_mapping:
286
+ sel_columns = [
287
+ (
288
+ F.col(name).cast(data_type).alias(name).expression
289
+ if data_type is not None
290
+ else F.col(name).expression
291
+ )
292
+ for name, data_type in column_mapping.items()
293
+ ]
294
+ else:
295
+ sel_columns = [F.lit(None).expression]
296
+
297
+ select_kwargs = {
298
+ "expressions": sel_columns,
299
+ "from": exp.From(
300
+ this=exp.Values(
301
+ expressions=data_expressions,
302
+ alias=exp.TableAlias(
303
+ this=exp.to_identifier(self._auto_incrementing_name),
304
+ columns=[
305
+ exp.parse_identifier(col_name, dialect=self.input_dialect)
306
+ for col_name in column_mapping
307
+ ],
308
+ ),
309
+ ),
310
+ ),
311
+ }
312
+
313
+ sel_expression = exp.Select(**select_kwargs)
314
+ if empty_df:
315
+ sel_expression = sel_expression.where(exp.false())
316
+ return self._create_df(sel_expression)
317
+
318
+ def sql(self, sqlQuery: t.Union[str, exp.Expression], optimize: bool = True) -> DF:
319
+ expression = (
320
+ sqlglot.parse_one(sqlQuery, read=self.input_dialect)
321
+ if isinstance(sqlQuery, str)
322
+ else sqlQuery
323
+ )
324
+ if optimize:
325
+ expression = self._optimize(expression)
326
+ if self.temp_views:
327
+ replacement_mapping = {}
328
+ for table in expression.find_all(exp.Table):
329
+ if not (df := self.temp_views.get(table.name)):
330
+ continue
331
+ expression_ctes = {cte.alias_or_name: cte for cte in expression.ctes} # type: ignore
332
+ replacement_mapping[table] = df.expression.ctes[-1].alias_or_name
333
+ ctes_to_add = []
334
+ for cte in df.expression.ctes:
335
+ if cte.alias_or_name not in expression_ctes:
336
+ ctes_to_add.append(cte)
337
+ expression.set("with", exp.With(expressions=expression.ctes + ctes_to_add)) # type: ignore
338
+
339
+ def replace_temp_view_name_with_cte(node: exp.Expression) -> exp.Expression:
340
+ if isinstance(node, exp.Table):
341
+ if node in replacement_mapping:
342
+ node.set("this", exp.to_identifier(replacement_mapping[node]))
343
+ return node
344
+
345
+ if replacement_mapping:
346
+ expression = expression.transform(replace_temp_view_name_with_cte)
347
+
348
+ if isinstance(expression, exp.Select):
349
+ df = self._create_df(expression)
350
+ df = df._convert_leaf_to_cte()
351
+ elif isinstance(expression, (exp.Create, exp.Insert)):
352
+ select_expression = expression.expression.copy()
353
+ if isinstance(expression, exp.Insert):
354
+ select_expression.set("with", expression.args.get("with"))
355
+ expression.set("with", None)
356
+ del expression.args["expression"]
357
+ df = self._create_df(select_expression, output_expression_container=expression) # type: ignore
358
+ df = df._convert_leaf_to_cte()
359
+ else:
360
+ raise ValueError(
361
+ "Unknown expression type provided in the SQL. Please create an issue with the SQL."
362
+ )
363
+ return df
364
+
365
+ @property
366
+ def _auto_incrementing_name(self) -> str:
367
+ name = f"a{self.incrementing_id}"
368
+ self.incrementing_id += 1
369
+ return name
370
+
371
+ @property
372
+ def _random_branch_id(self) -> str:
373
+ id = self._random_id
374
+ self.known_branch_ids.add(id)
375
+ return id
376
+
377
+ @property
378
+ def _random_sequence_id(self):
379
+ id = self._random_id
380
+ self.known_sequence_ids.add(id)
381
+ return id
382
+
383
+ @property
384
+ def _random_id(self) -> str:
385
+ id = "r" + uuid.uuid4().hex
386
+ normalized_id = self._normalize_string(id)
387
+ self.known_ids.add(normalized_id)
388
+ return normalized_id
389
+
390
+ @property
391
+ def _join_hint_names(self) -> t.Set[str]:
392
+ return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"}
393
+
394
+ def _normalize_string(self, value: str) -> str:
395
+ expression = parse_identifier(value, dialect=self.input_dialect)
396
+ normalize_identifiers(expression, dialect=self.input_dialect)
397
+ return expression.sql(dialect=self.input_dialect)
398
+
399
+ def _add_alias_to_mapping(self, name: str, sequence_id: str):
400
+ self.name_to_sequence_id_mapping[self._normalize_string(name)].append(sequence_id)
401
+
402
+ def _to_sql(self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True) -> str:
403
+ if isinstance(sql, exp.Expression):
404
+ expression = sql.copy()
405
+ if quote_identifiers:
406
+ normalize_identifiers(expression, dialect=self.input_dialect)
407
+ quote_identifiers_func(expression, dialect=self.input_dialect)
408
+ sql = expression.sql(dialect=self.output_dialect)
409
+ return t.cast(str, sql)
410
+
411
+ def _optimize(
412
+ self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
413
+ ) -> exp.Expression:
414
+ dialect = dialect or self.output_dialect
415
+ quote_identifiers_func(expression, dialect=dialect)
416
+ return optimize(expression, dialect=dialect, schema=self.catalog._schema)
417
+
418
+ def _execute(
419
+ self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
420
+ ) -> None:
421
+ self._cur.execute(self._to_sql(sql, quote_identifiers=quote_identifiers))
422
+
423
+ @classmethod
424
+ def _try_get_map(cls, value: t.Any) -> t.Optional[t.Dict[str, t.Any]]:
425
+ return None if not isinstance(value, dict) else value
426
+
427
+ @classmethod
428
+ def _to_value(cls, value: t.Any) -> t.Any:
429
+ if (map_value := cls._try_get_map(value)) is not None:
430
+ return map_value
431
+ elif isinstance(value, dict):
432
+ return cls._to_row(list(value.keys()), list(value.values()))
433
+ elif isinstance(value, (list, set, tuple)) and value:
434
+ return [cls._to_value(x) for x in value]
435
+ return value
436
+
437
+ @classmethod
438
+ def _to_row(cls, columns: t.List[str], values: t.Iterable[t.Any]) -> Row:
439
+ from sqlframe.base.types import Row
440
+
441
+ converted_values = []
442
+ for value in values:
443
+ converted_values.append(cls._to_value(value))
444
+ return Row(**dict(zip(columns, converted_values)))
445
+
446
+ def _fetch_rows(
447
+ self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
448
+ ) -> t.List[Row]:
449
+ from sqlframe.base.types import Row
450
+
451
+ def _dict_to_row(row: t.Dict[str, t.Any]) -> Row:
452
+ for key, value in row.items():
453
+ if isinstance(value, dict):
454
+ row[key] = _dict_to_row(value)
455
+ return Row(**row)
456
+
457
+ self._execute(sql, quote_identifiers=quote_identifiers)
458
+ result = self._cur.fetchall()
459
+ if not self._cur.description:
460
+ return []
461
+ columns = [x[0] for x in self._cur.description]
462
+ return [self._to_row(columns, row) for row in result]
463
+
464
+ def _fetchdf(
465
+ self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
466
+ ) -> pd.DataFrame:
467
+ from pandas.io.sql import read_sql_query
468
+
469
+ return read_sql_query(self._to_sql(sql, quote_identifiers=quote_identifiers), self._conn)
470
+
471
+ @property
472
+ def _is_standalone(self) -> bool:
473
+ from sqlframe.standalone.session import StandaloneSession
474
+
475
+ return isinstance(self, StandaloneSession)
476
+
477
+ @property
478
+ def _is_duckdb(self) -> bool:
479
+ from sqlframe.duckdb.session import DuckDBSession
480
+
481
+ return isinstance(self, DuckDBSession)
482
+
483
+ @property
484
+ def _is_postgres(self) -> bool:
485
+ from sqlframe.postgres.session import PostgresSession
486
+
487
+ return isinstance(self, PostgresSession)
488
+
489
+ @property
490
+ def _is_spark(self) -> bool:
491
+ from sqlframe.spark.session import SparkSession
492
+
493
+ return isinstance(self, SparkSession)
494
+
495
+ @property
496
+ def _is_bigquery(self) -> bool:
497
+ from sqlframe.bigquery.session import BigQuerySession
498
+
499
+ return isinstance(self, BigQuerySession)
500
+
501
+ @property
502
+ def _is_redshift(self) -> bool:
503
+ from sqlframe.redshift.session import RedshiftSession
504
+
505
+ return isinstance(self, RedshiftSession)
506
+
507
+ @property
508
+ def _is_snowflake(self) -> bool:
509
+ from sqlframe.snowflake.session import SnowflakeSession
510
+
511
+ return isinstance(self, SnowflakeSession)
512
+
513
+ class Builder:
514
+ SQLFRAME_INPUT_DIALECT_KEY = "sqlframe.input.dialect"
515
+ SQLFRAME_OUTPUT_DIALECT_KEY = "sqlframe.output.dialect"
516
+ SQLFRAME_CONN_KEY = "sqlframe.conn"
517
+ SQLFRAME_SCHEMA_KEY = "sqlframe.schema"
518
+ DEFAULT_INPUT_DIALECT = "spark"
519
+ DEFAULT_OUTPUT_DIALECT = "spark"
520
+
521
+ def __init__(self):
522
+ self.input_dialect = self.DEFAULT_INPUT_DIALECT
523
+ self.output_dialect = self.DEFAULT_OUTPUT_DIALECT
524
+ self._conn = None
525
+ self._session_kwargs = {}
526
+
527
+ def __getattr__(self, item) -> Self:
528
+ return self
529
+
530
+ def __call__(self, *args, **kwargs):
531
+ return self
532
+
533
+ @property
534
+ def session(self) -> _BaseSession:
535
+ return _BaseSession(**self._session_kwargs)
536
+
537
+ def getOrCreate(self) -> _BaseSession:
538
+ self._set_session_properties()
539
+ return self.session
540
+
541
+ def _set_config(
542
+ self,
543
+ key: t.Optional[str] = None,
544
+ value: t.Optional[t.Any] = None,
545
+ *,
546
+ map: t.Optional[t.Dict[str, t.Any]] = None,
547
+ ) -> None:
548
+ if value is not None:
549
+ if key == self.SQLFRAME_INPUT_DIALECT_KEY:
550
+ self.input_dialect = value
551
+ elif key == self.SQLFRAME_OUTPUT_DIALECT_KEY:
552
+ self.output_dialect = value
553
+ elif key == self.SQLFRAME_CONN_KEY:
554
+ self._session_kwargs["conn"] = value
555
+ elif key == self.SQLFRAME_SCHEMA_KEY:
556
+ self._session_kwargs["schema"] = value
557
+ else:
558
+ self._session_kwargs[key] = value
559
+ if map:
560
+ if self.SQLFRAME_INPUT_DIALECT_KEY in map:
561
+ self.input_dialect = map[self.SQLFRAME_INPUT_DIALECT_KEY]
562
+ if self.SQLFRAME_OUTPUT_DIALECT_KEY in map:
563
+ self.output_dialect = map[self.SQLFRAME_OUTPUT_DIALECT_KEY]
564
+ if self.SQLFRAME_CONN_KEY in map:
565
+ self._session_kwargs["conn"] = map[self.SQLFRAME_CONN_KEY]
566
+ if self.SQLFRAME_SCHEMA_KEY in map:
567
+ self._session_kwargs["schema"] = map[self.SQLFRAME_SCHEMA_KEY]
568
+
569
+ def config(
570
+ self,
571
+ key: t.Optional[str] = None,
572
+ value: t.Optional[t.Any] = None,
573
+ *,
574
+ map: t.Optional[t.Dict[str, t.Any]] = None,
575
+ ) -> Self:
576
+ self._set_config(key, value, map=map)
577
+ return self
578
+
579
+ def _set_session_properties(self) -> None:
580
+ self.session.input_dialect = Dialect.get_or_raise(self.input_dialect)
581
+ self.session.output_dialect = Dialect.get_or_raise(self.output_dialect)
582
+ if not self.session._connection:
583
+ self.session._connection = self._conn
584
+
585
+ builder = Builder()
@@ -0,0 +1,13 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ import typing as t
4
+
5
+ from sqlglot import expressions as exp
6
+
7
+
8
+ def replace_id_value(
9
+ node: exp.Expression, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]
10
+ ) -> exp.Expression:
11
+ if isinstance(node, exp.Identifier) and node in replacement_mapping:
12
+ node = node.replace(replacement_mapping[node].copy())
13
+ return node