fakesnow 0.9.36__py3-none-any.whl → 0.9.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fakesnow/__init__.py CHANGED
@@ -119,7 +119,7 @@ def server(port: int | None = None, session_parameters: dict[str, str | int | bo
119
119
 
120
120
  assert port
121
121
  server = uvicorn.Server(uvicorn.Config(fakesnow.server.app, port=port, log_level="info"))
122
- thread = threading.Thread(target=server.run, name="Server", daemon=True)
122
+ thread = threading.Thread(target=server.run, name="fakesnow server", daemon=True)
123
123
  thread.start()
124
124
 
125
125
  while not server.started:
fakesnow/converter.py CHANGED
@@ -8,20 +8,20 @@ from datetime import date, time, timezone
8
8
 
9
9
 
10
10
  def from_binding(binding: dict[str, str]) -> int | bytes | bool | date | time | datetime.datetime | str:
11
- typ = binding["type"]
11
+ type_ = binding["type"]
12
12
  value = binding["value"]
13
- if typ == "FIXED":
13
+ if type_ == "FIXED":
14
14
  return int(value)
15
- elif typ == "BINARY":
15
+ elif type_ == "BINARY":
16
16
  return from_binary(value)
17
17
  # TODO: not strictly needed
18
- elif typ == "BOOLEAN":
18
+ elif type_ == "BOOLEAN":
19
19
  return value.lower() == "true"
20
- elif typ == "DATE":
20
+ elif type_ == "DATE":
21
21
  return from_date(value)
22
- elif typ == "TIME":
22
+ elif type_ == "TIME":
23
23
  return from_time(value)
24
- elif typ == "TIMESTAMP_NTZ":
24
+ elif type_ == "TIMESTAMP_NTZ":
25
25
  return from_datetime(value)
26
26
  else:
27
27
  # For other types, return str
fakesnow/copy_into.py ADDED
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Protocol, cast
6
+ from urllib.parse import urlparse, urlunparse
7
+
8
+ import duckdb
9
+ import snowflake.connector.errors
10
+ from duckdb import DuckDBPyConnection
11
+ from sqlglot import exp
12
+
13
+ from fakesnow import logger
14
+
15
+
16
+ def copy_into(
17
+ duck_conn: DuckDBPyConnection, expr: exp.Copy, params: Sequence[Any] | dict[Any, Any] | None = None
18
+ ) -> str:
19
+ cparams = _params(expr)
20
+ urls = _source_urls(expr, cparams.files)
21
+ inserts = _inserts(expr, cparams, urls)
22
+
23
+ results = []
24
+ try:
25
+ # TODO: fetch files last modified dates and check if file exists in load_history already
26
+ for i, url in zip(inserts, urls):
27
+ sql = i.sql(dialect="duckdb")
28
+ logger.log_sql(sql, params)
29
+ duck_conn.execute(sql, params)
30
+ (affected_count,) = duck_conn.fetchall()[0]
31
+ results.append(f"('{url}', 'LOADED', {affected_count}, {affected_count}, 1, 0, NULL, NULL, NULL, NULL)")
32
+
33
+ # TODO: update load_history with the results if loaded
34
+
35
+ columns = "file, status, rows_parsed, rows_loaded, error_limit, errors_seen, first_error, first_error_line, first_error_character, first_error_column_name" # noqa: E501
36
+ values = "\n, ".join(results)
37
+ sql = f"SELECT * FROM (VALUES\n {values}\n) AS t({columns})"
38
+ duck_conn.execute(sql)
39
+ return sql
40
+ except duckdb.HTTPException as e:
41
+ raise snowflake.connector.errors.ProgrammingError(msg=e.args[0], errno=91016, sqlstate="22000") from None
42
+ except duckdb.ConversionException as e:
43
+ raise snowflake.connector.errors.ProgrammingError(msg=e.args[0], errno=100038, sqlstate="22018") from None
44
+
45
+
46
+ def _params(expr: exp.Copy) -> Params:
47
+ kwargs = {}
48
+ force = False
49
+
50
+ params = cast(list[exp.CopyParameter], expr.args.get("params", []))
51
+ cparams = Params()
52
+ for param in params:
53
+ assert isinstance(param.this, exp.Var), f"{param.this.__class__} is not a Var"
54
+ var = param.this.name.upper()
55
+ if var == "FILE_FORMAT":
56
+ if kwargs.get("file_format"):
57
+ raise ValueError(cparams)
58
+
59
+ var_type = next((e.args["value"].this for e in param.expressions if e.this.this == "TYPE"), None)
60
+ if not var_type:
61
+ raise NotImplementedError("FILE_FORMAT without TYPE is not currently implemented")
62
+
63
+ if var_type == "CSV":
64
+ kwargs["file_format"] = handle_csv(param.expressions)
65
+ else:
66
+ raise NotImplementedError(f"{var_type} FILE_FORMAT is not currently implemented")
67
+ elif var == "FORCE":
68
+ force = True
69
+ elif var == "FILES":
70
+ kwargs["files"] = [lit.name for lit in param.find_all(exp.Literal)]
71
+ else:
72
+ raise ValueError(f"Unknown copy parameter: {param.this}")
73
+
74
+ if not force:
75
+ raise NotImplementedError("COPY INTO with FORCE=false (default) is not currently implemented")
76
+
77
+ return Params(**kwargs)
78
+
79
+
80
+ def _source_urls(expr: exp.Copy, files: list[str]) -> list[str]:
81
+ """
82
+ Given a COPY statement and a list of files, return a list of URLs with each file appended as a fragment.
83
+ Checks that the source is a valid URL.
84
+ """
85
+ source = expr.args["files"][0].this
86
+ assert isinstance(source, exp.Literal), f"{source} is not a exp.Literal"
87
+
88
+ scheme, netloc, path, params, query, fragment = urlparse(source.name)
89
+ if not scheme:
90
+ raise snowflake.connector.errors.ProgrammingError(
91
+ msg=f"SQL compilation error:\ninvalid URL prefix found in: '{source.name}'", errno=1011, sqlstate="42601"
92
+ )
93
+
94
+ # rebuild url from components to ensure correct handling of host slash
95
+ return [_urlunparse(scheme, netloc, path, params, query, fragment, file) for file in files] or [source.name]
96
+
97
+
98
+ def _urlunparse(scheme: str, netloc: str, path: str, params: str, query: str, fragment: str, suffix: str) -> str:
99
+ """Construct a URL from its components appending suffix to the last used component."""
100
+ if fragment:
101
+ fragment += suffix
102
+ elif query:
103
+ query += suffix
104
+ elif params:
105
+ params += suffix
106
+ else:
107
+ path += suffix
108
+ return urlunparse((scheme, netloc, path, params, query, fragment))
109
+
110
+
111
+ def _inserts(expr: exp.Copy, params: Params, urls: list[str]) -> list[exp.Expression]:
112
+ # INTO expression
113
+ target = expr.this
114
+ columns = [exp.Column(this=exp.Identifier(this=f"column{i}")) for i in range(len(target.expressions))] or [
115
+ exp.Column(this=exp.Star())
116
+ ]
117
+
118
+ return [
119
+ exp.Insert(
120
+ this=target,
121
+ expression=exp.Select(expressions=columns).from_(exp.Table(this=params.file_format.read_expression(url))),
122
+ )
123
+ for url in urls
124
+ ]
125
+
126
+
127
+ def handle_csv(expressions: list[exp.Property]) -> ReadCSV:
128
+ skip_header = ReadCSV.skip_header
129
+ quote = ReadCSV.quote
130
+ delimiter = ReadCSV.delimiter
131
+
132
+ for expression in expressions:
133
+ exp_type = expression.name
134
+ if exp_type in {"TYPE"}:
135
+ continue
136
+
137
+ elif exp_type == "SKIP_HEADER":
138
+ skip_header = True
139
+ elif exp_type == "FIELD_OPTIONALLY_ENCLOSED_BY":
140
+ quote = expression.args["value"].this
141
+ elif exp_type == "FIELD_DELIMITER":
142
+ delimiter = expression.args["value"].this
143
+ else:
144
+ raise NotImplementedError(f"{exp_type} is not currently implemented")
145
+
146
+ return ReadCSV(
147
+ skip_header=skip_header,
148
+ quote=quote,
149
+ delimiter=delimiter,
150
+ )
151
+
152
+
153
+ @dataclass
154
+ class FileTypeHandler(Protocol):
155
+ def read_expression(self, url: str) -> exp.Expression: ...
156
+
157
+ @staticmethod
158
+ def make_eq(name: str, value: list | str | int | bool) -> exp.EQ:
159
+ if isinstance(value, list):
160
+ expression = exp.array(*[exp.Literal(this=str(v), is_string=isinstance(v, str)) for v in value])
161
+ elif isinstance(value, bool):
162
+ expression = exp.Boolean(this=value)
163
+ else:
164
+ expression = exp.Literal(this=str(value), is_string=isinstance(value, str))
165
+
166
+ return exp.EQ(this=exp.Literal(this=name, is_string=False), expression=expression)
167
+
168
+
169
+ @dataclass
170
+ class ReadCSV(FileTypeHandler):
171
+ skip_header: bool = False
172
+ quote: str | None = None
173
+ delimiter: str = ","
174
+
175
+ def read_expression(self, url: str) -> exp.Expression:
176
+ args = []
177
+
178
+ # don't parse header and use as column names, keep them as column0, column1, etc
179
+ args.append(self.make_eq("header", False))
180
+
181
+ if self.skip_header:
182
+ args.append(self.make_eq("skip", 1))
183
+
184
+ if self.quote:
185
+ quote = self.quote.replace("'", "''")
186
+ args.append(self.make_eq("quote", quote))
187
+
188
+ if self.delimiter and self.delimiter != ",":
189
+ delimiter = self.delimiter.replace("'", "''")
190
+ args.append(self.make_eq("sep", delimiter))
191
+
192
+ return exp.func("read_csv", exp.Literal(this=url, is_string=True), *args)
193
+
194
+
195
+ @dataclass
196
+ class Params:
197
+ files: list[str] = field(default_factory=list)
198
+ # Snowflake defaults to CSV when no file format is specified
199
+ file_format: FileTypeHandler = field(default_factory=ReadCSV)
fakesnow/cursor.py CHANGED
@@ -25,12 +25,13 @@ import fakesnow.checks as checks
25
25
  import fakesnow.expr as expr
26
26
  import fakesnow.info_schema as info_schema
27
27
  import fakesnow.transforms as transforms
28
+ from fakesnow import logger
29
+ from fakesnow.copy_into import copy_into
28
30
  from fakesnow.rowtype import describe_as_result_metadata
29
31
 
30
32
  if TYPE_CHECKING:
31
33
  # don't require pandas at import time
32
34
  import pandas as pd
33
- import pyarrow.lib
34
35
 
35
36
  # avoid circular import
36
37
  from fakesnow.conn import FakeSnowflakeConnection
@@ -255,7 +256,6 @@ class FakeSnowflakeCursor:
255
256
  .transform(lambda e: transforms.show_keys(e, self._conn.database, kind="FOREIGN"))
256
257
  .transform(transforms.show_users)
257
258
  .transform(transforms.create_user)
258
- .transform(transforms.copy_into)
259
259
  .transform(transforms.sha256)
260
260
  .transform(transforms.create_clone)
261
261
  .transform(transforms.alias_in_join)
@@ -286,8 +286,11 @@ class FakeSnowflakeCursor:
286
286
  result_sql = None
287
287
 
288
288
  try:
289
- self._log_sql(sql, params)
290
- self._duck_conn.execute(sql, params)
289
+ if isinstance(transformed, exp.Copy):
290
+ sql = copy_into(self._duck_conn, transformed, params)
291
+ else:
292
+ logger.log_sql(sql, params)
293
+ self._duck_conn.execute(sql, params)
291
294
  except duckdb.BinderException as e:
292
295
  msg = e.args[0]
293
296
  raise snowflake.connector.errors.ProgrammingError(msg=msg, errno=2043, sqlstate="02000") from None
@@ -307,10 +310,6 @@ class FakeSnowflakeCursor:
307
310
  raise snowflake.connector.errors.DatabaseError(msg=e.args[0], errno=250002, sqlstate="08003") from None
308
311
  except duckdb.ParserException as e:
309
312
  raise snowflake.connector.errors.ProgrammingError(msg=e.args[0], errno=1003, sqlstate="42000") from None
310
- except duckdb.HTTPException as e:
311
- raise snowflake.connector.errors.ProgrammingError(msg=e.args[0], errno=91016, sqlstate="22000") from None
312
- except duckdb.ConversionException as e:
313
- raise snowflake.connector.errors.ProgrammingError(msg=e.args[0], errno=100038, sqlstate="22018") from None
314
313
 
315
314
  affected_count = None
316
315
 
@@ -330,10 +329,6 @@ class FakeSnowflakeCursor:
330
329
  self._duck_conn.execute(info_schema.per_db_creation_sql(create_db_name))
331
330
  result_sql = SQL_CREATED_DATABASE.substitute(name=create_db_name)
332
331
 
333
- elif copy_from := transformed.args.get("copy_from"):
334
- (affected_count,) = self._duck_conn.fetchall()[0]
335
- result_sql = SQL_COPY_ROWS.substitute(count=affected_count, file=copy_from)
336
-
337
332
  elif cmd == "INSERT":
338
333
  (affected_count,) = self._duck_conn.fetchall()[0]
339
334
  result_sql = SQL_INSERTED_ROWS.substitute(count=affected_count)
@@ -399,7 +394,7 @@ class FakeSnowflakeCursor:
399
394
  self._duck_conn.execute(info_schema.insert_text_lengths_sql(catalog, schema, table.name, text_lengths))
400
395
 
401
396
  if result_sql:
402
- self._log_sql(result_sql, params)
397
+ logger.log_sql(result_sql)
403
398
  self._duck_conn.execute(result_sql)
404
399
 
405
400
  self._arrow_table = self._duck_conn.fetch_arrow_table()
@@ -409,10 +404,6 @@ class FakeSnowflakeCursor:
409
404
  self._last_sql = result_sql or sql
410
405
  self._last_params = None if result_sql else params
411
406
 
412
- def _log_sql(self, sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
413
- if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
414
- print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
415
-
416
407
  def executemany(
417
408
  self,
418
409
  command: str,
fakesnow/logger.py ADDED
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sys
5
+ from collections.abc import Sequence
6
+ from typing import Any
7
+
8
+
9
+ def log_sql(sql: str, params: Sequence[Any] | dict[Any, Any] | None = None) -> None:
10
+ if (fs_debug := os.environ.get("FAKESNOW_DEBUG")) and fs_debug != "snowflake":
11
+ print(f"{sql};{params=}" if params else f"{sql};", file=sys.stderr)
fakesnow/macros.py CHANGED
@@ -6,8 +6,32 @@ CREATE MACRO IF NOT EXISTS ${catalog}.equal_null(a, b) AS a IS NOT DISTINCT FROM
6
6
  """
7
7
  )
8
8
 
9
+ # emulate the Snowflake FLATTEN function for ARRAYs
10
+ # see https://docs.snowflake.com/en/sql-reference/functions/flatten.html
11
+ FS_FLATTEN = Template(
12
+ """
13
+ CREATE OR REPLACE MACRO ${catalog}._fs_flatten(input) AS TABLE
14
+ SELECT
15
+ NULL AS SEQ, -- TODO use a sequence and nextval
16
+ CAST(NULL AS VARCHAR) AS KEY,
17
+ '[' || GENERATE_SUBSCRIPTS(
18
+ CAST(TO_JSON(input) AS JSON []),
19
+ 1
20
+ ) - 1 || ']' AS PATH,
21
+ GENERATE_SUBSCRIPTS(
22
+ CAST(TO_JSON(input) AS JSON []),
23
+ 1
24
+ ) - 1 AS INDEX,
25
+ UNNEST(
26
+ CAST(TO_JSON(input) AS JSON [])
27
+ ) AS VALUE,
28
+ TO_JSON(input) AS THIS;
29
+ """
30
+ )
31
+
9
32
 
10
33
  def creation_sql(catalog: str) -> str:
11
34
  return f"""
12
35
  {EQUAL_NULL.substitute(catalog=catalog)};
36
+ {FS_FLATTEN.substitute(catalog=catalog)};
13
37
  """