fakesnow 0.9.35__py3-none-any.whl → 0.9.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1335 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from string import Template
5
+ from typing import ClassVar, cast
6
+
7
+ import sqlglot
8
+ from sqlglot import exp
9
+
10
+ from fakesnow.variables import Variables
11
+
12
+ SUCCESS_NOP = sqlglot.parse_one("SELECT 'Statement executed successfully.' as status")
13
+
14
+
15
+ def alias_in_join(expression: exp.Expression) -> exp.Expression:
16
+ if (
17
+ isinstance(expression, exp.Select)
18
+ and (aliases := {e.args.get("alias"): e for e in expression.expressions if isinstance(e, exp.Alias)})
19
+ and (joins := expression.args.get("joins"))
20
+ ):
21
+ j: exp.Join
22
+ for j in joins:
23
+ if (
24
+ (on := j.args.get("on"))
25
+ and (col := on.this)
26
+ and (isinstance(col, exp.Column))
27
+ and (alias := aliases.get(col.this))
28
+ # don't rewrite col with table identifier
29
+ and not col.table
30
+ ):
31
+ col.args["this"] = alias.this
32
+
33
+ return expression
34
+
35
+
36
+ def alter_table_strip_cluster_by(expression: exp.Expression) -> exp.Expression:
37
+ """Turn alter table cluster by into a no-op"""
38
+ if (
39
+ isinstance(expression, exp.Alter)
40
+ and (actions := expression.args.get("actions"))
41
+ and len(actions) == 1
42
+ and (isinstance(actions[0], exp.Cluster))
43
+ ):
44
+ return SUCCESS_NOP
45
+ return expression
46
+
47
+
48
+ def array_size(expression: exp.Expression) -> exp.Expression:
49
+ if isinstance(expression, exp.ArraySize):
50
+ # return null if not json array
51
+ jal = exp.Anonymous(this="json_array_length", expressions=[expression.this])
52
+ is_json_array = exp.EQ(
53
+ this=exp.Anonymous(this="json_type", expressions=[expression.this]),
54
+ expression=exp.Literal(this="ARRAY", is_string=True),
55
+ )
56
+ return exp.Case(ifs=[exp.If(this=is_json_array, true=jal)])
57
+
58
+ return expression
59
+
60
+
61
+ def array_agg(expression: exp.Expression) -> exp.Expression:
62
+ if isinstance(expression, exp.ArrayAgg) and not isinstance(expression.parent, exp.Window):
63
+ return exp.Anonymous(this="TO_JSON", expressions=[expression])
64
+
65
+ if isinstance(expression, exp.Window) and isinstance(expression.this, exp.ArrayAgg):
66
+ return exp.Anonymous(this="TO_JSON", expressions=[expression])
67
+
68
+ return expression
69
+
70
+
71
+ def array_agg_within_group(expression: exp.Expression) -> exp.Expression:
72
+ """Convert ARRAY_AGG(<expr>) WITHIN GROUP (<order-by-clause>) to ARRAY_AGG( <expr> <order-by-clause> )
73
+ Snowflake uses ARRAY_AGG(<expr>) WITHIN GROUP (ORDER BY <order-by-clause>)
74
+ to order the array, but DuckDB uses ARRAY_AGG( <expr> <order-by-clause> ).
75
+ See;
76
+ - https://docs.snowflake.com/en/sql-reference/functions/array_agg
77
+ - https://duckdb.org/docs/sql/aggregates.html#order-by-clause-in-aggregate-functions
78
+ Note; Snowflake has following restriction;
79
+ If you specify DISTINCT and WITHIN GROUP, both must refer to the same column.
80
+ Transformation does not handle this restriction.
81
+ """
82
+ if (
83
+ isinstance(expression, exp.WithinGroup)
84
+ and (agg := expression.find(exp.ArrayAgg))
85
+ and (order := expression.expression)
86
+ ):
87
+ return exp.ArrayAgg(
88
+ this=exp.Order(
89
+ this=agg.this,
90
+ expressions=order.expressions,
91
+ )
92
+ )
93
+
94
+ return expression
95
+
96
+
97
+ def create_clone(expression: exp.Expression) -> exp.Expression:
98
+ """Transform create table clone to create table as select."""
99
+
100
+ if (
101
+ isinstance(expression, exp.Create)
102
+ and str(expression.args.get("kind")).upper() == "TABLE"
103
+ and (clone := expression.find(exp.Clone))
104
+ ):
105
+ return exp.Create(
106
+ this=expression.this,
107
+ kind="TABLE",
108
+ expression=exp.Select(
109
+ expressions=[
110
+ exp.Star(),
111
+ ],
112
+ **{"from": exp.From(this=clone.this)},
113
+ ),
114
+ )
115
+ return expression
116
+
117
+
118
+ # TODO: move this into a Dialect as a transpilation
119
+ def create_database(expression: exp.Expression, db_path: Path | None = None) -> exp.Expression:
120
+ """Transform create database to attach database.
121
+
122
+ Example:
123
+ >>> import sqlglot
124
+ >>> sqlglot.parse_one("CREATE database foo").transform(create_database).sql()
125
+ 'ATTACH DATABASE ':memory:' as foo'
126
+ Args:
127
+ expression (exp.Expression): the expression that will be transformed.
128
+
129
+ Returns:
130
+ exp.Expression: The transformed expression, with the database name stored in the create_db_name arg.
131
+ """
132
+
133
+ if isinstance(expression, exp.Create) and str(expression.args.get("kind")).upper() == "DATABASE":
134
+ ident = expression.find(exp.Identifier)
135
+ assert ident, f"No identifier in {expression.sql}"
136
+ db_name = ident.this
137
+ db_file = f"{db_path / db_name}.db" if db_path else ":memory:"
138
+
139
+ if_not_exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
140
+
141
+ return exp.Command(
142
+ this="ATTACH",
143
+ expression=exp.Literal(this=f"{if_not_exists}DATABASE '{db_file}' AS {db_name}", is_string=True),
144
+ create_db_name=db_name,
145
+ )
146
+
147
+ return expression
148
+
149
+
150
+ SQL_DESCRIBE_TABLE = Template(
151
+ """
152
+ SELECT
153
+ column_name AS "name",
154
+ CASE WHEN data_type = 'NUMBER' THEN 'NUMBER(' || numeric_precision || ',' || numeric_scale || ')'
155
+ WHEN data_type = 'TEXT' THEN 'VARCHAR(' || coalesce(character_maximum_length,16777216) || ')'
156
+ WHEN data_type = 'TIMESTAMP_NTZ' THEN 'TIMESTAMP_NTZ(9)'
157
+ WHEN data_type = 'TIMESTAMP_TZ' THEN 'TIMESTAMP_TZ(9)'
158
+ WHEN data_type = 'TIME' THEN 'TIME(9)'
159
+ WHEN data_type = 'BINARY' THEN 'BINARY(8388608)'
160
+ ELSE data_type END AS "type",
161
+ 'COLUMN' AS "kind",
162
+ CASE WHEN is_nullable = 'YES' THEN 'Y' ELSE 'N' END AS "null?",
163
+ column_default AS "default",
164
+ 'N' AS "primary key",
165
+ 'N' AS "unique key",
166
+ NULL::VARCHAR AS "check",
167
+ NULL::VARCHAR AS "expression",
168
+ NULL::VARCHAR AS "comment",
169
+ NULL::VARCHAR AS "policy name",
170
+ NULL::JSON AS "privacy domain",
171
+ FROM _fs_information_schema._fs_columns
172
+ WHERE table_catalog = '${catalog}' AND table_schema = '${schema}' AND table_name = '${table}'
173
+ ORDER BY ordinal_position
174
+ """
175
+ )
176
+
177
+ SQL_DESCRIBE_INFO_SCHEMA = Template(
178
+ """
179
+ SELECT
180
+ column_name AS "name",
181
+ column_type as "type",
182
+ 'COLUMN' AS "kind",
183
+ CASE WHEN "null" = 'YES' THEN 'Y' ELSE 'N' END AS "null?",
184
+ NULL::VARCHAR AS "default",
185
+ 'N' AS "primary key",
186
+ 'N' AS "unique key",
187
+ NULL::VARCHAR AS "check",
188
+ NULL::VARCHAR AS "expression",
189
+ NULL::VARCHAR AS "comment",
190
+ NULL::VARCHAR AS "policy name",
191
+ NULL::JSON AS "privacy domain",
192
+ FROM (DESCRIBE ${view})
193
+ """
194
+ )
195
+
196
+
197
+ def describe_table(
198
+ expression: exp.Expression, current_database: str | None = None, current_schema: str | None = None
199
+ ) -> exp.Expression:
200
+ """Redirect to the information_schema._fs_columns to match snowflake.
201
+
202
+ See https://docs.snowflake.com/en/sql-reference/sql/desc-table
203
+ """
204
+
205
+ if (
206
+ isinstance(expression, exp.Describe)
207
+ and (kind := expression.args.get("kind"))
208
+ and isinstance(kind, str)
209
+ and kind.upper() in ("TABLE", "VIEW")
210
+ and (table := expression.find(exp.Table))
211
+ ):
212
+ catalog = table.catalog or current_database
213
+ schema = table.db or current_schema
214
+
215
+ if schema and schema.upper() == "_FS_INFORMATION_SCHEMA":
216
+ # describing an information_schema view
217
+ # (schema already transformed from information_schema -> _fs_information_schema)
218
+ return sqlglot.parse_one(SQL_DESCRIBE_INFO_SCHEMA.substitute(view=f"{schema}.{table.name}"), read="duckdb")
219
+
220
+ return sqlglot.parse_one(
221
+ SQL_DESCRIBE_TABLE.substitute(catalog=catalog, schema=schema, table=table.name),
222
+ read="duckdb",
223
+ )
224
+
225
+ return expression
226
+
227
+
228
+ def drop_schema_cascade(expression: exp.Expression) -> exp.Expression: #
229
+ """Drop schema cascade.
230
+
231
+ By default duckdb won't delete a schema if it contains tables, whereas snowflake will.
232
+ So we add the cascade keyword to mimic snowflake's behaviour.
233
+
234
+ Example:
235
+ >>> import sqlglot
236
+ >>> sqlglot.parse_one("DROP SCHEMA schema1").transform(remove_comment).sql()
237
+ 'DROP SCHEMA schema1 cascade'
238
+ Args:
239
+ expression (exp.Expression): the expression that will be transformed.
240
+
241
+ Returns:
242
+ exp.Expression: The transformed expression.
243
+ """
244
+
245
+ if (
246
+ not isinstance(expression, exp.Drop)
247
+ or not (kind := expression.args.get("kind"))
248
+ or not isinstance(kind, str)
249
+ or kind.upper() != "SCHEMA"
250
+ ):
251
+ return expression
252
+
253
+ new = expression.copy()
254
+ new.args["cascade"] = True
255
+ return new
256
+
257
+
258
+ def dateadd_date_cast(expression: exp.Expression) -> exp.Expression:
259
+ """Cast result of DATEADD to DATE if the given expression is a cast to DATE
260
+ and unit is either DAY, WEEK, MONTH or YEAR to mimic Snowflake's DATEADD
261
+ behaviour.
262
+
263
+ Snowflake;
264
+ SELECT DATEADD(DAY, 3, '2023-03-03'::DATE) as D;
265
+ D: 2023-03-06 (DATE)
266
+ DuckDB;
267
+ SELECT CAST('2023-03-03' AS DATE) + INTERVAL 3 DAY AS D
268
+ D: 2023-03-06 00:00:00 (TIMESTAMP)
269
+ """
270
+
271
+ if not isinstance(expression, exp.DateAdd):
272
+ return expression
273
+
274
+ if expression.unit is None:
275
+ return expression
276
+
277
+ if not isinstance(expression.unit.this, str):
278
+ return expression
279
+
280
+ if (unit := expression.unit.this.upper()) and unit.upper() not in {"DAY", "WEEK", "MONTH", "YEAR"}:
281
+ return expression
282
+
283
+ if not isinstance(expression.this, exp.Cast):
284
+ return expression
285
+
286
+ if expression.this.to.this != exp.DataType.Type.DATE:
287
+ return expression
288
+
289
+ return exp.Cast(
290
+ this=expression,
291
+ to=exp.DataType(this=exp.DataType.Type.DATE, nested=False, prefix=False),
292
+ )
293
+
294
+
295
+ def dateadd_string_literal_timestamp_cast(expression: exp.Expression) -> exp.Expression:
296
+ """Snowflake's DATEADD function implicitly casts string literals to
297
+ timestamps regardless of unit.
298
+ """
299
+ if not isinstance(expression, exp.DateAdd):
300
+ return expression
301
+
302
+ if not isinstance(expression.this, exp.Literal) or not expression.this.is_string:
303
+ return expression
304
+
305
+ new_dateadd = expression.copy()
306
+ new_dateadd.set(
307
+ "this",
308
+ exp.Cast(
309
+ this=expression.this,
310
+ # TODO: support TIMESTAMP_TYPE_MAPPING of TIMESTAMP_LTZ/TZ
311
+ to=exp.DataType(this=exp.DataType.Type.TIMESTAMP, nested=False, prefix=False),
312
+ ),
313
+ )
314
+
315
+ return new_dateadd
316
+
317
+
318
+ def datediff_string_literal_timestamp_cast(expression: exp.Expression) -> exp.Expression:
319
+ """Snowflake's DATEDIFF function implicitly casts string literals to
320
+ timestamps regardless of unit.
321
+ """
322
+
323
+ if not isinstance(expression, exp.DateDiff):
324
+ return expression
325
+
326
+ op1 = expression.this.copy()
327
+ op2 = expression.expression.copy()
328
+
329
+ if isinstance(op1, exp.Literal) and op1.is_string:
330
+ op1 = exp.Cast(
331
+ this=op1,
332
+ # TODO: support TIMESTAMP_TYPE_MAPPING of TIMESTAMP_LTZ/TZ
333
+ to=exp.DataType(this=exp.DataType.Type.TIMESTAMP, nested=False, prefix=False),
334
+ )
335
+
336
+ if isinstance(op2, exp.Literal) and op2.is_string:
337
+ op2 = exp.Cast(
338
+ this=op2,
339
+ # TODO: support TIMESTAMP_TYPE_MAPPING of TIMESTAMP_LTZ/TZ
340
+ to=exp.DataType(this=exp.DataType.Type.TIMESTAMP, nested=False, prefix=False),
341
+ )
342
+
343
+ new_datediff = expression.copy()
344
+ new_datediff.set("this", op1)
345
+ new_datediff.set("expression", op2)
346
+
347
+ return new_datediff
348
+
349
+
350
+ def extract_comment_on_columns(expression: exp.Expression) -> exp.Expression:
351
+ """Extract column comments, removing it from the Expression.
352
+
353
+ duckdb doesn't support comments. So we remove them from the expression and store them in the column_comment arg.
354
+ We also replace the transform the expression to NOP if the statement can't be executed by duckdb.
355
+
356
+ Args:
357
+ expression (exp.Expression): the expression that will be transformed.
358
+
359
+ Returns:
360
+ exp.Expression: The transformed expression, with any comment stored in the new 'table_comment' arg.
361
+ """
362
+
363
+ if isinstance(expression, exp.Alter) and (actions := expression.args.get("actions")):
364
+ new_actions: list[exp.Expression] = []
365
+ col_comments: list[tuple[str, str]] = []
366
+ for a in actions:
367
+ if isinstance(a, exp.AlterColumn) and (comment := a.args.get("comment")):
368
+ col_comments.append((a.name, comment.this))
369
+ else:
370
+ new_actions.append(a)
371
+ if not new_actions:
372
+ expression = SUCCESS_NOP.copy()
373
+ else:
374
+ expression.set("actions", new_actions)
375
+ expression.args["col_comments"] = col_comments
376
+
377
+ return expression
378
+
379
+
380
+ def extract_comment_on_table(expression: exp.Expression) -> exp.Expression:
381
+ """Extract table comment, removing it from the Expression.
382
+
383
+ duckdb doesn't support comments. So we remove them from the expression and store them in the table_comment arg.
384
+ We also replace the transform the expression to NOP if the statement can't be executed by duckdb.
385
+
386
+ Args:
387
+ expression (exp.Expression): the expression that will be transformed.
388
+
389
+ Returns:
390
+ exp.Expression: The transformed expression, with any comment stored in the new 'table_comment' arg.
391
+ """
392
+
393
+ if isinstance(expression, exp.Create) and (table := expression.find(exp.Table)):
394
+ comment = None
395
+ if props := cast(exp.Properties, expression.args.get("properties")):
396
+ other_props = []
397
+ for p in props.expressions:
398
+ if isinstance(p, exp.SchemaCommentProperty) and (isinstance(p.this, (exp.Literal, exp.Var))):
399
+ comment = p.this.this
400
+ else:
401
+ other_props.append(p)
402
+
403
+ new = expression.copy()
404
+ new_props: exp.Properties = new.args["properties"]
405
+ new_props.set("expressions", other_props)
406
+ new.args["table_comment"] = (table, comment)
407
+ return new
408
+ elif (
409
+ isinstance(expression, exp.Comment)
410
+ and (cexp := expression.args.get("expression"))
411
+ and (table := expression.find(exp.Table))
412
+ ):
413
+ new = SUCCESS_NOP.copy()
414
+ new.args["table_comment"] = (table, cexp.this)
415
+ return new
416
+ elif (
417
+ isinstance(expression, exp.Alter)
418
+ and (sexp := expression.find(exp.AlterSet))
419
+ and (scp := sexp.find(exp.SchemaCommentProperty))
420
+ and isinstance(scp.this, exp.Literal)
421
+ and (table := expression.find(exp.Table))
422
+ ):
423
+ new = SUCCESS_NOP.copy()
424
+ new.args["table_comment"] = (table, scp.this.this)
425
+ return new
426
+
427
+ return expression
428
+
429
+
430
+ def extract_text_length(expression: exp.Expression) -> exp.Expression:
431
+ """Extract length of text columns.
432
+
433
+ duckdb doesn't have fixed-sized text types. So we capture the size of text types and store that in the
434
+ character_maximum_length arg.
435
+
436
+ Args:
437
+ expression (exp.Expression): the expression that will be transformed.
438
+
439
+ Returns:
440
+ exp.Expression: The original expression, with any text lengths stored in the new 'text_lengths' arg.
441
+ """
442
+
443
+ if isinstance(expression, (exp.Create, exp.Alter)):
444
+ text_lengths = []
445
+
446
+ # exp.Select is for a ctas, exp.Schema is a plain definition
447
+ if cols := expression.find(exp.Select, exp.Schema):
448
+ expressions = cols.expressions
449
+ else:
450
+ # alter table
451
+ expressions = expression.args.get("actions") or []
452
+ for e in expressions:
453
+ if dts := [
454
+ dt for dt in e.find_all(exp.DataType) if dt.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.TEXT)
455
+ ]:
456
+ col_name = e.alias if isinstance(e, exp.Alias) else e.name
457
+ if len(dts) == 1 and (dt_size := dts[0].find(exp.DataTypeParam)):
458
+ size = (
459
+ isinstance(dt_size.this, exp.Literal)
460
+ and isinstance(dt_size.this.this, str)
461
+ and int(dt_size.this.this)
462
+ )
463
+ else:
464
+ size = 16777216
465
+ text_lengths.append((col_name, size))
466
+
467
+ if text_lengths:
468
+ expression.args["text_lengths"] = text_lengths
469
+
470
+ return expression
471
+
472
+
473
+ def flatten(expression: exp.Expression) -> exp.Expression:
474
+ """Flatten an array.
475
+
476
+ See https://docs.snowflake.com/en/sql-reference/functions/flatten
477
+
478
+ TODO: support objects.
479
+ """
480
+ if (isinstance(expression, (exp.Lateral, exp.TableFromRows))) and isinstance(expression.this, exp.Explode):
481
+ input_ = (
482
+ expression.this.this.expression if isinstance(expression.this.this, exp.Kwarg) else expression.this.this
483
+ )
484
+ alias = expression.args.get("alias")
485
+ return exp.Table(this=exp.Anonymous(this="_fs_flatten", expressions=[input_]), alias=alias)
486
+
487
+ return expression
488
+
489
+
490
+ def flatten_value_cast_as_varchar(expression: exp.Expression) -> exp.Expression:
491
+ """Return raw unquoted string when flatten VALUE is cast to varchar.
492
+
493
+ Returns a raw string using the Duckdb ->> operator, aka the json_extract_string function, see
494
+ https://duckdb.org/docs/extensions/json#json-extraction-functions
495
+ """
496
+ if (
497
+ isinstance(expression, exp.Cast)
498
+ and isinstance(expression.this, exp.Column)
499
+ and expression.this.name.upper() == "VALUE"
500
+ and expression.to.this in [exp.DataType.Type.VARCHAR, exp.DataType.Type.TEXT]
501
+ and (select := expression.find_ancestor(exp.Select))
502
+ and select.find(exp.Explode)
503
+ ):
504
+ return exp.JSONExtractScalar(this=expression.this, expression=exp.JSONPath(expressions=[exp.JSONPathRoot()]))
505
+
506
+ return expression
507
+
508
+
509
+ def float_to_double(expression: exp.Expression) -> exp.Expression:
510
+ """Convert float to double for 64 bit precision.
511
+
512
+ Snowflake floats are all 64 bit (ie: double)
513
+ see https://docs.snowflake.com/en/sql-reference/data-types-numeric#float-float4-float8
514
+ """
515
+
516
+ if isinstance(expression, exp.DataType) and expression.this == exp.DataType.Type.FLOAT:
517
+ expression.args["this"] = exp.DataType.Type.DOUBLE
518
+
519
+ return expression
520
+
521
+
522
+ def identifier(expression: exp.Expression) -> exp.Expression:
523
+ """Convert identifier function to an identifier.
524
+
525
+ See https://docs.snowflake.com/en/sql-reference/identifier-literal
526
+ """
527
+
528
+ if (
529
+ isinstance(expression, exp.Anonymous)
530
+ and isinstance(expression.this, str)
531
+ and expression.this.upper() == "IDENTIFIER"
532
+ ):
533
+ expression = exp.Identifier(this=expression.expressions[0].this, quoted=False)
534
+
535
+ return expression
536
+
537
+
538
+ def indices_to_json_extract(expression: exp.Expression) -> exp.Expression:
539
+ """Convert indices on objects and arrays to json_extract or json_extract_string
540
+
541
+ Supports Snowflake array indices, see
542
+ https://docs.snowflake.com/en/sql-reference/data-types-semistructured#accessing-elements-of-an-array-by-index-or-by-slice
543
+ and object indices, see
544
+ https://docs.snowflake.com/en/sql-reference/data-types-semistructured#accessing-elements-of-an-object-by-key
545
+
546
+ Duckdb uses the -> operator, aka the json_extract function, see
547
+ https://duckdb.org/docs/extensions/json#json-extraction-functions
548
+
549
+ This works for Snowflake arrays too because we convert them to JSON in duckdb.
550
+ """
551
+ if (
552
+ isinstance(expression, exp.Bracket)
553
+ and len(expression.expressions) == 1
554
+ and (index := expression.expressions[0])
555
+ and isinstance(index, exp.Literal)
556
+ and index.this
557
+ ):
558
+ if isinstance(expression.parent, exp.Cast) and expression.parent.to.this == exp.DataType.Type.VARCHAR:
559
+ # If the parent is a cast to varchar, we need to use JSONExtractScalar
560
+ # to get the unquoted string value.
561
+ klass = exp.JSONExtractScalar
562
+ else:
563
+ klass = exp.JSONExtract
564
+ if index.is_string:
565
+ return klass(this=expression.this, expression=exp.Literal(this=f"$.{index.this}", is_string=True))
566
+ else:
567
+ return klass(this=expression.this, expression=exp.Literal(this=f"$[{index.this}]", is_string=True))
568
+
569
+ return expression
570
+
571
+
572
+ def information_schema_fs_columns(expression: exp.Expression) -> exp.Expression:
573
+ """Redirect to the _FS_COLUMNS view which has metadata that matches snowflake.
574
+
575
+ Because duckdb doesn't store character_maximum_length or character_octet_length.
576
+ """
577
+
578
+ if (
579
+ isinstance(expression, exp.Table)
580
+ and expression.db
581
+ and expression.db.upper() == "INFORMATION_SCHEMA"
582
+ and expression.name
583
+ and expression.name.upper() == "COLUMNS"
584
+ ):
585
+ expression.set("this", exp.Identifier(this="_FS_COLUMNS", quoted=False))
586
+ expression.set("db", exp.Identifier(this="_FS_INFORMATION_SCHEMA", quoted=False))
587
+
588
+ return expression
589
+
590
+
591
+ def information_schema_databases(
592
+ expression: exp.Expression,
593
+ current_schema: str | None = None,
594
+ ) -> exp.Expression:
595
+ if (
596
+ isinstance(expression, exp.Table)
597
+ and (
598
+ expression.db.upper() == "INFORMATION_SCHEMA"
599
+ or (current_schema and current_schema.upper() == "INFORMATION_SCHEMA")
600
+ )
601
+ and expression.name.upper() == "DATABASES"
602
+ ):
603
+ return exp.Table(
604
+ this=exp.Identifier(this="DATABASES", quoted=False),
605
+ db=exp.Identifier(this="_FS_INFORMATION_SCHEMA", quoted=False),
606
+ )
607
+ return expression
608
+
609
+
610
+ def information_schema_fs_tables(
611
+ expression: exp.Expression,
612
+ ) -> exp.Expression:
613
+ """Use _FS_TABLES to access additional metadata columns (eg: comment)."""
614
+
615
+ if (
616
+ isinstance(expression, exp.Select)
617
+ and (tbl := expression.find(exp.Table))
618
+ and tbl.db.upper() == "INFORMATION_SCHEMA"
619
+ and tbl.name.upper() == "TABLES"
620
+ ):
621
+ tbl.set("this", exp.Identifier(this="_FS_TABLES", quoted=False))
622
+ tbl.set("db", exp.Identifier(this="_FS_INFORMATION_SCHEMA", quoted=False))
623
+
624
+ return expression
625
+
626
+
627
+ def information_schema_fs_views(expression: exp.Expression) -> exp.Expression:
628
+ """Use _FS_VIEWS to return Snowflake's version instead of duckdb's."""
629
+
630
+ if (
631
+ isinstance(expression, exp.Select)
632
+ and (tbl := expression.find(exp.Table))
633
+ and tbl.db.upper() == "INFORMATION_SCHEMA"
634
+ and tbl.name.upper() == "VIEWS"
635
+ ):
636
+ tbl.set("this", exp.Identifier(this="_FS_VIEWS", quoted=False))
637
+ tbl.set("db", exp.Identifier(this="_FS_INFORMATION_SCHEMA", quoted=False))
638
+
639
+ return expression
640
+
641
+
642
+ NUMBER_38_0 = [
643
+ exp.DataTypeParam(this=exp.Literal(this="38", is_string=False)),
644
+ exp.DataTypeParam(this=exp.Literal(this="0", is_string=False)),
645
+ ]
646
+
647
+
648
+ def integer_precision(expression: exp.Expression) -> exp.Expression:
649
+ """Convert integers and number(38,0) to bigint.
650
+
651
+ So fetch_all will return int and dataframes will return them with a dtype of int64.
652
+ """
653
+ if (
654
+ isinstance(expression, exp.DataType)
655
+ and expression.this == exp.DataType.Type.DECIMAL
656
+ and (not expression.expressions or expression.expressions == NUMBER_38_0)
657
+ ) or expression.this in (exp.DataType.Type.INT, exp.DataType.Type.SMALLINT, exp.DataType.Type.TINYINT):
658
+ return exp.DataType(
659
+ this=exp.DataType.Type.BIGINT,
660
+ nested=False,
661
+ prefix=False,
662
+ )
663
+
664
+ return expression
665
+
666
+
667
+ def json_extract_cased_as_varchar(expression: exp.Expression) -> exp.Expression:
668
+ """Convert json to varchar inside JSONExtract.
669
+
670
+ Snowflake case conversion (upper/lower) turns variant into varchar. This
671
+ mimics that behaviour within get_path.
672
+
673
+ TODO: a generic version that works on any variant, not just JSONExtract
674
+
675
+ Returns a raw string using the Duckdb ->> operator, aka the json_extract_string function, see
676
+ https://duckdb.org/docs/extensions/json#json-extraction-functions
677
+ """
678
+ if (
679
+ isinstance(expression, (exp.Upper, exp.Lower))
680
+ and (gp := expression.this)
681
+ and isinstance(gp, exp.JSONExtract)
682
+ and (path := gp.expression)
683
+ and isinstance(path, exp.JSONPath)
684
+ ):
685
+ expression.set("this", exp.JSONExtractScalar(this=gp.this, expression=path))
686
+
687
+ return expression
688
+
689
+
690
+ def json_extract_cast_as_varchar(expression: exp.Expression) -> exp.Expression:
691
+ """Return raw unquoted string when casting json extraction to varchar.
692
+
693
+ Returns a raw string using the Duckdb ->> operator, aka the json_extract_string function, see
694
+ https://duckdb.org/docs/extensions/json#json-extraction-functions
695
+ """
696
+ if (
697
+ isinstance(expression, exp.Cast)
698
+ and (je := expression.this)
699
+ and isinstance(je, exp.JSONExtract)
700
+ and (path := je.expression)
701
+ and isinstance(path, exp.JSONPath)
702
+ ):
703
+ je.replace(exp.JSONExtractScalar(this=je.this, expression=path))
704
+ return expression
705
+
706
+
707
+ def json_extract_precedence(expression: exp.Expression) -> exp.Expression:
708
+ """Associate json extract operands to avoid duckdb operators of higher precedence transforming the expression.
709
+
710
+ See https://github.com/tekumara/fakesnow/issues/53
711
+ """
712
+ if isinstance(expression, (exp.JSONExtract, exp.JSONExtractScalar)):
713
+ return exp.Paren(this=expression)
714
+ return expression
715
+
716
+
717
+ def random(expression: exp.Expression) -> exp.Expression:
718
+ """Convert random() and random(seed).
719
+
720
+ Snowflake random() is an signed 64 bit integer.
721
+ Duckdb random() is a double between 0 and 1 and uses setseed() to set the seed.
722
+ """
723
+ if isinstance(expression, exp.Select) and (rand := expression.find(exp.Rand)):
724
+ # shift result to between min and max signed 64bit integer
725
+ new_rand = exp.Cast(
726
+ this=exp.Paren(
727
+ this=exp.Mul(
728
+ this=exp.Paren(this=exp.Sub(this=exp.Rand(), expression=exp.Literal(this="0.5", is_string=False))),
729
+ expression=exp.Literal(this="9223372036854775807", is_string=False),
730
+ )
731
+ ),
732
+ to=exp.DataType(this=exp.DataType.Type.BIGINT, nested=False, prefix=False),
733
+ )
734
+
735
+ rand.replace(new_rand)
736
+
737
+ # convert seed to double between 0 and 1 by dividing by max INTEGER (int32)
738
+ # (not max BIGINT (int64) because we don't have enough floating point precision to distinguish seeds)
739
+ # then attach to SELECT as the seed arg
740
+ # (we can't attach it to exp.Rand because it will be rendered in the sql)
741
+ if rand.this and isinstance(rand.this, exp.Literal):
742
+ expression.args["seed"] = f"{rand.this}/2147483647-0.5"
743
+
744
+ return expression
745
+
746
+
747
+ def sample(expression: exp.Expression) -> exp.Expression:
748
+ if isinstance(expression, exp.TableSample) and not expression.args.get("method"):
749
+ # set snowflake default (bernoulli) rather than use the duckdb default (system)
750
+ # because bernoulli works better at small row sizes like we have in tests
751
+ expression.set("method", exp.Var(this="BERNOULLI"))
752
+
753
+ return expression
754
+
755
+
756
+ def object_construct(expression: exp.Expression) -> exp.Expression:
757
+ """Convert OBJECT_CONSTRUCT to TO_JSON.
758
+
759
+ Internally snowflake stores OBJECT types as a json string, so the Duckdb JSON type most closely matches.
760
+
761
+ See https://docs.snowflake.com/en/sql-reference/functions/object_construct
762
+ """
763
+
764
+ if not isinstance(expression, exp.Struct):
765
+ return expression
766
+
767
+ non_null_expressions = []
768
+ for e in expression.expressions:
769
+ if not (isinstance(e, exp.PropertyEQ)):
770
+ non_null_expressions.append(e)
771
+ continue
772
+
773
+ left = e.left
774
+ right = e.right
775
+
776
+ left_is_null = isinstance(left, exp.Null)
777
+ right_is_null = isinstance(right, exp.Null)
778
+
779
+ if left_is_null or right_is_null:
780
+ continue
781
+
782
+ non_null_expressions.append(e)
783
+
784
+ new_struct = expression.copy()
785
+ new_struct.set("expressions", non_null_expressions)
786
+ return exp.Anonymous(this="TO_JSON", expressions=[new_struct])
787
+
788
+
789
+ def regex_replace(expression: exp.Expression) -> exp.Expression:
790
+ """Transform regex_replace expressions from snowflake to duckdb."""
791
+
792
+ if isinstance(expression, exp.RegexpReplace) and isinstance(expression.expression, exp.Literal):
793
+ if len(expression.args) > 3:
794
+ # see https://docs.snowflake.com/en/sql-reference/functions/regexp_replace
795
+ raise NotImplementedError(
796
+ "REGEXP_REPLACE with additional parameters (eg: <position>, <occurrence>, <parameters>)"
797
+ )
798
+
799
+ # pattern: snowflake requires escaping backslashes in single-quoted string constants, but duckdb doesn't
800
+ # see https://docs.snowflake.com/en/sql-reference/functions-regexp#label-regexp-escape-character-caveats
801
+ expression.args["expression"] = exp.Literal(
802
+ this=expression.expression.this.replace("\\\\", "\\"), is_string=True
803
+ )
804
+
805
+ if not expression.args.get("replacement"):
806
+ # if no replacement string, the snowflake default is ''
807
+ expression.args["replacement"] = exp.Literal(this="", is_string=True)
808
+
809
+ # snowflake regex replacements are global
810
+ expression.args["modifiers"] = exp.Literal(this="g", is_string=True)
811
+
812
+ return expression
813
+
814
+
815
+ def regex_substr(expression: exp.Expression) -> exp.Expression:
816
+ """Transform regex_substr expressions from snowflake to duckdb.
817
+
818
+ See https://docs.snowflake.com/en/sql-reference/functions/regexp_substr
819
+ """
820
+
821
+ if isinstance(expression, exp.RegexpExtract):
822
+ subject = expression.this
823
+
824
+ # pattern: snowflake requires escaping backslashes in single-quoted string constants, but duckdb doesn't
825
+ # see https://docs.snowflake.com/en/sql-reference/functions-regexp#label-regexp-escape-character-caveats
826
+ pattern = expression.expression
827
+ pattern.args["this"] = pattern.this.replace("\\\\", "\\")
828
+
829
+ # number of characters from the beginning of the string where the function starts searching for matches
830
+ position = expression.args["position"] or exp.Literal(this="1", is_string=False)
831
+
832
+ # which occurrence of the pattern to match
833
+ occurrence = expression.args["occurrence"]
834
+ occurrence = int(occurrence.this) if occurrence else 1
835
+
836
+ # the duckdb dialect increments bracket (ie: index) expressions by 1 because duckdb is 1-indexed,
837
+ # so we need to compensate by subtracting 1
838
+ occurrence = exp.Literal(this=str(occurrence - 1), is_string=False)
839
+
840
+ if parameters := expression.args["parameters"]:
841
+ # 'e' parameter doesn't make sense for duckdb
842
+ regex_parameters = exp.Literal(this=parameters.this.replace("e", ""), is_string=True)
843
+ else:
844
+ regex_parameters = exp.Literal(is_string=True)
845
+
846
+ group_num = expression.args["group"]
847
+ if not group_num:
848
+ if isinstance(regex_parameters.this, str) and "e" in regex_parameters.this:
849
+ group_num = exp.Literal(this="1", is_string=False)
850
+ else:
851
+ group_num = exp.Literal(this="0", is_string=False)
852
+
853
+ expression = exp.Bracket(
854
+ this=exp.Anonymous(
855
+ this="regexp_extract_all",
856
+ expressions=[
857
+ # slice subject from position onwards
858
+ exp.Bracket(this=subject, expressions=[exp.Slice(this=position)]),
859
+ pattern,
860
+ group_num,
861
+ regex_parameters,
862
+ ],
863
+ ),
864
+ # select index of occurrence
865
+ expressions=[occurrence],
866
+ )
867
+
868
+ return expression
869
+
870
+
871
+ # TODO: move this into a Dialect as a transpilation
872
+ def set_schema(expression: exp.Expression, current_database: str | None) -> exp.Expression:
873
+ """Transform USE SCHEMA/DATABASE to SET schema.
874
+
875
+ Example:
876
+ >>> import sqlglot
877
+ >>> sqlglot.parse_one("USE SCHEMA bar").transform(set_schema, current_database="foo").sql()
878
+ "SET schema = 'foo.bar'"
879
+ >>> sqlglot.parse_one("USE SCHEMA foo.bar").transform(set_schema).sql()
880
+ "SET schema = 'foo.bar'"
881
+ >>> sqlglot.parse_one("USE DATABASE marts").transform(set_schema).sql()
882
+ "SET schema = 'marts.main'"
883
+
884
+ See tests for more examples.
885
+ Args:
886
+ expression (exp.Expression): the expression that will be transformed.
887
+
888
+ Returns:
889
+ exp.Expression: A SET schema expression if the input is a USE
890
+ expression, otherwise expression is returned as-is.
891
+ """
892
+
893
+ if (
894
+ isinstance(expression, exp.Use)
895
+ and (kind := expression.args.get("kind"))
896
+ and isinstance(kind, exp.Var)
897
+ and kind.name
898
+ and kind.name.upper() in ["SCHEMA", "DATABASE"]
899
+ ):
900
+ assert expression.this, f"No identifier for USE expression {expression}"
901
+
902
+ if kind.name.upper() == "DATABASE":
903
+ # duckdb's default schema is main
904
+ database = expression.this.name
905
+ return exp.Command(
906
+ this="SET", expression=exp.Literal.string(f"schema = '{database}.main'"), set_database=database
907
+ )
908
+ else:
909
+ # SCHEMA
910
+ if db := expression.this.args.get("db"): # noqa: SIM108
911
+ db_name = db.name
912
+ else:
913
+ # isn't qualified with a database
914
+ db_name = current_database
915
+
916
+ # assertion always true because check_db_schema is called before this
917
+ assert db_name
918
+
919
+ schema = expression.this.name
920
+ return exp.Command(
921
+ this="SET", expression=exp.Literal.string(f"schema = '{db_name}.{schema}'"), set_schema=schema
922
+ )
923
+
924
+ return expression
925
+
926
+
927
+ def split(expression: exp.Expression) -> exp.Expression:
928
+ """
929
+ Convert output of duckdb str_split from varchar[] to JSON array to match Snowflake.
930
+ """
931
+ if isinstance(expression, exp.Split):
932
+ return exp.Anonymous(this="to_json", expressions=[expression])
933
+
934
+ return expression
935
+
936
+
937
+ def tag(expression: exp.Expression) -> exp.Expression:
938
+ """Handle tags. Transfer tags into upserts of the tag table.
939
+
940
+ duckdb doesn't support tags. In lieu of a full implementation, for now we make it a NOP.
941
+
942
+ Example:
943
+ >>> import sqlglot
944
+ >>> sqlglot.parse_one("ALTER TABLE table1 SET TAG foo='bar'").transform(tag).sql()
945
+ "SELECT 'Statement executed successfully.'"
946
+ Args:
947
+ expression (exp.Expression): the expression that will be transformed.
948
+
949
+ Returns:
950
+ exp.Expression: The transformed expression.
951
+ """
952
+
953
+ if isinstance(expression, exp.Alter) and (actions := expression.args.get("actions")):
954
+ for a in actions:
955
+ if isinstance(a, exp.AlterSet) and a.args.get("tag"):
956
+ return SUCCESS_NOP
957
+ elif (
958
+ isinstance(expression, exp.Command)
959
+ and (cexp := expression.args.get("expression"))
960
+ and isinstance(cexp, str)
961
+ and "SET TAG" in cexp.upper()
962
+ ):
963
+ # alter table modify column set tag
964
+ return SUCCESS_NOP
965
+ elif (
966
+ isinstance(expression, exp.Create)
967
+ and (kind := expression.args.get("kind"))
968
+ and isinstance(kind, str)
969
+ and kind.upper() == "TAG"
970
+ ):
971
+ return SUCCESS_NOP
972
+
973
+ return expression
974
+
975
+
976
+ def to_date(expression: exp.Expression) -> exp.Expression:
977
+ """Convert to_date() to a cast.
978
+
979
+ See https://docs.snowflake.com/en/sql-reference/functions/to_date
980
+
981
+ Example:
982
+ >>> import sqlglot
983
+ >>> sqlglot.parse_one("SELECT to_date(to_timestamp(0))").transform(to_date).sql()
984
+ "SELECT CAST(DATE_TRUNC('day', TO_TIMESTAMP(0)) AS DATE)"
985
+ Args:
986
+ expression (exp.Expression): the expression that will be transformed.
987
+
988
+ Returns:
989
+ exp.Expression: The transformed expression.
990
+ """
991
+
992
+ if (
993
+ isinstance(expression, exp.Anonymous)
994
+ and isinstance(expression.this, str)
995
+ and expression.this.upper() == "TO_DATE"
996
+ ):
997
+ return exp.Cast(
998
+ this=expression.expressions[0],
999
+ to=exp.DataType(this=exp.DataType.Type.DATE, nested=False, prefix=False),
1000
+ )
1001
+ return expression
1002
+
1003
+
1004
+ def _get_to_number_args(e: exp.ToNumber) -> tuple[exp.Expression | None, exp.Expression | None, exp.Expression | None]:
1005
+ arg_format = e.args.get("format")
1006
+ arg_precision = e.args.get("precision")
1007
+ arg_scale = e.args.get("scale")
1008
+
1009
+ _format = None
1010
+ _precision = None
1011
+ _scale = None
1012
+
1013
+ # to_number(value, <format>, <precision>, <scale>)
1014
+ if arg_format:
1015
+ if arg_format.is_string:
1016
+ # to_number('100', 'TM9' ...)
1017
+ _format = arg_format
1018
+
1019
+ # to_number('100', 'TM9', 10 ...)
1020
+ if arg_precision:
1021
+ _precision = arg_precision
1022
+
1023
+ # to_number('100', 'TM9', 10, 2)
1024
+ if arg_scale:
1025
+ _scale = arg_scale
1026
+ else:
1027
+ # to_number('100', 10, ...)
1028
+ # arg_format is not a string, so it must be precision.
1029
+ _precision = arg_format
1030
+
1031
+ # to_number('100', 10, 2)
1032
+ # And arg_precision must be scale
1033
+ if arg_precision:
1034
+ _scale = arg_precision
1035
+ elif arg_precision:
1036
+ _precision = arg_precision
1037
+ if arg_scale:
1038
+ _scale = arg_scale
1039
+
1040
+ return _format, _precision, _scale
1041
+
1042
+
1043
+ def _to_decimal(expression: exp.Expression, cast_node: type[exp.Cast]) -> exp.Expression:
1044
+ expressions: list[exp.Expression] = expression.expressions
1045
+
1046
+ if len(expressions) > 1 and expressions[1].is_string:
1047
+ # see https://docs.snowflake.com/en/sql-reference/functions/to_decimal#arguments
1048
+ raise NotImplementedError(f"{expression.this} with format argument")
1049
+
1050
+ precision = expressions[1] if len(expressions) > 1 else exp.Literal(this="38", is_string=False)
1051
+ scale = expressions[2] if len(expressions) > 2 else exp.Literal(this="0", is_string=False)
1052
+
1053
+ return cast_node(
1054
+ this=expressions[0],
1055
+ to=exp.DataType(this=exp.DataType.Type.DECIMAL, expressions=[precision, scale], nested=False, prefix=False),
1056
+ )
1057
+
1058
+
1059
+ def to_decimal(expression: exp.Expression) -> exp.Expression:
1060
+ """Transform to_decimal, to_number, to_numeric expressions from snowflake to duckdb.
1061
+
1062
+ See https://docs.snowflake.com/en/sql-reference/functions/to_decimal
1063
+ """
1064
+
1065
+ if isinstance(expression, exp.ToNumber):
1066
+ format_, precision, scale = _get_to_number_args(expression)
1067
+ if format_:
1068
+ raise NotImplementedError(f"{expression.this} with format argument")
1069
+
1070
+ if not precision:
1071
+ precision = exp.Literal(this="38", is_string=False)
1072
+ if not scale:
1073
+ scale = exp.Literal(this="0", is_string=False)
1074
+
1075
+ return exp.Cast(
1076
+ this=expression.this,
1077
+ to=exp.DataType(this=exp.DataType.Type.DECIMAL, expressions=[precision, scale], nested=False, prefix=False),
1078
+ )
1079
+
1080
+ if (
1081
+ isinstance(expression, exp.Anonymous)
1082
+ and isinstance(expression.this, str)
1083
+ and expression.this.upper() in ["TO_DECIMAL", "TO_NUMERIC"]
1084
+ ):
1085
+ return _to_decimal(expression, exp.Cast)
1086
+
1087
+ return expression
1088
+
1089
+
1090
+ def try_to_decimal(expression: exp.Expression) -> exp.Expression:
1091
+ """Transform try_to_decimal, try_to_number, try_to_numeric expressions from snowflake to duckdb.
1092
+ See https://docs.snowflake.com/en/sql-reference/functions/try_to_decimal
1093
+ """
1094
+
1095
+ if (
1096
+ isinstance(expression, exp.Anonymous)
1097
+ and isinstance(expression.this, str)
1098
+ and expression.this.upper() in ["TRY_TO_DECIMAL", "TRY_TO_NUMBER", "TRY_TO_NUMERIC"]
1099
+ ):
1100
+ return _to_decimal(expression, exp.TryCast)
1101
+
1102
+ return expression
1103
+
1104
+
1105
+ def to_timestamp(expression: exp.Expression) -> exp.Expression:
1106
+ """Convert to_timestamp(seconds) to timestamp without timezone (ie: TIMESTAMP_NTZ).
1107
+
1108
+ See https://docs.snowflake.com/en/sql-reference/functions/to_timestamp
1109
+ """
1110
+
1111
+ if isinstance(expression, exp.UnixToTime):
1112
+ return exp.Cast(
1113
+ this=expression,
1114
+ to=exp.DataType(this=exp.DataType.Type.TIMESTAMP, nested=False, prefix=False),
1115
+ )
1116
+ return expression
1117
+
1118
+
1119
+ def to_timestamp_ntz(expression: exp.Expression) -> exp.Expression:
1120
+ """Convert to_timestamp_ntz to to_timestamp (StrToTime).
1121
+
1122
+ Because it's not yet supported by sqlglot, see https://github.com/tobymao/sqlglot/issues/2748
1123
+ """
1124
+
1125
+ if isinstance(expression, exp.Anonymous) and (
1126
+ isinstance(expression.this, str) and expression.this.upper() == "TO_TIMESTAMP_NTZ"
1127
+ ):
1128
+ return exp.StrToTime(
1129
+ this=expression.expressions[0],
1130
+ format=exp.Literal(this="%Y-%m-%d %H:%M:%S", is_string=True),
1131
+ )
1132
+ return expression
1133
+
1134
+
1135
+ def timestamp_ntz(expression: exp.Expression) -> exp.Expression:
1136
+ """Convert timestamp_ntz (snowflake) to timestamp (duckdb).
1137
+
1138
+ NB: timestamp_ntz defaults to nanosecond precision (ie: NTZ(9)). The duckdb equivalent is TIMESTAMP_NS.
1139
+ However we use TIMESTAMP (ie: microsecond precision) here rather than TIMESTAMP_NS to avoid
1140
+ https://github.com/duckdb/duckdb/issues/7980 in test_write_pandas_timestamp_ntz.
1141
+ """
1142
+
1143
+ if isinstance(expression, exp.DataType) and expression.this == exp.DataType.Type.TIMESTAMPNTZ:
1144
+ return exp.DataType(this=exp.DataType.Type.TIMESTAMP)
1145
+
1146
+ return expression
1147
+
1148
+
1149
+ def trim_cast_varchar(expression: exp.Expression) -> exp.Expression:
1150
+ """Snowflake's TRIM casts input to VARCHAR implicitly."""
1151
+
1152
+ if not (isinstance(expression, exp.Trim)):
1153
+ return expression
1154
+
1155
+ operand = expression.this
1156
+ if isinstance(operand, exp.Cast) and operand.to.this in [exp.DataType.Type.VARCHAR, exp.DataType.Type.TEXT]:
1157
+ return expression
1158
+
1159
+ return exp.Trim(
1160
+ this=exp.Cast(this=operand, to=exp.DataType(this=exp.DataType.Type.VARCHAR, nested=False, prefix=False))
1161
+ )
1162
+
1163
+
1164
+ def try_parse_json(expression: exp.Expression) -> exp.Expression:
1165
+ """Convert TRY_PARSE_JSON() to TRY_CAST(... as JSON).
1166
+
1167
+ Example:
1168
+ >>> import sqlglot
1169
+ >>> sqlglot.parse_one("select try_parse_json('{}')").transform(parse_json).sql()
1170
+ "SELECT TRY_CAST('{}' AS JSON)"
1171
+ Args:
1172
+ expression (exp.Expression): the expression that will be transformed.
1173
+
1174
+ Returns:
1175
+ exp.Expression: The transformed expression.
1176
+ """
1177
+
1178
+ if (
1179
+ isinstance(expression, exp.Anonymous)
1180
+ and isinstance(expression.this, str)
1181
+ and expression.this.upper() == "TRY_PARSE_JSON"
1182
+ ):
1183
+ expressions = expression.expressions
1184
+ return exp.TryCast(
1185
+ this=expressions[0],
1186
+ to=exp.DataType(this=exp.DataType.Type.JSON, nested=False),
1187
+ )
1188
+
1189
+ return expression
1190
+
1191
+
1192
+ def semi_structured_types(expression: exp.Expression) -> exp.Expression:
1193
+ """Convert OBJECT, ARRAY, and VARIANT types to duckdb compatible types.
1194
+
1195
+ Example:
1196
+ >>> import sqlglot
1197
+ >>> sqlglot.parse_one("CREATE TABLE table1 (name object)").transform(semi_structured_types).sql()
1198
+ "CREATE TABLE table1 (name JSON)"
1199
+ Args:
1200
+ expression (exp.Expression): the expression that will be transformed.
1201
+
1202
+ Returns:
1203
+ exp.Expression: The transformed expression.
1204
+ """
1205
+
1206
+ if isinstance(expression, exp.DataType) and expression.this in [
1207
+ exp.DataType.Type.ARRAY,
1208
+ exp.DataType.Type.OBJECT,
1209
+ exp.DataType.Type.VARIANT,
1210
+ ]:
1211
+ new = expression.copy()
1212
+ new.args["this"] = exp.DataType.Type.JSON
1213
+ return new
1214
+
1215
+ return expression
1216
+
1217
+
1218
+ def upper_case_unquoted_identifiers(expression: exp.Expression) -> exp.Expression:
1219
+ """Upper case unquoted identifiers.
1220
+
1221
+ Snowflake represents case-insensitivity using upper-case identifiers in cursor results.
1222
+ duckdb uses lowercase. We convert all unquoted identifiers to uppercase to match snowflake.
1223
+
1224
+ Example:
1225
+ >>> import sqlglot
1226
+ >>> sqlglot.parse_one("select name, name as fname from table1").transform(upper_case_unquoted_identifiers).sql()
1227
+ 'SELECT NAME, NAME AS FNAME FROM TABLE1'
1228
+ Args:
1229
+ expression (exp.Expression): the expression that will be transformed.
1230
+
1231
+ Returns:
1232
+ exp.Expression: The transformed expression.
1233
+ """
1234
+
1235
+ if isinstance(expression, exp.Identifier) and not expression.quoted and isinstance(expression.this, str):
1236
+ new = expression.copy()
1237
+ new.set("this", expression.this.upper())
1238
+ return new
1239
+
1240
+ return expression
1241
+
1242
+
1243
+ def values_columns(expression: exp.Expression) -> exp.Expression:
1244
+ """Support column1, column2 expressions in VALUES.
1245
+
1246
+ Snowflake uses column1, column2 .. for unnamed columns in VALUES. Whereas duckdb uses col0, col1 ..
1247
+ See https://docs.snowflake.com/en/sql-reference/constructs/values#examples
1248
+ """
1249
+
1250
+ if (
1251
+ isinstance(expression, exp.Values)
1252
+ and not expression.alias
1253
+ and expression.find_ancestor(exp.Select)
1254
+ and (values := expression.find(exp.Tuple))
1255
+ ):
1256
+ num_columns = len(values.expressions)
1257
+ columns = [exp.Identifier(this=f"COLUMN{i + 1}", quoted=True) for i in range(num_columns)]
1258
+ expression.set("alias", exp.TableAlias(this=exp.Identifier(this="_", quoted=False), columns=columns))
1259
+
1260
+ return expression
1261
+
1262
+
1263
+ def create_user(expression: exp.Expression) -> exp.Expression:
1264
+ """Transform CREATE USER to a query against the global database's information_schema._fs_users table.
1265
+
1266
+ https://docs.snowflake.com/en/sql-reference/sql/create-user
1267
+ """
1268
+ # XXX: this is a placeholder. We need to implement the full CREATE USER syntax, but
1269
+ # sqlglot doesnt yet support Create for snowflake.
1270
+ if isinstance(expression, exp.Command) and expression.this == "CREATE":
1271
+ sub_exp = expression.expression.strip()
1272
+ if sub_exp.upper().startswith("USER"):
1273
+ _, name, *ignored = sub_exp.split(" ")
1274
+ if ignored:
1275
+ raise NotImplementedError(f"`CREATE USER` with {ignored}")
1276
+ return sqlglot.parse_one(
1277
+ f"INSERT INTO _fs_global._fs_information_schema._fs_users (name) VALUES ('{name}')", read="duckdb"
1278
+ )
1279
+
1280
+ return expression
1281
+
1282
+
1283
+ def update_variables(
1284
+ expression: exp.Expression,
1285
+ variables: Variables,
1286
+ ) -> exp.Expression:
1287
+ if Variables.is_variable_modifier(expression):
1288
+ variables.update_variables(expression)
1289
+ return SUCCESS_NOP # Nothing further to do if its a SET/UNSET operation.
1290
+ return expression
1291
+
1292
+
1293
+ class SHA256(exp.Func):
1294
+ _sql_names: ClassVar = ["SHA256"]
1295
+ arg_types: ClassVar = {"this": True}
1296
+
1297
+
1298
+ def sha256(expression: exp.Expression) -> exp.Expression:
1299
+ """Convert sha2() or sha2_hex() to sha256().
1300
+
1301
+ Convert sha2_binary() to unhex(sha256()).
1302
+
1303
+ Example:
1304
+ >>> import sqlglot
1305
+ >>> sqlglot.parse_one("insert into table1 (name) select sha2('foo')").transform(sha256).sql()
1306
+ "INSERT INTO table1 (name) SELECT SHA256('foo')"
1307
+ Args:
1308
+ expression (exp.Expression): the expression that will be transformed.
1309
+
1310
+ Returns:
1311
+ exp.Expression: The transformed expression.
1312
+ """
1313
+
1314
+ if isinstance(expression, exp.SHA2) and expression.args.get("length", exp.Literal.number(256)).this == "256":
1315
+ return SHA256(this=expression.this)
1316
+ elif (
1317
+ isinstance(expression, exp.Anonymous)
1318
+ and expression.this.upper() == "SHA2_HEX"
1319
+ and (
1320
+ len(expression.expressions) == 1
1321
+ or (len(expression.expressions) == 2 and expression.expressions[1].this == "256")
1322
+ )
1323
+ ):
1324
+ return SHA256(this=expression.expressions[0])
1325
+ elif (
1326
+ isinstance(expression, exp.Anonymous)
1327
+ and expression.this.upper() == "SHA2_BINARY"
1328
+ and (
1329
+ len(expression.expressions) == 1
1330
+ or (len(expression.expressions) == 2 and expression.expressions[1].this == "256")
1331
+ )
1332
+ ):
1333
+ return exp.Unhex(this=SHA256(this=expression.expressions[0]))
1334
+
1335
+ return expression