sqlframe 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/__init__.py +0 -0
- sqlframe/_version.py +16 -0
- sqlframe/base/__init__.py +0 -0
- sqlframe/base/_typing.py +39 -0
- sqlframe/base/catalog.py +1163 -0
- sqlframe/base/column.py +388 -0
- sqlframe/base/dataframe.py +1519 -0
- sqlframe/base/decorators.py +51 -0
- sqlframe/base/exceptions.py +14 -0
- sqlframe/base/function_alternatives.py +1055 -0
- sqlframe/base/functions.py +1678 -0
- sqlframe/base/group.py +102 -0
- sqlframe/base/mixins/__init__.py +0 -0
- sqlframe/base/mixins/catalog_mixins.py +419 -0
- sqlframe/base/mixins/readwriter_mixins.py +118 -0
- sqlframe/base/normalize.py +84 -0
- sqlframe/base/operations.py +87 -0
- sqlframe/base/readerwriter.py +679 -0
- sqlframe/base/session.py +585 -0
- sqlframe/base/transforms.py +13 -0
- sqlframe/base/types.py +418 -0
- sqlframe/base/util.py +242 -0
- sqlframe/base/window.py +139 -0
- sqlframe/bigquery/__init__.py +23 -0
- sqlframe/bigquery/catalog.py +255 -0
- sqlframe/bigquery/column.py +1 -0
- sqlframe/bigquery/dataframe.py +54 -0
- sqlframe/bigquery/functions.py +378 -0
- sqlframe/bigquery/group.py +14 -0
- sqlframe/bigquery/readwriter.py +29 -0
- sqlframe/bigquery/session.py +89 -0
- sqlframe/bigquery/types.py +1 -0
- sqlframe/bigquery/window.py +1 -0
- sqlframe/duckdb/__init__.py +20 -0
- sqlframe/duckdb/catalog.py +108 -0
- sqlframe/duckdb/column.py +1 -0
- sqlframe/duckdb/dataframe.py +55 -0
- sqlframe/duckdb/functions.py +47 -0
- sqlframe/duckdb/group.py +14 -0
- sqlframe/duckdb/readwriter.py +111 -0
- sqlframe/duckdb/session.py +65 -0
- sqlframe/duckdb/types.py +1 -0
- sqlframe/duckdb/window.py +1 -0
- sqlframe/postgres/__init__.py +23 -0
- sqlframe/postgres/catalog.py +106 -0
- sqlframe/postgres/column.py +1 -0
- sqlframe/postgres/dataframe.py +54 -0
- sqlframe/postgres/functions.py +61 -0
- sqlframe/postgres/group.py +14 -0
- sqlframe/postgres/readwriter.py +29 -0
- sqlframe/postgres/session.py +68 -0
- sqlframe/postgres/types.py +1 -0
- sqlframe/postgres/window.py +1 -0
- sqlframe/redshift/__init__.py +23 -0
- sqlframe/redshift/catalog.py +127 -0
- sqlframe/redshift/column.py +1 -0
- sqlframe/redshift/dataframe.py +54 -0
- sqlframe/redshift/functions.py +18 -0
- sqlframe/redshift/group.py +14 -0
- sqlframe/redshift/readwriter.py +29 -0
- sqlframe/redshift/session.py +53 -0
- sqlframe/redshift/types.py +1 -0
- sqlframe/redshift/window.py +1 -0
- sqlframe/snowflake/__init__.py +26 -0
- sqlframe/snowflake/catalog.py +134 -0
- sqlframe/snowflake/column.py +1 -0
- sqlframe/snowflake/dataframe.py +54 -0
- sqlframe/snowflake/functions.py +18 -0
- sqlframe/snowflake/group.py +14 -0
- sqlframe/snowflake/readwriter.py +29 -0
- sqlframe/snowflake/session.py +53 -0
- sqlframe/snowflake/types.py +1 -0
- sqlframe/snowflake/window.py +1 -0
- sqlframe/spark/__init__.py +23 -0
- sqlframe/spark/catalog.py +1028 -0
- sqlframe/spark/column.py +1 -0
- sqlframe/spark/dataframe.py +54 -0
- sqlframe/spark/functions.py +22 -0
- sqlframe/spark/group.py +14 -0
- sqlframe/spark/readwriter.py +29 -0
- sqlframe/spark/session.py +90 -0
- sqlframe/spark/types.py +1 -0
- sqlframe/spark/window.py +1 -0
- sqlframe/standalone/__init__.py +26 -0
- sqlframe/standalone/catalog.py +13 -0
- sqlframe/standalone/column.py +1 -0
- sqlframe/standalone/dataframe.py +36 -0
- sqlframe/standalone/functions.py +1 -0
- sqlframe/standalone/group.py +14 -0
- sqlframe/standalone/readwriter.py +19 -0
- sqlframe/standalone/session.py +40 -0
- sqlframe/standalone/types.py +1 -0
- sqlframe/standalone/window.py +1 -0
- sqlframe-1.1.3.dist-info/LICENSE +21 -0
- sqlframe-1.1.3.dist-info/METADATA +172 -0
- sqlframe-1.1.3.dist-info/RECORD +98 -0
- sqlframe-1.1.3.dist-info/WHEEL +5 -0
- sqlframe-1.1.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import fnmatch
|
|
6
|
+
import typing as t
|
|
7
|
+
|
|
8
|
+
from sqlglot import exp, parse_one
|
|
9
|
+
|
|
10
|
+
from sqlframe.base.catalog import Function, _BaseCatalog
|
|
11
|
+
from sqlframe.base.mixins.catalog_mixins import (
|
|
12
|
+
GetCurrentCatalogFromFunctionMixin,
|
|
13
|
+
GetCurrentDatabaseFromFunctionMixin,
|
|
14
|
+
ListCatalogsFromInfoSchemaMixin,
|
|
15
|
+
ListColumnsFromInfoSchemaMixin,
|
|
16
|
+
ListDatabasesFromInfoSchemaMixin,
|
|
17
|
+
ListTablesFromInfoSchemaMixin,
|
|
18
|
+
SetCurrentDatabaseFromSearchPathMixin,
|
|
19
|
+
)
|
|
20
|
+
from sqlframe.base.util import schema_, to_schema
|
|
21
|
+
|
|
22
|
+
if t.TYPE_CHECKING:
|
|
23
|
+
from sqlframe.redshift.dataframe import RedshiftDataFrame
|
|
24
|
+
from sqlframe.redshift.session import RedshiftSession
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RedshiftCatalog(
|
|
28
|
+
GetCurrentCatalogFromFunctionMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
29
|
+
GetCurrentDatabaseFromFunctionMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
30
|
+
ListDatabasesFromInfoSchemaMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
31
|
+
ListCatalogsFromInfoSchemaMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
32
|
+
SetCurrentDatabaseFromSearchPathMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
33
|
+
ListTablesFromInfoSchemaMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
34
|
+
ListColumnsFromInfoSchemaMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
35
|
+
_BaseCatalog["RedshiftSession", "RedshiftDataFrame"],
|
|
36
|
+
):
|
|
37
|
+
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_database")
|
|
38
|
+
|
|
39
|
+
def listFunctions(
|
|
40
|
+
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
|
|
41
|
+
) -> t.List[Function]:
|
|
42
|
+
"""
|
|
43
|
+
Returns a t.List of functions registered in the specified database.
|
|
44
|
+
|
|
45
|
+
.. versionadded:: 3.4.0
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
dbName : str
|
|
50
|
+
name of the database to t.List the functions.
|
|
51
|
+
``dbName`` can be qualified with catalog name.
|
|
52
|
+
pattern : str
|
|
53
|
+
The pattern that the function name needs to match.
|
|
54
|
+
|
|
55
|
+
.. versionchanged: 3.5.0
|
|
56
|
+
Adds ``pattern`` argument.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
t.List
|
|
61
|
+
A t.List of :class:`Function`.
|
|
62
|
+
|
|
63
|
+
Notes
|
|
64
|
+
-----
|
|
65
|
+
If no database is specified, the current database and catalog
|
|
66
|
+
are used. This API includes all temporary functions.
|
|
67
|
+
|
|
68
|
+
Examples
|
|
69
|
+
--------
|
|
70
|
+
>>> spark.catalog.t.listFunctions()
|
|
71
|
+
[Function(name=...
|
|
72
|
+
|
|
73
|
+
>>> spark.catalog.t.listFunctions(pattern="to_*")
|
|
74
|
+
[Function(name=...
|
|
75
|
+
|
|
76
|
+
>>> spark.catalog.t.listFunctions(pattern="*not_existing_func*")
|
|
77
|
+
[]
|
|
78
|
+
"""
|
|
79
|
+
if dbName is None:
|
|
80
|
+
schema = schema_(
|
|
81
|
+
db=exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
|
|
82
|
+
catalog=exp.parse_identifier(
|
|
83
|
+
self.currentCatalog(), dialect=self.session.input_dialect
|
|
84
|
+
),
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
schema = to_schema(dbName, dialect=self.session.input_dialect)
|
|
88
|
+
if not schema.catalog:
|
|
89
|
+
schema.set("catalog", exp.parse_identifier(self.currentCatalog()))
|
|
90
|
+
query = parse_one(
|
|
91
|
+
f"""SELECT database_name as catalog, schema_name as namespace, function_name as name
|
|
92
|
+
FROM svv_redshift_functions
|
|
93
|
+
WHERE database_name = '{schema.catalog}'
|
|
94
|
+
and schema_name = '{schema.db}'
|
|
95
|
+
ORDER BY function_name;
|
|
96
|
+
""",
|
|
97
|
+
dialect=self.session.input_dialect,
|
|
98
|
+
)
|
|
99
|
+
functions = self.session._fetch_rows(query)
|
|
100
|
+
if pattern:
|
|
101
|
+
functions = [x for x in functions if fnmatch.fnmatch(x["name"], pattern)]
|
|
102
|
+
return [
|
|
103
|
+
Function(
|
|
104
|
+
name=x["name"],
|
|
105
|
+
catalog=x["catalog"],
|
|
106
|
+
namespace=[x["namespace"]],
|
|
107
|
+
description=None,
|
|
108
|
+
className="",
|
|
109
|
+
isTemporary=False,
|
|
110
|
+
)
|
|
111
|
+
for x in functions
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
# def get_columns(self, table_name: t.Union[exp.Table, str]) -> t.Dict[str, exp.DataType]:
|
|
115
|
+
# table = self.ensure_table(table_name)
|
|
116
|
+
# if not table.catalog:
|
|
117
|
+
# table.set("catalog", exp.parse_identifier(self.currentCatalog()))
|
|
118
|
+
# if not table.db:
|
|
119
|
+
# table.set("db", exp.parse_identifier(self.currentDatabase()))
|
|
120
|
+
# sql = f"SHOW COLUMNS FROM TABLE {table.sql(dialect=self.session.input_dialect)}"
|
|
121
|
+
# results = sorted(self.session._fetch_rows(sql), key=lambda x: x["ordinal_position"])
|
|
122
|
+
# return {
|
|
123
|
+
# row["column_name"]: exp.DataType.build(
|
|
124
|
+
# row["data_type"], dialect=self.session.input_dialect, udt=True
|
|
125
|
+
# )
|
|
126
|
+
# for row in results
|
|
127
|
+
# }
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.column import Column
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.dataframe import (
|
|
8
|
+
_BaseDataFrame,
|
|
9
|
+
_BaseDataFrameNaFunctions,
|
|
10
|
+
_BaseDataFrameStatFunctions,
|
|
11
|
+
)
|
|
12
|
+
from sqlframe.redshift.group import RedshiftGroupedData
|
|
13
|
+
|
|
14
|
+
if sys.version_info >= (3, 11):
|
|
15
|
+
from typing import Self
|
|
16
|
+
else:
|
|
17
|
+
from typing_extensions import Self
|
|
18
|
+
|
|
19
|
+
if t.TYPE_CHECKING:
|
|
20
|
+
from sqlframe.redshift.readwriter import RedshiftDataFrameWriter
|
|
21
|
+
from sqlframe.redshift.session import RedshiftSession
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RedshiftDataFrameNaFunctions(_BaseDataFrameNaFunctions["RedshiftDataFrame"]):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RedshiftDataFrameStatFunctions(_BaseDataFrameStatFunctions["RedshiftDataFrame"]):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RedshiftDataFrame(
|
|
36
|
+
_BaseDataFrame[
|
|
37
|
+
"RedshiftSession",
|
|
38
|
+
"RedshiftDataFrameWriter",
|
|
39
|
+
"RedshiftDataFrameNaFunctions",
|
|
40
|
+
"RedshiftDataFrameStatFunctions",
|
|
41
|
+
"RedshiftGroupedData",
|
|
42
|
+
]
|
|
43
|
+
):
|
|
44
|
+
_na = RedshiftDataFrameNaFunctions
|
|
45
|
+
_stat = RedshiftDataFrameStatFunctions
|
|
46
|
+
_group_data = RedshiftGroupedData
|
|
47
|
+
|
|
48
|
+
def cache(self) -> Self:
|
|
49
|
+
logger.warning("Redshift does not support caching. Ignoring cache() call.")
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
def persist(self) -> Self:
|
|
53
|
+
logger.warning("Redshift does not support persist. Ignoring persist() call.")
|
|
54
|
+
return self
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
import sqlframe.base.functions
|
|
5
|
+
|
|
6
|
+
module = sys.modules["sqlframe.base.functions"]
|
|
7
|
+
globals().update(
|
|
8
|
+
{
|
|
9
|
+
name: func
|
|
10
|
+
for name, func in inspect.getmembers(module, inspect.isfunction)
|
|
11
|
+
if hasattr(func, "unsupported_engines")
|
|
12
|
+
and "redshift" not in func.unsupported_engines
|
|
13
|
+
and "*" not in func.unsupported_engines
|
|
14
|
+
}
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from sqlframe.base.function_alternatives import e_literal as e # noqa
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.group import _BaseGroupedData
|
|
8
|
+
|
|
9
|
+
if t.TYPE_CHECKING:
|
|
10
|
+
from sqlframe.redshift.dataframe import RedshiftDataFrame
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RedshiftGroupedData(_BaseGroupedData["RedshiftDataFrame"]):
|
|
14
|
+
pass
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.mixins.readwriter_mixins import PandasLoaderMixin, PandasWriterMixin
|
|
8
|
+
from sqlframe.base.readerwriter import (
|
|
9
|
+
_BaseDataFrameReader,
|
|
10
|
+
_BaseDataFrameWriter,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if t.TYPE_CHECKING:
|
|
14
|
+
from sqlframe.redshift.session import RedshiftSession # noqa
|
|
15
|
+
from sqlframe.redshift.dataframe import RedshiftDataFrame # noqa
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RedshiftDataFrameReader(
|
|
19
|
+
PandasLoaderMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
20
|
+
_BaseDataFrameReader["RedshiftSession", "RedshiftDataFrame"],
|
|
21
|
+
):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class RedshiftDataFrameWriter(
|
|
26
|
+
PandasWriterMixin["RedshiftSession", "RedshiftDataFrame"],
|
|
27
|
+
_BaseDataFrameWriter["RedshiftSession", "RedshiftDataFrame"],
|
|
28
|
+
):
|
|
29
|
+
pass
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import typing as t
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
from sqlframe.base.session import _BaseSession
|
|
7
|
+
from sqlframe.redshift.catalog import RedshiftCatalog
|
|
8
|
+
from sqlframe.redshift.dataframe import RedshiftDataFrame
|
|
9
|
+
from sqlframe.redshift.readwriter import (
|
|
10
|
+
RedshiftDataFrameReader,
|
|
11
|
+
RedshiftDataFrameWriter,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if t.TYPE_CHECKING:
|
|
15
|
+
from redshift_connector.core import Connection as RedshiftConnection
|
|
16
|
+
else:
|
|
17
|
+
RedshiftConnection = t.Any
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RedshiftSession(
|
|
21
|
+
_BaseSession[ # type: ignore
|
|
22
|
+
RedshiftCatalog,
|
|
23
|
+
RedshiftDataFrameReader,
|
|
24
|
+
RedshiftDataFrameWriter,
|
|
25
|
+
RedshiftDataFrame,
|
|
26
|
+
RedshiftConnection,
|
|
27
|
+
],
|
|
28
|
+
):
|
|
29
|
+
_catalog = RedshiftCatalog
|
|
30
|
+
_reader = RedshiftDataFrameReader
|
|
31
|
+
_writer = RedshiftDataFrameWriter
|
|
32
|
+
_df = RedshiftDataFrame
|
|
33
|
+
|
|
34
|
+
def __init__(self, conn: t.Optional[RedshiftConnection] = None):
|
|
35
|
+
warnings.warn(
|
|
36
|
+
"RedshiftSession is still in active development. Functions may not work as expected."
|
|
37
|
+
)
|
|
38
|
+
if not hasattr(self, "_conn"):
|
|
39
|
+
super().__init__(conn)
|
|
40
|
+
|
|
41
|
+
class Builder(_BaseSession.Builder):
|
|
42
|
+
DEFAULT_INPUT_DIALECT = "redshift"
|
|
43
|
+
DEFAULT_OUTPUT_DIALECT = "redshift"
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def session(self) -> RedshiftSession:
|
|
47
|
+
return RedshiftSession(**self._session_kwargs)
|
|
48
|
+
|
|
49
|
+
def getOrCreate(self) -> RedshiftSession:
|
|
50
|
+
self._set_session_properties()
|
|
51
|
+
return self.session
|
|
52
|
+
|
|
53
|
+
builder = Builder()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.types import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.window import *
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from sqlframe.snowflake.catalog import SnowflakeCatalog
|
|
2
|
+
from sqlframe.snowflake.column import Column
|
|
3
|
+
from sqlframe.snowflake.dataframe import (
|
|
4
|
+
SnowflakeDataFrame,
|
|
5
|
+
SnowflakeDataFrameNaFunctions,
|
|
6
|
+
)
|
|
7
|
+
from sqlframe.snowflake.group import SnowflakeGroupedData
|
|
8
|
+
from sqlframe.snowflake.readwriter import (
|
|
9
|
+
SnowflakeDataFrameReader,
|
|
10
|
+
SnowflakeDataFrameWriter,
|
|
11
|
+
)
|
|
12
|
+
from sqlframe.snowflake.session import SnowflakeSession
|
|
13
|
+
from sqlframe.snowflake.window import Window, WindowSpec
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"SnowflakeCatalog",
|
|
17
|
+
"Column",
|
|
18
|
+
"SnowflakeDataFrame",
|
|
19
|
+
"SnowflakeDataFrameNaFunctions",
|
|
20
|
+
"SnowflakeGroupedData",
|
|
21
|
+
"SnowflakeDataFrameReader",
|
|
22
|
+
"SnowflakeDataFrameWriter",
|
|
23
|
+
"SnowflakeSession",
|
|
24
|
+
"Window",
|
|
25
|
+
"WindowSpec",
|
|
26
|
+
]
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import fnmatch
|
|
6
|
+
import json
|
|
7
|
+
import typing as t
|
|
8
|
+
|
|
9
|
+
from sqlglot import exp, parse_one
|
|
10
|
+
|
|
11
|
+
from sqlframe.base.catalog import Function, _BaseCatalog
|
|
12
|
+
from sqlframe.base.decorators import normalize
|
|
13
|
+
from sqlframe.base.mixins.catalog_mixins import (
|
|
14
|
+
GetCurrentCatalogFromFunctionMixin,
|
|
15
|
+
GetCurrentDatabaseFromFunctionMixin,
|
|
16
|
+
ListCatalogsFromInfoSchemaMixin,
|
|
17
|
+
ListColumnsFromInfoSchemaMixin,
|
|
18
|
+
ListDatabasesFromInfoSchemaMixin,
|
|
19
|
+
ListTablesFromInfoSchemaMixin,
|
|
20
|
+
SetCurrentCatalogFromUseMixin,
|
|
21
|
+
SetCurrentDatabaseFromUseMixin,
|
|
22
|
+
)
|
|
23
|
+
from sqlframe.base.util import schema_, to_schema
|
|
24
|
+
|
|
25
|
+
if t.TYPE_CHECKING:
|
|
26
|
+
from sqlframe.snowflake.session import SnowflakeSession # noqa
|
|
27
|
+
from sqlframe.snowflake.dataframe import SnowflakeDataFrame # noqa
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SnowflakeCatalog(
|
|
31
|
+
SetCurrentCatalogFromUseMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
32
|
+
GetCurrentCatalogFromFunctionMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
33
|
+
GetCurrentDatabaseFromFunctionMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
34
|
+
ListDatabasesFromInfoSchemaMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
35
|
+
ListCatalogsFromInfoSchemaMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
36
|
+
SetCurrentDatabaseFromUseMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
37
|
+
ListTablesFromInfoSchemaMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
38
|
+
ListColumnsFromInfoSchemaMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
39
|
+
_BaseCatalog["SnowflakeSession", "SnowflakeDataFrame"],
|
|
40
|
+
):
|
|
41
|
+
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_database")
|
|
42
|
+
|
|
43
|
+
@normalize(["dbName"])
|
|
44
|
+
def listFunctions(
|
|
45
|
+
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
|
|
46
|
+
) -> t.List[Function]:
|
|
47
|
+
"""
|
|
48
|
+
Returns a t.List of functions registered in the specified database.
|
|
49
|
+
|
|
50
|
+
.. versionadded:: 3.4.0
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
dbName : str
|
|
55
|
+
name of the database to t.List the functions.
|
|
56
|
+
``dbName`` can be qualified with catalog name.
|
|
57
|
+
pattern : str
|
|
58
|
+
The pattern that the function name needs to match.
|
|
59
|
+
|
|
60
|
+
.. versionchanged: 3.5.0
|
|
61
|
+
Adds ``pattern`` argument.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
t.List
|
|
66
|
+
A t.List of :class:`Function`.
|
|
67
|
+
|
|
68
|
+
Notes
|
|
69
|
+
-----
|
|
70
|
+
If no database is specified, the current database and catalog
|
|
71
|
+
are used. This API includes all temporary functions.
|
|
72
|
+
|
|
73
|
+
Examples
|
|
74
|
+
--------
|
|
75
|
+
>>> spark.catalog.t.listFunctions()
|
|
76
|
+
[Function(name=...
|
|
77
|
+
|
|
78
|
+
>>> spark.catalog.t.listFunctions(pattern="to_*")
|
|
79
|
+
[Function(name=...
|
|
80
|
+
|
|
81
|
+
>>> spark.catalog.t.listFunctions(pattern="*not_existing_func*")
|
|
82
|
+
[]
|
|
83
|
+
"""
|
|
84
|
+
if dbName is None:
|
|
85
|
+
schema = schema_(
|
|
86
|
+
db=exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
|
|
87
|
+
catalog=exp.parse_identifier(
|
|
88
|
+
self.currentCatalog(), dialect=self.session.input_dialect
|
|
89
|
+
),
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
schema = to_schema(dbName, dialect=self.session.input_dialect)
|
|
93
|
+
if not schema.catalog:
|
|
94
|
+
schema.set("catalog", exp.parse_identifier(self.currentCatalog()))
|
|
95
|
+
query = parse_one(
|
|
96
|
+
f"""SHOW USER FUNCTIONS IN {schema.sql(dialect=self.session.input_dialect)}""",
|
|
97
|
+
dialect=self.session.input_dialect,
|
|
98
|
+
)
|
|
99
|
+
functions = self.session._fetch_rows(query)
|
|
100
|
+
if pattern:
|
|
101
|
+
functions = [x for x in functions if fnmatch.fnmatch(x["name"], pattern)]
|
|
102
|
+
return [
|
|
103
|
+
Function(
|
|
104
|
+
name=x["name"],
|
|
105
|
+
catalog=x["catalog_name"],
|
|
106
|
+
namespace=[x["schema_name"]],
|
|
107
|
+
description=None,
|
|
108
|
+
className="",
|
|
109
|
+
isTemporary=False,
|
|
110
|
+
)
|
|
111
|
+
for x in functions
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
@normalize(["table_name"])
|
|
115
|
+
def get_columns(self, table_name: str) -> t.Dict[str, exp.DataType]:
|
|
116
|
+
table = exp.to_table(table_name)
|
|
117
|
+
if not table.catalog:
|
|
118
|
+
table.set(
|
|
119
|
+
"catalog",
|
|
120
|
+
exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
|
|
121
|
+
)
|
|
122
|
+
if not table.db:
|
|
123
|
+
table.set(
|
|
124
|
+
"db",
|
|
125
|
+
exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
|
|
126
|
+
)
|
|
127
|
+
sql = f"SHOW COLUMNS IN TABLE {table.sql(dialect=self.session.input_dialect)}"
|
|
128
|
+
results = self.session._fetch_rows(sql)
|
|
129
|
+
return {
|
|
130
|
+
row["column_name"]: exp.DataType.build(
|
|
131
|
+
json.loads(row["data_type"])["type"], dialect=self.session.input_dialect, udt=True
|
|
132
|
+
)
|
|
133
|
+
for row in results
|
|
134
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.column import Column
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.dataframe import (
|
|
8
|
+
_BaseDataFrame,
|
|
9
|
+
_BaseDataFrameNaFunctions,
|
|
10
|
+
_BaseDataFrameStatFunctions,
|
|
11
|
+
)
|
|
12
|
+
from sqlframe.snowflake.group import SnowflakeGroupedData
|
|
13
|
+
|
|
14
|
+
if sys.version_info >= (3, 11):
|
|
15
|
+
from typing import Self
|
|
16
|
+
else:
|
|
17
|
+
from typing_extensions import Self
|
|
18
|
+
|
|
19
|
+
if t.TYPE_CHECKING:
|
|
20
|
+
from sqlframe.snowflake.readwriter import SnowflakeDataFrameWriter
|
|
21
|
+
from sqlframe.snowflake.session import SnowflakeSession
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SnowflakeDataFrameNaFunctions(_BaseDataFrameNaFunctions["SnowflakeDataFrame"]):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SnowflakeDataFrameStatFunctions(_BaseDataFrameStatFunctions["SnowflakeDataFrame"]):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SnowflakeDataFrame(
|
|
36
|
+
_BaseDataFrame[
|
|
37
|
+
"SnowflakeSession",
|
|
38
|
+
"SnowflakeDataFrameWriter",
|
|
39
|
+
"SnowflakeDataFrameNaFunctions",
|
|
40
|
+
"SnowflakeDataFrameStatFunctions",
|
|
41
|
+
"SnowflakeGroupedData",
|
|
42
|
+
]
|
|
43
|
+
):
|
|
44
|
+
_na = SnowflakeDataFrameNaFunctions
|
|
45
|
+
_stat = SnowflakeDataFrameStatFunctions
|
|
46
|
+
_group_data = SnowflakeGroupedData
|
|
47
|
+
|
|
48
|
+
def cache(self) -> Self:
|
|
49
|
+
logger.warning("Snowflake does not support caching. Ignoring cache() call.")
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
def persist(self) -> Self:
|
|
53
|
+
logger.warning("Snowflake does not support persist. Ignoring persist() call.")
|
|
54
|
+
return self
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
import sqlframe.base.functions
|
|
5
|
+
|
|
6
|
+
module = sys.modules["sqlframe.base.functions"]
|
|
7
|
+
globals().update(
|
|
8
|
+
{
|
|
9
|
+
name: func
|
|
10
|
+
for name, func in inspect.getmembers(module, inspect.isfunction)
|
|
11
|
+
if hasattr(func, "unsupported_engines")
|
|
12
|
+
and "snowflake" not in func.unsupported_engines
|
|
13
|
+
and "*" not in func.unsupported_engines
|
|
14
|
+
}
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from sqlframe.base.function_alternatives import e_literal as e # noqa
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.group import _BaseGroupedData
|
|
8
|
+
|
|
9
|
+
if t.TYPE_CHECKING:
|
|
10
|
+
from sqlframe.snowflake.dataframe import SnowflakeDataFrame
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SnowflakeGroupedData(_BaseGroupedData["SnowflakeDataFrame"]):
|
|
14
|
+
pass
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.mixins.readwriter_mixins import PandasLoaderMixin, PandasWriterMixin
|
|
8
|
+
from sqlframe.base.readerwriter import (
|
|
9
|
+
_BaseDataFrameReader,
|
|
10
|
+
_BaseDataFrameWriter,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if t.TYPE_CHECKING:
|
|
14
|
+
from sqlframe.snowflake.session import SnowflakeSession # noqa
|
|
15
|
+
from sqlframe.snowflake.dataframe import SnowflakeDataFrame # noqa
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SnowflakeDataFrameReader(
|
|
19
|
+
PandasLoaderMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
20
|
+
_BaseDataFrameReader["SnowflakeSession", "SnowflakeDataFrame"],
|
|
21
|
+
):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SnowflakeDataFrameWriter(
|
|
26
|
+
PandasWriterMixin["SnowflakeSession", "SnowflakeDataFrame"],
|
|
27
|
+
_BaseDataFrameWriter["SnowflakeSession", "SnowflakeDataFrame"],
|
|
28
|
+
):
|
|
29
|
+
pass
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import typing as t
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
from sqlframe.base.session import _BaseSession
|
|
7
|
+
from sqlframe.snowflake.catalog import SnowflakeCatalog
|
|
8
|
+
from sqlframe.snowflake.dataframe import SnowflakeDataFrame
|
|
9
|
+
from sqlframe.snowflake.readwriter import (
|
|
10
|
+
SnowflakeDataFrameReader,
|
|
11
|
+
SnowflakeDataFrameWriter,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
if t.TYPE_CHECKING:
|
|
15
|
+
from snowflake.connector import SnowflakeConnection
|
|
16
|
+
else:
|
|
17
|
+
SnowflakeConnection = t.Any
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SnowflakeSession(
|
|
21
|
+
_BaseSession[ # type: ignore
|
|
22
|
+
SnowflakeCatalog,
|
|
23
|
+
SnowflakeDataFrameReader,
|
|
24
|
+
SnowflakeDataFrameWriter,
|
|
25
|
+
SnowflakeDataFrame,
|
|
26
|
+
SnowflakeConnection,
|
|
27
|
+
],
|
|
28
|
+
):
|
|
29
|
+
_catalog = SnowflakeCatalog
|
|
30
|
+
_reader = SnowflakeDataFrameReader
|
|
31
|
+
_writer = SnowflakeDataFrameWriter
|
|
32
|
+
_df = SnowflakeDataFrame
|
|
33
|
+
|
|
34
|
+
def __init__(self, conn: t.Optional[SnowflakeConnection] = None):
|
|
35
|
+
warnings.warn(
|
|
36
|
+
"SnowflakeSession is still in active development. Functions may not work as expected."
|
|
37
|
+
)
|
|
38
|
+
if not hasattr(self, "_conn"):
|
|
39
|
+
super().__init__(conn)
|
|
40
|
+
|
|
41
|
+
class Builder(_BaseSession.Builder):
|
|
42
|
+
DEFAULT_INPUT_DIALECT = "snowflake"
|
|
43
|
+
DEFAULT_OUTPUT_DIALECT = "snowflake"
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def session(self) -> SnowflakeSession:
|
|
47
|
+
return SnowflakeSession(**self._session_kwargs)
|
|
48
|
+
|
|
49
|
+
def getOrCreate(self) -> SnowflakeSession:
|
|
50
|
+
self._set_session_properties()
|
|
51
|
+
return self.session
|
|
52
|
+
|
|
53
|
+
builder = Builder()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.types import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.window import *
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from sqlframe.redshift.catalog import RedshiftCatalog
|
|
2
|
+
from sqlframe.redshift.column import Column
|
|
3
|
+
from sqlframe.redshift.dataframe import RedshiftDataFrame, RedshiftDataFrameNaFunctions
|
|
4
|
+
from sqlframe.redshift.group import RedshiftGroupedData
|
|
5
|
+
from sqlframe.redshift.readwriter import (
|
|
6
|
+
RedshiftDataFrameReader,
|
|
7
|
+
RedshiftDataFrameWriter,
|
|
8
|
+
)
|
|
9
|
+
from sqlframe.redshift.session import RedshiftSession
|
|
10
|
+
from sqlframe.redshift.window import Window, WindowSpec
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"RedshiftCatalog",
|
|
14
|
+
"Column",
|
|
15
|
+
"RedshiftDataFrame",
|
|
16
|
+
"RedshiftDataFrameNaFunctions",
|
|
17
|
+
"RedshiftGroupedData",
|
|
18
|
+
"RedshiftDataFrameReader",
|
|
19
|
+
"RedshiftDataFrameWriter",
|
|
20
|
+
"RedshiftSession",
|
|
21
|
+
"Window",
|
|
22
|
+
"WindowSpec",
|
|
23
|
+
]
|