sqlframe 3.35.1__py3-none-any.whl → 3.36.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/_version.py +2 -2
- sqlframe/base/function_alternatives.py +0 -4
- sqlframe/base/functions.py +14 -17
- sqlframe/base/group.py +121 -2
- sqlframe/databricks/session.py +51 -2
- {sqlframe-3.35.1.dist-info → sqlframe-3.36.0.dist-info}/METADATA +3 -3
- {sqlframe-3.35.1.dist-info → sqlframe-3.36.0.dist-info}/RECORD +10 -10
- {sqlframe-3.35.1.dist-info → sqlframe-3.36.0.dist-info}/LICENSE +0 -0
- {sqlframe-3.35.1.dist-info → sqlframe-3.36.0.dist-info}/WHEEL +0 -0
- {sqlframe-3.35.1.dist-info → sqlframe-3.36.0.dist-info}/top_level.txt +0 -0
sqlframe/_version.py
CHANGED
@@ -1300,10 +1300,6 @@ def day_with_try_to_timestamp(col: ColumnOrName) -> Column:
|
|
1300
1300
|
)
|
1301
1301
|
|
1302
1302
|
|
1303
|
-
def endswith_with_underscore(str: ColumnOrName, suffix: ColumnOrName) -> Column:
|
1304
|
-
return Column.invoke_anonymous_function(str, "ENDS_WITH", suffix)
|
1305
|
-
|
1306
|
-
|
1307
1303
|
def endswith_using_like(str: ColumnOrName, suffix: ColumnOrName) -> Column:
|
1308
1304
|
concat = get_func_from_session("concat")
|
1309
1305
|
lit = get_func_from_session("lit")
|
sqlframe/base/functions.py
CHANGED
@@ -2288,14 +2288,14 @@ def array_distinct(col: ColumnOrName) -> Column:
|
|
2288
2288
|
|
2289
2289
|
@meta(unsupported_engines=["bigquery", "postgres"])
|
2290
2290
|
def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column:
|
2291
|
-
|
2292
|
-
|
2293
|
-
|
2294
|
-
|
2295
|
-
|
2296
|
-
|
2297
|
-
|
2298
|
-
|
2291
|
+
return Column(
|
2292
|
+
expression.ArrayIntersect(
|
2293
|
+
expressions=[
|
2294
|
+
Column.ensure_col(col1).column_expression,
|
2295
|
+
Column.ensure_col(col2).column_expression,
|
2296
|
+
]
|
2297
|
+
)
|
2298
|
+
)
|
2299
2299
|
|
2300
2300
|
|
2301
2301
|
@meta(unsupported_engines=["postgres"])
|
@@ -3226,18 +3226,16 @@ def elt(*inputs: ColumnOrName) -> Column:
|
|
3226
3226
|
def endswith(str: ColumnOrName, suffix: ColumnOrName) -> Column:
|
3227
3227
|
from sqlframe.base.function_alternatives import (
|
3228
3228
|
endswith_using_like,
|
3229
|
-
endswith_with_underscore,
|
3230
3229
|
)
|
3231
3230
|
|
3232
3231
|
session = _get_session()
|
3233
3232
|
|
3234
|
-
if session._is_bigquery or session._is_duckdb:
|
3235
|
-
return endswith_with_underscore(str, suffix)
|
3236
|
-
|
3237
3233
|
if session._is_postgres:
|
3238
3234
|
return endswith_using_like(str, suffix)
|
3239
3235
|
|
3240
|
-
return Column.
|
3236
|
+
return Column.invoke_expression_over_column(
|
3237
|
+
str, expression.EndsWith, expression=Column.ensure_col(suffix).column_expression
|
3238
|
+
)
|
3241
3239
|
|
3242
3240
|
|
3243
3241
|
@meta(unsupported_engines="*")
|
@@ -5655,10 +5653,9 @@ def replace(
|
|
5655
5653
|
):
|
5656
5654
|
replace = expression.Literal.string("") # type: ignore
|
5657
5655
|
|
5658
|
-
|
5659
|
-
|
5660
|
-
|
5661
|
-
return Column.invoke_anonymous_function(src, "replace", search)
|
5656
|
+
return Column.invoke_expression_over_column(
|
5657
|
+
src, expression.Replace, expression=search, replacement=replace
|
5658
|
+
)
|
5662
5659
|
|
5663
5660
|
|
5664
5661
|
@meta()
|
sqlframe/base/group.py
CHANGED
@@ -2,10 +2,16 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
import sys
|
5
6
|
import typing as t
|
6
7
|
|
7
8
|
from sqlframe.base.operations import Operation, group_operation, operation
|
8
9
|
|
10
|
+
if sys.version_info >= (3, 11):
|
11
|
+
from typing import Self
|
12
|
+
else:
|
13
|
+
from typing_extensions import Self
|
14
|
+
|
9
15
|
if t.TYPE_CHECKING:
|
10
16
|
from sqlframe.base.column import Column
|
11
17
|
from sqlframe.base.session import DF
|
@@ -28,6 +34,8 @@ class _BaseGroupedData(t.Generic[DF]):
|
|
28
34
|
self.session = df.session
|
29
35
|
self.last_op = last_op
|
30
36
|
self.group_by_cols = group_by_cols
|
37
|
+
self.pivot_col: t.Optional[str] = None
|
38
|
+
self.pivot_values: t.Optional[t.List[t.Any]] = None
|
31
39
|
|
32
40
|
def _get_function_applied_columns(
|
33
41
|
self, func_name: str, cols: t.Tuple[str, ...]
|
@@ -56,6 +64,79 @@ class _BaseGroupedData(t.Generic[DF]):
|
|
56
64
|
)
|
57
65
|
cols = self._df._ensure_and_normalize_cols(columns)
|
58
66
|
|
67
|
+
# Handle pivot transformation
|
68
|
+
if self.pivot_col is not None and self.pivot_values is not None:
|
69
|
+
from sqlglot import exp
|
70
|
+
|
71
|
+
from sqlframe.base import functions as F
|
72
|
+
|
73
|
+
# Build the pivot expression
|
74
|
+
# First, we need to convert the DataFrame to include the pivot logic
|
75
|
+
df = self._df.copy()
|
76
|
+
|
77
|
+
# Create the base query with group by columns, pivot column, and aggregation columns
|
78
|
+
select_cols = []
|
79
|
+
# Add group by columns
|
80
|
+
for col in self.group_by_cols:
|
81
|
+
select_cols.append(col.expression) # type: ignore
|
82
|
+
# Add pivot column
|
83
|
+
select_cols.append(Column.ensure_col(self.pivot_col).expression)
|
84
|
+
# Add the value columns that will be aggregated
|
85
|
+
for agg_col in cols:
|
86
|
+
# Extract the column being aggregated from the aggregation function
|
87
|
+
# For example, from SUM(earnings), we want to extract 'earnings'
|
88
|
+
if (
|
89
|
+
isinstance(agg_col.column_expression, exp.AggFunc)
|
90
|
+
and agg_col.column_expression.this
|
91
|
+
):
|
92
|
+
if agg_col.column_expression.this not in select_cols:
|
93
|
+
select_cols.append(agg_col.column_expression.this)
|
94
|
+
|
95
|
+
# Create the base query
|
96
|
+
base_query = df.expression.select(*select_cols, append=False)
|
97
|
+
|
98
|
+
# Build pivot expression
|
99
|
+
pivot_expressions = []
|
100
|
+
for agg_col in cols:
|
101
|
+
if isinstance(agg_col.column_expression, exp.AggFunc):
|
102
|
+
# Clone the aggregation function
|
103
|
+
# Snowflake doesn't support alias in the pivot, so we need to use the column_expression
|
104
|
+
agg_func = (
|
105
|
+
agg_col.column_expression.copy()
|
106
|
+
if self.session._is_snowflake
|
107
|
+
else agg_col.expression.copy()
|
108
|
+
)
|
109
|
+
pivot_expressions.append(agg_func)
|
110
|
+
|
111
|
+
# Create the IN clause with pivot values
|
112
|
+
in_values = []
|
113
|
+
for v in self.pivot_values:
|
114
|
+
if isinstance(v, str):
|
115
|
+
in_values.append(exp.Literal.string(v))
|
116
|
+
else:
|
117
|
+
in_values.append(exp.Literal.number(v))
|
118
|
+
|
119
|
+
# Build the pivot node with the fields parameter
|
120
|
+
pivot = exp.Pivot(
|
121
|
+
expressions=pivot_expressions,
|
122
|
+
fields=[
|
123
|
+
exp.In(
|
124
|
+
this=Column.ensure_col(self.pivot_col).column_expression,
|
125
|
+
expressions=in_values,
|
126
|
+
)
|
127
|
+
],
|
128
|
+
)
|
129
|
+
|
130
|
+
# Create a subquery with the pivot attached
|
131
|
+
subquery = base_query.subquery()
|
132
|
+
subquery.set("pivots", [pivot])
|
133
|
+
|
134
|
+
# Create the final select from the pivoted subquery
|
135
|
+
expression = exp.select("*").from_(subquery)
|
136
|
+
|
137
|
+
return self._df.copy(expression=expression)
|
138
|
+
|
139
|
+
# Original non-pivot logic
|
59
140
|
if not self.group_by_cols or not isinstance(self.group_by_cols[0], (list, tuple, set)):
|
60
141
|
expression = self._df.expression.group_by(
|
61
142
|
# User column_expression for group by to avoid alias in group by
|
@@ -104,5 +185,43 @@ class _BaseGroupedData(t.Generic[DF]):
|
|
104
185
|
def sum(self, *cols: str) -> DF:
|
105
186
|
return self.agg(*self._get_function_applied_columns("sum", cols))
|
106
187
|
|
107
|
-
def pivot(self,
|
108
|
-
|
188
|
+
def pivot(self, pivot_col: str, values: t.Optional[t.List[t.Any]] = None) -> Self:
|
189
|
+
"""
|
190
|
+
Pivots a column of the current DataFrame and perform the specified aggregation.
|
191
|
+
|
192
|
+
There are two versions of the pivot function: one that requires the caller
|
193
|
+
to specify the list of distinct values to pivot on, and one that does not.
|
194
|
+
The latter is more concise but less efficient, because Spark needs to first
|
195
|
+
compute the list of distinct values internally.
|
196
|
+
|
197
|
+
Parameters
|
198
|
+
----------
|
199
|
+
pivot_col : str
|
200
|
+
Name of the column to pivot.
|
201
|
+
values : list, optional
|
202
|
+
List of values that will be translated to columns in the output DataFrame.
|
203
|
+
|
204
|
+
Returns
|
205
|
+
-------
|
206
|
+
GroupedData
|
207
|
+
Returns self to allow chaining with aggregation methods.
|
208
|
+
"""
|
209
|
+
if self.session._is_postgres:
|
210
|
+
raise NotImplementedError(
|
211
|
+
"Pivot operation is not supported in Postgres. Please create an issue if you would like a workaround implemented."
|
212
|
+
)
|
213
|
+
|
214
|
+
self.pivot_col = pivot_col
|
215
|
+
|
216
|
+
if values is None:
|
217
|
+
# Eagerly compute distinct values
|
218
|
+
from sqlframe.base.column import Column
|
219
|
+
|
220
|
+
distinct_df = self._df.select(pivot_col).distinct()
|
221
|
+
distinct_rows = distinct_df.collect()
|
222
|
+
# Sort to make the results deterministic
|
223
|
+
self.pivot_values = sorted([row[0] for row in distinct_rows])
|
224
|
+
else:
|
225
|
+
self.pivot_values = values
|
226
|
+
|
227
|
+
return self
|
sqlframe/databricks/session.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import logging
|
3
4
|
import typing as t
|
4
|
-
import warnings
|
5
5
|
|
6
6
|
from sqlframe.base.session import _BaseSession
|
7
7
|
from sqlframe.databricks.catalog import DatabricksCatalog
|
@@ -19,6 +19,9 @@ else:
|
|
19
19
|
DatabricksConnection = t.Any
|
20
20
|
|
21
21
|
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
22
25
|
class DatabricksSession(
|
23
26
|
_BaseSession[ # type: ignore
|
24
27
|
DatabricksCatalog,
|
@@ -43,14 +46,60 @@ class DatabricksSession(
|
|
43
46
|
server_hostname: t.Optional[str] = None,
|
44
47
|
http_path: t.Optional[str] = None,
|
45
48
|
access_token: t.Optional[str] = None,
|
49
|
+
**kwargs: t.Any,
|
46
50
|
):
|
47
51
|
from databricks import sql
|
48
52
|
|
53
|
+
self._conn_kwargs = (
|
54
|
+
{}
|
55
|
+
if conn
|
56
|
+
else {
|
57
|
+
"server_hostname": server_hostname,
|
58
|
+
"http_path": http_path,
|
59
|
+
"access_token": access_token,
|
60
|
+
"disable_pandas": True,
|
61
|
+
**kwargs,
|
62
|
+
}
|
63
|
+
)
|
64
|
+
|
49
65
|
if not hasattr(self, "_conn"):
|
50
66
|
super().__init__(
|
51
|
-
conn or sql.connect(
|
67
|
+
conn or sql.connect(**self._conn_kwargs),
|
52
68
|
)
|
53
69
|
|
70
|
+
def _execute(self, sql: str) -> None:
|
71
|
+
from databricks.sql import connect
|
72
|
+
from databricks.sql.exc import DatabaseError, RequestError
|
73
|
+
|
74
|
+
try:
|
75
|
+
super()._execute(sql)
|
76
|
+
except (DatabaseError, RequestError) as e:
|
77
|
+
logger.warning("Failed to execute query")
|
78
|
+
if not self._is_session_expired_error(e):
|
79
|
+
logger.error("Error is not related to session expiration, re-raising")
|
80
|
+
raise e
|
81
|
+
if self._conn_kwargs:
|
82
|
+
logger.info("Attempting to reconnect with provided connection parameters")
|
83
|
+
self._connection = connect(**self._conn_kwargs)
|
84
|
+
# Clear the cached cursor
|
85
|
+
if hasattr(self, "_cur"):
|
86
|
+
delattr(self, "_cur")
|
87
|
+
super()._execute(sql)
|
88
|
+
else:
|
89
|
+
logger.error("No connection parameters provided so could not reconnect")
|
90
|
+
raise
|
91
|
+
|
92
|
+
def _is_session_expired_error(self, error: Exception) -> bool:
|
93
|
+
error_str = str(error).lower()
|
94
|
+
session_keywords = [
|
95
|
+
"invalid sessionhandle",
|
96
|
+
"session is closed",
|
97
|
+
"session expired",
|
98
|
+
"session not found",
|
99
|
+
"sessionhandle",
|
100
|
+
]
|
101
|
+
return any(keyword in error_str for keyword in session_keywords)
|
102
|
+
|
54
103
|
@classmethod
|
55
104
|
def _try_get_map(cls, value: t.Any) -> t.Optional[t.Dict[str, t.Any]]:
|
56
105
|
if (
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sqlframe
|
3
|
-
Version: 3.
|
3
|
+
Version: 3.36.0
|
4
4
|
Summary: Turning PySpark Into a Universal DataFrame API
|
5
5
|
Home-page: https://github.com/eakmanrq/sqlframe
|
6
6
|
Author: Ryan Eakman
|
@@ -17,7 +17,7 @@ Requires-Python: >=3.9
|
|
17
17
|
Description-Content-Type: text/markdown
|
18
18
|
License-File: LICENSE
|
19
19
|
Requires-Dist: prettytable <4
|
20
|
-
Requires-Dist: sqlglot <26.
|
20
|
+
Requires-Dist: sqlglot <26.32,>=24.0.0
|
21
21
|
Requires-Dist: typing-extensions
|
22
22
|
Provides-Extra: bigquery
|
23
23
|
Requires-Dist: google-cloud-bigquery-storage <3,>=2 ; extra == 'bigquery'
|
@@ -39,7 +39,7 @@ Requires-Dist: pytest-forked ; extra == 'dev'
|
|
39
39
|
Requires-Dist: pytest-postgresql <8,>=6 ; extra == 'dev'
|
40
40
|
Requires-Dist: pytest-xdist <3.8,>=3.6 ; extra == 'dev'
|
41
41
|
Requires-Dist: pytest <8.5,>=8.2.0 ; extra == 'dev'
|
42
|
-
Requires-Dist: ruff <0.
|
42
|
+
Requires-Dist: ruff <0.13,>=0.4.4 ; extra == 'dev'
|
43
43
|
Requires-Dist: types-psycopg2 <3,>=2.9 ; extra == 'dev'
|
44
44
|
Provides-Extra: docs
|
45
45
|
Requires-Dist: mkdocs-include-markdown-plugin ==6.0.6 ; extra == 'docs'
|
@@ -1,5 +1,5 @@
|
|
1
1
|
sqlframe/__init__.py,sha256=SB80yLTITBXHI2GCDS6n6bN5ObHqgPjfpRPAUwxaots,3403
|
2
|
-
sqlframe/_version.py,sha256=
|
2
|
+
sqlframe/_version.py,sha256=bkUPQ6OdlXKrD5knIV3EChl0OWjLm_VJDu9m0db4vwg,513
|
3
3
|
sqlframe/py.typed,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
|
4
4
|
sqlframe/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
5
|
sqlframe/base/_typing.py,sha256=b2clI5HI1zEZKB_3Msx3FeAJQyft44ubUifJwQRVXyQ,1298
|
@@ -8,9 +8,9 @@ sqlframe/base/column.py,sha256=5ZnZcn6gCCrAL53-EEHxVQWXG2oijN3RCOhlWmsjbJM,21147
|
|
8
8
|
sqlframe/base/dataframe.py,sha256=0diYONDlet8iZt49LC3vcmfXHAAZ2MovPL2pTXYHj2U,85974
|
9
9
|
sqlframe/base/decorators.py,sha256=IhE5xNQDkwJHacCvulq5WpUKyKmXm7dL2A3o5WuKGP4,2131
|
10
10
|
sqlframe/base/exceptions.py,sha256=9Uwvqn2eAkDpqm4BrRgbL61qM-GMCbJEMAW8otxO46s,370
|
11
|
-
sqlframe/base/function_alternatives.py,sha256=
|
12
|
-
sqlframe/base/functions.py,sha256=
|
13
|
-
sqlframe/base/group.py,sha256=
|
11
|
+
sqlframe/base/function_alternatives.py,sha256=aTu3nQhIAkZoxrI1IpjpaHEAMxBNms0AnhS0EMR-TwY,51727
|
12
|
+
sqlframe/base/functions.py,sha256=qyV-4R4CPSkuS-0S3dPza0BZykoKAanxjQq83tu8L34,225778
|
13
|
+
sqlframe/base/group.py,sha256=PGxUAnZkNlYKBIVNzoEDtoHbsP9Rhy1bGcSg2eYuWF4,9015
|
14
14
|
sqlframe/base/normalize.py,sha256=nXAJ5CwxVf4DV0GsH-q1w0p8gmjSMlv96k_ez1eVul8,3880
|
15
15
|
sqlframe/base/operations.py,sha256=g-YNcbvNKTOBbYm23GKfB3fmydlR7ZZDAuZUtXIHtzw,4438
|
16
16
|
sqlframe/base/readerwriter.py,sha256=Nb2VJ_HBmLQp5mK8JhnFooZh2ydAaboCAFVPb-4MNX4,31241
|
@@ -47,7 +47,7 @@ sqlframe/databricks/functions.py,sha256=La8rjAwO0hD4FBO0QxW5CtZtFAPvOrVc6lG4OtPG
|
|
47
47
|
sqlframe/databricks/functions.pyi,sha256=FzVBpzXCJzxIp73sIAo_R8Wx8uOJrix-W12HsgyeTcQ,23799
|
48
48
|
sqlframe/databricks/group.py,sha256=dU3g0DVLRlfOSCamKchQFXRd1WTFbdxoXkpEX8tPD6Y,399
|
49
49
|
sqlframe/databricks/readwriter.py,sha256=cuGRI1G627JEZgGNtirrT8LAwT6xQCdgkSAETmLKNXU,14777
|
50
|
-
sqlframe/databricks/session.py,sha256=
|
50
|
+
sqlframe/databricks/session.py,sha256=i2CgrLIHJb53Cx1qu_rE1-cmmm19S-Sw1MhTISX1zYU,4013
|
51
51
|
sqlframe/databricks/table.py,sha256=Q0Vnrl5aUqnqFTQpTwfWMRyQ9AQnagtpnSnXmP6IKRs,678
|
52
52
|
sqlframe/databricks/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
|
53
53
|
sqlframe/databricks/udf.py,sha256=3rmxv_6zSLfIxH8P8P050ZO-ki0aqBb9wWuUQBtl4m8,272
|
@@ -130,8 +130,8 @@ sqlframe/standalone/udf.py,sha256=azmgtUjHNIPs0WMVNId05SHwiYn41MKVBhKXsQJ5dmY,27
|
|
130
130
|
sqlframe/standalone/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
|
131
131
|
sqlframe/testing/__init__.py,sha256=VVCosQhitU74A3NnE52O4mNtGZONapuEXcc20QmSlnQ,132
|
132
132
|
sqlframe/testing/utils.py,sha256=PFsGZpwNUE_4-g_f43_vstTqsK0AQ2lBneb5Eb6NkFo,13008
|
133
|
-
sqlframe-3.
|
134
|
-
sqlframe-3.
|
135
|
-
sqlframe-3.
|
136
|
-
sqlframe-3.
|
137
|
-
sqlframe-3.
|
133
|
+
sqlframe-3.36.0.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
|
134
|
+
sqlframe-3.36.0.dist-info/METADATA,sha256=F56M3UKMA8CZN2Ps3dAkputINvX8rhBcPKTiAuC5iEs,8987
|
135
|
+
sqlframe-3.36.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
136
|
+
sqlframe-3.36.0.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
|
137
|
+
sqlframe-3.36.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|