sqlframe 1.9.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/dataframe.py +54 -1
- sqlframe/base/exceptions.py +12 -0
- sqlframe/base/function_alternatives.py +96 -0
- sqlframe/base/functions.py +4013 -1
- sqlframe/base/mixins/dataframe_mixins.py +24 -33
- sqlframe/base/session.py +2 -2
- sqlframe/base/types.py +3 -3
- sqlframe/base/util.py +56 -0
- sqlframe/bigquery/dataframe.py +33 -13
- sqlframe/bigquery/functions.py +4 -0
- sqlframe/bigquery/functions.pyi +37 -1
- sqlframe/duckdb/dataframe.py +6 -15
- sqlframe/duckdb/functions.py +3 -0
- sqlframe/duckdb/functions.pyi +29 -0
- sqlframe/postgres/catalog.py +123 -3
- sqlframe/postgres/dataframe.py +6 -10
- sqlframe/postgres/functions.py +6 -0
- sqlframe/postgres/functions.pyi +28 -0
- sqlframe/redshift/dataframe.py +3 -14
- sqlframe/snowflake/dataframe.py +23 -13
- sqlframe/snowflake/functions.py +3 -0
- sqlframe/snowflake/functions.pyi +27 -0
- sqlframe/spark/dataframe.py +25 -15
- sqlframe/spark/functions.pyi +161 -1
- sqlframe/testing/__init__.py +3 -0
- sqlframe/testing/utils.py +320 -0
- {sqlframe-1.9.0.dist-info → sqlframe-1.11.0.dist-info}/METADATA +1 -1
- {sqlframe-1.9.0.dist-info → sqlframe-1.11.0.dist-info}/RECORD +32 -30
- {sqlframe-1.9.0.dist-info → sqlframe-1.11.0.dist-info}/LICENSE +0 -0
- {sqlframe-1.9.0.dist-info → sqlframe-1.11.0.dist-info}/WHEEL +0 -0
- {sqlframe-1.9.0.dist-info → sqlframe-1.11.0.dist-info}/top_level.txt +0 -0
sqlframe/postgres/catalog.py
CHANGED
|
@@ -7,16 +7,17 @@ import typing as t
|
|
|
7
7
|
|
|
8
8
|
from sqlglot import exp, parse_one
|
|
9
9
|
|
|
10
|
-
from sqlframe.base.catalog import Function, _BaseCatalog
|
|
10
|
+
from sqlframe.base.catalog import Column, Function, _BaseCatalog
|
|
11
|
+
from sqlframe.base.decorators import normalize
|
|
11
12
|
from sqlframe.base.mixins.catalog_mixins import (
|
|
12
13
|
GetCurrentCatalogFromFunctionMixin,
|
|
13
14
|
GetCurrentDatabaseFromFunctionMixin,
|
|
14
15
|
ListCatalogsFromInfoSchemaMixin,
|
|
15
|
-
ListColumnsFromInfoSchemaMixin,
|
|
16
16
|
ListDatabasesFromInfoSchemaMixin,
|
|
17
17
|
ListTablesFromInfoSchemaMixin,
|
|
18
18
|
SetCurrentDatabaseFromSearchPathMixin,
|
|
19
19
|
)
|
|
20
|
+
from sqlframe.base.util import to_schema
|
|
20
21
|
|
|
21
22
|
if t.TYPE_CHECKING:
|
|
22
23
|
from sqlframe.postgres.session import PostgresSession # noqa
|
|
@@ -30,12 +31,131 @@ class PostgresCatalog(
|
|
|
30
31
|
ListCatalogsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
|
|
31
32
|
SetCurrentDatabaseFromSearchPathMixin["PostgresSession", "PostgresDataFrame"],
|
|
32
33
|
ListTablesFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
|
|
33
|
-
ListColumnsFromInfoSchemaMixin["PostgresSession", "PostgresDataFrame"],
|
|
34
34
|
_BaseCatalog["PostgresSession", "PostgresDataFrame"],
|
|
35
35
|
):
|
|
36
36
|
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.column("current_catalog")
|
|
37
37
|
TEMP_SCHEMA_FILTER = exp.column("table_schema").like("pg_temp_%")
|
|
38
38
|
|
|
39
|
+
@normalize(["tableName", "dbName"])
|
|
40
|
+
def listColumns(
|
|
41
|
+
self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
|
|
42
|
+
) -> t.List[Column]:
|
|
43
|
+
"""Returns a t.List of columns for the given table/view in the specified database.
|
|
44
|
+
|
|
45
|
+
.. versionadded:: 2.0.0
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
tableName : str
|
|
50
|
+
name of the table to t.List columns.
|
|
51
|
+
|
|
52
|
+
.. versionchanged:: 3.4.0
|
|
53
|
+
Allow ``tableName`` to be qualified with catalog name when ``dbName`` is None.
|
|
54
|
+
|
|
55
|
+
dbName : str, t.Optional
|
|
56
|
+
name of the database to find the table to t.List columns.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
t.List
|
|
61
|
+
A t.List of :class:`Column`.
|
|
62
|
+
|
|
63
|
+
Notes
|
|
64
|
+
-----
|
|
65
|
+
The order of arguments here is different from that of its JVM counterpart
|
|
66
|
+
because Python does not support method overloading.
|
|
67
|
+
|
|
68
|
+
If no database is specified, the current database and catalog
|
|
69
|
+
are used. This API includes all temporary views.
|
|
70
|
+
|
|
71
|
+
Examples
|
|
72
|
+
--------
|
|
73
|
+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
|
|
74
|
+
>>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
|
|
75
|
+
>>> spark.catalog.t.listColumns("tblA")
|
|
76
|
+
[Column(name='name', description=None, dataType='string', nullable=True, ...
|
|
77
|
+
>>> _ = spark.sql("DROP TABLE tblA")
|
|
78
|
+
"""
|
|
79
|
+
if df := self.session.temp_views.get(tableName):
|
|
80
|
+
return [
|
|
81
|
+
Column(
|
|
82
|
+
name=x,
|
|
83
|
+
description=None,
|
|
84
|
+
dataType="",
|
|
85
|
+
nullable=True,
|
|
86
|
+
isPartition=False,
|
|
87
|
+
isBucket=False,
|
|
88
|
+
)
|
|
89
|
+
for x in df.columns
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
table = exp.to_table(tableName, dialect=self.session.input_dialect)
|
|
93
|
+
schema = to_schema(dbName, dialect=self.session.input_dialect) if dbName else None
|
|
94
|
+
if not table.db:
|
|
95
|
+
if schema and schema.db:
|
|
96
|
+
table.set("db", schema.args["db"])
|
|
97
|
+
else:
|
|
98
|
+
table.set(
|
|
99
|
+
"db",
|
|
100
|
+
exp.parse_identifier(
|
|
101
|
+
self.currentDatabase(), dialect=self.session.input_dialect
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
if not table.catalog:
|
|
105
|
+
if schema and schema.catalog:
|
|
106
|
+
table.set("catalog", schema.args["catalog"])
|
|
107
|
+
else:
|
|
108
|
+
table.set(
|
|
109
|
+
"catalog",
|
|
110
|
+
exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
|
|
111
|
+
)
|
|
112
|
+
source_table = self._get_info_schema_table("columns", database=table.db)
|
|
113
|
+
select = parse_one(
|
|
114
|
+
f"""
|
|
115
|
+
SELECT
|
|
116
|
+
att.attname AS column_name,
|
|
117
|
+
pg_catalog.format_type(att.atttypid, NULL) AS data_type,
|
|
118
|
+
col.is_nullable
|
|
119
|
+
FROM
|
|
120
|
+
pg_catalog.pg_attribute att
|
|
121
|
+
JOIN
|
|
122
|
+
pg_catalog.pg_class cls ON cls.oid = att.attrelid
|
|
123
|
+
JOIN
|
|
124
|
+
pg_catalog.pg_namespace nsp ON nsp.oid = cls.relnamespace
|
|
125
|
+
JOIN
|
|
126
|
+
information_schema.columns col ON col.table_schema = nsp.nspname AND col.table_name = cls.relname AND col.column_name = att.attname
|
|
127
|
+
WHERE
|
|
128
|
+
cls.relname = '{table.name}' AND -- replace with your table name
|
|
129
|
+
att.attnum > 0 AND
|
|
130
|
+
NOT att.attisdropped
|
|
131
|
+
ORDER BY
|
|
132
|
+
att.attnum;
|
|
133
|
+
""",
|
|
134
|
+
dialect="postgres",
|
|
135
|
+
)
|
|
136
|
+
if table.db:
|
|
137
|
+
schema_filter: exp.Expression = exp.column("table_schema").eq(table.db)
|
|
138
|
+
if include_temp and self.TEMP_SCHEMA_FILTER:
|
|
139
|
+
schema_filter = exp.Or(this=schema_filter, expression=self.TEMP_SCHEMA_FILTER)
|
|
140
|
+
select = select.where(schema_filter) # type: ignore
|
|
141
|
+
if table.catalog:
|
|
142
|
+
catalog_filter: exp.Expression = exp.column("table_catalog").eq(table.catalog)
|
|
143
|
+
if include_temp and self.TEMP_CATALOG_FILTER:
|
|
144
|
+
catalog_filter = exp.Or(this=catalog_filter, expression=self.TEMP_CATALOG_FILTER)
|
|
145
|
+
select = select.where(catalog_filter) # type: ignore
|
|
146
|
+
results = self.session._fetch_rows(select)
|
|
147
|
+
return [
|
|
148
|
+
Column(
|
|
149
|
+
name=x["column_name"],
|
|
150
|
+
description=None,
|
|
151
|
+
dataType=x["data_type"],
|
|
152
|
+
nullable=x["is_nullable"] == "YES",
|
|
153
|
+
isPartition=False,
|
|
154
|
+
isBucket=False,
|
|
155
|
+
)
|
|
156
|
+
for x in results
|
|
157
|
+
]
|
|
158
|
+
|
|
39
159
|
def listFunctions(
|
|
40
160
|
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
|
|
41
161
|
) -> t.List[Function]:
|
sqlframe/postgres/dataframe.py
CHANGED
|
@@ -9,7 +9,10 @@ from sqlframe.base.dataframe import (
|
|
|
9
9
|
_BaseDataFrameNaFunctions,
|
|
10
10
|
_BaseDataFrameStatFunctions,
|
|
11
11
|
)
|
|
12
|
-
from sqlframe.base.mixins.dataframe_mixins import
|
|
12
|
+
from sqlframe.base.mixins.dataframe_mixins import (
|
|
13
|
+
NoCachePersistSupportMixin,
|
|
14
|
+
TypedColumnsFromTempViewMixin,
|
|
15
|
+
)
|
|
13
16
|
from sqlframe.postgres.group import PostgresGroupedData
|
|
14
17
|
|
|
15
18
|
if sys.version_info >= (3, 11):
|
|
@@ -34,7 +37,8 @@ class PostgresDataFrameStatFunctions(_BaseDataFrameStatFunctions["PostgresDataFr
|
|
|
34
37
|
|
|
35
38
|
|
|
36
39
|
class PostgresDataFrame(
|
|
37
|
-
|
|
40
|
+
NoCachePersistSupportMixin,
|
|
41
|
+
TypedColumnsFromTempViewMixin,
|
|
38
42
|
_BaseDataFrame[
|
|
39
43
|
"PostgresSession",
|
|
40
44
|
"PostgresDataFrameWriter",
|
|
@@ -46,11 +50,3 @@ class PostgresDataFrame(
|
|
|
46
50
|
_na = PostgresDataFrameNaFunctions
|
|
47
51
|
_stat = PostgresDataFrameStatFunctions
|
|
48
52
|
_group_data = PostgresGroupedData
|
|
49
|
-
|
|
50
|
-
def cache(self) -> Self:
|
|
51
|
-
logger.warning("Postgres does not support caching. Ignoring cache() call.")
|
|
52
|
-
return self
|
|
53
|
-
|
|
54
|
-
def persist(self) -> Self:
|
|
55
|
-
logger.warning("Postgres does not support persist. Ignoring persist() call.")
|
|
56
|
-
return self
|
sqlframe/postgres/functions.py
CHANGED
|
@@ -16,6 +16,7 @@ globals().update(
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from sqlframe.base.function_alternatives import ( # noqa
|
|
19
|
+
any_value_ignore_nulls_not_supported as any_value,
|
|
19
20
|
e_literal as e,
|
|
20
21
|
expm1_from_exp as expm1,
|
|
21
22
|
log1p_from_log as log1p,
|
|
@@ -40,6 +41,7 @@ from sqlframe.base.function_alternatives import ( # noqa
|
|
|
40
41
|
date_add_by_multiplication as date_add,
|
|
41
42
|
date_sub_by_multiplication as date_sub,
|
|
42
43
|
date_diff_with_subtraction as date_diff,
|
|
44
|
+
date_diff_with_subtraction as datediff,
|
|
43
45
|
add_months_by_multiplication as add_months,
|
|
44
46
|
months_between_from_age_and_extract as months_between,
|
|
45
47
|
from_unixtime_from_timestamp as from_unixtime,
|
|
@@ -58,4 +60,8 @@ from sqlframe.base.function_alternatives import ( # noqa
|
|
|
58
60
|
get_json_object_using_arrow_op as get_json_object,
|
|
59
61
|
array_min_from_subquery as array_min,
|
|
60
62
|
array_max_from_subquery as array_max,
|
|
63
|
+
left_cast_len as left,
|
|
64
|
+
right_cast_len as right,
|
|
65
|
+
position_cast_start as position,
|
|
66
|
+
try_element_at_zero_based as try_element_at,
|
|
61
67
|
)
|
sqlframe/postgres/functions.pyi
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from sqlframe.base.function_alternatives import ( # noqa
|
|
2
|
+
any_value_ignore_nulls_not_supported as any_value,
|
|
2
3
|
e_literal as e,
|
|
3
4
|
expm1_from_exp as expm1,
|
|
4
5
|
log1p_from_log as log1p,
|
|
@@ -23,6 +24,7 @@ from sqlframe.base.function_alternatives import ( # noqa
|
|
|
23
24
|
date_add_by_multiplication as date_add,
|
|
24
25
|
date_sub_by_multiplication as date_sub,
|
|
25
26
|
date_diff_with_subtraction as date_diff,
|
|
27
|
+
date_diff_with_subtraction as datediff,
|
|
26
28
|
add_months_by_multiplication as add_months,
|
|
27
29
|
months_between_from_age_and_extract as months_between,
|
|
28
30
|
from_unixtime_from_timestamp as from_unixtime,
|
|
@@ -41,6 +43,10 @@ from sqlframe.base.function_alternatives import ( # noqa
|
|
|
41
43
|
get_json_object_using_arrow_op as get_json_object,
|
|
42
44
|
array_min_from_subquery as array_min,
|
|
43
45
|
array_max_from_subquery as array_max,
|
|
46
|
+
left_cast_len as left,
|
|
47
|
+
right_cast_len as right,
|
|
48
|
+
position_cast_start as position,
|
|
49
|
+
try_element_at_zero_based as try_element_at,
|
|
44
50
|
)
|
|
45
51
|
from sqlframe.base.functions import (
|
|
46
52
|
abs as abs,
|
|
@@ -64,9 +70,13 @@ from sqlframe.base.functions import (
|
|
|
64
70
|
bit_length as bit_length,
|
|
65
71
|
bitwiseNOT as bitwiseNOT,
|
|
66
72
|
bitwise_not as bitwise_not,
|
|
73
|
+
bool_and as bool_and,
|
|
74
|
+
bool_or as bool_or,
|
|
75
|
+
call_function as call_function,
|
|
67
76
|
cbrt as cbrt,
|
|
68
77
|
ceil as ceil,
|
|
69
78
|
ceiling as ceiling,
|
|
79
|
+
char as char,
|
|
70
80
|
coalesce as coalesce,
|
|
71
81
|
col as col,
|
|
72
82
|
collect_list as collect_list,
|
|
@@ -84,8 +94,10 @@ from sqlframe.base.functions import (
|
|
|
84
94
|
cume_dist as cume_dist,
|
|
85
95
|
current_date as current_date,
|
|
86
96
|
current_timestamp as current_timestamp,
|
|
97
|
+
current_user as current_user,
|
|
87
98
|
date_format as date_format,
|
|
88
99
|
date_trunc as date_trunc,
|
|
100
|
+
dateadd as dateadd,
|
|
89
101
|
degrees as degrees,
|
|
90
102
|
dense_rank as dense_rank,
|
|
91
103
|
desc as desc,
|
|
@@ -94,18 +106,22 @@ from sqlframe.base.functions import (
|
|
|
94
106
|
exp as exp,
|
|
95
107
|
explode as explode,
|
|
96
108
|
expr as expr,
|
|
109
|
+
extract as extract,
|
|
97
110
|
factorial as factorial,
|
|
98
111
|
floor as floor,
|
|
99
112
|
greatest as greatest,
|
|
113
|
+
ifnull as ifnull,
|
|
100
114
|
initcap as initcap,
|
|
101
115
|
input_file_name as input_file_name,
|
|
102
116
|
instr as instr,
|
|
103
117
|
lag as lag,
|
|
118
|
+
lcase as lcase,
|
|
104
119
|
lead as lead,
|
|
105
120
|
least as least,
|
|
106
121
|
length as length,
|
|
107
122
|
levenshtein as levenshtein,
|
|
108
123
|
lit as lit,
|
|
124
|
+
ln as ln,
|
|
109
125
|
locate as locate,
|
|
110
126
|
log as log,
|
|
111
127
|
log10 as log10,
|
|
@@ -117,19 +133,25 @@ from sqlframe.base.functions import (
|
|
|
117
133
|
md5 as md5,
|
|
118
134
|
mean as mean,
|
|
119
135
|
min as min,
|
|
136
|
+
now as now,
|
|
120
137
|
nth_value as nth_value,
|
|
121
138
|
ntile as ntile,
|
|
122
139
|
nullif as nullif,
|
|
140
|
+
nvl as nvl,
|
|
141
|
+
nvl2 as nvl2,
|
|
123
142
|
octet_length as octet_length,
|
|
124
143
|
overlay as overlay,
|
|
125
144
|
percent_rank as percent_rank,
|
|
126
145
|
percentile as percentile,
|
|
127
146
|
pow as pow,
|
|
147
|
+
power as power,
|
|
128
148
|
radians as radians,
|
|
129
149
|
rank as rank,
|
|
150
|
+
regexp_like as regexp_like,
|
|
130
151
|
regexp_replace as regexp_replace,
|
|
131
152
|
repeat as repeat,
|
|
132
153
|
reverse as reverse,
|
|
154
|
+
rlike as rlike,
|
|
133
155
|
row_number as row_number,
|
|
134
156
|
rpad as rpad,
|
|
135
157
|
rtrim as rtrim,
|
|
@@ -137,12 +159,14 @@ from sqlframe.base.functions import (
|
|
|
137
159
|
shiftRight as shiftRight,
|
|
138
160
|
shiftleft as shiftleft,
|
|
139
161
|
shiftright as shiftright,
|
|
162
|
+
sign as sign,
|
|
140
163
|
signum as signum,
|
|
141
164
|
sin as sin,
|
|
142
165
|
sinh as sinh,
|
|
143
166
|
size as size,
|
|
144
167
|
soundex as soundex,
|
|
145
168
|
sqrt as sqrt,
|
|
169
|
+
startswith as startswith,
|
|
146
170
|
stddev as stddev,
|
|
147
171
|
stddev_pop as stddev_pop,
|
|
148
172
|
stddev_samp as stddev_samp,
|
|
@@ -156,11 +180,15 @@ from sqlframe.base.functions import (
|
|
|
156
180
|
toDegrees as toDegrees,
|
|
157
181
|
toRadians as toRadians,
|
|
158
182
|
to_date as to_date,
|
|
183
|
+
to_number as to_number,
|
|
159
184
|
to_timestamp as to_timestamp,
|
|
160
185
|
translate as translate,
|
|
161
186
|
trim as trim,
|
|
162
187
|
trunc as trunc,
|
|
188
|
+
ucase as ucase,
|
|
189
|
+
unix_date as unix_date,
|
|
163
190
|
upper as upper,
|
|
191
|
+
user as user,
|
|
164
192
|
var_pop as var_pop,
|
|
165
193
|
var_samp as var_samp,
|
|
166
194
|
variance as variance,
|
sqlframe/redshift/dataframe.py
CHANGED
|
@@ -9,13 +9,9 @@ from sqlframe.base.dataframe import (
|
|
|
9
9
|
_BaseDataFrameNaFunctions,
|
|
10
10
|
_BaseDataFrameStatFunctions,
|
|
11
11
|
)
|
|
12
|
+
from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
|
|
12
13
|
from sqlframe.redshift.group import RedshiftGroupedData
|
|
13
14
|
|
|
14
|
-
if sys.version_info >= (3, 11):
|
|
15
|
-
from typing import Self
|
|
16
|
-
else:
|
|
17
|
-
from typing_extensions import Self
|
|
18
|
-
|
|
19
15
|
if t.TYPE_CHECKING:
|
|
20
16
|
from sqlframe.redshift.readwriter import RedshiftDataFrameWriter
|
|
21
17
|
from sqlframe.redshift.session import RedshiftSession
|
|
@@ -33,22 +29,15 @@ class RedshiftDataFrameStatFunctions(_BaseDataFrameStatFunctions["RedshiftDataFr
|
|
|
33
29
|
|
|
34
30
|
|
|
35
31
|
class RedshiftDataFrame(
|
|
32
|
+
NoCachePersistSupportMixin,
|
|
36
33
|
_BaseDataFrame[
|
|
37
34
|
"RedshiftSession",
|
|
38
35
|
"RedshiftDataFrameWriter",
|
|
39
36
|
"RedshiftDataFrameNaFunctions",
|
|
40
37
|
"RedshiftDataFrameStatFunctions",
|
|
41
38
|
"RedshiftGroupedData",
|
|
42
|
-
]
|
|
39
|
+
],
|
|
43
40
|
):
|
|
44
41
|
_na = RedshiftDataFrameNaFunctions
|
|
45
42
|
_stat = RedshiftDataFrameStatFunctions
|
|
46
43
|
_group_data = RedshiftGroupedData
|
|
47
|
-
|
|
48
|
-
def cache(self) -> Self:
|
|
49
|
-
logger.warning("Redshift does not support caching. Ignoring cache() call.")
|
|
50
|
-
return self
|
|
51
|
-
|
|
52
|
-
def persist(self) -> Self:
|
|
53
|
-
logger.warning("Redshift does not support persist. Ignoring persist() call.")
|
|
54
|
-
return self
|
sqlframe/snowflake/dataframe.py
CHANGED
|
@@ -4,18 +4,15 @@ import logging
|
|
|
4
4
|
import sys
|
|
5
5
|
import typing as t
|
|
6
6
|
|
|
7
|
+
from sqlframe.base.catalog import Column as CatalogColumn
|
|
7
8
|
from sqlframe.base.dataframe import (
|
|
8
9
|
_BaseDataFrame,
|
|
9
10
|
_BaseDataFrameNaFunctions,
|
|
10
11
|
_BaseDataFrameStatFunctions,
|
|
11
12
|
)
|
|
13
|
+
from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
|
|
12
14
|
from sqlframe.snowflake.group import SnowflakeGroupedData
|
|
13
15
|
|
|
14
|
-
if sys.version_info >= (3, 11):
|
|
15
|
-
from typing import Self
|
|
16
|
-
else:
|
|
17
|
-
from typing_extensions import Self
|
|
18
|
-
|
|
19
16
|
if t.TYPE_CHECKING:
|
|
20
17
|
from sqlframe.snowflake.readwriter import SnowflakeDataFrameWriter
|
|
21
18
|
from sqlframe.snowflake.session import SnowflakeSession
|
|
@@ -33,22 +30,35 @@ class SnowflakeDataFrameStatFunctions(_BaseDataFrameStatFunctions["SnowflakeData
|
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
class SnowflakeDataFrame(
|
|
33
|
+
NoCachePersistSupportMixin,
|
|
36
34
|
_BaseDataFrame[
|
|
37
35
|
"SnowflakeSession",
|
|
38
36
|
"SnowflakeDataFrameWriter",
|
|
39
37
|
"SnowflakeDataFrameNaFunctions",
|
|
40
38
|
"SnowflakeDataFrameStatFunctions",
|
|
41
39
|
"SnowflakeGroupedData",
|
|
42
|
-
]
|
|
40
|
+
],
|
|
43
41
|
):
|
|
44
42
|
_na = SnowflakeDataFrameNaFunctions
|
|
45
43
|
_stat = SnowflakeDataFrameStatFunctions
|
|
46
44
|
_group_data = SnowflakeGroupedData
|
|
47
45
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
46
|
+
@property
|
|
47
|
+
def _typed_columns(self) -> t.List[CatalogColumn]:
|
|
48
|
+
df = self._convert_leaf_to_cte()
|
|
49
|
+
df = df.limit(0)
|
|
50
|
+
self.session._execute(df.expression)
|
|
51
|
+
query_id = self.session._cur.sfqid
|
|
52
|
+
columns = []
|
|
53
|
+
for row in self.session._fetch_rows(f"DESCRIBE RESULT '{query_id}'"):
|
|
54
|
+
columns.append(
|
|
55
|
+
CatalogColumn(
|
|
56
|
+
name=row.name,
|
|
57
|
+
dataType=row.type,
|
|
58
|
+
nullable=row["null?"] == "Y",
|
|
59
|
+
description=row.comment,
|
|
60
|
+
isPartition=False,
|
|
61
|
+
isBucket=False,
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
return columns
|
sqlframe/snowflake/functions.py
CHANGED
|
@@ -16,6 +16,7 @@ globals().update(
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from sqlframe.base.function_alternatives import ( # noqa
|
|
19
|
+
any_value_ignore_nulls_not_supported as any_value,
|
|
19
20
|
e_literal as e,
|
|
20
21
|
expm1_from_exp as expm1,
|
|
21
22
|
log1p_from_log as log1p,
|
|
@@ -32,6 +33,7 @@ from sqlframe.base.function_alternatives import ( # noqa
|
|
|
32
33
|
struct_with_eq as struct,
|
|
33
34
|
make_date_date_from_parts as make_date,
|
|
34
35
|
date_add_no_date_sub as date_add,
|
|
36
|
+
date_add_no_date_sub as dateadd,
|
|
35
37
|
date_sub_by_date_add as date_sub,
|
|
36
38
|
add_months_using_func as add_months,
|
|
37
39
|
months_between_cast_as_date_cast_roundoff as months_between,
|
|
@@ -60,4 +62,5 @@ from sqlframe.base.function_alternatives import ( # noqa
|
|
|
60
62
|
flatten_using_array_flatten as flatten,
|
|
61
63
|
map_concat_using_map_cat as map_concat,
|
|
62
64
|
sequence_from_array_generate_range as sequence,
|
|
65
|
+
to_number_using_to_double as to_number,
|
|
63
66
|
)
|
sqlframe/snowflake/functions.pyi
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from sqlframe.base.function_alternatives import ( # noqa
|
|
2
|
+
any_value_ignore_nulls_not_supported as any_value,
|
|
2
3
|
e_literal as e,
|
|
3
4
|
expm1_from_exp as expm1,
|
|
4
5
|
log1p_from_log as log1p,
|
|
@@ -15,6 +16,7 @@ from sqlframe.base.function_alternatives import ( # noqa
|
|
|
15
16
|
struct_with_eq as struct,
|
|
16
17
|
make_date_date_from_parts as make_date,
|
|
17
18
|
date_add_no_date_sub as date_add,
|
|
19
|
+
date_add_no_date_sub as dateadd,
|
|
18
20
|
date_sub_by_date_add as date_sub,
|
|
19
21
|
add_months_using_func as add_months,
|
|
20
22
|
months_between_cast_as_date_cast_roundoff as months_between,
|
|
@@ -43,6 +45,7 @@ from sqlframe.base.function_alternatives import ( # noqa
|
|
|
43
45
|
flatten_using_array_flatten as flatten,
|
|
44
46
|
map_concat_using_map_cat as map_concat,
|
|
45
47
|
sequence_from_array_generate_range as sequence,
|
|
48
|
+
to_number_using_to_double as to_number,
|
|
46
49
|
)
|
|
47
50
|
from sqlframe.base.functions import (
|
|
48
51
|
abs as abs,
|
|
@@ -69,9 +72,13 @@ from sqlframe.base.functions import (
|
|
|
69
72
|
avg as avg,
|
|
70
73
|
bit_length as bit_length,
|
|
71
74
|
bitwiseNOT as bitwiseNOT,
|
|
75
|
+
bool_and as bool_and,
|
|
76
|
+
bool_or as bool_or,
|
|
77
|
+
call_function as call_function,
|
|
72
78
|
cbrt as cbrt,
|
|
73
79
|
ceil as ceil,
|
|
74
80
|
ceiling as ceiling,
|
|
81
|
+
char as char,
|
|
75
82
|
coalesce as coalesce,
|
|
76
83
|
col as col,
|
|
77
84
|
collect_list as collect_list,
|
|
@@ -85,14 +92,17 @@ from sqlframe.base.functions import (
|
|
|
85
92
|
count as count,
|
|
86
93
|
countDistinct as countDistinct,
|
|
87
94
|
count_distinct as count_distinct,
|
|
95
|
+
count_if as count_if,
|
|
88
96
|
covar_pop as covar_pop,
|
|
89
97
|
covar_samp as covar_samp,
|
|
90
98
|
cume_dist as cume_dist,
|
|
91
99
|
current_date as current_date,
|
|
92
100
|
current_timestamp as current_timestamp,
|
|
101
|
+
current_user as current_user,
|
|
93
102
|
date_diff as date_diff,
|
|
94
103
|
date_format as date_format,
|
|
95
104
|
date_trunc as date_trunc,
|
|
105
|
+
datediff as datediff,
|
|
96
106
|
dayofmonth as dayofmonth,
|
|
97
107
|
dayofweek as dayofweek,
|
|
98
108
|
dayofyear as dayofyear,
|
|
@@ -104,21 +114,26 @@ from sqlframe.base.functions import (
|
|
|
104
114
|
exp as exp,
|
|
105
115
|
explode as explode,
|
|
106
116
|
expr as expr,
|
|
117
|
+
extract as extract,
|
|
107
118
|
factorial as factorial,
|
|
108
119
|
floor as floor,
|
|
109
120
|
greatest as greatest,
|
|
110
121
|
grouping_id as grouping_id,
|
|
111
122
|
hash as hash,
|
|
112
123
|
hour as hour,
|
|
124
|
+
ifnull as ifnull,
|
|
113
125
|
initcap as initcap,
|
|
114
126
|
input_file_name as input_file_name,
|
|
115
127
|
instr as instr,
|
|
116
128
|
kurtosis as kurtosis,
|
|
117
129
|
lag as lag,
|
|
130
|
+
lcase as lcase,
|
|
118
131
|
lead as lead,
|
|
119
132
|
least as least,
|
|
133
|
+
left as left,
|
|
120
134
|
length as length,
|
|
121
135
|
lit as lit,
|
|
136
|
+
ln as ln,
|
|
122
137
|
locate as locate,
|
|
123
138
|
log as log,
|
|
124
139
|
log10 as log10,
|
|
@@ -136,35 +151,44 @@ from sqlframe.base.functions import (
|
|
|
136
151
|
minute as minute,
|
|
137
152
|
month as month,
|
|
138
153
|
next_day as next_day,
|
|
154
|
+
now as now,
|
|
139
155
|
nth_value as nth_value,
|
|
140
156
|
ntile as ntile,
|
|
141
157
|
nullif as nullif,
|
|
158
|
+
nvl as nvl,
|
|
159
|
+
nvl2 as nvl2,
|
|
142
160
|
octet_length as octet_length,
|
|
143
161
|
percent_rank as percent_rank,
|
|
144
162
|
percentile as percentile,
|
|
145
163
|
posexplode as posexplode,
|
|
164
|
+
position as position,
|
|
146
165
|
pow as pow,
|
|
166
|
+
power as power,
|
|
147
167
|
quarter as quarter,
|
|
148
168
|
radians as radians,
|
|
149
169
|
rand as rand,
|
|
150
170
|
rank as rank,
|
|
151
171
|
regexp_replace as regexp_replace,
|
|
152
172
|
repeat as repeat,
|
|
173
|
+
right as right,
|
|
153
174
|
round as round,
|
|
154
175
|
row_number as row_number,
|
|
155
176
|
rpad as rpad,
|
|
156
177
|
rtrim as rtrim,
|
|
157
178
|
second as second,
|
|
179
|
+
sha as sha,
|
|
158
180
|
sha1 as sha1,
|
|
159
181
|
sha2 as sha2,
|
|
160
182
|
shiftLeft as shiftLeft,
|
|
161
183
|
shiftRight as shiftRight,
|
|
184
|
+
sign as sign,
|
|
162
185
|
signum as signum,
|
|
163
186
|
sin as sin,
|
|
164
187
|
sinh as sinh,
|
|
165
188
|
size as size,
|
|
166
189
|
soundex as soundex,
|
|
167
190
|
sqrt as sqrt,
|
|
191
|
+
startswith as startswith,
|
|
168
192
|
stddev as stddev,
|
|
169
193
|
stddev_pop as stddev_pop,
|
|
170
194
|
stddev_samp as stddev_samp,
|
|
@@ -182,7 +206,10 @@ from sqlframe.base.functions import (
|
|
|
182
206
|
translate as translate,
|
|
183
207
|
trim as trim,
|
|
184
208
|
trunc as trunc,
|
|
209
|
+
ucase as ucase,
|
|
210
|
+
unix_date as unix_date,
|
|
185
211
|
upper as upper,
|
|
212
|
+
user as user,
|
|
186
213
|
var_pop as var_pop,
|
|
187
214
|
var_samp as var_samp,
|
|
188
215
|
variance as variance,
|
sqlframe/spark/dataframe.py
CHANGED
|
@@ -1,26 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import sys
|
|
5
4
|
import typing as t
|
|
6
5
|
|
|
6
|
+
from sqlglot import exp
|
|
7
|
+
|
|
8
|
+
from sqlframe.base.catalog import Column
|
|
7
9
|
from sqlframe.base.dataframe import (
|
|
8
10
|
_BaseDataFrame,
|
|
9
11
|
_BaseDataFrameNaFunctions,
|
|
10
12
|
_BaseDataFrameStatFunctions,
|
|
11
13
|
)
|
|
14
|
+
from sqlframe.base.mixins.dataframe_mixins import NoCachePersistSupportMixin
|
|
12
15
|
from sqlframe.spark.group import SparkGroupedData
|
|
13
16
|
|
|
14
|
-
if sys.version_info >= (3, 11):
|
|
15
|
-
from typing import Self
|
|
16
|
-
else:
|
|
17
|
-
from typing_extensions import Self
|
|
18
|
-
|
|
19
17
|
if t.TYPE_CHECKING:
|
|
20
18
|
from sqlframe.spark.readwriter import SparkDataFrameWriter
|
|
21
19
|
from sqlframe.spark.session import SparkSession
|
|
22
20
|
|
|
23
|
-
|
|
24
21
|
logger = logging.getLogger(__name__)
|
|
25
22
|
|
|
26
23
|
|
|
@@ -33,22 +30,35 @@ class SparkDataFrameStatFunctions(_BaseDataFrameStatFunctions["SparkDataFrame"])
|
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
class SparkDataFrame(
|
|
33
|
+
NoCachePersistSupportMixin,
|
|
36
34
|
_BaseDataFrame[
|
|
37
35
|
"SparkSession",
|
|
38
36
|
"SparkDataFrameWriter",
|
|
39
37
|
"SparkDataFrameNaFunctions",
|
|
40
38
|
"SparkDataFrameStatFunctions",
|
|
41
39
|
"SparkGroupedData",
|
|
42
|
-
]
|
|
40
|
+
],
|
|
43
41
|
):
|
|
44
42
|
_na = SparkDataFrameNaFunctions
|
|
45
43
|
_stat = SparkDataFrameStatFunctions
|
|
46
44
|
_group_data = SparkGroupedData
|
|
47
45
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
46
|
+
@property
|
|
47
|
+
def _typed_columns(self) -> t.List[Column]:
|
|
48
|
+
columns = []
|
|
49
|
+
for field in self.session.spark_session.sql(
|
|
50
|
+
self.session._to_sql(self.expression)
|
|
51
|
+
).schema.fields:
|
|
52
|
+
columns.append(
|
|
53
|
+
Column(
|
|
54
|
+
name=field.name,
|
|
55
|
+
dataType=exp.DataType.build(field.dataType.simpleString(), dialect="spark").sql(
|
|
56
|
+
dialect="spark"
|
|
57
|
+
),
|
|
58
|
+
nullable=field.nullable,
|
|
59
|
+
description=None,
|
|
60
|
+
isPartition=False,
|
|
61
|
+
isBucket=False,
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
return columns
|