fakesnow 0.9.21__py3-none-any.whl → 0.9.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fakesnow/fakes.py CHANGED
@@ -1,747 +1,3 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- import re
6
- import sys
7
- from collections.abc import Iterable, Iterator, Sequence
8
- from pathlib import Path
9
- from string import Template
10
- from types import TracebackType
11
- from typing import TYPE_CHECKING, Any, Literal, Optional, cast
12
-
13
- import duckdb
14
- from sqlglot import exp
15
-
16
- if TYPE_CHECKING:
17
- import pandas as pd
18
- import pyarrow.lib
19
- import numpy as np
20
- import pyarrow
21
- import snowflake.connector.converter
22
- import snowflake.connector.errors
23
- import sqlglot
24
- from duckdb import DuckDBPyConnection
25
- from snowflake.connector.cursor import DictCursor, ResultMetadata, SnowflakeCursor
26
- from snowflake.connector.result_batch import ResultBatch
27
- from sqlglot import parse_one
28
- from typing_extensions import Self
29
-
30
- import fakesnow.checks as checks
31
- import fakesnow.expr as expr
32
- import fakesnow.info_schema as info_schema
33
- import fakesnow.macros as macros
34
- import fakesnow.transforms as transforms
35
- from fakesnow.variables import Variables
36
-
37
- SCHEMA_UNSET = "schema_unset"
38
- SQL_SUCCESS = "SELECT 'Statement executed successfully.' as 'status'"
39
- SQL_CREATED_DATABASE = Template("SELECT 'Database ${name} successfully created.' as 'status'")
40
- SQL_CREATED_SCHEMA = Template("SELECT 'Schema ${name} successfully created.' as 'status'")
41
- SQL_CREATED_TABLE = Template("SELECT 'Table ${name} successfully created.' as 'status'")
42
- SQL_CREATED_VIEW = Template("SELECT 'View ${name} successfully created.' as 'status'")
43
- SQL_DROPPED = Template("SELECT '${name} successfully dropped.' as 'status'")
44
- SQL_INSERTED_ROWS = Template("SELECT ${count} as 'number of rows inserted'")
45
- SQL_UPDATED_ROWS = Template("SELECT ${count} as 'number of rows updated', 0 as 'number of multi-joined rows updated'")
46
- SQL_DELETED_ROWS = Template("SELECT ${count} as 'number of rows deleted'")
47
-
48
-
49
- class FakeSnowflakeCursor:
50
- def __init__(
51
- self,
52
- conn: FakeSnowflakeConnection,
53
- duck_conn: DuckDBPyConnection,
54
- use_dict_result: bool = False,
55
- ) -> None:
56
- """Create a fake snowflake cursor backed by DuckDB.
57
-
58
- Args:
59
- conn (FakeSnowflakeConnection): Used to maintain current database and schema.
60
- duck_conn (DuckDBPyConnection): DuckDB connection.
61
- use_dict_result (bool, optional): If true rows are returned as dicts otherwise they
62
- are returned as tuples. Defaults to False.
63
- """
64
- self._conn = conn
65
- self._duck_conn = duck_conn
66
- self._use_dict_result = use_dict_result
67
- self._last_sql = None
68
- self._last_params = None
69
- self._sqlstate = None
70
- self._arraysize = 1
71
- self._arrow_table = None
72
- self._arrow_table_fetch_index = None
73
- self._rowcount = None
74
- self._converter = snowflake.connector.converter.SnowflakeConverter()
75
-
76
- def __enter__(self) -> Self:
77
- return self
78
-
79
- def __exit__(
80
- self,
81
- exc_type: type[BaseException] | None,
82
- exc_value: BaseException | None,
83
- traceback: TracebackType | None,
84
- ) -> None:
85
- pass
86
-
87
- @property
88
- def arraysize(self) -> int:
89
- return self._arraysize
90
-
91
- @arraysize.setter
92
- def arraysize(self, value: int) -> None:
93
- self._arraysize = value
94
-
95
- def close(self) -> bool:
96
- self._last_sql = None
97
- self._last_params = None
98
- return True
99
-
100
- def describe(self, command: str, *args: Any, **kwargs: Any) -> list[ResultMetadata]:
101
- """Return the schema of the result without executing the query.
102
-
103
- Takes the same arguments as execute
104
-
105
- Returns:
106
- list[ResultMetadata]: _description_
107
- """
108
-
109
- describe = f"DESCRIBE {command}"
110
- self.execute(describe, *args, **kwargs)
111
- return FakeSnowflakeCursor._describe_as_result_metadata(self.fetchall())
112
-
113
- @property
114
- def description(self) -> list[ResultMetadata]:
115
- # use a separate cursor to avoid consuming the result set on this cursor
116
- with self._conn.cursor() as cur:
117
- # self._duck_conn.execute(sql, params)
118
- expression = sqlglot.parse_one(f"DESCRIBE {self._last_sql}", read="duckdb")
119
- cur._execute(expression, self._last_params) # noqa: SLF001
120
- meta = FakeSnowflakeCursor._describe_as_result_metadata(cur.fetchall())
121
-
122
- return meta
123
-
124
- def execute(
125
- self,
126
- command: str,
127
- params: Sequence[Any] | dict[Any, Any] | None = None,
128
- *args: Any,
129
- **kwargs: Any,
130
- ) -> FakeSnowflakeCursor:
131
- try:
132
- self._sqlstate = None
133
-
134
- if os.environ.get("FAKESNOW_DEBUG") == "snowflake":
135
- print(f"{command};{params=}" if params else f"{command};", file=sys.stderr)
136
-
137
- command = self._inline_variables(command)
138
- command, params = self._rewrite_with_params(command, params)
139
- if self._conn.nop_regexes and any(re.match(p, command, re.IGNORECASE) for p in self._conn.nop_regexes):
140
- transformed = transforms.SUCCESS_NOP
141
- else:
142
- expression = parse_one(command, read="snowflake")
143
- transformed = self._transform(expression)
144
- return self._execute(transformed, params)
145
- except snowflake.connector.errors.ProgrammingError as e:
146
- self._sqlstate = e.sqlstate
147
- raise e
148
-
149
- def _transform(self, expression: exp.Expression) -> exp.Expression:
150
- return (
151
- expression.transform(transforms.upper_case_unquoted_identifiers)
152
- .transform(transforms.update_variables, variables=self._conn.variables)
153
- .transform(transforms.set_schema, current_database=self._conn.database)
154
- .transform(transforms.create_database, db_path=self._conn.db_path)
155
- .transform(transforms.extract_comment_on_table)
156
- .transform(transforms.extract_comment_on_columns)
157
- .transform(transforms.information_schema_fs_columns_snowflake)
158
- .transform(transforms.information_schema_fs_tables_ext)
159
- .transform(transforms.drop_schema_cascade)
160
- .transform(transforms.tag)
161
- .transform(transforms.semi_structured_types)
162
- .transform(transforms.try_parse_json)
163
- .transform(transforms.split)
164
- # NOTE: trim_cast_varchar must be before json_extract_cast_as_varchar
165
- .transform(transforms.trim_cast_varchar)
166
- # indices_to_json_extract must be before regex_substr
167
- .transform(transforms.indices_to_json_extract)
168
- .transform(transforms.json_extract_cast_as_varchar)
169
- .transform(transforms.json_extract_cased_as_varchar)
170
- .transform(transforms.json_extract_precedence)
171
- .transform(transforms.flatten_value_cast_as_varchar)
172
- .transform(transforms.flatten)
173
- .transform(transforms.regex_replace)
174
- .transform(transforms.regex_substr)
175
- .transform(transforms.values_columns)
176
- .transform(transforms.to_date)
177
- .transform(transforms.to_decimal)
178
- .transform(transforms.try_to_decimal)
179
- .transform(transforms.to_timestamp_ntz)
180
- .transform(transforms.to_timestamp)
181
- .transform(transforms.object_construct)
182
- .transform(transforms.timestamp_ntz)
183
- .transform(transforms.float_to_double)
184
- .transform(transforms.integer_precision)
185
- .transform(transforms.extract_text_length)
186
- .transform(transforms.sample)
187
- .transform(transforms.array_size)
188
- .transform(transforms.random)
189
- .transform(transforms.identifier)
190
- .transform(transforms.array_agg_within_group)
191
- .transform(transforms.array_agg)
192
- .transform(transforms.dateadd_date_cast)
193
- .transform(transforms.dateadd_string_literal_timestamp_cast)
194
- .transform(transforms.datediff_string_literal_timestamp_cast)
195
- .transform(lambda e: transforms.show_schemas(e, self._conn.database))
196
- .transform(lambda e: transforms.show_objects_tables(e, self._conn.database))
197
- # TODO collapse into a single show_keys function
198
- .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="PRIMARY"))
199
- .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="UNIQUE"))
200
- .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="FOREIGN"))
201
- .transform(transforms.show_users)
202
- .transform(transforms.create_user)
203
- .transform(transforms.sha256)
204
- .transform(transforms.create_clone)
205
- .transform(transforms.alias_in_join)
206
- .transform(transforms.alter_table_strip_cluster_by)
207
- )
208
-
209
- def _execute(
210
- self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None
211
- ) -> FakeSnowflakeCursor:
212
- self._arrow_table = None
213
- self._arrow_table_fetch_index = None
214
- self._rowcount = None
215
-
216
- cmd = expr.key_command(transformed)
217
-
218
- no_database, no_schema = checks.is_unqualified_table_expression(transformed)
219
-
220
- if no_database and not self._conn.database_set:
221
- raise snowflake.connector.errors.ProgrammingError(
222
- msg=f"Cannot perform {cmd}. This session does not have a current database. Call 'USE DATABASE', or use a qualified name.", # noqa: E501
223
- errno=90105,
224
- sqlstate="22000",
225
- )
226
- elif no_schema and not self._conn.schema_set:
227
- raise snowflake.connector.errors.ProgrammingError(
228
- msg=f"Cannot perform {cmd}. This session does not have a current schema. Call 'USE SCHEMA', or use a qualified name.", # noqa: E501
229
- errno=90106,
230
- sqlstate="22000",
231
- )
232
-
233
- sql = transformed.sql(dialect="duckdb")
234
-
235
- if transformed.find(exp.Select) and (seed := transformed.args.get("seed")):
236
- sql = f"SELECT setseed({seed}); {sql}"
237
-
238
- if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
239
- print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
240
-
241
- result_sql = None
242
-
243
- try:
244
- self._duck_conn.execute(sql, params)
245
- except duckdb.BinderException as e:
246
- msg = e.args[0]
247
- raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=2043, sqlstate="02000") from None
248
- except duckdb.CatalogException as e:
249
- # minimal processing to make it look like a snowflake exception, message content may differ
250
- msg = cast(str, e.args[0]).split("\n")[0]
251
- raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=2003, sqlstate="42S02") from None
252
- except duckdb.TransactionException as e:
253
- if "cannot rollback - no transaction is active" in str(
254
- e
255
- ) or "cannot commit - no transaction is active" in str(e):
256
- # snowflake doesn't error on rollback or commit outside a tx
257
- result_sql = SQL_SUCCESS
258
- else:
259
- raise e
260
- except duckdb.ConnectionException as e:
261
- raise snowflake.connector.errors.DatabaseError(msg=e.args[0], errno=250002, sqlstate="08003") from None
262
-
263
- affected_count = None
264
-
265
- if set_database := transformed.args.get("set_database"):
266
- self._conn.database = set_database
267
- self._conn.database_set = True
268
-
269
- elif set_schema := transformed.args.get("set_schema"):
270
- self._conn.schema = set_schema
271
- self._conn.schema_set = True
272
-
273
- elif create_db_name := transformed.args.get("create_db_name"):
274
- # we created a new database, so create the info schema extensions
275
- self._duck_conn.execute(info_schema.creation_sql(create_db_name))
276
- result_sql = SQL_CREATED_DATABASE.substitute(name=create_db_name)
277
-
278
- elif cmd == "INSERT":
279
- (affected_count,) = self._duck_conn.fetchall()[0]
280
- result_sql = SQL_INSERTED_ROWS.substitute(count=affected_count)
281
-
282
- elif cmd == "UPDATE":
283
- (affected_count,) = self._duck_conn.fetchall()[0]
284
- result_sql = SQL_UPDATED_ROWS.substitute(count=affected_count)
285
-
286
- elif cmd == "DELETE":
287
- (affected_count,) = self._duck_conn.fetchall()[0]
288
- result_sql = SQL_DELETED_ROWS.substitute(count=affected_count)
289
-
290
- elif cmd == "DESCRIBE TABLE":
291
- # DESCRIBE TABLE has already been run above to detect and error if the table exists
292
- # We now rerun DESCRIBE TABLE but transformed with columns to match Snowflake
293
- result_sql = transformed.transform(
294
- lambda e: transforms.describe_table(e, self._conn.database, self._conn.schema)
295
- ).sql(dialect="duckdb")
296
-
297
- elif (eid := transformed.find(exp.Identifier, bfs=False)) and isinstance(eid.this, str):
298
- ident = eid.this if eid.quoted else eid.this.upper()
299
- if cmd == "CREATE SCHEMA" and ident:
300
- result_sql = SQL_CREATED_SCHEMA.substitute(name=ident)
301
-
302
- elif cmd == "CREATE TABLE" and ident:
303
- result_sql = SQL_CREATED_TABLE.substitute(name=ident)
304
-
305
- elif cmd.startswith("ALTER") and ident:
306
- result_sql = SQL_SUCCESS
307
-
308
- elif cmd == "CREATE VIEW" and ident:
309
- result_sql = SQL_CREATED_VIEW.substitute(name=ident)
310
-
311
- elif cmd.startswith("DROP") and ident:
312
- result_sql = SQL_DROPPED.substitute(name=ident)
313
-
314
- # if dropping the current database/schema then reset conn metadata
315
- if cmd == "DROP DATABASE" and ident == self._conn.database:
316
- self._conn.database = None
317
- self._conn.schema = None
318
-
319
- elif cmd == "DROP SCHEMA" and ident == self._conn.schema:
320
- self._conn.schema = None
321
-
322
- if table_comment := cast(tuple[exp.Table, str], transformed.args.get("table_comment")):
323
- # record table comment
324
- table, comment = table_comment
325
- catalog = table.catalog or self._conn.database
326
- schema = table.db or self._conn.schema
327
- assert catalog and schema
328
- self._duck_conn.execute(info_schema.insert_table_comment_sql(catalog, schema, table.name, comment))
329
-
330
- if (text_lengths := cast(list[tuple[str, int]], transformed.args.get("text_lengths"))) and (
331
- table := transformed.find(exp.Table)
332
- ):
333
- # record text lengths
334
- catalog = table.catalog or self._conn.database
335
- schema = table.db or self._conn.schema
336
- assert catalog and schema
337
- self._duck_conn.execute(info_schema.insert_text_lengths_sql(catalog, schema, table.name, text_lengths))
338
-
339
- if result_sql:
340
- self._duck_conn.execute(result_sql)
341
-
342
- self._arrow_table = self._duck_conn.fetch_arrow_table()
343
- self._rowcount = affected_count or self._arrow_table.num_rows
344
-
345
- self._last_sql = result_sql or sql
346
- self._last_params = params
347
-
348
- return self
349
-
350
- def executemany(
351
- self,
352
- command: str,
353
- seqparams: Sequence[Any] | dict[str, Any],
354
- **kwargs: Any,
355
- ) -> FakeSnowflakeCursor:
356
- if isinstance(seqparams, dict):
357
- # see https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api
358
- raise NotImplementedError("dict params not supported yet")
359
-
360
- # TODO: support insert optimisations
361
- # the snowflake connector will optimise inserts into a single query
362
- # unless num_statements != 1 .. but for simplicity we execute each
363
- # query one by one, which means the response differs
364
- for p in seqparams:
365
- self.execute(command, p)
366
-
367
- return self
368
-
369
- def fetchall(self) -> list[tuple] | list[dict]:
370
- if self._arrow_table is None:
371
- # mimic snowflake python connector error type
372
- raise TypeError("No open result set")
373
- return self.fetchmany(self._arrow_table.num_rows)
374
-
375
- def fetch_pandas_all(self, **kwargs: dict[str, Any]) -> pd.DataFrame:
376
- if self._arrow_table is None:
377
- # mimic snowflake python connector error type
378
- raise snowflake.connector.NotSupportedError("No open result set")
379
- return self._arrow_table.to_pandas()
380
-
381
- def fetchone(self) -> dict | tuple | None:
382
- result = self.fetchmany(1)
383
- return result[0] if result else None
384
-
385
- def fetchmany(self, size: int | None = None) -> list[tuple] | list[dict]:
386
- # https://peps.python.org/pep-0249/#fetchmany
387
- size = size or self._arraysize
388
-
389
- if self._arrow_table is None:
390
- # mimic snowflake python connector error type
391
- raise TypeError("No open result set")
392
- if self._arrow_table_fetch_index is None:
393
- self._arrow_table_fetch_index = 0
394
- else:
395
- self._arrow_table_fetch_index += size
396
-
397
- tslice = self._arrow_table.slice(offset=self._arrow_table_fetch_index, length=size).to_pylist()
398
- return tslice if self._use_dict_result else [tuple(d.values()) for d in tslice]
399
-
400
- def get_result_batches(self) -> list[ResultBatch] | None:
401
- if self._arrow_table is None:
402
- return None
403
- return [FakeResultBatch(self._use_dict_result, b) for b in self._arrow_table.to_batches(max_chunksize=1000)]
404
-
405
- @property
406
- def rowcount(self) -> int | None:
407
- return self._rowcount
408
-
409
- @property
410
- def sfqid(self) -> str | None:
411
- return "fakesnow"
412
-
413
- @property
414
- def sqlstate(self) -> str | None:
415
- return self._sqlstate
416
-
417
- @staticmethod
418
- def _describe_as_result_metadata(describe_results: list) -> list[ResultMetadata]:
419
- # fmt: off
420
- def as_result_metadata(column_name: str, column_type: str, _: str) -> ResultMetadata:
421
- # see https://docs.snowflake.com/en/user-guide/python-connector-api.html#type-codes
422
- # and https://arrow.apache.org/docs/python/api/datatypes.html#type-checking
423
- if column_type in {"BIGINT", "INTEGER"}:
424
- return ResultMetadata(
425
- name=column_name, type_code=0, display_size=None, internal_size=None, precision=38, scale=0, is_nullable=True # noqa: E501
426
- )
427
- elif column_type.startswith("DECIMAL"):
428
- match = re.search(r'\((\d+),(\d+)\)', column_type)
429
- if match:
430
- precision = int(match[1])
431
- scale = int(match[2])
432
- else:
433
- precision = scale = None
434
- return ResultMetadata(
435
- name=column_name, type_code=0, display_size=None, internal_size=None, precision=precision, scale=scale, is_nullable=True # noqa: E501
436
- )
437
- elif column_type == "VARCHAR":
438
- # TODO: fetch internal_size from varchar size
439
- return ResultMetadata(
440
- name=column_name, type_code=2, display_size=None, internal_size=16777216, precision=None, scale=None, is_nullable=True # noqa: E501
441
- )
442
- elif column_type == "DOUBLE":
443
- return ResultMetadata(
444
- name=column_name, type_code=1, display_size=None, internal_size=None, precision=None, scale=None, is_nullable=True # noqa: E501
445
- )
446
- elif column_type == "BOOLEAN":
447
- return ResultMetadata(
448
- name=column_name, type_code=13, display_size=None, internal_size=None, precision=None, scale=None, is_nullable=True # noqa: E501
449
- )
450
- elif column_type == "DATE":
451
- return ResultMetadata(
452
- name=column_name, type_code=3, display_size=None, internal_size=None, precision=None, scale=None, is_nullable=True # noqa: E501
453
- )
454
- elif column_type in {"TIMESTAMP", "TIMESTAMP_NS"}:
455
- return ResultMetadata(
456
- name=column_name, type_code=8, display_size=None, internal_size=None, precision=0, scale=9, is_nullable=True # noqa: E501
457
- )
458
- elif column_type == "TIMESTAMP WITH TIME ZONE":
459
- return ResultMetadata(
460
- name=column_name, type_code=7, display_size=None, internal_size=None, precision=0, scale=9, is_nullable=True # noqa: E501
461
- )
462
- elif column_type == "BLOB":
463
- return ResultMetadata(
464
- name=column_name, type_code=11, display_size=None, internal_size=8388608, precision=None, scale=None, is_nullable=True # noqa: E501
465
- )
466
- elif column_type == "TIME":
467
- return ResultMetadata(
468
- name=column_name, type_code=12, display_size=None, internal_size=None, precision=0, scale=9, is_nullable=True # noqa: E501
469
- )
470
- elif column_type == "JSON":
471
- # TODO: correctly map OBJECT and ARRAY see https://github.com/tekumara/fakesnow/issues/26
472
- return ResultMetadata(
473
- name=column_name, type_code=5, display_size=None, internal_size=None, precision=None, scale=None, is_nullable=True # noqa: E501
474
- )
475
- else:
476
- # TODO handle more types
477
- raise NotImplementedError(f"for column type {column_type}")
478
-
479
- # fmt: on
480
-
481
- meta = [
482
- as_result_metadata(column_name, column_type, null)
483
- for (column_name, column_type, null, _, _, _) in describe_results
484
- ]
485
- return meta
486
-
487
- def _rewrite_with_params(
488
- self,
489
- command: str,
490
- params: Sequence[Any] | dict[Any, Any] | None = None,
491
- ) -> tuple[str, Sequence[Any] | dict[Any, Any] | None]:
492
- if params and self._conn._paramstyle in ("pyformat", "format"): # noqa: SLF001
493
- # handle client-side in the same manner as the snowflake python connector
494
-
495
- def convert(param: Any) -> Any: # noqa: ANN401
496
- return self._converter.quote(self._converter.escape(self._converter.to_snowflake(param)))
497
-
498
- if isinstance(params, dict):
499
- params = {k: convert(v) for k, v in params.items()}
500
- else:
501
- params = tuple(convert(v) for v in params)
502
-
503
- return command % params, None
504
-
505
- return command, params
506
-
507
- def _inline_variables(self, sql: str) -> str:
508
- return self._conn.variables.inline_variables(sql)
509
-
510
-
511
- class FakeSnowflakeConnection:
512
- def __init__(
513
- self,
514
- duck_conn: DuckDBPyConnection,
515
- database: str | None = None,
516
- schema: str | None = None,
517
- create_database: bool = True,
518
- create_schema: bool = True,
519
- db_path: str | os.PathLike | None = None,
520
- nop_regexes: list[str] | None = None,
521
- *args: Any,
522
- **kwargs: Any,
523
- ):
524
- self._duck_conn = duck_conn
525
- # upper case database and schema like snowflake unquoted identifiers
526
- # so they appear as upper-cased in information_schema
527
- # catalog and schema names are not actually case-sensitive in duckdb even though
528
- # they are as cased in information_schema.schemata, so when selecting from
529
- # information_schema.schemata below we use upper-case to match any existing duckdb
530
- # catalog or schemas like "information_schema"
531
- self.database = database and database.upper()
532
- self.schema = schema and schema.upper()
533
-
534
- self.database_set = False
535
- self.schema_set = False
536
- self.db_path = Path(db_path) if db_path else None
537
- self.nop_regexes = nop_regexes
538
- self._paramstyle = snowflake.connector.paramstyle
539
- self.variables = Variables()
540
-
541
- # create database if needed
542
- if (
543
- create_database
544
- and self.database
545
- and not duck_conn.execute(
546
- f"""select * from information_schema.schemata
547
- where upper(catalog_name) = '{self.database}'"""
548
- ).fetchone()
549
- ):
550
- db_file = f"{self.db_path/self.database}.db" if self.db_path else ":memory:"
551
- duck_conn.execute(f"ATTACH DATABASE '{db_file}' AS {self.database}")
552
- duck_conn.execute(info_schema.creation_sql(self.database))
553
- duck_conn.execute(macros.creation_sql(self.database))
554
-
555
- # create schema if needed
556
- if (
557
- create_schema
558
- and self.database
559
- and self.schema
560
- and not duck_conn.execute(
561
- f"""select * from information_schema.schemata
562
- where upper(catalog_name) = '{self.database}' and upper(schema_name) = '{self.schema}'"""
563
- ).fetchone()
564
- ):
565
- duck_conn.execute(f"CREATE SCHEMA {self.database}.{self.schema}")
566
-
567
- # set database and schema if both exist
568
- if (
569
- self.database
570
- and self.schema
571
- and duck_conn.execute(
572
- f"""select * from information_schema.schemata
573
- where upper(catalog_name) = '{self.database}' and upper(schema_name) = '{self.schema}'"""
574
- ).fetchone()
575
- ):
576
- duck_conn.execute(f"SET schema='{self.database}.{self.schema}'")
577
- self.database_set = True
578
- self.schema_set = True
579
- # set database if only that exists
580
- elif (
581
- self.database
582
- and duck_conn.execute(
583
- f"""select * from information_schema.schemata
584
- where upper(catalog_name) = '{self.database}'"""
585
- ).fetchone()
586
- ):
587
- duck_conn.execute(f"SET schema='{self.database}.main'")
588
- self.database_set = True
589
-
590
- # use UTC instead of local time zone for consistent testing
591
- duck_conn.execute("SET GLOBAL TimeZone = 'UTC'")
592
-
593
- def __enter__(self) -> Self:
594
- return self
595
-
596
- def __exit__(
597
- self,
598
- exc_type: type[BaseException] | None,
599
- exc_value: BaseException | None,
600
- traceback: TracebackType | None,
601
- ) -> None:
602
- pass
603
-
604
- def close(self, retry: bool = True) -> None:
605
- self._duck_conn.close()
606
-
607
- def commit(self) -> None:
608
- self.cursor().execute("COMMIT")
609
-
610
- def cursor(self, cursor_class: type[SnowflakeCursor] = SnowflakeCursor) -> FakeSnowflakeCursor:
611
- # TODO: use duck_conn cursor for thread-safety
612
- return FakeSnowflakeCursor(conn=self, duck_conn=self._duck_conn, use_dict_result=cursor_class == DictCursor)
613
-
614
- def execute_string(
615
- self,
616
- sql_text: str,
617
- remove_comments: bool = False,
618
- return_cursors: bool = True,
619
- cursor_class: type[SnowflakeCursor] = SnowflakeCursor,
620
- **kwargs: dict[str, Any],
621
- ) -> Iterable[FakeSnowflakeCursor]:
622
- cursors = [
623
- self.cursor(cursor_class).execute(e.sql(dialect="snowflake"))
624
- for e in sqlglot.parse(sql_text, read="snowflake")
625
- if e and not isinstance(e, exp.Semicolon) # ignore comments
626
- ]
627
- return cursors if return_cursors else []
628
-
629
- def rollback(self) -> None:
630
- self.cursor().execute("ROLLBACK")
631
-
632
- def _insert_df(self, df: pd.DataFrame, table_name: str) -> int:
633
- # Objects in dataframes are written as parquet structs, and snowflake loads parquet structs as json strings.
634
- # Whereas duckdb analyses a dataframe see https://duckdb.org/docs/api/python/data_ingestion.html#pandas-dataframes--object-columns
635
- # and converts a object to the most specific type possible, eg: dict -> STRUCT, MAP or varchar, and list -> LIST
636
- # For dicts see https://github.com/duckdb/duckdb/pull/3985 and https://github.com/duckdb/duckdb/issues/9510
637
- #
638
- # When the rows have dicts with different keys there isn't a single STRUCT that can cover them, so the type is
639
- # varchar and value a string containing a struct representation. In order to support dicts with different keys
640
- # we first convert the dicts to json strings. A pity we can't do something inside duckdb and avoid the dataframe
641
- # copy and transform in python.
642
-
643
- df = df.copy()
644
-
645
- # Identify columns of type object
646
- object_cols = df.select_dtypes(include=["object"]).columns
647
-
648
- # Apply json.dumps to these columns
649
- for col in object_cols:
650
- # don't jsonify string
651
- df[col] = df[col].apply(lambda x: json.dumps(x) if isinstance(x, (dict, list)) else x)
652
-
653
- escaped_cols = ",".join(f'"{col}"' for col in df.columns.to_list())
654
- self._duck_conn.execute(f"INSERT INTO {table_name}({escaped_cols}) SELECT * FROM df")
655
-
656
- return self._duck_conn.fetchall()[0][0]
657
-
658
-
659
- class FakeResultBatch(ResultBatch):
660
- def __init__(self, use_dict_result: bool, batch: pyarrow.RecordBatch):
661
- self._use_dict_result = use_dict_result
662
- self._batch = batch
663
-
664
- def create_iter(
665
- self, **kwargs: dict[str, Any]
666
- ) -> Iterator[dict | Exception] | Iterator[tuple | Exception] | Iterator[pyarrow.Table] | Iterator[pd.DataFrame]:
667
- if self._use_dict_result:
668
- return iter(self._batch.to_pylist())
669
-
670
- return iter(tuple(d.values()) for d in self._batch.to_pylist())
671
-
672
- @property
673
- def rowcount(self) -> int:
674
- return self._batch.num_rows
675
-
676
- def to_pandas(self) -> pd.DataFrame:
677
- return self._batch.to_pandas()
678
-
679
- def to_arrow(self) -> pyarrow.Table:
680
- raise NotImplementedError()
681
-
682
-
683
- CopyResult = tuple[
684
- str,
685
- str,
686
- int,
687
- int,
688
- int,
689
- int,
690
- Optional[str],
691
- Optional[int],
692
- Optional[int],
693
- Optional[str],
694
- ]
695
-
696
- WritePandasResult = tuple[
697
- bool,
698
- int,
699
- int,
700
- Sequence[CopyResult],
701
- ]
702
-
703
-
704
- def sql_type(dtype: np.dtype) -> str:
705
- if str(dtype) == "int64":
706
- return "NUMBER"
707
- elif str(dtype) == "object":
708
- return "VARCHAR"
709
- else:
710
- raise NotImplementedError(f"sql_type {dtype=}")
711
-
712
-
713
- def write_pandas(
714
- conn: FakeSnowflakeConnection,
715
- df: pd.DataFrame,
716
- table_name: str,
717
- database: str | None = None,
718
- schema: str | None = None,
719
- chunk_size: int | None = None,
720
- compression: str = "gzip",
721
- on_error: str = "abort_statement",
722
- parallel: int = 4,
723
- quote_identifiers: bool = True,
724
- auto_create_table: bool = False,
725
- create_temp_table: bool = False,
726
- overwrite: bool = False,
727
- table_type: Literal["", "temp", "temporary", "transient"] = "",
728
- **kwargs: Any,
729
- ) -> WritePandasResult:
730
- name = table_name
731
- if schema:
732
- name = f"{schema}.{name}"
733
- if database:
734
- name = f"{database}.{name}"
735
-
736
- if auto_create_table:
737
- cols = [f"{c} {sql_type(t)}" for c, t in df.dtypes.to_dict().items()]
738
-
739
- conn.cursor().execute(f"CREATE TABLE IF NOT EXISTS {name} ({','.join(cols)})")
740
-
741
- count = conn._insert_df(df, name) # noqa: SLF001
742
-
743
- # mocks https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html#output
744
- mock_copy_results = [("fakesnow/file0.txt", "LOADED", count, count, 1, 0, None, None, None, None)]
745
-
746
- # return success
747
- return (True, len(mock_copy_results), count, mock_copy_results)
1
+ from .conn import FakeSnowflakeConnection as FakeSnowflakeConnection
2
+ from .cursor import FakeSnowflakeCursor as FakeSnowflakeCursor
3
+ from .pandas_tools import write_pandas as write_pandas