sqlframe 3.22.1__py3-none-any.whl → 3.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sqlframe/__init__.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import importlib
4
4
  import sys
5
5
  import typing as t
6
+ from contextlib import contextmanager
6
7
  from unittest.mock import MagicMock
7
8
 
8
9
  if t.TYPE_CHECKING:
@@ -98,3 +99,14 @@ def deactivate() -> None:
98
99
  except ImportError:
99
100
  pass
100
101
  ACTIVATE_CONFIG.clear()
102
+
103
+
104
+ @contextmanager
105
+ def activate_context(
106
+ engine: t.Optional[str] = None,
107
+ conn: t.Optional[CONN] = None,
108
+ config: t.Optional[t.Dict[str, t.Any]] = None,
109
+ ):
110
+ activate(engine, conn, config)
111
+ yield
112
+ deactivate()
sqlframe/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '3.22.1'
21
- __version_tuple__ = version_tuple = (3, 22, 1)
20
+ __version__ = version = '3.24.0'
21
+ __version_tuple__ = version_tuple = (3, 24, 0)
sqlframe/base/catalog.py CHANGED
@@ -14,16 +14,17 @@ if t.TYPE_CHECKING:
14
14
  from sqlglot.schema import ColumnMapping
15
15
 
16
16
  from sqlframe.base._typing import StorageLevel, UserDefinedFunctionLike
17
- from sqlframe.base.session import DF, _BaseSession
17
+ from sqlframe.base.session import DF, TABLE, _BaseSession
18
18
  from sqlframe.base.types import DataType, StructType
19
19
 
20
20
  SESSION = t.TypeVar("SESSION", bound=_BaseSession)
21
21
  else:
22
22
  DF = t.TypeVar("DF")
23
+ TABLE = t.TypeVar("TABLE")
23
24
  SESSION = t.TypeVar("SESSION")
24
25
 
25
26
 
26
- class _BaseCatalog(t.Generic[SESSION, DF]):
27
+ class _BaseCatalog(t.Generic[SESSION, DF, TABLE]):
27
28
  """User-facing catalog API, accessible through `SparkSession.catalog`."""
28
29
 
29
30
  TEMP_CATALOG_FILTER: t.Optional[exp.Expression] = None
@@ -688,7 +689,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
688
689
  source: t.Optional[str] = None,
689
690
  schema: t.Optional[StructType] = None,
690
691
  **options: str,
691
- ) -> DF:
692
+ ) -> TABLE:
692
693
  """Creates a table based on the dataset in a data source.
693
694
 
694
695
  It returns the DataFrame associated with the external table.
@@ -716,7 +717,7 @@ class _BaseCatalog(t.Generic[SESSION, DF]):
716
717
  schema: t.Optional[StructType] = None,
717
718
  description: t.Optional[str] = None,
718
719
  **options: str,
719
- ) -> DF:
720
+ ) -> TABLE:
720
721
  """Creates a table based on the dataset in a data source.
721
722
 
722
723
  .. versionadded:: 2.2.0
sqlframe/base/column.py CHANGED
@@ -128,6 +128,21 @@ class Column:
128
128
  "Tried to call a column which is unexpected. Did you mean to call a method on a DataFrame? If so, make sure the method is typed correctly and is supported. If not, please open an issue requesting support: https://github.com/eakmanrq/sqlframe/issues"
129
129
  )
130
130
 
131
+ def __getattr__(self, name: str) -> Column:
132
+ """
133
+ Enables accessing nested fields using dot notation for struct types.
134
+
135
+ For example:
136
+ df.select(df.r.a) # This is equivalent to df.select(df.r.getField("a"))
137
+
138
+ This method is called when the attribute doesn't exist in the class,
139
+ and delegates to getField method.
140
+ """
141
+ # Handle special method names (like __iter__) properly by raising AttributeError
142
+ if name.startswith("__") and name.endswith("__"):
143
+ raise AttributeError(f"{self.__class__.__name__} object has no attribute '{name}'")
144
+ return self.getField(name)
145
+
131
146
  @classmethod
132
147
  def ensure_col(cls, value: t.Optional[t.Union[ColumnOrName, exp.Expression]]) -> Column:
133
148
  col = get_func_from_session("col")
@@ -459,3 +474,45 @@ class Column:
459
474
  if isinstance(key.expression, exp.Literal) and key.expression.is_number:
460
475
  key = key + lit(1)
461
476
  return element_at(self, key)
477
+
478
+ def getField(self, name: t.Any) -> Column:
479
+ """
480
+ An expression that gets a field by name in a StructType.
481
+
482
+ .. versionadded:: 1.3.0
483
+
484
+ .. versionchanged:: 3.4.0
485
+ Supports Spark Connect.
486
+
487
+ Parameters
488
+ ----------
489
+ name
490
+ a literal value, or a :class:`Column` expression.
491
+ The result will only be true at a location if the field matches in the Column.
492
+
493
+ .. deprecated:: 3.0.0
494
+ :class:`Column` as a parameter is deprecated.
495
+
496
+ Returns
497
+ -------
498
+ :class:`Column`
499
+ Column representing whether each element of Column got by name.
500
+
501
+ Examples
502
+ --------
503
+ >>> from sqlframe.base.types import Row
504
+ >>> df = spark.createDataFrame([Row(r=Row(a=1, b="b"))])
505
+ >>> df.select(df.r.getField("b")).show()
506
+ +---+
507
+ |r.b|
508
+ +---+
509
+ | b|
510
+ +---+
511
+ >>> df.select(df.r.a).show()
512
+ +---+
513
+ |r.a|
514
+ +---+
515
+ | 1|
516
+ +---+
517
+ """
518
+ return self.getItem(name)
@@ -1730,8 +1730,8 @@ class BaseDataFrame(t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]):
1730
1730
  @operation(Operation.SELECT)
1731
1731
  def unpivot(
1732
1732
  self,
1733
- ids: t.Union[ColumnOrName, t.List[ColumnOrName], t.Tuple[ColumnOrName, ...]],
1734
- values: t.Optional[t.Union[ColumnOrName, t.List[ColumnOrName], t.Tuple[ColumnOrName, ...]]],
1733
+ ids: t.Union[ColumnOrName, t.Collection[ColumnOrName]],
1734
+ values: t.Optional[t.Union[ColumnOrName, t.Collection[ColumnOrName]]],
1735
1735
  variableColumnName: str,
1736
1736
  valueColumnName: str,
1737
1737
  ) -> Self:
@@ -856,15 +856,21 @@ def expr(str: str) -> Column:
856
856
 
857
857
  @meta(unsupported_engines=["postgres"])
858
858
  def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column:
859
- from sqlframe.base.function_alternatives import struct_with_eq
860
-
861
859
  session = _get_session()
862
-
863
- if session._is_snowflake:
864
- return struct_with_eq(col, *cols)
865
-
866
- columns = ensure_list(col) + list(cols)
867
- return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns)
860
+ col_func = get_func_from_session("col")
861
+
862
+ columns = [col_func(x) for x in ensure_list(col) + list(cols)]
863
+ expressions = []
864
+ for column in columns:
865
+ expressions.append(
866
+ expression.PropertyEQ(
867
+ this=expression.parse_identifier(
868
+ column.alias_or_name, dialect=session.input_dialect
869
+ ),
870
+ expression=column.column_expression,
871
+ )
872
+ )
873
+ return Column(expression.Struct(expressions=expressions))
868
874
 
869
875
 
870
876
  @meta(unsupported_engines=["bigquery", "duckdb", "postgres", "snowflake"])
sqlframe/base/group.py CHANGED
@@ -16,6 +16,8 @@ else:
16
16
  # https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-groupby.html
17
17
  # https://stackoverflow.com/questions/37975227/what-is-the-difference-between-cube-rollup-and-groupby-operators
18
18
  class _BaseGroupedData(t.Generic[DF]):
19
+ last_op: Operation
20
+
19
21
  def __init__(
20
22
  self,
21
23
  df: DF,
@@ -6,16 +6,23 @@ from sqlglot import exp
6
6
  from sqlframe.base.catalog import (
7
7
  DF,
8
8
  SESSION,
9
+ TABLE,
9
10
  CatalogMetadata,
10
11
  Column,
11
12
  Database,
12
13
  Table,
13
14
  _BaseCatalog,
14
15
  )
15
- from sqlframe.base.util import normalize_string, schema_, to_schema
16
+ from sqlframe.base.types import StructType
17
+ from sqlframe.base.util import (
18
+ get_column_mapping_from_schema_input,
19
+ normalize_string,
20
+ schema_,
21
+ to_schema,
22
+ )
16
23
 
17
24
 
18
- class _BaseInfoSchemaMixin(_BaseCatalog, t.Generic[SESSION, DF]):
25
+ class _BaseInfoSchemaMixin(_BaseCatalog, t.Generic[SESSION, DF, TABLE]):
19
26
  QUALIFY_INFO_SCHEMA_WITH_DATABASE = False
20
27
  UPPERCASE_INFO_SCHEMA = False
21
28
 
@@ -52,7 +59,7 @@ class _BaseInfoSchemaMixin(_BaseCatalog, t.Generic[SESSION, DF]):
52
59
  )
53
60
 
54
61
 
55
- class GetCurrentCatalogFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF]):
62
+ class GetCurrentCatalogFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF, TABLE]):
56
63
  CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog")
57
64
 
58
65
  def currentCatalog(self) -> str:
@@ -74,7 +81,7 @@ class GetCurrentCatalogFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF]):
74
81
  )
75
82
 
76
83
 
77
- class GetCurrentDatabaseFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF]):
84
+ class GetCurrentDatabaseFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF, TABLE]):
78
85
  CURRENT_DATABASE_EXPRESSION: exp.Expression = exp.func("current_schema")
79
86
 
80
87
  def currentDatabase(self) -> str:
@@ -94,7 +101,7 @@ class GetCurrentDatabaseFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF]):
94
101
  )
95
102
 
96
103
 
97
- class SetCurrentCatalogFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF]):
104
+ class SetCurrentCatalogFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF, TABLE]):
98
105
  def setCurrentCatalog(self, catalogName: str) -> None:
99
106
  """Sets the current default catalog in this session.
100
107
 
@@ -114,7 +121,136 @@ class SetCurrentCatalogFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF]):
114
121
  )
115
122
 
116
123
 
117
- class ListDatabasesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
124
+ class CreateTableFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF, TABLE]):
125
+ def createTable(
126
+ self,
127
+ tableName: str,
128
+ path: t.Optional[str] = None,
129
+ source: t.Optional[str] = None,
130
+ schema: t.Optional[StructType] = None,
131
+ description: t.Optional[str] = None,
132
+ **options: str,
133
+ ) -> TABLE:
134
+ """Creates a table based on the dataset in a data source.
135
+
136
+ .. versionadded:: 2.2.0
137
+
138
+ Parameters
139
+ ----------
140
+ tableName : str
141
+ name of the table to create.
142
+
143
+ .. versionchanged:: 3.4.0
144
+ Allow ``tableName`` to be qualified with catalog name.
145
+
146
+ path : str, t.Optional
147
+ the path in which the data for this table exists.
148
+ When ``path`` is specified, an external table is
149
+ created from the data at the given path. Otherwise a managed table is created.
150
+ source : str, t.Optional
151
+ the source of this table such as 'parquet, 'orc', etc.
152
+ If ``source`` is not specified, the default data source configured by
153
+ ``spark.sql.sources.default`` will be used.
154
+ schema : class:`StructType`, t.Optional
155
+ the schema for this table.
156
+ description : str, t.Optional
157
+ the description of this table.
158
+
159
+ .. versionchanged:: 3.1.0
160
+ Added the ``description`` parameter.
161
+
162
+ **options : dict, t.Optional
163
+ extra options to specify in the table.
164
+
165
+ Returns
166
+ -------
167
+ :class:`DataFrame`
168
+ The DataFrame associated with the table.
169
+
170
+ Examples
171
+ --------
172
+ Creating a managed table.
173
+
174
+ >>> _ = spark.catalog.createTable("tbl1", schema=spark.range(1).schema, source='parquet')
175
+ >>> _ = spark.sql("DROP TABLE tbl1")
176
+
177
+ Creating an external table
178
+
179
+ >>> import tempfile
180
+ >>> with tempfile.TemporaryDirectory() as d:
181
+ ... _ = spark.catalog.createTable(
182
+ ... "tbl2", schema=spark.range(1).schema, path=d, source='parquet')
183
+ >>> _ = spark.sql("DROP TABLE tbl2")
184
+ """
185
+ if source is not None:
186
+ raise NotImplementedError("Providing source to create table is not supported")
187
+ if path is not None:
188
+ raise NotImplementedError("Creating a external table is not supported")
189
+
190
+ replace: t.Union[str, bool, None] = options.pop("replace", None)
191
+ exists: t.Union[str, bool, None] = options.pop("exists", None)
192
+
193
+ if isinstance(replace, str) and replace.lower() == "true":
194
+ replace = True
195
+ if isinstance(exists, str) and exists.lower() == "true":
196
+ exists = True
197
+
198
+ if schema is None:
199
+ raise ValueError("schema must be specified.")
200
+
201
+ column_mapping = get_column_mapping_from_schema_input(
202
+ schema, dialect=self.session.input_dialect
203
+ )
204
+ expressions = [
205
+ exp.ColumnDef(this=exp.parse_identifier(k, dialect=self.session.input_dialect), kind=v)
206
+ for k, v in column_mapping.items()
207
+ ]
208
+
209
+ name = normalize_string(tableName, from_dialect="input", is_table=True)
210
+ output_expression_container = exp.Create(
211
+ this=exp.Schema(
212
+ this=exp.to_table(name, dialect=self.session.input_dialect),
213
+ expressions=expressions,
214
+ ),
215
+ kind="TABLE",
216
+ exists=exists,
217
+ replace=replace,
218
+ )
219
+ if self.session._has_connection:
220
+ self.session._collect(output_expression_container)
221
+
222
+ df = self.session.table(name)
223
+ return df
224
+
225
+ def createExternalTable(
226
+ self,
227
+ tableName: str,
228
+ path: t.Optional[str] = None,
229
+ source: t.Optional[str] = None,
230
+ schema: t.Optional[StructType] = None,
231
+ **options: str,
232
+ ) -> TABLE:
233
+ """Creates a table based on the dataset in a data source.
234
+
235
+ It returns the DataFrame associated with the external table.
236
+
237
+ The data source is specified by the ``source`` and a set of ``options``.
238
+ If ``source`` is not specified, the default data source configured by
239
+ ``spark.sql.sources.default`` will be used.
240
+
241
+ t.Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
242
+ created external table.
243
+
244
+ .. versionadded:: 2.0.0
245
+
246
+ Returns
247
+ -------
248
+ :class:`DataFrame`
249
+ """
250
+ return self.createTable(tableName, path=path, source=source, schema=schema, **options)
251
+
252
+
253
+ class ListDatabasesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF, TABLE]):
118
254
  def listDatabases(self, pattern: t.Optional[str] = None) -> t.List[Database]:
119
255
  """
120
256
  Returns a t.List of databases available across all sessions.
@@ -169,7 +305,7 @@ class ListDatabasesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION,
169
305
  return databases
170
306
 
171
307
 
172
- class ListCatalogsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
308
+ class ListCatalogsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF, TABLE]):
173
309
  def listCatalogs(self, pattern: t.Optional[str] = None) -> t.List[CatalogMetadata]:
174
310
  """
175
311
  Returns a t.List of databases available across all sessions.
@@ -221,7 +357,7 @@ class ListCatalogsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, D
221
357
  return catalogs
222
358
 
223
359
 
224
- class SetCurrentDatabaseFromSearchPathMixin(_BaseCatalog, t.Generic[SESSION, DF]):
360
+ class SetCurrentDatabaseFromSearchPathMixin(_BaseCatalog, t.Generic[SESSION, DF, TABLE]):
225
361
  def setCurrentDatabase(self, dbName: str) -> None:
226
362
  """
227
363
  Sets the current default database in this session.
@@ -235,7 +371,7 @@ class SetCurrentDatabaseFromSearchPathMixin(_BaseCatalog, t.Generic[SESSION, DF]
235
371
  self.session._execute(f'SET search_path TO "{dbName}"')
236
372
 
237
373
 
238
- class SetCurrentDatabaseFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF]):
374
+ class SetCurrentDatabaseFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF, TABLE]):
239
375
  def setCurrentDatabase(self, dbName: str) -> None:
240
376
  """
241
377
  Sets the current default database in this session.
@@ -257,7 +393,7 @@ class SetCurrentDatabaseFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF]):
257
393
  self.session._collect(exp.Use(this=schema))
258
394
 
259
395
 
260
- class ListTablesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
396
+ class ListTablesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF, TABLE]):
261
397
  def listTables(
262
398
  self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
263
399
  ) -> t.List[Table]:
@@ -395,7 +531,7 @@ class ListTablesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]
395
531
  return tables
396
532
 
397
533
 
398
- class ListColumnsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
534
+ class ListColumnsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF, TABLE]):
399
535
  def listColumns(
400
536
  self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
401
537
  ) -> t.List[Column]:
@@ -4,6 +4,9 @@ import typing as t
4
4
 
5
5
  from sqlglot import exp
6
6
 
7
+ if t.TYPE_CHECKING:
8
+ from sqlframe.base._typing import StorageLevel
9
+
7
10
  from sqlframe.base.catalog import Column
8
11
  from sqlframe.base.dataframe import (
9
12
  GROUP_DATA,
@@ -28,7 +31,7 @@ class NoCachePersistSupportMixin(BaseDataFrame, t.Generic[SESSION, WRITER, NA, S
28
31
  logger.warning("This engine does not support caching. Ignoring cache() call.")
29
32
  return self
30
33
 
31
- def persist(self) -> Self:
34
+ def persist(self, storageLevel: "StorageLevel" = "MEMORY_AND_DISK_SER") -> Self:
32
35
  logger.warning("This engine does not support persist. Ignoring persist() call.")
33
36
  return self
34
37
 
@@ -6,10 +6,19 @@ import functools
6
6
  import typing as t
7
7
  from enum import IntEnum
8
8
 
9
+ from typing_extensions import Concatenate, ParamSpec
10
+
9
11
  if t.TYPE_CHECKING:
10
12
  from sqlframe.base.dataframe import BaseDataFrame
11
13
  from sqlframe.base.group import _BaseGroupedData
12
14
 
15
+ DF = t.TypeVar("DF", bound=BaseDataFrame)
16
+ T = t.TypeVar("T", bound=t.Union[BaseDataFrame, _BaseGroupedData])
17
+ else:
18
+ DF = t.TypeVar("DF")
19
+ T = t.TypeVar("T")
20
+ P = ParamSpec("P") # represents arbitrary args + kwargs
21
+
13
22
 
14
23
  class Operation(IntEnum):
15
24
  INIT = -1
@@ -23,7 +32,17 @@ class Operation(IntEnum):
23
32
  LIMIT = 7
24
33
 
25
34
 
26
- def operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:
35
+ # We want to decorate a function (self: DF, *args, **kwargs) -> T
36
+ # where DF is a subclass of BaseDataFrame
37
+ # where T is a subclass of BaseDataFrame or _BaseGroupedData
38
+ # And keep its signature, i.e. produce a function of the same shape
39
+ # Hence we work with `t.Callable[Concatenate[DF, P], T]`
40
+ def operation(
41
+ op: Operation,
42
+ ) -> t.Callable[
43
+ [t.Callable[Concatenate[DF, P], T]], # accept such a function
44
+ t.Callable[Concatenate[DF, P], T], # and return such a function
45
+ ]:
27
46
  """
28
47
  Decorator used around DataFrame methods to indicate what type of operation is being performed from the
29
48
  ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
@@ -35,9 +54,11 @@ def operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:
35
54
  in cases where there is overlap in names.
36
55
  """
37
56
 
38
- def decorator(func: t.Callable) -> t.Callable:
57
+ def decorator(
58
+ func: t.Callable[Concatenate[DF, P], T],
59
+ ) -> t.Callable[Concatenate[DF, P], T]:
39
60
  @functools.wraps(func)
40
- def wrapper(self: BaseDataFrame, *args, **kwargs) -> BaseDataFrame:
61
+ def wrapper(self: DF, *args, **kwargs) -> T:
41
62
  if self.last_op == Operation.INIT:
42
63
  self = self._convert_leaf_to_cte()
43
64
  self.last_op = Operation.NO_OP
@@ -45,17 +66,22 @@ def operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:
45
66
  new_op = op if op != Operation.NO_OP else last_op
46
67
  if new_op < last_op or (last_op == new_op == Operation.SELECT):
47
68
  self = self._convert_leaf_to_cte()
48
- df: t.Union[BaseDataFrame, _BaseGroupedData] = func(self, *args, **kwargs)
49
- df.last_op = new_op # type: ignore
50
- return df # type: ignore
69
+ df = func(self, *args, **kwargs)
70
+ df.last_op = new_op
71
+ return df
51
72
 
52
- wrapper.__wrapped__ = func # type: ignore
73
+ wrapper.__wrapped__ = func
53
74
  return wrapper
54
75
 
55
76
  return decorator
56
77
 
57
78
 
58
- def group_operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:
79
+ # Here decorate a function (self: _BaseGroupedData[DF], *args, **kwargs) -> DF
80
+ # Hence we work with t.Callable[Concatenate[_BaseGroupedData[DF], P], DF]
81
+ # We simplify the parameters, as Pyright (used for VSCode autocomplete) doesn't unterstand this
82
+ def group_operation(
83
+ op: Operation,
84
+ ) -> t.Callable[[t.Callable[P, DF]], t.Callable[P, DF]]:
59
85
  """
60
86
  Decorator used around DataFrame methods to indicate what type of operation is being performed from the
61
87
  ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
@@ -67,9 +93,11 @@ def group_operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:
67
93
  in cases where there is overlap in names.
68
94
  """
69
95
 
70
- def decorator(func: t.Callable) -> t.Callable:
96
+ def decorator(
97
+ func: t.Callable[Concatenate[_BaseGroupedData[DF], P], DF],
98
+ ) -> t.Callable[Concatenate[_BaseGroupedData[DF], P], DF]:
71
99
  @functools.wraps(func)
72
- def wrapper(self: _BaseGroupedData, *args, **kwargs) -> BaseDataFrame:
100
+ def wrapper(self: _BaseGroupedData[DF], *args, **kwargs) -> DF:
73
101
  if self._df.last_op == Operation.INIT:
74
102
  self._df = self._df._convert_leaf_to_cte()
75
103
  self._df.last_op = Operation.NO_OP
@@ -77,11 +105,11 @@ def group_operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:
77
105
  new_op = op if op != Operation.NO_OP else last_op
78
106
  if new_op < last_op or (last_op == new_op == Operation.SELECT):
79
107
  self._df = self._df._convert_leaf_to_cte()
80
- df: BaseDataFrame = func(self, *args, **kwargs)
81
- df.last_op = new_op # type: ignore
108
+ df = func(self, *args, **kwargs)
109
+ df.last_op = new_op
82
110
  return df
83
111
 
84
- wrapper.__wrapped__ = func # type: ignore
112
+ wrapper.__wrapped__ = func
85
113
  return wrapper
86
114
 
87
- return decorator
115
+ return decorator # type: ignore
@@ -444,7 +444,10 @@ class _BaseDataFrameWriter(t.Generic[SESSION, DF]):
444
444
  return self.copy(_df=df)
445
445
 
446
446
  def saveAsTable(
447
- self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None
447
+ self,
448
+ name: str,
449
+ format: t.Optional[str] = None,
450
+ mode: t.Optional[str] = None,
448
451
  ) -> Self:
449
452
  if format is not None:
450
453
  raise NotImplementedError("Providing Format in the save as table is not supported")
sqlframe/base/window.py CHANGED
@@ -27,11 +27,11 @@ class Window:
27
27
  currentRow: int = 0
28
28
 
29
29
  @classmethod
30
- def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
30
+ def partitionBy(cls, *cols: t.Union[ColumnOrName, t.Collection[ColumnOrName]]) -> WindowSpec:
31
31
  return WindowSpec().partitionBy(*cols)
32
32
 
33
33
  @classmethod
34
- def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
34
+ def orderBy(cls, *cols: t.Union[ColumnOrName, t.Collection[ColumnOrName]]) -> WindowSpec:
35
35
  return WindowSpec().orderBy(*cols)
36
36
 
37
37
  @classmethod
@@ -55,10 +55,10 @@ class WindowSpec:
55
55
 
56
56
  return self.expression.sql(dialect=_BaseSession().input_dialect, **kwargs)
57
57
 
58
- def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
58
+ def partitionBy(self, *cols: t.Union[ColumnOrName, t.Collection[ColumnOrName]]) -> WindowSpec:
59
59
  from sqlframe.base.column import Column
60
60
 
61
- cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
61
+ cols = flatten(cols) if isinstance(cols[0], t.Collection) else cols # type: ignore
62
62
  expressions = [Column.ensure_col(x).expression for x in cols] # type: ignore
63
63
  window_spec = self.copy()
64
64
  partition_by_expressions = window_spec.expression.args.get("partition_by", [])
@@ -66,10 +66,10 @@ class WindowSpec:
66
66
  window_spec.expression.set("partition_by", partition_by_expressions)
67
67
  return window_spec
68
68
 
69
- def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
69
+ def orderBy(self, *cols: t.Union[ColumnOrName, t.Collection[ColumnOrName]]) -> WindowSpec:
70
70
  from sqlframe.base.column import Column
71
71
 
72
- cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
72
+ cols = flatten(cols) if isinstance(cols[0], t.Collection) else cols # type: ignore
73
73
  expressions = [Column.ensure_col(x).expression for x in cols] # type: ignore
74
74
  window_spec = self.copy()
75
75
  if window_spec.expression.args.get("order") is None:
@@ -7,6 +7,7 @@ from sqlglot import exp
7
7
 
8
8
  from sqlframe.base.catalog import CatalogMetadata, Column, Function
9
9
  from sqlframe.base.mixins.catalog_mixins import (
10
+ CreateTableFromFunctionMixin,
10
11
  ListDatabasesFromInfoSchemaMixin,
11
12
  ListTablesFromInfoSchemaMixin,
12
13
  _BaseInfoSchemaMixin,
@@ -18,12 +19,14 @@ if t.TYPE_CHECKING:
18
19
 
19
20
  from sqlframe.bigquery.dataframe import BigQueryDataFrame # noqa
20
21
  from sqlframe.bigquery.session import BigQuerySession # noqa
22
+ from sqlframe.bigquery.table import BigQueryTable # noqa
21
23
 
22
24
 
23
25
  class BigQueryCatalog(
24
- ListDatabasesFromInfoSchemaMixin["BigQuerySession", "BigQueryDataFrame"],
25
- ListTablesFromInfoSchemaMixin["BigQuerySession", "BigQueryDataFrame"],
26
- _BaseInfoSchemaMixin["BigQuerySession", "BigQueryDataFrame"],
26
+ CreateTableFromFunctionMixin["BigQuerySession", "BigQueryDataFrame", "BigQueryTable"],
27
+ ListDatabasesFromInfoSchemaMixin["BigQuerySession", "BigQueryDataFrame", "BigQueryTable"],
28
+ ListTablesFromInfoSchemaMixin["BigQuerySession", "BigQueryDataFrame", "BigQueryTable"],
29
+ _BaseInfoSchemaMixin["BigQuerySession", "BigQueryDataFrame", "BigQueryTable"],
27
30
  ):
28
31
  QUALIFY_INFO_SCHEMA_WITH_DATABASE = True
29
32
  UPPERCASE_INFO_SCHEMA = True