fakesnow 0.9.24__py3-none-any.whl → 0.9.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fakesnow/__init__.py CHANGED
@@ -90,3 +90,4 @@ def patch(
90
90
  yield None
91
91
  finally:
92
92
  stack.close()
93
+ fs.duck_conn.close()
fakesnow/arrow.py CHANGED
@@ -1,33 +1,45 @@
1
- from typing import Any
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
2
4
 
3
5
  import pyarrow as pa
6
+ import pyarrow.compute as pc
7
+
8
+ from fakesnow.types import ColumnInfo
9
+
10
+
11
+ def to_sf_schema(schema: pa.Schema, rowtype: list[ColumnInfo]) -> pa.Schema:
12
+ # expected by the snowflake connector
13
+ # uses rowtype to populate metadata, rather than the arrow schema type, for consistency with
14
+ # rowtype returned in the response
4
15
 
16
+ assert len(schema) == len(rowtype), f"schema and rowtype must be same length but f{len(schema)=} f{len(rowtype)=}"
5
17
 
6
- def with_sf_metadata(schema: pa.Schema) -> pa.Schema:
7
18
  # see https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp#L32
8
19
  # and https://github.com/snowflakedb/snowflake-connector-python/blob/e9393a6/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/SnowflakeType.cpp#L10
9
- fms = []
10
- for i, t in enumerate(schema.types):
11
- f = schema.field(i)
12
-
13
- # TODO: precision, scale, charLength etc. for all types
14
-
15
- if t == pa.bool_():
16
- fm = f.with_metadata({"logicalType": "BOOLEAN"})
17
- elif t == pa.int64():
18
- # scale and precision required, see here
19
- # https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147
20
- fm = f.with_metadata({"logicalType": "FIXED", "precision": "38", "scale": "0"})
21
- elif t == pa.float64():
22
- fm = f.with_metadata({"logicalType": "REAL"})
23
- elif isinstance(t, pa.Decimal128Type):
24
- fm = f.with_metadata({"logicalType": "FIXED", "precision": str(t.precision), "scale": str(t.scale)})
25
- elif t == pa.string():
26
- # TODO: set charLength to size of column
27
- fm = f.with_metadata({"logicalType": "TEXT", "charLength": "16777216"})
28
- else:
29
- raise NotImplementedError(f"Unsupported Arrow type: {t}")
30
- fms.append(fm)
20
+
21
+ def sf_field(field: pa.Field, c: ColumnInfo) -> pa.Field:
22
+ if isinstance(field.type, pa.TimestampType):
23
+ # snowflake uses a struct to represent timestamps, see timestamp_to_sf_struct
24
+ fields = [pa.field("epoch", pa.int64(), nullable=False), pa.field("fraction", pa.int32(), nullable=False)]
25
+ if field.type.tz:
26
+ fields.append(pa.field("timezone", nullable=False, type=pa.int32()))
27
+ field = field.with_type(pa.struct(fields))
28
+ elif isinstance(field.type, pa.Time64Type):
29
+ field = field.with_type(pa.int64())
30
+
31
+ return field.with_metadata(
32
+ {
33
+ "logicalType": c["type"].upper(),
34
+ # required for FIXED type see
35
+ # https://github.com/snowflakedb/snowflake-connector-python/blob/416ff57/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowChunkIterator.cpp#L147
36
+ "precision": str(c["precision"] or 38),
37
+ "scale": str(c["scale"] or 0),
38
+ "charLength": str(c["length"] or 0),
39
+ }
40
+ )
41
+
42
+ fms = [sf_field(schema.field(i), c) for i, c in enumerate(rowtype)]
31
43
  return pa.schema(fms)
32
44
 
33
45
 
@@ -39,29 +51,56 @@ def to_ipc(table: pa.Table) -> pa.Buffer:
39
51
 
40
52
  sink = pa.BufferOutputStream()
41
53
 
42
- with pa.ipc.new_stream(sink, with_sf_metadata(table.schema)) as writer:
54
+ with pa.ipc.new_stream(sink, table.schema) as writer:
43
55
  writer.write_batch(batch)
44
56
 
45
57
  return sink.getvalue()
46
58
 
47
59
 
48
- # TODO: should this be derived before with_schema?
49
- def to_rowtype(schema: pa.Schema) -> list[dict[str, Any]]:
50
- return [
51
- {
52
- "name": f.name,
53
- # TODO
54
- # "database": "",
55
- # "schema": "",
56
- # "table": "",
57
- "nullable": f.nullable,
58
- "type": f.metadata.get(b"logicalType").decode("utf-8").lower(), # type: ignore
59
- # TODO
60
- # "byteLength": 20,
61
- "length": int(f.metadata.get(b"charLength")) if f.metadata.get(b"charLength") else None, # type: ignore
62
- "scale": int(f.metadata.get(b"scale")) if f.metadata.get(b"scale") else None, # type: ignore
63
- "precision": int(f.metadata.get(b"precision")) if f.metadata.get(b"precision") else None, # type: ignore
64
- "collation": None,
65
- }
66
- for f in schema
67
- ]
60
+ def to_sf(table: pa.Table, rowtype: list[ColumnInfo]) -> pa.Table:
61
+ def to_sf_col(col: pa.Array) -> pa.Array:
62
+ if pa.types.is_timestamp(col.type):
63
+ return timestamp_to_sf_struct(col)
64
+ elif pa.types.is_time(col.type):
65
+ # as nanoseconds
66
+ return pc.multiply(col.cast(pa.int64()), 1000) # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/44
67
+ return col
68
+
69
+ return pa.Table.from_arrays([to_sf_col(c) for c in table.columns], schema=to_sf_schema(table.schema, rowtype))
70
+
71
+
72
+ def timestamp_to_sf_struct(ts: pa.Array | pa.ChunkedArray) -> pa.Array:
73
+ if isinstance(ts, pa.ChunkedArray):
74
+ # combine because pa.StructArray.from_arrays doesn't support ChunkedArray
75
+ ts = cast(pa.Array, ts.combine_chunks()) # see https://github.com/zen-xu/pyarrow-stubs/issues/46
76
+
77
+ if not isinstance(ts.type, pa.TimestampType):
78
+ raise ValueError(f"Expected TimestampArray, got {type(ts)}")
79
+
80
+ # Round to seconds, ie: strip subseconds
81
+ tsa_without_us = pc.floor_temporal(ts, unit="second") # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/45
82
+ epoch = pc.divide(tsa_without_us.cast(pa.int64()), 1_000_000) # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/44
83
+
84
+ # Calculate fractional part as nanoseconds
85
+ fraction = pc.multiply(pc.subsecond(ts), 1_000_000_000).cast(pa.int32()) # type: ignore
86
+
87
+ if ts.type.tz:
88
+ assert ts.type.tz == "UTC", f"Timezone {ts.type.tz} not yet supported"
89
+ timezone = pa.array([1440] * len(ts), type=pa.int32())
90
+
91
+ return pa.StructArray.from_arrays(
92
+ arrays=[epoch, fraction, timezone], # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/42
93
+ fields=[
94
+ pa.field("epoch", nullable=False, type=pa.int64()),
95
+ pa.field("fraction", nullable=False, type=pa.int32()),
96
+ pa.field("timezone", nullable=False, type=pa.int32()),
97
+ ],
98
+ )
99
+ else:
100
+ return pa.StructArray.from_arrays(
101
+ arrays=[epoch, fraction], # type: ignore https://github.com/zen-xu/pyarrow-stubs/issues/42
102
+ fields=[
103
+ pa.field("epoch", nullable=False, type=pa.int64()),
104
+ pa.field("fraction", nullable=False, type=pa.int32()),
105
+ ],
106
+ )
fakesnow/checks.py CHANGED
@@ -68,3 +68,11 @@ def is_unqualified_table_expression(expression: exp.Expression) -> tuple[bool, b
68
68
  no_schema = not node.args.get("db")
69
69
 
70
70
  return no_database, no_schema
71
+
72
+
73
+ def equal(left: exp.Identifier, right: exp.Identifier) -> bool:
74
+ # as per https://docs.snowflake.com/en/sql-reference/identifiers-syntax#label-identifier-casing
75
+ lid = left.this if left.quoted else left.this.upper()
76
+ rid = right.this if right.quoted else right.this.upper()
77
+
78
+ return lid == rid
fakesnow/cursor.py CHANGED
@@ -112,13 +112,15 @@ class FakeSnowflakeCursor:
112
112
 
113
113
  @property
114
114
  def description(self) -> list[ResultMetadata]:
115
+ return describe_as_result_metadata(self._describe_last_sql())
116
+
117
+ def _describe_last_sql(self) -> list:
115
118
  # use a separate cursor to avoid consuming the result set on this cursor
116
119
  with self._conn.cursor() as cur:
120
+ # TODO: can we replace with self._duck_conn.description?
117
121
  expression = sqlglot.parse_one(f"DESCRIBE {self._last_sql}", read="duckdb")
118
122
  cur._execute(expression, self._last_params) # noqa: SLF001
119
- meta = describe_as_result_metadata(cur.fetchall())
120
-
121
- return meta
123
+ return cur.fetchall()
122
124
 
123
125
  def execute(
124
126
  self,
@@ -137,10 +139,15 @@ class FakeSnowflakeCursor:
137
139
  command, params = self._rewrite_with_params(command, params)
138
140
  if self._conn.nop_regexes and any(re.match(p, command, re.IGNORECASE) for p in self._conn.nop_regexes):
139
141
  transformed = transforms.SUCCESS_NOP
140
- else:
141
- expression = parse_one(command, read="snowflake")
142
- transformed = self._transform(expression)
143
- return self._execute(transformed, params)
142
+ self._execute(transformed, params)
143
+ return self
144
+
145
+ expression = parse_one(command, read="snowflake")
146
+ for exp in self._transform_explode(expression):
147
+ transformed = self._transform(exp)
148
+ self._execute(transformed, params)
149
+
150
+ return self
144
151
  except snowflake.connector.errors.ProgrammingError as e:
145
152
  self._sqlstate = e.sqlstate
146
153
  raise e
@@ -205,9 +212,12 @@ class FakeSnowflakeCursor:
205
212
  .transform(transforms.alter_table_strip_cluster_by)
206
213
  )
207
214
 
208
- def _execute(
209
- self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None
210
- ) -> FakeSnowflakeCursor:
215
+ def _transform_explode(self, expression: exp.Expression) -> list[exp.Expression]:
216
+ # Applies transformations that require splitting the expression into multiple expressions
217
+ # Split transforms have limited support at the moment.
218
+ return transforms.merge(expression)
219
+
220
+ def _execute(self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
211
221
  self._arrow_table = None
212
222
  self._arrow_table_fetch_index = None
213
223
  self._rowcount = None
@@ -343,8 +353,6 @@ class FakeSnowflakeCursor:
343
353
  self._last_sql = result_sql or sql
344
354
  self._last_params = params
345
355
 
346
- return self
347
-
348
356
  def _log_sql(self, sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
349
357
  if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
350
358
  print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
fakesnow/server.py CHANGED
@@ -5,19 +5,22 @@ import json
5
5
  import secrets
6
6
  from base64 import b64encode
7
7
  from dataclasses import dataclass
8
+ from typing import Any
8
9
 
10
+ import snowflake.connector.errors
9
11
  from starlette.applications import Starlette
10
12
  from starlette.concurrency import run_in_threadpool
11
13
  from starlette.requests import Request
12
14
  from starlette.responses import JSONResponse
13
15
  from starlette.routing import Route
14
16
 
15
- from fakesnow.arrow import to_ipc, to_rowtype, with_sf_metadata
17
+ from fakesnow.arrow import to_ipc, to_sf
16
18
  from fakesnow.fakes import FakeSnowflakeConnection
17
19
  from fakesnow.instance import FakeSnow
20
+ from fakesnow.types import describe_as_rowtype
18
21
 
19
- fs = FakeSnow()
20
- sessions = {}
22
+ shared_fs = FakeSnow()
23
+ sessions: dict[str, FakeSnowflakeConnection] = {}
21
24
 
22
25
 
23
26
  @dataclass
@@ -27,9 +30,19 @@ class ServerError(Exception):
27
30
  message: str
28
31
 
29
32
 
30
- def login_request(request: Request) -> JSONResponse:
33
+ async def login_request(request: Request) -> JSONResponse:
31
34
  database = request.query_params.get("databaseName")
32
35
  schema = request.query_params.get("schemaName")
36
+ body = await request.body()
37
+ body_json = json.loads(gzip.decompress(body))
38
+ session_params: dict[str, Any] = body_json["data"]["SESSION_PARAMETERS"]
39
+ if db_path := session_params.get("FAKESNOW_DB_PATH"):
40
+ # isolated creates a new in-memory database, rather than using the shared in-memory database
41
+ # so this connection won't share any tables with other connections
42
+ fs = FakeSnow() if db_path == ":isolated:" else FakeSnow(db_path=db_path)
43
+ else:
44
+ # share the in-memory database across connections
45
+ fs = shared_fs
33
46
  token = secrets.token_urlsafe(32)
34
47
  sessions[token] = fs.connect(database, schema)
35
48
  return JSONResponse({"data": {"token": token}, "success": True})
@@ -44,16 +57,30 @@ async def query_request(request: Request) -> JSONResponse:
44
57
 
45
58
  sql_text = body_json["sqlText"]
46
59
 
47
- # only a single sql statement is sent at a time by the python snowflake connector
48
- cur = await run_in_threadpool(conn.cursor().execute, sql_text)
49
-
50
- assert cur._arrow_table, "No result set" # noqa: SLF001
51
-
52
- batch_bytes = to_ipc(cur._arrow_table) # noqa: SLF001
53
- rowset_b64 = b64encode(batch_bytes).decode("utf-8")
54
-
55
- # TODO: avoid calling with_sf_metadata twice
56
- rowtype = to_rowtype(with_sf_metadata(cur._arrow_table.schema)) # noqa: SLF001
60
+ try:
61
+ # only a single sql statement is sent at a time by the python snowflake connector
62
+ cur = await run_in_threadpool(conn.cursor().execute, sql_text)
63
+ except snowflake.connector.errors.ProgrammingError as e:
64
+ code = f"{e.errno:06d}"
65
+ return JSONResponse(
66
+ {
67
+ "data": {
68
+ "errorCode": code,
69
+ "sqlState": e.sqlstate,
70
+ },
71
+ "code": code,
72
+ "message": e.msg,
73
+ "success": False,
74
+ }
75
+ )
76
+
77
+ rowtype = describe_as_rowtype(cur._describe_last_sql()) # noqa: SLF001
78
+
79
+ if cur._arrow_table: # noqa: SLF001
80
+ batch_bytes = to_ipc(to_sf(cur._arrow_table, rowtype)) # noqa: SLF001
81
+ rowset_b64 = b64encode(batch_bytes).decode("utf-8")
82
+ else:
83
+ rowset_b64 = ""
57
84
 
58
85
  return JSONResponse(
59
86
  {
fakesnow/transforms.py CHANGED
@@ -7,6 +7,7 @@ from typing import ClassVar, Literal, cast
7
7
  import sqlglot
8
8
  from sqlglot import exp
9
9
 
10
+ from fakesnow import transforms_merge
10
11
  from fakesnow.instance import USERS_TABLE_FQ_NAME
11
12
  from fakesnow.variables import Variables
12
13
 
@@ -691,6 +692,10 @@ def json_extract_precedence(expression: exp.Expression) -> exp.Expression:
691
692
  return expression
692
693
 
693
694
 
695
+ def merge(expression: exp.Expression) -> list[exp.Expression]:
696
+ return transforms_merge.merge(expression)
697
+
698
+
694
699
  def random(expression: exp.Expression) -> exp.Expression:
695
700
  """Convert random() and random(seed).
696
701
 
@@ -0,0 +1,203 @@
1
+ import sqlglot
2
+ from sqlglot import exp
3
+
4
+ from fakesnow import checks
5
+
6
+ # Implements snowflake's MERGE INTO functionality in duckdb (https://docs.snowflake.com/en/sql-reference/sql/merge).
7
+
8
+
9
+ def merge(merge_expr: exp.Expression) -> list[exp.Expression]:
10
+ if not isinstance(merge_expr, exp.Merge):
11
+ return [merge_expr]
12
+
13
+ return [_create_merge_candidates(merge_expr), *_mutations(merge_expr), _counts(merge_expr)]
14
+
15
+
16
+ def _create_merge_candidates(merge_expr: exp.Merge) -> exp.Expression:
17
+ """
18
+ Given a merge statement, produce a temporary table that joins together the target and source tables.
19
+ The merge_op column identifies which merge clause applies to the row.
20
+ """
21
+ target_tbl = merge_expr.this
22
+
23
+ source = merge_expr.args.get("using")
24
+ assert isinstance(source, exp.Expression)
25
+ source_id = (alias := source.args.get("alias")) and alias.this if isinstance(source, exp.Subquery) else source.this
26
+ assert isinstance(source_id, exp.Identifier)
27
+
28
+ join_expr = merge_expr.args.get("on")
29
+ assert isinstance(join_expr, exp.Binary)
30
+
31
+ case_when_clauses: list[str] = []
32
+ values: set[str] = set()
33
+
34
+ # extract keys that reference the source table from the join expression
35
+ # so they can be used by the mutation statements for joining
36
+ # will include the source table identifier
37
+ values.update(
38
+ map(
39
+ str,
40
+ {
41
+ c
42
+ for c in join_expr.find_all(exp.Column)
43
+ if (table := c.args.get("table"))
44
+ and isinstance(table, exp.Identifier)
45
+ and checks.equal(table, source_id)
46
+ },
47
+ )
48
+ )
49
+
50
+ # Iterate through the WHEN clauses to build up the CASE WHEN clauses
51
+ for w_idx, w in enumerate(merge_expr.expressions):
52
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
53
+
54
+ predicate = join_expr.copy()
55
+ matched = w.args.get("matched")
56
+ then = w.args.get("then")
57
+ condition = w.args.get("condition")
58
+
59
+ if matched:
60
+ # matchedClause see https://docs.snowflake.com/en/sql-reference/sql/merge#matchedclause-for-updates-or-deletes
61
+ if condition:
62
+ # Combine the top level ON expression with the AND condition
63
+ # from this specific WHEN into a subquery, we use to target rows.
64
+ # Eg. MERGE INTO t1 USING t2 ON t1.t1Key = t2.t2Key
65
+ # WHEN MATCHED AND t2.marked = 1 THEN DELETE
66
+ predicate = exp.And(this=predicate, expression=condition)
67
+
68
+ if isinstance(then, exp.Update):
69
+ case_when_clauses.append(f"WHEN {predicate} THEN {w_idx}")
70
+ values.update([str(c.expression) for c in then.expressions if isinstance(c.expression, exp.Column)])
71
+ elif isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
72
+ case_when_clauses.append(f"WHEN {predicate} THEN {w_idx}")
73
+ else:
74
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
75
+ else:
76
+ # notMatchedClause see https://docs.snowflake.com/en/sql-reference/sql/merge#notmatchedclause-for-inserts
77
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
78
+ insert_values = then.expression.expressions
79
+ values.update([str(c) for c in insert_values if isinstance(c, exp.Column)])
80
+ predicate = f"AND {condition}" if condition else ""
81
+ case_when_clauses.append(f"WHEN {target_tbl}.rowid is NULL {predicate} THEN {w_idx}")
82
+
83
+ sql = f"""
84
+ CREATE OR REPLACE TEMPORARY TABLE merge_candidates AS
85
+ SELECT
86
+ {', '.join(sorted(values))},
87
+ CASE
88
+ {' '.join(case_when_clauses)}
89
+ ELSE NULL
90
+ END AS MERGE_OP
91
+ FROM {target_tbl}
92
+ FULL OUTER JOIN {source} ON {join_expr.sql()}
93
+ WHERE MERGE_OP IS NOT NULL
94
+ """
95
+
96
+ return sqlglot.parse_one(sql)
97
+
98
+
99
+ def _mutations(merge_expr: exp.Merge) -> list[exp.Expression]:
100
+ """
101
+ Given a merge statement, produce a list of delete, update and insert statements that use the
102
+ merge_candidates and source table to update the target target.
103
+ """
104
+ target_tbl = merge_expr.this
105
+ source = merge_expr.args.get("using")
106
+ source_tbl = source.alias if isinstance(source, exp.Subquery) else source
107
+ join_expr = merge_expr.args.get("on")
108
+
109
+ statements: list[exp.Expression] = []
110
+
111
+ # Iterate through the WHEN clauses to generate delete/update/insert statements
112
+ for w_idx, w in enumerate(merge_expr.expressions):
113
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
114
+
115
+ matched = w.args.get("matched")
116
+ then = w.args.get("then")
117
+
118
+ if matched:
119
+ if isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
120
+ delete_sql = f"""
121
+ DELETE FROM {target_tbl}
122
+ USING merge_candidates AS {source_tbl}
123
+ WHERE {join_expr}
124
+ AND {source_tbl}.merge_op = {w_idx}
125
+ """
126
+ statements.append(sqlglot.parse_one(delete_sql))
127
+ elif isinstance(then, exp.Update):
128
+ # when the update statement has a table alias, duckdb doesn't support the alias in the set
129
+ # column name, so we use e.this.this to get just the column name without its table prefix
130
+ set_clauses = ", ".join(
131
+ [f"{e.this.this} = {e.expression.sql()}" for e in then.args.get("expressions", [])]
132
+ )
133
+ update_sql = f"""
134
+ UPDATE {target_tbl}
135
+ SET {set_clauses}
136
+ FROM merge_candidates AS {source_tbl}
137
+ WHERE {join_expr}
138
+ AND {source_tbl}.merge_op = {w_idx}
139
+ """
140
+ statements.append(sqlglot.parse_one(update_sql))
141
+ else:
142
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
143
+ else:
144
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
145
+ cols = [str(c) for c in then.this.expressions] if then.this else []
146
+ columns = f"({', '.join(cols)})" if cols else ""
147
+ values = ", ".join(map(str, then.expression.expressions))
148
+ insert_sql = f"""
149
+ INSERT INTO {target_tbl} {columns}
150
+ SELECT {values}
151
+ FROM merge_candidates AS {source_tbl}
152
+ WHERE {source_tbl}.merge_op = {w_idx}
153
+ """
154
+ statements.append(sqlglot.parse_one(insert_sql))
155
+
156
+ return statements
157
+
158
+
159
+ def _counts(merge_expr: exp.Merge) -> exp.Expression:
160
+ """
161
+ Given a merge statement, derive the a SQL statement which produces the following columns using the merge_candidates
162
+ table:
163
+
164
+ - "number of rows inserted"
165
+ - "number of rows updated"
166
+ - "number of rows deleted"
167
+
168
+ Only columns relevant to the merge operation are included, eg: if no rows are deleted, the "number of rows deleted"
169
+ column is not included.
170
+ """
171
+
172
+ # Initialize dictionaries to store operation types and their corresponding indices
173
+ operations = {"inserted": [], "updated": [], "deleted": []}
174
+
175
+ # Iterate through the WHEN clauses to categorize operations
176
+ for w_idx, w in enumerate(merge_expr.expressions):
177
+ assert isinstance(w, exp.When), f"Expected When expression, got {w}"
178
+
179
+ matched = w.args.get("matched")
180
+ then = w.args.get("then")
181
+
182
+ if matched:
183
+ if isinstance(then, exp.Update):
184
+ operations["updated"].append(w_idx)
185
+ elif isinstance(then, exp.Var) and then.args.get("this") == "DELETE":
186
+ operations["deleted"].append(w_idx)
187
+ else:
188
+ raise AssertionError(f"Expected 'Update' or 'Delete', got {then}")
189
+ else:
190
+ assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
191
+ operations["inserted"].append(w_idx)
192
+
193
+ count_statements = [
194
+ f"""COUNT_IF(merge_op in ({','.join(map(str, indices))})) as \"number of rows {op}\""""
195
+ for op, indices in operations.items()
196
+ if indices
197
+ ]
198
+ sql = f"""
199
+ SELECT {', '.join(count_statements)}
200
+ FROM merge_candidates
201
+ """
202
+
203
+ return sqlglot.parse_one(sql)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fakesnow
3
- Version: 0.9.24
3
+ Version: 0.9.25
4
4
  Summary: Fake Snowflake Connector for Python. Run, mock and test Snowflake DB locally.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -220,10 +220,10 @@ Requires-Dist: dirty-equals; extra == "dev"
220
220
  Requires-Dist: pandas-stubs; extra == "dev"
221
221
  Requires-Dist: snowflake-connector-python[pandas,secure-local-storage]; extra == "dev"
222
222
  Requires-Dist: pre-commit~=3.4; extra == "dev"
223
- Requires-Dist: pyarrow-stubs; extra == "dev"
223
+ Requires-Dist: pyarrow-stubs==10.0.1.9; extra == "dev"
224
224
  Requires-Dist: pytest~=8.0; extra == "dev"
225
225
  Requires-Dist: pytest-asyncio; extra == "dev"
226
- Requires-Dist: ruff~=0.5.1; extra == "dev"
226
+ Requires-Dist: ruff~=0.6.3; extra == "dev"
227
227
  Requires-Dist: twine~=5.0; extra == "dev"
228
228
  Requires-Dist: snowflake-sqlalchemy~=1.5.0; extra == "dev"
229
229
  Provides-Extra: notebook
@@ -1,10 +1,10 @@
1
- fakesnow/__init__.py,sha256=9tFJJKvowKNW3vfnlmza6hOLN1I52DwChgNc5Ew6CcA,3499
1
+ fakesnow/__init__.py,sha256=qUfgucQYPdELrJaxczalhJgWAWQ6cfTCUAHx6nUqRaI,3528
2
2
  fakesnow/__main__.py,sha256=GDrGyNTvBFuqn_UfDjKs7b3LPtU6gDv1KwosVDrukIM,76
3
- fakesnow/arrow.py,sha256=WLkr1nEiNxUcPdzadKSM33sRAiQJsN6LvuzTVIsi3D0,2766
4
- fakesnow/checks.py,sha256=-QMvdcrRbhN60rnzxLBJ0IkUBWyLR8gGGKKmCS0w9mA,2383
3
+ fakesnow/arrow.py,sha256=EGAYeuCnRuvmWBEGqw2YOcgQR4zcCsZBu85kSRl70dQ,4698
4
+ fakesnow/checks.py,sha256=N8sXldhS3u1gG32qvZ4VFlsKgavRKrQrxLiQU8am1lw,2691
5
5
  fakesnow/cli.py,sha256=9qfI-Ssr6mo8UmIlXkUAOz2z2YPBgDsrEVaZv9FjGFs,2201
6
6
  fakesnow/conn.py,sha256=Gy_Z7BZRm5yMjV3x6hR4iegDQFdG9aJBjqWdc3iWYFU,5353
7
- fakesnow/cursor.py,sha256=2PtW9hzfXs3mzv6BBxXLoS-pPtD4otrfQ2KnPNNanGI,19441
7
+ fakesnow/cursor.py,sha256=JbLSTzIN5Hu6ECn1kQ8-hC8V-ENeEWTID4JxHbGpEIo,19948
8
8
  fakesnow/expr.py,sha256=CAxuYIUkwI339DQIBzvFF0F-m1tcVGKEPA5rDTzmH9A,892
9
9
  fakesnow/fakes.py,sha256=JQTiUkkwPeQrJ8FDWhPFPK6pGwd_aR2oiOrNzCWznlM,187
10
10
  fakesnow/fixtures.py,sha256=G-NkVeruSQAJ7fvSS2fR2oysUn0Yra1pohHlOvacKEk,455
@@ -13,13 +13,14 @@ fakesnow/instance.py,sha256=3cJvPRuFy19dMKXbtBLl6imzO48pEw8uTYhZyFDuwhk,3133
13
13
  fakesnow/macros.py,sha256=pX1YJDnQOkFJSHYUjQ6ErEkYIKvFI6Ncz_au0vv1csA,265
14
14
  fakesnow/pandas_tools.py,sha256=WjyjTV8QUCQQaCGboaEOvx2uo4BkknpWYjtLwkeCY6U,3468
15
15
  fakesnow/py.typed,sha256=B-DLSjYBi7pkKjwxCSdpVj2J02wgfJr-E7B1wOUyxYU,80
16
- fakesnow/server.py,sha256=8dzaLUUXPzCMm6-ESn0CBws6XSwwOpnUuHQAZJ-4SwU,3011
17
- fakesnow/transforms.py,sha256=ellcY5OBc7mqgL9ChNolrqcCLWXF9RH21Jt88FcFl-I,54419
16
+ fakesnow/server.py,sha256=SO5xKZ4rvySsuKDsoSPSCZcFuIX_K7d1XJYhRRJ-7Bk,4150
17
+ fakesnow/transforms.py,sha256=hPQ9L1TvsDjFiGbs8mGUnIFyPUR7JxFU8FZRKFj5ZD0,54568
18
+ fakesnow/transforms_merge.py,sha256=7rq-UPjfFNRrFsqR8xx3otwP6-k4eslLVLhfuqSXq1A,8314
18
19
  fakesnow/types.py,sha256=9Tt83Z7ctc9_v6SYyayXYz4MEI4RZo4zq_uqdj4g3Dk,2681
19
20
  fakesnow/variables.py,sha256=WXyPnkeNwD08gy52yF66CVe2twiYC50tztNfgXV4q1k,3032
20
- fakesnow-0.9.24.dist-info/LICENSE,sha256=kW-7NWIyaRMQiDpryfSmF2DObDZHGR1cJZ39s6B1Svg,11344
21
- fakesnow-0.9.24.dist-info/METADATA,sha256=LHKc6JYn9sxxFh6_i7kqlWz1fmloFv2CCmpalwPVFrE,18064
22
- fakesnow-0.9.24.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
23
- fakesnow-0.9.24.dist-info/entry_points.txt,sha256=2riAUgu928ZIHawtO8EsfrMEJhi-EH-z_Vq7Q44xKPM,47
24
- fakesnow-0.9.24.dist-info/top_level.txt,sha256=500evXI1IFX9so82cizGIEMHAb_dJNPaZvd2H9dcKTA,24
25
- fakesnow-0.9.24.dist-info/RECORD,,
21
+ fakesnow-0.9.25.dist-info/LICENSE,sha256=kW-7NWIyaRMQiDpryfSmF2DObDZHGR1cJZ39s6B1Svg,11344
22
+ fakesnow-0.9.25.dist-info/METADATA,sha256=1RktYqC8KfU4ekh8IGApjHSLb5oe3KMgBLjyKJKXHRc,18074
23
+ fakesnow-0.9.25.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
24
+ fakesnow-0.9.25.dist-info/entry_points.txt,sha256=2riAUgu928ZIHawtO8EsfrMEJhi-EH-z_Vq7Q44xKPM,47
25
+ fakesnow-0.9.25.dist-info/top_level.txt,sha256=500evXI1IFX9so82cizGIEMHAb_dJNPaZvd2H9dcKTA,24
26
+ fakesnow-0.9.25.dist-info/RECORD,,