fakesnow 0.9.23__py3-none-any.whl → 0.9.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fakesnow/__init__.py CHANGED
@@ -90,3 +90,4 @@ def patch(
90
90
  yield None
91
91
  finally:
92
92
  stack.close()
93
+ fs.duck_conn.close()
fakesnow/arrow.py CHANGED
@@ -1,33 +1,45 @@
1
- from typing import Any
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
2
4
 
3
5
  import pyarrow as pa
6
+ import pyarrow.compute as pc
7
+
8
+ from fakesnow.types import ColumnInfo
9
+
10
+
11
+ def to_sf_schema(schema: pa.Schema, rowtype: list[ColumnInfo]) -> pa.Schema:
12
+ # expected by the snowflake connector
13
+ # uses rowtype to populate metadata, rather than the arrow schema type, for consistency with
14
+ # rowtype returned in the response
4
15
 
16
+ assert len(schema) == len(rowtype), f"schema and rowtype must be same length but f{len(schema)=} f{len(rowtype)=}"
5
17
 
6
- def with_sf_metadata(schema: pa.Schema) -> pa.Schema:
7
18
  # see https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp#L32
8
19
  # and https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp#L10
9
- fms = []
10
- for i, t in enumerate(schema.types):
11
- f = schema.field(i)
12
-
13
- # TODO: precision, scale, charLength etc. for all types
14
-
15
- if t == pa.bool_():
16
- fm = f.with_metadata({"logicalType": "BOOLEAN"})
17
- elif t == pa.int64():
18
- # scale and precision required, see here
19
- # https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147
20
- fm = f.with_metadata({"logicalType": "FIXED", "precision": "38", "scale": "0"})
21
- elif t == pa.float64():
22
- fm = f.with_metadata({"logicalType": "REAL"})
23
- elif isinstance(t, pa.Decimal128Type):
24
- fm = f.with_metadata({"logicalType": "FIXED", "precision": str(t.precision), "scale": str(t.scale)})
25
- elif t == pa.string():
26
- # TODO: set charLength to size of column
27
- fm = f.with_metadata({"logicalType": "TEXT", "charLength": "16777216"})
28
- else:
29
- raise NotImplementedError(f"Unsupported Arrow type: {t}")
30
- fms.append(fm)
20
+
21
+ def sf_field(field: pa.Field, c: ColumnInfo) -> pa.Field:
22
+ if isinstance(field.type, pa.TimestampType):
23
+ # snowflake uses a struct to represent timestamps, see timestamp_to_sf_struct
24
+ fields = [pa.field("epoch", pa.int64(), nullable=False), pa.field("fraction", pa.int32(), nullable=False)]
25
+ if field.type.tz:
26
+ fields.append(pa.field("timezone", nullable=False, type=pa.int32()))
27
+ field = field.with_type(pa.struct(fields))
28
+ elif isinstance(field.type, pa.Time64Type):
29
+ field = field.with_type(pa.int64())
30
+
31
+ return field.with_metadata(
32
+ {
33
+ "logicalType": c["type"].upper(),
34
+ # required for FIXED type see
35
+ # https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147
36
+ "precision": str(c["precision"] or 38),
37
+ "scale": str(c["scale"] or 0),
38
+ "charLength": str(c["length"] or 0),
39
+ }
40
+ )
41
+
42
+ fms = [sf_field(schema.field(i), c) for i, c in enumerate(rowtype)]
31
43
  return pa.schema(fms)
32
44
 
33
45
 
@@ -39,29 +51,56 @@ def to_ipc(table: pa.Table) -> pa.Buffer:
39
51
 
40
52
  sink = pa.BufferOutputStream()
41
53
 
42
- with pa.ipc.new_stream(sink, with_sf_metadata(table.schema)) as writer:
54
+ with pa.ipc.new_stream(sink, table.schema) as writer:
43
55
  writer.write_batch(batch)
44
56
 
45
57
  return sink.getvalue()
46
58
 
47
59
 
48
- # TODO: should this be derived before with_schema?
49
- def to_rowtype(schema: pa.Schema) -> list[dict[str, Any]]:
50
- return [
51
- {
52
- "name": f.name,
53
- # TODO
54
- # "database": "",
55
- # "schema": "",
56
- # "table": "",
57
- "nullable": f.nullable,
58
- "type": f.metadata.get(b"logicalType").decode("utf-8").lower(), # type: ignore
59
- # TODO
60
- # "byteLength": 20,
61
- "length": int(f.metadata.get(b"charLength")) if f.metadata.get(b"charLength") else None, # type: ignore
62
- "scale": int(f.metadata.get(b"scale")) if f.metadata.get(b"scale") else None, # type: ignore
63
- "precision": int(f.metadata.get(b"precision")) if f.metadata.get(b"precision") else None, # type: ignore
64
- "collation": None,
65
- }
66
- for f in schema
67
- ]
60
+ def to_sf(table: pa.Table, rowtype: list[ColumnInfo]) -> pa.Table:
61
+ def to_sf_col(col: pa.Array) -> pa.Array:
62
+ if pa.types.is_timestamp(col.type):
63
+ return timestamp_to_sf_struct(col)
64
+ elif pa.types.is_time(col.type):
65
+ # as nanoseconds
66
+ return pc.multiply(col.cast(pa.int64()), 1000) # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/44
67
+ return col
68
+
69
+ return pa.Table.from_arrays([to_sf_col(c) for c in table.columns], schema=to_sf_schema(table.schema, rowtype))
70
+
71
+
72
+ def timestamp_to_sf_struct(ts: pa.Array | pa.ChunkedArray) -> pa.Array:
73
+ if isinstance(ts, pa.ChunkedArray):
74
+ # combine because pa.StructArray.from_arrays doesn't support ChunkedArray
75
+ ts = cast(pa.Array, ts.combine_chunks()) # see https://github.com/zen-xu/pyarrow-stubs/issues/46
76
+
77
+ if not isinstance(ts.type, pa.TimestampType):
78
+ raise ValueError(f"Expected TimestampArray, got {type(ts)}")
79
+
80
+ # Round to seconds, ie: strip subseconds
81
+ tsa_without_us = pc.floor_temporal(ts, unit="second") # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/45
82
+ epoch = pc.divide(tsa_without_us.cast(pa.int64()), 1_000_000) # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/44
83
+
84
+ # Calculate fractional part as nanoseconds
85
+ fraction = pc.multiply(pc.subsecond(ts), 1_000_000_000).cast(pa.int32()) # type: ignore
86
+
87
+ if ts.type.tz:
88
+ assert ts.type.tz == "UTC", f"Timezone {ts.type.tz} not yet supported"
89
+ timezone = pa.array([1440] * len(ts), type=pa.int32())
90
+
91
+ return pa.StructArray.from_arrays(
92
+ arrays=[epoch, fraction, timezone], # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/42
93
+ fields=[
94
+ pa.field("epoch", nullable=False, type=pa.int64()),
95
+ pa.field("fraction", nullable=False, type=pa.int32()),
96
+ pa.field("timezone", nullable=False, type=pa.int32()),
97
+ ],
98
+ )
99
+ else:
100
+ return pa.StructArray.from_arrays(
101
+ arrays=[epoch, fraction], # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/42
102
+ fields=[
103
+ pa.field("epoch", nullable=False, type=pa.int64()),
104
+ pa.field("fraction", nullable=False, type=pa.int32()),
105
+ ],
106
+ )
fakesnow/checks.py CHANGED
@@ -68,3 +68,11 @@ def is_unqualified_table_expression(expression: exp.Expression) -> tuple[bool, b
68
68
  no_schema = not node.args.get("db")
69
69
 
70
70
  return no_database, no_schema
71
+
72
+
73
+ def equal(left: exp.Identifier, right: exp.Identifier) -> bool:
74
+ # as per https://docs.snowflake.com/en/sql-reference/identifiers-syntax#label-identifier-casing
75
+ lid = left.this if left.quoted else left.this.upper()
76
+ rid = right.this if right.quoted else right.this.upper()
77
+
78
+ return lid == rid
fakesnow/conn.py CHANGED
@@ -1,13 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
- import json
4
3
  import os
5
4
  from collections.abc import Iterable
6
5
  from pathlib import Path
7
6
  from types import TracebackType
8
7
  from typing import Any
9
8
 
10
- import pandas as pd
11
9
  import snowflake.connector.converter
12
10
  import snowflake.connector.errors
13
11
  import sqlglot
@@ -147,29 +145,3 @@ class FakeSnowflakeConnection:
147
145
 
148
146
  def rollback(self) -> None:
149
147
  self.cursor().execute("ROLLBACK")
150
-
151
- def _insert_df(self, df: pd.DataFrame, table_name: str) -> int:
152
- # Objects in dataframes are written as parquet structs, and snowflake loads parquet structs as json strings.
153
- # Whereas duckdb analyses a dataframe see https://duckdb.org/docs/api/python/data_ingestion.html#pandas-dataframes--object-columns
154
- # and converts a object to the most specific type possible, eg: dict -> STRUCT, MAP or varchar, and list -> LIST
155
- # For dicts see https://github.com/duckdb/duckdb/pull/3985 and https://github.com/duckdb/duckdb/issues/9510
156
- #
157
- # When the rows have dicts with different keys there isn't a single STRUCT that can cover them, so the type is
158
- # varchar and value a string containing a struct representation. In order to support dicts with different keys
159
- # we first convert the dicts to json strings. A pity we can't do something inside duckdb and avoid the dataframe
160
- # copy and transform in python.
161
-
162
- df = df.copy()
163
-
164
- # Identify columns of type object
165
- object_cols = df.select_dtypes(include=["object"]).columns
166
-
167
- # Apply json.dumps to these columns
168
- for col in object_cols:
169
- # don't jsonify string
170
- df[col] = df[col].apply(lambda x: json.dumps(x) if isinstance(x, (dict, list)) else x)
171
-
172
- escaped_cols = ",".join(f'"{col}"' for col in df.columns.to_list())
173
- self._duck_conn.execute(f"INSERT INTO {table_name}({escaped_cols}) SELECT * FROM df")
174
-
175
- return self._duck_conn.fetchall()[0][0]
fakesnow/cursor.py CHANGED
@@ -26,9 +26,11 @@ import fakesnow.transforms as transforms
26
26
  from fakesnow.types import describe_as_result_metadata
27
27
 
28
28
  if TYPE_CHECKING:
29
+ # don't require pandas at import time
29
30
  import pandas as pd
30
31
  import pyarrow.lib
31
32
 
33
+ # avoid circular import
32
34
  from fakesnow.conn import FakeSnowflakeConnection
33
35
 
34
36
 
@@ -110,13 +112,15 @@ class FakeSnowflakeCursor:
110
112
 
111
113
  @property
112
114
  def description(self) -> list[ResultMetadata]:
115
+ return describe_as_result_metadata(self._describe_last_sql())
116
+
117
+ def _describe_last_sql(self) -> list:
113
118
  # use a separate cursor to avoid consuming the result set on this cursor
114
119
  with self._conn.cursor() as cur:
120
+ # TODO: can we replace with self._duck_conn.description?
115
121
  expression = sqlglot.parse_one(f"DESCRIBE {self._last_sql}", read="duckdb")
116
122
  cur._execute(expression, self._last_params) # noqa: SLF001
117
- meta = describe_as_result_metadata(cur.fetchall())
118
-
119
- return meta
123
+ return cur.fetchall()
120
124
 
121
125
  def execute(
122
126
  self,
@@ -135,10 +139,15 @@ class FakeSnowflakeCursor:
135
139
  command, params = self._rewrite_with_params(command, params)
136
140
  if self._conn.nop_regexes and any(re.match(p, command, re.IGNORECASE) for p in self._conn.nop_regexes):
137
141
  transformed = transforms.SUCCESS_NOP
138
- else:
139
- expression = parse_one(command, read="snowflake")
140
- transformed = self._transform(expression)
141
- return self._execute(transformed, params)
142
+ self._execute(transformed, params)
143
+ return self
144
+
145
+ expression = parse_one(command, read="snowflake")
146
+ for exp in self._transform_explode(expression):
147
+ transformed = self._transform(exp)
148
+ self._execute(transformed, params)
149
+
150
+ return self
142
151
  except snowflake.connector.errors.ProgrammingError as e:
143
152
  self._sqlstate = e.sqlstate
144
153
  raise e
@@ -203,9 +212,12 @@ class FakeSnowflakeCursor:
203
212
  .transform(transforms.alter_table_strip_cluster_by)
204
213
  )
205
214
 
206
- def _execute(
207
- self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None
208
- ) -> FakeSnowflakeCursor:
215
+ def _transform_explode(self, expression: exp.Expression) -> list[exp.Expression]:
216
+ # Applies transformations that require splitting the expression into multiple expressions
217
+ # Split transforms have limited support at the moment.
218
+ return transforms.merge(expression)
219
+
220
+ def _execute(self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
209
221
  self._arrow_table = None
210
222
  self._arrow_table_fetch_index = None
211
223
  self._rowcount = None
@@ -341,8 +353,6 @@ class FakeSnowflakeCursor:
341
353
  self._last_sql = result_sql or sql
342
354
  self._last_params = params
343
355
 
344
- return self
345
-
346
356
  def _log_sql(self, sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
347
357
  if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
348
358
  print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
fakesnow/pandas_tools.py CHANGED
@@ -1,14 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import json
3
4
  from collections.abc import Sequence
4
5
  from typing import TYPE_CHECKING, Any, Literal, Optional
5
6
 
6
7
  import numpy as np
8
+ from duckdb import DuckDBPyConnection
9
+
10
+ from fakesnow.conn import FakeSnowflakeConnection
7
11
 
8
12
  if TYPE_CHECKING:
13
+ # don't require pandas at import time
9
14
  import pandas as pd
10
15
 
11
- from fakesnow.conn import FakeSnowflakeConnection
12
16
 
13
17
  CopyResult = tuple[
14
18
  str,
@@ -68,10 +72,37 @@ def write_pandas(
68
72
 
69
73
  conn.cursor().execute(f"CREATE TABLE IF NOT EXISTS {name} ({','.join(cols)})")
70
74
 
71
- count = conn._insert_df(df, name) # noqa: SLF001
75
+ count = _insert_df(conn._duck_conn, df, name) # noqa: SLF001
72
76
 
73
77
  # mocks https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#output
74
78
  mock_copy_results = [("fakesnow/file0.txt", "LOADED", count, count, 1, 0, None, None, None, None)]
75
79
 
76
80
  # return success
77
81
  return (True, len(mock_copy_results), count, mock_copy_results)
82
+
83
+
84
+ def _insert_df(duck_conn: DuckDBPyConnection, df: pd.DataFrame, table_name: str) -> int:
85
+ # Objects in dataframes are written as parquet structs, and snowflake loads parquet structs as json strings.
86
+ # Whereas duckdb analyses a dataframe see https://duckdb.org/docs/api/python/data_ingestion.html#pandas-dataframes--object-columns
87
+ # and converts a object to the most specific type possible, eg: dict -> STRUCT, MAP or varchar, and list -> LIST
88
+ # For dicts see https://github.com/duckdb/duckdb/pull/3985 and https://github.com/duckdb/duckdb/issues/9510
89
+ #
90
+ # When the rows have dicts with different keys there isn't a single STRUCT that can cover them, so the type is
91
+ # varchar and value a string containing a struct representation. In order to support dicts with different keys
92
+ # we first convert the dicts to json strings. A pity we can't do something inside duckdb and avoid the dataframe
93
+ # copy and transform in python.
94
+
95
+ df = df.copy()
96
+
97
+ # Identify columns of type object
98
+ object_cols = df.select_dtypes(include=["object"]).columns
99
+
100
+ # Apply json.dumps to these columns
101
+ for col in object_cols:
102
+ # don't jsonify string
103
+ df[col] = df[col].apply(lambda x: json.dumps(x) if isinstance(x, (dict, list)) else x)
104
+
105
+ escaped_cols = ",".join(f'"{col}"' for col in df.columns.to_list())
106
+ duck_conn.execute(f"INSERT INTO {table_name}({escaped_cols}) SELECT * FROM df")
107
+
108
+ return duck_conn.fetchall()[0][0]
fakesnow/server.py CHANGED
@@ -5,19 +5,22 @@ import json
5
5
  import secrets
6
6
  from base64 import b64encode
7
7
  from dataclasses import dataclass
8
+ from typing import Any
8
9
 
10
+ import snowflake.connector.errors
9
11
  from starlette.applications import Starlette
10
12
  from starlette.concurrency import run_in_threadpool
11
13
  from starlette.requests import Request
12
14
  from starlette.responses import JSONResponse
13
15
  from starlette.routing import Route
14
16
 
15
- from fakesnow.arrow import to_ipc, to_rowtype, with_sf_metadata
17
+ from fakesnow.arrow import to_ipc, to_sf
16
18
  from fakesnow.fakes import FakeSnowflakeConnection
17
19
  from fakesnow.instance import FakeSnow
20
+ from fakesnow.types import describe_as_rowtype
18
21
 
19
- fs = FakeSnow()
20
- sessions = {}
22
+ shared_fs = FakeSnow()
23
+ sessions: dict[str, FakeSnowflakeConnection] = {}
21
24
 
22
25
 
23
26
  @dataclass
@@ -27,9 +30,19 @@ class ServerError(Exception):
27
30
  message: str
28
31
 
29
32
 
30
- def login_request(request: Request) -> JSONResponse:
33
+ async def login_request(request: Request) -> JSONResponse:
31
34
  database = request.query_params.get("databaseName")
32
35
  schema = request.query_params.get("schemaName")
36
+ body = await request.body()
37
+ body_json = json.loads(gzip.decompress(body))
38
+ session_params: dict[str, Any] = body_json["data"]["SESSION_PARAMETERS"]
39
+ if db_path := session_params.get("FAKESNOW_DB_PATH"):
40
+ # isolated creates a new in-memory database, rather than using the shared in-memory database
41
+ # so this connection won't share any tables with other connections
42
+ fs = FakeSnow() if db_path == ":isolated:" else FakeSnow(db_path=db_path)
43
+ else:
44
+ # share the in-memory database across connections
45
+ fs = shared_fs
33
46
  token = secrets.token_urlsafe(32)
34
47
  sessions[token] = fs.connect(database, schema)
35
48
  return JSONResponse({"data": {"token": token}, "success": True})
@@ -44,16 +57,30 @@ async def query_request(request: Request) -> JSONResponse:
44
57
 
45
58
  sql_text = body_json["sqlText"]
46
59
 
47
- # only a single sql statement is sent at a time by the python snowflake connector
48
- cur = await run_in_threadpool(conn.cursor().execute, sql_text)
49
-
50
- assert cur._arrow_table, "No result set" # noqa: SLF001
51
-
52
- batch_bytes = to_ipc(cur._arrow_table) # noqa: SLF001
53
- rowset_b64 = b64encode(batch_bytes).decode("utf-8")
54
-
55
- # TODO: avoid calling with_sf_metadata twice
56
- rowtype = to_rowtype(with_sf_metadata(cur._arrow_table.schema)) # noqa: SLF001
60
+ try:
61
+ # only a single sql statement is sent at a time by the python snowflake connector
62
+ cur = await run_in_threadpool(conn.cursor().execute, sql_text)
63
+ except snowflake.connector.errors.ProgrammingError as e:
64
+ code = f"{e.errno:06d}"
65
+ return JSONResponse(
66
+ {
67
+ "data": {
68
+ "errorCode": code,
69
+ "sqlState": e.sqlstate,
70
+ },
71
+ "code": code,
72
+ "message": e.msg,
73
+ "success": False,
74
+ }
75
+ )
76
+
77
+ rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001
78
+
79
+ if cur._arrow_table: # noqa: SLF001
80
+ batch_bytes = to_ipc(to_sf(cur._arrow_table, rowtype)) # noqa: SLF001
81
+ rowset_b64 = b64encode(batch_bytes).decode("utf-8")
82
+ else:
83
+ rowset_b64 = ""
57
84
 
58
85
  return JSONResponse(
59
86
  {
fakesnow/transforms.py CHANGED
@@ -7,6 +7,7 @@ from typing import ClassVar, Literal, cast
7
7
  import sqlglot
8
8
  from sqlglot import exp
9
9
 
10
+ from fakesnow import transforms_merge
10
11
  from fakesnow.instance import USERS_TABLE_FQ_NAME
11
12
  from fakesnow.variables import Variables
12
13
 
@@ -691,6 +692,10 @@ def json_extract_precedence(expression: exp.Expression) -> exp.Expression:
691
692
  return expression
692
693
 
693
694
 
695
+ def merge(expression: exp.Expression) -> list[exp.Expression]:
696
+ return transforms_merge.merge(expression)
697
+
698
+
694
699
  def random(expression: exp.Expression) -> exp.Expression:
695
700
  """Convert random() and random(seed).
696
701
 
@@ -0,0 +1,203 @@
1
+ import sqlglot
2
+ from sqlglot import exp
3
+
4
+ from fakesnow import checks
5
+
6
+ # Implements snowflake's MERGE INTO functionality in duckdb (https://docs.snowflake.com/en/sql-reference/sql/merge).
7
+
8
+
9
+ def merge(merge_expr: exp.Expression) -> list[exp.Expression]:
10
+ if not isinstance(merge_expr, exp.Merge):
11
+ return [merge_expr]
12
+
13
+ return [_create_merge_candidates(merge_expr), *_mutations(merge_expr), _counts(merge_expr)]
14
+
15
+
16
+ def _create_merge_candidates(merge_expr: exp.Merge) -> exp.Expression:
17
+ """
18
+ Given a merge statement, produce a temporary table that joins together the target and source tables.
19
+ The merge_op column identifies which merge clause applies to the row.
20
+ """
21
+ target_tbl = merge_expr.this
22
+
23
+ source = merge_expr.args.get("using")
24
+ assert isinstance(source, exp.Expression)
25
+ source_id = (alias := source.args.get("alias")) and alias.this if isinstance(source, exp.Subquery) else source.this
26
+ assert isinstance(source_id, exp.Identifier)
27
+
28
+ join_expr = merge_expr.args.get("on")
29
+ assert isinstance(join_expr, exp.Binary)
30
+
31
+ case_when_clauses: list[str] = []
32
+ values: set[str] = set()
33
+
34
+ # extract keys that reference the source table from the join expression
35
+ # so they can be used by the mutation statements for joining
36
+ # will include the source table identifier
37
+ values.update(
38
+ map(
39
+ str,
40
+ {
41
+ c
42
+ for c in join_expr.find_all(exp.Column)
43
+ if (table := c.args.get("table"))
44
+ and isinstance(table, exp.Identifier)
45
+ and checks.equal(table, source_id)
46
+ },
47
+ )
48
+ )
49
+
50
+ # Iterate through the WHEN clauses to build up the CASE WHEN clauses
51
+ for w_idx, w in enumerate(merge_expr.expressions):
52
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
53
+
54
+ predicate = join_expr.copy()
55
+ matched = w.args.get("matched")
56
+ then = w.args.get("then")
57
+ condition = w.args.get("condition")
58
+
59
+ if matched:
60
+ # matchedClause see https://docs.snowflake.com/en/sql-reference/sql/merge#matchedclause-for-updates-or-deletes
61
+ if condition:
62
+ # Combine the top level ON expression with the AND condition
63
+ # from this specific WHEN into a subquery, we use to target rows.
64
+ # Eg. MERGE INTO t1 USING t2 ON t1.t1Key = t2.t2Key
65
+ # WHEN MATCHED AND t2.marked = 1 THEN DELETE
66
+ predicate = exp.And(this=predicate, expression=condition)
67
+
68
+ if isinstance(then, exp.Update):
69
+ case_when_clauses.append(f"WHEN {predicate} THEN {w_idx}")
70
+ values.update([str(c.expression) for c in then.expressions if isinstance(c.expression, exp.Column)])
71
+ elif isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
72
+ case_when_clauses.append(f"WHEN {predicate} THEN {w_idx}")
73
+ else:
74
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
75
+ else:
76
+ # notMatchedClause see https://docs.snowflake.com/en/sql-reference/sql/merge#notmatchedclause-for-inserts
77
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
78
+ insert_values = then.expression.expressions
79
+ values.update([str(c) for c in insert_values if isinstance(c, exp.Column)])
80
+ predicate = f"AND {condition}" if condition else ""
81
+ case_when_clauses.append(f"WHEN {target_tbl}.rowid is NULL {predicate} THEN {w_idx}")
82
+
83
+ sql = f"""
84
+ CREATE OR REPLACE TEMPORARY TABLE merge_candidates AS
85
+ SELECT
86
+ {', '.join(sorted(values))},
87
+ CASE
88
+ {' '.join(case_when_clauses)}
89
+ ELSE NULL
90
+ END AS MERGE_OP
91
+ FROM {target_tbl}
92
+ FULL OUTER JOIN {source} ON {join_expr.sql()}
93
+ WHERE MERGE_OP IS NOT NULL
94
+ """
95
+
96
+ return sqlglot.parse_one(sql)
97
+
98
+
99
+ def _mutations(merge_expr: exp.Merge) -> list[exp.Expression]:
100
+ """
101
+ Given a merge statement, produce a list of delete, update and insert statements that use the
102
+ merge_candidates and source table to update the target target.
103
+ """
104
+ target_tbl = merge_expr.this
105
+ source = merge_expr.args.get("using")
106
+ source_tbl = source.alias if isinstance(source, exp.Subquery) else source
107
+ join_expr = merge_expr.args.get("on")
108
+
109
+ statements: list[exp.Expression] = []
110
+
111
+ # Iterate through the WHEN clauses to generate delete/update/insert statements
112
+ for w_idx, w in enumerate(merge_expr.expressions):
113
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
114
+
115
+ matched = w.args.get("matched")
116
+ then = w.args.get("then")
117
+
118
+ if matched:
119
+ if isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
120
+ delete_sql = f"""
121
+ DELETE FROM {target_tbl}
122
+ USING merge_candidates AS {source_tbl}
123
+ WHERE {join_expr}
124
+ AND {source_tbl}.merge_op = {w_idx}
125
+ """
126
+ statements.append(sqlglot.parse_one(delete_sql))
127
+ elif isinstance(then, exp.Update):
128
+ # when the update statement has a table alias, duckdb doesn't support the alias in the set
129
+ # column name, so we use e.this.this to get just the column name without its table prefix
130
+ set_clauses = ", ".join(
131
+ [f"{e.this.this} = {e.expression.sql()}" for e in then.args.get("expressions", [])]
132
+ )
133
+ update_sql = f"""
134
+ UPDATE {target_tbl}
135
+ SET {set_clauses}
136
+ FROM merge_candidates AS {source_tbl}
137
+ WHERE {join_expr}
138
+ AND {source_tbl}.merge_op = {w_idx}
139
+ """
140
+ statements.append(sqlglot.parse_one(update_sql))
141
+ else:
142
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
143
+ else:
144
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
145
+ cols = [str(c) for c in then.this.expressions] if then.this else []
146
+ columns = f"({', '.join(cols)})" if cols else ""
147
+ values = ", ".join(map(str, then.expression.expressions))
148
+ insert_sql = f"""
149
+ INSERT INTO {target_tbl} {columns}
150
+ SELECT {values}
151
+ FROM merge_candidates AS {source_tbl}
152
+ WHERE {source_tbl}.merge_op = {w_idx}
153
+ """
154
+ statements.append(sqlglot.parse_one(insert_sql))
155
+
156
+ return statements
157
+
158
+
159
+ def _counts(merge_expr: exp.Merge) -> exp.Expression:
160
+ """
161
+ Given a merge statement, derive the a SQL statement which produces the following columns using the merge_candidates
162
+ table:
163
+
164
+ - "number of rows inserted"
165
+ - "number of rows updated"
166
+ - "number of rows deleted"
167
+
168
+ Only columns relevant to the merge operation are included, eg: if no rows are deleted, the "number of rows deleted"
169
+ column is not included.
170
+ """
171
+
172
+ # Initialize dictionaries to store operation types and their corresponding indices
173
+ operations = {"inserted": [], "updated": [], "deleted": []}
174
+
175
+ # Iterate through the WHEN clauses to categorize operations
176
+ for w_idx, w in enumerate(merge_expr.expressions):
177
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
178
+
179
+ matched = w.args.get("matched")
180
+ then = w.args.get("then")
181
+
182
+ if matched:
183
+ if isinstance(then, exp.Update):
184
+ operations["updated"].append(w_idx)
185
+ elif isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
186
+ operations["deleted"].append(w_idx)
187
+ else:
188
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
189
+ else:
190
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
191
+ operations["inserted"].append(w_idx)
192
+
193
+ count_statements = [
194
+ f"""COUNT_IF(merge_op in ({','.join(map(str, indices))})) as \"number of rows {op}\""""
195
+ for op, indices in operations.items()
196
+ if indices
197
+ ]
198
+ sql = f"""
199
+ SELECT {', '.join(count_statements)}
200
+ FROM merge_candidates
201
+ """
202
+
203
+ return sqlglot.parse_one(sql)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fakesnow
3
- Version: 0.9.23
3
+ Version: 0.9.25
4
4
  Summary: Fake Snowflake Connector for Python. Run, mock and test Snowflake DB locally.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -220,10 +220,10 @@ Requires-Dist: dirty-equals; extra == "dev"
220
220
  Requires-Dist: pandas-stubs; extra == "dev"
221
221
  Requires-Dist: snowflake-connector-python[pandas,secure-local-storage]; extra == "dev"
222
222
  Requires-Dist: pre-commit~=3.4; extra == "dev"
223
- Requires-Dist: pyarrow-stubs; extra == "dev"
223
+ Requires-Dist: pyarrow-stubs==10.0.1.9; extra == "dev"
224
224
  Requires-Dist: pytest~=8.0; extra == "dev"
225
225
  Requires-Dist: pytest-asyncio; extra == "dev"
226
- Requires-Dist: ruff~=0.5.1; extra == "dev"
226
+ Requires-Dist: ruff~=0.6.3; extra == "dev"
227
227
  Requires-Dist: twine~=5.0; extra == "dev"
228
228
  Requires-Dist: snowflake-sqlalchemy~=1.5.0; extra == "dev"
229
229
  Provides-Extra: notebook
@@ -0,0 +1,26 @@
1
+ fakesnow/__init__.py,sha256=qUfgucQYPdELrJaxczalhJgWAWQ6cfTCUAHx6nUqRaI,3528
2
+ fakesnow/__main__.py,sha256=GDrGyNTvBFuqn_UfDjKs7b3LPtU6gDv1KwosVDrukIM,76
3
+ fakesnow/arrow.py,sha256=EGAYeuCnRuvmWBEGqw2YOcgQR4zcCsZBu85kSRl70dQ,4698
4
+ fakesnow/checks.py,sha256=N8sXldhS3u1gG32qvZ4VFlsKgavRKrQrxLiQU8am1lw,2691
5
+ fakesnow/cli.py,sha256=9qfI-Ssr6mo8UmIlXkUAOz2z2YPBgDsrEVaZv9FjGFs,2201
6
+ fakesnow/conn.py,sha256=Gy_Z7BZRm5yMjV3x6hR4iegDQFdG9aJBjqWdc3iWYFU,5353
7
+ fakesnow/cursor.py,sha256=JbLSTzIN5Hu6ECn1kQ8-hC8V-ENeEWTID4JxHbGpEIo,19948
8
+ fakesnow/expr.py,sha256=CAxuYIUkwI339DQIBzvFF0F-m1tcVGKEPA5rDTzmH9A,892
9
+ fakesnow/fakes.py,sha256=JQTiUkkwPeQrJ8FDWhPFPK6pGwd_aR2oiOrNzCWznlM,187
10
+ fakesnow/fixtures.py,sha256=G-NkVeruSQAJ7fvSS2fR2oysUn0Yra1pohHlOvacKEk,455
11
+ fakesnow/info_schema.py,sha256=DObVOrhzppAFHsdtj4YI9oRISn9SkJUG6ONjVleQQ_Y,6303
12
+ fakesnow/instance.py,sha256=3cJvPRuFy19dMKXbtBLl6imzO48pEw8uTYhZyFDuwhk,3133
13
+ fakesnow/macros.py,sha256=pX1YJDnQOkFJSHYUjQ6ErEkYIKvFI6Ncz_au0vv1csA,265
14
+ fakesnow/pandas_tools.py,sha256=WjyjTV8QUCQQaCGboaEOvx2uo4BkknpWYjtLwkeCY6U,3468
15
+ fakesnow/py.typed,sha256=B-DLSjYBi7pkKjwxCSdpVj2J02wgfJr-E7B1wOUyxYU,80
16
+ fakesnow/server.py,sha256=SO5xKZ4rvySsuKDsoSPSCZcFuIX_K7d1XJYhRRJ-7Bk,4150
17
+ fakesnow/transforms.py,sha256=hPQ9L1TvsDjFiGbs8mGUnIFyPUR7JxFU8FZRKFj5ZD0,54568
18
+ fakesnow/transforms_merge.py,sha256=7rq-UPjfFNRrFsqR8xx3otwP6-k4eslLVLhfuqSXq1A,8314
19
+ fakesnow/types.py,sha256=9Tt83Z7ctc9_v6SYyayXYz4MEI4RZo4zq_uqdj4g3Dk,2681
20
+ fakesnow/variables.py,sha256=WXyPnkeNwD08gy52yF66CVe2twiYC50tztNfgXV4q1k,3032
21
+ fakesnow-0.9.25.dist-info/LICENSE,sha256=kW-7NWIyaRMQiDpryfSmF2DObDZHGR1cJZ39s6B1Svg,11344
22
+ fakesnow-0.9.25.dist-info/METADATA,sha256=1RktYqC8KfU4ekh8IGApjHSLb5oe3KMgBLjyKJKXHRc,18074
23
+ fakesnow-0.9.25.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
24
+ fakesnow-0.9.25.dist-info/entry_points.txt,sha256=2riAUgu928ZIHawtO8EsfrMEJhi-EH-z_Vq7Q44xKPM,47
25
+ fakesnow-0.9.25.dist-info/top_level.txt,sha256=500evXI1IFX9so82cizGIEMHAb_dJNPaZvd2H9dcKTA,24
26
+ fakesnow-0.9.25.dist-info/RECORD,,
@@ -1,25 +0,0 @@
1
- fakesnow/__init__.py,sha256=9tFJJKvowKNW3vfnlmza6hOLN1I52DwChgNc5Ew6CcA,3499
2
- fakesnow/__main__.py,sha256=GDrGyNTvBFuqn_UfDjKs7b3LPtU6gDv1KwosVDrukIM,76
3
- fakesnow/arrow.py,sha256=WLkr1nEiNxUcPdzadKSM33sRAiQJsN6LvuzTVIsi3D0,2766
4
- fakesnow/checks.py,sha256=-QMvdcrRbhN60rnzxLBJ0IkUBWyLR8gGGKKmCS0w9mA,2383
5
- fakesnow/cli.py,sha256=9qfI-Ssr6mo8UmIlXkUAOz2z2YPBgDsrEVaZv9FjGFs,2201
6
- fakesnow/conn.py,sha256=yR9SMGSKkLvdPfi5fDwj9PQggOrPcKkdBBnQy2y_Bak,6921
7
- fakesnow/cursor.py,sha256=lITuMMy_hA9_riC121Sv9bFDFbU9BP9NlOLdHlP0ahY,19371
8
- fakesnow/expr.py,sha256=CAxuYIUkwI339DQIBzvFF0F-m1tcVGKEPA5rDTzmH9A,892
9
- fakesnow/fakes.py,sha256=JQTiUkkwPeQrJ8FDWhPFPK6pGwd_aR2oiOrNzCWznlM,187
10
- fakesnow/fixtures.py,sha256=G-NkVeruSQAJ7fvSS2fR2oysUn0Yra1pohHlOvacKEk,455
11
- fakesnow/info_schema.py,sha256=DObVOrhzppAFHsdtj4YI9oRISn9SkJUG6ONjVleQQ_Y,6303
12
- fakesnow/instance.py,sha256=3cJvPRuFy19dMKXbtBLl6imzO48pEw8uTYhZyFDuwhk,3133
13
- fakesnow/macros.py,sha256=pX1YJDnQOkFJSHYUjQ6ErEkYIKvFI6Ncz_au0vv1csA,265
14
- fakesnow/pandas_tools.py,sha256=ecL0kxIVU5o2--P3bRLWWVhxXqq6Km4trFr36txukMg,1897
15
- fakesnow/py.typed,sha256=B-DLSjYBi7pkKjwxCSdpVj2J02wgfJr-E7B1wOUyxYU,80
16
- fakesnow/server.py,sha256=8dzaLUUXPzCMm6-ESn0CBws6XSwwOpnUuHQAZJ-4SwU,3011
17
- fakesnow/transforms.py,sha256=ellcY5OBc7mqgL9ChNolrqcCLWXF9RH21Jt88FcFl-I,54419
18
- fakesnow/types.py,sha256=9Tt83Z7ctc9_v6SYyayXYz4MEI4RZo4zq_uqdj4g3Dk,2681
19
- fakesnow/variables.py,sha256=WXyPnkeNwD08gy52yF66CVe2twiYC50tztNfgXV4q1k,3032
20
- fakesnow-0.9.23.dist-info/LICENSE,sha256=kW-7NWIyaRMQiDpryfSmF2DObDZHGR1cJZ39s6B1Svg,11344
21
- fakesnow-0.9.23.dist-info/METADATA,sha256=90vwAf40lg9ipw2urbu9nMq7cEARbtENE0aKd7qSpjo,18064
22
- fakesnow-0.9.23.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
23
- fakesnow-0.9.23.dist-info/entry_points.txt,sha256=2riAUgu928ZIHawtO8EsfrMEJhi-EH-z_Vq7Q44xKPM,47
24
- fakesnow-0.9.23.dist-info/top_level.txt,sha256=500evXI1IFX9so82cizGIEMHAb_dJNPaZvd2H9dcKTA,24
25
- fakesnow-0.9.23.dist-info/RECORD,,