fakesnow 0.9.22__py3-none-any.whl → 0.9.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fakesnow/fakes.py CHANGED
@@ -1,750 +1,3 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- import re
6
- import sys
7
- from collections.abc import Iterable, Iterator, Sequence
8
- from pathlib import Path
9
- from string import Template
10
- from types import TracebackType
11
- from typing import TYPE_CHECKING, Any, Literal, Optional, cast
12
-
13
- import duckdb
14
- from sqlglot import exp
15
-
16
- if TYPE_CHECKING:
17
- import pandas as pd
18
- import pyarrow.lib
19
- import numpy as np
20
- import pyarrow
21
- import snowflake.connector.converter
22
- import snowflake.connector.errors
23
- import sqlglot
24
- from duckdb import DuckDBPyConnection
25
- from snowflake.connector.cursor import DictCursor, ResultMetadata, SnowflakeCursor
26
- from snowflake.connector.result_batch import ResultBatch
27
- from sqlglot import parse_one
28
- from typing_extensions import Self
29
-
30
- import fakesnow.checks as checks
31
- import fakesnow.expr as expr
32
- import fakesnow.info_schema as info_schema
33
- import fakesnow.macros as macros
34
- import fakesnow.transforms as transforms
35
- from fakesnow.variables import Variables
36
-
37
- SCHEMA_UNSET = "schema_unset"
38
- SQL_SUCCESS = "SELECT 'Statement executed successfully.' as 'status'"
39
- SQL_CREATED_DATABASE = Template("SELECT 'Database ${name} successfully created.' as 'status'")
40
- SQL_CREATED_SCHEMA = Template("SELECT 'Schema ${name} successfully created.' as 'status'")
41
- SQL_CREATED_TABLE = Template("SELECT 'Table ${name} successfully created.' as 'status'")
42
- SQL_CREATED_VIEW = Template("SELECT 'View ${name} successfully created.' as 'status'")
43
- SQL_DROPPED = Template("SELECT '${name} successfully dropped.' as 'status'")
44
- SQL_INSERTED_ROWS = Template("SELECT ${count} as 'number of rows inserted'")
45
- SQL_UPDATED_ROWS = Template("SELECT ${count} as 'number of rows updated', 0 as 'number of multi-joined rows updated'")
46
- SQL_DELETED_ROWS = Template("SELECT ${count} as 'number of rows deleted'")
47
-
48
-
49
- class FakeSnowflakeCursor:
50
- def __init__(
51
- self,
52
- conn: FakeSnowflakeConnection,
53
- duck_conn: DuckDBPyConnection,
54
- use_dict_result: bool = False,
55
- ) -> None:
56
- """Create a fake snowflake cursor backed by DuckDB.
57
-
58
- Args:
59
- conn (FakeSnowflakeConnection): Used to maintain current database and schema.
60
- duck_conn (DuckDBPyConnection): DuckDB connection.
61
- use_dict_result (bool, optional): If true rows are returned as dicts otherwise they
62
- are returned as tuples. Defaults to False.
63
- """
64
- self._conn = conn
65
- self._duck_conn = duck_conn
66
- self._use_dict_result = use_dict_result
67
- self._last_sql = None
68
- self._last_params = None
69
- self._sqlstate = None
70
- self._arraysize = 1
71
- self._arrow_table = None
72
- self._arrow_table_fetch_index = None
73
- self._rowcount = None
74
- self._converter = snowflake.connector.converter.SnowflakeConverter()
75
-
76
- def __enter__(self) -> Self:
77
- return self
78
-
79
- def __exit__(
80
- self,
81
- exc_type: type[BaseException] | None,
82
- exc_value: BaseException | None,
83
- traceback: TracebackType | None,
84
- ) -> None:
85
- pass
86
-
87
- @property
88
- def arraysize(self) -> int:
89
- return self._arraysize
90
-
91
- @arraysize.setter
92
- def arraysize(self, value: int) -> None:
93
- self._arraysize = value
94
-
95
- def close(self) -> bool:
96
- self._last_sql = None
97
- self._last_params = None
98
- return True
99
-
100
- def describe(self, command: str, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
101
- """Return the schema of the result without executing the query.
102
-
103
- Takes the same arguments as execute
104
-
105
- Returns:
106
- list[ResultMetadata]: _description_
107
- """
108
-
109
- describe = f"DESCRIBE {command}"
110
- self.execute(describe, *args, **kwargs)
111
- return FakeSnowflakeCursor._describe_as_result_metadata(self.fetchall())
112
-
113
- @property
114
- def description(self) -> list[ResultMetadata]:
115
- # use a separate cursor to avoid consuming the result set on this cursor
116
- with self._conn.cursor() as cur:
117
- expression = sqlglot.parse_one(f"DESCRIBE {self._last_sql}", read="duckdb")
118
- cur._execute(expression, self._last_params) # noqa: SLF001
119
- meta = FakeSnowflakeCursor._describe_as_result_metadata(cur.fetchall())
120
-
121
- return meta
122
-
123
- def execute(
124
- self,
125
- command: str,
126
- params: Sequence[Any] | dict[Any, Any] | None = None,
127
- *args: Any,
128
- **kwargs: Any,
129
- ) -> FakeSnowflakeCursor:
130
- try:
131
- self._sqlstate = None
132
-
133
- if os.environ.get("FAKESNOW_DEBUG") == "snowflake":
134
- print(f"{command};{params=}" if params else f"{command};", file=sys.stderr)
135
-
136
- command = self._inline_variables(command)
137
- command, params = self._rewrite_with_params(command, params)
138
- if self._conn.nop_regexes and any(re.match(p, command, re.IGNORECASE) for p in self._conn.nop_regexes):
139
- transformed = transforms.SUCCESS_NOP
140
- else:
141
- expression = parse_one(command, read="snowflake")
142
- transformed = self._transform(expression)
143
- return self._execute(transformed, params)
144
- except snowflake.connector.errors.ProgrammingError as e:
145
- self._sqlstate = e.sqlstate
146
- raise e
147
-
148
- def _transform(self, expression: exp.Expression) -> exp.Expression:
149
- return (
150
- expression.transform(transforms.upper_case_unquoted_identifiers)
151
- .transform(transforms.update_variables, variables=self._conn.variables)
152
- .transform(transforms.set_schema, current_database=self._conn.database)
153
- .transform(transforms.create_database, db_path=self._conn.db_path)
154
- .transform(transforms.extract_comment_on_table)
155
- .transform(transforms.extract_comment_on_columns)
156
- .transform(transforms.information_schema_fs_columns_snowflake)
157
- .transform(transforms.information_schema_fs_tables_ext)
158
- .transform(transforms.drop_schema_cascade)
159
- .transform(transforms.tag)
160
- .transform(transforms.semi_structured_types)
161
- .transform(transforms.try_parse_json)
162
- .transform(transforms.split)
163
- # NOTE: trim_cast_varchar must be before json_extract_cast_as_varchar
164
- .transform(transforms.trim_cast_varchar)
165
- # indices_to_json_extract must be before regex_substr
166
- .transform(transforms.indices_to_json_extract)
167
- .transform(transforms.json_extract_cast_as_varchar)
168
- .transform(transforms.json_extract_cased_as_varchar)
169
- .transform(transforms.json_extract_precedence)
170
- .transform(transforms.flatten_value_cast_as_varchar)
171
- .transform(transforms.flatten)
172
- .transform(transforms.regex_replace)
173
- .transform(transforms.regex_substr)
174
- .transform(transforms.values_columns)
175
- .transform(transforms.to_date)
176
- .transform(transforms.to_decimal)
177
- .transform(transforms.try_to_decimal)
178
- .transform(transforms.to_timestamp_ntz)
179
- .transform(transforms.to_timestamp)
180
- .transform(transforms.object_construct)
181
- .transform(transforms.timestamp_ntz)
182
- .transform(transforms.float_to_double)
183
- .transform(transforms.integer_precision)
184
- .transform(transforms.extract_text_length)
185
- .transform(transforms.sample)
186
- .transform(transforms.array_size)
187
- .transform(transforms.random)
188
- .transform(transforms.identifier)
189
- .transform(transforms.array_agg_within_group)
190
- .transform(transforms.array_agg)
191
- .transform(transforms.dateadd_date_cast)
192
- .transform(transforms.dateadd_string_literal_timestamp_cast)
193
- .transform(transforms.datediff_string_literal_timestamp_cast)
194
- .transform(lambda e: transforms.show_schemas(e, self._conn.database))
195
- .transform(lambda e: transforms.show_objects_tables(e, self._conn.database))
196
- # TODO collapse into a single show_keys function
197
- .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="PRIMARY"))
198
- .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="UNIQUE"))
199
- .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="FOREIGN"))
200
- .transform(transforms.show_users)
201
- .transform(transforms.create_user)
202
- .transform(transforms.sha256)
203
- .transform(transforms.create_clone)
204
- .transform(transforms.alias_in_join)
205
- .transform(transforms.alter_table_strip_cluster_by)
206
- )
207
-
208
- def _execute(
209
- self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None
210
- ) -> FakeSnowflakeCursor:
211
- self._arrow_table = None
212
- self._arrow_table_fetch_index = None
213
- self._rowcount = None
214
-
215
- cmd = expr.key_command(transformed)
216
-
217
- no_database, no_schema = checks.is_unqualified_table_expression(transformed)
218
-
219
- if no_database and not self._conn.database_set:
220
- raise snowflake.connector.errors.ProgrammingError(
221
- msg=f"Cannot perform {cmd}. This session does not have a current database. Call 'USE DATABASE', or use a qualified name.", # noqa: E501
222
- errno=90105,
223
- sqlstate="22000",
224
- )
225
- elif no_schema and not self._conn.schema_set:
226
- raise snowflake.connector.errors.ProgrammingError(
227
- msg=f"Cannot perform {cmd}. This session does not have a current schema. Call 'USE SCHEMA', or use a qualified name.", # noqa: E501
228
- errno=90106,
229
- sqlstate="22000",
230
- )
231
-
232
- sql = transformed.sql(dialect="duckdb")
233
-
234
- if transformed.find(exp.Select) and (seed := transformed.args.get("seed")):
235
- sql = f"SELECT setseed({seed}); {sql}"
236
-
237
- result_sql = None
238
-
239
- try:
240
- self._log_sql(sql, params)
241
- self._duck_conn.execute(sql, params)
242
- except duckdb.BinderException as e:
243
- msg = e.args[0]
244
- raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=2043, sqlstate="02000") from None
245
- except duckdb.CatalogException as e:
246
- # minimal processing to make it look like a snowflake exception, message content may differ
247
- msg = cast(str, e.args[0]).split("\n")[0]
248
- raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=2003, sqlstate="42S02") from None
249
- except duckdb.TransactionException as e:
250
- if "cannot rollback - no transaction is active" in str(
251
- e
252
- ) or "cannot commit - no transaction is active" in str(e):
253
- # snowflake doesn't error on rollback or commit outside a tx
254
- result_sql = SQL_SUCCESS
255
- else:
256
- raise e
257
- except duckdb.ConnectionException as e:
258
- raise snowflake.connector.errors.DatabaseError(msg=e.args[0], errno=250002, sqlstate="08003") from None
259
-
260
- affected_count = None
261
-
262
- if set_database := transformed.args.get("set_database"):
263
- self._conn.database = set_database
264
- self._conn.database_set = True
265
-
266
- elif set_schema := transformed.args.get("set_schema"):
267
- self._conn.schema = set_schema
268
- self._conn.schema_set = True
269
-
270
- elif create_db_name := transformed.args.get("create_db_name"):
271
- # we created a new database, so create the info schema extensions
272
- self._duck_conn.execute(info_schema.creation_sql(create_db_name))
273
- result_sql = SQL_CREATED_DATABASE.substitute(name=create_db_name)
274
-
275
- elif cmd == "INSERT":
276
- (affected_count,) = self._duck_conn.fetchall()[0]
277
- result_sql = SQL_INSERTED_ROWS.substitute(count=affected_count)
278
-
279
- elif cmd == "UPDATE":
280
- (affected_count,) = self._duck_conn.fetchall()[0]
281
- result_sql = SQL_UPDATED_ROWS.substitute(count=affected_count)
282
-
283
- elif cmd == "DELETE":
284
- (affected_count,) = self._duck_conn.fetchall()[0]
285
- result_sql = SQL_DELETED_ROWS.substitute(count=affected_count)
286
-
287
- elif cmd in ("DESCRIBE TABLE", "DESCRIBE VIEW"):
288
- # DESCRIBE TABLE/VIEW has already been run above to detect and error if the table exists
289
- # We now rerun DESCRIBE TABLE/VIEW but transformed with columns to match Snowflake
290
- result_sql = transformed.transform(
291
- lambda e: transforms.describe_table(e, self._conn.database, self._conn.schema)
292
- ).sql(dialect="duckdb")
293
-
294
- elif (eid := transformed.find(exp.Identifier, bfs=False)) and isinstance(eid.this, str):
295
- ident = eid.this if eid.quoted else eid.this.upper()
296
- if cmd == "CREATE SCHEMA" and ident:
297
- result_sql = SQL_CREATED_SCHEMA.substitute(name=ident)
298
-
299
- elif cmd == "CREATE TABLE" and ident:
300
- result_sql = SQL_CREATED_TABLE.substitute(name=ident)
301
-
302
- elif cmd.startswith("ALTER") and ident:
303
- result_sql = SQL_SUCCESS
304
-
305
- elif cmd == "CREATE VIEW" and ident:
306
- result_sql = SQL_CREATED_VIEW.substitute(name=ident)
307
-
308
- elif cmd.startswith("DROP") and ident:
309
- result_sql = SQL_DROPPED.substitute(name=ident)
310
-
311
- # if dropping the current database/schema then reset conn metadata
312
- if cmd == "DROP DATABASE" and ident == self._conn.database:
313
- self._conn.database = None
314
- self._conn.schema = None
315
-
316
- elif cmd == "DROP SCHEMA" and ident == self._conn.schema:
317
- self._conn.schema = None
318
-
319
- if table_comment := cast(tuple[exp.Table, str], transformed.args.get("table_comment")):
320
- # record table comment
321
- table, comment = table_comment
322
- catalog = table.catalog or self._conn.database
323
- schema = table.db or self._conn.schema
324
- assert catalog and schema
325
- self._duck_conn.execute(info_schema.insert_table_comment_sql(catalog, schema, table.name, comment))
326
-
327
- if (text_lengths := cast(list[tuple[str, int]], transformed.args.get("text_lengths"))) and (
328
- table := transformed.find(exp.Table)
329
- ):
330
- # record text lengths
331
- catalog = table.catalog or self._conn.database
332
- schema = table.db or self._conn.schema
333
- assert catalog and schema
334
- self._duck_conn.execute(info_schema.insert_text_lengths_sql(catalog, schema, table.name, text_lengths))
335
-
336
- if result_sql:
337
- self._log_sql(result_sql, params)
338
- self._duck_conn.execute(result_sql)
339
-
340
- self._arrow_table = self._duck_conn.fetch_arrow_table()
341
- self._rowcount = affected_count or self._arrow_table.num_rows
342
-
343
- self._last_sql = result_sql or sql
344
- self._last_params = params
345
-
346
- return self
347
-
348
- def _log_sql(self, sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
349
- if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
350
- print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
351
-
352
- def executemany(
353
- self,
354
- command: str,
355
- seqparams: Sequence[Any] | dict[str, Any],
356
- **kwargs: Any,
357
- ) -> FakeSnowflakeCursor:
358
- if isinstance(seqparams, dict):
359
- # see https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api
360
- raise NotImplementedError("dict params not supported yet")
361
-
362
- # TODO: support insert optimisations
363
- # the snowflake connector will optimise inserts into a single query
364
- # unless num_statements != 1 .. but for simplicity we execute each
365
- # query one by one, which means the response differs
366
- for p in seqparams:
367
- self.execute(command, p)
368
-
369
- return self
370
-
371
- def fetchall(self) -> list[tuple] | list[dict]:
372
- if self._arrow_table is None:
373
- # mimic snowflake python connector error type
374
- raise TypeError("No open result set")
375
- return self.fetchmany(self._arrow_table.num_rows)
376
-
377
- def fetch_pandas_all(self, **kwargs: dict[str, Any]) -> pd.DataFrame:
378
- if self._arrow_table is None:
379
- # mimic snowflake python connector error type
380
- raise snowflake.connector.NotSupportedError("No open result set")
381
- return self._arrow_table.to_pandas()
382
-
383
- def fetchone(self) -> dict | tuple | None:
384
- result = self.fetchmany(1)
385
- return result[0] if result else None
386
-
387
- def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
388
- # https://peps.python.org/pep-0249/#fetchmany
389
- size = size or self._arraysize
390
-
391
- if self._arrow_table is None:
392
- # mimic snowflake python connector error type
393
- raise TypeError("No open result set")
394
- tslice = self._arrow_table.slice(offset=self._arrow_table_fetch_index or 0, length=size).to_pylist()
395
-
396
- if self._arrow_table_fetch_index is None:
397
- self._arrow_table_fetch_index = size
398
- else:
399
- self._arrow_table_fetch_index += size
400
-
401
- return tslice if self._use_dict_result else [tuple(d.values()) for d in tslice]
402
-
403
- def get_result_batches(self) -> list[ResultBatch] | None:
404
- if self._arrow_table is None:
405
- return None
406
- return [FakeResultBatch(self._use_dict_result, b) for b in self._arrow_table.to_batches(max_chunksize=1000)]
407
-
408
- @property
409
- def rowcount(self) -> int | None:
410
- return self._rowcount
411
-
412
- @property
413
- def sfqid(self) -> str | None:
414
- return "fakesnow"
415
-
416
- @property
417
- def sqlstate(self) -> str | None:
418
- return self._sqlstate
419
-
420
- @staticmethod
421
- def _describe_as_result_metadata(describe_results: list) -> list[ResultMetadata]:
422
- # fmt: off
423
- def as_result_metadata(column_name: str, column_type: str, _: str) -> ResultMetadata:
424
- # see https://docs.snowflake.com/en/user-guide/python-connector-api.html#type-codes
425
- # and https://arrow.apache.org/docs/python/api/datatypes.html#type-checking
426
- if column_type in {"BIGINT", "INTEGER"}:
427
- return ResultMetadata(
428
- name=column_name, type_code=0, display_size=None, internal_size=None, precision=38, scale=0, is_nullable=True # noqa: E501
429
- )
430
- elif column_type.startswith("DECIMAL"):
431
- match = re.search(r'\((\d+),(\d+)\)', column_type)
432
- if match:
433
- precision = int(match[1])
434
- scale = int(match[2])
435
- else:
436
- precision = scale = None
437
- return ResultMetadata(
438
- name=column_name, type_code=0, display_size=None, internal_size=None, precision=precision, scale=scale, is_nullable=True # noqa: E501
439
- )
440
- elif column_type == "VARCHAR":
441
- # TODO: fetch internal_size from varchar size
442
- return ResultMetadata(
443
- name=column_name, type_code=2, display_size=None, internal_size=16777216, precision=None, scale=None, is_nullable=True # noqa: E501
444
- )
445
- elif column_type == "DOUBLE":
446
- return ResultMetadata(
447
- name=column_name, type_code=1, display_size=None, internal_size=None, precision=None, scale=None, is_nullable=True # noqa: E501
448
- )
449
- elif column_type == "BOOLEAN":
450
- return ResultMetadata(
451
- name=column_name, type_code=13, display_size=None, internal_size=None, precision=None, scale=None, is_nullable=True # noqa: E501
452
- )
453
- elif column_type == "DATE":
454
- return ResultMetadata(
455
- name=column_name, type_code=3, display_size=None, internal_size=None, precision=None, scale=None, is_nullable=True # noqa: E501
456
- )
457
- elif column_type in {"TIMESTAMP", "TIMESTAMP_NS"}:
458
- return ResultMetadata(
459
- name=column_name, type_code=8, display_size=None, internal_size=None, precision=0, scale=9, is_nullable=True # noqa: E501
460
- )
461
- elif column_type == "TIMESTAMP WITH TIME ZONE":
462
- return ResultMetadata(
463
- name=column_name, type_code=7, display_size=None, internal_size=None, precision=0, scale=9, is_nullable=True # noqa: E501
464
- )
465
- elif column_type == "BLOB":
466
- return ResultMetadata(
467
- name=column_name, type_code=11, display_size=None, internal_size=8388608, precision=None, scale=None, is_nullable=True # noqa: E501
468
- )
469
- elif column_type == "TIME":
470
- return ResultMetadata(
471
- name=column_name, type_code=12, display_size=None, internal_size=None, precision=0, scale=9, is_nullable=True # noqa: E501
472
- )
473
- elif column_type == "JSON":
474
- # TODO: correctly map OBJECT and ARRAY see https://github.com/tekumara/fakesnow/issues/26
475
- return ResultMetadata(
476
- name=column_name, type_code=5, display_size=None, internal_size=None, precision=None, scale=None, is_nullable=True # noqa: E501
477
- )
478
- else:
479
- # TODO handle more types
480
- raise NotImplementedError(f"for column type {column_type}")
481
-
482
- # fmt: on
483
-
484
- meta = [
485
- as_result_metadata(column_name, column_type, null)
486
- for (column_name, column_type, null, _, _, _) in describe_results
487
- ]
488
- return meta
489
-
490
- def _rewrite_with_params(
491
- self,
492
- command: str,
493
- params: Sequence[Any] | dict[Any, Any] | None = None,
494
- ) -> tuple[str, Sequence[Any] | dict[Any, Any] | None]:
495
- if params and self._conn._paramstyle in ("pyformat", "format"): # noqa: SLF001
496
- # handle client-side in the same manner as the snowflake python connector
497
-
498
- def convert(param: Any) -> Any: # noqa: ANN401
499
- return self._converter.quote(self._converter.escape(self._converter.to_snowflake(param)))
500
-
501
- if isinstance(params, dict):
502
- params = {k: convert(v) for k, v in params.items()}
503
- else:
504
- params = tuple(convert(v) for v in params)
505
-
506
- return command % params, None
507
-
508
- return command, params
509
-
510
- def _inline_variables(self, sql: str) -> str:
511
- return self._conn.variables.inline_variables(sql)
512
-
513
-
514
- class FakeSnowflakeConnection:
515
- def __init__(
516
- self,
517
- duck_conn: DuckDBPyConnection,
518
- database: str | None = None,
519
- schema: str | None = None,
520
- create_database: bool = True,
521
- create_schema: bool = True,
522
- db_path: str | os.PathLike | None = None,
523
- nop_regexes: list[str] | None = None,
524
- *args: Any,
525
- **kwargs: Any,
526
- ):
527
- self._duck_conn = duck_conn
528
- # upper case database and schema like snowflake unquoted identifiers
529
- # so they appear as upper-cased in information_schema
530
- # catalog and schema names are not actually case-sensitive in duckdb even though
531
- # they are as cased in information_schema.schemata, so when selecting from
532
- # information_schema.schemata below we use upper-case to match any existing duckdb
533
- # catalog or schemas like "information_schema"
534
- self.database = database and database.upper()
535
- self.schema = schema and schema.upper()
536
-
537
- self.database_set = False
538
- self.schema_set = False
539
- self.db_path = Path(db_path) if db_path else None
540
- self.nop_regexes = nop_regexes
541
- self._paramstyle = snowflake.connector.paramstyle
542
- self.variables = Variables()
543
-
544
- # create database if needed
545
- if (
546
- create_database
547
- and self.database
548
- and not duck_conn.execute(
549
- f"""select * from information_schema.schemata
550
- where upper(catalog_name) = '{self.database}'"""
551
- ).fetchone()
552
- ):
553
- db_file = f"{self.db_path/self.database}.db" if self.db_path else ":memory:"
554
- duck_conn.execute(f"ATTACH DATABASE '{db_file}' AS {self.database}")
555
- duck_conn.execute(info_schema.creation_sql(self.database))
556
- duck_conn.execute(macros.creation_sql(self.database))
557
-
558
- # create schema if needed
559
- if (
560
- create_schema
561
- and self.database
562
- and self.schema
563
- and not duck_conn.execute(
564
- f"""select * from information_schema.schemata
565
- where upper(catalog_name) = '{self.database}' and upper(schema_name) = '{self.schema}'"""
566
- ).fetchone()
567
- ):
568
- duck_conn.execute(f"CREATE SCHEMA {self.database}.{self.schema}")
569
-
570
- # set database and schema if both exist
571
- if (
572
- self.database
573
- and self.schema
574
- and duck_conn.execute(
575
- f"""select * from information_schema.schemata
576
- where upper(catalog_name) = '{self.database}' and upper(schema_name) = '{self.schema}'"""
577
- ).fetchone()
578
- ):
579
- duck_conn.execute(f"SET schema='{self.database}.{self.schema}'")
580
- self.database_set = True
581
- self.schema_set = True
582
- # set database if only that exists
583
- elif (
584
- self.database
585
- and duck_conn.execute(
586
- f"""select * from information_schema.schemata
587
- where upper(catalog_name) = '{self.database}'"""
588
- ).fetchone()
589
- ):
590
- duck_conn.execute(f"SET schema='{self.database}.main'")
591
- self.database_set = True
592
-
593
- # use UTC instead of local time zone for consistent testing
594
- duck_conn.execute("SET GLOBAL TimeZone = 'UTC'")
595
-
596
- def __enter__(self) -> Self:
597
- return self
598
-
599
- def __exit__(
600
- self,
601
- exc_type: type[BaseException] | None,
602
- exc_value: BaseException | None,
603
- traceback: TracebackType | None,
604
- ) -> None:
605
- pass
606
-
607
- def close(self, retry: bool = True) -> None:
608
- self._duck_conn.close()
609
-
610
- def commit(self) -> None:
611
- self.cursor().execute("COMMIT")
612
-
613
- def cursor(self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor) -> FakeSnowflakeCursor:
614
- # TODO: use duck_conn cursor for thread-safety
615
- return FakeSnowflakeCursor(conn=self, duck_conn=self._duck_conn, use_dict_result=cursor_class == DictCursor)
616
-
617
- def execute_string(
618
- self,
619
- sql_text: str,
620
- remove_comments: bool = False,
621
- return_cursors: bool = True,
622
- cursor_class: type[SnowflakeCursor] = SnowflakeCursor,
623
- **kwargs: dict[str, Any],
624
- ) -> Iterable[FakeSnowflakeCursor]:
625
- cursors = [
626
- self.cursor(cursor_class).execute(e.sql(dialect="snowflake"))
627
- for e in sqlglot.parse(sql_text, read="snowflake")
628
- if e and not isinstance(e, exp.Semicolon) # ignore comments
629
- ]
630
- return cursors if return_cursors else []
631
-
632
- def rollback(self) -> None:
633
- self.cursor().execute("ROLLBACK")
634
-
635
- def _insert_df(self, df: pd.DataFrame, table_name: str) -> int:
636
- # Objects in dataframes are written as parquet structs, and snowflake loads parquet structs as json strings.
637
- # Whereas duckdb analyses a dataframe see https://duckdb.org/docs/api/python/data_ingestion.html#pandas-dataframes--object-columns
638
- # and converts a object to the most specific type possible, eg: dict -> STRUCT, MAP or varchar, and list -> LIST
639
- # For dicts see https://github.com/duckdb/duckdb/pull/3985 and https://github.com/duckdb/duckdb/issues/9510
640
- #
641
- # When the rows have dicts with different keys there isn't a single STRUCT that can cover them, so the type is
642
- # varchar and value a string containing a struct representation. In order to support dicts with different keys
643
- # we first convert the dicts to json strings. A pity we can't do something inside duckdb and avoid the dataframe
644
- # copy and transform in python.
645
-
646
- df = df.copy()
647
-
648
- # Identify columns of type object
649
- object_cols = df.select_dtypes(include=["object"]).columns
650
-
651
- # Apply json.dumps to these columns
652
- for col in object_cols:
653
- # don't jsonify string
654
- df[col] = df[col].apply(lambda x: json.dumps(x) if isinstance(x, (dict, list)) else x)
655
-
656
- escaped_cols = ",".join(f'"{col}"' for col in df.columns.to_list())
657
- self._duck_conn.execute(f"INSERT INTO {table_name}({escaped_cols}) SELECT * FROM df")
658
-
659
- return self._duck_conn.fetchall()[0][0]
660
-
661
-
662
- class FakeResultBatch(ResultBatch):
663
- def __init__(self, use_dict_result: bool, batch: pyarrow.RecordBatch):
664
- self._use_dict_result = use_dict_result
665
- self._batch = batch
666
-
667
- def create_iter(
668
- self, **kwargs: dict[str, Any]
669
- ) -> Iterator[dict | Exception] | Iterator[tuple | Exception] | Iterator[pyarrow.Table] | Iterator[pd.DataFrame]:
670
- if self._use_dict_result:
671
- return iter(self._batch.to_pylist())
672
-
673
- return iter(tuple(d.values()) for d in self._batch.to_pylist())
674
-
675
- @property
676
- def rowcount(self) -> int:
677
- return self._batch.num_rows
678
-
679
- def to_pandas(self) -> pd.DataFrame:
680
- return self._batch.to_pandas()
681
-
682
- def to_arrow(self) -> pyarrow.Table:
683
- raise NotImplementedError()
684
-
685
-
686
- CopyResult = tuple[
687
- str,
688
- str,
689
- int,
690
- int,
691
- int,
692
- int,
693
- Optional[str],
694
- Optional[int],
695
- Optional[int],
696
- Optional[str],
697
- ]
698
-
699
- WritePandasResult = tuple[
700
- bool,
701
- int,
702
- int,
703
- Sequence[CopyResult],
704
- ]
705
-
706
-
707
- def sql_type(dtype: np.dtype) -> str:
708
- if str(dtype) == "int64":
709
- return "NUMBER"
710
- elif str(dtype) == "object":
711
- return "VARCHAR"
712
- else:
713
- raise NotImplementedError(f"sql_type {dtype=}")
714
-
715
-
716
- def write_pandas(
717
- conn: FakeSnowflakeConnection,
718
- df: pd.DataFrame,
719
- table_name: str,
720
- database: str | None = None,
721
- schema: str | None = None,
722
- chunk_size: int | None = None,
723
- compression: str = "gzip",
724
- on_error: str = "abort_statement",
725
- parallel: int = 4,
726
- quote_identifiers: bool = True,
727
- auto_create_table: bool = False,
728
- create_temp_table: bool = False,
729
- overwrite: bool = False,
730
- table_type: Literal["", "temp", "temporary", "transient"] = "",
731
- **kwargs: Any,
732
- ) -> WritePandasResult:
733
- name = table_name
734
- if schema:
735
- name = f"{schema}.{name}"
736
- if database:
737
- name = f"{database}.{name}"
738
-
739
- if auto_create_table:
740
- cols = [f"{c} {sql_type(t)}" for c, t in df.dtypes.to_dict().items()]
741
-
742
- conn.cursor().execute(f"CREATE TABLE IF NOT EXISTS {name} ({','.join(cols)})")
743
-
744
- count = conn._insert_df(df, name) # noqa: SLF001
745
-
746
- # mocks https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#output
747
- mock_copy_results = [("fakesnow/file0.txt", "LOADED", count, count, 1, 0, None, None, None, None)]
748
-
749
- # return success
750
- return (True, len(mock_copy_results), count, mock_copy_results)
1
+ from .conn import FakeSnowflakeConnection as FakeSnowflakeConnection
2
+ from .cursor import FakeSnowflakeCursor as FakeSnowflakeCursor
3
+ from .pandas_tools import write_pandas as write_pandas