fakesnow 0.9.22__py3-none-any.whl → 0.9.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fakesnow/cursor.py ADDED
@@ -0,0 +1,463 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import sys
6
+ from collections.abc import Iterator, Sequence
7
+ from string import Template
8
+ from types import TracebackType
9
+ from typing import TYPE_CHECKING, Any, cast
10
+
11
+ import duckdb
12
+ import pyarrow
13
+ import snowflake.connector.converter
14
+ import snowflake.connector.errors
15
+ import sqlglot
16
+ from duckdb import DuckDBPyConnection
17
+ from snowflake.connector.cursor import ResultMetadata
18
+ from snowflake.connector.result_batch import ResultBatch
19
+ from sqlglot import exp, parse_one
20
+ from typing_extensions import Self
21
+
22
+ import fakesnow.checks as checks
23
+ import fakesnow.expr as expr
24
+ import fakesnow.info_schema as info_schema
25
+ import fakesnow.transforms as transforms
26
+ from fakesnow.types import describe_as_result_metadata
27
+
28
+ if TYPE_CHECKING:
29
+ import pandas as pd
30
+ import pyarrow.lib
31
+
32
+ from fakesnow.conn import FakeSnowflakeConnection
33
+
34
+
35
+ SCHEMA_UNSET = "schema_unset"
36
+ SQL_SUCCESS = "SELECT 'Statement executed successfully.' as 'status'"
37
+ SQL_CREATED_DATABASE = Template("SELECT 'Database ${name} successfully created.' as 'status'")
38
+ SQL_CREATED_SCHEMA = Template("SELECT 'Schema ${name} successfully created.' as 'status'")
39
+ SQL_CREATED_TABLE = Template("SELECT 'Table ${name} successfully created.' as 'status'")
40
+ SQL_CREATED_VIEW = Template("SELECT 'View ${name} successfully created.' as 'status'")
41
+ SQL_DROPPED = Template("SELECT '${name} successfully dropped.' as 'status'")
42
+ SQL_INSERTED_ROWS = Template("SELECT ${count} as 'number of rows inserted'")
43
+ SQL_UPDATED_ROWS = Template("SELECT ${count} as 'number of rows updated', 0 as 'number of multi-joined rows updated'")
44
+ SQL_DELETED_ROWS = Template("SELECT ${count} as 'number of rows deleted'")
45
+
46
+
47
+ class FakeSnowflakeCursor:
48
+ def __init__(
49
+ self,
50
+ conn: FakeSnowflakeConnection,
51
+ duck_conn: DuckDBPyConnection,
52
+ use_dict_result: bool = False,
53
+ ) -> None:
54
+ """Create a fake snowflake cursor backed by DuckDB.
55
+
56
+ Args:
57
+ conn (FakeSnowflakeConnection): Used to maintain current database and schema.
58
+ duck_conn (DuckDBPyConnection): DuckDB connection.
59
+ use_dict_result (bool, optional): If true rows are returned as dicts otherwise they
60
+ are returned as tuples. Defaults to False.
61
+ """
62
+ self._conn = conn
63
+ self._duck_conn = duck_conn
64
+ self._use_dict_result = use_dict_result
65
+ self._last_sql = None
66
+ self._last_params = None
67
+ self._sqlstate = None
68
+ self._arraysize = 1
69
+ self._arrow_table = None
70
+ self._arrow_table_fetch_index = None
71
+ self._rowcount = None
72
+ self._converter = snowflake.connector.converter.SnowflakeConverter()
73
+
74
+ def __enter__(self) -> Self:
75
+ return self
76
+
77
+ def __exit__(
78
+ self,
79
+ exc_type: type[BaseException] | None,
80
+ exc_value: BaseException | None,
81
+ traceback: TracebackType | None,
82
+ ) -> None:
83
+ pass
84
+
85
+ @property
86
+ def arraysize(self) -> int:
87
+ return self._arraysize
88
+
89
+ @arraysize.setter
90
+ def arraysize(self, value: int) -> None:
91
+ self._arraysize = value
92
+
93
+ def close(self) -> bool:
94
+ self._last_sql = None
95
+ self._last_params = None
96
+ return True
97
+
98
+ def describe(self, command: str, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
99
+ """Return the schema of the result without executing the query.
100
+
101
+ Takes the same arguments as execute
102
+
103
+ Returns:
104
+ list[ResultMetadata]: _description_
105
+ """
106
+
107
+ describe = f"DESCRIBE {command}"
108
+ self.execute(describe, *args, **kwargs)
109
+ return describe_as_result_metadata(self.fetchall())
110
+
111
+ @property
112
+ def description(self) -> list[ResultMetadata]:
113
+ # use a separate cursor to avoid consuming the result set on this cursor
114
+ with self._conn.cursor() as cur:
115
+ expression = sqlglot.parse_one(f"DESCRIBE {self._last_sql}", read="duckdb")
116
+ cur._execute(expression, self._last_params) # noqa: SLF001
117
+ meta = describe_as_result_metadata(cur.fetchall())
118
+
119
+ return meta
120
+
121
+ def execute(
122
+ self,
123
+ command: str,
124
+ params: Sequence[Any] | dict[Any, Any] | None = None,
125
+ *args: Any,
126
+ **kwargs: Any,
127
+ ) -> FakeSnowflakeCursor:
128
+ try:
129
+ self._sqlstate = None
130
+
131
+ if os.environ.get("FAKESNOW_DEBUG") == "snowflake":
132
+ print(f"{command};{params=}" if params else f"{command};", file=sys.stderr)
133
+
134
+ command = self._inline_variables(command)
135
+ command, params = self._rewrite_with_params(command, params)
136
+ if self._conn.nop_regexes and any(re.match(p, command, re.IGNORECASE) for p in self._conn.nop_regexes):
137
+ transformed = transforms.SUCCESS_NOP
138
+ else:
139
+ expression = parse_one(command, read="snowflake")
140
+ transformed = self._transform(expression)
141
+ return self._execute(transformed, params)
142
+ except snowflake.connector.errors.ProgrammingError as e:
143
+ self._sqlstate = e.sqlstate
144
+ raise e
145
+
146
+ def _transform(self, expression: exp.Expression) -> exp.Expression:
147
+ return (
148
+ expression.transform(transforms.upper_case_unquoted_identifiers)
149
+ .transform(transforms.update_variables, variables=self._conn.variables)
150
+ .transform(transforms.set_schema, current_database=self._conn.database)
151
+ .transform(transforms.create_database, db_path=self._conn.db_path)
152
+ .transform(transforms.extract_comment_on_table)
153
+ .transform(transforms.extract_comment_on_columns)
154
+ .transform(transforms.information_schema_fs_columns_snowflake)
155
+ .transform(transforms.information_schema_fs_tables_ext)
156
+ .transform(transforms.drop_schema_cascade)
157
+ .transform(transforms.tag)
158
+ .transform(transforms.semi_structured_types)
159
+ .transform(transforms.try_parse_json)
160
+ .transform(transforms.split)
161
+ # NOTE: trim_cast_varchar must be before json_extract_cast_as_varchar
162
+ .transform(transforms.trim_cast_varchar)
163
+ # indices_to_json_extract must be before regex_substr
164
+ .transform(transforms.indices_to_json_extract)
165
+ .transform(transforms.json_extract_cast_as_varchar)
166
+ .transform(transforms.json_extract_cased_as_varchar)
167
+ .transform(transforms.json_extract_precedence)
168
+ .transform(transforms.flatten_value_cast_as_varchar)
169
+ .transform(transforms.flatten)
170
+ .transform(transforms.regex_replace)
171
+ .transform(transforms.regex_substr)
172
+ .transform(transforms.values_columns)
173
+ .transform(transforms.to_date)
174
+ .transform(transforms.to_decimal)
175
+ .transform(transforms.try_to_decimal)
176
+ .transform(transforms.to_timestamp_ntz)
177
+ .transform(transforms.to_timestamp)
178
+ .transform(transforms.object_construct)
179
+ .transform(transforms.timestamp_ntz)
180
+ .transform(transforms.float_to_double)
181
+ .transform(transforms.integer_precision)
182
+ .transform(transforms.extract_text_length)
183
+ .transform(transforms.sample)
184
+ .transform(transforms.array_size)
185
+ .transform(transforms.random)
186
+ .transform(transforms.identifier)
187
+ .transform(transforms.array_agg_within_group)
188
+ .transform(transforms.array_agg)
189
+ .transform(transforms.dateadd_date_cast)
190
+ .transform(transforms.dateadd_string_literal_timestamp_cast)
191
+ .transform(transforms.datediff_string_literal_timestamp_cast)
192
+ .transform(lambda e: transforms.show_schemas(e, self._conn.database))
193
+ .transform(lambda e: transforms.show_objects_tables(e, self._conn.database))
194
+ # TODO collapse into a single show_keys function
195
+ .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="PRIMARY"))
196
+ .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="UNIQUE"))
197
+ .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="FOREIGN"))
198
+ .transform(transforms.show_users)
199
+ .transform(transforms.create_user)
200
+ .transform(transforms.sha256)
201
+ .transform(transforms.create_clone)
202
+ .transform(transforms.alias_in_join)
203
+ .transform(transforms.alter_table_strip_cluster_by)
204
+ )
205
+
206
+ def _execute(
207
+ self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None
208
+ ) -> FakeSnowflakeCursor:
209
+ self._arrow_table = None
210
+ self._arrow_table_fetch_index = None
211
+ self._rowcount = None
212
+
213
+ cmd = expr.key_command(transformed)
214
+
215
+ no_database, no_schema = checks.is_unqualified_table_expression(transformed)
216
+
217
+ if no_database and not self._conn.database_set:
218
+ raise snowflake.connector.errors.ProgrammingError(
219
+ msg=f"Cannot perform {cmd}. This session does not have a current database. Call 'USE DATABASE', or use a qualified name.", # noqa: E501
220
+ errno=90105,
221
+ sqlstate="22000",
222
+ )
223
+ elif no_schema and not self._conn.schema_set:
224
+ raise snowflake.connector.errors.ProgrammingError(
225
+ msg=f"Cannot perform {cmd}. This session does not have a current schema. Call 'USE SCHEMA', or use a qualified name.", # noqa: E501
226
+ errno=90106,
227
+ sqlstate="22000",
228
+ )
229
+
230
+ sql = transformed.sql(dialect="duckdb")
231
+
232
+ if transformed.find(exp.Select) and (seed := transformed.args.get("seed")):
233
+ sql = f"SELECT setseed({seed}); {sql}"
234
+
235
+ result_sql = None
236
+
237
+ try:
238
+ self._log_sql(sql, params)
239
+ self._duck_conn.execute(sql, params)
240
+ except duckdb.BinderException as e:
241
+ msg = e.args[0]
242
+ raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=2043, sqlstate="02000") from None
243
+ except duckdb.CatalogException as e:
244
+ # minimal processing to make it look like a snowflake exception, message content may differ
245
+ msg = cast(str, e.args[0]).split("\n")[0]
246
+ raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=2003, sqlstate="42S02") from None
247
+ except duckdb.TransactionException as e:
248
+ if "cannot rollback - no transaction is active" in str(
249
+ e
250
+ ) or "cannot commit - no transaction is active" in str(e):
251
+ # snowflake doesn't error on rollback or commit outside a tx
252
+ result_sql = SQL_SUCCESS
253
+ else:
254
+ raise e
255
+ except duckdb.ConnectionException as e:
256
+ raise snowflake.connector.errors.DatabaseError(msg=e.args[0], errno=250002, sqlstate="08003") from None
257
+
258
+ affected_count = None
259
+
260
+ if set_database := transformed.args.get("set_database"):
261
+ self._conn.database = set_database
262
+ self._conn.database_set = True
263
+
264
+ elif set_schema := transformed.args.get("set_schema"):
265
+ self._conn.schema = set_schema
266
+ self._conn.schema_set = True
267
+
268
+ elif create_db_name := transformed.args.get("create_db_name"):
269
+ # we created a new database, so create the info schema extensions
270
+ self._duck_conn.execute(info_schema.creation_sql(create_db_name))
271
+ result_sql = SQL_CREATED_DATABASE.substitute(name=create_db_name)
272
+
273
+ elif cmd == "INSERT":
274
+ (affected_count,) = self._duck_conn.fetchall()[0]
275
+ result_sql = SQL_INSERTED_ROWS.substitute(count=affected_count)
276
+
277
+ elif cmd == "UPDATE":
278
+ (affected_count,) = self._duck_conn.fetchall()[0]
279
+ result_sql = SQL_UPDATED_ROWS.substitute(count=affected_count)
280
+
281
+ elif cmd == "DELETE":
282
+ (affected_count,) = self._duck_conn.fetchall()[0]
283
+ result_sql = SQL_DELETED_ROWS.substitute(count=affected_count)
284
+
285
+ elif cmd in ("DESCRIBE TABLE", "DESCRIBE VIEW"):
286
+ # DESCRIBE TABLE/VIEW has already been run above to detect and error if the table exists
287
+ # We now rerun DESCRIBE TABLE/VIEW but transformed with columns to match Snowflake
288
+ result_sql = transformed.transform(
289
+ lambda e: transforms.describe_table(e, self._conn.database, self._conn.schema)
290
+ ).sql(dialect="duckdb")
291
+
292
+ elif (eid := transformed.find(exp.Identifier, bfs=False)) and isinstance(eid.this, str):
293
+ ident = eid.this if eid.quoted else eid.this.upper()
294
+ if cmd == "CREATE SCHEMA" and ident:
295
+ result_sql = SQL_CREATED_SCHEMA.substitute(name=ident)
296
+
297
+ elif cmd == "CREATE TABLE" and ident:
298
+ result_sql = SQL_CREATED_TABLE.substitute(name=ident)
299
+
300
+ elif cmd.startswith("ALTER") and ident:
301
+ result_sql = SQL_SUCCESS
302
+
303
+ elif cmd == "CREATE VIEW" and ident:
304
+ result_sql = SQL_CREATED_VIEW.substitute(name=ident)
305
+
306
+ elif cmd.startswith("DROP") and ident:
307
+ result_sql = SQL_DROPPED.substitute(name=ident)
308
+
309
+ # if dropping the current database/schema then reset conn metadata
310
+ if cmd == "DROP DATABASE" and ident == self._conn.database:
311
+ self._conn.database = None
312
+ self._conn.schema = None
313
+
314
+ elif cmd == "DROP SCHEMA" and ident == self._conn.schema:
315
+ self._conn.schema = None
316
+
317
+ if table_comment := cast(tuple[exp.Table, str], transformed.args.get("table_comment")):
318
+ # record table comment
319
+ table, comment = table_comment
320
+ catalog = table.catalog or self._conn.database
321
+ schema = table.db or self._conn.schema
322
+ assert catalog and schema
323
+ self._duck_conn.execute(info_schema.insert_table_comment_sql(catalog, schema, table.name, comment))
324
+
325
+ if (text_lengths := cast(list[tuple[str, int]], transformed.args.get("text_lengths"))) and (
326
+ table := transformed.find(exp.Table)
327
+ ):
328
+ # record text lengths
329
+ catalog = table.catalog or self._conn.database
330
+ schema = table.db or self._conn.schema
331
+ assert catalog and schema
332
+ self._duck_conn.execute(info_schema.insert_text_lengths_sql(catalog, schema, table.name, text_lengths))
333
+
334
+ if result_sql:
335
+ self._log_sql(result_sql, params)
336
+ self._duck_conn.execute(result_sql)
337
+
338
+ self._arrow_table = self._duck_conn.fetch_arrow_table()
339
+ self._rowcount = affected_count or self._arrow_table.num_rows
340
+
341
+ self._last_sql = result_sql or sql
342
+ self._last_params = params
343
+
344
+ return self
345
+
346
+ def _log_sql(self, sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
347
+ if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
348
+ print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
349
+
350
+ def executemany(
351
+ self,
352
+ command: str,
353
+ seqparams: Sequence[Any] | dict[str, Any],
354
+ **kwargs: Any,
355
+ ) -> FakeSnowflakeCursor:
356
+ if isinstance(seqparams, dict):
357
+ # see https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api
358
+ raise NotImplementedError("dict params not supported yet")
359
+
360
+ # TODO: support insert optimisations
361
+ # the snowflake connector will optimise inserts into a single query
362
+ # unless num_statements != 1 .. but for simplicity we execute each
363
+ # query one by one, which means the response differs
364
+ for p in seqparams:
365
+ self.execute(command, p)
366
+
367
+ return self
368
+
369
+ def fetchall(self) -> list[tuple] | list[dict]:
370
+ if self._arrow_table is None:
371
+ # mimic snowflake python connector error type
372
+ raise TypeError("No open result set")
373
+ return self.fetchmany(self._arrow_table.num_rows)
374
+
375
+ def fetch_pandas_all(self, **kwargs: dict[str, Any]) -> pd.DataFrame:
376
+ if self._arrow_table is None:
377
+ # mimic snowflake python connector error type
378
+ raise snowflake.connector.NotSupportedError("No open result set")
379
+ return self._arrow_table.to_pandas()
380
+
381
+ def fetchone(self) -> dict | tuple | None:
382
+ result = self.fetchmany(1)
383
+ return result[0] if result else None
384
+
385
+ def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
386
+ # https://peps.python.org/pep-0249/#fetchmany
387
+ size = size or self._arraysize
388
+
389
+ if self._arrow_table is None:
390
+ # mimic snowflake python connector error type
391
+ raise TypeError("No open result set")
392
+ tslice = self._arrow_table.slice(offset=self._arrow_table_fetch_index or 0, length=size).to_pylist()
393
+
394
+ if self._arrow_table_fetch_index is None:
395
+ self._arrow_table_fetch_index = size
396
+ else:
397
+ self._arrow_table_fetch_index += size
398
+
399
+ return tslice if self._use_dict_result else [tuple(d.values()) for d in tslice]
400
+
401
+ def get_result_batches(self) -> list[ResultBatch] | None:
402
+ if self._arrow_table is None:
403
+ return None
404
+ return [FakeResultBatch(self._use_dict_result, b) for b in self._arrow_table.to_batches(max_chunksize=1000)]
405
+
406
+ @property
407
+ def rowcount(self) -> int | None:
408
+ return self._rowcount
409
+
410
+ @property
411
+ def sfqid(self) -> str | None:
412
+ return "fakesnow"
413
+
414
+ @property
415
+ def sqlstate(self) -> str | None:
416
+ return self._sqlstate
417
+
418
+ def _rewrite_with_params(
419
+ self,
420
+ command: str,
421
+ params: Sequence[Any] | dict[Any, Any] | None = None,
422
+ ) -> tuple[str, Sequence[Any] | dict[Any, Any] | None]:
423
+ if params and self._conn._paramstyle in ("pyformat", "format"): # noqa: SLF001
424
+ # handle client-side in the same manner as the snowflake python connector
425
+
426
+ def convert(param: Any) -> Any: # noqa: ANN401
427
+ return self._converter.quote(self._converter.escape(self._converter.to_snowflake(param)))
428
+
429
+ if isinstance(params, dict):
430
+ params = {k: convert(v) for k, v in params.items()}
431
+ else:
432
+ params = tuple(convert(v) for v in params)
433
+
434
+ return command % params, None
435
+
436
+ return command, params
437
+
438
+ def _inline_variables(self, sql: str) -> str:
439
+ return self._conn.variables.inline_variables(sql)
440
+
441
+
442
+ class FakeResultBatch(ResultBatch):
443
+ def __init__(self, use_dict_result: bool, batch: pyarrow.RecordBatch):
444
+ self._use_dict_result = use_dict_result
445
+ self._batch = batch
446
+
447
+ def create_iter(
448
+ self, **kwargs: dict[str, Any]
449
+ ) -> Iterator[dict | Exception] | Iterator[tuple | Exception] | Iterator[pyarrow.Table] | Iterator[pd.DataFrame]:
450
+ if self._use_dict_result:
451
+ return iter(self._batch.to_pylist())
452
+
453
+ return iter(tuple(d.values()) for d in self._batch.to_pylist())
454
+
455
+ @property
456
+ def rowcount(self) -> int:
457
+ return self._batch.num_rows
458
+
459
+ def to_pandas(self) -> pd.DataFrame:
460
+ return self._batch.to_pandas()
461
+
462
+ def to_arrow(self) -> pyarrow.Table:
463
+ raise NotImplementedError()