sqlframe 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlframe/__init__.py +0 -0
- sqlframe/_version.py +16 -0
- sqlframe/base/__init__.py +0 -0
- sqlframe/base/_typing.py +39 -0
- sqlframe/base/catalog.py +1163 -0
- sqlframe/base/column.py +388 -0
- sqlframe/base/dataframe.py +1519 -0
- sqlframe/base/decorators.py +51 -0
- sqlframe/base/exceptions.py +14 -0
- sqlframe/base/function_alternatives.py +1055 -0
- sqlframe/base/functions.py +1678 -0
- sqlframe/base/group.py +102 -0
- sqlframe/base/mixins/__init__.py +0 -0
- sqlframe/base/mixins/catalog_mixins.py +419 -0
- sqlframe/base/mixins/readwriter_mixins.py +118 -0
- sqlframe/base/normalize.py +84 -0
- sqlframe/base/operations.py +87 -0
- sqlframe/base/readerwriter.py +679 -0
- sqlframe/base/session.py +585 -0
- sqlframe/base/transforms.py +13 -0
- sqlframe/base/types.py +418 -0
- sqlframe/base/util.py +242 -0
- sqlframe/base/window.py +139 -0
- sqlframe/bigquery/__init__.py +23 -0
- sqlframe/bigquery/catalog.py +255 -0
- sqlframe/bigquery/column.py +1 -0
- sqlframe/bigquery/dataframe.py +54 -0
- sqlframe/bigquery/functions.py +378 -0
- sqlframe/bigquery/group.py +14 -0
- sqlframe/bigquery/readwriter.py +29 -0
- sqlframe/bigquery/session.py +89 -0
- sqlframe/bigquery/types.py +1 -0
- sqlframe/bigquery/window.py +1 -0
- sqlframe/duckdb/__init__.py +20 -0
- sqlframe/duckdb/catalog.py +108 -0
- sqlframe/duckdb/column.py +1 -0
- sqlframe/duckdb/dataframe.py +55 -0
- sqlframe/duckdb/functions.py +47 -0
- sqlframe/duckdb/group.py +14 -0
- sqlframe/duckdb/readwriter.py +111 -0
- sqlframe/duckdb/session.py +65 -0
- sqlframe/duckdb/types.py +1 -0
- sqlframe/duckdb/window.py +1 -0
- sqlframe/postgres/__init__.py +23 -0
- sqlframe/postgres/catalog.py +106 -0
- sqlframe/postgres/column.py +1 -0
- sqlframe/postgres/dataframe.py +54 -0
- sqlframe/postgres/functions.py +61 -0
- sqlframe/postgres/group.py +14 -0
- sqlframe/postgres/readwriter.py +29 -0
- sqlframe/postgres/session.py +68 -0
- sqlframe/postgres/types.py +1 -0
- sqlframe/postgres/window.py +1 -0
- sqlframe/redshift/__init__.py +23 -0
- sqlframe/redshift/catalog.py +127 -0
- sqlframe/redshift/column.py +1 -0
- sqlframe/redshift/dataframe.py +54 -0
- sqlframe/redshift/functions.py +18 -0
- sqlframe/redshift/group.py +14 -0
- sqlframe/redshift/readwriter.py +29 -0
- sqlframe/redshift/session.py +53 -0
- sqlframe/redshift/types.py +1 -0
- sqlframe/redshift/window.py +1 -0
- sqlframe/snowflake/__init__.py +26 -0
- sqlframe/snowflake/catalog.py +134 -0
- sqlframe/snowflake/column.py +1 -0
- sqlframe/snowflake/dataframe.py +54 -0
- sqlframe/snowflake/functions.py +18 -0
- sqlframe/snowflake/group.py +14 -0
- sqlframe/snowflake/readwriter.py +29 -0
- sqlframe/snowflake/session.py +53 -0
- sqlframe/snowflake/types.py +1 -0
- sqlframe/snowflake/window.py +1 -0
- sqlframe/spark/__init__.py +23 -0
- sqlframe/spark/catalog.py +1028 -0
- sqlframe/spark/column.py +1 -0
- sqlframe/spark/dataframe.py +54 -0
- sqlframe/spark/functions.py +22 -0
- sqlframe/spark/group.py +14 -0
- sqlframe/spark/readwriter.py +29 -0
- sqlframe/spark/session.py +90 -0
- sqlframe/spark/types.py +1 -0
- sqlframe/spark/window.py +1 -0
- sqlframe/standalone/__init__.py +26 -0
- sqlframe/standalone/catalog.py +13 -0
- sqlframe/standalone/column.py +1 -0
- sqlframe/standalone/dataframe.py +36 -0
- sqlframe/standalone/functions.py +1 -0
- sqlframe/standalone/group.py +14 -0
- sqlframe/standalone/readwriter.py +19 -0
- sqlframe/standalone/session.py +40 -0
- sqlframe/standalone/types.py +1 -0
- sqlframe/standalone/window.py +1 -0
- sqlframe-1.1.3.dist-info/LICENSE +21 -0
- sqlframe-1.1.3.dist-info/METADATA +172 -0
- sqlframe-1.1.3.dist-info/RECORD +98 -0
- sqlframe-1.1.3.dist-info/WHEEL +5 -0
- sqlframe-1.1.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import sys
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlglot import exp as sqlglot_expression
|
|
8
|
+
|
|
9
|
+
import sqlframe.base.functions
|
|
10
|
+
from sqlframe.base.util import get_func_from_session
|
|
11
|
+
from sqlframe.bigquery.column import Column
|
|
12
|
+
|
|
13
|
+
if t.TYPE_CHECKING:
|
|
14
|
+
from sqlframe.base._typing import ColumnOrLiteral, ColumnOrName
|
|
15
|
+
|
|
16
|
+
module = sys.modules["sqlframe.base.functions"]
|
|
17
|
+
globals().update(
|
|
18
|
+
{
|
|
19
|
+
name: func
|
|
20
|
+
for name, func in inspect.getmembers(module, inspect.isfunction)
|
|
21
|
+
if hasattr(func, "unsupported_engines")
|
|
22
|
+
and "bigquery" not in func.unsupported_engines
|
|
23
|
+
and "*" not in func.unsupported_engines
|
|
24
|
+
}
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
from sqlframe.base.function_alternatives import ( # noqa
|
|
29
|
+
e_literal as e,
|
|
30
|
+
expm1_from_exp as expm1,
|
|
31
|
+
factorial_from_case_statement as factorial,
|
|
32
|
+
log1p_from_log as log1p,
|
|
33
|
+
rint_from_round as rint,
|
|
34
|
+
collect_set_from_list_distinct as collect_set,
|
|
35
|
+
isnull_using_equal as isnull,
|
|
36
|
+
nanvl_as_case as nanvl,
|
|
37
|
+
percentile_approx_without_accuracy_and_plural as percentile_approx,
|
|
38
|
+
rand_no_seed as rand,
|
|
39
|
+
year_from_extract as year,
|
|
40
|
+
quarter_from_extract as quarter,
|
|
41
|
+
month_from_extract as month,
|
|
42
|
+
dayofweek_from_extract as dayofweek,
|
|
43
|
+
dayofmonth_from_extract_with_day as dayofmonth,
|
|
44
|
+
dayofyear_from_extract as dayofyear,
|
|
45
|
+
hour_from_extract as hour,
|
|
46
|
+
minute_from_extract as minute,
|
|
47
|
+
second_from_extract as second,
|
|
48
|
+
weekofyear_from_extract_as_isoweek as weekofyear,
|
|
49
|
+
make_date_from_date_func as make_date,
|
|
50
|
+
to_date_from_timestamp as to_date,
|
|
51
|
+
last_day_with_cast as last_day,
|
|
52
|
+
sha1_force_sha1_and_to_hex as sha1,
|
|
53
|
+
hash_from_farm_fingerprint as hash,
|
|
54
|
+
base64_from_blob as base64,
|
|
55
|
+
concat_ws_from_array_to_string as concat_ws,
|
|
56
|
+
format_string_with_format as format_string,
|
|
57
|
+
instr_using_strpos as instr,
|
|
58
|
+
overlay_from_substr as overlay,
|
|
59
|
+
split_with_split as split,
|
|
60
|
+
regexp_extract_only_one_group as regexp_extract,
|
|
61
|
+
hex_casted_as_bytes as hex,
|
|
62
|
+
bit_length_from_length as bit_length,
|
|
63
|
+
element_at_using_brackets as element_at,
|
|
64
|
+
array_union_using_array_concat as array_union,
|
|
65
|
+
sequence_from_generate_array as sequence,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def typeof(col: ColumnOrName) -> Column:
|
|
70
|
+
return Column(
|
|
71
|
+
sqlglot_expression.Anonymous(
|
|
72
|
+
this="bqutil.fn.typeof", expressions=[Column.ensure_col(col).expression]
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def degrees(col: ColumnOrName) -> Column:
|
|
78
|
+
return Column(
|
|
79
|
+
sqlglot_expression.Anonymous(
|
|
80
|
+
this="bqutil.fn.degrees", expressions=[Column.ensure_col(col).expression]
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def radians(col: ColumnOrName) -> Column:
|
|
86
|
+
return Column(
|
|
87
|
+
sqlglot_expression.Anonymous(
|
|
88
|
+
this="bqutil.fn.radians", expressions=[Column.ensure_col(col).expression]
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
|
|
94
|
+
from sqlframe.base.session import _BaseSession
|
|
95
|
+
|
|
96
|
+
lit = get_func_from_session("lit", _BaseSession())
|
|
97
|
+
|
|
98
|
+
expressions = [Column.ensure_col(col).cast("bignumeric").expression]
|
|
99
|
+
if scale is not None:
|
|
100
|
+
expressions.append(lit(scale).expression)
|
|
101
|
+
return Column(
|
|
102
|
+
sqlglot_expression.Anonymous(
|
|
103
|
+
this="bqutil.fn.cw_round_half_even",
|
|
104
|
+
expressions=expressions,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def months_between(
|
|
110
|
+
date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
|
|
111
|
+
) -> Column:
|
|
112
|
+
roundOff = True if roundOff is None else roundOff
|
|
113
|
+
round = get_func_from_session("round")
|
|
114
|
+
lit = get_func_from_session("lit")
|
|
115
|
+
|
|
116
|
+
value = Column(
|
|
117
|
+
sqlglot_expression.Anonymous(
|
|
118
|
+
this="bqutil.fn.cw_months_between",
|
|
119
|
+
expressions=[
|
|
120
|
+
Column.ensure_col(date1).cast("datetime").expression,
|
|
121
|
+
Column.ensure_col(date2).cast("datetime").expression,
|
|
122
|
+
],
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
if roundOff:
|
|
126
|
+
value = round(value, lit(8))
|
|
127
|
+
return value
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def next_day(col: ColumnOrName, dayOfWeek: str) -> Column:
|
|
131
|
+
lit = get_func_from_session("lit")
|
|
132
|
+
|
|
133
|
+
return Column(
|
|
134
|
+
sqlglot_expression.Anonymous(
|
|
135
|
+
this="bqutil.fn.cw_next_day",
|
|
136
|
+
expressions=[Column.ensure_col(col).cast("date").expression, lit(dayOfWeek).expression],
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
|
|
142
|
+
from sqlframe.base.session import _BaseSession
|
|
143
|
+
|
|
144
|
+
session: _BaseSession = _BaseSession()
|
|
145
|
+
lit = get_func_from_session("lit")
|
|
146
|
+
to_timestamp = get_func_from_session("to_timestamp")
|
|
147
|
+
|
|
148
|
+
expressions = [Column.ensure_col(col).expression]
|
|
149
|
+
if format is not None:
|
|
150
|
+
expressions.append(lit(format).expression)
|
|
151
|
+
return Column(
|
|
152
|
+
sqlglot_expression.Anonymous(
|
|
153
|
+
this="FORMAT_TIMESTAMP",
|
|
154
|
+
expressions=[
|
|
155
|
+
lit(session.DEFAULT_TIME_FORMAT).expression,
|
|
156
|
+
to_timestamp(
|
|
157
|
+
Column(
|
|
158
|
+
sqlglot_expression.Anonymous(
|
|
159
|
+
this="TIMESTAMP_SECONDS", expressions=expressions
|
|
160
|
+
)
|
|
161
|
+
),
|
|
162
|
+
format,
|
|
163
|
+
).expression,
|
|
164
|
+
],
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def unix_timestamp(
|
|
170
|
+
timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
|
|
171
|
+
) -> Column:
|
|
172
|
+
from sqlframe.base.session import _BaseSession
|
|
173
|
+
|
|
174
|
+
lit = get_func_from_session("lit")
|
|
175
|
+
|
|
176
|
+
if format is None:
|
|
177
|
+
format = _BaseSession().DEFAULT_TIME_FORMAT
|
|
178
|
+
return Column(
|
|
179
|
+
sqlglot_expression.Anonymous(
|
|
180
|
+
this="UNIX_SECONDS",
|
|
181
|
+
expressions=[
|
|
182
|
+
sqlglot_expression.Anonymous(
|
|
183
|
+
this="PARSE_TIMESTAMP",
|
|
184
|
+
expressions=[
|
|
185
|
+
lit(format).expression,
|
|
186
|
+
Column.ensure_col(timestamp).expression,
|
|
187
|
+
lit("UTC").expression,
|
|
188
|
+
],
|
|
189
|
+
)
|
|
190
|
+
],
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def format_number(col: ColumnOrName, d: int) -> Column:
|
|
196
|
+
round = get_func_from_session("round")
|
|
197
|
+
lit = get_func_from_session("lit")
|
|
198
|
+
|
|
199
|
+
return Column(
|
|
200
|
+
sqlglot_expression.Anonymous(
|
|
201
|
+
this="FORMAT",
|
|
202
|
+
expressions=[
|
|
203
|
+
lit(f"%'.{d}f").expression,
|
|
204
|
+
round(Column.ensure_col(col).cast("float"), d).expression,
|
|
205
|
+
],
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def substring_index(str: ColumnOrName, delim: str, count: int) -> Column:
|
|
211
|
+
lit = get_func_from_session("lit")
|
|
212
|
+
|
|
213
|
+
return Column(
|
|
214
|
+
sqlglot_expression.Anonymous(
|
|
215
|
+
this="bqutil.fn.cw_substring_index",
|
|
216
|
+
expressions=[
|
|
217
|
+
Column.ensure_col(str).expression,
|
|
218
|
+
lit(delim).expression,
|
|
219
|
+
lit(count).expression,
|
|
220
|
+
],
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def bin(col: ColumnOrName) -> Column:
|
|
226
|
+
return (
|
|
227
|
+
Column(
|
|
228
|
+
sqlglot_expression.Anonymous(
|
|
229
|
+
this="bqutil.fn.to_binary",
|
|
230
|
+
expressions=[Column.ensure_col(col).expression],
|
|
231
|
+
)
|
|
232
|
+
)
|
|
233
|
+
.cast("int")
|
|
234
|
+
.cast("string")
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def slice(
|
|
239
|
+
x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
|
|
240
|
+
) -> Column:
|
|
241
|
+
lit = get_func_from_session("lit")
|
|
242
|
+
|
|
243
|
+
start_col = start if isinstance(start, Column) else lit(start)
|
|
244
|
+
length_col = length if isinstance(length, Column) else lit(length)
|
|
245
|
+
|
|
246
|
+
subquery = (
|
|
247
|
+
sqlglot_expression.select(
|
|
248
|
+
sqlglot_expression.column("x"),
|
|
249
|
+
)
|
|
250
|
+
.from_(
|
|
251
|
+
sqlglot_expression.Unnest(
|
|
252
|
+
expressions=[Column.ensure_col(x).expression],
|
|
253
|
+
alias=sqlglot_expression.TableAlias(
|
|
254
|
+
columns=[sqlglot_expression.to_identifier("x")],
|
|
255
|
+
),
|
|
256
|
+
offset=sqlglot_expression.to_identifier("offset"),
|
|
257
|
+
)
|
|
258
|
+
)
|
|
259
|
+
.where(
|
|
260
|
+
sqlglot_expression.Between(
|
|
261
|
+
this=sqlglot_expression.column("offset"),
|
|
262
|
+
low=(start_col - lit(1)).expression,
|
|
263
|
+
high=(start_col + length_col).expression,
|
|
264
|
+
)
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
return Column(
|
|
269
|
+
sqlglot_expression.Anonymous(
|
|
270
|
+
this="ARRAY",
|
|
271
|
+
expressions=[subquery],
|
|
272
|
+
)
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
|
277
|
+
lit = get_func_from_session("lit")
|
|
278
|
+
|
|
279
|
+
value_col = value if isinstance(value, Column) else lit(value)
|
|
280
|
+
|
|
281
|
+
return Column(
|
|
282
|
+
sqlglot_expression.Coalesce(
|
|
283
|
+
this=sqlglot_expression.Anonymous(
|
|
284
|
+
this="bqutil.fn.find_in_set",
|
|
285
|
+
expressions=[
|
|
286
|
+
value_col.expression,
|
|
287
|
+
sqlglot_expression.Anonymous(
|
|
288
|
+
this="ARRAY_TO_STRING",
|
|
289
|
+
expressions=[Column.ensure_col(col).expression, lit(",").expression],
|
|
290
|
+
),
|
|
291
|
+
],
|
|
292
|
+
),
|
|
293
|
+
expressions=[lit(0).expression],
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
|
|
299
|
+
lit = get_func_from_session("lit")
|
|
300
|
+
|
|
301
|
+
value_col = value if isinstance(value, Column) else lit(value)
|
|
302
|
+
|
|
303
|
+
filter_subquery = sqlglot_expression.select(
|
|
304
|
+
"*",
|
|
305
|
+
).from_(
|
|
306
|
+
sqlglot_expression.Unnest(
|
|
307
|
+
expressions=[Column.ensure_col(col).expression],
|
|
308
|
+
alias=sqlglot_expression.TableAlias(
|
|
309
|
+
columns=[sqlglot_expression.to_identifier("x")],
|
|
310
|
+
),
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
agg_subquery = (
|
|
315
|
+
sqlglot_expression.select(
|
|
316
|
+
sqlglot_expression.Anonymous(
|
|
317
|
+
this="ARRAY_AGG",
|
|
318
|
+
expressions=[sqlglot_expression.column("x")],
|
|
319
|
+
),
|
|
320
|
+
)
|
|
321
|
+
.from_(filter_subquery.subquery("t"))
|
|
322
|
+
.where(
|
|
323
|
+
sqlglot_expression.NEQ(
|
|
324
|
+
this=sqlglot_expression.column("x", "t"),
|
|
325
|
+
expression=value_col.expression,
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
return Column(agg_subquery.subquery())
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def array_distinct(col: ColumnOrName) -> Column:
|
|
334
|
+
return Column(
|
|
335
|
+
sqlglot_expression.Anonymous(
|
|
336
|
+
this="bqutil.fn.cw_array_distinct",
|
|
337
|
+
expressions=[Column.ensure_col(col).expression],
|
|
338
|
+
)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def array_min(col: ColumnOrName) -> Column:
|
|
343
|
+
return Column(
|
|
344
|
+
sqlglot_expression.Anonymous(
|
|
345
|
+
this="bqutil.fn.cw_array_min",
|
|
346
|
+
expressions=[Column.ensure_col(col).expression],
|
|
347
|
+
)
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def array_max(col: ColumnOrName) -> Column:
|
|
352
|
+
return Column(
|
|
353
|
+
sqlglot_expression.Anonymous(
|
|
354
|
+
this="bqutil.fn.cw_array_max",
|
|
355
|
+
expressions=[Column.ensure_col(col).expression],
|
|
356
|
+
)
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column:
|
|
361
|
+
order = "ASC" if asc or asc is None else "DESC"
|
|
362
|
+
subquery = (
|
|
363
|
+
sqlglot_expression.select("x")
|
|
364
|
+
.from_(
|
|
365
|
+
sqlglot_expression.Unnest(
|
|
366
|
+
expressions=[Column.ensure_col(col).expression],
|
|
367
|
+
alias=sqlglot_expression.TableAlias(
|
|
368
|
+
columns=[sqlglot_expression.to_identifier("x")],
|
|
369
|
+
),
|
|
370
|
+
)
|
|
371
|
+
)
|
|
372
|
+
.order_by(f"x {order}")
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return Column(sqlglot_expression.Anonymous(this="ARRAY", expressions=[subquery]))
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
array_sort = sort_array
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.group import _BaseGroupedData
|
|
8
|
+
|
|
9
|
+
if t.TYPE_CHECKING:
|
|
10
|
+
from sqlframe.bigquery.dataframe import BigQueryDataFrame
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BigQueryGroupedData(_BaseGroupedData["BigQueryDataFrame"]):
|
|
14
|
+
pass
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.mixins.readwriter_mixins import PandasLoaderMixin, PandasWriterMixin
|
|
8
|
+
from sqlframe.base.readerwriter import (
|
|
9
|
+
_BaseDataFrameReader,
|
|
10
|
+
_BaseDataFrameWriter,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if t.TYPE_CHECKING:
|
|
14
|
+
from sqlframe.bigquery.session import BigQuerySession # noqa
|
|
15
|
+
from sqlframe.bigquery.dataframe import BigQueryDataFrame # noqa
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BigQueryDataFrameReader(
|
|
19
|
+
PandasLoaderMixin["BigQuerySession", "BigQueryDataFrame"],
|
|
20
|
+
_BaseDataFrameReader["BigQuerySession", "BigQueryDataFrame"],
|
|
21
|
+
):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class BigQueryDataFrameWriter(
|
|
26
|
+
PandasWriterMixin["BigQuerySession", "BigQueryDataFrame"],
|
|
27
|
+
_BaseDataFrameWriter["BigQuerySession", "BigQueryDataFrame"],
|
|
28
|
+
):
|
|
29
|
+
pass
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import typing as t
|
|
4
|
+
|
|
5
|
+
from sqlframe.base.session import _BaseSession
|
|
6
|
+
from sqlframe.bigquery.catalog import BigQueryCatalog
|
|
7
|
+
from sqlframe.bigquery.dataframe import BigQueryDataFrame
|
|
8
|
+
from sqlframe.bigquery.readwriter import (
|
|
9
|
+
BigQueryDataFrameReader,
|
|
10
|
+
BigQueryDataFrameWriter,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
if t.TYPE_CHECKING:
|
|
14
|
+
from google.cloud.bigquery.client import Client as BigQueryClient
|
|
15
|
+
from google.cloud.bigquery.dbapi.connection import Connection as BigQueryConnection
|
|
16
|
+
else:
|
|
17
|
+
BigQueryClient = t.Any
|
|
18
|
+
BigQueryConnection = t.Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BigQuerySession(
|
|
22
|
+
_BaseSession[ # type: ignore
|
|
23
|
+
BigQueryCatalog,
|
|
24
|
+
BigQueryDataFrameReader,
|
|
25
|
+
BigQueryDataFrameWriter,
|
|
26
|
+
BigQueryDataFrame,
|
|
27
|
+
BigQueryConnection,
|
|
28
|
+
],
|
|
29
|
+
):
|
|
30
|
+
_catalog = BigQueryCatalog
|
|
31
|
+
_reader = BigQueryDataFrameReader
|
|
32
|
+
_writer = BigQueryDataFrameWriter
|
|
33
|
+
_df = BigQueryDataFrame
|
|
34
|
+
|
|
35
|
+
DEFAULT_TIME_FORMAT = "%Y-%m-%d %H:%M:%S"
|
|
36
|
+
QUALIFY_INFO_SCHEMA_WITH_DATABASE = True
|
|
37
|
+
SANITIZE_COLUMN_NAMES = True
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self, conn: t.Optional[BigQueryConnection] = None, default_dataset: t.Optional[str] = None
|
|
41
|
+
):
|
|
42
|
+
from google.cloud import bigquery
|
|
43
|
+
from google.cloud.bigquery.dbapi import connect
|
|
44
|
+
|
|
45
|
+
if not hasattr(self, "_conn"):
|
|
46
|
+
super().__init__(conn or connect())
|
|
47
|
+
if self._client.default_query_job_config is None:
|
|
48
|
+
self._client.default_query_job_config = bigquery.QueryJobConfig()
|
|
49
|
+
self.default_dataset = default_dataset
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def _client(self) -> BigQueryClient:
|
|
53
|
+
assert self._connection
|
|
54
|
+
return self._connection._client
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def default_dataset(self) -> t.Optional[str]:
|
|
58
|
+
return self._default_dataset
|
|
59
|
+
|
|
60
|
+
@default_dataset.setter
|
|
61
|
+
def default_dataset(self, dataset: str) -> None:
|
|
62
|
+
self._default_dataset = dataset
|
|
63
|
+
self._client.default_query_job_config.default_dataset = dataset # type: ignore
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def default_project(self) -> str:
|
|
67
|
+
return self._client.project
|
|
68
|
+
|
|
69
|
+
@default_project.setter
|
|
70
|
+
def default_project(self, project: str) -> None:
|
|
71
|
+
self._client.project = project
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def _try_get_map(cls, value: t.Any) -> t.Optional[t.Dict[str, t.Any]]:
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
class Builder(_BaseSession.Builder):
|
|
78
|
+
DEFAULT_INPUT_DIALECT = "bigquery"
|
|
79
|
+
DEFAULT_OUTPUT_DIALECT = "bigquery"
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def session(self) -> BigQuerySession:
|
|
83
|
+
return BigQuerySession(**self._session_kwargs)
|
|
84
|
+
|
|
85
|
+
def getOrCreate(self) -> BigQuerySession:
|
|
86
|
+
self._set_session_properties()
|
|
87
|
+
return self.session
|
|
88
|
+
|
|
89
|
+
builder = Builder()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.types import *
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.window import *
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from sqlframe.duckdb.catalog import DuckDBCatalog
|
|
2
|
+
from sqlframe.duckdb.column import DuckDBColumn
|
|
3
|
+
from sqlframe.duckdb.dataframe import DuckDBDataFrame, DuckDBDataFrameNaFunctions
|
|
4
|
+
from sqlframe.duckdb.group import DuckDBGroupedData
|
|
5
|
+
from sqlframe.duckdb.readwriter import DuckDBDataFrameReader, DuckDBDataFrameWriter
|
|
6
|
+
from sqlframe.duckdb.session import DuckDBSession
|
|
7
|
+
from sqlframe.duckdb.window import Window, WindowSpec
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"DuckDBCatalog",
|
|
11
|
+
"DuckDBColumn",
|
|
12
|
+
"DuckDBDataFrame",
|
|
13
|
+
"DuckDBDataFrameNaFunctions",
|
|
14
|
+
"DuckDBGroupedData",
|
|
15
|
+
"DuckDBDataFrameReader",
|
|
16
|
+
"DuckDBDataFrameWriter",
|
|
17
|
+
"DuckDBSession",
|
|
18
|
+
"Window",
|
|
19
|
+
"WindowSpec",
|
|
20
|
+
]
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# This code is based on code from Apache Spark under the license found in the LICENSE file located in the 'sqlframe' folder.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import fnmatch
|
|
6
|
+
import typing as t
|
|
7
|
+
|
|
8
|
+
from sqlglot import exp
|
|
9
|
+
|
|
10
|
+
from sqlframe.base.catalog import Function, _BaseCatalog
|
|
11
|
+
from sqlframe.base.mixins.catalog_mixins import (
|
|
12
|
+
GetCurrentCatalogFromFunctionMixin,
|
|
13
|
+
GetCurrentDatabaseFromFunctionMixin,
|
|
14
|
+
ListCatalogsFromInfoSchemaMixin,
|
|
15
|
+
ListColumnsFromInfoSchemaMixin,
|
|
16
|
+
ListDatabasesFromInfoSchemaMixin,
|
|
17
|
+
ListTablesFromInfoSchemaMixin,
|
|
18
|
+
SetCurrentCatalogFromUseMixin,
|
|
19
|
+
SetCurrentDatabaseFromUseMixin,
|
|
20
|
+
)
|
|
21
|
+
from sqlframe.base.util import schema_, to_schema
|
|
22
|
+
|
|
23
|
+
if t.TYPE_CHECKING:
|
|
24
|
+
from sqlframe.duckdb.session import DuckDBSession # noqa
|
|
25
|
+
from sqlframe.duckdb.dataframe import DuckDBDataFrame # noqa
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DuckDBCatalog(
|
|
29
|
+
GetCurrentCatalogFromFunctionMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
30
|
+
SetCurrentCatalogFromUseMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
31
|
+
GetCurrentDatabaseFromFunctionMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
32
|
+
ListDatabasesFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
33
|
+
ListCatalogsFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
34
|
+
SetCurrentDatabaseFromUseMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
35
|
+
ListTablesFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
36
|
+
ListColumnsFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
|
|
37
|
+
_BaseCatalog["DuckDBSession", "DuckDBDataFrame"],
|
|
38
|
+
):
|
|
39
|
+
def listFunctions(
|
|
40
|
+
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
|
|
41
|
+
) -> t.List[Function]:
|
|
42
|
+
"""
|
|
43
|
+
Returns a t.List of functions registered in the specified database.
|
|
44
|
+
|
|
45
|
+
.. versionadded:: 3.4.0
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
dbName : str
|
|
50
|
+
name of the database to t.List the functions.
|
|
51
|
+
``dbName`` can be qualified with catalog name.
|
|
52
|
+
pattern : str
|
|
53
|
+
The pattern that the function name needs to match.
|
|
54
|
+
|
|
55
|
+
.. versionchanged: 3.5.0
|
|
56
|
+
Adds ``pattern`` argument.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
t.List
|
|
61
|
+
A t.List of :class:`Function`.
|
|
62
|
+
|
|
63
|
+
Notes
|
|
64
|
+
-----
|
|
65
|
+
If no database is specified, the current database and catalog
|
|
66
|
+
are used. This API includes all temporary functions.
|
|
67
|
+
|
|
68
|
+
Examples
|
|
69
|
+
--------
|
|
70
|
+
>>> spark.catalog.t.listFunctions()
|
|
71
|
+
[Function(name=...
|
|
72
|
+
|
|
73
|
+
>>> spark.catalog.t.listFunctions(pattern="to_*")
|
|
74
|
+
[Function(name=...
|
|
75
|
+
|
|
76
|
+
>>> spark.catalog.t.listFunctions(pattern="*not_existing_func*")
|
|
77
|
+
[]
|
|
78
|
+
"""
|
|
79
|
+
if not dbName:
|
|
80
|
+
schema = schema_(
|
|
81
|
+
db=exp.parse_identifier(self.currentDatabase(), dialect=self.session.input_dialect),
|
|
82
|
+
catalog=exp.parse_identifier(
|
|
83
|
+
self.currentCatalog(), dialect=self.session.input_dialect
|
|
84
|
+
),
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
schema = to_schema(dbName, dialect=self.session.input_dialect)
|
|
88
|
+
select = (
|
|
89
|
+
exp.select("function_name", "schema_name", "database_name")
|
|
90
|
+
.from_("duckdb_functions()")
|
|
91
|
+
.where(exp.column("schema_name").eq(schema.db))
|
|
92
|
+
)
|
|
93
|
+
if schema.catalog:
|
|
94
|
+
select = select.where(exp.column("database_name").eq(schema.catalog))
|
|
95
|
+
functions = self.session._fetch_rows(select)
|
|
96
|
+
if pattern:
|
|
97
|
+
functions = [x for x in functions if fnmatch.fnmatch(x["function_name"], pattern)]
|
|
98
|
+
return [
|
|
99
|
+
Function(
|
|
100
|
+
name=x["function_name"],
|
|
101
|
+
catalog=x["database_name"],
|
|
102
|
+
namespace=[x["schema_name"]],
|
|
103
|
+
description=None,
|
|
104
|
+
className="",
|
|
105
|
+
isTemporary=False,
|
|
106
|
+
)
|
|
107
|
+
for x in functions
|
|
108
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from sqlframe.base.column import Column as DuckDBColumn
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
import typing as t
|
|
6
|
+
|
|
7
|
+
from sqlframe.base.dataframe import (
|
|
8
|
+
_BaseDataFrame,
|
|
9
|
+
_BaseDataFrameNaFunctions,
|
|
10
|
+
_BaseDataFrameStatFunctions,
|
|
11
|
+
)
|
|
12
|
+
from sqlframe.duckdb.group import DuckDBGroupedData
|
|
13
|
+
|
|
14
|
+
if sys.version_info >= (3, 11):
|
|
15
|
+
from typing import Self
|
|
16
|
+
else:
|
|
17
|
+
from typing_extensions import Self
|
|
18
|
+
|
|
19
|
+
if t.TYPE_CHECKING:
|
|
20
|
+
from sqlframe.duckdb.session import DuckDBSession # noqa
|
|
21
|
+
from sqlframe.duckdb.readwriter import DuckDBDataFrameWriter # noqa
|
|
22
|
+
from sqlframe.duckdb.group import DuckDBGroupedData # noqa
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DuckDBDataFrameNaFunctions(_BaseDataFrameNaFunctions["DuckDBDataFrame"]):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DuckDBDataFrameStatFunctions(_BaseDataFrameStatFunctions["DuckDBDataFrame"]):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DuckDBDataFrame(
|
|
37
|
+
_BaseDataFrame[
|
|
38
|
+
"DuckDBSession",
|
|
39
|
+
"DuckDBDataFrameWriter",
|
|
40
|
+
"DuckDBDataFrameNaFunctions",
|
|
41
|
+
"DuckDBDataFrameStatFunctions",
|
|
42
|
+
"DuckDBGroupedData",
|
|
43
|
+
]
|
|
44
|
+
):
|
|
45
|
+
_na = DuckDBDataFrameNaFunctions
|
|
46
|
+
_stat = DuckDBDataFrameStatFunctions
|
|
47
|
+
_group_data = DuckDBGroupedData
|
|
48
|
+
|
|
49
|
+
def cache(self) -> Self:
|
|
50
|
+
logger.warning("DuckDB does not support caching. Ignoring cache() call.")
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def persist(self) -> Self:
|
|
54
|
+
logger.warning("DuckDB does not support persist. Ignoring persist() call.")
|
|
55
|
+
return self
|