kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +177 -50
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +23 -3
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/sampler.py +782 -0
  23. kumoai/experimental/rfm/base/source.py +2 -1
  24. kumoai/experimental/rfm/base/sql_sampler.py +247 -0
  25. kumoai/experimental/rfm/base/table.py +404 -203
  26. kumoai/experimental/rfm/graph.py +374 -172
  27. kumoai/experimental/rfm/infer/__init__.py +6 -4
  28. kumoai/experimental/rfm/infer/dtype.py +7 -4
  29. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  30. kumoai/experimental/rfm/infer/pkey.py +4 -2
  31. kumoai/experimental/rfm/infer/stype.py +35 -0
  32. kumoai/experimental/rfm/infer/time_col.py +1 -2
  33. kumoai/experimental/rfm/pquery/executor.py +27 -27
  34. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  35. kumoai/experimental/rfm/relbench.py +76 -0
  36. kumoai/experimental/rfm/rfm.py +762 -467
  37. kumoai/experimental/rfm/sagemaker.py +4 -4
  38. kumoai/experimental/rfm/task_table.py +292 -0
  39. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  40. kumoai/pquery/predictive_query.py +10 -6
  41. kumoai/pquery/training_table.py +16 -2
  42. kumoai/testing/snow.py +50 -0
  43. kumoai/trainer/distilled_trainer.py +175 -0
  44. kumoai/utils/__init__.py +3 -2
  45. kumoai/utils/display.py +87 -0
  46. kumoai/utils/progress_logger.py +190 -12
  47. kumoai/utils/sql.py +3 -0
  48. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +3 -2
  49. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +52 -41
  50. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  51. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  52. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
  53. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
  54. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
@@ -27,9 +27,11 @@ def connect(**kwargs: Any) -> Connection:
27
27
 
28
28
 
29
29
  from .table import SnowTable # noqa: E402
30
+ from .sampler import SnowSampler # noqa: E402
30
31
 
31
32
  __all__ = [
32
33
  'connect',
33
34
  'Connection',
34
35
  'SnowTable',
36
+ 'SnowSampler',
35
37
  ]
@@ -0,0 +1,366 @@
1
+ import json
2
+ from collections.abc import Iterator
3
+ from contextlib import contextmanager
4
+ from typing import TYPE_CHECKING
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pyarrow as pa
9
+ from kumoapi.pquery import ValidatedPredictiveQuery
10
+
11
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
12
+ from kumoai.experimental.rfm.base import SQLSampler, Table
13
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
14
+ from kumoai.utils import ProgressLogger
15
+
16
+ if TYPE_CHECKING:
17
+ from kumoai.experimental.rfm import Graph
18
+
19
+
20
+ @contextmanager
21
+ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
22
+ _style = connection._paramstyle
23
+ connection._paramstyle = style
24
+ yield
25
+ connection._paramstyle = _style
26
+
27
+
28
+ class SnowSampler(SQLSampler):
29
+ def __init__(
30
+ self,
31
+ graph: 'Graph',
32
+ verbose: bool | ProgressLogger = True,
33
+ ) -> None:
34
+ super().__init__(graph=graph, verbose=verbose)
35
+
36
+ for table in graph.tables.values():
37
+ assert isinstance(table, SnowTable)
38
+ self._connection = table._connection
39
+
40
+ def _get_min_max_time_dict(
41
+ self,
42
+ table_names: list[str],
43
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
44
+ selects: list[str] = []
45
+ for table_name in table_names:
46
+ column = self.time_column_dict[table_name]
47
+ column_ref = self.table_column_ref_dict[table_name][column]
48
+ select = (f"SELECT\n"
49
+ f" ? as table_name,\n"
50
+ f" MIN({column_ref}) as min_date,\n"
51
+ f" MAX({column_ref}) as max_date\n"
52
+ f"FROM {self.source_name_dict[table_name]}")
53
+ selects.append(select)
54
+ sql = "\nUNION ALL\n".join(selects)
55
+
56
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
57
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
58
+ cursor.execute(sql, table_names)
59
+ rows = cursor.fetchall()
60
+ for table_name, _min, _max in rows:
61
+ out_dict[table_name] = (
62
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
63
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
64
+ )
65
+
66
+ return out_dict
67
+
68
+ def _sample_entity_table(
69
+ self,
70
+ table_name: str,
71
+ columns: set[str],
72
+ num_rows: int,
73
+ random_seed: int | None = None,
74
+ ) -> pd.DataFrame:
75
+ # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
76
+ num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
77
+
78
+ source_table = self.source_table_dict[table_name]
79
+ filters: list[str] = []
80
+
81
+ key = self.primary_key_dict[table_name]
82
+ if key not in source_table or source_table[key].is_nullable:
83
+ key_ref = self.table_column_ref_dict[table_name][key]
84
+ filters.append(f" {key_ref} IS NOT NULL")
85
+
86
+ column = self.time_column_dict.get(table_name)
87
+ if column is None:
88
+ pass
89
+ elif column not in source_table or source_table[column].is_nullable:
90
+ column_ref = self.table_column_ref_dict[table_name][column]
91
+ filters.append(f" {column_ref} IS NOT NULL")
92
+
93
+ projections = [
94
+ self.table_column_proj_dict[table_name][column]
95
+ for column in columns
96
+ ]
97
+ sql = (f"SELECT {', '.join(projections)}\n"
98
+ f"FROM {self.source_name_dict[table_name]}\n"
99
+ f"SAMPLE ROW ({num_rows} ROWS)")
100
+ if len(filters) > 0:
101
+ sql += f"\nWHERE{' AND'.join(filters)}"
102
+
103
+ with self._connection.cursor() as cursor:
104
+ # NOTE This may return duplicate primary keys. This is okay.
105
+ cursor.execute(sql)
106
+ table = cursor.fetch_arrow_all()
107
+
108
+ return Table._sanitize(
109
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
110
+ dtype_dict=self.table_dtype_dict[table_name],
111
+ stype_dict=self.table_stype_dict[table_name],
112
+ )
113
+
114
+ def _sample_target(
115
+ self,
116
+ query: ValidatedPredictiveQuery,
117
+ entity_df: pd.DataFrame,
118
+ train_index: np.ndarray,
119
+ train_time: pd.Series,
120
+ num_train_examples: int,
121
+ test_index: np.ndarray,
122
+ test_time: pd.Series,
123
+ num_test_examples: int,
124
+ columns_dict: dict[str, set[str]],
125
+ time_offset_dict: dict[
126
+ tuple[str, str, str],
127
+ tuple[pd.DateOffset | None, pd.DateOffset],
128
+ ],
129
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
130
+
131
+ # NOTE For Snowflake, we execute everything at once to pay minimal
132
+ # query initialization costs.
133
+ index = np.concatenate([train_index, test_index])
134
+ time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
135
+
136
+ entity_df = entity_df.iloc[index].reset_index(drop=True)
137
+
138
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
139
+ time_dict: dict[str, pd.Series] = {}
140
+ time_column = self.time_column_dict.get(query.entity_table)
141
+ if time_column in columns_dict[query.entity_table]:
142
+ time_dict[query.entity_table] = entity_df[time_column]
143
+ batch_dict: dict[str, np.ndarray] = {
144
+ query.entity_table: np.arange(len(entity_df)),
145
+ }
146
+ for edge_type, (min_offset, max_offset) in time_offset_dict.items():
147
+ table_name, foreign_key, _ = edge_type
148
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
149
+ table_name=table_name,
150
+ foreign_key=foreign_key,
151
+ index=entity_df[self.primary_key_dict[query.entity_table]],
152
+ anchor_time=time,
153
+ min_offset=min_offset,
154
+ max_offset=max_offset,
155
+ columns=columns_dict[table_name],
156
+ )
157
+ time_column = self.time_column_dict.get(table_name)
158
+ if time_column in columns_dict[table_name]:
159
+ time_dict[table_name] = feat_dict[table_name][time_column]
160
+
161
+ y, mask = PQueryPandasExecutor().execute(
162
+ query=query,
163
+ feat_dict=feat_dict,
164
+ time_dict=time_dict,
165
+ batch_dict=batch_dict,
166
+ anchor_time=time,
167
+ num_forecasts=query.num_forecasts,
168
+ )
169
+
170
+ train_mask = mask[:len(train_index)]
171
+ test_mask = mask[len(train_index):]
172
+
173
+ boundary = int(train_mask.sum())
174
+ train_y = y.iloc[:boundary]
175
+ test_y = y.iloc[boundary:].reset_index(drop=True)
176
+
177
+ return train_y, train_mask, test_y, test_mask
178
+
179
+ def _by_pkey(
180
+ self,
181
+ table_name: str,
182
+ index: pd.Series,
183
+ columns: set[str],
184
+ ) -> tuple[pd.DataFrame, np.ndarray]:
185
+ key = self.primary_key_dict[table_name]
186
+ key_ref = self.table_column_ref_dict[table_name][key]
187
+ projections = [
188
+ self.table_column_proj_dict[table_name][column]
189
+ for column in columns
190
+ ]
191
+
192
+ payload = json.dumps(list(index))
193
+
194
+ sql = ("WITH TMP as (\n"
195
+ " SELECT\n"
196
+ " f.index as __KUMO_BATCH__,\n")
197
+ if self.table_dtype_dict[table_name][key].is_int():
198
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
199
+ elif self.table_dtype_dict[table_name][key].is_float():
200
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
201
+ else:
202
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
203
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
204
+ f")\n"
205
+ f"SELECT "
206
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
+ f"{', '.join(projections)}\n"
208
+ f"FROM TMP\n"
209
+ f"JOIN {self.source_name_dict[table_name]}\n"
210
+ f" ON {key_ref} = TMP.__KUMO_ID__")
211
+
212
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
213
+ cursor.execute(sql, (payload, ))
214
+ table = cursor.fetch_arrow_all()
215
+
216
+ # Remove any duplicated primary keys in post-processing:
217
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
218
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
219
+ table = table.take(gb['__KUMO_ID___min'])
220
+
221
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
222
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
223
+ table = table.remove_column(batch_index)
224
+
225
+ return Table._sanitize(
226
+ df=table.to_pandas(),
227
+ dtype_dict=self.table_dtype_dict[table_name],
228
+ stype_dict=self.table_stype_dict[table_name],
229
+ ), batch
230
+
231
+ def _by_fkey(
232
+ self,
233
+ table_name: str,
234
+ foreign_key: str,
235
+ index: pd.Series,
236
+ num_neighbors: int,
237
+ anchor_time: pd.Series | None,
238
+ columns: set[str],
239
+ ) -> tuple[pd.DataFrame, np.ndarray]:
240
+ time_column = self.time_column_dict.get(table_name)
241
+
242
+ if time_column is not None and anchor_time is not None:
243
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
244
+ payload = json.dumps(list(zip(index, anchor_time)))
245
+ else:
246
+ payload = json.dumps(list(zip(index)))
247
+
248
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
249
+ projections = [
250
+ self.table_column_proj_dict[table_name][column]
251
+ for column in columns
252
+ ]
253
+
254
+ sql = ("WITH TMP as (\n"
255
+ " SELECT\n"
256
+ " f.index as __KUMO_BATCH__,\n")
257
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
258
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
259
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
260
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
261
+ else:
262
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
263
+ if time_column is not None and anchor_time is not None:
264
+ sql += (",\n"
265
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_TIME__")
266
+ sql += (f"\n"
267
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
268
+ f")\n"
269
+ f"SELECT "
270
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
271
+ f"{', '.join(projections)}\n"
272
+ f"FROM TMP\n"
273
+ f"JOIN {self.source_name_dict[table_name]}\n"
274
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
275
+ if time_column is not None and anchor_time is not None:
276
+ time_ref = self.table_column_ref_dict[table_name][time_column]
277
+ sql += f" AND {time_ref} <= TMP.__KUMO_TIME__\n"
278
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
279
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
280
+ if time_column is not None:
281
+ sql += f" ORDER BY {time_ref} DESC\n"
282
+ else:
283
+ sql += f" ORDER BY {key_ref}\n"
284
+ sql += f") <= {num_neighbors}"
285
+
286
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
287
+ cursor.execute(sql, (payload, ))
288
+ table = cursor.fetch_arrow_all()
289
+
290
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
291
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
292
+ table = table.remove_column(batch_index)
293
+
294
+ return Table._sanitize(
295
+ df=table.to_pandas(),
296
+ dtype_dict=self.table_dtype_dict[table_name],
297
+ stype_dict=self.table_stype_dict[table_name],
298
+ ), batch
299
+
300
+ # Helper Methods ##########################################################
301
+
302
+ def _by_time(
303
+ self,
304
+ table_name: str,
305
+ foreign_key: str,
306
+ index: pd.Series,
307
+ anchor_time: pd.Series,
308
+ min_offset: pd.DateOffset | None,
309
+ max_offset: pd.DateOffset,
310
+ columns: set[str],
311
+ ) -> tuple[pd.DataFrame, np.ndarray]:
312
+ time_column = self.time_column_dict[table_name]
313
+
314
+ end_time = anchor_time + max_offset
315
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
316
+ if min_offset is not None:
317
+ start_time = anchor_time + min_offset
318
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
319
+ payload = json.dumps(list(zip(index, end_time, start_time)))
320
+ else:
321
+ payload = json.dumps(list(zip(index, end_time)))
322
+
323
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
324
+ time_ref = self.table_column_ref_dict[table_name][time_column]
325
+ projections = [
326
+ self.table_column_proj_dict[table_name][column]
327
+ for column in columns
328
+ ]
329
+ sql = ("WITH TMP as (\n"
330
+ " SELECT\n"
331
+ " f.index as __KUMO_BATCH__,\n")
332
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
333
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
334
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
335
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
336
+ else:
337
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
338
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
339
+ if min_offset is not None:
340
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
341
+ sql += (f"\n"
342
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
343
+ f")\n"
344
+ f"SELECT "
345
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
346
+ f"{', '.join(projections)}\n"
347
+ f"FROM TMP\n"
348
+ f"JOIN {self.source_name_dict[table_name]}\n"
349
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
350
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
351
+ if min_offset is not None:
352
+ sql += f"\n AND {time_ref} > TMP.__KUMO_START_TIME__"
353
+
354
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
355
+ cursor.execute(sql, (payload, ))
356
+ table = cursor.fetch_arrow_all()
357
+
358
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
359
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
360
+ table = table.remove_column(batch_index)
361
+
362
+ return Table._sanitize(
363
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
364
+ dtype_dict=self.table_dtype_dict[table_name],
365
+ stype_dict=self.table_stype_dict[table_name],
366
+ ), batch
@@ -1,11 +1,22 @@
1
1
  import re
2
- from typing import List, Optional, Sequence
2
+ from collections import Counter
3
+ from collections.abc import Sequence
4
+ from typing import cast
3
5
 
4
6
  import pandas as pd
7
+ from kumoapi.model_plan import MissingType
5
8
  from kumoapi.typing import Dtype
6
9
 
7
- from kumoai.experimental.rfm.backend.sqlite import Connection
8
- from kumoai.experimental.rfm.base import SourceColumn, SourceForeignKey, Table
10
+ from kumoai.experimental.rfm.backend.snow import Connection
11
+ from kumoai.experimental.rfm.base import (
12
+ ColumnSpec,
13
+ ColumnSpecType,
14
+ DataBackend,
15
+ SourceColumn,
16
+ SourceForeignKey,
17
+ Table,
18
+ )
19
+ from kumoai.utils import quote_ident
9
20
 
10
21
 
11
22
  class SnowTable(Table):
@@ -14,6 +25,10 @@ class SnowTable(Table):
14
25
  Args:
15
26
  connection: The connection to a :class:`snowflake` database.
16
27
  name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
30
+ database: The database.
31
+ schema: The schema.
17
32
  columns: The selected columns of this table.
18
33
  primary_key: The name of the primary key of this table, if it exists.
19
34
  time_column: The name of the time column of this table, if it exists.
@@ -24,17 +39,27 @@ class SnowTable(Table):
24
39
  self,
25
40
  connection: Connection,
26
41
  name: str,
42
+ source_name: str | None = None,
27
43
  database: str | None = None,
28
44
  schema: str | None = None,
29
- columns: Optional[Sequence[str]] = None,
30
- primary_key: Optional[str] = None,
31
- time_column: Optional[str] = None,
32
- end_time_column: Optional[str] = None,
45
+ columns: Sequence[ColumnSpecType] | None = None,
46
+ primary_key: MissingType | str | None = MissingType.VALUE,
47
+ time_column: str | None = None,
48
+ end_time_column: str | None = None,
33
49
  ) -> None:
34
50
 
35
- if database is not None and schema is None:
36
- raise ValueError(f"Missing 'schema' for table '{name}' in "
37
- f"database '{database}'")
51
+ if database is None or schema is None:
52
+ with connection.cursor() as cursor:
53
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
54
+ result = cursor.fetchone()
55
+ database = database or result[0]
56
+ assert database is not None
57
+ schema = schema or result[1]
58
+
59
+ if schema is None:
60
+ raise ValueError(f"Unspecified 'schema' for table "
61
+ f"'{source_name or name}' in database "
62
+ f"'{database}'")
38
63
 
39
64
  self._connection = connection
40
65
  self._database = database
@@ -42,6 +67,7 @@ class SnowTable(Table):
42
67
 
43
68
  super().__init__(
44
69
  name=name,
70
+ source_name=source_name,
45
71
  columns=columns,
46
72
  primary_key=primary_key,
47
73
  time_column=time_column,
@@ -49,67 +75,168 @@ class SnowTable(Table):
49
75
  )
50
76
 
51
77
  @property
52
- def fqn_name(self) -> str:
53
- names: List[str] = []
78
+ def source_name(self) -> str:
79
+ names: list[str] = []
54
80
  if self._database is not None:
55
- assert self._schema is not None
56
- names.extend([self._database, self._schema])
57
- elif self._schema is not None:
81
+ names.append(self._database)
82
+ if self._schema is not None:
58
83
  names.append(self._schema)
59
- names.append(self._name)
60
- return '.'.join(names)
84
+ return '.'.join(names + [self._source_name])
61
85
 
62
- def _get_source_columns(self) -> List[SourceColumn]:
63
- source_columns: List[SourceColumn] = []
86
+ @property
87
+ def _quoted_source_name(self) -> str:
88
+ names: list[str] = []
89
+ if self._database is not None:
90
+ names.append(quote_ident(self._database))
91
+ if self._schema is not None:
92
+ names.append(quote_ident(self._schema))
93
+ return '.'.join(names + [quote_ident(self._source_name)])
94
+
95
+ @property
96
+ def backend(self) -> DataBackend:
97
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
98
+
99
+ def _get_source_columns(self) -> list[SourceColumn]:
100
+ source_columns: list[SourceColumn] = []
64
101
  with self._connection.cursor() as cursor:
65
102
  try:
66
- cursor.execute(f"DESCRIBE TABLE {self.fqn_name}")
103
+ sql = f"DESCRIBE TABLE {self._quoted_source_name}"
104
+ cursor.execute(sql)
67
105
  except Exception as e:
68
- raise ValueError(
69
- f"Table '{self.fqn_name}' does not exist") from e
106
+ raise ValueError(f"Table '{self.source_name}' does not exist "
107
+ f"in the remote data backend") from e
70
108
 
71
109
  for row in cursor.fetchall():
72
- column, type, _, _, _, is_pkey, is_unique = row[:7]
73
-
74
- type = type.strip().upper()
75
- if type.startswith('NUMBER'):
76
- dtype = Dtype.int
77
- elif type.startswith('VARCHAR'):
78
- dtype = Dtype.string
79
- elif type == 'FLOAT':
80
- dtype = Dtype.float
81
- elif type == 'BOOLEAN':
82
- dtype = Dtype.bool
83
- elif re.search('DATE|TIMESTAMP', type):
84
- dtype = Dtype.date
85
- else:
86
- continue
110
+ column, dtype, _, null, _, is_pkey, is_unique, *_ = row
87
111
 
88
112
  source_column = SourceColumn(
89
113
  name=column,
90
- dtype=dtype,
114
+ dtype=self._to_dtype(dtype),
91
115
  is_primary_key=is_pkey.strip().upper() == 'Y',
92
116
  is_unique_key=is_unique.strip().upper() == 'Y',
117
+ is_nullable=null.strip().upper() == 'Y',
93
118
  )
94
119
  source_columns.append(source_column)
95
120
 
96
121
  return source_columns
97
122
 
98
- def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
99
- source_fkeys: List[SourceForeignKey] = []
123
+ def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
124
+ source_foreign_keys: list[SourceForeignKey] = []
100
125
  with self._connection.cursor() as cursor:
101
- cursor.execute(f"SHOW IMPORTED KEYS IN TABLE {self.fqn_name}")
102
- for row in cursor.fetchall():
103
- _, _, _, dst_table, pkey, _, _, _, fkey = row[:9]
104
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
105
- return source_fkeys
126
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
127
+ cursor.execute(sql)
128
+ rows = cursor.fetchall()
129
+ counts = Counter(row[13] for row in rows)
130
+ for row in rows:
131
+ if counts[row[13]] == 1:
132
+ source_foreign_key = SourceForeignKey(
133
+ name=row[8],
134
+ dst_table=f'{row[1]}.{row[2]}.{row[3]}',
135
+ primary_key=row[4],
136
+ )
137
+ source_foreign_keys.append(source_foreign_key)
138
+ return source_foreign_keys
139
+
140
+ def _get_source_sample_df(self) -> pd.DataFrame:
141
+ with self._connection.cursor() as cursor:
142
+ columns = [quote_ident(col) for col in self._source_column_dict]
143
+ sql = (f"SELECT {', '.join(columns)} "
144
+ f"FROM {self._quoted_source_name} "
145
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
146
+ cursor.execute(sql)
147
+ table = cursor.fetch_arrow_all()
106
148
 
107
- def _get_sample_df(self) -> pd.DataFrame:
149
+ if table is None:
150
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
151
+
152
+ return self._sanitize(
153
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
154
+ dtype_dict={
155
+ column.name: column.dtype
156
+ for column in self._source_column_dict.values()
157
+ },
158
+ stype_dict=None,
159
+ )
160
+
161
+ def _get_num_rows(self) -> int | None:
162
+ return None
163
+
164
+ def _get_expr_sample_df(
165
+ self,
166
+ columns: Sequence[ColumnSpec],
167
+ ) -> pd.DataFrame:
108
168
  with self._connection.cursor() as cursor:
109
- columns = ', '.join(self._source_column_dict.keys())
110
- cursor.execute(f"SELECT {columns} FROM {self.fqn_name} LIMIT 1000")
169
+ projections = [
170
+ f"{column.expr} AS {quote_ident(column.name)}"
171
+ for column in columns
172
+ ]
173
+ sql = (f"SELECT {', '.join(projections)} "
174
+ f"FROM {self._quoted_source_name} "
175
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
176
+ cursor.execute(sql)
111
177
  table = cursor.fetch_arrow_all()
112
- return table.to_pandas(types_mapper=pd.ArrowDtype)
113
178
 
114
- def _get_num_rows(self) -> Optional[int]:
179
+ if table is None:
180
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
181
+
182
+ return self._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict={column.name: column.dtype
185
+ for column in columns},
186
+ stype_dict=None,
187
+ )
188
+
189
+ @staticmethod
190
+ def _to_dtype(dtype: str | None) -> Dtype | None:
191
+ if dtype is None:
192
+ return None
193
+ dtype = dtype.strip().upper()
194
+ if dtype.startswith('NUMBER'):
195
+ try: # Parse `scale` from 'NUMBER(precision, scale)':
196
+ scale = int(dtype.split(',')[-1].split(')')[0])
197
+ return Dtype.int if scale == 0 else Dtype.float
198
+ except Exception:
199
+ return Dtype.float
200
+ if dtype == 'FLOAT':
201
+ return Dtype.float
202
+ if dtype.startswith('VARCHAR'):
203
+ return Dtype.string
204
+ if dtype.startswith('BINARY'):
205
+ return Dtype.binary
206
+ if dtype == 'BOOLEAN':
207
+ return Dtype.bool
208
+ if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
209
+ return Dtype.date
210
+ if dtype.startswith('TIME'):
211
+ return Dtype.time
212
+ if dtype.startswith('VECTOR'):
213
+ try: # Parse element data type from 'VECTOR(dtype, dimension)':
214
+ dtype = dtype.split(',')[0].split('(')[1].strip()
215
+ if dtype == 'INT':
216
+ return Dtype.intlist
217
+ elif dtype == 'FLOAT':
218
+ return Dtype.floatlist
219
+ except Exception:
220
+ pass
221
+ return Dtype.unsupported
222
+ if dtype.startswith('ARRAY'):
223
+ try: # Parse element data type from 'ARRAY(dtype)':
224
+ dtype = dtype.split('(', maxsplit=1)[1]
225
+ dtype = dtype.rsplit(')', maxsplit=1)[0]
226
+ _dtype = SnowTable._to_dtype(dtype)
227
+ if _dtype is not None and _dtype.is_int():
228
+ return Dtype.intlist
229
+ elif _dtype is not None and _dtype.is_float():
230
+ return Dtype.floatlist
231
+ elif _dtype is not None and _dtype.is_string():
232
+ return Dtype.stringlist
233
+ except Exception:
234
+ pass
235
+ return Dtype.unsupported
236
+ # Unsupported data types:
237
+ if re.search(
238
+ 'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
239
+ dtype,
240
+ ):
241
+ return Dtype.unsupported
115
242
  return None