kumoai 2.13.0.dev202512040252__cp310-cp310-win_amd64.whl → 2.15.0.dev202601141731__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +181 -51
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +456 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +23 -3
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/mapper.py +69 -0
  23. kumoai/experimental/rfm/base/sampler.py +783 -0
  24. kumoai/experimental/rfm/base/source.py +2 -1
  25. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  26. kumoai/experimental/rfm/base/table.py +385 -203
  27. kumoai/experimental/rfm/base/utils.py +36 -0
  28. kumoai/experimental/rfm/graph.py +374 -172
  29. kumoai/experimental/rfm/infer/__init__.py +6 -4
  30. kumoai/experimental/rfm/infer/dtype.py +10 -5
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +4 -2
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +5 -4
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +770 -467
  39. kumoai/experimental/rfm/sagemaker.py +4 -4
  40. kumoai/experimental/rfm/task_table.py +292 -0
  41. kumoai/kumolib.cp310-win_amd64.pyd +0 -0
  42. kumoai/pquery/predictive_query.py +10 -6
  43. kumoai/pquery/training_table.py +16 -2
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +87 -0
  48. kumoai/utils/progress_logger.py +192 -13
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +3 -2
  51. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +54 -41
  52. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  53. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  54. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
  55. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
  56. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
@@ -27,9 +27,11 @@ def connect(**kwargs: Any) -> Connection:
27
27
 
28
28
 
29
29
  from .table import SnowTable # noqa: E402
30
+ from .sampler import SnowSampler # noqa: E402
30
31
 
31
32
  __all__ = [
32
33
  'connect',
33
34
  'Connection',
34
35
  'SnowTable',
36
+ 'SnowSampler',
35
37
  ]
@@ -0,0 +1,407 @@
1
+ import json
2
+ import math
3
+ from collections.abc import Iterator
4
+ from contextlib import contextmanager
5
+ from typing import TYPE_CHECKING, cast
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import pyarrow as pa
10
+ from kumoapi.pquery import ValidatedPredictiveQuery
11
+
12
+ from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
13
+ from kumoai.experimental.rfm.base import SQLSampler, Table
14
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
15
+ from kumoai.utils import ProgressLogger, quote_ident
16
+
17
+ if TYPE_CHECKING:
18
+ from kumoai.experimental.rfm import Graph
19
+
20
+
21
+ @contextmanager
22
+ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
23
+ _style = connection._paramstyle
24
+ connection._paramstyle = style
25
+ yield
26
+ connection._paramstyle = _style
27
+
28
+
29
+ class SnowSampler(SQLSampler):
30
+ def __init__(
31
+ self,
32
+ graph: 'Graph',
33
+ verbose: bool | ProgressLogger = True,
34
+ ) -> None:
35
+ super().__init__(graph=graph, verbose=verbose)
36
+
37
+ for table in graph.tables.values():
38
+ assert isinstance(table, SnowTable)
39
+ self._connection = table._connection
40
+
41
+ self._num_rows_dict: dict[str, int] = {
42
+ table.name: cast(int, table._num_rows)
43
+ for table in graph.tables.values()
44
+ }
45
+
46
+ @property
47
+ def num_rows_dict(self) -> dict[str, int]:
48
+ return self._num_rows_dict
49
+
50
+ def _get_min_max_time_dict(
51
+ self,
52
+ table_names: list[str],
53
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
54
+ selects: list[str] = []
55
+ for table_name in table_names:
56
+ column = self.time_column_dict[table_name]
57
+ column_ref = self.table_column_ref_dict[table_name][column]
58
+ ident = quote_ident(table_name, char="'")
59
+ select = (f"SELECT\n"
60
+ f" {ident} as table_name,\n"
61
+ f" MIN({column_ref}) as min_date,\n"
62
+ f" MAX({column_ref}) as max_date\n"
63
+ f"FROM {self.source_name_dict[table_name]}")
64
+ selects.append(select)
65
+ sql = "\nUNION ALL\n".join(selects)
66
+
67
+ out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
68
+ with self._connection.cursor() as cursor:
69
+ cursor.execute(sql)
70
+ for table_name, _min, _max in cursor.fetchall():
71
+ out_dict[table_name] = (
72
+ pd.Timestamp.max if _min is None else pd.Timestamp(_min),
73
+ pd.Timestamp.min if _max is None else pd.Timestamp(_max),
74
+ )
75
+
76
+ return out_dict
77
+
78
+ def _sample_entity_table(
79
+ self,
80
+ table_name: str,
81
+ columns: set[str],
82
+ num_rows: int,
83
+ random_seed: int | None = None,
84
+ ) -> pd.DataFrame:
85
+ # NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
86
+ num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
87
+
88
+ source_table = self.source_table_dict[table_name]
89
+ filters: list[str] = []
90
+
91
+ key = self.primary_key_dict[table_name]
92
+ if key not in source_table or source_table[key].is_nullable:
93
+ key_ref = self.table_column_ref_dict[table_name][key]
94
+ filters.append(f" {key_ref} IS NOT NULL")
95
+
96
+ column = self.time_column_dict.get(table_name)
97
+ if column is None:
98
+ pass
99
+ elif column not in source_table or source_table[column].is_nullable:
100
+ column_ref = self.table_column_ref_dict[table_name][column]
101
+ filters.append(f" {column_ref} IS NOT NULL")
102
+
103
+ projections = [
104
+ self.table_column_proj_dict[table_name][column]
105
+ for column in columns
106
+ ]
107
+ sql = (f"SELECT {', '.join(projections)}\n"
108
+ f"FROM {self.source_name_dict[table_name]}\n"
109
+ f"SAMPLE ROW ({num_rows} ROWS)")
110
+ if len(filters) > 0:
111
+ sql += f"\nWHERE{' AND'.join(filters)}"
112
+
113
+ with self._connection.cursor() as cursor:
114
+ # NOTE This may return duplicate primary keys. This is okay.
115
+ cursor.execute(sql)
116
+ table = cursor.fetch_arrow_all()
117
+
118
+ return Table._sanitize(
119
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
120
+ dtype_dict=self.table_dtype_dict[table_name],
121
+ stype_dict=self.table_stype_dict[table_name],
122
+ )
123
+
124
+ def _sample_target(
125
+ self,
126
+ query: ValidatedPredictiveQuery,
127
+ entity_df: pd.DataFrame,
128
+ train_index: np.ndarray,
129
+ train_time: pd.Series,
130
+ num_train_examples: int,
131
+ test_index: np.ndarray,
132
+ test_time: pd.Series,
133
+ num_test_examples: int,
134
+ columns_dict: dict[str, set[str]],
135
+ time_offset_dict: dict[
136
+ tuple[str, str, str],
137
+ tuple[pd.DateOffset | None, pd.DateOffset],
138
+ ],
139
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
140
+
141
+ # NOTE For Snowflake, we execute everything at once to pay minimal
142
+ # query initialization costs.
143
+ index = np.concatenate([train_index, test_index])
144
+ time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
145
+
146
+ entity_df = entity_df.iloc[index].reset_index(drop=True)
147
+
148
+ feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
149
+ time_dict: dict[str, pd.Series] = {}
150
+ time_column = self.time_column_dict.get(query.entity_table)
151
+ if time_column in columns_dict[query.entity_table]:
152
+ time_dict[query.entity_table] = entity_df[time_column]
153
+ batch_dict: dict[str, np.ndarray] = {
154
+ query.entity_table: np.arange(len(entity_df)),
155
+ }
156
+ for edge_type, (min_offset, max_offset) in time_offset_dict.items():
157
+ table_name, foreign_key, _ = edge_type
158
+ feat_dict[table_name], batch_dict[table_name] = self._by_time(
159
+ table_name=table_name,
160
+ foreign_key=foreign_key,
161
+ index=entity_df[self.primary_key_dict[query.entity_table]],
162
+ anchor_time=time,
163
+ min_offset=min_offset,
164
+ max_offset=max_offset,
165
+ columns=columns_dict[table_name],
166
+ )
167
+ time_column = self.time_column_dict.get(table_name)
168
+ if time_column in columns_dict[table_name]:
169
+ time_dict[table_name] = feat_dict[table_name][time_column]
170
+
171
+ y, mask = PQueryPandasExecutor().execute(
172
+ query=query,
173
+ feat_dict=feat_dict,
174
+ time_dict=time_dict,
175
+ batch_dict=batch_dict,
176
+ anchor_time=time,
177
+ num_forecasts=query.num_forecasts,
178
+ )
179
+
180
+ train_mask = mask[:len(train_index)]
181
+ test_mask = mask[len(train_index):]
182
+
183
+ boundary = int(train_mask.sum())
184
+ train_y = y.iloc[:boundary]
185
+ test_y = y.iloc[boundary:].reset_index(drop=True)
186
+
187
+ return train_y, train_mask, test_y, test_mask
188
+
189
+ def _by_pkey(
190
+ self,
191
+ table_name: str,
192
+ index: pd.Series,
193
+ columns: set[str],
194
+ ) -> tuple[pd.DataFrame, np.ndarray]:
195
+ key = self.primary_key_dict[table_name]
196
+ key_ref = self.table_column_ref_dict[table_name][key]
197
+ projections = [
198
+ self.table_column_proj_dict[table_name][column]
199
+ for column in columns
200
+ ]
201
+
202
+ payload = json.dumps(list(index))
203
+
204
+ sql = ("WITH TMP as (\n"
205
+ " SELECT\n"
206
+ " f.index as __KUMO_BATCH__,\n")
207
+ if self.table_dtype_dict[table_name][key].is_int():
208
+ sql += " f.value::NUMBER as __KUMO_ID__\n"
209
+ elif self.table_dtype_dict[table_name][key].is_float():
210
+ sql += " f.value::FLOAT as __KUMO_ID__\n"
211
+ else:
212
+ sql += " f.value::VARCHAR as __KUMO_ID__\n"
213
+ sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
214
+ f")\n"
215
+ f"SELECT "
216
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
217
+ f"{', '.join(projections)}\n"
218
+ f"FROM TMP\n"
219
+ f"JOIN {self.source_name_dict[table_name]}\n"
220
+ f" ON {key_ref} = TMP.__KUMO_ID__")
221
+
222
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
223
+ cursor.execute(sql, (payload, ))
224
+ table = cursor.fetch_arrow_all()
225
+
226
+ # Remove any duplicated primary keys in post-processing:
227
+ tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
228
+ gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
229
+ table = table.take(gb['__KUMO_ID___min'])
230
+
231
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
232
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
233
+ table = table.remove_column(batch_index)
234
+
235
+ return Table._sanitize(
236
+ df=table.to_pandas(),
237
+ dtype_dict=self.table_dtype_dict[table_name],
238
+ stype_dict=self.table_stype_dict[table_name],
239
+ ), batch
240
+
241
+ def _by_fkey(
242
+ self,
243
+ table_name: str,
244
+ foreign_key: str,
245
+ index: pd.Series,
246
+ num_neighbors: int,
247
+ anchor_time: pd.Series | None,
248
+ columns: set[str],
249
+ ) -> tuple[pd.DataFrame, np.ndarray]:
250
+ time_column = self.time_column_dict.get(table_name)
251
+
252
+ end_time: pd.Series | None = None
253
+ start_time: pd.Series | None = None
254
+ if time_column is not None and anchor_time is not None:
255
+ # In order to avoid a full table scan, we limit foreign key
256
+ # sampling to a certain time range, approximated by the number of
257
+ # rows, timestamp ranges and `num_neighbors` value.
258
+ # Downstream, this helps Snowflake to apply partition pruning:
259
+ dst_table_name = [
260
+ dst_table
261
+ for key, dst_table in self.foreign_key_dict[table_name]
262
+ if key == foreign_key
263
+ ][0]
264
+ num_facts = self.num_rows_dict[table_name]
265
+ num_entities = self.num_rows_dict[dst_table_name]
266
+ min_time = self.get_min_time([table_name])
267
+ max_time = self.get_max_time([table_name])
268
+ freq = num_facts / num_entities
269
+ freq = freq / max((max_time - min_time).total_seconds(), 1)
270
+ offset = pd.Timedelta(seconds=math.ceil(5 * num_neighbors / freq))
271
+
272
+ end_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
273
+ start_time = anchor_time - offset
274
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
275
+ payload = json.dumps(list(zip(index, end_time, start_time)))
276
+ else:
277
+ payload = json.dumps(list(zip(index)))
278
+
279
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
280
+ projections = [
281
+ self.table_column_proj_dict[table_name][column]
282
+ for column in columns
283
+ ]
284
+
285
+ sql = ("WITH TMP as (\n"
286
+ " SELECT\n"
287
+ " f.index as __KUMO_BATCH__,\n")
288
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
289
+ sql += " f.value[0]::NUMBER as __KUMO_ID__"
290
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
291
+ sql += " f.value[0]::FLOAT as __KUMO_ID__"
292
+ else:
293
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__"
294
+ if end_time is not None and start_time is not None:
295
+ sql += (",\n"
296
+ " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__,\n"
297
+ " f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__")
298
+ sql += (f"\n"
299
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
300
+ f")\n"
301
+ f"SELECT "
302
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
303
+ f"{', '.join(projections)}\n"
304
+ f"FROM TMP\n"
305
+ f"JOIN {self.source_name_dict[table_name]}\n"
306
+ f" ON {key_ref} = TMP.__KUMO_ID__\n")
307
+ if end_time is not None and start_time is not None:
308
+ assert time_column is not None
309
+ time_ref = self.table_column_ref_dict[table_name][time_column]
310
+ sql += (f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n"
311
+ f" AND {time_ref} > TMP.__KUMO_START_TIME__\n"
312
+ f"WHERE {time_ref} <= '{end_time.max()}'\n"
313
+ f" AND {time_ref} > '{start_time.min()}'\n")
314
+ sql += ("QUALIFY ROW_NUMBER() OVER (\n"
315
+ " PARTITION BY TMP.__KUMO_BATCH__\n")
316
+ if time_column is not None:
317
+ sql += f" ORDER BY {time_ref} DESC\n"
318
+ else:
319
+ sql += f" ORDER BY {key_ref}\n"
320
+ sql += f") <= {num_neighbors}"
321
+
322
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
323
+ cursor.execute(sql, (payload, ))
324
+ table = cursor.fetch_arrow_all()
325
+
326
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
327
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
328
+ table = table.remove_column(batch_index)
329
+
330
+ return Table._sanitize(
331
+ df=table.to_pandas(),
332
+ dtype_dict=self.table_dtype_dict[table_name],
333
+ stype_dict=self.table_stype_dict[table_name],
334
+ ), batch
335
+
336
+ # Helper Methods ##########################################################
337
+
338
+ def _by_time(
339
+ self,
340
+ table_name: str,
341
+ foreign_key: str,
342
+ index: pd.Series,
343
+ anchor_time: pd.Series,
344
+ min_offset: pd.DateOffset | None,
345
+ max_offset: pd.DateOffset,
346
+ columns: set[str],
347
+ ) -> tuple[pd.DataFrame, np.ndarray]:
348
+ time_column = self.time_column_dict[table_name]
349
+
350
+ end_time = anchor_time + max_offset
351
+ end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
352
+ start_time: pd.Series | None = None
353
+ if min_offset is not None:
354
+ start_time = anchor_time + min_offset
355
+ start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
356
+ payload = json.dumps(list(zip(index, end_time, start_time)))
357
+ else:
358
+ payload = json.dumps(list(zip(index, end_time)))
359
+
360
+ key_ref = self.table_column_ref_dict[table_name][foreign_key]
361
+ time_ref = self.table_column_ref_dict[table_name][time_column]
362
+ projections = [
363
+ self.table_column_proj_dict[table_name][column]
364
+ for column in columns
365
+ ]
366
+ sql = ("WITH TMP as (\n"
367
+ " SELECT\n"
368
+ " f.index as __KUMO_BATCH__,\n")
369
+ if self.table_dtype_dict[table_name][foreign_key].is_int():
370
+ sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
371
+ elif self.table_dtype_dict[table_name][foreign_key].is_float():
372
+ sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
373
+ else:
374
+ sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
375
+ sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
376
+ if min_offset is not None:
377
+ sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
378
+ sql += (f"\n"
379
+ f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
380
+ f")\n"
381
+ f"SELECT "
382
+ f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
383
+ f"{', '.join(projections)}\n"
384
+ f"FROM TMP\n"
385
+ f"JOIN {self.source_name_dict[table_name]}\n"
386
+ f" ON {key_ref} = TMP.__KUMO_ID__\n"
387
+ f" AND {time_ref} <= TMP.__KUMO_END_TIME__\n")
388
+ if start_time is not None:
389
+ sql += f"AND {time_ref} > TMP.__KUMO_START_TIME__\n"
390
+ # Add global time bounds to enable partition pruning:
391
+ sql += f"WHERE {time_ref} <= '{end_time.max()}'"
392
+ if start_time is not None:
393
+ sql += f"\nAND {time_ref} > '{start_time.min()}'"
394
+
395
+ with paramstyle(self._connection), self._connection.cursor() as cursor:
396
+ cursor.execute(sql, (payload, ))
397
+ table = cursor.fetch_arrow_all()
398
+
399
+ batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
400
+ batch_index = table.schema.get_field_index('__KUMO_BATCH__')
401
+ table = table.remove_column(batch_index)
402
+
403
+ return Table._sanitize(
404
+ df=table.to_pandas(types_mapper=pd.ArrowDtype),
405
+ dtype_dict=self.table_dtype_dict[table_name],
406
+ stype_dict=self.table_stype_dict[table_name],
407
+ ), batch