kumoai 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +190 -71
  11. kumoai/experimental/rfm/backend/snow/table.py +137 -64
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +192 -87
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +27 -0
  23. kumoai/experimental/rfm/graph.py +115 -107
  24. kumoai/experimental/rfm/infer/dtype.py +4 -1
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/pquery/training_table.py +16 -2
  31. kumoai/testing/snow.py +3 -3
  32. kumoai/trainer/distilled_trainer.py +175 -0
  33. kumoai/utils/display.py +87 -0
  34. kumoai/utils/progress_logger.py +13 -1
  35. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +2 -2
  36. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +39 -34
  37. kumoai/experimental/rfm/base/column_expression.py +0 -50
  38. kumoai/experimental/rfm/base/sql_table.py +0 -229
  39. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
  40. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
  41. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,20 @@
1
1
  import json
2
2
  from collections.abc import Iterator
3
3
  from contextlib import contextmanager
4
+ from typing import TYPE_CHECKING
4
5
 
5
6
  import numpy as np
6
7
  import pandas as pd
7
8
  import pyarrow as pa
8
9
  from kumoapi.pquery import ValidatedPredictiveQuery
9
10
 
10
- from kumoai.experimental.rfm.backend.snow import Connection
11
- from kumoai.experimental.rfm.base import SQLSampler
11
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
12
+ from kumoai.experimental.rfm.base import SQLSampler, Table
12
13
  from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
13
- from kumoai.utils import quote_ident
14
+ from kumoai.utils import ProgressLogger
15
+
16
+ if TYPE_CHECKING:
17
+ from kumoai.experimental.rfm import Graph
14
18
 
15
19
 
16
20
  @contextmanager
@@ -22,18 +26,30 @@ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
22
26
 
23
27
 
24
28
  class SnowSampler(SQLSampler):
29
+ def __init__(
30
+ self,
31
+ graph: 'Graph',
32
+ verbose: bool | ProgressLogger = True,
33
+ ) -> None:
34
+ super().__init__(graph=graph, verbose=verbose)
35
+
36
+ for table in graph.tables.values():
37
+ assert isinstance(table, SnowTable)
38
+ self._connection = table._connection
39
+
25
40
  def _get_min_max_time_dict(
26
41
  self,
27
42
  table_names: list[str],
28
43
  ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
29
44
  selects: list[str] = []
30
45
  for table_name in table_names:
31
- time_column = self.time_column_dict[table_name]
46
+ column = self.time_column_dict[table_name]
47
+ column_ref = self.table_column_ref_dict[table_name][column]
32
48
  select = (f"SELECT\n"
33
49
  f" ? as table_name,\n"
34
- f" MIN({quote_ident(time_column)}) as min_date,\n"
35
- f" MAX({quote_ident(time_column)}) as max_date\n"
36
- f"FROM {self.fqn_dict[table_name]}")
50
+ f" MIN({column_ref}) as min_date,\n"
51
+ f" MAX({column_ref}) as max_date\n"
52
+ f"FROM {self.source_name_dict[table_name]}")
37
53
  selects.append(select)
38
54
  sql = "\nUNION ALL\n".join(selects)
39
55
 
@@ -59,17 +75,27 @@ class SnowSampler(SQLSampler):
59
75
  # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
60
76
  num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
61
77
 
78
+ source_table = self.source_table_dict[table_name]
62
79
  filters: list[str] = []
63
- primary_key = self.primary_key_dict[table_name]
64
- if self.source_table_dict[table_name][primary_key].is_nullable:
65
- filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
66
- time_column = self.time_column_dict.get(table_name)
67
- if (time_column is not None and
68
- self.source_table_dict[table_name][time_column].is_nullable):
69
- filters.append(f" {quote_ident(time_column)} IS NOT NULL")
70
80
 
71
- sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
72
- f"FROM {self.fqn_dict[table_name]}\n"
81
+ key = self.primary_key_dict[table_name]
82
+ if key not in source_table or source_table[key].is_nullable:
83
+ key_ref = self.table_column_ref_dict[table_name][key]
84
+ filters.append(f" {key_ref} IS NOT NULL")
85
+
86
+ column = self.time_column_dict.get(table_name)
87
+ if column is None:
88
+ pass
89
+ elif column not in source_table or source_table[column].is_nullable:
90
+ column_ref = self.table_column_ref_dict[table_name][column]
91
+ filters.append(f" {column_ref} IS NOT NULL")
92
+
93
+ projections = [
94
+ self.table_column_proj_dict[table_name][column]
95
+ for column in columns
96
+ ]
97
+ sql = (f"SELECT {', '.join(projections)}\n"
98
+ f"FROM {self.source_name_dict[table_name]}\n"
73
99
  f"SAMPLE ROW ({num_rows} ROWS)")
74
100
  if len(filters) > 0:
75
101
  sql += f"\nWHERE{' AND'.join(filters)}"
@@ -79,7 +105,11 @@ class SnowSampler(SQLSampler):
79
105
  cursor.execute(sql)
80
106
  table = cursor.fetch_arrow_all()
81
107
 
82
- return self._sanitize(table_name, table)
108
+ return Table._sanitize(
109
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
110
+ dtype_dict=self.table_dtype_dict[table_name],
111
+ stype_dict=self.table_stype_dict[table_name],
112
+ )
83
113
 
84
114
  def _sample_target(
85
115
  self,
@@ -114,11 +144,11 @@ class SnowSampler(SQLSampler):
114
144
  query.entity_table: np.arange(len(entity_df)),
115
145
  }
116
146
  for edge_type, (min_offset, max_offset) in time_offset_dict.items():
117
- table_name, fkey, _ = edge_type
147
+ table_name, foreign_key, _ = edge_type
118
148
  feat_dict[table_name], batch_dict[table_name] = self._by_time(
119
149
  table_name=table_name,
120
- fkey=fkey,
121
- pkey=entity_df[self.primary_key_dict[query.entity_table]],
150
+ foreign_key=foreign_key,
151
+ index=entity_df[self.primary_key_dict[query.entity_table]],
122
152
  anchor_time=time,
123
153
  min_offset=min_offset,
124
154
  max_offset=max_offset,
@@ -149,104 +179,193 @@ class SnowSampler(SQLSampler):
149
179
  def _by_pkey(
150
180
  self,
151
181
  table_name: str,
152
- pkey: pd.Series,
182
+ index: pd.Series,
153
183
  columns: set[str],
154
184
  ) -> tuple[pd.DataFrame, np.ndarray]:
185
+ key = self.primary_key_dict[table_name]
186
+ key_ref = self.table_column_ref_dict[table_name][key]
187
+ projections = [
188
+ self.table_column_proj_dict[table_name][column]
189
+ for column in columns
190
+ ]
155
191
 
156
- pkey_name = self.primary_key_dict[table_name]
157
- source_table = self.source_table_dict[table_name]
158
-
159
- payload = json.dumps(list(pkey))
192
+ payload = json.dumps(list(index))
160
193
 
161
194
  sql = ("WITH TMP as (\n"
162
195
  " SELECT\n"
163
- " f.index as BATCH,\n")
164
- if source_table[pkey_name].dtype.is_int():
165
- sql += " f.value::NUMBER as ID\n"
166
- elif source_table[pkey_name].dtype.is_float():
167
- sql += " f.value::FLOAT as ID\n"
196
+ " f.index as __KUMO_BATCH__,\n")
197
+ if self.table_dtype_dict[table_name][key].is_int():
198
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
199
+ elif self.table_dtype_dict[table_name][key].is_float():
200
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
168
201
  else:
169
- sql += " f.value::VARCHAR as ID\n"
202
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
170
203
  sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
171
204
  f")\n"
172
- f"SELECT TMP.BATCH as __BATCH__, "
173
- f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
205
+ f"SELECT "
206
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
207
+ f"{', '.join(projections)}\n"
174
208
  f"FROM TMP\n"
175
- f"JOIN {self.fqn_dict[table_name]} ENT\n"
176
- f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
209
+ f"JOIN {self.source_name_dict[table_name]}\n"
210
+ f" ON {key_ref} = TMP.__KUMO_ID__")
177
211
 
178
212
  with paramstyle(self._connection), self._connection.cursor() as cursor:
179
213
  cursor.execute(sql, (payload, ))
180
214
  table = cursor.fetch_arrow_all()
181
215
 
182
216
  # Remove any duplicated primary keys in post-processing:
183
- tmp = table.append_column('__TMP__', pa.array(range(len(table))))
184
- gb = tmp.group_by('__BATCH__').aggregate([('__TMP__', 'min')])
185
- table = table.take(gb['__TMP___min'])
217
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
218
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
219
+ table = table.take(gb['__KUMO_ID___min'])
220
+
221
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
222
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
223
+ table = table.remove_column(batch_index)
186
224
 
187
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
188
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
225
+ return Table._sanitize(
226
+ df=table.to_pandas(),
227
+ dtype_dict=self.table_dtype_dict[table_name],
228
+ stype_dict=self.table_stype_dict[table_name],
229
+ ), batch
189
230
 
190
- return table.to_pandas(), batch # TODO Use `self._sanitize`.
231
+ def _by_fkey(
232
+ self,
233
+ table_name: str,
234
+ foreign_key: str,
235
+ index: pd.Series,
236
+ num_neighbors: int,
237
+ anchor_time: pd.Series | None,
238
+ columns: set[str],
239
+ ) -> tuple[pd.DataFrame, np.ndarray]:
240
+ time_column = self.time_column_dict.get(table_name)
241
+
242
+ if time_column is not None and anchor_time is not None:
243
+ anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
244
+ payload = json.dumps(list(zip(index, anchor_time)))
245
+ else:
246
+ payload = json.dumps(list(zip(index)))
247
+
248
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
249
+ projections = [
250
+ self.table_column_proj_dict[table_name][column]
251
+ for column in columns
252
+ ]
253
+
254
+ sql = ("WITH TMP as (\n"
255
+ " SELECT\n"
256
+ " f.index as __KUMO_BATCH__,\n")
257
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
258
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
259
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
260
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
261
+ else:
262
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
263
+ if time_column is not None and anchor_time is not None:
264
+ sql += (",\n"
265
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_TIME__")
266
+ sql += (f"\n"
267
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
268
+ f")\n"
269
+ f"SELECT "
270
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
271
+ f"{', '.join(projections)}\n"
272
+ f"FROM TMP\n"
273
+ f"JOIN {self.source_name_dict[table_name]}\n"
274
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
275
+ if time_column is not None and anchor_time is not None:
276
+ time_ref = self.table_column_ref_dict[table_name][time_column]
277
+ sql += f" AND {time_ref} <= TMP.__KUMO_TIME__\n"
278
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
279
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
280
+ if time_column is not None:
281
+ sql += f" ORDER BY {time_ref} DESC\n"
282
+ else:
283
+ sql += f" ORDER BY {key_ref}\n"
284
+ sql += f") <= {num_neighbors}"
285
+
286
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
287
+ cursor.execute(sql, (payload, ))
288
+ table = cursor.fetch_arrow_all()
289
+
290
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
291
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
292
+ table = table.remove_column(batch_index)
293
+
294
+ return Table._sanitize(
295
+ df=table.to_pandas(),
296
+ dtype_dict=self.table_dtype_dict[table_name],
297
+ stype_dict=self.table_stype_dict[table_name],
298
+ ), batch
191
299
 
192
300
  # Helper Methods ##########################################################
193
301
 
194
302
  def _by_time(
195
303
  self,
196
304
  table_name: str,
197
- fkey: str,
198
- pkey: pd.Series,
305
+ foreign_key: str,
306
+ index: pd.Series,
199
307
  anchor_time: pd.Series,
200
308
  min_offset: pd.DateOffset | None,
201
309
  max_offset: pd.DateOffset,
202
310
  columns: set[str],
203
311
  ) -> tuple[pd.DataFrame, np.ndarray]:
312
+ time_column = self.time_column_dict[table_name]
204
313
 
205
314
  end_time = anchor_time + max_offset
206
315
  end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
316
+ start_time: pd.Series | None = None
207
317
  if min_offset is not None:
208
318
  start_time = anchor_time + min_offset
209
319
  start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
210
- payload = json.dumps(list(zip(pkey, end_time, start_time)))
320
+ payload = json.dumps(list(zip(index, end_time, start_time)))
211
321
  else:
212
- payload = json.dumps(list(zip(pkey, end_time)))
213
-
214
- # Based on benchmarking, JSON payload is the fastest way to query by
215
- # custom indices (compared to large `IN` clauses or temporary tables):
216
- source_table = self.source_table_dict[table_name]
217
- time_column = self.time_column_dict[table_name]
322
+ payload = json.dumps(list(zip(index, end_time)))
323
+
324
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
325
+ time_ref = self.table_column_ref_dict[table_name][time_column]
326
+ projections = [
327
+ self.table_column_proj_dict[table_name][column]
328
+ for column in columns
329
+ ]
218
330
  sql = ("WITH TMP as (\n"
219
331
  " SELECT\n"
220
- " f.index as BATCH,\n")
221
- if source_table[fkey].dtype.is_int():
222
- sql += " f.value[0]::NUMBER as ID,\n"
223
- elif source_table[fkey].dtype.is_float():
224
- sql += " f.value[0]::FLOAT as ID,\n"
332
+ " f.index as __KUMO_BATCH__,\n")
333
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
334
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
335
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
336
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
225
337
  else:
226
- sql += " f.value[0]::VARCHAR as ID,\n"
227
- sql += " f.value[1]::TIMESTAMP_NTZ as END_TIME"
338
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
339
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
228
340
  if min_offset is not None:
229
- sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
341
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
230
342
  sql += (f"\n"
231
343
  f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
232
344
  f")\n"
233
- f"SELECT TMP.BATCH as __BATCH__, "
234
- f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
345
+ f"SELECT "
346
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
347
+ f"{', '.join(projections)}\n"
235
348
  f"FROM TMP\n"
236
- f"JOIN {self.fqn_dict[table_name]} FACT\n"
237
- f" ON FACT.{quote_ident(fkey)} = TMP.ID\n"
238
- f" AND FACT.{quote_ident(time_column)} <= TMP.END_TIME")
239
- if min_offset is not None:
240
- sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
349
+ f"JOIN {self.source_name_dict[table_name]}\n"
350
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
351
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
352
+ if start_time is not None:
353
+ sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
354
+ # Add global time bounds to enable partition pruning:
355
+ sql += f"WHERE {time_ref} <= '{end_time.max()}'"
356
+ if start_time is not None:
357
+ sql += f"\nAND {time_ref} > '{start_time.min()}'"
241
358
 
242
359
  with paramstyle(self._connection), self._connection.cursor() as cursor:
243
360
  cursor.execute(sql, (payload, ))
244
361
  table = cursor.fetch_arrow_all()
245
362
 
246
- batch = table['__BATCH__'].cast(pa.int64()).to_numpy()
247
- table = table.remove_column(table.schema.get_field_index('__BATCH__'))
248
-
249
- return self._sanitize(table_name, table), batch
363
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
364
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
365
+ table = table.remove_column(batch_index)
250
366
 
251
- def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
252
- return table.to_pandas(types_mapper=pd.ArrowDtype)
367
+ return Table._sanitize(
368
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
369
+ dtype_dict=self.table_dtype_dict[table_name],
370
+ stype_dict=self.table_stype_dict[table_name],
371
+ ), batch
@@ -1,4 +1,5 @@
1
1
  import re
2
+ from collections import Counter
2
3
  from collections.abc import Sequence
3
4
  from typing import cast
4
5
 
@@ -8,28 +9,27 @@ from kumoapi.typing import Dtype
8
9
 
9
10
  from kumoai.experimental.rfm.backend.snow import Connection
10
11
  from kumoai.experimental.rfm.base import (
11
- ColumnExpressionSpec,
12
- ColumnExpressionType,
12
+ ColumnSpec,
13
+ ColumnSpecType,
13
14
  DataBackend,
14
15
  SourceColumn,
15
16
  SourceForeignKey,
16
- SQLTable,
17
+ Table,
17
18
  )
18
19
  from kumoai.utils import quote_ident
19
20
 
20
21
 
21
- class SnowTable(SQLTable):
22
+ class SnowTable(Table):
22
23
  r"""A table backed by a :class:`sqlite` database.
23
24
 
24
25
  Args:
25
26
  connection: The connection to a :class:`snowflake` database.
26
- name: The logical name of this table.
27
- source_name: The physical name of this table in the database. If set to
28
- ``None``, ``name`` is being used.
27
+ name: The name of this table.
28
+ source_name: The source name of this table. If set to ``None``,
29
+ ``name`` is being used.
29
30
  database: The database.
30
31
  schema: The schema.
31
- columns: The selected physical columns of this table.
32
- column_expressions: The logical columns of this table.
32
+ columns: The selected columns of this table.
33
33
  primary_key: The name of the primary key of this table, if it exists.
34
34
  time_column: The name of the time column of this table, if it exists.
35
35
  end_time_column: The name of the end time column of this table, if it
@@ -42,14 +42,21 @@ class SnowTable(SQLTable):
42
42
  source_name: str | None = None,
43
43
  database: str | None = None,
44
44
  schema: str | None = None,
45
- columns: Sequence[str] | None = None,
46
- column_expressions: Sequence[ColumnExpressionType] | None = None,
45
+ columns: Sequence[ColumnSpecType] | None = None,
47
46
  primary_key: MissingType | str | None = MissingType.VALUE,
48
47
  time_column: str | None = None,
49
48
  end_time_column: str | None = None,
50
49
  ) -> None:
51
50
 
52
- if database is not None and schema is None:
51
+ if database is None or schema is None:
52
+ with connection.cursor() as cursor:
53
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
54
+ result = cursor.fetchone()
55
+ database = database or result[0]
56
+ assert database is not None
57
+ schema = schema or result[1]
58
+
59
+ if schema is None:
53
60
  raise ValueError(f"Unspecified 'schema' for table "
54
61
  f"'{source_name or name}' in database "
55
62
  f"'{database}'")
@@ -62,37 +69,22 @@ class SnowTable(SQLTable):
62
69
  name=name,
63
70
  source_name=source_name,
64
71
  columns=columns,
65
- column_expressions=column_expressions,
66
72
  primary_key=primary_key,
67
73
  time_column=time_column,
68
74
  end_time_column=end_time_column,
69
75
  )
70
76
 
71
- @staticmethod
72
- def to_dtype(snowflake_dtype: str | None) -> Dtype | None:
73
- if snowflake_dtype is None:
74
- return None
75
- snowflake_dtype = snowflake_dtype.strip().upper()
76
- # TODO 'NUMBER(...)' is not always an integer!
77
- if snowflake_dtype.startswith('NUMBER'):
78
- return Dtype.int
79
- elif snowflake_dtype.startswith('VARCHAR'):
80
- return Dtype.string
81
- elif snowflake_dtype == 'FLOAT':
82
- return Dtype.float
83
- elif snowflake_dtype == 'BOOLEAN':
84
- return Dtype.bool
85
- elif re.search('DATE|TIMESTAMP', snowflake_dtype):
86
- return Dtype.date
87
- return None
88
-
89
77
  @property
90
- def backend(self) -> DataBackend:
91
- return cast(DataBackend, DataBackend.SNOWFLAKE)
78
+ def source_name(self) -> str:
79
+ names: list[str] = []
80
+ if self._database is not None:
81
+ names.append(self._database)
82
+ if self._schema is not None:
83
+ names.append(self._schema)
84
+ return '.'.join(names + [self._source_name])
92
85
 
93
86
  @property
94
- def fqn(self) -> str:
95
- r"""The fully-qualified quoted table name."""
87
+ def _quoted_source_name(self) -> str:
96
88
  names: list[str] = []
97
89
  if self._database is not None:
98
90
  names.append(quote_ident(self._database))
@@ -100,32 +92,26 @@ class SnowTable(SQLTable):
100
92
  names.append(quote_ident(self._schema))
101
93
  return '.'.join(names + [quote_ident(self._source_name)])
102
94
 
95
+ @property
96
+ def backend(self) -> DataBackend:
97
+ return cast(DataBackend, DataBackend.SNOWFLAKE)
98
+
103
99
  def _get_source_columns(self) -> list[SourceColumn]:
104
100
  source_columns: list[SourceColumn] = []
105
101
  with self._connection.cursor() as cursor:
106
102
  try:
107
- sql = f"DESCRIBE TABLE {self.fqn}"
103
+ sql = f"DESCRIBE TABLE {self._quoted_source_name}"
108
104
  cursor.execute(sql)
109
105
  except Exception as e:
110
- names: list[str] = []
111
- if self._database is not None:
112
- names.append(self._database)
113
- if self._schema is not None:
114
- names.append(self._schema)
115
- source_name = '.'.join(names + [self._source_name])
116
- raise ValueError(f"Table '{source_name}' does not exist in "
117
- f"the remote data backend") from e
106
+ raise ValueError(f"Table '{self.source_name}' does not exist "
107
+ f"in the remote data backend") from e
118
108
 
119
109
  for row in cursor.fetchall():
120
- column, type, _, null, _, is_pkey, is_unique, *_ = row
121
-
122
- dtype = self.to_dtype(type)
123
- if dtype is None:
124
- continue
110
+ column, dtype, _, null, _, is_pkey, is_unique, *_ = row
125
111
 
126
112
  source_column = SourceColumn(
127
113
  name=column,
128
- dtype=dtype,
114
+ dtype=self._to_dtype(dtype),
129
115
  is_primary_key=is_pkey.strip().upper() == 'Y',
130
116
  is_unique_key=is_unique.strip().upper() == 'Y',
131
117
  is_nullable=null.strip().upper() == 'Y',
@@ -135,35 +121,122 @@ class SnowTable(SQLTable):
135
121
  return source_columns
136
122
 
137
123
  def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
138
- source_fkeys: list[SourceForeignKey] = []
124
+ source_foreign_keys: list[SourceForeignKey] = []
139
125
  with self._connection.cursor() as cursor:
140
- sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
126
+ sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
141
127
  cursor.execute(sql)
142
- for row in cursor.fetchall():
143
- _, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
144
- source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
145
- return source_fkeys
128
+ rows = cursor.fetchall()
129
+ counts = Counter(row[13] for row in rows)
130
+ for row in rows:
131
+ if counts[row[13]] == 1:
132
+ source_foreign_key = SourceForeignKey(
133
+ name=row[8],
134
+ dst_table=f'{row[1]}.{row[2]}.{row[3]}',
135
+ primary_key=row[4],
136
+ )
137
+ source_foreign_keys.append(source_foreign_key)
138
+ return source_foreign_keys
146
139
 
147
140
  def _get_source_sample_df(self) -> pd.DataFrame:
148
141
  with self._connection.cursor() as cursor:
149
142
  columns = [quote_ident(col) for col in self._source_column_dict]
150
- sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
143
+ sql = (f"SELECT {', '.join(columns)} "
144
+ f"FROM {self._quoted_source_name} "
145
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
151
146
  cursor.execute(sql)
152
147
  table = cursor.fetch_arrow_all()
153
- return table.to_pandas(types_mapper=pd.ArrowDtype)
148
+
149
+ if table is None:
150
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
151
+
152
+ return self._sanitize(
153
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
154
+ dtype_dict={
155
+ column.name: column.dtype
156
+ for column in self._source_column_dict.values()
157
+ },
158
+ stype_dict=None,
159
+ )
154
160
 
155
161
  def _get_num_rows(self) -> int | None:
156
162
  return None
157
163
 
158
- def _get_expression_sample_df(
164
+ def _get_expr_sample_df(
159
165
  self,
160
- specs: Sequence[ColumnExpressionSpec],
166
+ columns: Sequence[ColumnSpec],
161
167
  ) -> pd.DataFrame:
162
168
  with self._connection.cursor() as cursor:
163
- columns = [
164
- f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
169
+ projections = [
170
+ f"{column.expr} AS {quote_ident(column.name)}"
171
+ for column in columns
165
172
  ]
166
- sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
173
+ sql = (f"SELECT {', '.join(projections)} "
174
+ f"FROM {self._quoted_source_name} "
175
+ f"LIMIT {self._NUM_SAMPLE_ROWS}")
167
176
  cursor.execute(sql)
168
177
  table = cursor.fetch_arrow_all()
169
- return table.to_pandas(types_mapper=pd.ArrowDtype)
178
+
179
+ if table is None:
180
+ raise RuntimeError(f"Table '{self.source_name}' is empty")
181
+
182
+ return self._sanitize(
183
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
184
+ dtype_dict={column.name: column.dtype
185
+ for column in columns},
186
+ stype_dict=None,
187
+ )
188
+
189
+ @staticmethod
190
+ def _to_dtype(dtype: str | None) -> Dtype | None:
191
+ if dtype is None:
192
+ return None
193
+ dtype = dtype.strip().upper()
194
+ if dtype.startswith('NUMBER'):
195
+ try: # Parse `scale` from 'NUMBER(precision, scale)':
196
+ scale = int(dtype.split(',')[-1].split(')')[0])
197
+ return Dtype.int if scale == 0 else Dtype.float
198
+ except Exception:
199
+ return Dtype.float
200
+ if dtype == 'FLOAT':
201
+ return Dtype.float
202
+ if dtype.startswith('VARCHAR'):
203
+ return Dtype.string
204
+ if dtype.startswith('BINARY'):
205
+ return Dtype.binary
206
+ if dtype == 'BOOLEAN':
207
+ return Dtype.bool
208
+ if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
209
+ return Dtype.date
210
+ if dtype.startswith('TIME'):
211
+ return Dtype.time
212
+ if dtype.startswith('VECTOR'):
213
+ try: # Parse element data type from 'VECTOR(dtype, dimension)':
214
+ dtype = dtype.split(',')[0].split('(')[1].strip()
215
+ if dtype == 'INT':
216
+ return Dtype.intlist
217
+ elif dtype == 'FLOAT':
218
+ return Dtype.floatlist
219
+ except Exception:
220
+ pass
221
+ return Dtype.unsupported
222
+ if dtype.startswith('ARRAY'):
223
+ try: # Parse element data type from 'ARRAY(dtype)':
224
+ dtype = dtype.split('(', maxsplit=1)[1]
225
+ dtype = dtype.rsplit(')', maxsplit=1)[0]
226
+ _dtype = SnowTable._to_dtype(dtype)
227
+ if _dtype is not None and _dtype.is_int():
228
+ return Dtype.intlist
229
+ elif _dtype is not None and _dtype.is_float():
230
+ return Dtype.floatlist
231
+ elif _dtype is not None and _dtype.is_string():
232
+ return Dtype.stringlist
233
+ except Exception:
234
+ pass
235
+ return Dtype.unsupported
236
+ # Unsupported data types:
237
+ if re.search(
238
+ 'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
239
+ dtype,
240
+ ):
241
+ return Dtype.unsupported
242
+ return None