sqlframe 1.2.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/column.py +7 -3
- sqlframe/base/dataframe.py +94 -7
- sqlframe/base/decorators.py +17 -15
- sqlframe/base/mixins/catalog_mixins.py +1 -1
- sqlframe/base/mixins/readwriter_mixins.py +4 -3
- sqlframe/base/readerwriter.py +3 -0
- sqlframe/base/session.py +6 -9
- sqlframe/base/util.py +38 -1
- sqlframe/snowflake/catalog.py +3 -1
- sqlframe/snowflake/session.py +31 -0
- sqlframe/spark/session.py +3 -1
- {sqlframe-1.2.0.dist-info → sqlframe-1.4.0.dist-info}/METADATA +18 -11
- {sqlframe-1.2.0.dist-info → sqlframe-1.4.0.dist-info}/RECORD +17 -17
- {sqlframe-1.2.0.dist-info → sqlframe-1.4.0.dist-info}/LICENSE +0 -0
- {sqlframe-1.2.0.dist-info → sqlframe-1.4.0.dist-info}/WHEEL +0 -0
- {sqlframe-1.2.0.dist-info → sqlframe-1.4.0.dist-info}/top_level.txt +0 -0
sqlframe/_version.py
CHANGED
sqlframe/base/column.py
CHANGED
|
@@ -9,9 +9,11 @@ import typing as t
|
|
|
9
9
|
import sqlglot
|
|
10
10
|
from sqlglot import expressions as exp
|
|
11
11
|
from sqlglot.helper import flatten, is_iterable
|
|
12
|
+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
|
12
13
|
|
|
14
|
+
from sqlframe.base.decorators import normalize
|
|
13
15
|
from sqlframe.base.types import DataType
|
|
14
|
-
from sqlframe.base.util import get_func_from_session
|
|
16
|
+
from sqlframe.base.util import get_func_from_session, quote_preserving_alias_or_name
|
|
15
17
|
|
|
16
18
|
if t.TYPE_CHECKING:
|
|
17
19
|
from sqlframe.base._typing import ColumnOrLiteral, ColumnOrName
|
|
@@ -237,7 +239,7 @@ class Column:
|
|
|
237
239
|
|
|
238
240
|
@property
|
|
239
241
|
def alias_or_name(self) -> str:
|
|
240
|
-
return self.expression
|
|
242
|
+
return quote_preserving_alias_or_name(self.expression) # type: ignore
|
|
241
243
|
|
|
242
244
|
@classmethod
|
|
243
245
|
def ensure_literal(cls, value) -> Column:
|
|
@@ -266,7 +268,9 @@ class Column:
|
|
|
266
268
|
from sqlframe.base.session import _BaseSession
|
|
267
269
|
|
|
268
270
|
dialect = _BaseSession().input_dialect
|
|
269
|
-
alias: exp.Expression =
|
|
271
|
+
alias: exp.Expression = normalize_identifiers(
|
|
272
|
+
exp.parse_identifier(name, dialect=dialect), dialect=dialect
|
|
273
|
+
)
|
|
270
274
|
new_expression = exp.Alias(
|
|
271
275
|
this=self.column_expression,
|
|
272
276
|
alias=alias.this if isinstance(alias, exp.Column) else alias,
|
sqlframe/base/dataframe.py
CHANGED
|
@@ -2,26 +2,34 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import enum
|
|
5
6
|
import functools
|
|
6
7
|
import itertools
|
|
8
|
+
import json
|
|
7
9
|
import logging
|
|
8
10
|
import sys
|
|
9
11
|
import typing as t
|
|
10
12
|
import zlib
|
|
11
13
|
from copy import copy
|
|
14
|
+
from dataclasses import dataclass
|
|
12
15
|
|
|
13
16
|
import sqlglot
|
|
14
17
|
from prettytable import PrettyTable
|
|
15
18
|
from sqlglot import Dialect
|
|
16
19
|
from sqlglot import expressions as exp
|
|
17
20
|
from sqlglot.helper import ensure_list, object_to_dict, seq_get
|
|
21
|
+
from sqlglot.optimizer.pushdown_projections import pushdown_projections
|
|
22
|
+
from sqlglot.optimizer.qualify import qualify
|
|
18
23
|
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
|
19
24
|
|
|
25
|
+
from sqlframe.base.decorators import normalize
|
|
20
26
|
from sqlframe.base.operations import Operation, operation
|
|
21
27
|
from sqlframe.base.transforms import replace_id_value
|
|
22
28
|
from sqlframe.base.util import (
|
|
23
29
|
get_func_from_session,
|
|
24
30
|
get_tables_from_expression_with_join,
|
|
31
|
+
quote_preserving_alias_or_name,
|
|
32
|
+
verify_openai_installed,
|
|
25
33
|
)
|
|
26
34
|
|
|
27
35
|
if sys.version_info >= (3, 11):
|
|
@@ -70,6 +78,46 @@ JOIN_HINTS = {
|
|
|
70
78
|
DF = t.TypeVar("DF", bound="_BaseDataFrame")
|
|
71
79
|
|
|
72
80
|
|
|
81
|
+
class OpenAIMode(enum.Enum):
|
|
82
|
+
CTE_ONLY = "cte_only"
|
|
83
|
+
FULL = "full"
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def is_cte_only(self) -> bool:
|
|
87
|
+
return self == OpenAIMode.CTE_ONLY
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def is_full(self) -> bool:
|
|
91
|
+
return self == OpenAIMode.FULL
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class OpenAIConfig:
|
|
96
|
+
mode: OpenAIMode = OpenAIMode.CTE_ONLY
|
|
97
|
+
model: str = "gpt-4o"
|
|
98
|
+
prompt_override: t.Optional[str] = None
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_dict(cls, config: t.Dict[str, t.Any]) -> OpenAIConfig:
|
|
102
|
+
if "mode" in config:
|
|
103
|
+
config["mode"] = OpenAIMode(config["mode"].lower())
|
|
104
|
+
return cls(**config)
|
|
105
|
+
|
|
106
|
+
def get_prompt(self, dialect: Dialect) -> str:
|
|
107
|
+
if self.prompt_override:
|
|
108
|
+
return self.prompt_override
|
|
109
|
+
if self.mode.is_cte_only:
|
|
110
|
+
return f"You are a backend tool that creates unique CTE alias names match what a human would write and in snake case. You respond without code blocks and only a json payload with the key being the CTE name that is being replaced and the value being the new CTE human readable name."
|
|
111
|
+
return f"""
|
|
112
|
+
You are a backend tool that converts correct {dialect} SQL to simplified and more human readable version.
|
|
113
|
+
You respond without code block with rewritten {dialect} SQL.
|
|
114
|
+
You don't change any column names in the final select because the user expects those to remain the same.
|
|
115
|
+
You make unique CTE alias names match what a human would write and in snake case.
|
|
116
|
+
You improve formatting with spacing and line-breaks.
|
|
117
|
+
You remove redundant parenthesis and aliases.
|
|
118
|
+
When remove extra quotes, make sure to keep quotes around words that could be reserved words"""
|
|
119
|
+
|
|
120
|
+
|
|
73
121
|
class _BaseDataFrameNaFunctions(t.Generic[DF]):
|
|
74
122
|
def __init__(self, df: DF):
|
|
75
123
|
self.df = df
|
|
@@ -410,7 +458,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
410
458
|
|
|
411
459
|
outer_select = item.find(exp.Select)
|
|
412
460
|
if outer_select:
|
|
413
|
-
return [col(x
|
|
461
|
+
return [col(quote_preserving_alias_or_name(x)) for x in outer_select.expressions]
|
|
414
462
|
return []
|
|
415
463
|
|
|
416
464
|
def _create_hash_from_expression(self, expression: exp.Expression) -> str:
|
|
@@ -471,6 +519,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
471
519
|
dialect: DialectType = None,
|
|
472
520
|
optimize: bool = True,
|
|
473
521
|
pretty: bool = True,
|
|
522
|
+
openai_config: t.Optional[t.Union[t.Dict[str, t.Any], OpenAIConfig]] = None,
|
|
474
523
|
as_list: bool = False,
|
|
475
524
|
**kwargs,
|
|
476
525
|
) -> t.Union[str, t.List[str]]:
|
|
@@ -480,6 +529,11 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
480
529
|
select_expressions = df._get_select_expressions()
|
|
481
530
|
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
|
|
482
531
|
replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {}
|
|
532
|
+
openai_config = (
|
|
533
|
+
OpenAIConfig.from_dict(openai_config)
|
|
534
|
+
if openai_config is not None and isinstance(openai_config, dict)
|
|
535
|
+
else openai_config
|
|
536
|
+
)
|
|
483
537
|
|
|
484
538
|
for expression_type, select_expression in select_expressions:
|
|
485
539
|
select_expression = select_expression.transform(
|
|
@@ -490,6 +544,9 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
490
544
|
select_expression = t.cast(
|
|
491
545
|
exp.Select, self.session._optimize(select_expression, dialect=dialect)
|
|
492
546
|
)
|
|
547
|
+
elif openai_config:
|
|
548
|
+
qualify(select_expression, dialect=dialect, schema=self.session.catalog._schema)
|
|
549
|
+
pushdown_projections(select_expression, schema=self.session.catalog._schema)
|
|
493
550
|
|
|
494
551
|
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
|
495
552
|
|
|
@@ -505,7 +562,9 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
505
562
|
self.session.catalog.add_table(
|
|
506
563
|
cache_table_name,
|
|
507
564
|
{
|
|
508
|
-
expression
|
|
565
|
+
quote_preserving_alias_or_name(expression): expression.type.sql(
|
|
566
|
+
dialect=dialect
|
|
567
|
+
)
|
|
509
568
|
if expression.type
|
|
510
569
|
else "UNKNOWN"
|
|
511
570
|
for expression in select_expression.expressions
|
|
@@ -541,10 +600,37 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
541
600
|
|
|
542
601
|
output_expressions.append(expression)
|
|
543
602
|
|
|
544
|
-
results = [
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
603
|
+
results = []
|
|
604
|
+
for expression in output_expressions:
|
|
605
|
+
sql = expression.sql(dialect=dialect, pretty=pretty, **kwargs)
|
|
606
|
+
if openai_config:
|
|
607
|
+
assert isinstance(openai_config, OpenAIConfig)
|
|
608
|
+
verify_openai_installed()
|
|
609
|
+
from openai import OpenAI
|
|
610
|
+
|
|
611
|
+
client = OpenAI()
|
|
612
|
+
chat_completed = client.chat.completions.create(
|
|
613
|
+
messages=[
|
|
614
|
+
{ # type: ignore
|
|
615
|
+
"role": "system",
|
|
616
|
+
"content": openai_config.get_prompt(dialect),
|
|
617
|
+
},
|
|
618
|
+
{
|
|
619
|
+
"role": "user",
|
|
620
|
+
"content": sql,
|
|
621
|
+
},
|
|
622
|
+
],
|
|
623
|
+
model=openai_config.model,
|
|
624
|
+
)
|
|
625
|
+
assert chat_completed.choices[0].message.content is not None
|
|
626
|
+
if openai_config.mode.is_cte_only:
|
|
627
|
+
cte_replacement_mapping = json.loads(chat_completed.choices[0].message.content)
|
|
628
|
+
for old_name, new_name in cte_replacement_mapping.items():
|
|
629
|
+
sql = sql.replace(old_name, new_name)
|
|
630
|
+
else:
|
|
631
|
+
sql = chat_completed.choices[0].message.content
|
|
632
|
+
results.append(sql)
|
|
633
|
+
|
|
548
634
|
if as_list:
|
|
549
635
|
return results
|
|
550
636
|
return ";\n".join(results)
|
|
@@ -688,7 +774,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
688
774
|
join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes)
|
|
689
775
|
self_columns = self._get_outer_select_columns(join_expression)
|
|
690
776
|
other_columns = self._get_outer_select_columns(other_df.expression)
|
|
691
|
-
join_columns = self.
|
|
777
|
+
join_columns = self._ensure_and_normalize_cols(on)
|
|
692
778
|
# Determines the join clause and select columns to be used passed on what type of columns were provided for
|
|
693
779
|
# the join. The columns returned changes based on how the on expression is provided.
|
|
694
780
|
if how != "cross":
|
|
@@ -1324,6 +1410,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
1324
1410
|
assert sqls[-1] is not None
|
|
1325
1411
|
return self.session._fetchdf(sqls[-1])
|
|
1326
1412
|
|
|
1413
|
+
@normalize("name")
|
|
1327
1414
|
def createOrReplaceTempView(self, name: str) -> None:
|
|
1328
1415
|
self.session.temp_views[name] = self.copy()._convert_leaf_to_cte()
|
|
1329
1416
|
|
sqlframe/base/decorators.py
CHANGED
|
@@ -10,31 +10,33 @@ from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
|
|
10
10
|
if t.TYPE_CHECKING:
|
|
11
11
|
from sqlframe.base.catalog import _BaseCatalog
|
|
12
12
|
|
|
13
|
+
CALLING_CLASS = t.TypeVar("CALLING_CLASS")
|
|
13
14
|
|
|
14
|
-
|
|
15
|
+
|
|
16
|
+
def normalize(normalize_kwargs: t.Union[str, t.List[str]]) -> t.Callable[[t.Callable], t.Callable]:
|
|
15
17
|
"""
|
|
16
|
-
Decorator used
|
|
17
|
-
ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
|
|
18
|
-
included with the previous operation.
|
|
19
|
-
|
|
20
|
-
Ex: After a user does a join we want to allow them to select which columns for the different
|
|
21
|
-
tables that they want to carry through to the following operation. If we put that join in
|
|
22
|
-
a CTE preemptively then the user would not have a chance to select which column they want
|
|
23
|
-
in cases where there is overlap in names.
|
|
18
|
+
Decorator used to normalize identifiers in the kwargs of a method.
|
|
24
19
|
"""
|
|
25
20
|
|
|
26
21
|
def decorator(func: t.Callable) -> t.Callable:
|
|
27
22
|
@functools.wraps(func)
|
|
28
|
-
def wrapper(self:
|
|
23
|
+
def wrapper(self: CALLING_CLASS, *args, **kwargs) -> CALLING_CLASS:
|
|
24
|
+
from sqlframe.base.session import _BaseSession
|
|
25
|
+
|
|
26
|
+
input_dialect = _BaseSession().input_dialect
|
|
29
27
|
kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
|
|
30
|
-
for kwarg in normalize_kwargs:
|
|
28
|
+
for kwarg in ensure_list(normalize_kwargs):
|
|
31
29
|
if kwarg in kwargs:
|
|
32
30
|
value = kwargs.get(kwarg)
|
|
33
31
|
if value:
|
|
34
|
-
expression =
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
32
|
+
expression = (
|
|
33
|
+
parse_one(value, dialect=input_dialect)
|
|
34
|
+
if isinstance(value, str)
|
|
35
|
+
else value
|
|
36
|
+
)
|
|
37
|
+
kwargs[kwarg] = normalize_identifiers(expression, input_dialect).sql(
|
|
38
|
+
dialect=input_dialect
|
|
39
|
+
)
|
|
38
40
|
return func(self, **kwargs)
|
|
39
41
|
|
|
40
42
|
wrapper.__wrapped__ = func # type: ignore
|
|
@@ -13,7 +13,7 @@ from sqlframe.base.catalog import (
|
|
|
13
13
|
_BaseCatalog,
|
|
14
14
|
)
|
|
15
15
|
from sqlframe.base.decorators import normalize
|
|
16
|
-
from sqlframe.base.util import
|
|
16
|
+
from sqlframe.base.util import schema_, to_schema
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class _BaseInfoSchemaMixin(_BaseCatalog, t.Generic[SESSION, DF]):
|
|
@@ -3,8 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
import pathlib
|
|
4
4
|
import typing as t
|
|
5
5
|
|
|
6
|
-
import pandas as pd
|
|
7
|
-
|
|
8
6
|
from sqlframe.base.exceptions import UnsupportedOperationError
|
|
9
7
|
from sqlframe.base.readerwriter import (
|
|
10
8
|
DF,
|
|
@@ -13,7 +11,7 @@ from sqlframe.base.readerwriter import (
|
|
|
13
11
|
_BaseDataFrameWriter,
|
|
14
12
|
_infer_format,
|
|
15
13
|
)
|
|
16
|
-
from sqlframe.base.util import pandas_to_spark_schema
|
|
14
|
+
from sqlframe.base.util import pandas_to_spark_schema, verify_pandas_installed
|
|
17
15
|
|
|
18
16
|
if t.TYPE_CHECKING:
|
|
19
17
|
from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
|
|
@@ -72,6 +70,9 @@ class PandasLoaderMixin(_BaseDataFrameReader, t.Generic[SESSION, DF]):
|
|
|
72
70
|
|100|NULL|
|
|
73
71
|
+---+----+
|
|
74
72
|
"""
|
|
73
|
+
verify_pandas_installed()
|
|
74
|
+
import pandas as pd
|
|
75
|
+
|
|
75
76
|
assert path is not None, "path is required"
|
|
76
77
|
assert isinstance(path, str), "path must be a string"
|
|
77
78
|
format = format or _infer_format(path)
|
sqlframe/base/readerwriter.py
CHANGED
|
@@ -11,6 +11,8 @@ from functools import reduce
|
|
|
11
11
|
from sqlglot import exp
|
|
12
12
|
from sqlglot.helper import object_to_dict
|
|
13
13
|
|
|
14
|
+
from sqlframe.base.decorators import normalize
|
|
15
|
+
|
|
14
16
|
if sys.version_info >= (3, 11):
|
|
15
17
|
from typing import Self
|
|
16
18
|
else:
|
|
@@ -39,6 +41,7 @@ class _BaseDataFrameReader(t.Generic[SESSION, DF]):
|
|
|
39
41
|
def session(self) -> SESSION:
|
|
40
42
|
return self._session
|
|
41
43
|
|
|
44
|
+
@normalize("tableName")
|
|
42
45
|
def table(self, tableName: str) -> DF:
|
|
43
46
|
if df := self.session.temp_views.get(tableName):
|
|
44
47
|
return df
|
sqlframe/base/session.py
CHANGED
|
@@ -24,7 +24,10 @@ from sqlglot.schema import MappingSchema
|
|
|
24
24
|
from sqlframe.base.catalog import _BaseCatalog
|
|
25
25
|
from sqlframe.base.dataframe import _BaseDataFrame
|
|
26
26
|
from sqlframe.base.readerwriter import _BaseDataFrameReader, _BaseDataFrameWriter
|
|
27
|
-
from sqlframe.base.util import
|
|
27
|
+
from sqlframe.base.util import (
|
|
28
|
+
get_column_mapping_from_schema_input,
|
|
29
|
+
verify_pandas_installed,
|
|
30
|
+
)
|
|
28
31
|
|
|
29
32
|
if sys.version_info >= (3, 11):
|
|
30
33
|
from typing import Self
|
|
@@ -412,6 +415,7 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN]):
|
|
|
412
415
|
self, expression: exp.Expression, dialect: t.Optional[Dialect] = None
|
|
413
416
|
) -> exp.Expression:
|
|
414
417
|
dialect = dialect or self.output_dialect
|
|
418
|
+
normalize_identifiers(expression, dialect=self.input_dialect)
|
|
415
419
|
quote_identifiers_func(expression, dialect=dialect)
|
|
416
420
|
return optimize(expression, dialect=dialect, schema=self.catalog._schema)
|
|
417
421
|
|
|
@@ -446,14 +450,6 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN]):
|
|
|
446
450
|
def _fetch_rows(
|
|
447
451
|
self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
|
|
448
452
|
) -> t.List[Row]:
|
|
449
|
-
from sqlframe.base.types import Row
|
|
450
|
-
|
|
451
|
-
def _dict_to_row(row: t.Dict[str, t.Any]) -> Row:
|
|
452
|
-
for key, value in row.items():
|
|
453
|
-
if isinstance(value, dict):
|
|
454
|
-
row[key] = _dict_to_row(value)
|
|
455
|
-
return Row(**row)
|
|
456
|
-
|
|
457
453
|
self._execute(sql, quote_identifiers=quote_identifiers)
|
|
458
454
|
result = self._cur.fetchall()
|
|
459
455
|
if not self._cur.description:
|
|
@@ -464,6 +460,7 @@ class _BaseSession(t.Generic[CATALOG, READER, WRITER, DF, CONN]):
|
|
|
464
460
|
def _fetchdf(
|
|
465
461
|
self, sql: t.Union[str, exp.Expression], *, quote_identifiers: bool = True
|
|
466
462
|
) -> pd.DataFrame:
|
|
463
|
+
verify_pandas_installed()
|
|
467
464
|
from pandas.io.sql import read_sql_query
|
|
468
465
|
|
|
469
466
|
return read_sql_query(self._to_sql(sql, quote_identifiers=quote_identifiers), self._conn)
|
sqlframe/base/util.py
CHANGED
|
@@ -154,7 +154,12 @@ def pandas_to_spark_schema(pandas_df: PandasDataFrame) -> types.StructType:
|
|
|
154
154
|
"""
|
|
155
155
|
from sqlframe.base import types
|
|
156
156
|
|
|
157
|
-
columns = list(
|
|
157
|
+
columns = list(
|
|
158
|
+
[
|
|
159
|
+
x.replace("?column?", f"unknown_column_{i}").replace("NULL", f"unknown_column_{i}")
|
|
160
|
+
for i, x in enumerate(pandas_df.columns)
|
|
161
|
+
]
|
|
162
|
+
)
|
|
158
163
|
d_types = list(pandas_df.dtypes)
|
|
159
164
|
p_schema = types.StructType(
|
|
160
165
|
[
|
|
@@ -240,3 +245,35 @@ def soundex(s):
|
|
|
240
245
|
|
|
241
246
|
result += "0" * (4 - count)
|
|
242
247
|
return "".join(result)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def verify_pandas_installed():
|
|
251
|
+
try:
|
|
252
|
+
import pandas # noqa
|
|
253
|
+
except ImportError:
|
|
254
|
+
raise ImportError(
|
|
255
|
+
"""Pandas is required for this functionality. `pip install "sqlframe[pandas]"` (also include your engine if needed) to install pandas."""
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def verify_openai_installed():
|
|
260
|
+
try:
|
|
261
|
+
import openai # noqa
|
|
262
|
+
except ImportError:
|
|
263
|
+
raise ImportError(
|
|
264
|
+
"""OpenAI is required for this functionality. `pip install "sqlframe[openai]"` (also include your engine if needed) to install openai."""
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def quote_preserving_alias_or_name(col: t.Union[exp.Column, exp.Alias]) -> str:
|
|
269
|
+
from sqlframe.base.session import _BaseSession
|
|
270
|
+
|
|
271
|
+
if isinstance(col, exp.Alias):
|
|
272
|
+
col = col.args["alias"]
|
|
273
|
+
if isinstance(col, exp.Column):
|
|
274
|
+
col = col.copy()
|
|
275
|
+
col.set("table", None)
|
|
276
|
+
if isinstance(col, (exp.Identifier, exp.Column)):
|
|
277
|
+
return col.sql(dialect=_BaseSession().input_dialect)
|
|
278
|
+
# We may get things like `Null()` expression or maybe literals so we just return the alias or name in those cases
|
|
279
|
+
return col.alias_or_name
|
sqlframe/snowflake/catalog.py
CHANGED
|
@@ -127,7 +127,9 @@ class SnowflakeCatalog(
|
|
|
127
127
|
sql = f"SHOW COLUMNS IN TABLE {table.sql(dialect=self.session.input_dialect)}"
|
|
128
128
|
results = self.session._fetch_rows(sql)
|
|
129
129
|
return {
|
|
130
|
-
row["column_name"]
|
|
130
|
+
exp.column(row["column_name"], quoted=True).sql(
|
|
131
|
+
dialect=self.session.input_dialect
|
|
132
|
+
): exp.DataType.build(
|
|
131
133
|
json.loads(row["data_type"])["type"], dialect=self.session.input_dialect, udt=True
|
|
132
134
|
)
|
|
133
135
|
for row in results
|
sqlframe/snowflake/session.py
CHANGED
|
@@ -1,8 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import typing as t
|
|
4
5
|
import warnings
|
|
5
6
|
|
|
7
|
+
try:
|
|
8
|
+
from snowflake.connector.converter import SnowflakeConverter
|
|
9
|
+
except ImportError:
|
|
10
|
+
SnowflakeConverter = object # type: ignore
|
|
11
|
+
|
|
6
12
|
from sqlframe.base.session import _BaseSession
|
|
7
13
|
from sqlframe.snowflake.catalog import SnowflakeCatalog
|
|
8
14
|
from sqlframe.snowflake.dataframe import SnowflakeDataFrame
|
|
@@ -17,6 +23,18 @@ else:
|
|
|
17
23
|
SnowflakeConnection = t.Any
|
|
18
24
|
|
|
19
25
|
|
|
26
|
+
class JsonLoadsSnowflakeConverter(SnowflakeConverter):
|
|
27
|
+
def _json_loads(self, ctx: dict[str, t.Any]) -> t.Callable:
|
|
28
|
+
def conv(value: str) -> t.List:
|
|
29
|
+
return json.loads(value)
|
|
30
|
+
|
|
31
|
+
return conv
|
|
32
|
+
|
|
33
|
+
_OBJECT_to_python = _json_loads # type: ignore
|
|
34
|
+
_VARIANT_to_python = _json_loads # type: ignore
|
|
35
|
+
_ARRAY_to_python = _json_loads # type: ignore
|
|
36
|
+
|
|
37
|
+
|
|
20
38
|
class SnowflakeSession(
|
|
21
39
|
_BaseSession[ # type: ignore
|
|
22
40
|
SnowflakeCatalog,
|
|
@@ -35,8 +53,21 @@ class SnowflakeSession(
|
|
|
35
53
|
warnings.warn(
|
|
36
54
|
"SnowflakeSession is still in active development. Functions may not work as expected."
|
|
37
55
|
)
|
|
56
|
+
import snowflake
|
|
57
|
+
|
|
58
|
+
snowflake.connector.cursor.CAN_USE_ARROW_RESULT_FORMAT = False
|
|
59
|
+
|
|
38
60
|
if not hasattr(self, "_conn"):
|
|
39
61
|
super().__init__(conn)
|
|
62
|
+
if self._conn.converter and not isinstance(
|
|
63
|
+
self._conn.converter, JsonLoadsSnowflakeConverter
|
|
64
|
+
):
|
|
65
|
+
self._conn.converter = JsonLoadsSnowflakeConverter(
|
|
66
|
+
use_numpy=self._conn._numpy,
|
|
67
|
+
support_negative_year=self._conn._support_negative_year,
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
self._conn._converter_class = JsonLoadsSnowflakeConverter # type: ignore
|
|
40
71
|
|
|
41
72
|
class Builder(_BaseSession.Builder):
|
|
42
73
|
DEFAULT_INPUT_DIALECT = "snowflake"
|
sqlframe/spark/session.py
CHANGED
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
|
3
3
|
import typing as t
|
|
4
4
|
import warnings
|
|
5
5
|
|
|
6
|
-
import pandas as pd
|
|
7
6
|
from sqlglot import exp
|
|
8
7
|
|
|
9
8
|
from sqlframe.base.session import _BaseSession
|
|
@@ -15,6 +14,9 @@ from sqlframe.spark.readwriter import (
|
|
|
15
14
|
)
|
|
16
15
|
from sqlframe.spark.types import Row
|
|
17
16
|
|
|
17
|
+
if t.TYPE_CHECKING:
|
|
18
|
+
import pandas as pd
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
class SparkSession(
|
|
20
22
|
_BaseSession[ # type: ignore
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: sqlframe
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.4.0
|
|
4
4
|
Summary: Taking the Spark out of PySpark by converting to SQL
|
|
5
5
|
Home-page: https://github.com/eakmanrq/sqlframe
|
|
6
6
|
Author: Ryan Eakman
|
|
@@ -22,10 +22,10 @@ Requires-Dist: sqlglot (<24.1,>=24.0.0)
|
|
|
22
22
|
Provides-Extra: bigquery
|
|
23
23
|
Requires-Dist: google-cloud-bigquery-storage (<3,>=2) ; extra == 'bigquery'
|
|
24
24
|
Requires-Dist: google-cloud-bigquery[pandas] (<4,>=3) ; extra == 'bigquery'
|
|
25
|
-
Requires-Dist: pandas (<3,>=2) ; extra == 'bigquery'
|
|
26
25
|
Provides-Extra: dev
|
|
27
26
|
Requires-Dist: duckdb (<0.11,>=0.9) ; extra == 'dev'
|
|
28
27
|
Requires-Dist: mypy (<1.11,>=1.10.0) ; extra == 'dev'
|
|
28
|
+
Requires-Dist: openai (<1.31,>=1.30) ; extra == 'dev'
|
|
29
29
|
Requires-Dist: pandas-stubs (<3,>=2) ; extra == 'dev'
|
|
30
30
|
Requires-Dist: pandas (<3,>=2) ; extra == 'dev'
|
|
31
31
|
Requires-Dist: psycopg (<4,>=3.1) ; extra == 'dev'
|
|
@@ -48,32 +48,33 @@ Requires-Dist: pymdown-extensions ; extra == 'docs'
|
|
|
48
48
|
Provides-Extra: duckdb
|
|
49
49
|
Requires-Dist: duckdb (<0.11,>=0.9) ; extra == 'duckdb'
|
|
50
50
|
Requires-Dist: pandas (<3,>=2) ; extra == 'duckdb'
|
|
51
|
+
Provides-Extra: openai
|
|
52
|
+
Requires-Dist: openai (<1.31,>=1.30) ; extra == 'openai'
|
|
53
|
+
Provides-Extra: pandas
|
|
54
|
+
Requires-Dist: pandas (<3,>=2) ; extra == 'pandas'
|
|
51
55
|
Provides-Extra: postgres
|
|
52
|
-
Requires-Dist: pandas (<3,>=2) ; extra == 'postgres'
|
|
53
56
|
Requires-Dist: psycopg2 (<3,>=2.8) ; extra == 'postgres'
|
|
54
57
|
Provides-Extra: redshift
|
|
55
|
-
Requires-Dist: pandas (<3,>=2) ; extra == 'redshift'
|
|
56
58
|
Requires-Dist: redshift-connector (<2.2.0,>=2.1.1) ; extra == 'redshift'
|
|
57
59
|
Provides-Extra: snowflake
|
|
58
|
-
Requires-Dist:
|
|
59
|
-
Requires-Dist: snowflake-connector-python[pandas,secure-local-storage] (<3.11,>=3.10.0) ; extra == 'snowflake'
|
|
60
|
+
Requires-Dist: snowflake-connector-python[secure-local-storage] (<3.11,>=3.10.0) ; extra == 'snowflake'
|
|
60
61
|
Provides-Extra: spark
|
|
61
62
|
Requires-Dist: pyspark (<3.6,>=2) ; extra == 'spark'
|
|
62
63
|
|
|
63
64
|
<div align="center">
|
|
64
|
-
<img src="https://sqlframe.readthedocs.io/en/
|
|
65
|
+
<img src="https://sqlframe.readthedocs.io/en/stable/docs/images/sqlframe_logo.png" alt="SQLFrame Logo" width="400"/>
|
|
65
66
|
</div>
|
|
66
67
|
|
|
67
68
|
SQLFrame implements the PySpark DataFrame API in order to enable running transformation pipelines directly on database engines - no Spark clusters or dependencies required.
|
|
68
69
|
|
|
69
70
|
SQLFrame currently supports the following engines (many more in development):
|
|
70
71
|
|
|
71
|
-
* [BigQuery](https://sqlframe.readthedocs.io/en/
|
|
72
|
-
* [DuckDB](https://sqlframe.readthedocs.io/en/
|
|
73
|
-
* [Postgres](https://sqlframe.readthedocs.io/en/
|
|
72
|
+
* [BigQuery](https://sqlframe.readthedocs.io/en/stable/bigquery/)
|
|
73
|
+
* [DuckDB](https://sqlframe.readthedocs.io/en/stable/duckdb)
|
|
74
|
+
* [Postgres](https://sqlframe.readthedocs.io/en/stable/postgres)
|
|
74
75
|
|
|
75
76
|
SQLFrame also has a "Standalone" session that be used to generate SQL without any connection to a database engine.
|
|
76
|
-
* [Standalone](https://sqlframe.readthedocs.io/en/
|
|
77
|
+
* [Standalone](https://sqlframe.readthedocs.io/en/stable/standalone)
|
|
77
78
|
|
|
78
79
|
SQLFrame is great for:
|
|
79
80
|
|
|
@@ -96,6 +97,12 @@ pip install sqlframe
|
|
|
96
97
|
|
|
97
98
|
See specific engine documentation for additional setup instructions.
|
|
98
99
|
|
|
100
|
+
## Configuration
|
|
101
|
+
|
|
102
|
+
SQLFrame generates consistently accurate yet complex SQL for engine execution.
|
|
103
|
+
However, when using df.sql(), it produces more human-readable SQL.
|
|
104
|
+
For details on how to configure this output and leverage OpenAI to enhance the SQL, see [Generated SQL Configuration](https://sqlframe.readthedocs.io/en/stable/configuration/#generated-sql).
|
|
105
|
+
|
|
99
106
|
## Example Usage
|
|
100
107
|
|
|
101
108
|
```python
|
|
@@ -1,27 +1,27 @@
|
|
|
1
1
|
sqlframe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
sqlframe/_version.py,sha256=
|
|
2
|
+
sqlframe/_version.py,sha256=R8-T9fmURjcuoxYpHTAjyNAhgJPDtI2jogCjqYYkfCU,411
|
|
3
3
|
sqlframe/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
sqlframe/base/_typing.py,sha256=DuTay8-o9W-pw3RPZCgLunKNJLS9PkaV11G_pxXp9NY,1256
|
|
5
5
|
sqlframe/base/catalog.py,sha256=ATDGirouUjal05P4ymL-wIi8rgjg_8w4PoACamiO64A,37245
|
|
6
|
-
sqlframe/base/column.py,sha256=
|
|
7
|
-
sqlframe/base/dataframe.py,sha256=
|
|
8
|
-
sqlframe/base/decorators.py,sha256=
|
|
6
|
+
sqlframe/base/column.py,sha256=p3VrtATBmjAYHollFcsdps2UJTNC-Pvyg4Zt7y4CK9w,15358
|
|
7
|
+
sqlframe/base/dataframe.py,sha256=9PuqC9dBficSE-Y1v_BHyk4gK-Hd43SaVBmxBeyNnD8,62939
|
|
8
|
+
sqlframe/base/decorators.py,sha256=I5osMgx9BuCgbtp4jVM2DNwYJVLzCv-OtTedhQEik0g,1882
|
|
9
9
|
sqlframe/base/exceptions.py,sha256=pCB9hXX4jxZWzNg3JN1i38cv3BmpUlee5NoLYx3YXIQ,208
|
|
10
10
|
sqlframe/base/function_alternatives.py,sha256=to0kv3MTJmQFeVTMcitz0AxBIoUJC3cu5LkEY5aJpoo,31318
|
|
11
11
|
sqlframe/base/functions.py,sha256=iVe8AbXGX_gXnkQ1N-clX6rihsonfzJ84_YvWzhB2FM,53540
|
|
12
12
|
sqlframe/base/group.py,sha256=TES9CleVmH3x-0X-tqmuUKfCKSWjH5vg1aU3R6dDmFc,4059
|
|
13
13
|
sqlframe/base/normalize.py,sha256=gRBn-PziFdE-CHtPJMkMl7y_YH0mauUcD4zfgyyvlpw,3565
|
|
14
14
|
sqlframe/base/operations.py,sha256=-AhNuEzcV7ZExoP1oY3blaKip-joQyJeQVvfBTs_2g4,3456
|
|
15
|
-
sqlframe/base/readerwriter.py,sha256=
|
|
16
|
-
sqlframe/base/session.py,sha256
|
|
15
|
+
sqlframe/base/readerwriter.py,sha256=5NPQMiOrw6I54U243R_6-ynnWYsNksgqwRpPp4IFjIw,25288
|
|
16
|
+
sqlframe/base/session.py,sha256=-h7qcOPRw9KBJPg_V6Tlr8Z2SmcsgAWruBo34o6zfrQ,21795
|
|
17
17
|
sqlframe/base/transforms.py,sha256=y0j3SGDz3XCmNGrvassk1S-owllUWfkHyMgZlY6SFO4,467
|
|
18
18
|
sqlframe/base/types.py,sha256=aJT5YXr-M_LAfUM0uK4asfbrQFab_xmsp1CP2zkG8p0,11924
|
|
19
|
-
sqlframe/base/util.py,sha256=
|
|
19
|
+
sqlframe/base/util.py,sha256=wdATi7STt-FfXrX9TPRkw4PFJP7uAsK_K9YkKSrd0qU,8824
|
|
20
20
|
sqlframe/base/window.py,sha256=8hOv-ignPPIsZA9FzvYzcLE9J_glalVaYjIAUdRUX3o,4943
|
|
21
21
|
sqlframe/base/mixins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
-
sqlframe/base/mixins/catalog_mixins.py,sha256=
|
|
22
|
+
sqlframe/base/mixins/catalog_mixins.py,sha256=NhuPGxIqPjyuC_V_NALN1sn9v9h0-xwFOlJyJgsvyek,14212
|
|
23
23
|
sqlframe/base/mixins/dataframe_mixins.py,sha256=U2tKIY5pCLnoPy1boAQ1YWLgK1E-ZT4x47oRWtGoYLQ,2360
|
|
24
|
-
sqlframe/base/mixins/readwriter_mixins.py,sha256=
|
|
24
|
+
sqlframe/base/mixins/readwriter_mixins.py,sha256=QnxGVL8ftZfYlBNG0Bl24N_bnA2YioSxUsTSgKIbuvQ,4723
|
|
25
25
|
sqlframe/bigquery/__init__.py,sha256=i2NsMbiXOj2xphCtPuNk6cVw4iYeq5_B1I9dVI9aGAk,712
|
|
26
26
|
sqlframe/bigquery/catalog.py,sha256=h3aQAQAJg6MMvFpP8Ku0S4pcx30n5qYrqHhWSomxb6A,9319
|
|
27
27
|
sqlframe/bigquery/column.py,sha256=E1tUa62Y5HajkhgFuebU9zohrGyieudcHzTT8gfalio,40
|
|
@@ -63,13 +63,13 @@ sqlframe/redshift/session.py,sha256=GA2CFGJckissPYmcXWR1R3QOOoSa9XuLOR6sWFFuC1k,
|
|
|
63
63
|
sqlframe/redshift/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
|
|
64
64
|
sqlframe/redshift/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
|
|
65
65
|
sqlframe/snowflake/__init__.py,sha256=nuQ3cuHjDpW4ELZfbd2qOYmtXmcYl7MtsrdOrRdozo0,746
|
|
66
|
-
sqlframe/snowflake/catalog.py,sha256=
|
|
66
|
+
sqlframe/snowflake/catalog.py,sha256=uDjBgDdCyxaDkGNX_8tb-lol7MwwazcClUBAZsOSj70,5014
|
|
67
67
|
sqlframe/snowflake/column.py,sha256=E1tUa62Y5HajkhgFuebU9zohrGyieudcHzTT8gfalio,40
|
|
68
68
|
sqlframe/snowflake/dataframe.py,sha256=OJ27NudBUE3XX9mc8ywooGhYV4ijF9nX2K_nkHRcTx4,1393
|
|
69
69
|
sqlframe/snowflake/functions.py,sha256=ZYX9gyPvmpKoLi_7uQdB0uPQNTREOAJD0aCcccX1iPc,456
|
|
70
70
|
sqlframe/snowflake/group.py,sha256=pPP1l2RRo_LgkXrji8a87n2PKo-63ZRPT-WUtvVcBME,395
|
|
71
71
|
sqlframe/snowflake/readwriter.py,sha256=yhRc2HcMq6PwV3ghZWC-q-qaE7LE4aEjZEXCip4OOlQ,884
|
|
72
|
-
sqlframe/snowflake/session.py,sha256=
|
|
72
|
+
sqlframe/snowflake/session.py,sha256=QKdxXgK9_YgxoyxzEd73ot4t0M6Dz4em09JdVMYxVPI,2584
|
|
73
73
|
sqlframe/snowflake/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
|
|
74
74
|
sqlframe/snowflake/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
|
|
75
75
|
sqlframe/spark/__init__.py,sha256=jamKYQtQaKjjXnQ01QGPHvatbrZSw9sWno_VOUGSz6I,712
|
|
@@ -79,7 +79,7 @@ sqlframe/spark/dataframe.py,sha256=V3z5Bx9snLgYh4bDwJfJb5mj1P7UsZF8DMlLwZXopBg,1
|
|
|
79
79
|
sqlframe/spark/functions.py,sha256=eSGMM2DXcj17nIPH5ZDLG95ZMuE7F8Qvn0IqGO_wQVw,586
|
|
80
80
|
sqlframe/spark/group.py,sha256=MrvV_v-YkBc6T1zz882WrEqtWjlooWIyHBCmTQg3fCA,379
|
|
81
81
|
sqlframe/spark/readwriter.py,sha256=w68EImTcGJv64X7pc1tk5tDjDxb1nAnn-MiIaaN9Dc8,812
|
|
82
|
-
sqlframe/spark/session.py,sha256=
|
|
82
|
+
sqlframe/spark/session.py,sha256=D7gss1QGSvSLAF86WrLKvIbn0UC2YiMZnmVdCqv1SZA,2628
|
|
83
83
|
sqlframe/spark/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
|
|
84
84
|
sqlframe/spark/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
|
|
85
85
|
sqlframe/standalone/__init__.py,sha256=yu4A97HwhyDwllDEzG7io4ScyWipWSAH2tqUKS545OA,767
|
|
@@ -92,8 +92,8 @@ sqlframe/standalone/readwriter.py,sha256=EZNyDJ4ID6sGNog3uP4-e9RvchX4biJJDNtc5hk
|
|
|
92
92
|
sqlframe/standalone/session.py,sha256=wQmdu2sv6KMTAv0LRFk7TY7yzlh3xvmsyqilEtRecbY,1191
|
|
93
93
|
sqlframe/standalone/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
|
|
94
94
|
sqlframe/standalone/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
|
|
95
|
-
sqlframe-1.
|
|
96
|
-
sqlframe-1.
|
|
97
|
-
sqlframe-1.
|
|
98
|
-
sqlframe-1.
|
|
99
|
-
sqlframe-1.
|
|
95
|
+
sqlframe-1.4.0.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
|
|
96
|
+
sqlframe-1.4.0.dist-info/METADATA,sha256=nnz73ML6w8WyctFzwiaKVVNr9RQwmpmfckrcKqEX_PE,7219
|
|
97
|
+
sqlframe-1.4.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
|
98
|
+
sqlframe-1.4.0.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
|
|
99
|
+
sqlframe-1.4.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|