sqlframe 3.35.1__py3-none-any.whl → 3.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '3.35.1'
21
- __version_tuple__ = version_tuple = (3, 35, 1)
20
+ __version__ = version = '3.36.0'
21
+ __version_tuple__ = version_tuple = (3, 36, 0)
@@ -1300,10 +1300,6 @@ def day_with_try_to_timestamp(col: ColumnOrName) -> Column:
1300
1300
  )
1301
1301
 
1302
1302
 
1303
- def endswith_with_underscore(str: ColumnOrName, suffix: ColumnOrName) -> Column:
1304
- return Column.invoke_anonymous_function(str, "ENDS_WITH", suffix)
1305
-
1306
-
1307
1303
  def endswith_using_like(str: ColumnOrName, suffix: ColumnOrName) -> Column:
1308
1304
  concat = get_func_from_session("concat")
1309
1305
  lit = get_func_from_session("lit")
@@ -2288,14 +2288,14 @@ def array_distinct(col: ColumnOrName) -> Column:
2288
2288
 
2289
2289
  @meta(unsupported_engines=["bigquery", "postgres"])
2290
2290
  def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column:
2291
- from sqlframe.base.function_alternatives import array_intersect_using_intersection
2292
-
2293
- session = _get_session()
2294
-
2295
- if session._is_snowflake:
2296
- return array_intersect_using_intersection(col1, col2)
2297
-
2298
- return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2))
2291
+ return Column(
2292
+ expression.ArrayIntersect(
2293
+ expressions=[
2294
+ Column.ensure_col(col1).column_expression,
2295
+ Column.ensure_col(col2).column_expression,
2296
+ ]
2297
+ )
2298
+ )
2299
2299
 
2300
2300
 
2301
2301
  @meta(unsupported_engines=["postgres"])
@@ -3226,18 +3226,16 @@ def elt(*inputs: ColumnOrName) -> Column:
3226
3226
  def endswith(str: ColumnOrName, suffix: ColumnOrName) -> Column:
3227
3227
  from sqlframe.base.function_alternatives import (
3228
3228
  endswith_using_like,
3229
- endswith_with_underscore,
3230
3229
  )
3231
3230
 
3232
3231
  session = _get_session()
3233
3232
 
3234
- if session._is_bigquery or session._is_duckdb:
3235
- return endswith_with_underscore(str, suffix)
3236
-
3237
3233
  if session._is_postgres:
3238
3234
  return endswith_using_like(str, suffix)
3239
3235
 
3240
- return Column.invoke_anonymous_function(str, "endswith", suffix)
3236
+ return Column.invoke_expression_over_column(
3237
+ str, expression.EndsWith, expression=Column.ensure_col(suffix).column_expression
3238
+ )
3241
3239
 
3242
3240
 
3243
3241
  @meta(unsupported_engines="*")
@@ -5655,10 +5653,9 @@ def replace(
5655
5653
  ):
5656
5654
  replace = expression.Literal.string("") # type: ignore
5657
5655
 
5658
- if replace is not None:
5659
- return Column.invoke_anonymous_function(src, "replace", search, replace)
5660
- else:
5661
- return Column.invoke_anonymous_function(src, "replace", search)
5656
+ return Column.invoke_expression_over_column(
5657
+ src, expression.Replace, expression=search, replacement=replace
5658
+ )
5662
5659
 
5663
5660
 
5664
5661
  @meta()
sqlframe/base/group.py CHANGED
@@ -2,10 +2,16 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import sys
5
6
  import typing as t
6
7
 
7
8
  from sqlframe.base.operations import Operation, group_operation, operation
8
9
 
10
+ if sys.version_info >= (3, 11):
11
+ from typing import Self
12
+ else:
13
+ from typing_extensions import Self
14
+
9
15
  if t.TYPE_CHECKING:
10
16
  from sqlframe.base.column import Column
11
17
  from sqlframe.base.session import DF
@@ -28,6 +34,8 @@ class _BaseGroupedData(t.Generic[DF]):
28
34
  self.session = df.session
29
35
  self.last_op = last_op
30
36
  self.group_by_cols = group_by_cols
37
+ self.pivot_col: t.Optional[str] = None
38
+ self.pivot_values: t.Optional[t.List[t.Any]] = None
31
39
 
32
40
  def _get_function_applied_columns(
33
41
  self, func_name: str, cols: t.Tuple[str, ...]
@@ -56,6 +64,79 @@ class _BaseGroupedData(t.Generic[DF]):
56
64
  )
57
65
  cols = self._df._ensure_and_normalize_cols(columns)
58
66
 
67
+ # Handle pivot transformation
68
+ if self.pivot_col is not None and self.pivot_values is not None:
69
+ from sqlglot import exp
70
+
71
+ from sqlframe.base import functions as F
72
+
73
+ # Build the pivot expression
74
+ # First, we need to convert the DataFrame to include the pivot logic
75
+ df = self._df.copy()
76
+
77
+ # Create the base query with group by columns, pivot column, and aggregation columns
78
+ select_cols = []
79
+ # Add group by columns
80
+ for col in self.group_by_cols:
81
+ select_cols.append(col.expression) # type: ignore
82
+ # Add pivot column
83
+ select_cols.append(Column.ensure_col(self.pivot_col).expression)
84
+ # Add the value columns that will be aggregated
85
+ for agg_col in cols:
86
+ # Extract the column being aggregated from the aggregation function
87
+ # For example, from SUM(earnings), we want to extract 'earnings'
88
+ if (
89
+ isinstance(agg_col.column_expression, exp.AggFunc)
90
+ and agg_col.column_expression.this
91
+ ):
92
+ if agg_col.column_expression.this not in select_cols:
93
+ select_cols.append(agg_col.column_expression.this)
94
+
95
+ # Create the base query
96
+ base_query = df.expression.select(*select_cols, append=False)
97
+
98
+ # Build pivot expression
99
+ pivot_expressions = []
100
+ for agg_col in cols:
101
+ if isinstance(agg_col.column_expression, exp.AggFunc):
102
+ # Clone the aggregation function
103
+ # Snowflake doesn't support alias in the pivot, so we need to use the column_expression
104
+ agg_func = (
105
+ agg_col.column_expression.copy()
106
+ if self.session._is_snowflake
107
+ else agg_col.expression.copy()
108
+ )
109
+ pivot_expressions.append(agg_func)
110
+
111
+ # Create the IN clause with pivot values
112
+ in_values = []
113
+ for v in self.pivot_values:
114
+ if isinstance(v, str):
115
+ in_values.append(exp.Literal.string(v))
116
+ else:
117
+ in_values.append(exp.Literal.number(v))
118
+
119
+ # Build the pivot node with the fields parameter
120
+ pivot = exp.Pivot(
121
+ expressions=pivot_expressions,
122
+ fields=[
123
+ exp.In(
124
+ this=Column.ensure_col(self.pivot_col).column_expression,
125
+ expressions=in_values,
126
+ )
127
+ ],
128
+ )
129
+
130
+ # Create a subquery with the pivot attached
131
+ subquery = base_query.subquery()
132
+ subquery.set("pivots", [pivot])
133
+
134
+ # Create the final select from the pivoted subquery
135
+ expression = exp.select("*").from_(subquery)
136
+
137
+ return self._df.copy(expression=expression)
138
+
139
+ # Original non-pivot logic
59
140
  if not self.group_by_cols or not isinstance(self.group_by_cols[0], (list, tuple, set)):
60
141
  expression = self._df.expression.group_by(
61
142
  # User column_expression for group by to avoid alias in group by
@@ -104,5 +185,43 @@ class _BaseGroupedData(t.Generic[DF]):
104
185
  def sum(self, *cols: str) -> DF:
105
186
  return self.agg(*self._get_function_applied_columns("sum", cols))
106
187
 
107
- def pivot(self, *cols: str) -> DF:
108
- raise NotImplementedError("Sum distinct is not currently implemented")
188
+ def pivot(self, pivot_col: str, values: t.Optional[t.List[t.Any]] = None) -> Self:
189
+ """
190
+ Pivots a column of the current DataFrame and perform the specified aggregation.
191
+
192
+ There are two versions of the pivot function: one that requires the caller
193
+ to specify the list of distinct values to pivot on, and one that does not.
194
+ The latter is more concise but less efficient, because Spark needs to first
195
+ compute the list of distinct values internally.
196
+
197
+ Parameters
198
+ ----------
199
+ pivot_col : str
200
+ Name of the column to pivot.
201
+ values : list, optional
202
+ List of values that will be translated to columns in the output DataFrame.
203
+
204
+ Returns
205
+ -------
206
+ GroupedData
207
+ Returns self to allow chaining with aggregation methods.
208
+ """
209
+ if self.session._is_postgres:
210
+ raise NotImplementedError(
211
+ "Pivot operation is not supported in Postgres. Please create an issue if you would like a workaround implemented."
212
+ )
213
+
214
+ self.pivot_col = pivot_col
215
+
216
+ if values is None:
217
+ # Eagerly compute distinct values
218
+ from sqlframe.base.column import Column
219
+
220
+ distinct_df = self._df.select(pivot_col).distinct()
221
+ distinct_rows = distinct_df.collect()
222
+ # Sort to make the results deterministic
223
+ self.pivot_values = sorted([row[0] for row in distinct_rows])
224
+ else:
225
+ self.pivot_values = values
226
+
227
+ return self
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
3
4
  import typing as t
4
- import warnings
5
5
 
6
6
  from sqlframe.base.session import _BaseSession
7
7
  from sqlframe.databricks.catalog import DatabricksCatalog
@@ -19,6 +19,9 @@ else:
19
19
  DatabricksConnection = t.Any
20
20
 
21
21
 
22
+ logger = logging.getLogger(__name__)
23
+
24
+
22
25
  class DatabricksSession(
23
26
  _BaseSession[ # type: ignore
24
27
  DatabricksCatalog,
@@ -43,14 +46,60 @@ class DatabricksSession(
43
46
  server_hostname: t.Optional[str] = None,
44
47
  http_path: t.Optional[str] = None,
45
48
  access_token: t.Optional[str] = None,
49
+ **kwargs: t.Any,
46
50
  ):
47
51
  from databricks import sql
48
52
 
53
+ self._conn_kwargs = (
54
+ {}
55
+ if conn
56
+ else {
57
+ "server_hostname": server_hostname,
58
+ "http_path": http_path,
59
+ "access_token": access_token,
60
+ "disable_pandas": True,
61
+ **kwargs,
62
+ }
63
+ )
64
+
49
65
  if not hasattr(self, "_conn"):
50
66
  super().__init__(
51
- conn or sql.connect(server_hostname, http_path, access_token, disable_pandas=True)
67
+ conn or sql.connect(**self._conn_kwargs),
52
68
  )
53
69
 
70
+ def _execute(self, sql: str) -> None:
71
+ from databricks.sql import connect
72
+ from databricks.sql.exc import DatabaseError, RequestError
73
+
74
+ try:
75
+ super()._execute(sql)
76
+ except (DatabaseError, RequestError) as e:
77
+ logger.warning("Failed to execute query")
78
+ if not self._is_session_expired_error(e):
79
+ logger.error("Error is not related to session expiration, re-raising")
80
+ raise e
81
+ if self._conn_kwargs:
82
+ logger.info("Attempting to reconnect with provided connection parameters")
83
+ self._connection = connect(**self._conn_kwargs)
84
+ # Clear the cached cursor
85
+ if hasattr(self, "_cur"):
86
+ delattr(self, "_cur")
87
+ super()._execute(sql)
88
+ else:
89
+ logger.error("No connection parameters provided so could not reconnect")
90
+ raise
91
+
92
+ def _is_session_expired_error(self, error: Exception) -> bool:
93
+ error_str = str(error).lower()
94
+ session_keywords = [
95
+ "invalid sessionhandle",
96
+ "session is closed",
97
+ "session expired",
98
+ "session not found",
99
+ "sessionhandle",
100
+ ]
101
+ return any(keyword in error_str for keyword in session_keywords)
102
+
54
103
  @classmethod
55
104
  def _try_get_map(cls, value: t.Any) -> t.Optional[t.Dict[str, t.Any]]:
56
105
  if (
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sqlframe
3
- Version: 3.35.1
3
+ Version: 3.36.0
4
4
  Summary: Turning PySpark Into a Universal DataFrame API
5
5
  Home-page: https://github.com/eakmanrq/sqlframe
6
6
  Author: Ryan Eakman
@@ -17,7 +17,7 @@ Requires-Python: >=3.9
17
17
  Description-Content-Type: text/markdown
18
18
  License-File: LICENSE
19
19
  Requires-Dist: prettytable <4
20
- Requires-Dist: sqlglot <26.26,>=24.0.0
20
+ Requires-Dist: sqlglot <26.32,>=24.0.0
21
21
  Requires-Dist: typing-extensions
22
22
  Provides-Extra: bigquery
23
23
  Requires-Dist: google-cloud-bigquery-storage <3,>=2 ; extra == 'bigquery'
@@ -39,7 +39,7 @@ Requires-Dist: pytest-forked ; extra == 'dev'
39
39
  Requires-Dist: pytest-postgresql <8,>=6 ; extra == 'dev'
40
40
  Requires-Dist: pytest-xdist <3.8,>=3.6 ; extra == 'dev'
41
41
  Requires-Dist: pytest <8.5,>=8.2.0 ; extra == 'dev'
42
- Requires-Dist: ruff <0.12,>=0.4.4 ; extra == 'dev'
42
+ Requires-Dist: ruff <0.13,>=0.4.4 ; extra == 'dev'
43
43
  Requires-Dist: types-psycopg2 <3,>=2.9 ; extra == 'dev'
44
44
  Provides-Extra: docs
45
45
  Requires-Dist: mkdocs-include-markdown-plugin ==6.0.6 ; extra == 'docs'
@@ -1,5 +1,5 @@
1
1
  sqlframe/__init__.py,sha256=SB80yLTITBXHI2GCDS6n6bN5ObHqgPjfpRPAUwxaots,3403
2
- sqlframe/_version.py,sha256=kPcRtrGIJvBSXjEXIsPZ4vA33McEBwn6hXm6zOraFmM,513
2
+ sqlframe/_version.py,sha256=bkUPQ6OdlXKrD5knIV3EChl0OWjLm_VJDu9m0db4vwg,513
3
3
  sqlframe/py.typed,sha256=Nqnn8clbgv-5l0PgxcTOldg8mkMKrFn4TvPL-rYUUGg,1
4
4
  sqlframe/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  sqlframe/base/_typing.py,sha256=b2clI5HI1zEZKB_3Msx3FeAJQyft44ubUifJwQRVXyQ,1298
@@ -8,9 +8,9 @@ sqlframe/base/column.py,sha256=5ZnZcn6gCCrAL53-EEHxVQWXG2oijN3RCOhlWmsjbJM,21147
8
8
  sqlframe/base/dataframe.py,sha256=0diYONDlet8iZt49LC3vcmfXHAAZ2MovPL2pTXYHj2U,85974
9
9
  sqlframe/base/decorators.py,sha256=IhE5xNQDkwJHacCvulq5WpUKyKmXm7dL2A3o5WuKGP4,2131
10
10
  sqlframe/base/exceptions.py,sha256=9Uwvqn2eAkDpqm4BrRgbL61qM-GMCbJEMAW8otxO46s,370
11
- sqlframe/base/function_alternatives.py,sha256=EKtDgYyaJSfaSfhs_IemDkpy6VK2E8V6fDvjAqKR_tM,51880
12
- sqlframe/base/functions.py,sha256=geB8QRQvyOipB3v_gOC5KhhB--UpKJH0z2dbyRNCNaI,225983
13
- sqlframe/base/group.py,sha256=OY4w1WRsCqLgW-Pi7DjF63zbbxSLISCF3qjAbzI2CQ4,4283
11
+ sqlframe/base/function_alternatives.py,sha256=aTu3nQhIAkZoxrI1IpjpaHEAMxBNms0AnhS0EMR-TwY,51727
12
+ sqlframe/base/functions.py,sha256=qyV-4R4CPSkuS-0S3dPza0BZykoKAanxjQq83tu8L34,225778
13
+ sqlframe/base/group.py,sha256=PGxUAnZkNlYKBIVNzoEDtoHbsP9Rhy1bGcSg2eYuWF4,9015
14
14
  sqlframe/base/normalize.py,sha256=nXAJ5CwxVf4DV0GsH-q1w0p8gmjSMlv96k_ez1eVul8,3880
15
15
  sqlframe/base/operations.py,sha256=g-YNcbvNKTOBbYm23GKfB3fmydlR7ZZDAuZUtXIHtzw,4438
16
16
  sqlframe/base/readerwriter.py,sha256=Nb2VJ_HBmLQp5mK8JhnFooZh2ydAaboCAFVPb-4MNX4,31241
@@ -47,7 +47,7 @@ sqlframe/databricks/functions.py,sha256=La8rjAwO0hD4FBO0QxW5CtZtFAPvOrVc6lG4OtPG
47
47
  sqlframe/databricks/functions.pyi,sha256=FzVBpzXCJzxIp73sIAo_R8Wx8uOJrix-W12HsgyeTcQ,23799
48
48
  sqlframe/databricks/group.py,sha256=dU3g0DVLRlfOSCamKchQFXRd1WTFbdxoXkpEX8tPD6Y,399
49
49
  sqlframe/databricks/readwriter.py,sha256=cuGRI1G627JEZgGNtirrT8LAwT6xQCdgkSAETmLKNXU,14777
50
- sqlframe/databricks/session.py,sha256=iw4uczkJHkpVO8vusEEmfCrhxHWyAHpCFmOZ-0qlkms,2343
50
+ sqlframe/databricks/session.py,sha256=i2CgrLIHJb53Cx1qu_rE1-cmmm19S-Sw1MhTISX1zYU,4013
51
51
  sqlframe/databricks/table.py,sha256=Q0Vnrl5aUqnqFTQpTwfWMRyQ9AQnagtpnSnXmP6IKRs,678
52
52
  sqlframe/databricks/types.py,sha256=KwNyuXIo-2xVVd4bZED3YrQOobKCtemlxGrJL7DrTC8,34
53
53
  sqlframe/databricks/udf.py,sha256=3rmxv_6zSLfIxH8P8P050ZO-ki0aqBb9wWuUQBtl4m8,272
@@ -130,8 +130,8 @@ sqlframe/standalone/udf.py,sha256=azmgtUjHNIPs0WMVNId05SHwiYn41MKVBhKXsQJ5dmY,27
130
130
  sqlframe/standalone/window.py,sha256=6GKPzuxeSapJakBaKBeT9VpED1ACdjggDv9JRILDyV0,35
131
131
  sqlframe/testing/__init__.py,sha256=VVCosQhitU74A3NnE52O4mNtGZONapuEXcc20QmSlnQ,132
132
132
  sqlframe/testing/utils.py,sha256=PFsGZpwNUE_4-g_f43_vstTqsK0AQ2lBneb5Eb6NkFo,13008
133
- sqlframe-3.35.1.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
134
- sqlframe-3.35.1.dist-info/METADATA,sha256=T1Zjv7wX8XssCXYaXa0mrRAR8Br1Udv8Brw0ZqeWj3I,8987
135
- sqlframe-3.35.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
136
- sqlframe-3.35.1.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
137
- sqlframe-3.35.1.dist-info/RECORD,,
133
+ sqlframe-3.36.0.dist-info/LICENSE,sha256=VZu79YgW780qxaFJMr0t5ZgbOYEh04xWoxaWOaqIGWk,1068
134
+ sqlframe-3.36.0.dist-info/METADATA,sha256=F56M3UKMA8CZN2Ps3dAkputINvX8rhBcPKTiAuC5iEs,8987
135
+ sqlframe-3.36.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
136
+ sqlframe-3.36.0.dist-info/top_level.txt,sha256=T0_RpoygaZSF6heeWwIDQgaP0varUdSK1pzjeJZRjM8,9
137
+ sqlframe-3.36.0.dist-info/RECORD,,