fakesnow 0.9.35__py3-none-any.whl → 0.9.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fakesnow/__init__.py CHANGED
@@ -91,3 +91,54 @@ def patch(
91
91
  finally:
92
92
  stack.close()
93
93
  fs.duck_conn.close()
94
+
95
+
96
+ @contextmanager
97
+ def server(port: int | None = None, session_parameters: dict[str, str | int | bool] | None = None) -> Iterator[dict]:
98
+ """Start a fake snowflake server in a separate thread and yield connection kwargs.
99
+
100
+ Args:
101
+ port (int | None, optional): Port to run the server on. If None, an available port is chosen. Defaults to None.
102
+
103
+ Yields:
104
+ Iterator[dict]: Connection parameters for the fake snowflake server.
105
+ """
106
+ import socket
107
+ import threading
108
+ from time import sleep
109
+
110
+ import uvicorn
111
+
112
+ import fakesnow.server
113
+
114
+ # find an unused TCP port between 1024-65535
115
+ if not port:
116
+ with contextlib.closing(socket.socket(type=socket.SOCK_STREAM)) as sock:
117
+ sock.bind(("127.0.0.1", 0))
118
+ port = sock.getsockname()[1]
119
+
120
+ assert port
121
+ server = uvicorn.Server(uvicorn.Config(fakesnow.server.app, port=port, log_level="info"))
122
+ thread = threading.Thread(target=server.run, name="fakesnow server", daemon=True)
123
+ thread.start()
124
+
125
+ while not server.started:
126
+ sleep(0.1)
127
+
128
+ try:
129
+ yield dict(
130
+ user="fake",
131
+ password="snow",
132
+ account="fakesnow",
133
+ host="localhost",
134
+ port=port,
135
+ protocol="http",
136
+ # disable telemetry
137
+ session_parameters={"CLIENT_OUT_OF_BAND_TELEMETRY_ENABLED": False} | (session_parameters or {}),
138
+ # disable retries on error
139
+ network_timeout=1,
140
+ )
141
+ finally:
142
+ server.should_exit = True
143
+ # wait for server thread to end
144
+ thread.join()
fakesnow/conn.py CHANGED
@@ -6,8 +6,7 @@ from pathlib import Path
6
6
  from types import TracebackType
7
7
  from typing import Any
8
8
 
9
- import snowflake.connector.converter
10
- import snowflake.connector.errors
9
+ import snowflake.connector
11
10
  import sqlglot
12
11
  from duckdb import DuckDBPyConnection
13
12
  from snowflake.connector.cursor import DictCursor, SnowflakeCursor
fakesnow/converter.py CHANGED
@@ -8,20 +8,20 @@ from datetime import date, time, timezone
8
8
 
9
9
 
10
10
  def from_binding(binding: dict[str, str]) -> int | bytes | bool | date | time | datetime.datetime | str:
11
- typ = binding["type"]
11
+ type_ = binding["type"]
12
12
  value = binding["value"]
13
- if typ == "FIXED":
13
+ if type_ == "FIXED":
14
14
  return int(value)
15
- elif typ == "BINARY":
15
+ elif type_ == "BINARY":
16
16
  return from_binary(value)
17
17
  # TODO: not strictly needed
18
- elif typ == "BOOLEAN":
18
+ elif type_ == "BOOLEAN":
19
19
  return value.lower() == "true"
20
- elif typ == "DATE":
20
+ elif type_ == "DATE":
21
21
  return from_date(value)
22
- elif typ == "TIME":
22
+ elif type_ == "TIME":
23
23
  return from_time(value)
24
- elif typ == "TIMESTAMP_NTZ":
24
+ elif type_ == "TIMESTAMP_NTZ":
25
25
  return from_datetime(value)
26
26
  else:
27
27
  # For other types, return str
fakesnow/copy_into.py ADDED
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Protocol, cast
6
+ from urllib.parse import urlparse, urlunparse
7
+
8
+ import duckdb
9
+ import snowflake.connector.errors
10
+ from duckdb import DuckDBPyConnection
11
+ from sqlglot import exp
12
+
13
+ from fakesnow import logger
14
+
15
+
16
+ def copy_into(
17
+ duck_conn: DuckDBPyConnection, expr: exp.Copy, params: Sequence[Any] | dict[Any, Any] | None = None
18
+ ) -> str:
19
+ cparams = _params(expr)
20
+ urls = _source_urls(expr, cparams.files)
21
+ inserts = _inserts(expr, cparams, urls)
22
+
23
+ results = []
24
+ try:
25
+ # TODO: fetch files last modified dates and check if file exists in load_history already
26
+ for i, url in zip(inserts, urls):
27
+ sql = i.sql(dialect="duckdb")
28
+ logger.log_sql(sql, params)
29
+ duck_conn.execute(sql, params)
30
+ (affected_count,) = duck_conn.fetchall()[0]
31
+ results.append(f"('{url}', 'LOADED', {affected_count}, {affected_count}, 1, 0, NULL, NULL, NULL, NULL)")
32
+
33
+ # TODO: update load_history with the results if loaded
34
+
35
+ columns = "file, status, rows_parsed, rows_loaded, error_limit, errors_seen, first_error, first_error_line, first_error_character, first_error_column_name" # noqa: E501
36
+ values = "\n, ".join(results)
37
+ sql = f"SELECT * FROM (VALUES\n {values}\n) AS t({columns})"
38
+ duck_conn.execute(sql)
39
+ return sql
40
+ except duckdb.HTTPException as e:
41
+ raise snowflake.connector.errors.ProgrammingError(msg=e.args[0], errno=91016, sqlstate="22000") from None
42
+ except duckdb.ConversionException as e:
43
+ raise snowflake.connector.errors.ProgrammingError(msg=e.args[0], errno=100038, sqlstate="22018") from None
44
+
45
+
46
+ def _params(expr: exp.Copy) -> Params:
47
+ kwargs = {}
48
+ force = False
49
+
50
+ params = cast(list[exp.CopyParameter], expr.args.get("params", []))
51
+ cparams = Params()
52
+ for param in params:
53
+ assert isinstance(param.this, exp.Var), f"{param.this.__class__} is not a Var"
54
+ var = param.this.name.upper()
55
+ if var == "FILE_FORMAT":
56
+ if kwargs.get("file_format"):
57
+ raise ValueError(cparams)
58
+
59
+ var_type = next((e.args["value"].this for e in param.expressions if e.this.this == "TYPE"), None)
60
+ if not var_type:
61
+ raise NotImplementedError("FILE_FORMAT without TYPE is not currently implemented")
62
+
63
+ if var_type == "CSV":
64
+ kwargs["file_format"] = handle_csv(param.expressions)
65
+ else:
66
+ raise NotImplementedError(f"{var_type} FILE_FORMAT is not currently implemented")
67
+ elif var == "FORCE":
68
+ force = True
69
+ elif var == "FILES":
70
+ kwargs["files"] = [lit.name for lit in param.find_all(exp.Literal)]
71
+ else:
72
+ raise ValueError(f"Unknown copy parameter: {param.this}")
73
+
74
+ if not force:
75
+ raise NotImplementedError("COPY INTO with FORCE=false (default) is not currently implemented")
76
+
77
+ return Params(**kwargs)
78
+
79
+
80
+ def _source_urls(expr: exp.Copy, files: list[str]) -> list[str]:
81
+ """
82
+ Given a COPY statement and a list of files, return a list of URLs with each file appended as a fragment.
83
+ Checks that the source is a valid URL.
84
+ """
85
+ source = expr.args["files"][0].this
86
+ assert isinstance(source, exp.Literal), f"{source} is not a exp.Literal"
87
+
88
+ scheme, netloc, path, params, query, fragment = urlparse(source.name)
89
+ if not scheme:
90
+ raise snowflake.connector.errors.ProgrammingError(
91
+ msg=f"SQL compilation error:\ninvalid URL prefix found in: '{source.name}'", errno=1011, sqlstate="42601"
92
+ )
93
+
94
+ # rebuild url from components to ensure correct handling of host slash
95
+ return [_urlunparse(scheme, netloc, path, params, query, fragment, file) for file in files] or [source.name]
96
+
97
+
98
+ def _urlunparse(scheme: str, netloc: str, path: str, params: str, query: str, fragment: str, suffix: str) -> str:
99
+ """Construct a URL from its components appending suffix to the last used component."""
100
+ if fragment:
101
+ fragment += suffix
102
+ elif query:
103
+ query += suffix
104
+ elif params:
105
+ params += suffix
106
+ else:
107
+ path += suffix
108
+ return urlunparse((scheme, netloc, path, params, query, fragment))
109
+
110
+
111
+ def _inserts(expr: exp.Copy, params: Params, urls: list[str]) -> list[exp.Expression]:
112
+ # INTO expression
113
+ target = expr.this
114
+ columns = [exp.Column(this=exp.Identifier(this=f"column{i}")) for i in range(len(target.expressions))] or [
115
+ exp.Column(this=exp.Star())
116
+ ]
117
+
118
+ return [
119
+ exp.Insert(
120
+ this=target,
121
+ expression=exp.Select(expressions=columns).from_(exp.Table(this=params.file_format.read_expression(url))),
122
+ )
123
+ for url in urls
124
+ ]
125
+
126
+
127
+ def handle_csv(expressions: list[exp.Property]) -> ReadCSV:
128
+ skip_header = ReadCSV.skip_header
129
+ quote = ReadCSV.quote
130
+ delimiter = ReadCSV.delimiter
131
+
132
+ for expression in expressions:
133
+ exp_type = expression.name
134
+ if exp_type in {"TYPE"}:
135
+ continue
136
+
137
+ elif exp_type == "SKIP_HEADER":
138
+ skip_header = True
139
+ elif exp_type == "FIELD_OPTIONALLY_ENCLOSED_BY":
140
+ quote = expression.args["value"].this
141
+ elif exp_type == "FIELD_DELIMITER":
142
+ delimiter = expression.args["value"].this
143
+ else:
144
+ raise NotImplementedError(f"{exp_type} is not currently implemented")
145
+
146
+ return ReadCSV(
147
+ skip_header=skip_header,
148
+ quote=quote,
149
+ delimiter=delimiter,
150
+ )
151
+
152
+
153
+ @dataclass
154
+ class FileTypeHandler(Protocol):
155
+ def read_expression(self, url: str) -> exp.Expression: ...
156
+
157
+ @staticmethod
158
+ def make_eq(name: str, value: list | str | int | bool) -> exp.EQ:
159
+ if isinstance(value, list):
160
+ expression = exp.array(*[exp.Literal(this=str(v), is_string=isinstance(v, str)) for v in value])
161
+ elif isinstance(value, bool):
162
+ expression = exp.Boolean(this=value)
163
+ else:
164
+ expression = exp.Literal(this=str(value), is_string=isinstance(value, str))
165
+
166
+ return exp.EQ(this=exp.Literal(this=name, is_string=False), expression=expression)
167
+
168
+
169
+ @dataclass
170
+ class ReadCSV(FileTypeHandler):
171
+ skip_header: bool = False
172
+ quote: str | None = None
173
+ delimiter: str = ","
174
+
175
+ def read_expression(self, url: str) -> exp.Expression:
176
+ args = []
177
+
178
+ # don't parse header and use as column names, keep them as column0, column1, etc
179
+ args.append(self.make_eq("header", False))
180
+
181
+ if self.skip_header:
182
+ args.append(self.make_eq("skip", 1))
183
+
184
+ if self.quote:
185
+ quote = self.quote.replace("'", "''")
186
+ args.append(self.make_eq("quote", quote))
187
+
188
+ if self.delimiter and self.delimiter != ",":
189
+ delimiter = self.delimiter.replace("'", "''")
190
+ args.append(self.make_eq("sep", delimiter))
191
+
192
+ return exp.func("read_csv", exp.Literal(this=url, is_string=True), *args)
193
+
194
+
195
+ @dataclass
196
+ class Params:
197
+ files: list[str] = field(default_factory=list)
198
+ # Snowflake defaults to CSV when no file format is specified
199
+ file_format: FileTypeHandler = field(default_factory=ReadCSV)
fakesnow/cursor.py CHANGED
@@ -25,12 +25,13 @@ import fakesnow.checks as checks
25
25
  import fakesnow.expr as expr
26
26
  import fakesnow.info_schema as info_schema
27
27
  import fakesnow.transforms as transforms
28
+ from fakesnow import logger
29
+ from fakesnow.copy_into import copy_into
28
30
  from fakesnow.rowtype import describe_as_result_metadata
29
31
 
30
32
  if TYPE_CHECKING:
31
33
  # don't require pandas at import time
32
34
  import pandas as pd
33
- import pyarrow.lib
34
35
 
35
36
  # avoid circular import
36
37
  from fakesnow.conn import FakeSnowflakeConnection
@@ -46,6 +47,12 @@ SQL_DROPPED = Template("SELECT '${name} successfully dropped.' as 'status'")
46
47
  SQL_INSERTED_ROWS = Template("SELECT ${count} as 'number of rows inserted'")
47
48
  SQL_UPDATED_ROWS = Template("SELECT ${count} as 'number of rows updated', 0 as 'number of multi-joined rows updated'")
48
49
  SQL_DELETED_ROWS = Template("SELECT ${count} as 'number of rows deleted'")
50
+ SQL_COPY_ROWS = Template(
51
+ "SELECT '${file}' as file, 'LOADED' as status, ${count} as rows_parsed, "
52
+ "${count} as rows_loaded, 1 as error_limit, 0 as errors_seen, "
53
+ "NULL as first_error, NULL as first_error_line, NULL as first_error_character, "
54
+ "NULL as first_error_column_name"
55
+ )
49
56
 
50
57
 
51
58
  class FakeSnowflakeCursor:
@@ -279,8 +286,11 @@ class FakeSnowflakeCursor:
279
286
  result_sql = None
280
287
 
281
288
  try:
282
- self._log_sql(sql, params)
283
- self._duck_conn.execute(sql, params)
289
+ if isinstance(transformed, exp.Copy):
290
+ sql = copy_into(self._duck_conn, transformed, params)
291
+ else:
292
+ logger.log_sql(sql, params)
293
+ self._duck_conn.execute(sql, params)
284
294
  except duckdb.BinderException as e:
285
295
  msg = e.args[0]
286
296
  raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=2043, sqlstate="02000") from None
@@ -384,7 +394,7 @@ class FakeSnowflakeCursor:
384
394
  self._duck_conn.execute(info_schema.insert_text_lengths_sql(catalog, schema, table.name, text_lengths))
385
395
 
386
396
  if result_sql:
387
- self._log_sql(result_sql, params)
397
+ logger.log_sql(result_sql)
388
398
  self._duck_conn.execute(result_sql)
389
399
 
390
400
  self._arrow_table = self._duck_conn.fetch_arrow_table()
@@ -394,10 +404,6 @@ class FakeSnowflakeCursor:
394
404
  self._last_sql = result_sql or sql
395
405
  self._last_params = None if result_sql else params
396
406
 
397
- def _log_sql(self, sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
398
- if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
399
- print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
400
-
401
407
  def executemany(
402
408
  self,
403
409
  command: str,
fakesnow/fixtures.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from collections.abc import Iterator
2
+ from typing import Any
2
3
 
3
4
  import pytest
4
5
 
@@ -11,6 +12,12 @@ def _fakesnow() -> Iterator[None]:
11
12
  yield
12
13
 
13
14
 
15
+ @pytest.fixture(scope="session")
16
+ def fakesnow_server() -> Iterator[dict[str, Any]]:
17
+ with fakesnow.server() as conn_kwargs:
18
+ yield conn_kwargs
19
+
20
+
14
21
  @pytest.fixture
15
22
  def _fakesnow_no_auto_create() -> Iterator[None]:
16
23
  with fakesnow.patch(create_database_on_connect=False, create_schema_on_connect=False):
fakesnow/logger.py ADDED
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ from collections.abc import Sequence
6
+ from typing import Any
7
+
8
+
9
+ def log_sql(sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
10
+ if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
11
+ print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
fakesnow/macros.py CHANGED
@@ -6,8 +6,32 @@ CREATE MACRO IF NOT EXISTS ${catalog}.equal_null(a, b) AS a IS NOT DISTINCT FROM
6
6
  """
7
7
  )
8
8
 
9
+ # emulate the Snowflake FLATTEN function for ARRAYs
10
+ # see https://docs.snowflake.com/en/sql-reference/functions/flatten.html
11
+ FS_FLATTEN = Template(
12
+ """
13
+ CREATE OR REPLACE MACRO ${catalog}._fs_flatten(input) AS TABLE
14
+ SELECT
15
+ NULL AS SEQ, -- TODO use a sequence and nextval
16
+ CAST(NULL AS VARCHAR) AS KEY,
17
+ '[' || GENERATE_SUBSCRIPTS(
18
+ CAST(TO_JSON(input) AS JSON []),
19
+ 1
20
+ ) - 1 || ']' AS PATH,
21
+ GENERATE_SUBSCRIPTS(
22
+ CAST(TO_JSON(input) AS JSON []),
23
+ 1
24
+ ) - 1 AS INDEX,
25
+ UNNEST(
26
+ CAST(TO_JSON(input) AS JSON [])
27
+ ) AS VALUE,
28
+ TO_JSON(input) AS THIS;
29
+ """
30
+ )
31
+
9
32
 
10
33
  def creation_sql(catalog: str) -> str:
11
34
  return f"""
12
35
  {EQUAL_NULL.substitute(catalog=catalog)};
36
+ {FS_FLATTEN.substitute(catalog=catalog)};
13
37
  """
fakesnow/server.py CHANGED
@@ -59,7 +59,10 @@ async def login_request(request: Request) -> JSONResponse:
59
59
  {
60
60
  "data": {
61
61
  "token": token,
62
- "parameters": [{"name": "AUTOCOMMIT", "value": True}],
62
+ "parameters": [
63
+ {"name": "AUTOCOMMIT", "value": True},
64
+ {"name": "CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY", "value": 3600},
65
+ ],
63
66
  },
64
67
  "success": True,
65
68
  }