sqlframe 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sqlframe/__init__.py +0 -0
  2. sqlframe/_version.py +16 -0
  3. sqlframe/base/__init__.py +0 -0
  4. sqlframe/base/_typing.py +39 -0
  5. sqlframe/base/catalog.py +1163 -0
  6. sqlframe/base/column.py +388 -0
  7. sqlframe/base/dataframe.py +1519 -0
  8. sqlframe/base/decorators.py +51 -0
  9. sqlframe/base/exceptions.py +14 -0
  10. sqlframe/base/function_alternatives.py +1055 -0
  11. sqlframe/base/functions.py +1678 -0
  12. sqlframe/base/group.py +102 -0
  13. sqlframe/base/mixins/__init__.py +0 -0
  14. sqlframe/base/mixins/catalog_mixins.py +419 -0
  15. sqlframe/base/mixins/readwriter_mixins.py +118 -0
  16. sqlframe/base/normalize.py +84 -0
  17. sqlframe/base/operations.py +87 -0
  18. sqlframe/base/readerwriter.py +679 -0
  19. sqlframe/base/session.py +585 -0
  20. sqlframe/base/transforms.py +13 -0
  21. sqlframe/base/types.py +418 -0
  22. sqlframe/base/util.py +242 -0
  23. sqlframe/base/window.py +139 -0
  24. sqlframe/bigquery/__init__.py +23 -0
  25. sqlframe/bigquery/catalog.py +255 -0
  26. sqlframe/bigquery/column.py +1 -0
  27. sqlframe/bigquery/dataframe.py +54 -0
  28. sqlframe/bigquery/functions.py +378 -0
  29. sqlframe/bigquery/group.py +14 -0
  30. sqlframe/bigquery/readwriter.py +29 -0
  31. sqlframe/bigquery/session.py +89 -0
  32. sqlframe/bigquery/types.py +1 -0
  33. sqlframe/bigquery/window.py +1 -0
  34. sqlframe/duckdb/__init__.py +20 -0
  35. sqlframe/duckdb/catalog.py +108 -0
  36. sqlframe/duckdb/column.py +1 -0
  37. sqlframe/duckdb/dataframe.py +55 -0
  38. sqlframe/duckdb/functions.py +47 -0
  39. sqlframe/duckdb/group.py +14 -0
  40. sqlframe/duckdb/readwriter.py +111 -0
  41. sqlframe/duckdb/session.py +65 -0
  42. sqlframe/duckdb/types.py +1 -0
  43. sqlframe/duckdb/window.py +1 -0
  44. sqlframe/postgres/__init__.py +23 -0
  45. sqlframe/postgres/catalog.py +106 -0
  46. sqlframe/postgres/column.py +1 -0
  47. sqlframe/postgres/dataframe.py +54 -0
  48. sqlframe/postgres/functions.py +61 -0
  49. sqlframe/postgres/group.py +14 -0
  50. sqlframe/postgres/readwriter.py +29 -0
  51. sqlframe/postgres/session.py +68 -0
  52. sqlframe/postgres/types.py +1 -0
  53. sqlframe/postgres/window.py +1 -0
  54. sqlframe/redshift/__init__.py +23 -0
  55. sqlframe/redshift/catalog.py +127 -0
  56. sqlframe/redshift/column.py +1 -0
  57. sqlframe/redshift/dataframe.py +54 -0
  58. sqlframe/redshift/functions.py +18 -0
  59. sqlframe/redshift/group.py +14 -0
  60. sqlframe/redshift/readwriter.py +29 -0
  61. sqlframe/redshift/session.py +53 -0
  62. sqlframe/redshift/types.py +1 -0
  63. sqlframe/redshift/window.py +1 -0
  64. sqlframe/snowflake/__init__.py +26 -0
  65. sqlframe/snowflake/catalog.py +134 -0
  66. sqlframe/snowflake/column.py +1 -0
  67. sqlframe/snowflake/dataframe.py +54 -0
  68. sqlframe/snowflake/functions.py +18 -0
  69. sqlframe/snowflake/group.py +14 -0
  70. sqlframe/snowflake/readwriter.py +29 -0
  71. sqlframe/snowflake/session.py +53 -0
  72. sqlframe/snowflake/types.py +1 -0
  73. sqlframe/snowflake/window.py +1 -0
  74. sqlframe/spark/__init__.py +23 -0
  75. sqlframe/spark/catalog.py +1028 -0
  76. sqlframe/spark/column.py +1 -0
  77. sqlframe/spark/dataframe.py +54 -0
  78. sqlframe/spark/functions.py +22 -0
  79. sqlframe/spark/group.py +14 -0
  80. sqlframe/spark/readwriter.py +29 -0
  81. sqlframe/spark/session.py +90 -0
  82. sqlframe/spark/types.py +1 -0
  83. sqlframe/spark/window.py +1 -0
  84. sqlframe/standalone/__init__.py +26 -0
  85. sqlframe/standalone/catalog.py +13 -0
  86. sqlframe/standalone/column.py +1 -0
  87. sqlframe/standalone/dataframe.py +36 -0
  88. sqlframe/standalone/functions.py +1 -0
  89. sqlframe/standalone/group.py +14 -0
  90. sqlframe/standalone/readwriter.py +19 -0
  91. sqlframe/standalone/session.py +40 -0
  92. sqlframe/standalone/types.py +1 -0
  93. sqlframe/standalone/window.py +1 -0
  94. sqlframe-1.1.3.dist-info/LICENSE +21 -0
  95. sqlframe-1.1.3.dist-info/METADATA +172 -0
  96. sqlframe-1.1.3.dist-info/RECORD +98 -0
  97. sqlframe-1.1.3.dist-info/WHEEL +5 -0
  98. sqlframe-1.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,378 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import sys
5
+ import typing as t
6
+
7
+ from sqlglot import exp as sqlglot_expression
8
+
9
+ import sqlframe.base.functions
10
+ from sqlframe.base.util import get_func_from_session
11
+ from sqlframe.bigquery.column import Column
12
+
13
+ if t.TYPE_CHECKING:
14
+ from sqlframe.base._typing import ColumnOrLiteral, ColumnOrName
15
+
16
+ module = sys.modules["sqlframe.base.functions"]
17
+ globals().update(
18
+ {
19
+ name: func
20
+ for name, func in inspect.getmembers(module, inspect.isfunction)
21
+ if hasattr(func, "unsupported_engines")
22
+ and "bigquery" not in func.unsupported_engines
23
+ and "*" not in func.unsupported_engines
24
+ }
25
+ )
26
+
27
+
28
+ from sqlframe.base.function_alternatives import ( # noqa
29
+ e_literal as e,
30
+ expm1_from_exp as expm1,
31
+ factorial_from_case_statement as factorial,
32
+ log1p_from_log as log1p,
33
+ rint_from_round as rint,
34
+ collect_set_from_list_distinct as collect_set,
35
+ isnull_using_equal as isnull,
36
+ nanvl_as_case as nanvl,
37
+ percentile_approx_without_accuracy_and_plural as percentile_approx,
38
+ rand_no_seed as rand,
39
+ year_from_extract as year,
40
+ quarter_from_extract as quarter,
41
+ month_from_extract as month,
42
+ dayofweek_from_extract as dayofweek,
43
+ dayofmonth_from_extract_with_day as dayofmonth,
44
+ dayofyear_from_extract as dayofyear,
45
+ hour_from_extract as hour,
46
+ minute_from_extract as minute,
47
+ second_from_extract as second,
48
+ weekofyear_from_extract_as_isoweek as weekofyear,
49
+ make_date_from_date_func as make_date,
50
+ to_date_from_timestamp as to_date,
51
+ last_day_with_cast as last_day,
52
+ sha1_force_sha1_and_to_hex as sha1,
53
+ hash_from_farm_fingerprint as hash,
54
+ base64_from_blob as base64,
55
+ concat_ws_from_array_to_string as concat_ws,
56
+ format_string_with_format as format_string,
57
+ instr_using_strpos as instr,
58
+ overlay_from_substr as overlay,
59
+ split_with_split as split,
60
+ regexp_extract_only_one_group as regexp_extract,
61
+ hex_casted_as_bytes as hex,
62
+ bit_length_from_length as bit_length,
63
+ element_at_using_brackets as element_at,
64
+ array_union_using_array_concat as array_union,
65
+ sequence_from_generate_array as sequence,
66
+ )
67
+
68
+
69
+ def typeof(col: ColumnOrName) -> Column:
70
+ return Column(
71
+ sqlglot_expression.Anonymous(
72
+ this="bqutil.fn.typeof", expressions=[Column.ensure_col(col).expression]
73
+ )
74
+ )
75
+
76
+
77
+ def degrees(col: ColumnOrName) -> Column:
78
+ return Column(
79
+ sqlglot_expression.Anonymous(
80
+ this="bqutil.fn.degrees", expressions=[Column.ensure_col(col).expression]
81
+ )
82
+ )
83
+
84
+
85
+ def radians(col: ColumnOrName) -> Column:
86
+ return Column(
87
+ sqlglot_expression.Anonymous(
88
+ this="bqutil.fn.radians", expressions=[Column.ensure_col(col).expression]
89
+ )
90
+ )
91
+
92
+
93
+ def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
94
+ from sqlframe.base.session import _BaseSession
95
+
96
+ lit = get_func_from_session("lit", _BaseSession())
97
+
98
+ expressions = [Column.ensure_col(col).cast("bignumeric").expression]
99
+ if scale is not None:
100
+ expressions.append(lit(scale).expression)
101
+ return Column(
102
+ sqlglot_expression.Anonymous(
103
+ this="bqutil.fn.cw_round_half_even",
104
+ expressions=expressions,
105
+ )
106
+ )
107
+
108
+
109
+ def months_between(
110
+ date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
111
+ ) -> Column:
112
+ roundOff = True if roundOff is None else roundOff
113
+ round = get_func_from_session("round")
114
+ lit = get_func_from_session("lit")
115
+
116
+ value = Column(
117
+ sqlglot_expression.Anonymous(
118
+ this="bqutil.fn.cw_months_between",
119
+ expressions=[
120
+ Column.ensure_col(date1).cast("datetime").expression,
121
+ Column.ensure_col(date2).cast("datetime").expression,
122
+ ],
123
+ )
124
+ )
125
+ if roundOff:
126
+ value = round(value, lit(8))
127
+ return value
128
+
129
+
130
+ def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
131
+ lit = get_func_from_session("lit")
132
+
133
+ return Column(
134
+ sqlglot_expression.Anonymous(
135
+ this="bqutil.fn.cw_next_day",
136
+ expressions=[Column.ensure_col(col).cast("date").expression, lit(dayOfWeek).expression],
137
+ )
138
+ )
139
+
140
+
141
+ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
142
+ from sqlframe.base.session import _BaseSession
143
+
144
+ session: _BaseSession = _BaseSession()
145
+ lit = get_func_from_session("lit")
146
+ to_timestamp = get_func_from_session("to_timestamp")
147
+
148
+ expressions = [Column.ensure_col(col).expression]
149
+ if format is not None:
150
+ expressions.append(lit(format).expression)
151
+ return Column(
152
+ sqlglot_expression.Anonymous(
153
+ this="FORMAT_TIMESTAMP",
154
+ expressions=[
155
+ lit(session.DEFAULT_TIME_FORMAT).expression,
156
+ to_timestamp(
157
+ Column(
158
+ sqlglot_expression.Anonymous(
159
+ this="TIMESTAMP_SECONDS", expressions=expressions
160
+ )
161
+ ),
162
+ format,
163
+ ).expression,
164
+ ],
165
+ )
166
+ )
167
+
168
+
169
+ def unix_timestamp(
170
+ timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
171
+ ) -> Column:
172
+ from sqlframe.base.session import _BaseSession
173
+
174
+ lit = get_func_from_session("lit")
175
+
176
+ if format is None:
177
+ format = _BaseSession().DEFAULT_TIME_FORMAT
178
+ return Column(
179
+ sqlglot_expression.Anonymous(
180
+ this="UNIX_SECONDS",
181
+ expressions=[
182
+ sqlglot_expression.Anonymous(
183
+ this="PARSE_TIMESTAMP",
184
+ expressions=[
185
+ lit(format).expression,
186
+ Column.ensure_col(timestamp).expression,
187
+ lit("UTC").expression,
188
+ ],
189
+ )
190
+ ],
191
+ )
192
+ )
193
+
194
+
195
+ def format_number(col: ColumnOrName, d: int) -> Column:
196
+ round = get_func_from_session("round")
197
+ lit = get_func_from_session("lit")
198
+
199
+ return Column(
200
+ sqlglot_expression.Anonymous(
201
+ this="FORMAT",
202
+ expressions=[
203
+ lit(f"%'.{d}f").expression,
204
+ round(Column.ensure_col(col).cast("float"), d).expression,
205
+ ],
206
+ )
207
+ )
208
+
209
+
210
+ def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
211
+ lit = get_func_from_session("lit")
212
+
213
+ return Column(
214
+ sqlglot_expression.Anonymous(
215
+ this="bqutil.fn.cw_substring_index",
216
+ expressions=[
217
+ Column.ensure_col(str).expression,
218
+ lit(delim).expression,
219
+ lit(count).expression,
220
+ ],
221
+ )
222
+ )
223
+
224
+
225
+ def bin(col: ColumnOrName) -> Column:
226
+ return (
227
+ Column(
228
+ sqlglot_expression.Anonymous(
229
+ this="bqutil.fn.to_binary",
230
+ expressions=[Column.ensure_col(col).expression],
231
+ )
232
+ )
233
+ .cast("int")
234
+ .cast("string")
235
+ )
236
+
237
+
238
+ def slice(
239
+ x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
240
+ ) -> Column:
241
+ lit = get_func_from_session("lit")
242
+
243
+ start_col = start if isinstance(start, Column) else lit(start)
244
+ length_col = length if isinstance(length, Column) else lit(length)
245
+
246
+ subquery = (
247
+ sqlglot_expression.select(
248
+ sqlglot_expression.column("x"),
249
+ )
250
+ .from_(
251
+ sqlglot_expression.Unnest(
252
+ expressions=[Column.ensure_col(x).expression],
253
+ alias=sqlglot_expression.TableAlias(
254
+ columns=[sqlglot_expression.to_identifier("x")],
255
+ ),
256
+ offset=sqlglot_expression.to_identifier("offset"),
257
+ )
258
+ )
259
+ .where(
260
+ sqlglot_expression.Between(
261
+ this=sqlglot_expression.column("offset"),
262
+ low=(start_col - lit(1)).expression,
263
+ high=(start_col + length_col).expression,
264
+ )
265
+ )
266
+ )
267
+
268
+ return Column(
269
+ sqlglot_expression.Anonymous(
270
+ this="ARRAY",
271
+ expressions=[subquery],
272
+ )
273
+ )
274
+
275
+
276
+ def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
277
+ lit = get_func_from_session("lit")
278
+
279
+ value_col = value if isinstance(value, Column) else lit(value)
280
+
281
+ return Column(
282
+ sqlglot_expression.Coalesce(
283
+ this=sqlglot_expression.Anonymous(
284
+ this="bqutil.fn.find_in_set",
285
+ expressions=[
286
+ value_col.expression,
287
+ sqlglot_expression.Anonymous(
288
+ this="ARRAY_TO_STRING",
289
+ expressions=[Column.ensure_col(col).expression, lit(",").expression],
290
+ ),
291
+ ],
292
+ ),
293
+ expressions=[lit(0).expression],
294
+ )
295
+ )
296
+
297
+
298
+ def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
299
+ lit = get_func_from_session("lit")
300
+
301
+ value_col = value if isinstance(value, Column) else lit(value)
302
+
303
+ filter_subquery = sqlglot_expression.select(
304
+ "*",
305
+ ).from_(
306
+ sqlglot_expression.Unnest(
307
+ expressions=[Column.ensure_col(col).expression],
308
+ alias=sqlglot_expression.TableAlias(
309
+ columns=[sqlglot_expression.to_identifier("x")],
310
+ ),
311
+ )
312
+ )
313
+
314
+ agg_subquery = (
315
+ sqlglot_expression.select(
316
+ sqlglot_expression.Anonymous(
317
+ this="ARRAY_AGG",
318
+ expressions=[sqlglot_expression.column("x")],
319
+ ),
320
+ )
321
+ .from_(filter_subquery.subquery("t"))
322
+ .where(
323
+ sqlglot_expression.NEQ(
324
+ this=sqlglot_expression.column("x", "t"),
325
+ expression=value_col.expression,
326
+ )
327
+ )
328
+ )
329
+
330
+ return Column(agg_subquery.subquery())
331
+
332
+
333
+ def array_distinct(col: ColumnOrName) -> Column:
334
+ return Column(
335
+ sqlglot_expression.Anonymous(
336
+ this="bqutil.fn.cw_array_distinct",
337
+ expressions=[Column.ensure_col(col).expression],
338
+ )
339
+ )
340
+
341
+
342
+ def array_min(col: ColumnOrName) -> Column:
343
+ return Column(
344
+ sqlglot_expression.Anonymous(
345
+ this="bqutil.fn.cw_array_min",
346
+ expressions=[Column.ensure_col(col).expression],
347
+ )
348
+ )
349
+
350
+
351
+ def array_max(col: ColumnOrName) -> Column:
352
+ return Column(
353
+ sqlglot_expression.Anonymous(
354
+ this="bqutil.fn.cw_array_max",
355
+ expressions=[Column.ensure_col(col).expression],
356
+ )
357
+ )
358
+
359
+
360
+ def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
361
+ order = "ASC" if asc or asc is None else "DESC"
362
+ subquery = (
363
+ sqlglot_expression.select("x")
364
+ .from_(
365
+ sqlglot_expression.Unnest(
366
+ expressions=[Column.ensure_col(col).expression],
367
+ alias=sqlglot_expression.TableAlias(
368
+ columns=[sqlglot_expression.to_identifier("x")],
369
+ ),
370
+ )
371
+ )
372
+ .order_by(f"x {order}")
373
+ )
374
+
375
+ return Column(sqlglot_expression.Anonymous(this="ARRAY", expressions=[subquery]))
376
+
377
+
378
+ array_sort = sort_array
@@ -0,0 +1,14 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing as t
6
+
7
+ from sqlframe.base.group import _BaseGroupedData
8
+
9
+ if t.TYPE_CHECKING:
10
+ from sqlframe.bigquery.dataframe import BigQueryDataFrame
11
+
12
+
13
+ class BigQueryGroupedData(_BaseGroupedData["BigQueryDataFrame"]):
14
+ pass
@@ -0,0 +1,29 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import typing as t
6
+
7
+ from sqlframe.base.mixins.readwriter_mixins import PandasLoaderMixin, PandasWriterMixin
8
+ from sqlframe.base.readerwriter import (
9
+ _BaseDataFrameReader,
10
+ _BaseDataFrameWriter,
11
+ )
12
+
13
+ if t.TYPE_CHECKING:
14
+ from sqlframe.bigquery.session import BigQuerySession # noqa
15
+ from sqlframe.bigquery.dataframe import BigQueryDataFrame # noqa
16
+
17
+
18
+ class BigQueryDataFrameReader(
19
+ PandasLoaderMixin["BigQuerySession", "BigQueryDataFrame"],
20
+ _BaseDataFrameReader["BigQuerySession", "BigQueryDataFrame"],
21
+ ):
22
+ pass
23
+
24
+
25
+ class BigQueryDataFrameWriter(
26
+ PandasWriterMixin["BigQuerySession", "BigQueryDataFrame"],
27
+ _BaseDataFrameWriter["BigQuerySession", "BigQueryDataFrame"],
28
+ ):
29
+ pass
@@ -0,0 +1,89 @@
1
+ from __future__ import annotations
2
+
3
+ import typing as t
4
+
5
+ from sqlframe.base.session import _BaseSession
6
+ from sqlframe.bigquery.catalog import BigQueryCatalog
7
+ from sqlframe.bigquery.dataframe import BigQueryDataFrame
8
+ from sqlframe.bigquery.readwriter import (
9
+ BigQueryDataFrameReader,
10
+ BigQueryDataFrameWriter,
11
+ )
12
+
13
+ if t.TYPE_CHECKING:
14
+ from google.cloud.bigquery.client import Client as BigQueryClient
15
+ from google.cloud.bigquery.dbapi.connection import Connection as BigQueryConnection
16
+ else:
17
+ BigQueryClient = t.Any
18
+ BigQueryConnection = t.Any
19
+
20
+
21
+ class BigQuerySession(
22
+ _BaseSession[ # type: ignore
23
+ BigQueryCatalog,
24
+ BigQueryDataFrameReader,
25
+ BigQueryDataFrameWriter,
26
+ BigQueryDataFrame,
27
+ BigQueryConnection,
28
+ ],
29
+ ):
30
+ _catalog = BigQueryCatalog
31
+ _reader = BigQueryDataFrameReader
32
+ _writer = BigQueryDataFrameWriter
33
+ _df = BigQueryDataFrame
34
+
35
+ DEFAULT_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
36
+ QUALIFY_INFO_SCHEMA_WITH_DATABASE = True
37
+ SANITIZE_COLUMN_NAMES = True
38
+
39
+ def __init__(
40
+ self, conn: t.Optional[BigQueryConnection] = None, default_dataset: t.Optional[str] = None
41
+ ):
42
+ from google.cloud import bigquery
43
+ from google.cloud.bigquery.dbapi import connect
44
+
45
+ if not hasattr(self, "_conn"):
46
+ super().__init__(conn or connect())
47
+ if self._client.default_query_job_config is None:
48
+ self._client.default_query_job_config = bigquery.QueryJobConfig()
49
+ self.default_dataset = default_dataset
50
+
51
+ @property
52
+ def _client(self) -> BigQueryClient:
53
+ assert self._connection
54
+ return self._connection._client
55
+
56
+ @property
57
+ def default_dataset(self) -> t.Optional[str]:
58
+ return self._default_dataset
59
+
60
+ @default_dataset.setter
61
+ def default_dataset(self, dataset: str) -> None:
62
+ self._default_dataset = dataset
63
+ self._client.default_query_job_config.default_dataset = dataset # type: ignore
64
+
65
+ @property
66
+ def default_project(self) -> str:
67
+ return self._client.project
68
+
69
+ @default_project.setter
70
+ def default_project(self, project: str) -> None:
71
+ self._client.project = project
72
+
73
+ @classmethod
74
+ def _try_get_map(cls, value: t.Any) -> t.Optional[t.Dict[str, t.Any]]:
75
+ return None
76
+
77
+ class Builder(_BaseSession.Builder):
78
+ DEFAULT_INPUT_DIALECT = "bigquery"
79
+ DEFAULT_OUTPUT_DIALECT = "bigquery"
80
+
81
+ @property
82
+ def session(self) -> BigQuerySession:
83
+ return BigQuerySession(**self._session_kwargs)
84
+
85
+ def getOrCreate(self) -> BigQuerySession:
86
+ self._set_session_properties()
87
+ return self.session
88
+
89
+ builder = Builder()
@@ -0,0 +1 @@
1
+ from sqlframe.base.types import *
@@ -0,0 +1 @@
1
+ from sqlframe.base.window import *
@@ -0,0 +1,20 @@
1
+ from sqlframe.duckdb.catalog import DuckDBCatalog
2
+ from sqlframe.duckdb.column import DuckDBColumn
3
+ from sqlframe.duckdb.dataframe import DuckDBDataFrame, DuckDBDataFrameNaFunctions
4
+ from sqlframe.duckdb.group import DuckDBGroupedData
5
+ from sqlframe.duckdb.readwriter import DuckDBDataFrameReader, DuckDBDataFrameWriter
6
+ from sqlframe.duckdb.session import DuckDBSession
7
+ from sqlframe.duckdb.window import Window, WindowSpec
8
+
9
+ __all__ = [
10
+ "DuckDBCatalog",
11
+ "DuckDBColumn",
12
+ "DuckDBDataFrame",
13
+ "DuckDBDataFrameNaFunctions",
14
+ "DuckDBGroupedData",
15
+ "DuckDBDataFrameReader",
16
+ "DuckDBDataFrameWriter",
17
+ "DuckDBSession",
18
+ "Window",
19
+ "WindowSpec",
20
+ ]
@@ -0,0 +1,108 @@
1
+ # This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
2
+
3
+ from __future__ import annotations
4
+
5
+ import fnmatch
6
+ import typing as t
7
+
8
+ from sqlglot import exp
9
+
10
+ from sqlframe.base.catalog import Function, _BaseCatalog
11
+ from sqlframe.base.mixins.catalog_mixins import (
12
+ GetCurrentCatalogFromFunctionMixin,
13
+ GetCurrentDatabaseFromFunctionMixin,
14
+ ListCatalogsFromInfoSchemaMixin,
15
+ ListColumnsFromInfoSchemaMixin,
16
+ ListDatabasesFromInfoSchemaMixin,
17
+ ListTablesFromInfoSchemaMixin,
18
+ SetCurrentCatalogFromUseMixin,
19
+ SetCurrentDatabaseFromUseMixin,
20
+ )
21
+ from sqlframe.base.util import schema_, to_schema
22
+
23
+ if t.TYPE_CHECKING:
24
+ from sqlframe.duckdb.session import DuckDBSession # noqa
25
+ from sqlframe.duckdb.dataframe import DuckDBDataFrame # noqa
26
+
27
+
28
+ class DuckDBCatalog(
29
+ GetCurrentCatalogFromFunctionMixin["DuckDBSession", "DuckDBDataFrame"],
30
+ SetCurrentCatalogFromUseMixin["DuckDBSession", "DuckDBDataFrame"],
31
+ GetCurrentDatabaseFromFunctionMixin["DuckDBSession", "DuckDBDataFrame"],
32
+ ListDatabasesFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
33
+ ListCatalogsFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
34
+ SetCurrentDatabaseFromUseMixin["DuckDBSession", "DuckDBDataFrame"],
35
+ ListTablesFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
36
+ ListColumnsFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
37
+ _BaseCatalog["DuckDBSession", "DuckDBDataFrame"],
38
+ ):
39
+ def listFunctions(
40
+ self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
41
+ ) -> t.List[Function]:
42
+ """
43
+ Returns a t.List of functions registered in the specified database.
44
+
45
+ .. versionadded:: 3.4.0
46
+
47
+ Parameters
48
+ ----------
49
+ dbName : str
50
+ name of the database to t.List the functions.
51
+ ``dbName`` can be qualified with catalog name.
52
+ pattern : str
53
+ The pattern that the function name needs to match.
54
+
55
+ .. versionchanged: 3.5.0
56
+ Adds ``pattern`` argument.
57
+
58
+ Returns
59
+ -------
60
+ t.List
61
+ A t.List of :class:`Function`.
62
+
63
+ Notes
64
+ -----
65
+ If no database is specified, the current database and catalog
66
+ are used. This API includes all temporary functions.
67
+
68
+ Examples
69
+ --------
70
+ >>> spark.catalog.t.listFunctions()
71
+ [Function(name=...
72
+
73
+ >>> spark.catalog.t.listFunctions(pattern="to_*")
74
+ [Function(name=...
75
+
76
+ >>> spark.catalog.t.listFunctions(pattern="*not_existing_func*")
77
+ []
78
+ """
79
+ if not dbName:
80
+ schema = schema_(
81
+ db=exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
82
+ catalog=exp.parse_identifier(
83
+ self.currentCatalog(), dialect=self.session.input_dialect
84
+ ),
85
+ )
86
+ else:
87
+ schema = to_schema(dbName, dialect=self.session.input_dialect)
88
+ select = (
89
+ exp.select("function_name", "schema_name", "database_name")
90
+ .from_("duckdb_functions()")
91
+ .where(exp.column("schema_name").eq(schema.db))
92
+ )
93
+ if schema.catalog:
94
+ select = select.where(exp.column("database_name").eq(schema.catalog))
95
+ functions = self.session._fetch_rows(select)
96
+ if pattern:
97
+ functions = [x for x in functions if fnmatch.fnmatch(x["function_name"], pattern)]
98
+ return [
99
+ Function(
100
+ name=x["function_name"],
101
+ catalog=x["database_name"],
102
+ namespace=[x["schema_name"]],
103
+ description=None,
104
+ className="",
105
+ isTemporary=False,
106
+ )
107
+ for x in functions
108
+ ]
@@ -0,0 +1 @@
1
+ from sqlframe.base.column import Column as DuckDBColumn
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import sys
5
+ import typing as t
6
+
7
+ from sqlframe.base.dataframe import (
8
+ _BaseDataFrame,
9
+ _BaseDataFrameNaFunctions,
10
+ _BaseDataFrameStatFunctions,
11
+ )
12
+ from sqlframe.duckdb.group import DuckDBGroupedData
13
+
14
+ if sys.version_info >= (3, 11):
15
+ from typing import Self
16
+ else:
17
+ from typing_extensions import Self
18
+
19
+ if t.TYPE_CHECKING:
20
+ from sqlframe.duckdb.session import DuckDBSession # noqa
21
+ from sqlframe.duckdb.readwriter import DuckDBDataFrameWriter # noqa
22
+ from sqlframe.duckdb.group import DuckDBGroupedData # noqa
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class DuckDBDataFrameNaFunctions(_BaseDataFrameNaFunctions["DuckDBDataFrame"]):
29
+ pass
30
+
31
+
32
+ class DuckDBDataFrameStatFunctions(_BaseDataFrameStatFunctions["DuckDBDataFrame"]):
33
+ pass
34
+
35
+
36
+ class DuckDBDataFrame(
37
+ _BaseDataFrame[
38
+ "DuckDBSession",
39
+ "DuckDBDataFrameWriter",
40
+ "DuckDBDataFrameNaFunctions",
41
+ "DuckDBDataFrameStatFunctions",
42
+ "DuckDBGroupedData",
43
+ ]
44
+ ):
45
+ _na = DuckDBDataFrameNaFunctions
46
+ _stat = DuckDBDataFrameStatFunctions
47
+ _group_data = DuckDBGroupedData
48
+
49
+ def cache(self) -> Self:
50
+ logger.warning("DuckDB does not support caching. Ignoring cache() call.")
51
+ return self
52
+
53
+ def persist(self) -> Self:
54
+ logger.warning("DuckDB does not support persist. Ignoring persist() call.")
55
+ return self