sqlframe 1.14.0__py3-none-any.whl → 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/_typing.py +1 -0
- sqlframe/base/catalog.py +36 -13
- sqlframe/base/column.py +11 -9
- sqlframe/base/dataframe.py +72 -79
- sqlframe/base/decorators.py +0 -38
- sqlframe/base/function_alternatives.py +36 -25
- sqlframe/base/functions.py +88 -28
- sqlframe/base/mixins/catalog_mixins.py +156 -45
- sqlframe/base/mixins/dataframe_mixins.py +1 -1
- sqlframe/base/readerwriter.py +12 -14
- sqlframe/base/session.py +157 -84
- sqlframe/base/udf.py +36 -0
- sqlframe/base/util.py +71 -20
- sqlframe/bigquery/catalog.py +79 -28
- sqlframe/bigquery/functions.py +5 -8
- sqlframe/bigquery/session.py +4 -2
- sqlframe/bigquery/udf.py +11 -0
- sqlframe/duckdb/catalog.py +30 -13
- sqlframe/duckdb/dataframe.py +5 -0
- sqlframe/duckdb/functions.py +2 -0
- sqlframe/duckdb/readwriter.py +7 -6
- sqlframe/duckdb/session.py +8 -2
- sqlframe/duckdb/udf.py +19 -0
- sqlframe/postgres/catalog.py +30 -18
- sqlframe/postgres/functions.py +2 -0
- sqlframe/postgres/session.py +16 -5
- sqlframe/postgres/udf.py +11 -0
- sqlframe/redshift/catalog.py +28 -13
- sqlframe/redshift/session.py +4 -2
- sqlframe/redshift/udf.py +11 -0
- sqlframe/snowflake/catalog.py +64 -24
- sqlframe/snowflake/dataframe.py +9 -5
- sqlframe/snowflake/functions.py +1 -0
- sqlframe/snowflake/session.py +5 -2
- sqlframe/snowflake/udf.py +11 -0
- sqlframe/spark/catalog.py +180 -10
- sqlframe/spark/session.py +46 -14
- sqlframe/spark/udf.py +34 -0
- sqlframe/standalone/session.py +3 -0
- sqlframe/standalone/udf.py +11 -0
- {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/METADATA +14 -10
- {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/RECORD +46 -38
- {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/LICENSE +0 -0
- {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/WHEEL +0 -0
- {sqlframe-1.14.0.dist-info → sqlframe-2.1.0.dist-info}/top_level.txt +0 -0
sqlframe/_version.py
CHANGED
sqlframe/base/_typing.py
CHANGED
|
@@ -24,6 +24,7 @@ OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert]
|
|
|
24
24
|
StorageLevel = str
|
|
25
25
|
PathOrPaths = t.Union[str, t.List[str]]
|
|
26
26
|
OptionalPrimitiveType = t.Optional[PrimitiveType]
|
|
27
|
+
DataTypeOrString = t.Union[DataType, str]
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
class UserDefinedFunctionLike(t.Protocol):
|
sqlframe/base/catalog.py
CHANGED
|
@@ -3,12 +3,12 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import typing as t
|
|
6
|
+
from collections import defaultdict
|
|
6
7
|
|
|
7
8
|
from sqlglot import MappingSchema, exp
|
|
8
9
|
|
|
9
|
-
from sqlframe.base.decorators import normalize
|
|
10
10
|
from sqlframe.base.exceptions import TableSchemaError
|
|
11
|
-
from sqlframe.base.util import ensure_column_mapping, to_schema
|
|
11
|
+
from sqlframe.base.util import ensure_column_mapping, normalize_string, to_schema
|
|
12
12
|
|
|
13
13
|
if t.TYPE_CHECKING:
|
|
14
14
|
from sqlglot.schema import ColumnMapping
|
|
@@ -33,6 +33,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
33
33
|
"""Create a new Catalog that wraps the underlying JVM object."""
|
|
34
34
|
self.session = sparkSession
|
|
35
35
|
self._schema = schema or MappingSchema()
|
|
36
|
+
self._quoted_columns: t.Dict[exp.Table, t.List[str]] = defaultdict(list)
|
|
36
37
|
|
|
37
38
|
@property
|
|
38
39
|
def spark(self) -> SESSION:
|
|
@@ -52,7 +53,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
52
53
|
def get_columns_from_schema(self, table: exp.Table | str) -> t.Dict[str, exp.DataType]:
|
|
53
54
|
table = self.ensure_table(table)
|
|
54
55
|
return {
|
|
55
|
-
exp.column(name, quoted=
|
|
56
|
+
exp.column(name, quoted=name in self._quoted_columns[table]).sql(
|
|
56
57
|
dialect=self.session.input_dialect
|
|
57
58
|
): exp.DataType.build(dtype, dialect=self.session.input_dialect)
|
|
58
59
|
for name, dtype in self._schema.find(table, raise_on_missing=True).items() # type: ignore
|
|
@@ -64,9 +65,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
64
65
|
if not columns:
|
|
65
66
|
return {}
|
|
66
67
|
return {
|
|
67
|
-
exp.
|
|
68
|
-
dialect=self.session.input_dialect
|
|
69
|
-
): exp.DataType.build(c.dataType, dialect=self.session.input_dialect)
|
|
68
|
+
c.name: exp.DataType.build(c.dataType, dialect=self.session.output_dialect)
|
|
70
69
|
for c in columns
|
|
71
70
|
}
|
|
72
71
|
|
|
@@ -79,16 +78,30 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
79
78
|
return
|
|
80
79
|
if not column_mapping:
|
|
81
80
|
try:
|
|
82
|
-
column_mapping =
|
|
81
|
+
column_mapping = {
|
|
82
|
+
normalize_string(
|
|
83
|
+
k, from_dialect="output", to_dialect="input", is_column=True
|
|
84
|
+
): normalize_string(
|
|
85
|
+
v.sql(dialect=self.session.output_dialect),
|
|
86
|
+
from_dialect="output",
|
|
87
|
+
to_dialect="input",
|
|
88
|
+
is_datatype=True,
|
|
89
|
+
)
|
|
90
|
+
for k, v in self.get_columns(table).items()
|
|
91
|
+
}
|
|
83
92
|
except NotImplementedError:
|
|
84
93
|
# TODO: Add doc link
|
|
85
94
|
raise TableSchemaError(
|
|
86
95
|
"This session does not have access to a catalog that can lookup column information. See docs for explicitly defining columns or using a session that can automatically determine this."
|
|
87
96
|
)
|
|
88
97
|
column_mapping = ensure_column_mapping(column_mapping) # type: ignore
|
|
98
|
+
for column_name in column_mapping:
|
|
99
|
+
column = exp.to_column(column_name, dialect=self.session.input_dialect)
|
|
100
|
+
if column.this.quoted:
|
|
101
|
+
self._quoted_columns[table].append(column.this.name)
|
|
102
|
+
|
|
89
103
|
self._schema.add_table(table, column_mapping, dialect=self.session.input_dialect)
|
|
90
104
|
|
|
91
|
-
@normalize(["dbName"])
|
|
92
105
|
def getDatabase(self, dbName: str) -> Database:
|
|
93
106
|
"""Get the database with the specified name.
|
|
94
107
|
This throws an :class:`AnalysisException` when the database cannot be found.
|
|
@@ -115,6 +128,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
115
128
|
>>> spark.catalog.getDatabase("spark_catalog.default")
|
|
116
129
|
Database(name='default', catalog='spark_catalog', description='default database', ...
|
|
117
130
|
"""
|
|
131
|
+
dbName = normalize_string(dbName, from_dialect="input", is_schema=True)
|
|
118
132
|
schema = to_schema(dbName, dialect=self.session.input_dialect)
|
|
119
133
|
database_name = schema.db
|
|
120
134
|
databases = self.listDatabases(pattern=database_name)
|
|
@@ -122,12 +136,16 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
122
136
|
raise ValueError(f"Database '{dbName}' not found")
|
|
123
137
|
if len(databases) > 1:
|
|
124
138
|
if schema.catalog is not None:
|
|
125
|
-
filtered_databases = [
|
|
139
|
+
filtered_databases = [
|
|
140
|
+
db
|
|
141
|
+
for db in databases
|
|
142
|
+
if normalize_string(db.catalog, from_dialect="output", to_dialect="input") # type: ignore
|
|
143
|
+
== schema.catalog
|
|
144
|
+
]
|
|
126
145
|
if filtered_databases:
|
|
127
146
|
return filtered_databases[0]
|
|
128
147
|
return databases[0]
|
|
129
148
|
|
|
130
|
-
@normalize(["dbName"])
|
|
131
149
|
def databaseExists(self, dbName: str) -> bool:
|
|
132
150
|
"""Check if the database with the specified name exists.
|
|
133
151
|
|
|
@@ -168,7 +186,6 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
168
186
|
except ValueError:
|
|
169
187
|
return False
|
|
170
188
|
|
|
171
|
-
@normalize(["tableName"])
|
|
172
189
|
def getTable(self, tableName: str) -> Table:
|
|
173
190
|
"""Get the table or view with the specified name. This table can be a temporary view or a
|
|
174
191
|
table/view. This throws an :class:`AnalysisException` when no Table can be found.
|
|
@@ -210,13 +227,18 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
210
227
|
...
|
|
211
228
|
AnalysisException: ...
|
|
212
229
|
"""
|
|
230
|
+
tableName = normalize_string(tableName, from_dialect="input", is_table=True)
|
|
213
231
|
table = exp.to_table(tableName, dialect=self.session.input_dialect)
|
|
214
232
|
schema = table.copy()
|
|
215
233
|
schema.set("this", None)
|
|
216
234
|
tables = self.listTables(
|
|
217
235
|
schema.sql(dialect=self.session.input_dialect) if schema.db else None
|
|
218
236
|
)
|
|
219
|
-
matching_tables = [
|
|
237
|
+
matching_tables = [
|
|
238
|
+
t
|
|
239
|
+
for t in tables
|
|
240
|
+
if normalize_string(t.name, from_dialect="output", to_dialect="input") == table.name
|
|
241
|
+
]
|
|
220
242
|
if not matching_tables:
|
|
221
243
|
raise ValueError(f"Table '{tableName}' not found")
|
|
222
244
|
return matching_tables[0]
|
|
@@ -315,7 +337,6 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
315
337
|
raise ValueError(f"Function '{functionName}' not found")
|
|
316
338
|
return matching_functions[0]
|
|
317
339
|
|
|
318
|
-
@normalize(["tableName", "dbName"])
|
|
319
340
|
def tableExists(self, tableName: str, dbName: t.Optional[str] = None) -> bool:
|
|
320
341
|
"""Check if the table or view with the specified name exists.
|
|
321
342
|
This can either be a temporary view or a table/view.
|
|
@@ -389,6 +410,8 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
|
|
|
389
410
|
>>> spark.catalog.tableExists("view1")
|
|
390
411
|
False
|
|
391
412
|
"""
|
|
413
|
+
tableName = normalize_string(tableName, from_dialect="input", is_table=True)
|
|
414
|
+
dbName = normalize_string(dbName, from_dialect="input", is_schema=True) if dbName else None
|
|
392
415
|
table = exp.to_table(tableName, dialect=self.session.input_dialect)
|
|
393
416
|
schema_arg = to_schema(dbName, dialect=self.session.input_dialect) if dbName else None
|
|
394
417
|
if not table.db:
|
sqlframe/base/column.py
CHANGED
|
@@ -7,11 +7,11 @@ import math
|
|
|
7
7
|
import typing as t
|
|
8
8
|
|
|
9
9
|
import sqlglot
|
|
10
|
+
from sqlglot import Dialect
|
|
10
11
|
from sqlglot import expressions as exp
|
|
11
12
|
from sqlglot.helper import flatten, is_iterable
|
|
12
13
|
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
|
13
14
|
|
|
14
|
-
from sqlframe.base.decorators import normalize
|
|
15
15
|
from sqlframe.base.exceptions import UnsupportedOperationError
|
|
16
16
|
from sqlframe.base.types import DataType
|
|
17
17
|
from sqlframe.base.util import get_func_from_session, quote_preserving_alias_or_name
|
|
@@ -211,9 +211,8 @@ class Column:
|
|
|
211
211
|
def binary_op(
|
|
212
212
|
self, klass: t.Callable, other: ColumnOrLiteral, paren: bool = False, **kwargs
|
|
213
213
|
) -> Column:
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
)
|
|
214
|
+
other = self._lit(other) if isinstance(other, str) else Column(other)
|
|
215
|
+
op = klass(this=self.column_expression, expression=other.column_expression, **kwargs)
|
|
217
216
|
if paren:
|
|
218
217
|
return Column(exp.Paren(this=op))
|
|
219
218
|
return Column(op)
|
|
@@ -221,9 +220,8 @@ class Column:
|
|
|
221
220
|
def inverse_binary_op(
|
|
222
221
|
self, klass: t.Callable, other: ColumnOrLiteral, paren: bool = False, **kwargs
|
|
223
222
|
) -> Column:
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
)
|
|
223
|
+
other = self._lit(other) if isinstance(other, str) else Column(other)
|
|
224
|
+
op = klass(this=other.column_expression, expression=self.column_expression, **kwargs)
|
|
227
225
|
if paren:
|
|
228
226
|
return Column(exp.Paren(this=op))
|
|
229
227
|
return Column(op)
|
|
@@ -340,13 +338,17 @@ class Column:
|
|
|
340
338
|
new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null()))
|
|
341
339
|
return Column(new_expression)
|
|
342
340
|
|
|
343
|
-
def cast(
|
|
341
|
+
def cast(
|
|
342
|
+
self, dataType: t.Union[str, DataType], dialect: t.Optional[t.Union[str, Dialect]] = None
|
|
343
|
+
) -> Column:
|
|
344
344
|
from sqlframe.base.session import _BaseSession
|
|
345
345
|
|
|
346
346
|
if isinstance(dataType, DataType):
|
|
347
347
|
dataType = dataType.simpleString()
|
|
348
348
|
return Column(
|
|
349
|
-
exp.cast(
|
|
349
|
+
exp.cast(
|
|
350
|
+
self.column_expression, dataType, dialect=dialect or _BaseSession().input_dialect
|
|
351
|
+
)
|
|
350
352
|
)
|
|
351
353
|
|
|
352
354
|
def startswith(self, value: t.Union[str, Column]) -> Column:
|
sqlframe/base/dataframe.py
CHANGED
|
@@ -23,12 +23,12 @@ from sqlglot.optimizer.qualify import qualify
|
|
|
23
23
|
from sqlglot.optimizer.qualify_columns import quote_identifiers
|
|
24
24
|
|
|
25
25
|
from sqlframe.base.catalog import Column as CatalogColumn
|
|
26
|
-
from sqlframe.base.decorators import normalize
|
|
27
26
|
from sqlframe.base.operations import Operation, operation
|
|
28
27
|
from sqlframe.base.transforms import replace_id_value
|
|
29
28
|
from sqlframe.base.util import (
|
|
30
29
|
get_func_from_session,
|
|
31
30
|
get_tables_from_expression_with_join,
|
|
31
|
+
normalize_string,
|
|
32
32
|
quote_preserving_alias_or_name,
|
|
33
33
|
sqlglot_to_spark,
|
|
34
34
|
verify_openai_installed,
|
|
@@ -41,6 +41,7 @@ else:
|
|
|
41
41
|
|
|
42
42
|
if t.TYPE_CHECKING:
|
|
43
43
|
import pandas as pd
|
|
44
|
+
from pyarrow import Table as ArrowTable
|
|
44
45
|
from sqlglot.dialects.dialect import DialectType
|
|
45
46
|
|
|
46
47
|
from sqlframe.base._typing import (
|
|
@@ -535,39 +536,11 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
535
536
|
)
|
|
536
537
|
return [col]
|
|
537
538
|
|
|
538
|
-
|
|
539
|
-
def sql(
|
|
540
|
-
self,
|
|
541
|
-
dialect: DialectType = ...,
|
|
542
|
-
optimize: bool = ...,
|
|
543
|
-
pretty: bool = ...,
|
|
544
|
-
*,
|
|
545
|
-
as_list: t.Literal[False],
|
|
546
|
-
**kwargs: t.Any,
|
|
547
|
-
) -> str: ...
|
|
548
|
-
|
|
549
|
-
@t.overload
|
|
550
|
-
def sql(
|
|
551
|
-
self,
|
|
552
|
-
dialect: DialectType = ...,
|
|
553
|
-
optimize: bool = ...,
|
|
554
|
-
pretty: bool = ...,
|
|
555
|
-
*,
|
|
556
|
-
as_list: t.Literal[True],
|
|
557
|
-
**kwargs: t.Any,
|
|
558
|
-
) -> t.List[str]: ...
|
|
559
|
-
|
|
560
|
-
def sql(
|
|
539
|
+
def _get_expressions(
|
|
561
540
|
self,
|
|
562
|
-
dialect: DialectType = None,
|
|
563
541
|
optimize: bool = True,
|
|
564
|
-
pretty: bool = True,
|
|
565
542
|
openai_config: t.Optional[t.Union[t.Dict[str, t.Any], OpenAIConfig]] = None,
|
|
566
|
-
|
|
567
|
-
**kwargs,
|
|
568
|
-
) -> t.Union[str, t.List[str]]:
|
|
569
|
-
dialect = Dialect.get_or_raise(dialect or self.session.output_dialect)
|
|
570
|
-
|
|
543
|
+
) -> t.List[exp.Expression]:
|
|
571
544
|
df = self._resolve_pending_hints()
|
|
572
545
|
select_expressions = df._get_select_expressions()
|
|
573
546
|
output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = []
|
|
@@ -583,12 +556,16 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
583
556
|
replace_id_value, replacement_mapping
|
|
584
557
|
).assert_is(exp.Select)
|
|
585
558
|
if optimize:
|
|
586
|
-
quote_identifiers(select_expression, dialect=dialect)
|
|
587
559
|
select_expression = t.cast(
|
|
588
|
-
exp.Select,
|
|
560
|
+
exp.Select,
|
|
561
|
+
self.session._optimize(select_expression),
|
|
589
562
|
)
|
|
590
563
|
elif openai_config:
|
|
591
|
-
qualify(
|
|
564
|
+
qualify(
|
|
565
|
+
select_expression,
|
|
566
|
+
dialect=self.session.input_dialect,
|
|
567
|
+
schema=self.session.catalog._schema,
|
|
568
|
+
)
|
|
592
569
|
pushdown_projections(select_expression, schema=self.session.catalog._schema)
|
|
593
570
|
|
|
594
571
|
select_expression = df._replace_cte_names_with_hashes(select_expression)
|
|
@@ -606,7 +583,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
606
583
|
cache_table_name,
|
|
607
584
|
{
|
|
608
585
|
quote_preserving_alias_or_name(expression): expression.type.sql(
|
|
609
|
-
dialect=
|
|
586
|
+
dialect=self.session.input_dialect
|
|
610
587
|
)
|
|
611
588
|
if expression.type
|
|
612
589
|
else "UNKNOWN"
|
|
@@ -642,10 +619,43 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
642
619
|
raise ValueError(f"Invalid expression type: {expression_type}")
|
|
643
620
|
|
|
644
621
|
output_expressions.append(expression)
|
|
622
|
+
return output_expressions # type: ignore
|
|
623
|
+
|
|
624
|
+
@t.overload
|
|
625
|
+
def sql(
|
|
626
|
+
self,
|
|
627
|
+
dialect: DialectType = ...,
|
|
628
|
+
optimize: bool = ...,
|
|
629
|
+
pretty: bool = ...,
|
|
630
|
+
*,
|
|
631
|
+
as_list: t.Literal[False],
|
|
632
|
+
**kwargs: t.Any,
|
|
633
|
+
) -> str: ...
|
|
634
|
+
|
|
635
|
+
@t.overload
|
|
636
|
+
def sql(
|
|
637
|
+
self,
|
|
638
|
+
dialect: DialectType = ...,
|
|
639
|
+
optimize: bool = ...,
|
|
640
|
+
pretty: bool = ...,
|
|
641
|
+
*,
|
|
642
|
+
as_list: t.Literal[True],
|
|
643
|
+
**kwargs: t.Any,
|
|
644
|
+
) -> t.List[str]: ...
|
|
645
645
|
|
|
646
|
+
def sql(
|
|
647
|
+
self,
|
|
648
|
+
dialect: DialectType = None,
|
|
649
|
+
optimize: bool = True,
|
|
650
|
+
pretty: bool = True,
|
|
651
|
+
openai_config: t.Optional[t.Union[t.Dict[str, t.Any], OpenAIConfig]] = None,
|
|
652
|
+
as_list: bool = False,
|
|
653
|
+
**kwargs,
|
|
654
|
+
) -> t.Union[str, t.List[str]]:
|
|
655
|
+
dialect = Dialect.get_or_raise(dialect) if dialect else self.session.output_dialect
|
|
646
656
|
results = []
|
|
647
|
-
for expression in
|
|
648
|
-
sql =
|
|
657
|
+
for expression in self._get_expressions(optimize=optimize, openai_config=openai_config):
|
|
658
|
+
sql = self.session._to_sql(expression, dialect=dialect, pretty=pretty, **kwargs)
|
|
649
659
|
if openai_config:
|
|
650
660
|
assert isinstance(openai_config, OpenAIConfig)
|
|
651
661
|
verify_openai_installed()
|
|
@@ -828,7 +838,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
828
838
|
from sqlframe.base.functions import coalesce
|
|
829
839
|
|
|
830
840
|
if on is None:
|
|
831
|
-
logger.warning("Got no value for on. This appears change the join to a cross join.")
|
|
841
|
+
logger.warning("Got no value for on. This appears to change the join to a cross join.")
|
|
832
842
|
how = "cross"
|
|
833
843
|
other_df = other_df._convert_leaf_to_cte()
|
|
834
844
|
# We will determine actual "join on" expression later so we don't provide it at first
|
|
@@ -871,13 +881,12 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
871
881
|
)
|
|
872
882
|
join_column_pairs.append((left_column, right_column))
|
|
873
883
|
num_matching_ctes += 1
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
elif num_matching_ctes == 0:
|
|
884
|
+
# We only want to match one table to the column and that should be matched left -> right
|
|
885
|
+
# so we break after the first match
|
|
886
|
+
break
|
|
887
|
+
if num_matching_ctes == 0:
|
|
879
888
|
raise ValueError(
|
|
880
|
-
f"Column {join_column.alias_or_name} does not exist in any of the tables."
|
|
889
|
+
f"Column `{join_column.alias_or_name}` does not exist in any of the tables."
|
|
881
890
|
)
|
|
882
891
|
join_clause = functools.reduce(
|
|
883
892
|
lambda x, y: x & y,
|
|
@@ -1154,7 +1163,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
1154
1163
|
if len(sql_queries) > 1:
|
|
1155
1164
|
raise ValueError("Cannot explain a DataFrame with multiple queries")
|
|
1156
1165
|
sql_query = "EXPLAIN " + sql_queries[0]
|
|
1157
|
-
self.session._execute(sql_query
|
|
1166
|
+
self.session._execute(sql_query)
|
|
1158
1167
|
|
|
1159
1168
|
@operation(Operation.FROM)
|
|
1160
1169
|
def fillna(
|
|
@@ -1373,23 +1382,20 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
1373
1382
|
self._ensure_and_normalize_col(k).alias_or_name: self._ensure_and_normalize_col(v)
|
|
1374
1383
|
for k, v in colsMap[0].items()
|
|
1375
1384
|
}
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
updated_expression = self.expression.copy()
|
|
1380
|
-
select_columns = []
|
|
1385
|
+
existing_cols = self._get_outer_select_columns(self.expression)
|
|
1386
|
+
existing_col_names = [x.alias_or_name for x in existing_cols]
|
|
1387
|
+
select_columns = existing_cols
|
|
1381
1388
|
for column_name, col_value in col_map.items():
|
|
1382
1389
|
existing_col_index = (
|
|
1383
1390
|
existing_col_names.index(column_name) if column_name in existing_col_names else None
|
|
1384
1391
|
)
|
|
1385
|
-
if existing_col_index:
|
|
1386
|
-
|
|
1392
|
+
if existing_col_index is not None:
|
|
1393
|
+
select_columns[existing_col_index] = col_value.alias( # type: ignore
|
|
1387
1394
|
column_name
|
|
1388
1395
|
).expression
|
|
1389
1396
|
else:
|
|
1390
1397
|
select_columns.append(col_value.alias(column_name))
|
|
1391
|
-
|
|
1392
|
-
return df.select.__wrapped__(df, *select_columns, append=True) # type: ignore
|
|
1398
|
+
return self.select.__wrapped__(self, *select_columns) # type: ignore
|
|
1393
1399
|
|
|
1394
1400
|
@operation(Operation.SELECT)
|
|
1395
1401
|
def drop(self, *cols: t.Union[str, Column]) -> Self:
|
|
@@ -1539,10 +1545,10 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
1539
1545
|
return self._group_data(self, grouping_columns, self.last_op)
|
|
1540
1546
|
|
|
1541
1547
|
def collect(self) -> t.List[Row]:
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
return
|
|
1548
|
+
return self._collect()
|
|
1549
|
+
|
|
1550
|
+
def _collect(self, **kwargs) -> t.List[Row]:
|
|
1551
|
+
return self.session._collect(self._get_expressions(optimize=False), **kwargs)
|
|
1546
1552
|
|
|
1547
1553
|
@t.overload
|
|
1548
1554
|
def head(self) -> t.Optional[Row]: ...
|
|
@@ -1569,11 +1575,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
1569
1575
|
logger.warning("Truncate is ignored so full results will be displayed")
|
|
1570
1576
|
# Make sure that the limit we add doesn't affect the results
|
|
1571
1577
|
df = self._convert_leaf_to_cte()
|
|
1572
|
-
|
|
1573
|
-
pretty=False, optimize=False, dialect=self.session.output_dialect, as_list=True
|
|
1574
|
-
)
|
|
1575
|
-
for sql in ensure_list(sql):
|
|
1576
|
-
result = self.session._fetch_rows(sql)
|
|
1578
|
+
result = df.limit(n).collect()
|
|
1577
1579
|
table = PrettyTable()
|
|
1578
1580
|
if row := seq_get(result, 0):
|
|
1579
1581
|
table.field_names = row._unique_field_names
|
|
@@ -1612,18 +1614,10 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
1612
1614
|
)
|
|
1613
1615
|
|
|
1614
1616
|
def toPandas(self) -> pd.DataFrame:
|
|
1615
|
-
|
|
1616
|
-
|
|
1617
|
-
)
|
|
1618
|
-
sqls = [None] + self.sql(**sql_kwargs) # type: ignore
|
|
1619
|
-
for sql in self.sql(**sql_kwargs)[:-1]: # type: ignore
|
|
1620
|
-
if sql:
|
|
1621
|
-
self.session._execute(sql)
|
|
1622
|
-
assert sqls[-1] is not None
|
|
1623
|
-
return self.session._fetchdf(sqls[-1])
|
|
1624
|
-
|
|
1625
|
-
@normalize("name")
|
|
1617
|
+
return self.session._fetchdf(self._get_expressions(optimize=False))
|
|
1618
|
+
|
|
1626
1619
|
def createOrReplaceTempView(self, name: str) -> None:
|
|
1620
|
+
name = normalize_string(name, from_dialect="input")
|
|
1627
1621
|
self.session.temp_views[name] = self.copy()._convert_leaf_to_cte()
|
|
1628
1622
|
|
|
1629
1623
|
def count(self) -> int:
|
|
@@ -1632,11 +1626,7 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
1632
1626
|
|
|
1633
1627
|
df = self._convert_leaf_to_cte()
|
|
1634
1628
|
df = self.copy(expression=df.expression.select("count(*)", append=False))
|
|
1635
|
-
|
|
1636
|
-
dialect=self.session.output_dialect, pretty=False, optimize=False, as_list=True
|
|
1637
|
-
):
|
|
1638
|
-
result = self.session._fetch_rows(sql)
|
|
1639
|
-
return result[0][0]
|
|
1629
|
+
return df.collect()[0][0]
|
|
1640
1630
|
|
|
1641
1631
|
def createGlobalTempView(self, name: str) -> None:
|
|
1642
1632
|
raise NotImplementedError("Global temp views are not yet supported")
|
|
@@ -1816,3 +1806,6 @@ class _BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
|
|
|
1816
1806
|
col_func = get_func_from_session("col")
|
|
1817
1807
|
|
|
1818
1808
|
return self.select(covar_samp(col_func(col1), col_func(col2))).collect()[0][0]
|
|
1809
|
+
|
|
1810
|
+
def toArrow(self) -> ArrowTable:
|
|
1811
|
+
raise NotImplementedError("Arrow conversion is not supported by this engine")
|
sqlframe/base/decorators.py
CHANGED
|
@@ -1,50 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import functools
|
|
4
3
|
import typing as t
|
|
5
4
|
|
|
6
|
-
from sqlglot import parse_one
|
|
7
5
|
from sqlglot.helper import ensure_list
|
|
8
|
-
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
|
|
9
|
-
|
|
10
|
-
if t.TYPE_CHECKING:
|
|
11
|
-
from sqlframe.base.catalog import _BaseCatalog
|
|
12
6
|
|
|
13
7
|
CALLING_CLASS = t.TypeVar("CALLING_CLASS")
|
|
14
8
|
|
|
15
9
|
|
|
16
|
-
def normalize(normalize_kwargs: t.Union[str, t.List[str]]) -> t.Callable[[t.Callable], t.Callable]:
|
|
17
|
-
"""
|
|
18
|
-
Decorator used to normalize identifiers in the kwargs of a method.
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
def decorator(func: t.Callable) -> t.Callable:
|
|
22
|
-
@functools.wraps(func)
|
|
23
|
-
def wrapper(self: CALLING_CLASS, *args, **kwargs) -> CALLING_CLASS:
|
|
24
|
-
from sqlframe.base.session import _BaseSession
|
|
25
|
-
|
|
26
|
-
input_dialect = _BaseSession().input_dialect
|
|
27
|
-
kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
|
|
28
|
-
for kwarg in ensure_list(normalize_kwargs):
|
|
29
|
-
if kwarg in kwargs:
|
|
30
|
-
value = kwargs.get(kwarg)
|
|
31
|
-
if value:
|
|
32
|
-
expression = (
|
|
33
|
-
parse_one(value, dialect=input_dialect)
|
|
34
|
-
if isinstance(value, str)
|
|
35
|
-
else value
|
|
36
|
-
)
|
|
37
|
-
kwargs[kwarg] = normalize_identifiers(expression, input_dialect).sql(
|
|
38
|
-
dialect=input_dialect
|
|
39
|
-
)
|
|
40
|
-
return func(self, **kwargs)
|
|
41
|
-
|
|
42
|
-
wrapper.__wrapped__ = func # type: ignore
|
|
43
|
-
return wrapper
|
|
44
|
-
|
|
45
|
-
return decorator
|
|
46
|
-
|
|
47
|
-
|
|
48
10
|
def func_metadata(unsupported_engines: t.Optional[t.Union[str, t.List[str]]] = None) -> t.Callable:
|
|
49
11
|
def _metadata(func: t.Callable) -> t.Callable:
|
|
50
12
|
func.unsupported_engines = ensure_list(unsupported_engines) if unsupported_engines else [] # type: ignore
|