fakesnow 0.9.24__py3-none-any.whl → 0.9.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fakesnow/__init__.py CHANGED
@@ -90,3 +90,4 @@ def patch(
90
90
  yield None
91
91
  finally:
92
92
  stack.close()
93
+ fs.duck_conn.close()
fakesnow/arrow.py CHANGED
@@ -1,33 +1,45 @@
1
- from typing import Any
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
2
4
 
3
5
  import pyarrow as pa
6
+ import pyarrow.compute as pc
7
+
8
+ from fakesnow.types import ColumnInfo
9
+
10
+
11
+ def to_sf_schema(schema: pa.Schema, rowtype: list[ColumnInfo]) -> pa.Schema:
12
+ # expected by the snowflake connector
13
+ # uses rowtype to populate metadata, rather than the arrow schema type, for consistency with
14
+ # rowtype returned in the response
4
15
 
16
+ assert len(schema) == len(rowtype), f"schema and rowtype must be same length but f{len(schema)=} f{len(rowtype)=}"
5
17
 
6
- def with_sf_metadata(schema: pa.Schema) -> pa.Schema:
7
18
  # see https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp#L32
8
19
  # and https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp#L10
9
- fms = []
10
- for i, t in enumerate(schema.types):
11
- f = schema.field(i)
12
-
13
- # TODO: precision, scale, charLength etc. for all types
14
-
15
- if t == pa.bool_():
16
- fm = f.with_metadata({"logicalType": "BOOLEAN"})
17
- elif t == pa.int64():
18
- # scale and precision required, see here
19
- # https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147
20
- fm = f.with_metadata({"logicalType": "FIXED", "precision": "38", "scale": "0"})
21
- elif t == pa.float64():
22
- fm = f.with_metadata({"logicalType": "REAL"})
23
- elif isinstance(t, pa.Decimal128Type):
24
- fm = f.with_metadata({"logicalType": "FIXED", "precision": str(t.precision), "scale": str(t.scale)})
25
- elif t == pa.string():
26
- # TODO: set charLength to size of column
27
- fm = f.with_metadata({"logicalType": "TEXT", "charLength": "16777216"})
28
- else:
29
- raise NotImplementedError(f"Unsupported Arrow type: {t}")
30
- fms.append(fm)
20
+
21
+ def sf_field(field: pa.Field, c: ColumnInfo) -> pa.Field:
22
+ if isinstance(field.type, pa.TimestampType):
23
+ # snowflake uses a struct to represent timestamps, see timestamp_to_sf_struct
24
+ fields = [pa.field("epoch", pa.int64(), nullable=False), pa.field("fraction", pa.int32(), nullable=False)]
25
+ if field.type.tz:
26
+ fields.append(pa.field("timezone", nullable=False, type=pa.int32()))
27
+ field = field.with_type(pa.struct(fields))
28
+ elif isinstance(field.type, pa.Time64Type):
29
+ field = field.with_type(pa.int64())
30
+
31
+ return field.with_metadata(
32
+ {
33
+ "logicalType": c["type"].upper(),
34
+ # required for FIXED type see
35
+ # https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147
36
+ "precision": str(c["precision"] or 38),
37
+ "scale": str(c["scale"] or 0),
38
+ "charLength": str(c["length"] or 0),
39
+ }
40
+ )
41
+
42
+ fms = [sf_field(schema.field(i), c) for i, c in enumerate(rowtype)]
31
43
  return pa.schema(fms)
32
44
 
33
45
 
@@ -39,29 +51,56 @@ def to_ipc(table: pa.Table) -> pa.Buffer:
39
51
 
40
52
  sink = pa.BufferOutputStream()
41
53
 
42
- with pa.ipc.new_stream(sink, with_sf_metadata(table.schema)) as writer:
54
+ with pa.ipc.new_stream(sink, table.schema) as writer:
43
55
  writer.write_batch(batch)
44
56
 
45
57
  return sink.getvalue()
46
58
 
47
59
 
48
- # TODO: should this be derived before with_schema?
49
- def to_rowtype(schema: pa.Schema) -> list[dict[str, Any]]:
50
- return [
51
- {
52
- "name": f.name,
53
- # TODO
54
- # "database": "",
55
- # "schema": "",
56
- # "table": "",
57
- "nullable": f.nullable,
58
- "type": f.metadata.get(b"logicalType").decode("utf-8").lower(), # type: ignore
59
- # TODO
60
- # "byteLength": 20,
61
- "length": int(f.metadata.get(b"charLength")) if f.metadata.get(b"charLength") else None, # type: ignore
62
- "scale": int(f.metadata.get(b"scale")) if f.metadata.get(b"scale") else None, # type: ignore
63
- "precision": int(f.metadata.get(b"precision")) if f.metadata.get(b"precision") else None, # type: ignore
64
- "collation": None,
65
- }
66
- for f in schema
67
- ]
60
+ def to_sf(table: pa.Table, rowtype: list[ColumnInfo]) -> pa.Table:
61
+ def to_sf_col(col: pa.Array) -> pa.Array:
62
+ if pa.types.is_timestamp(col.type):
63
+ return timestamp_to_sf_struct(col)
64
+ elif pa.types.is_time(col.type):
65
+ # as nanoseconds
66
+ return pc.multiply(col.cast(pa.int64()), 1000) # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/44
67
+ return col
68
+
69
+ return pa.Table.from_arrays([to_sf_col(c) for c in table.columns], schema=to_sf_schema(table.schema, rowtype))
70
+
71
+
72
+ def timestamp_to_sf_struct(ts: pa.Array | pa.ChunkedArray) -> pa.Array:
73
+ if isinstance(ts, pa.ChunkedArray):
74
+ # combine because pa.StructArray.from_arrays doesn't support ChunkedArray
75
+ ts = cast(pa.Array, ts.combine_chunks()) # see https://github.com/zen-xu/pyarrow-stubs/issues/46
76
+
77
+ if not isinstance(ts.type, pa.TimestampType):
78
+ raise ValueError(f"Expected TimestampArray, got {type(ts)}")
79
+
80
+ # Round to seconds, ie: strip subseconds
81
+ tsa_without_us = pc.floor_temporal(ts, unit="second") # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/45
82
+ epoch = pc.divide(tsa_without_us.cast(pa.int64()), 1_000_000) # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/44
83
+
84
+ # Calculate fractional part as nanoseconds
85
+ fraction = pc.multiply(pc.subsecond(ts), 1_000_000_000).cast(pa.int32()) # type: ignore
86
+
87
+ if ts.type.tz:
88
+ assert ts.type.tz == "UTC", f"Timezone {ts.type.tz} not yet supported"
89
+ timezone = pa.array([1440] * len(ts), type=pa.int32())
90
+
91
+ return pa.StructArray.from_arrays(
92
+ arrays=[epoch, fraction, timezone], # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/42
93
+ fields=[
94
+ pa.field("epoch", nullable=False, type=pa.int64()),
95
+ pa.field("fraction", nullable=False, type=pa.int32()),
96
+ pa.field("timezone", nullable=False, type=pa.int32()),
97
+ ],
98
+ )
99
+ else:
100
+ return pa.StructArray.from_arrays(
101
+ arrays=[epoch, fraction], # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/42
102
+ fields=[
103
+ pa.field("epoch", nullable=False, type=pa.int64()),
104
+ pa.field("fraction", nullable=False, type=pa.int32()),
105
+ ],
106
+ )
fakesnow/checks.py CHANGED
@@ -68,3 +68,11 @@ def is_unqualified_table_expression(expression: exp.Expression) -> tuple[bool, b
68
68
  no_schema = not node.args.get("db")
69
69
 
70
70
  return no_database, no_schema
71
+
72
+
73
+ def equal(left: exp.Identifier, right: exp.Identifier) -> bool:
74
+ # as per https://docs.snowflake.com/en/sql-reference/identifiers-syntax#label-identifier-casing
75
+ lid = left.this if left.quoted else left.this.upper()
76
+ rid = right.this if right.quoted else right.this.upper()
77
+
78
+ return lid == rid
fakesnow/cursor.py CHANGED
@@ -112,13 +112,15 @@ class FakeSnowflakeCursor:
112
112
 
113
113
  @property
114
114
  def description(self) -> list[ResultMetadata]:
115
+ return describe_as_result_metadata(self._describe_last_sql())
116
+
117
+ def _describe_last_sql(self) -> list:
115
118
  # use a separate cursor to avoid consuming the result set on this cursor
116
119
  with self._conn.cursor() as cur:
120
+ # TODO: can we replace with self._duck_conn.description?
117
121
  expression = sqlglot.parse_one(f"DESCRIBE {self._last_sql}", read="duckdb")
118
122
  cur._execute(expression, self._last_params) # noqa: SLF001
119
- meta = describe_as_result_metadata(cur.fetchall())
120
-
121
- return meta
123
+ return cur.fetchall()
122
124
 
123
125
  def execute(
124
126
  self,
@@ -137,10 +139,15 @@ class FakeSnowflakeCursor:
137
139
  command, params = self._rewrite_with_params(command, params)
138
140
  if self._conn.nop_regexes and any(re.match(p, command, re.IGNORECASE) for p in self._conn.nop_regexes):
139
141
  transformed = transforms.SUCCESS_NOP
140
- else:
141
- expression = parse_one(command, read="snowflake")
142
- transformed = self._transform(expression)
143
- return self._execute(transformed, params)
142
+ self._execute(transformed, params)
143
+ return self
144
+
145
+ expression = parse_one(command, read="snowflake")
146
+ for exp in self._transform_explode(expression):
147
+ transformed = self._transform(exp)
148
+ self._execute(transformed, params)
149
+
150
+ return self
144
151
  except snowflake.connector.errors.ProgrammingError as e:
145
152
  self._sqlstate = e.sqlstate
146
153
  raise e
@@ -155,6 +162,7 @@ class FakeSnowflakeCursor:
155
162
  .transform(transforms.extract_comment_on_columns)
156
163
  .transform(transforms.information_schema_fs_columns_snowflake)
157
164
  .transform(transforms.information_schema_fs_tables_ext)
165
+ .transform(transforms.information_schema_fs_views)
158
166
  .transform(transforms.drop_schema_cascade)
159
167
  .transform(transforms.tag)
160
168
  .transform(transforms.semi_structured_types)
@@ -205,9 +213,12 @@ class FakeSnowflakeCursor:
205
213
  .transform(transforms.alter_table_strip_cluster_by)
206
214
  )
207
215
 
208
- def _execute(
209
- self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None
210
- ) -> FakeSnowflakeCursor:
216
+ def _transform_explode(self, expression: exp.Expression) -> list[exp.Expression]:
217
+ # Applies transformations that require splitting the expression into multiple expressions
218
+ # Split transforms have limited support at the moment.
219
+ return transforms.merge(expression)
220
+
221
+ def _execute(self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
211
222
  self._arrow_table = None
212
223
  self._arrow_table_fetch_index = None
213
224
  self._rowcount = None
@@ -284,6 +295,9 @@ class FakeSnowflakeCursor:
284
295
  (affected_count,) = self._duck_conn.fetchall()[0]
285
296
  result_sql = SQL_DELETED_ROWS.substitute(count=affected_count)
286
297
 
298
+ elif cmd == "TRUNCATETABLE":
299
+ result_sql = SQL_SUCCESS
300
+
287
301
  elif cmd in ("DESCRIBE TABLE", "DESCRIBE VIEW"):
288
302
  # DESCRIBE TABLE/VIEW has already been run above to detect and error if the table exists
289
303
  # We now rerun DESCRIBE TABLE/VIEW but transformed with columns to match Snowflake
@@ -343,8 +357,6 @@ class FakeSnowflakeCursor:
343
357
  self._last_sql = result_sql or sql
344
358
  self._last_params = params
345
359
 
346
- return self
347
-
348
360
  def _log_sql(self, sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
349
361
  if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
350
362
  print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
fakesnow/info_schema.py CHANGED
@@ -102,7 +102,7 @@ where catalog_name not in ('memory', 'system', 'temp', '_fs_global')
102
102
  # replicates https://docs.snowflake.com/sql-reference/info-schema/views
103
103
  SQL_CREATE_INFORMATION_SCHEMA_VIEWS_VIEW = Template(
104
104
  """
105
- create view if not exists ${catalog}.information_schema.views AS
105
+ create view if not exists ${catalog}.information_schema._fs_views AS
106
106
  select
107
107
  database_name as table_catalog,
108
108
  schema_name as table_schema,
fakesnow/server.py CHANGED
@@ -5,19 +5,22 @@ import json
5
5
  import secrets
6
6
  from base64 import b64encode
7
7
  from dataclasses import dataclass
8
+ from typing import Any
8
9
 
10
+ import snowflake.connector.errors
9
11
  from starlette.applications import Starlette
10
12
  from starlette.concurrency import run_in_threadpool
11
13
  from starlette.requests import Request
12
14
  from starlette.responses import JSONResponse
13
15
  from starlette.routing import Route
14
16
 
15
- from fakesnow.arrow import to_ipc, to_rowtype, with_sf_metadata
17
+ from fakesnow.arrow import to_ipc, to_sf
16
18
  from fakesnow.fakes import FakeSnowflakeConnection
17
19
  from fakesnow.instance import FakeSnow
20
+ from fakesnow.types import describe_as_rowtype
18
21
 
19
- fs = FakeSnow()
20
- sessions = {}
22
+ shared_fs = FakeSnow()
23
+ sessions: dict[str, FakeSnowflakeConnection] = {}
21
24
 
22
25
 
23
26
  @dataclass
@@ -27,9 +30,19 @@ class ServerError(Exception):
27
30
  message: str
28
31
 
29
32
 
30
- def login_request(request: Request) -> JSONResponse:
33
+ async def login_request(request: Request) -> JSONResponse:
31
34
  database = request.query_params.get("databaseName")
32
35
  schema = request.query_params.get("schemaName")
36
+ body = await request.body()
37
+ body_json = json.loads(gzip.decompress(body))
38
+ session_params: dict[str, Any] = body_json["data"]["SESSION_PARAMETERS"]
39
+ if db_path := session_params.get("FAKESNOW_DB_PATH"):
40
+ # isolated creates a new in-memory database, rather than using the shared in-memory database
41
+ # so this connection won't share any tables with other connections
42
+ fs = FakeSnow() if db_path == ":isolated:" else FakeSnow(db_path=db_path)
43
+ else:
44
+ # share the in-memory database across connections
45
+ fs = shared_fs
33
46
  token = secrets.token_urlsafe(32)
34
47
  sessions[token] = fs.connect(database, schema)
35
48
  return JSONResponse({"data": {"token": token}, "success": True})
@@ -44,16 +57,30 @@ async def query_request(request: Request) -> JSONResponse:
44
57
 
45
58
  sql_text = body_json["sqlText"]
46
59
 
47
- # only a single sql statement is sent at a time by the python snowflake connector
48
- cur = await run_in_threadpool(conn.cursor().execute, sql_text)
49
-
50
- assert cur._arrow_table, "No result set" # noqa: SLF001
51
-
52
- batch_bytes = to_ipc(cur._arrow_table) # noqa: SLF001
53
- rowset_b64 = b64encode(batch_bytes).decode("utf-8")
54
-
55
- # TODO: avoid calling with_sf_metadata twice
56
- rowtype = to_rowtype(with_sf_metadata(cur._arrow_table.schema)) # noqa: SLF001
60
+ try:
61
+ # only a single sql statement is sent at a time by the python snowflake connector
62
+ cur = await run_in_threadpool(conn.cursor().execute, sql_text)
63
+ except snowflake.connector.errors.ProgrammingError as e:
64
+ code = f"{e.errno:06d}"
65
+ return JSONResponse(
66
+ {
67
+ "data": {
68
+ "errorCode": code,
69
+ "sqlState": e.sqlstate,
70
+ },
71
+ "code": code,
72
+ "message": e.msg,
73
+ "success": False,
74
+ }
75
+ )
76
+
77
+ rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001
78
+
79
+ if cur._arrow_table: # noqa: SLF001
80
+ batch_bytes = to_ipc(to_sf(cur._arrow_table, rowtype)) # noqa: SLF001
81
+ rowset_b64 = b64encode(batch_bytes).decode("utf-8")
82
+ else:
83
+ rowset_b64 = ""
57
84
 
58
85
  return JSONResponse(
59
86
  {
fakesnow/transforms.py CHANGED
@@ -7,6 +7,7 @@ from typing import ClassVar, Literal, cast
7
7
  import sqlglot
8
8
  from sqlglot import exp
9
9
 
10
+ from fakesnow import transforms_merge
10
11
  from fakesnow.instance import USERS_TABLE_FQ_NAME
11
12
  from fakesnow.variables import Variables
12
13
 
@@ -36,7 +37,7 @@ def alias_in_join(expression: exp.Expression) -> exp.Expression:
36
37
  def alter_table_strip_cluster_by(expression: exp.Expression) -> exp.Expression:
37
38
  """Turn alter table cluster by into a no-op"""
38
39
  if (
39
- isinstance(expression, exp.AlterTable)
40
+ isinstance(expression, exp.Alter)
40
41
  and (actions := expression.args.get("actions"))
41
42
  and len(actions) == 1
42
43
  and (isinstance(actions[0], exp.Cluster))
@@ -355,7 +356,7 @@ def extract_comment_on_columns(expression: exp.Expression) -> exp.Expression:
355
356
  exp.Expression: The transformed expression, with any comment stored in the new 'table_comment' arg.
356
357
  """
357
358
 
358
- if isinstance(expression, exp.AlterTable) and (actions := expression.args.get("actions")):
359
+ if isinstance(expression, exp.Alter) and (actions := expression.args.get("actions")):
359
360
  new_actions: list[exp.Expression] = []
360
361
  col_comments: list[tuple[str, str]] = []
361
362
  for a in actions:
@@ -409,7 +410,7 @@ def extract_comment_on_table(expression: exp.Expression) -> exp.Expression:
409
410
  new.args["table_comment"] = (table, cexp.this)
410
411
  return new
411
412
  elif (
412
- isinstance(expression, exp.AlterTable)
413
+ isinstance(expression, exp.Alter)
413
414
  and (sexp := expression.find(exp.AlterSet))
414
415
  and (scp := sexp.find(exp.SchemaCommentProperty))
415
416
  and isinstance(scp.this, exp.Literal)
@@ -435,7 +436,7 @@ def extract_text_length(expression: exp.Expression) -> exp.Expression:
435
436
  exp.Expression: The original expression, with any text lengths stored in the new 'text_lengths' arg.
436
437
  """
437
438
 
438
- if isinstance(expression, (exp.Create, exp.AlterTable)):
439
+ if isinstance(expression, (exp.Create, exp.Alter)):
439
440
  text_lengths = []
440
441
 
441
442
  # exp.Select is for a ctas, exp.Schema is a plain definition
@@ -470,7 +471,6 @@ def flatten(expression: exp.Expression) -> exp.Expression:
470
471
 
471
472
  See https://docs.snowflake.com/en/sql-reference/functions/flatten
472
473
 
473
- TODO: return index.
474
474
  TODO: support objects.
475
475
  """
476
476
  if (
@@ -482,20 +482,34 @@ def flatten(expression: exp.Expression) -> exp.Expression:
482
482
  ):
483
483
  explode_expression = expression.this.this.expression
484
484
 
485
- return exp.Lateral(
486
- this=exp.Unnest(
485
+ value = exp.Cast(
486
+ this=explode_expression,
487
+ to=exp.DataType(
488
+ this=exp.DataType.Type.ARRAY,
489
+ expressions=[exp.DataType(this=exp.DataType.Type.JSON, nested=False, prefix=False)],
490
+ nested=True,
491
+ ),
492
+ )
493
+
494
+ return exp.Subquery(
495
+ this=exp.Select(
487
496
  expressions=[
488
- exp.Cast(
489
- this=explode_expression,
490
- to=exp.DataType(
491
- this=exp.DataType.Type.ARRAY,
492
- expressions=[exp.DataType(this=exp.DataType.Type.JSON, nested=False, prefix=False)],
493
- nested=True,
497
+ exp.Unnest(
498
+ expressions=[value],
499
+ alias=exp.Identifier(this="VALUE", quoted=False),
500
+ ),
501
+ exp.Alias(
502
+ this=exp.Sub(
503
+ this=exp.Anonymous(
504
+ this="generate_subscripts", expressions=[value, exp.Literal(this="1", is_string=False)]
505
+ ),
506
+ expression=exp.Literal(this="1", is_string=False),
494
507
  ),
495
- )
508
+ alias=exp.Identifier(this="INDEX", quoted=False),
509
+ ),
496
510
  ],
497
511
  ),
498
- alias=exp.TableAlias(this=alias.this, columns=[exp.Identifier(this="VALUE", quoted=False)]),
512
+ alias=exp.TableAlias(this=alias.this),
499
513
  )
500
514
 
501
515
  return expression
@@ -621,6 +635,20 @@ def information_schema_fs_tables_ext(expression: exp.Expression) -> exp.Expressi
621
635
  return expression
622
636
 
623
637
 
638
+ def information_schema_fs_views(expression: exp.Expression) -> exp.Expression:
639
+ """Use information_schema._fs_views to return Snowflake's version instead of duckdb's."""
640
+
641
+ if (
642
+ isinstance(expression, exp.Select)
643
+ and (tbl_exp := expression.find(exp.Table))
644
+ and tbl_exp.name.upper() == "VIEWS"
645
+ and tbl_exp.db.upper() == "INFORMATION_SCHEMA"
646
+ ):
647
+ tbl_exp.set("this", exp.Identifier(this="_FS_VIEWS", quoted=False))
648
+
649
+ return expression
650
+
651
+
624
652
  def integer_precision(expression: exp.Expression) -> exp.Expression:
625
653
  """Convert integers to bigint.
626
654
 
@@ -691,6 +719,10 @@ def json_extract_precedence(expression: exp.Expression) -> exp.Expression:
691
719
  return expression
692
720
 
693
721
 
722
+ def merge(expression: exp.Expression) -> list[exp.Expression]:
723
+ return transforms_merge.merge(expression)
724
+
725
+
694
726
  def random(expression: exp.Expression) -> exp.Expression:
695
727
  """Convert random() and random(seed).
696
728
 
@@ -702,8 +734,8 @@ def random(expression: exp.Expression) -> exp.Expression:
702
734
  new_rand = exp.Cast(
703
735
  this=exp.Paren(
704
736
  this=exp.Mul(
705
- this=exp.Paren(this=exp.Sub(this=exp.Rand(), expression=exp.Literal(this=0.5, is_string=False))),
706
- expression=exp.Literal(this=9223372036854775807, is_string=False),
737
+ this=exp.Paren(this=exp.Sub(this=exp.Rand(), expression=exp.Literal(this="0.5", is_string=False))),
738
+ expression=exp.Literal(this="9223372036854775807", is_string=False),
707
739
  )
708
740
  ),
709
741
  to=exp.DataType(this=exp.DataType.Type.BIGINT, nested=False, prefix=False),
@@ -804,31 +836,24 @@ def regex_substr(expression: exp.Expression) -> exp.Expression:
804
836
  pattern.args["this"] = pattern.this.replace("\\\\", "\\")
805
837
 
806
838
  # number of characters from the beginning of the string where the function starts searching for matches
807
- try:
808
- position = expression.args["position"]
809
- except KeyError:
810
- position = exp.Literal(this="1", is_string=False)
839
+ position = expression.args["position"] or exp.Literal(this="1", is_string=False)
811
840
 
812
841
  # which occurrence of the pattern to match
813
- try:
814
- occurrence = int(expression.args["occurrence"].this)
815
- except KeyError:
816
- occurrence = 1
842
+ occurrence = expression.args["occurrence"]
843
+ occurrence = int(occurrence.this) if occurrence else 1
817
844
 
818
845
  # the duckdb dialect increments bracket (ie: index) expressions by 1 because duckdb is 1-indexed,
819
846
  # so we need to compensate by subtracting 1
820
847
  occurrence = exp.Literal(this=str(occurrence - 1), is_string=False)
821
848
 
822
- try:
823
- regex_parameters_value = str(expression.args["parameters"].this)
849
+ if parameters := expression.args["parameters"]:
824
850
  # 'e' parameter doesn't make sense for duckdb
825
- regex_parameters = exp.Literal(this=regex_parameters_value.replace("e", ""), is_string=True)
826
- except KeyError:
851
+ regex_parameters = exp.Literal(this=parameters.this.replace("e", ""), is_string=True)
852
+ else:
827
853
  regex_parameters = exp.Literal(is_string=True)
828
854
 
829
- try:
830
- group_num = expression.args["group"]
831
- except KeyError:
855
+ group_num = expression.args["group"]
856
+ if not group_num:
832
857
  if isinstance(regex_parameters.this, str) and "e" in regex_parameters.this:
833
858
  group_num = exp.Literal(this="1", is_string=False)
834
859
  else:
@@ -1018,7 +1043,7 @@ def tag(expression: exp.Expression) -> exp.Expression:
1018
1043
  exp.Expression: The transformed expression.
1019
1044
  """
1020
1045
 
1021
- if isinstance(expression, exp.AlterTable) and (actions := expression.args.get("actions")):
1046
+ if isinstance(expression, exp.Alter) and (actions := expression.args.get("actions")):
1022
1047
  for a in actions:
1023
1048
  if isinstance(a, exp.AlterSet) and a.args.get("tag"):
1024
1049
  return SUCCESS_NOP
@@ -0,0 +1,203 @@
1
+ import sqlglot
2
+ from sqlglot import exp
3
+
4
+ from fakesnow import checks
5
+
6
+ # Implements snowflake's MERGE INTO functionality in duckdb (https://docs.snowflake.com/en/sql-reference/sql/merge).
7
+
8
+
9
+ def merge(merge_expr: exp.Expression) -> list[exp.Expression]:
10
+ if not isinstance(merge_expr, exp.Merge):
11
+ return [merge_expr]
12
+
13
+ return [_create_merge_candidates(merge_expr), *_mutations(merge_expr), _counts(merge_expr)]
14
+
15
+
16
+ def _create_merge_candidates(merge_expr: exp.Merge) -> exp.Expression:
17
+ """
18
+ Given a merge statement, produce a temporary table that joins together the target and source tables.
19
+ The merge_op column identifies which merge clause applies to the row.
20
+ """
21
+ target_tbl = merge_expr.this
22
+
23
+ source = merge_expr.args.get("using")
24
+ assert isinstance(source, exp.Expression)
25
+ source_id = (alias := source.args.get("alias")) and alias.this if isinstance(source, exp.Subquery) else source.this
26
+ assert isinstance(source_id, exp.Identifier)
27
+
28
+ join_expr = merge_expr.args.get("on")
29
+ assert isinstance(join_expr, exp.Binary)
30
+
31
+ case_when_clauses: list[str] = []
32
+ values: set[str] = set()
33
+
34
+ # extract keys that reference the source table from the join expression
35
+ # so they can be used by the mutation statements for joining
36
+ # will include the source table identifier
37
+ values.update(
38
+ map(
39
+ str,
40
+ {
41
+ c
42
+ for c in join_expr.find_all(exp.Column)
43
+ if (table := c.args.get("table"))
44
+ and isinstance(table, exp.Identifier)
45
+ and checks.equal(table, source_id)
46
+ },
47
+ )
48
+ )
49
+
50
+ # Iterate through the WHEN clauses to build up the CASE WHEN clauses
51
+ for w_idx, w in enumerate(merge_expr.expressions):
52
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
53
+
54
+ predicate = join_expr.copy()
55
+ matched = w.args.get("matched")
56
+ then = w.args.get("then")
57
+ condition = w.args.get("condition")
58
+
59
+ if matched:
60
+ # matchedClause see https://docs.snowflake.com/en/sql-reference/sql/merge#matchedclause-for-updates-or-deletes
61
+ if condition:
62
+ # Combine the top level ON expression with the AND condition
63
+ # from this specific WHEN into a subquery, we use to target rows.
64
+ # Eg. MERGE INTO t1 USING t2 ON t1.t1Key = t2.t2Key
65
+ # WHEN MATCHED AND t2.marked = 1 THEN DELETE
66
+ predicate = exp.And(this=predicate, expression=condition)
67
+
68
+ if isinstance(then, exp.Update):
69
+ case_when_clauses.append(f"WHEN {predicate} THEN {w_idx}")
70
+ values.update([str(c.expression) for c in then.expressions if isinstance(c.expression, exp.Column)])
71
+ elif isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
72
+ case_when_clauses.append(f"WHEN {predicate} THEN {w_idx}")
73
+ else:
74
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
75
+ else:
76
+ # notMatchedClause see https://docs.snowflake.com/en/sql-reference/sql/merge#notmatchedclause-for-inserts
77
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
78
+ insert_values = then.expression.expressions
79
+ values.update([str(c) for c in insert_values if isinstance(c, exp.Column)])
80
+ predicate = f"AND {condition}" if condition else ""
81
+ case_when_clauses.append(f"WHEN {target_tbl}.rowid is NULL {predicate} THEN {w_idx}")
82
+
83
+ sql = f"""
84
+ CREATE OR REPLACE TEMPORARY TABLE merge_candidates AS
85
+ SELECT
86
+ {', '.join(sorted(values))},
87
+ CASE
88
+ {' '.join(case_when_clauses)}
89
+ ELSE NULL
90
+ END AS MERGE_OP
91
+ FROM {target_tbl}
92
+ FULL OUTER JOIN {source} ON {join_expr.sql()}
93
+ WHERE MERGE_OP IS NOT NULL
94
+ """
95
+
96
+ return sqlglot.parse_one(sql)
97
+
98
+
99
+ def _mutations(merge_expr: exp.Merge) -> list[exp.Expression]:
100
+ """
101
+ Given a merge statement, produce a list of delete, update and insert statements that use the
102
+ merge_candidates and source table to update the target target.
103
+ """
104
+ target_tbl = merge_expr.this
105
+ source = merge_expr.args.get("using")
106
+ source_tbl = source.alias if isinstance(source, exp.Subquery) else source
107
+ join_expr = merge_expr.args.get("on")
108
+
109
+ statements: list[exp.Expression] = []
110
+
111
+ # Iterate through the WHEN clauses to generate delete/update/insert statements
112
+ for w_idx, w in enumerate(merge_expr.expressions):
113
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
114
+
115
+ matched = w.args.get("matched")
116
+ then = w.args.get("then")
117
+
118
+ if matched:
119
+ if isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
120
+ delete_sql = f"""
121
+ DELETE FROM {target_tbl}
122
+ USING merge_candidates AS {source_tbl}
123
+ WHERE {join_expr}
124
+ AND {source_tbl}.merge_op = {w_idx}
125
+ """
126
+ statements.append(sqlglot.parse_one(delete_sql))
127
+ elif isinstance(then, exp.Update):
128
+ # when the update statement has a table alias, duckdb doesn't support the alias in the set
129
+ # column name, so we use e.this.this to get just the column name without its table prefix
130
+ set_clauses = ", ".join(
131
+ [f"{e.this.this} = {e.expression.sql()}" for e in then.args.get("expressions", [])]
132
+ )
133
+ update_sql = f"""
134
+ UPDATE {target_tbl}
135
+ SET {set_clauses}
136
+ FROM merge_candidates AS {source_tbl}
137
+ WHERE {join_expr}
138
+ AND {source_tbl}.merge_op = {w_idx}
139
+ """
140
+ statements.append(sqlglot.parse_one(update_sql))
141
+ else:
142
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
143
+ else:
144
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
145
+ cols = [str(c) for c in then.this.expressions] if then.this else []
146
+ columns = f"({', '.join(cols)})" if cols else ""
147
+ values = ", ".join(map(str, then.expression.expressions))
148
+ insert_sql = f"""
149
+ INSERT INTO {target_tbl} {columns}
150
+ SELECT {values}
151
+ FROM merge_candidates AS {source_tbl}
152
+ WHERE {source_tbl}.merge_op = {w_idx}
153
+ """
154
+ statements.append(sqlglot.parse_one(insert_sql))
155
+
156
+ return statements
157
+
158
+
159
+ def _counts(merge_expr: exp.Merge) -> exp.Expression:
160
+ """
161
+ Given a merge statement, derive the a SQL statement which produces the following columns using the merge_candidates
162
+ table:
163
+
164
+ - "number of rows inserted"
165
+ - "number of rows updated"
166
+ - "number of rows deleted"
167
+
168
+ Only columns relevant to the merge operation are included, eg: if no rows are deleted, the "number of rows deleted"
169
+ column is not included.
170
+ """
171
+
172
+ # Initialize dictionaries to store operation types and their corresponding indices
173
+ operations = {"inserted": [], "updated": [], "deleted": []}
174
+
175
+ # Iterate through the WHEN clauses to categorize operations
176
+ for w_idx, w in enumerate(merge_expr.expressions):
177
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
178
+
179
+ matched = w.args.get("matched")
180
+ then = w.args.get("then")
181
+
182
+ if matched:
183
+ if isinstance(then, exp.Update):
184
+ operations["updated"].append(w_idx)
185
+ elif isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
186
+ operations["deleted"].append(w_idx)
187
+ else:
188
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
189
+ else:
190
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
191
+ operations["inserted"].append(w_idx)
192
+
193
+ count_statements = [
194
+ f"""COUNT_IF(merge_op in ({','.join(map(str, indices))})) as \"number of rows {op}\""""
195
+ for op, indices in operations.items()
196
+ if indices
197
+ ]
198
+ sql = f"""
199
+ SELECT {', '.join(count_statements)}
200
+ FROM merge_candidates
201
+ """
202
+
203
+ return sqlglot.parse_one(sql)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fakesnow
3
- Version: 0.9.24
3
+ Version: 0.9.26
4
4
  Summary: Fake Snowflake Connector for Python. Run, mock and test Snowflake DB locally.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -210,22 +210,22 @@ Classifier: License :: OSI Approved :: MIT License
210
210
  Requires-Python: >=3.9
211
211
  Description-Content-Type: text/markdown
212
212
  License-File: LICENSE
213
- Requires-Dist: duckdb~=1.0.0
213
+ Requires-Dist: duckdb~=1.1.3
214
214
  Requires-Dist: pyarrow
215
215
  Requires-Dist: snowflake-connector-python
216
- Requires-Dist: sqlglot~=25.9.0
216
+ Requires-Dist: sqlglot~=25.24.1
217
217
  Provides-Extra: dev
218
218
  Requires-Dist: build~=1.0; extra == "dev"
219
219
  Requires-Dist: dirty-equals; extra == "dev"
220
220
  Requires-Dist: pandas-stubs; extra == "dev"
221
221
  Requires-Dist: snowflake-connector-python[pandas,secure-local-storage]; extra == "dev"
222
- Requires-Dist: pre-commit~=3.4; extra == "dev"
223
- Requires-Dist: pyarrow-stubs; extra == "dev"
222
+ Requires-Dist: pre-commit~=4.0; extra == "dev"
223
+ Requires-Dist: pyarrow-stubs==10.0.1.9; extra == "dev"
224
224
  Requires-Dist: pytest~=8.0; extra == "dev"
225
225
  Requires-Dist: pytest-asyncio; extra == "dev"
226
- Requires-Dist: ruff~=0.5.1; extra == "dev"
226
+ Requires-Dist: ruff~=0.7.2; extra == "dev"
227
227
  Requires-Dist: twine~=5.0; extra == "dev"
228
- Requires-Dist: snowflake-sqlalchemy~=1.5.0; extra == "dev"
228
+ Requires-Dist: snowflake-sqlalchemy~=1.6.1; extra == "dev"
229
229
  Provides-Extra: notebook
230
230
  Requires-Dist: duckdb-engine; extra == "notebook"
231
231
  Requires-Dist: ipykernel; extra == "notebook"
@@ -0,0 +1,26 @@
1
+ fakesnow/__init__.py,sha256=qUfgucQYPdELrJaxczalhJgWAWQ6cfTCUAHx6nUqRaI,3528
2
+ fakesnow/__main__.py,sha256=GDrGyNTvBFuqn_UfDjKs7b3LPtU6gDv1KwosVDrukIM,76
3
+ fakesnow/arrow.py,sha256=EGAYeuCnRuvmWBEGqw2YOcgQR4zcCsZBu85kSRl70dQ,4698
4
+ fakesnow/checks.py,sha256=N8sXldhS3u1gG32qvZ4VFlsKgavRKrQrxLiQU8am1lw,2691
5
+ fakesnow/cli.py,sha256=9qfI-Ssr6mo8UmIlXkUAOz2z2YPBgDsrEVaZv9FjGFs,2201
6
+ fakesnow/conn.py,sha256=Gy_Z7BZRm5yMjV3x6hR4iegDQFdG9aJBjqWdc3iWYFU,5353
7
+ fakesnow/cursor.py,sha256=8wWtRCxzrM1yiHmH2C-9CT0b98nTzr23ygeaEAkumRE,20086
8
+ fakesnow/expr.py,sha256=CAxuYIUkwI339DQIBzvFF0F-m1tcVGKEPA5rDTzmH9A,892
9
+ fakesnow/fakes.py,sha256=JQTiUkkwPeQrJ8FDWhPFPK6pGwd_aR2oiOrNzCWznlM,187
10
+ fakesnow/fixtures.py,sha256=G-NkVeruSQAJ7fvSS2fR2oysUn0Yra1pohHlOvacKEk,455
11
+ fakesnow/info_schema.py,sha256=nsDceFtjiSXrvkksKziVvqrefskaSyOmAspBwMAsaDg,6307
12
+ fakesnow/instance.py,sha256=3cJvPRuFy19dMKXbtBLl6imzO48pEw8uTYhZyFDuwhk,3133
13
+ fakesnow/macros.py,sha256=pX1YJDnQOkFJSHYUjQ6ErEkYIKvFI6Ncz_au0vv1csA,265
14
+ fakesnow/pandas_tools.py,sha256=WjyjTV8QUCQQaCGboaEOvx2uo4BkknpWYjtLwkeCY6U,3468
15
+ fakesnow/py.typed,sha256=B-DLSjYBi7pkKjwxCSdpVj2J02wgfJr-E7B1wOUyxYU,80
16
+ fakesnow/server.py,sha256=SO5xKZ4rvySsuKDsoSPSCZcFuIX_K7d1XJYhRRJ-7Bk,4150
17
+ fakesnow/transforms.py,sha256=VFLA5Fc1i4FuiVdvUuDrK-kA2caqiT8Gw9btMDPJhRA,55367
18
+ fakesnow/transforms_merge.py,sha256=7rq-UPjfFNRrFsqR8xx3otwP6-k4eslLVLhfuqSXq1A,8314
19
+ fakesnow/types.py,sha256=9Tt83Z7ctc9_v6SYyayXYz4MEI4RZo4zq_uqdj4g3Dk,2681
20
+ fakesnow/variables.py,sha256=WXyPnkeNwD08gy52yF66CVe2twiYC50tztNfgXV4q1k,3032
21
+ fakesnow-0.9.26.dist-info/LICENSE,sha256=kW-7NWIyaRMQiDpryfSmF2DObDZHGR1cJZ39s6B1Svg,11344
22
+ fakesnow-0.9.26.dist-info/METADATA,sha256=92zIwzq7FP-BrfhUcKbdbqYs0eqN9TCKvT_NVdEKZTI,18075
23
+ fakesnow-0.9.26.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
24
+ fakesnow-0.9.26.dist-info/entry_points.txt,sha256=2riAUgu928ZIHawtO8EsfrMEJhi-EH-z_Vq7Q44xKPM,47
25
+ fakesnow-0.9.26.dist-info/top_level.txt,sha256=500evXI1IFX9so82cizGIEMHAb_dJNPaZvd2H9dcKTA,24
26
+ fakesnow-0.9.26.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.44.0)
2
+ Generator: bdist_wheel (0.45.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,25 +0,0 @@
1
- fakesnow/__init__.py,sha256=9tFJJKvowKNW3vfnlmza6hOLN1I52DwChgNc5Ew6CcA,3499
2
- fakesnow/__main__.py,sha256=GDrGyNTvBFuqn_UfDjKs7b3LPtU6gDv1KwosVDrukIM,76
3
- fakesnow/arrow.py,sha256=WLkr1nEiNxUcPdzadKSM33sRAiQJsN6LvuzTVIsi3D0,2766
4
- fakesnow/checks.py,sha256=-QMvdcrRbhN60rnzxLBJ0IkUBWyLR8gGGKKmCS0w9mA,2383
5
- fakesnow/cli.py,sha256=9qfI-Ssr6mo8UmIlXkUAOz2z2YPBgDsrEVaZv9FjGFs,2201
6
- fakesnow/conn.py,sha256=Gy_Z7BZRm5yMjV3x6hR4iegDQFdG9aJBjqWdc3iWYFU,5353
7
- fakesnow/cursor.py,sha256=2PtW9hzfXs3mzv6BBxXLoS-pPtD4otrfQ2KnPNNanGI,19441
8
- fakesnow/expr.py,sha256=CAxuYIUkwI339DQIBzvFF0F-m1tcVGKEPA5rDTzmH9A,892
9
- fakesnow/fakes.py,sha256=JQTiUkkwPeQrJ8FDWhPFPK6pGwd_aR2oiOrNzCWznlM,187
10
- fakesnow/fixtures.py,sha256=G-NkVeruSQAJ7fvSS2fR2oysUn0Yra1pohHlOvacKEk,455
11
- fakesnow/info_schema.py,sha256=DObVOrhzppAFHsdtj4YI9oRISn9SkJUG6ONjVleQQ_Y,6303
12
- fakesnow/instance.py,sha256=3cJvPRuFy19dMKXbtBLl6imzO48pEw8uTYhZyFDuwhk,3133
13
- fakesnow/macros.py,sha256=pX1YJDnQOkFJSHYUjQ6ErEkYIKvFI6Ncz_au0vv1csA,265
14
- fakesnow/pandas_tools.py,sha256=WjyjTV8QUCQQaCGboaEOvx2uo4BkknpWYjtLwkeCY6U,3468
15
- fakesnow/py.typed,sha256=B-DLSjYBi7pkKjwxCSdpVj2J02wgfJr-E7B1wOUyxYU,80
16
- fakesnow/server.py,sha256=8dzaLUUXPzCMm6-ESn0CBws6XSwwOpnUuHQAZJ-4SwU,3011
17
- fakesnow/transforms.py,sha256=ellcY5OBc7mqgL9ChNolrqcCLWXF9RH21Jt88FcFl-I,54419
18
- fakesnow/types.py,sha256=9Tt83Z7ctc9_v6SYyayXYz4MEI4RZo4zq_uqdj4g3Dk,2681
19
- fakesnow/variables.py,sha256=WXyPnkeNwD08gy52yF66CVe2twiYC50tztNfgXV4q1k,3032
20
- fakesnow-0.9.24.dist-info/LICENSE,sha256=kW-7NWIyaRMQiDpryfSmF2DObDZHGR1cJZ39s6B1Svg,11344
21
- fakesnow-0.9.24.dist-info/METADATA,sha256=LHKc6JYn9sxxFh6_i7kqlWz1fmloFv2CCmpalwPVFrE,18064
22
- fakesnow-0.9.24.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
23
- fakesnow-0.9.24.dist-info/entry_points.txt,sha256=2riAUgu928ZIHawtO8EsfrMEJhi-EH-z_Vq7Q44xKPM,47
24
- fakesnow-0.9.24.dist-info/top_level.txt,sha256=500evXI1IFX9so82cizGIEMHAb_dJNPaZvd2H9dcKTA,24
25
- fakesnow-0.9.24.dist-info/RECORD,,