sqlframe 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/__init__.py +0 -0
- sqlframe/_version.py +16 -0
- sqlframe/base/__init__.py +0 -0
- sqlframe/base/_typing.py +39 -0
- sqlframe/base/catalog.py +1163 -0
- sqlframe/base/column.py +388 -0
- sqlframe/base/dataframe.py +1519 -0
- sqlframe/base/decorators.py +51 -0
- sqlframe/base/exceptions.py +14 -0
- sqlframe/base/function_alternatives.py +1055 -0
- sqlframe/base/functions.py +1678 -0
- sqlframe/base/group.py +102 -0
- sqlframe/base/mixins/__init__.py +0 -0
- sqlframe/base/mixins/catalog_mixins.py +419 -0
- sqlframe/base/mixins/readwriter_mixins.py +118 -0
- sqlframe/base/normalize.py +84 -0
- sqlframe/base/operations.py +87 -0
- sqlframe/base/readerwriter.py +679 -0
- sqlframe/base/session.py +585 -0
- sqlframe/base/transforms.py +13 -0
- sqlframe/base/types.py +418 -0
- sqlframe/base/util.py +242 -0
- sqlframe/base/window.py +139 -0
- sqlframe/bigquery/__init__.py +23 -0
- sqlframe/bigquery/catalog.py +255 -0
- sqlframe/bigquery/column.py +1 -0
- sqlframe/bigquery/dataframe.py +54 -0
- sqlframe/bigquery/functions.py +378 -0
- sqlframe/bigquery/group.py +14 -0
- sqlframe/bigquery/readwriter.py +29 -0
- sqlframe/bigquery/session.py +89 -0
- sqlframe/bigquery/types.py +1 -0
- sqlframe/bigquery/window.py +1 -0
- sqlframe/duckdb/__init__.py +20 -0
- sqlframe/duckdb/catalog.py +108 -0
- sqlframe/duckdb/column.py +1 -0
- sqlframe/duckdb/dataframe.py +55 -0
- sqlframe/duckdb/functions.py +47 -0
- sqlframe/duckdb/group.py +14 -0
- sqlframe/duckdb/readwriter.py +111 -0
- sqlframe/duckdb/session.py +65 -0
- sqlframe/duckdb/types.py +1 -0
- sqlframe/duckdb/window.py +1 -0
- sqlframe/postgres/__init__.py +23 -0
- sqlframe/postgres/catalog.py +106 -0
- sqlframe/postgres/column.py +1 -0
- sqlframe/postgres/dataframe.py +54 -0
- sqlframe/postgres/functions.py +61 -0
- sqlframe/postgres/group.py +14 -0
- sqlframe/postgres/readwriter.py +29 -0
- sqlframe/postgres/session.py +68 -0
- sqlframe/postgres/types.py +1 -0
- sqlframe/postgres/window.py +1 -0
- sqlframe/redshift/__init__.py +23 -0
- sqlframe/redshift/catalog.py +127 -0
- sqlframe/redshift/column.py +1 -0
- sqlframe/redshift/dataframe.py +54 -0
- sqlframe/redshift/functions.py +18 -0
- sqlframe/redshift/group.py +14 -0
- sqlframe/redshift/readwriter.py +29 -0
- sqlframe/redshift/session.py +53 -0
- sqlframe/redshift/types.py +1 -0
- sqlframe/redshift/window.py +1 -0
- sqlframe/snowflake/__init__.py +26 -0
- sqlframe/snowflake/catalog.py +134 -0
- sqlframe/snowflake/column.py +1 -0
- sqlframe/snowflake/dataframe.py +54 -0
- sqlframe/snowflake/functions.py +18 -0
- sqlframe/snowflake/group.py +14 -0
- sqlframe/snowflake/readwriter.py +29 -0
- sqlframe/snowflake/session.py +53 -0
- sqlframe/snowflake/types.py +1 -0
- sqlframe/snowflake/window.py +1 -0
- sqlframe/spark/__init__.py +23 -0
- sqlframe/spark/catalog.py +1028 -0
- sqlframe/spark/column.py +1 -0
- sqlframe/spark/dataframe.py +54 -0
- sqlframe/spark/functions.py +22 -0
- sqlframe/spark/group.py +14 -0
- sqlframe/spark/readwriter.py +29 -0
- sqlframe/spark/session.py +90 -0
- sqlframe/spark/types.py +1 -0
- sqlframe/spark/window.py +1 -0
- sqlframe/standalone/__init__.py +26 -0
- sqlframe/standalone/catalog.py +13 -0
- sqlframe/standalone/column.py +1 -0
- sqlframe/standalone/dataframe.py +36 -0
- sqlframe/standalone/functions.py +1 -0
- sqlframe/standalone/group.py +14 -0
- sqlframe/standalone/readwriter.py +19 -0
- sqlframe/standalone/session.py +40 -0
- sqlframe/standalone/types.py +1 -0
- sqlframe/standalone/window.py +1 -0
- sqlframe-1.1.3.dist-info/LICENSE +21 -0
- sqlframe-1.1.3.dist-info/METADATA +172 -0
- sqlframe-1.1.3.dist-info/RECORD +98 -0
- sqlframe-1.1.3.dist-info/WHEEL +5 -0
- sqlframe-1.1.3.dist-info/top_level.txt +1 -0
sqlframe/base/group.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.operations import Operation, group_operation, operation
|
|
8
|
+
|
|
9
|
+
if t.TYPE_CHECKING:
|
|
10
|
+
from sqlframe.base.column import Column
|
|
11
|
+
from sqlframe.base.session import DF
|
|
12
|
+
else:
|
|
13
|
+
DF = t.TypeVar("DF")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-groupby.html
|
|
17
|
+
# https://stackoverflow.com/questions/37975227/what-is-the-difference-between-cube-rollup-and-groupby-operators
|
|
18
|
+
class _BaseGroupedData(t.Generic[DF]):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
df: DF,
|
|
22
|
+
group_by_cols: t.Union[t.List[Column], t.List[t.List[Column]]],
|
|
23
|
+
last_op: Operation,
|
|
24
|
+
):
|
|
25
|
+
self._df = df.copy()
|
|
26
|
+
self.session = df.session
|
|
27
|
+
self.last_op = last_op
|
|
28
|
+
self.group_by_cols = group_by_cols
|
|
29
|
+
|
|
30
|
+
def _get_function_applied_columns(
|
|
31
|
+
self, func_name: str, cols: t.Tuple[str, ...]
|
|
32
|
+
) -> t.List[Column]:
|
|
33
|
+
from sqlframe.base import functions as F
|
|
34
|
+
|
|
35
|
+
func_name = func_name.lower()
|
|
36
|
+
return [
|
|
37
|
+
getattr(F, func_name)(name).alias(
|
|
38
|
+
self.session._sanitize_column_name(f"{func_name}({name})")
|
|
39
|
+
)
|
|
40
|
+
for name in cols
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
@group_operation(Operation.SELECT)
|
|
44
|
+
def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DF:
|
|
45
|
+
from sqlframe.base.column import Column
|
|
46
|
+
|
|
47
|
+
columns = (
|
|
48
|
+
[Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()]
|
|
49
|
+
if isinstance(exprs[0], dict)
|
|
50
|
+
else exprs
|
|
51
|
+
)
|
|
52
|
+
cols = self._df._ensure_and_normalize_cols(columns)
|
|
53
|
+
|
|
54
|
+
if not self.group_by_cols or not isinstance(self.group_by_cols[0], (list, tuple, set)):
|
|
55
|
+
expression = self._df.expression.group_by(
|
|
56
|
+
*[x.expression for x in self.group_by_cols] # type: ignore
|
|
57
|
+
).select(*[x.expression for x in self.group_by_cols + cols], append=False) # type: ignore
|
|
58
|
+
group_by_cols = self.group_by_cols
|
|
59
|
+
else:
|
|
60
|
+
from sqlglot import exp
|
|
61
|
+
|
|
62
|
+
expression = self._df.expression
|
|
63
|
+
all_grouping_sets = []
|
|
64
|
+
group_by_cols = []
|
|
65
|
+
for grouping_set in self.group_by_cols:
|
|
66
|
+
all_grouping_sets.append(
|
|
67
|
+
exp.Tuple(expressions=[x.expression for x in grouping_set]) # type: ignore
|
|
68
|
+
)
|
|
69
|
+
group_by_cols.extend(grouping_set) # type: ignore
|
|
70
|
+
group_by_cols = list(dict.fromkeys(group_by_cols))
|
|
71
|
+
group_by = exp.Group(grouping_sets=all_grouping_sets)
|
|
72
|
+
expression.set("group", group_by)
|
|
73
|
+
for col in cols:
|
|
74
|
+
# Spark supports having an empty grouping_id which means all of the columns but other dialects
|
|
75
|
+
# like duckdb don't support this so we expand the grouping_id to include all of the columns
|
|
76
|
+
if col.column_expression.this == "GROUPING_ID":
|
|
77
|
+
col.column_expression.set("expressions", [x.expression for x in group_by_cols]) # type: ignore
|
|
78
|
+
expression = expression.select(*[x.expression for x in group_by_cols + cols], append=False) # type: ignore
|
|
79
|
+
return self._df.copy(expression=expression)
|
|
80
|
+
|
|
81
|
+
def count(self) -> DF:
|
|
82
|
+
from sqlframe.base import functions as F
|
|
83
|
+
|
|
84
|
+
return self.agg(F.count("*").alias("count"))
|
|
85
|
+
|
|
86
|
+
def mean(self, *cols: str) -> DF:
|
|
87
|
+
return self.avg(*cols)
|
|
88
|
+
|
|
89
|
+
def avg(self, *cols: str) -> DF:
|
|
90
|
+
return self.agg(*self._get_function_applied_columns("avg", cols))
|
|
91
|
+
|
|
92
|
+
def max(self, *cols: str) -> DF:
|
|
93
|
+
return self.agg(*self._get_function_applied_columns("max", cols))
|
|
94
|
+
|
|
95
|
+
def min(self, *cols: str) -> DF:
|
|
96
|
+
return self.agg(*self._get_function_applied_columns("min", cols))
|
|
97
|
+
|
|
98
|
+
def sum(self, *cols: str) -> DF:
|
|
99
|
+
return self.agg(*self._get_function_applied_columns("sum", cols))
|
|
100
|
+
|
|
101
|
+
def pivot(self, *cols: str) -> DF:
|
|
102
|
+
raise NotImplementedError("Sum distinct is not currently implemented")
|
|
File without changes
|
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
import fnmatch
|
|
2
|
+
import typing as t
|
|
3
|
+
|
|
4
|
+
from sqlglot import exp
|
|
5
|
+
|
|
6
|
+
from sqlframe.base.catalog import (
|
|
7
|
+
DF,
|
|
8
|
+
SESSION,
|
|
9
|
+
CatalogMetadata,
|
|
10
|
+
Column,
|
|
11
|
+
Database,
|
|
12
|
+
Table,
|
|
13
|
+
_BaseCatalog,
|
|
14
|
+
)
|
|
15
|
+
from sqlframe.base.decorators import normalize
|
|
16
|
+
from sqlframe.base.util import decoded_str, schema_, to_schema
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _BaseInfoSchemaMixin(_BaseCatalog, t.Generic[SESSION, DF]):
|
|
20
|
+
QUALIFY_INFO_SCHEMA_WITH_DATABASE = False
|
|
21
|
+
UPPERCASE_INFO_SCHEMA = False
|
|
22
|
+
|
|
23
|
+
def _get_info_schema_table(
|
|
24
|
+
self,
|
|
25
|
+
table_name: str,
|
|
26
|
+
database: t.Optional[str] = None,
|
|
27
|
+
qualify_override: t.Optional[bool] = None,
|
|
28
|
+
) -> exp.Table:
|
|
29
|
+
table = f"information_schema.{table_name}"
|
|
30
|
+
if self.UPPERCASE_INFO_SCHEMA:
|
|
31
|
+
table = table.upper()
|
|
32
|
+
qualify = (
|
|
33
|
+
qualify_override
|
|
34
|
+
if qualify_override is not None
|
|
35
|
+
else self.QUALIFY_INFO_SCHEMA_WITH_DATABASE
|
|
36
|
+
)
|
|
37
|
+
if qualify:
|
|
38
|
+
db = database or self.currentDatabase()
|
|
39
|
+
if not db:
|
|
40
|
+
raise ValueError("Table name must be qualified with a database.")
|
|
41
|
+
table = f"{db}.{table}"
|
|
42
|
+
return exp.to_table(table)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GetCurrentCatalogFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF]):
|
|
46
|
+
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_catalog")
|
|
47
|
+
|
|
48
|
+
def currentCatalog(self) -> str:
|
|
49
|
+
"""Returns the current default catalog in this session.
|
|
50
|
+
|
|
51
|
+
.. versionadded:: 3.4.0
|
|
52
|
+
|
|
53
|
+
Examples
|
|
54
|
+
--------
|
|
55
|
+
>>> spark.catalog.currentCatalog()
|
|
56
|
+
'spark_catalog'
|
|
57
|
+
"""
|
|
58
|
+
return self.session._fetch_rows(
|
|
59
|
+
exp.select(self.CURRENT_CATALOG_EXPRESSION), quote_identifiers=False
|
|
60
|
+
)[0][0]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class GetCurrentDatabaseFromFunctionMixin(_BaseCatalog, t.Generic[SESSION, DF]):
|
|
64
|
+
CURRENT_DATABASE_EXPRESSION: exp.Expression = exp.func("current_schema")
|
|
65
|
+
|
|
66
|
+
def currentDatabase(self) -> str:
|
|
67
|
+
"""Returns the current default schema in this session.
|
|
68
|
+
|
|
69
|
+
.. versionadded:: 3.4.0
|
|
70
|
+
|
|
71
|
+
Examples
|
|
72
|
+
--------
|
|
73
|
+
>>> spark.catalog.currentDatabase()
|
|
74
|
+
'default'
|
|
75
|
+
"""
|
|
76
|
+
return self.session._fetch_rows(exp.select(self.CURRENT_DATABASE_EXPRESSION))[0][0]
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class SetCurrentCatalogFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF]):
|
|
80
|
+
def setCurrentCatalog(self, catalogName: str) -> None:
|
|
81
|
+
"""Sets the current default catalog in this session.
|
|
82
|
+
|
|
83
|
+
.. versionadded:: 3.4.0
|
|
84
|
+
|
|
85
|
+
Parameters
|
|
86
|
+
----------
|
|
87
|
+
catalogName : str
|
|
88
|
+
name of the catalog to set
|
|
89
|
+
|
|
90
|
+
Examples
|
|
91
|
+
--------
|
|
92
|
+
>>> spark.catalog.setCurrentCatalog("spark_catalog")
|
|
93
|
+
"""
|
|
94
|
+
self.session._execute(
|
|
95
|
+
exp.Use(this=exp.parse_identifier(catalogName, dialect=self.session.input_dialect))
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class ListDatabasesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
|
|
100
|
+
def listDatabases(self, pattern: t.Optional[str] = None) -> t.List[Database]:
|
|
101
|
+
"""
|
|
102
|
+
Returns a t.List of databases available across all sessions.
|
|
103
|
+
|
|
104
|
+
.. versionadded:: 2.0.0
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
pattern : str
|
|
109
|
+
The pattern that the database name needs to match.
|
|
110
|
+
|
|
111
|
+
.. versionchanged: 3.5.0
|
|
112
|
+
Adds ``pattern`` argument.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
t.List
|
|
117
|
+
A t.List of :class:`Database`.
|
|
118
|
+
|
|
119
|
+
Examples
|
|
120
|
+
--------
|
|
121
|
+
>>> spark.catalog.t.listDatabases()
|
|
122
|
+
[Database(name='default', catalog='spark_catalog', description='default database', ...
|
|
123
|
+
|
|
124
|
+
>>> spark.catalog.t.listDatabases("def*")
|
|
125
|
+
[Database(name='default', catalog='spark_catalog', description='default database', ...
|
|
126
|
+
|
|
127
|
+
>>> spark.catalog.t.listDatabases("def2*")
|
|
128
|
+
[]
|
|
129
|
+
"""
|
|
130
|
+
table = self._get_info_schema_table("schemata", qualify_override=False)
|
|
131
|
+
results = self.session._fetch_rows(
|
|
132
|
+
exp.Select().select("schema_name", "catalog_name").from_(table)
|
|
133
|
+
)
|
|
134
|
+
databases = [
|
|
135
|
+
Database(name=x[0], catalog=x[1], description=None, locationUri="") for x in results
|
|
136
|
+
]
|
|
137
|
+
if pattern:
|
|
138
|
+
databases = [db for db in databases if fnmatch.fnmatch(db.name, pattern)]
|
|
139
|
+
return databases
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class ListCatalogsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
|
|
143
|
+
def listCatalogs(self, pattern: t.Optional[str] = None) -> t.List[CatalogMetadata]:
|
|
144
|
+
"""
|
|
145
|
+
Returns a t.List of databases available across all sessions.
|
|
146
|
+
|
|
147
|
+
.. versionadded:: 2.0.0
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
pattern : str
|
|
152
|
+
The pattern that the database name needs to match.
|
|
153
|
+
|
|
154
|
+
.. versionchanged: 3.5.0
|
|
155
|
+
Adds ``pattern`` argument.
|
|
156
|
+
|
|
157
|
+
Returns
|
|
158
|
+
-------
|
|
159
|
+
t.List
|
|
160
|
+
A t.List of :class:`Database`.
|
|
161
|
+
|
|
162
|
+
Examples
|
|
163
|
+
--------
|
|
164
|
+
>>> spark.catalog.t.listDatabases()
|
|
165
|
+
[Database(name='default', catalog='spark_catalog', description='default database', ...
|
|
166
|
+
|
|
167
|
+
>>> spark.catalog.t.listDatabases("def*")
|
|
168
|
+
[Database(name='default', catalog='spark_catalog', description='default database', ...
|
|
169
|
+
|
|
170
|
+
>>> spark.catalog.t.listDatabases("def2*")
|
|
171
|
+
[]
|
|
172
|
+
"""
|
|
173
|
+
table = self._get_info_schema_table("schemata")
|
|
174
|
+
results = self.session._fetch_rows(
|
|
175
|
+
exp.Select().select("catalog_name").from_(table).distinct()
|
|
176
|
+
)
|
|
177
|
+
catalogs = [CatalogMetadata(name=x[0], description=None) for x in results]
|
|
178
|
+
if pattern:
|
|
179
|
+
catalogs = [catalog for catalog in catalogs if fnmatch.fnmatch(catalog.name, pattern)]
|
|
180
|
+
return catalogs
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class SetCurrentDatabaseFromSearchPathMixin(_BaseCatalog, t.Generic[SESSION, DF]):
|
|
184
|
+
def setCurrentDatabase(self, dbName: str) -> None:
|
|
185
|
+
"""
|
|
186
|
+
Sets the current default database in this session.
|
|
187
|
+
|
|
188
|
+
.. versionadded:: 2.0.0
|
|
189
|
+
|
|
190
|
+
Examples
|
|
191
|
+
--------
|
|
192
|
+
>>> spark.catalog.setCurrentDatabase("default")
|
|
193
|
+
"""
|
|
194
|
+
self.session._execute(f'SET search_path TO "{dbName}"')
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class SetCurrentDatabaseFromUseMixin(_BaseCatalog, t.Generic[SESSION, DF]):
|
|
198
|
+
def setCurrentDatabase(self, dbName: str) -> None:
|
|
199
|
+
"""
|
|
200
|
+
Sets the current default database in this session.
|
|
201
|
+
|
|
202
|
+
.. versionadded:: 2.0.0
|
|
203
|
+
|
|
204
|
+
Examples
|
|
205
|
+
--------
|
|
206
|
+
>>> spark.catalog.setCurrentDatabase("default")
|
|
207
|
+
"""
|
|
208
|
+
schema = to_schema(dbName, dialect=self.session.input_dialect)
|
|
209
|
+
if not schema.catalog:
|
|
210
|
+
schema.set(
|
|
211
|
+
"catalog",
|
|
212
|
+
exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
|
|
213
|
+
)
|
|
214
|
+
self.session._execute(exp.Use(this=schema))
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class ListTablesFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
|
|
218
|
+
@normalize(["dbName"])
|
|
219
|
+
def listTables(
|
|
220
|
+
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
|
|
221
|
+
) -> t.List[Table]:
|
|
222
|
+
"""Returns a t.List of tables/views in the specified database.
|
|
223
|
+
|
|
224
|
+
.. versionadded:: 2.0.0
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
dbName : str
|
|
229
|
+
name of the database to t.List the tables.
|
|
230
|
+
|
|
231
|
+
.. versionchanged:: 3.4.0
|
|
232
|
+
Allow ``dbName`` to be qualified with catalog name.
|
|
233
|
+
|
|
234
|
+
pattern : str
|
|
235
|
+
The pattern that the database name needs to match.
|
|
236
|
+
|
|
237
|
+
.. versionchanged: 3.5.0
|
|
238
|
+
Adds ``pattern`` argument.
|
|
239
|
+
|
|
240
|
+
Returns
|
|
241
|
+
-------
|
|
242
|
+
t.List
|
|
243
|
+
A t.List of :class:`Table`.
|
|
244
|
+
|
|
245
|
+
Notes
|
|
246
|
+
-----
|
|
247
|
+
If no database is specified, the current database and catalog
|
|
248
|
+
are used. This API includes all temporary views.
|
|
249
|
+
|
|
250
|
+
Examples
|
|
251
|
+
--------
|
|
252
|
+
>>> spark.range(1).createTempView("test_view")
|
|
253
|
+
>>> spark.catalog.t.listTables()
|
|
254
|
+
[Table(name='test_view', catalog=None, namespace=[], description=None, ...
|
|
255
|
+
|
|
256
|
+
>>> spark.catalog.t.listTables(pattern="test*")
|
|
257
|
+
[Table(name='test_view', catalog=None, namespace=[], description=None, ...
|
|
258
|
+
|
|
259
|
+
>>> spark.catalog.t.listTables(pattern="table*")
|
|
260
|
+
[]
|
|
261
|
+
|
|
262
|
+
>>> _ = spark.catalog.dropTempView("test_view")
|
|
263
|
+
>>> spark.catalog.t.listTables()
|
|
264
|
+
[]
|
|
265
|
+
"""
|
|
266
|
+
if dbName is None and pattern is None:
|
|
267
|
+
schema = schema_(
|
|
268
|
+
db=exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
|
|
269
|
+
catalog=exp.parse_identifier(
|
|
270
|
+
self.currentCatalog(), dialect=self.session.input_dialect
|
|
271
|
+
),
|
|
272
|
+
)
|
|
273
|
+
elif dbName:
|
|
274
|
+
schema = to_schema(dbName, dialect=self.session.input_dialect)
|
|
275
|
+
else:
|
|
276
|
+
schema = None
|
|
277
|
+
table = self._get_info_schema_table("tables", database=schema.db if schema else None)
|
|
278
|
+
select = exp.select(
|
|
279
|
+
'table_name AS "table_name"',
|
|
280
|
+
'table_schema AS "table_schema"',
|
|
281
|
+
'table_catalog AS "table_catalog"',
|
|
282
|
+
'table_type AS "table_type"',
|
|
283
|
+
).from_(table)
|
|
284
|
+
if schema and schema.db:
|
|
285
|
+
select = select.where(exp.column("table_schema").eq(schema.db))
|
|
286
|
+
if schema and schema.catalog:
|
|
287
|
+
select = select.where(exp.column("table_catalog").eq(schema.catalog))
|
|
288
|
+
results = self.session._fetch_rows(select)
|
|
289
|
+
tables = [
|
|
290
|
+
Table(
|
|
291
|
+
name=x["table_name"],
|
|
292
|
+
catalog=x["table_catalog"],
|
|
293
|
+
namespace=[x["table_schema"]],
|
|
294
|
+
description=None,
|
|
295
|
+
tableType="VIEW" if x["table_type"] == "VIEW" else "MANAGED",
|
|
296
|
+
isTemporary=False,
|
|
297
|
+
)
|
|
298
|
+
for x in results
|
|
299
|
+
]
|
|
300
|
+
for table in self.session.temp_views.keys():
|
|
301
|
+
tables.append(
|
|
302
|
+
Table(
|
|
303
|
+
name=table, # type: ignore
|
|
304
|
+
catalog=None,
|
|
305
|
+
namespace=[],
|
|
306
|
+
description=None,
|
|
307
|
+
tableType="VIEW",
|
|
308
|
+
isTemporary=True,
|
|
309
|
+
)
|
|
310
|
+
)
|
|
311
|
+
if pattern:
|
|
312
|
+
tables = [x for x in tables if fnmatch.fnmatch(x.name, pattern)]
|
|
313
|
+
return tables
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class ListColumnsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
|
|
317
|
+
@normalize(["tableName", "dbName"])
|
|
318
|
+
def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]:
|
|
319
|
+
"""Returns a t.List of columns for the given table/view in the specified database.
|
|
320
|
+
|
|
321
|
+
.. versionadded:: 2.0.0
|
|
322
|
+
|
|
323
|
+
Parameters
|
|
324
|
+
----------
|
|
325
|
+
tableName : str
|
|
326
|
+
name of the table to t.List columns.
|
|
327
|
+
|
|
328
|
+
.. versionchanged:: 3.4.0
|
|
329
|
+
Allow ``tableName`` to be qualified with catalog name when ``dbName`` is None.
|
|
330
|
+
|
|
331
|
+
dbName : str, t.Optional
|
|
332
|
+
name of the database to find the table to t.List columns.
|
|
333
|
+
|
|
334
|
+
Returns
|
|
335
|
+
-------
|
|
336
|
+
t.List
|
|
337
|
+
A t.List of :class:`Column`.
|
|
338
|
+
|
|
339
|
+
Notes
|
|
340
|
+
-----
|
|
341
|
+
The order of arguments here is different from that of its JVM counterpart
|
|
342
|
+
because Python does not support method overloading.
|
|
343
|
+
|
|
344
|
+
If no database is specified, the current database and catalog
|
|
345
|
+
are used. This API includes all temporary views.
|
|
346
|
+
|
|
347
|
+
Examples
|
|
348
|
+
--------
|
|
349
|
+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
|
|
350
|
+
>>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
|
|
351
|
+
>>> spark.catalog.t.listColumns("tblA")
|
|
352
|
+
[Column(name='name', description=None, dataType='string', nullable=True, ...
|
|
353
|
+
>>> _ = spark.sql("DROP TABLE tblA")
|
|
354
|
+
"""
|
|
355
|
+
if df := self.session.temp_views.get(tableName):
|
|
356
|
+
return [
|
|
357
|
+
Column(
|
|
358
|
+
name=x,
|
|
359
|
+
description=None,
|
|
360
|
+
dataType="",
|
|
361
|
+
nullable=True,
|
|
362
|
+
isPartition=False,
|
|
363
|
+
isBucket=False,
|
|
364
|
+
)
|
|
365
|
+
for x in df.columns
|
|
366
|
+
]
|
|
367
|
+
|
|
368
|
+
table = exp.to_table(tableName, dialect=self.session.input_dialect)
|
|
369
|
+
schema = to_schema(dbName, dialect=self.session.input_dialect) if dbName else None
|
|
370
|
+
if not table.db:
|
|
371
|
+
if schema and schema.db:
|
|
372
|
+
table.set("db", schema.args["db"])
|
|
373
|
+
else:
|
|
374
|
+
table.set(
|
|
375
|
+
"db",
|
|
376
|
+
exp.parse_identifier(
|
|
377
|
+
self.currentDatabase(), dialect=self.session.input_dialect
|
|
378
|
+
),
|
|
379
|
+
)
|
|
380
|
+
if not table.catalog:
|
|
381
|
+
if schema and schema.catalog:
|
|
382
|
+
table.set("catalog", schema.args["catalog"])
|
|
383
|
+
else:
|
|
384
|
+
table.set(
|
|
385
|
+
"catalog",
|
|
386
|
+
exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
|
|
387
|
+
)
|
|
388
|
+
# if self.QUALIFY_INFO_SCHEMA_WITH_DATABASE:
|
|
389
|
+
# if not table.db:
|
|
390
|
+
# raise ValueError("dbName must be specified when listing columns from INFORMATION_SCHEMA")
|
|
391
|
+
# source_table = f"{table.db}.INFORMATION_SCHEMA.COLUMNS"
|
|
392
|
+
# else:
|
|
393
|
+
# source_table = "INFORMATION_SCHEMA.COLUMNS"
|
|
394
|
+
source_table = self._get_info_schema_table("columns", database=table.db)
|
|
395
|
+
select = (
|
|
396
|
+
exp.select(
|
|
397
|
+
'column_name AS "column_name"',
|
|
398
|
+
'data_type AS "data_type"',
|
|
399
|
+
'is_nullable AS "is_nullable"',
|
|
400
|
+
)
|
|
401
|
+
.from_(source_table)
|
|
402
|
+
.where(exp.column("table_name").eq(table.name))
|
|
403
|
+
)
|
|
404
|
+
if table.db:
|
|
405
|
+
select = select.where(exp.column("table_schema").eq(table.db))
|
|
406
|
+
if table.catalog:
|
|
407
|
+
select = select.where(exp.column("table_catalog").eq(table.catalog))
|
|
408
|
+
results = self.session._fetch_rows(select)
|
|
409
|
+
return [
|
|
410
|
+
Column(
|
|
411
|
+
name=x["column_name"],
|
|
412
|
+
description=None,
|
|
413
|
+
dataType=x["data_type"],
|
|
414
|
+
nullable=x["is_nullable"] == "YES",
|
|
415
|
+
isPartition=False,
|
|
416
|
+
isBucket=False,
|
|
417
|
+
)
|
|
418
|
+
for x in results
|
|
419
|
+
]
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import pathlib
|
|
4
|
+
import typing as t
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from sqlframe.base.exceptions import UnsupportedOperationError
|
|
9
|
+
from sqlframe.base.readerwriter import (
|
|
10
|
+
DF,
|
|
11
|
+
SESSION,
|
|
12
|
+
_BaseDataFrameReader,
|
|
13
|
+
_BaseDataFrameWriter,
|
|
14
|
+
_infer_format,
|
|
15
|
+
)
|
|
16
|
+
from sqlframe.base.util import pandas_to_spark_schema
|
|
17
|
+
|
|
18
|
+
if t.TYPE_CHECKING:
|
|
19
|
+
from sqlframe.base._typing import OptionalPrimitiveType, PathOrPaths
|
|
20
|
+
from sqlframe.base.types import StructType
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PandasLoaderMixin(_BaseDataFrameReader, t.Generic[SESSION, DF]):
|
|
24
|
+
def load(
|
|
25
|
+
self,
|
|
26
|
+
path: t.Optional[PathOrPaths] = None,
|
|
27
|
+
format: t.Optional[str] = None,
|
|
28
|
+
schema: t.Optional[t.Union[StructType, str]] = None,
|
|
29
|
+
**options: OptionalPrimitiveType,
|
|
30
|
+
) -> DF:
|
|
31
|
+
"""Loads data from a data source and returns it as a :class:`DataFrame`.
|
|
32
|
+
|
|
33
|
+
.. versionadded:: 1.4.0
|
|
34
|
+
|
|
35
|
+
.. versionchanged:: 3.4.0
|
|
36
|
+
Supports Spark Connect.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
path : str or list, t.Optional
|
|
41
|
+
t.Optional string or a list of string for file-system backed data sources.
|
|
42
|
+
format : str, t.Optional
|
|
43
|
+
t.Optional string for format of the data source. Default to 'parquet'.
|
|
44
|
+
schema : :class:`pyspark.sql.types.StructType` or str, t.Optional
|
|
45
|
+
t.Optional :class:`pyspark.sql.types.StructType` for the input schema
|
|
46
|
+
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
|
|
47
|
+
**options : dict
|
|
48
|
+
all other string options
|
|
49
|
+
|
|
50
|
+
Examples
|
|
51
|
+
--------
|
|
52
|
+
Load a CSV file with format, schema and options specified.
|
|
53
|
+
|
|
54
|
+
>>> import tempfile
|
|
55
|
+
>>> with tempfile.TemporaryDirectory() as d:
|
|
56
|
+
... # Write a DataFrame into a CSV file with a header
|
|
57
|
+
... df = spark.createDataFrame([{"age": 100, "name": "Hyukjin Kwon"}])
|
|
58
|
+
... df.write.option("header", True).mode("overwrite").format("csv").save(d)
|
|
59
|
+
...
|
|
60
|
+
... # Read the CSV file as a DataFrame with 'nullValue' option set to 'Hyukjin Kwon',
|
|
61
|
+
... # and 'header' option set to `True`.
|
|
62
|
+
... df = spark.read.load(
|
|
63
|
+
... d, schema=df.schema, format="csv", nullValue="Hyukjin Kwon", header=True)
|
|
64
|
+
... df.printSchema()
|
|
65
|
+
... df.show()
|
|
66
|
+
root
|
|
67
|
+
|-- age: long (nullable = true)
|
|
68
|
+
|-- name: string (nullable = true)
|
|
69
|
+
+---+----+
|
|
70
|
+
|age|name|
|
|
71
|
+
+---+----+
|
|
72
|
+
|100|NULL|
|
|
73
|
+
+---+----+
|
|
74
|
+
"""
|
|
75
|
+
assert path is not None, "path is required"
|
|
76
|
+
assert isinstance(path, str), "path must be a string"
|
|
77
|
+
format = format or _infer_format(path)
|
|
78
|
+
kwargs = {k: v for k, v in options.items() if v is not None}
|
|
79
|
+
if format == "json":
|
|
80
|
+
df = pd.read_json(path, lines=True, **kwargs) # type: ignore
|
|
81
|
+
elif format == "parquet":
|
|
82
|
+
df = pd.read_parquet(path, **kwargs) # type: ignore
|
|
83
|
+
elif format == "csv":
|
|
84
|
+
df = pd.read_csv(path, **kwargs) # type: ignore
|
|
85
|
+
else:
|
|
86
|
+
raise UnsupportedOperationError(f"Unsupported format: {format}")
|
|
87
|
+
schema = schema or pandas_to_spark_schema(df)
|
|
88
|
+
self.session._last_loaded_file = path
|
|
89
|
+
return self._session.createDataFrame(list(df.itertuples(index=False)), schema)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class PandasWriterMixin(_BaseDataFrameWriter, t.Generic[SESSION, DF]):
|
|
93
|
+
def _write(self, path: str, mode: t.Optional[str], format: str, **options): # type: ignore
|
|
94
|
+
mode, skip = self._validate_mode(path, mode)
|
|
95
|
+
if skip:
|
|
96
|
+
return
|
|
97
|
+
pandas_df = self._df.toPandas()
|
|
98
|
+
mode = self._mode_to_pandas_mode(mode)
|
|
99
|
+
kwargs = {k: v for k, v in options.items() if v is not None}
|
|
100
|
+
kwargs["index"] = False
|
|
101
|
+
if format == "csv":
|
|
102
|
+
kwargs["mode"] = mode
|
|
103
|
+
if mode == "a" and pathlib.Path(path).exists():
|
|
104
|
+
kwargs["header"] = False
|
|
105
|
+
pandas_df.to_csv(path, **kwargs)
|
|
106
|
+
elif format == "parquet":
|
|
107
|
+
if mode == "a":
|
|
108
|
+
raise NotImplementedError("Append mode is not supported for parquet.")
|
|
109
|
+
pandas_df.to_parquet(path, **kwargs)
|
|
110
|
+
elif format == "json":
|
|
111
|
+
# Pandas versions are inconsistent on how to handle True/False index so we just remove it
|
|
112
|
+
# since in all versions it will not result in an index column in the output.
|
|
113
|
+
del kwargs["index"]
|
|
114
|
+
kwargs["mode"] = mode
|
|
115
|
+
kwargs["orient"] = "records"
|
|
116
|
+
pandas_df.to_json(path, lines=True, **kwargs)
|
|
117
|
+
else:
|
|
118
|
+
raise NotImplementedError(f"Unsupported format: {format}")
|