sqlframe 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sqlframe/__init__.py +0 -0
  2. sqlframe/_version.py +16 -0
  3. sqlframe/base/__init__.py +0 -0
  4. sqlframe/base/_typing.py +39 -0
  5. sqlframe/base/catalog.py +1163 -0
  6. sqlframe/base/column.py +388 -0
  7. sqlframe/base/dataframe.py +1519 -0
  8. sqlframe/base/decorators.py +51 -0
  9. sqlframe/base/exceptions.py +14 -0
  10. sqlframe/base/function_alternatives.py +1055 -0
  11. sqlframe/base/functions.py +1678 -0
  12. sqlframe/base/group.py +102 -0
  13. sqlframe/base/mixins/__init__.py +0 -0
  14. sqlframe/base/mixins/catalog_mixins.py +419 -0
  15. sqlframe/base/mixins/readwriter_mixins.py +118 -0
  16. sqlframe/base/normalize.py +84 -0
  17. sqlframe/base/operations.py +87 -0
  18. sqlframe/base/readerwriter.py +679 -0
  19. sqlframe/base/session.py +585 -0
  20. sqlframe/base/transforms.py +13 -0
  21. sqlframe/base/types.py +418 -0
  22. sqlframe/base/util.py +242 -0
  23. sqlframe/base/window.py +139 -0
  24. sqlframe/bigquery/__init__.py +23 -0
  25. sqlframe/bigquery/catalog.py +255 -0
  26. sqlframe/bigquery/column.py +1 -0
  27. sqlframe/bigquery/dataframe.py +54 -0
  28. sqlframe/bigquery/functions.py +378 -0
  29. sqlframe/bigquery/group.py +14 -0
  30. sqlframe/bigquery/readwriter.py +29 -0
  31. sqlframe/bigquery/session.py +89 -0
  32. sqlframe/bigquery/types.py +1 -0
  33. sqlframe/bigquery/window.py +1 -0
  34. sqlframe/duckdb/__init__.py +20 -0
  35. sqlframe/duckdb/catalog.py +108 -0
  36. sqlframe/duckdb/column.py +1 -0
  37. sqlframe/duckdb/dataframe.py +55 -0
  38. sqlframe/duckdb/functions.py +47 -0
  39. sqlframe/duckdb/group.py +14 -0
  40. sqlframe/duckdb/readwriter.py +111 -0
  41. sqlframe/duckdb/session.py +65 -0
  42. sqlframe/duckdb/types.py +1 -0
  43. sqlframe/duckdb/window.py +1 -0
  44. sqlframe/postgres/__init__.py +23 -0
  45. sqlframe/postgres/catalog.py +106 -0
  46. sqlframe/postgres/column.py +1 -0
  47. sqlframe/postgres/dataframe.py +54 -0
  48. sqlframe/postgres/functions.py +61 -0
  49. sqlframe/postgres/group.py +14 -0
  50. sqlframe/postgres/readwriter.py +29 -0
  51. sqlframe/postgres/session.py +68 -0
  52. sqlframe/postgres/types.py +1 -0
  53. sqlframe/postgres/window.py +1 -0
  54. sqlframe/redshift/__init__.py +23 -0
  55. sqlframe/redshift/catalog.py +127 -0
  56. sqlframe/redshift/column.py +1 -0
  57. sqlframe/redshift/dataframe.py +54 -0
  58. sqlframe/redshift/functions.py +18 -0
  59. sqlframe/redshift/group.py +14 -0
  60. sqlframe/redshift/readwriter.py +29 -0
  61. sqlframe/redshift/session.py +53 -0
  62. sqlframe/redshift/types.py +1 -0
  63. sqlframe/redshift/window.py +1 -0
  64. sqlframe/snowflake/__init__.py +26 -0
  65. sqlframe/snowflake/catalog.py +134 -0
  66. sqlframe/snowflake/column.py +1 -0
  67. sqlframe/snowflake/dataframe.py +54 -0
  68. sqlframe/snowflake/functions.py +18 -0
  69. sqlframe/snowflake/group.py +14 -0
  70. sqlframe/snowflake/readwriter.py +29 -0
  71. sqlframe/snowflake/session.py +53 -0
  72. sqlframe/snowflake/types.py +1 -0
  73. sqlframe/snowflake/window.py +1 -0
  74. sqlframe/spark/__init__.py +23 -0
  75. sqlframe/spark/catalog.py +1028 -0
  76. sqlframe/spark/column.py +1 -0
  77. sqlframe/spark/dataframe.py +54 -0
  78. sqlframe/spark/functions.py +22 -0
  79. sqlframe/spark/group.py +14 -0
  80. sqlframe/spark/readwriter.py +29 -0
  81. sqlframe/spark/session.py +90 -0
  82. sqlframe/spark/types.py +1 -0
  83. sqlframe/spark/window.py +1 -0
  84. sqlframe/standalone/__init__.py +26 -0
  85. sqlframe/standalone/catalog.py +13 -0
  86. sqlframe/standalone/column.py +1 -0
  87. sqlframe/standalone/dataframe.py +36 -0
  88. sqlframe/standalone/functions.py +1 -0
  89. sqlframe/standalone/group.py +14 -0
  90. sqlframe/standalone/readwriter.py +19 -0
  91. sqlframe/standalone/session.py +40 -0
  92. sqlframe/standalone/types.py +1 -0
  93. sqlframe/standalone/window.py +1 -0
  94. sqlframe-1.1.3.dist-info/LICENSE +21 -0
  95. sqlframe-1.1.3.dist-info/METADATA +172 -0
  96. sqlframe-1.1.3.dist-info/RECORD +98 -0
  97. sqlframe-1.1.3.dist-info/WHEEL +5 -0
  98. sqlframe-1.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,127 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import fnmatch
6
+ import typing as t
7
+
8
+ from sqlglot import exp, parse_one
9
+
10
+ from sqlframe.base.catalog import Function, _BaseCatalog
11
+ from sqlframe.base.mixins.catalog_mixins import (
12
+ GetCurrentCatalogFromFunctionMixin,
13
+ GetCurrentDatabaseFromFunctionMixin,
14
+ ListCatalogsFromInfoSchemaMixin,
15
+ ListColumnsFromInfoSchemaMixin,
16
+ ListDatabasesFromInfoSchemaMixin,
17
+ ListTablesFromInfoSchemaMixin,
18
+ SetCurrentDatabaseFromSearchPathMixin,
19
+ )
20
+ from sqlframe.base.util import schema_, to_schema
21
+
22
+ if t.TYPE_CHECKING:
23
+ from sqlframe.redshift.dataframe import RedshiftDataFrame
24
+ from sqlframe.redshift.session import RedshiftSession
25
+
26
+
27
+ class RedshiftCatalog(
28
+ GetCurrentCatalogFromFunctionMixin["RedshiftSession", "RedshiftDataFrame"],
29
+ GetCurrentDatabaseFromFunctionMixin["RedshiftSession", "RedshiftDataFrame"],
30
+ ListDatabasesFromInfoSchemaMixin["RedshiftSession", "RedshiftDataFrame"],
31
+ ListCatalogsFromInfoSchemaMixin["RedshiftSession", "RedshiftDataFrame"],
32
+ SetCurrentDatabaseFromSearchPathMixin["RedshiftSession", "RedshiftDataFrame"],
33
+ ListTablesFromInfoSchemaMixin["RedshiftSession", "RedshiftDataFrame"],
34
+ ListColumnsFromInfoSchemaMixin["RedshiftSession", "RedshiftDataFrame"],
35
+ _BaseCatalog["RedshiftSession", "RedshiftDataFrame"],
36
+ ):
37
+ CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_database")
38
+
39
+ def listFunctions(
40
+ self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
41
+ ) -> t.List[Function]:
42
+ """
43
+ Returns a t.List of functions registered in the specified database.
44
+
45
+ .. versionadded:: 3.4.0
46
+
47
+ Parameters
48
+ ----------
49
+ dbName : str
50
+ name of the database to t.List the functions.
51
+ ``dbName`` can be qualified with catalog name.
52
+ pattern : str
53
+ The pattern that the function name needs to match.
54
+
55
+ .. versionchanged: 3.5.0
56
+ Adds ``pattern`` argument.
57
+
58
+ Returns
59
+ -------
60
+ t.List
61
+ A t.List of :class:`Function`.
62
+
63
+ Notes
64
+ -----
65
+ If no database is specified, the current database and catalog
66
+ are used. This API includes all temporary functions.
67
+
68
+ Examples
69
+ --------
70
+ >>> spark.catalog.t.listFunctions()
71
+ [Function(name=...
72
+
73
+ >>> spark.catalog.t.listFunctions(pattern="to_*")
74
+ [Function(name=...
75
+
76
+ >>> spark.catalog.t.listFunctions(pattern="*not_existing_func*")
77
+ []
78
+ """
79
+ if dbName is None:
80
+ schema = schema_(
81
+ db=exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
82
+ catalog=exp.parse_identifier(
83
+ self.currentCatalog(), dialect=self.session.input_dialect
84
+ ),
85
+ )
86
+ else:
87
+ schema = to_schema(dbName, dialect=self.session.input_dialect)
88
+ if not schema.catalog:
89
+ schema.set("catalog", exp.parse_identifier(self.currentCatalog()))
90
+ query = parse_one(
91
+ f"""SELECT database_name as catalog, schema_name as namespace, function_name as name
92
+ FROM svv_redshift_functions
93
+ WHERE database_name = '{schema.catalog}'
94
+ and schema_name = '{schema.db}'
95
+ ORDER BY function_name;
96
+ """,
97
+ dialect=self.session.input_dialect,
98
+ )
99
+ functions = self.session._fetch_rows(query)
100
+ if pattern:
101
+ functions = [x for x in functions if fnmatch.fnmatch(x["name"], pattern)]
102
+ return [
103
+ Function(
104
+ name=x["name"],
105
+ catalog=x["catalog"],
106
+ namespace=[x["namespace"]],
107
+ description=None,
108
+ className="",
109
+ isTemporary=False,
110
+ )
111
+ for x in functions
112
+ ]
113
+
114
+ # def get_columns(self, table_name: t.Union[exp.Table, str]) -> t.Dict[str, exp.DataType]:
115
+ # table = self.ensure_table(table_name)
116
+ # if not table.catalog:
117
+ # table.set("catalog", exp.parse_identifier(self.currentCatalog()))
118
+ # if not table.db:
119
+ # table.set("db", exp.parse_identifier(self.currentDatabase()))
120
+ # sql = f"SHOW COLUMNS FROM TABLE {table.sql(dialect=self.session.input_dialect)}"
121
+ # results = sorted(self.session._fetch_rows(sql), key=lambda x: x["ordinal_position"])
122
+ # return {
123
+ # row["column_name"]: exp.DataType.build(
124
+ # row["data_type"], dialect=self.session.input_dialect, udt=True
125
+ # )
126
+ # for row in results
127
+ # }
@@ -0,0 +1 @@
1
+ from sqlframe.base.column import Column
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import sys
5
+ import typing as t
6
+
7
+ from sqlframe.base.dataframe import (
8
+ _BaseDataFrame,
9
+ _BaseDataFrameNaFunctions,
10
+ _BaseDataFrameStatFunctions,
11
+ )
12
+ from sqlframe.redshift.group import RedshiftGroupedData
13
+
14
+ if sys.version_info >= (3, 11):
15
+ from typing import Self
16
+ else:
17
+ from typing_extensions import Self
18
+
19
+ if t.TYPE_CHECKING:
20
+ from sqlframe.redshift.readwriter import RedshiftDataFrameWriter
21
+ from sqlframe.redshift.session import RedshiftSession
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class RedshiftDataFrameNaFunctions(_BaseDataFrameNaFunctions["RedshiftDataFrame"]):
28
+ pass
29
+
30
+
31
+ class RedshiftDataFrameStatFunctions(_BaseDataFrameStatFunctions["RedshiftDataFrame"]):
32
+ pass
33
+
34
+
35
+ class RedshiftDataFrame(
36
+ _BaseDataFrame[
37
+ "RedshiftSession",
38
+ "RedshiftDataFrameWriter",
39
+ "RedshiftDataFrameNaFunctions",
40
+ "RedshiftDataFrameStatFunctions",
41
+ "RedshiftGroupedData",
42
+ ]
43
+ ):
44
+ _na = RedshiftDataFrameNaFunctions
45
+ _stat = RedshiftDataFrameStatFunctions
46
+ _group_data = RedshiftGroupedData
47
+
48
+ def cache(self) -> Self:
49
+ logger.warning("Redshift does not support caching. Ignoring cache() call.")
50
+ return self
51
+
52
+ def persist(self) -> Self:
53
+ logger.warning("Redshift does not support persist. Ignoring persist() call.")
54
+ return self
@@ -0,0 +1,18 @@
1
+ import inspect
2
+ import sys
3
+
4
+ import sqlframe.base.functions
5
+
6
+ module = sys.modules["sqlframe.base.functions"]
7
+ globals().update(
8
+ {
9
+ name: func
10
+ for name, func in inspect.getmembers(module, inspect.isfunction)
11
+ if hasattr(func, "unsupported_engines")
12
+ and "redshift" not in func.unsupported_engines
13
+ and "*" not in func.unsupported_engines
14
+ }
15
+ )
16
+
17
+
18
+ from sqlframe.base.function_alternatives import e_literal as e # noqa
@@ -0,0 +1,14 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing as t
6
+
7
+ from sqlframe.base.group import _BaseGroupedData
8
+
9
+ if t.TYPE_CHECKING:
10
+ from sqlframe.redshift.dataframe import RedshiftDataFrame
11
+
12
+
13
+ class RedshiftGroupedData(_BaseGroupedData["RedshiftDataFrame"]):
14
+ pass
@@ -0,0 +1,29 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing as t
6
+
7
+ from sqlframe.base.mixins.readwriter_mixins import PandasLoaderMixin, PandasWriterMixin
8
+ from sqlframe.base.readerwriter import (
9
+ _BaseDataFrameReader,
10
+ _BaseDataFrameWriter,
11
+ )
12
+
13
+ if t.TYPE_CHECKING:
14
+ from sqlframe.redshift.session import RedshiftSession # noqa
15
+ from sqlframe.redshift.dataframe import RedshiftDataFrame # noqa
16
+
17
+
18
+ class RedshiftDataFrameReader(
19
+ PandasLoaderMixin["RedshiftSession", "RedshiftDataFrame"],
20
+ _BaseDataFrameReader["RedshiftSession", "RedshiftDataFrame"],
21
+ ):
22
+ pass
23
+
24
+
25
+ class RedshiftDataFrameWriter(
26
+ PandasWriterMixin["RedshiftSession", "RedshiftDataFrame"],
27
+ _BaseDataFrameWriter["RedshiftSession", "RedshiftDataFrame"],
28
+ ):
29
+ pass
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+ import warnings
5
+
6
+ from sqlframe.base.session import _BaseSession
7
+ from sqlframe.redshift.catalog import RedshiftCatalog
8
+ from sqlframe.redshift.dataframe import RedshiftDataFrame
9
+ from sqlframe.redshift.readwriter import (
10
+ RedshiftDataFrameReader,
11
+ RedshiftDataFrameWriter,
12
+ )
13
+
14
+ if t.TYPE_CHECKING:
15
+ from redshift_connector.core import Connection as RedshiftConnection
16
+ else:
17
+ RedshiftConnection = t.Any
18
+
19
+
20
+ class RedshiftSession(
21
+ _BaseSession[ # type: ignore
22
+ RedshiftCatalog,
23
+ RedshiftDataFrameReader,
24
+ RedshiftDataFrameWriter,
25
+ RedshiftDataFrame,
26
+ RedshiftConnection,
27
+ ],
28
+ ):
29
+ _catalog = RedshiftCatalog
30
+ _reader = RedshiftDataFrameReader
31
+ _writer = RedshiftDataFrameWriter
32
+ _df = RedshiftDataFrame
33
+
34
+ def __init__(self, conn: t.Optional[RedshiftConnection] = None):
35
+ warnings.warn(
36
+ "RedshiftSession is still in active development. Functions may not work as expected."
37
+ )
38
+ if not hasattr(self, "_conn"):
39
+ super().__init__(conn)
40
+
41
+ class Builder(_BaseSession.Builder):
42
+ DEFAULT_INPUT_DIALECT = "redshift"
43
+ DEFAULT_OUTPUT_DIALECT = "redshift"
44
+
45
+ @property
46
+ def session(self) -> RedshiftSession:
47
+ return RedshiftSession(**self._session_kwargs)
48
+
49
+ def getOrCreate(self) -> RedshiftSession:
50
+ self._set_session_properties()
51
+ return self.session
52
+
53
+ builder = Builder()
@@ -0,0 +1 @@
1
+ from sqlframe.base.types import *
@@ -0,0 +1 @@
1
+ from sqlframe.base.window import *
@@ -0,0 +1,26 @@
1
+ from sqlframe.snowflake.catalog import SnowflakeCatalog
2
+ from sqlframe.snowflake.column import Column
3
+ from sqlframe.snowflake.dataframe import (
4
+ SnowflakeDataFrame,
5
+ SnowflakeDataFrameNaFunctions,
6
+ )
7
+ from sqlframe.snowflake.group import SnowflakeGroupedData
8
+ from sqlframe.snowflake.readwriter import (
9
+ SnowflakeDataFrameReader,
10
+ SnowflakeDataFrameWriter,
11
+ )
12
+ from sqlframe.snowflake.session import SnowflakeSession
13
+ from sqlframe.snowflake.window import Window, WindowSpec
14
+
15
+ __all__ = [
16
+ "SnowflakeCatalog",
17
+ "Column",
18
+ "SnowflakeDataFrame",
19
+ "SnowflakeDataFrameNaFunctions",
20
+ "SnowflakeGroupedData",
21
+ "SnowflakeDataFrameReader",
22
+ "SnowflakeDataFrameWriter",
23
+ "SnowflakeSession",
24
+ "Window",
25
+ "WindowSpec",
26
+ ]
@@ -0,0 +1,134 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import fnmatch
6
+ import json
7
+ import typing as t
8
+
9
+ from sqlglot import exp, parse_one
10
+
11
+ from sqlframe.base.catalog import Function, _BaseCatalog
12
+ from sqlframe.base.decorators import normalize
13
+ from sqlframe.base.mixins.catalog_mixins import (
14
+ GetCurrentCatalogFromFunctionMixin,
15
+ GetCurrentDatabaseFromFunctionMixin,
16
+ ListCatalogsFromInfoSchemaMixin,
17
+ ListColumnsFromInfoSchemaMixin,
18
+ ListDatabasesFromInfoSchemaMixin,
19
+ ListTablesFromInfoSchemaMixin,
20
+ SetCurrentCatalogFromUseMixin,
21
+ SetCurrentDatabaseFromUseMixin,
22
+ )
23
+ from sqlframe.base.util import schema_, to_schema
24
+
25
+ if t.TYPE_CHECKING:
26
+ from sqlframe.snowflake.session import SnowflakeSession # noqa
27
+ from sqlframe.snowflake.dataframe import SnowflakeDataFrame # noqa
28
+
29
+
30
+ class SnowflakeCatalog(
31
+ SetCurrentCatalogFromUseMixin["SnowflakeSession", "SnowflakeDataFrame"],
32
+ GetCurrentCatalogFromFunctionMixin["SnowflakeSession", "SnowflakeDataFrame"],
33
+ GetCurrentDatabaseFromFunctionMixin["SnowflakeSession", "SnowflakeDataFrame"],
34
+ ListDatabasesFromInfoSchemaMixin["SnowflakeSession", "SnowflakeDataFrame"],
35
+ ListCatalogsFromInfoSchemaMixin["SnowflakeSession", "SnowflakeDataFrame"],
36
+ SetCurrentDatabaseFromUseMixin["SnowflakeSession", "SnowflakeDataFrame"],
37
+ ListTablesFromInfoSchemaMixin["SnowflakeSession", "SnowflakeDataFrame"],
38
+ ListColumnsFromInfoSchemaMixin["SnowflakeSession", "SnowflakeDataFrame"],
39
+ _BaseCatalog["SnowflakeSession", "SnowflakeDataFrame"],
40
+ ):
41
+ CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.func("current_database")
42
+
43
+ @normalize(["dbName"])
44
+ def listFunctions(
45
+ self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
46
+ ) -> t.List[Function]:
47
+ """
48
+ Returns a t.List of functions registered in the specified database.
49
+
50
+ .. versionadded:: 3.4.0
51
+
52
+ Parameters
53
+ ----------
54
+ dbName : str
55
+ name of the database to t.List the functions.
56
+ ``dbName`` can be qualified with catalog name.
57
+ pattern : str
58
+ The pattern that the function name needs to match.
59
+
60
+ .. versionchanged: 3.5.0
61
+ Adds ``pattern`` argument.
62
+
63
+ Returns
64
+ -------
65
+ t.List
66
+ A t.List of :class:`Function`.
67
+
68
+ Notes
69
+ -----
70
+ If no database is specified, the current database and catalog
71
+ are used. This API includes all temporary functions.
72
+
73
+ Examples
74
+ --------
75
+ >>> spark.catalog.t.listFunctions()
76
+ [Function(name=...
77
+
78
+ >>> spark.catalog.t.listFunctions(pattern="to_*")
79
+ [Function(name=...
80
+
81
+ >>> spark.catalog.t.listFunctions(pattern="*not_existing_func*")
82
+ []
83
+ """
84
+ if dbName is None:
85
+ schema = schema_(
86
+ db=exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
87
+ catalog=exp.parse_identifier(
88
+ self.currentCatalog(), dialect=self.session.input_dialect
89
+ ),
90
+ )
91
+ else:
92
+ schema = to_schema(dbName, dialect=self.session.input_dialect)
93
+ if not schema.catalog:
94
+ schema.set("catalog", exp.parse_identifier(self.currentCatalog()))
95
+ query = parse_one(
96
+ f"""SHOW USER FUNCTIONS IN {schema.sql(dialect=self.session.input_dialect)}""",
97
+ dialect=self.session.input_dialect,
98
+ )
99
+ functions = self.session._fetch_rows(query)
100
+ if pattern:
101
+ functions = [x for x in functions if fnmatch.fnmatch(x["name"], pattern)]
102
+ return [
103
+ Function(
104
+ name=x["name"],
105
+ catalog=x["catalog_name"],
106
+ namespace=[x["schema_name"]],
107
+ description=None,
108
+ className="",
109
+ isTemporary=False,
110
+ )
111
+ for x in functions
112
+ ]
113
+
114
+ @normalize(["table_name"])
115
+ def get_columns(self, table_name: str) -> t.Dict[str, exp.DataType]:
116
+ table = exp.to_table(table_name)
117
+ if not table.catalog:
118
+ table.set(
119
+ "catalog",
120
+ exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
121
+ )
122
+ if not table.db:
123
+ table.set(
124
+ "db",
125
+ exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
126
+ )
127
+ sql = f"SHOW COLUMNS IN TABLE {table.sql(dialect=self.session.input_dialect)}"
128
+ results = self.session._fetch_rows(sql)
129
+ return {
130
+ row["column_name"]: exp.DataType.build(
131
+ json.loads(row["data_type"])["type"], dialect=self.session.input_dialect, udt=True
132
+ )
133
+ for row in results
134
+ }
@@ -0,0 +1 @@
1
+ from sqlframe.base.column import Column
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import sys
5
+ import typing as t
6
+
7
+ from sqlframe.base.dataframe import (
8
+ _BaseDataFrame,
9
+ _BaseDataFrameNaFunctions,
10
+ _BaseDataFrameStatFunctions,
11
+ )
12
+ from sqlframe.snowflake.group import SnowflakeGroupedData
13
+
14
+ if sys.version_info >= (3, 11):
15
+ from typing import Self
16
+ else:
17
+ from typing_extensions import Self
18
+
19
+ if t.TYPE_CHECKING:
20
+ from sqlframe.snowflake.readwriter import SnowflakeDataFrameWriter
21
+ from sqlframe.snowflake.session import SnowflakeSession
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class SnowflakeDataFrameNaFunctions(_BaseDataFrameNaFunctions["SnowflakeDataFrame"]):
28
+ pass
29
+
30
+
31
+ class SnowflakeDataFrameStatFunctions(_BaseDataFrameStatFunctions["SnowflakeDataFrame"]):
32
+ pass
33
+
34
+
35
+ class SnowflakeDataFrame(
36
+ _BaseDataFrame[
37
+ "SnowflakeSession",
38
+ "SnowflakeDataFrameWriter",
39
+ "SnowflakeDataFrameNaFunctions",
40
+ "SnowflakeDataFrameStatFunctions",
41
+ "SnowflakeGroupedData",
42
+ ]
43
+ ):
44
+ _na = SnowflakeDataFrameNaFunctions
45
+ _stat = SnowflakeDataFrameStatFunctions
46
+ _group_data = SnowflakeGroupedData
47
+
48
+ def cache(self) -> Self:
49
+ logger.warning("Snowflake does not support caching. Ignoring cache() call.")
50
+ return self
51
+
52
+ def persist(self) -> Self:
53
+ logger.warning("Snowflake does not support persist. Ignoring persist() call.")
54
+ return self
@@ -0,0 +1,18 @@
1
+ import inspect
2
+ import sys
3
+
4
+ import sqlframe.base.functions
5
+
6
+ module = sys.modules["sqlframe.base.functions"]
7
+ globals().update(
8
+ {
9
+ name: func
10
+ for name, func in inspect.getmembers(module, inspect.isfunction)
11
+ if hasattr(func, "unsupported_engines")
12
+ and "snowflake" not in func.unsupported_engines
13
+ and "*" not in func.unsupported_engines
14
+ }
15
+ )
16
+
17
+
18
+ from sqlframe.base.function_alternatives import e_literal as e # noqa
@@ -0,0 +1,14 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing as t
6
+
7
+ from sqlframe.base.group import _BaseGroupedData
8
+
9
+ if t.TYPE_CHECKING:
10
+ from sqlframe.snowflake.dataframe import SnowflakeDataFrame
11
+
12
+
13
+ class SnowflakeGroupedData(_BaseGroupedData["SnowflakeDataFrame"]):
14
+ pass
@@ -0,0 +1,29 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing as t
6
+
7
+ from sqlframe.base.mixins.readwriter_mixins import PandasLoaderMixin, PandasWriterMixin
8
+ from sqlframe.base.readerwriter import (
9
+ _BaseDataFrameReader,
10
+ _BaseDataFrameWriter,
11
+ )
12
+
13
+ if t.TYPE_CHECKING:
14
+ from sqlframe.snowflake.session import SnowflakeSession # noqa
15
+ from sqlframe.snowflake.dataframe import SnowflakeDataFrame # noqa
16
+
17
+
18
+ class SnowflakeDataFrameReader(
19
+ PandasLoaderMixin["SnowflakeSession", "SnowflakeDataFrame"],
20
+ _BaseDataFrameReader["SnowflakeSession", "SnowflakeDataFrame"],
21
+ ):
22
+ pass
23
+
24
+
25
+ class SnowflakeDataFrameWriter(
26
+ PandasWriterMixin["SnowflakeSession", "SnowflakeDataFrame"],
27
+ _BaseDataFrameWriter["SnowflakeSession", "SnowflakeDataFrame"],
28
+ ):
29
+ pass
@@ -0,0 +1,53 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+ import warnings
5
+
6
+ from sqlframe.base.session import _BaseSession
7
+ from sqlframe.snowflake.catalog import SnowflakeCatalog
8
+ from sqlframe.snowflake.dataframe import SnowflakeDataFrame
9
+ from sqlframe.snowflake.readwriter import (
10
+ SnowflakeDataFrameReader,
11
+ SnowflakeDataFrameWriter,
12
+ )
13
+
14
+ if t.TYPE_CHECKING:
15
+ from snowflake.connector import SnowflakeConnection
16
+ else:
17
+ SnowflakeConnection = t.Any
18
+
19
+
20
+ class SnowflakeSession(
21
+ _BaseSession[ # type: ignore
22
+ SnowflakeCatalog,
23
+ SnowflakeDataFrameReader,
24
+ SnowflakeDataFrameWriter,
25
+ SnowflakeDataFrame,
26
+ SnowflakeConnection,
27
+ ],
28
+ ):
29
+ _catalog = SnowflakeCatalog
30
+ _reader = SnowflakeDataFrameReader
31
+ _writer = SnowflakeDataFrameWriter
32
+ _df = SnowflakeDataFrame
33
+
34
+ def __init__(self, conn: t.Optional[SnowflakeConnection] = None):
35
+ warnings.warn(
36
+ "SnowflakeSession is still in active development. Functions may not work as expected."
37
+ )
38
+ if not hasattr(self, "_conn"):
39
+ super().__init__(conn)
40
+
41
+ class Builder(_BaseSession.Builder):
42
+ DEFAULT_INPUT_DIALECT = "snowflake"
43
+ DEFAULT_OUTPUT_DIALECT = "snowflake"
44
+
45
+ @property
46
+ def session(self) -> SnowflakeSession:
47
+ return SnowflakeSession(**self._session_kwargs)
48
+
49
+ def getOrCreate(self) -> SnowflakeSession:
50
+ self._set_session_properties()
51
+ return self.session
52
+
53
+ builder = Builder()
@@ -0,0 +1 @@
1
+ from sqlframe.base.types import *
@@ -0,0 +1 @@
1
+ from sqlframe.base.window import *
@@ -0,0 +1,23 @@
1
+ from sqlframe.redshift.catalog import RedshiftCatalog
2
+ from sqlframe.redshift.column import Column
3
+ from sqlframe.redshift.dataframe import RedshiftDataFrame, RedshiftDataFrameNaFunctions
4
+ from sqlframe.redshift.group import RedshiftGroupedData
5
+ from sqlframe.redshift.readwriter import (
6
+ RedshiftDataFrameReader,
7
+ RedshiftDataFrameWriter,
8
+ )
9
+ from sqlframe.redshift.session import RedshiftSession
10
+ from sqlframe.redshift.window import Window, WindowSpec
11
+
12
+ __all__ = [
13
+ "RedshiftCatalog",
14
+ "Column",
15
+ "RedshiftDataFrame",
16
+ "RedshiftDataFrameNaFunctions",
17
+ "RedshiftGroupedData",
18
+ "RedshiftDataFrameReader",
19
+ "RedshiftDataFrameWriter",
20
+ "RedshiftSession",
21
+ "Window",
22
+ "WindowSpec",
23
+ ]