kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
  12. kumoai/experimental/rfm/backend/snow/table.py +159 -52
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
  15. kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
  16. kumoai/experimental/rfm/base/__init__.py +6 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +342 -13
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +27 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +7 -4
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +600 -360
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +190 -12
  44. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
  45. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
  46. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
  47. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
  48. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,13 @@
1
1
  import json
2
+ import math
2
3
  import time
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from collections.abc import Generator
6
+ from collections.abc import Generator, Iterator
6
7
  from contextlib import contextmanager
7
8
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
9
+ from typing import Any, Literal, overload
19
10
 
20
- import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
@@ -37,30 +27,37 @@ from kumoapi.rfm import (
37
27
  )
38
28
  from kumoapi.task import TaskType
39
29
  from kumoapi.typing import AggregationType, Stype
30
+ from rich.console import Console
31
+ from rich.markdown import Markdown
40
32
 
33
+ from kumoai import in_notebook
41
34
  from kumoai.client.rfm import RFMAPI
42
35
  from kumoai.exceptions import HTTPException
43
- from kumoai.experimental.rfm import Graph
36
+ from kumoai.experimental.rfm import Graph, TaskTable
44
37
  from kumoai.experimental.rfm.base import DataBackend, Sampler
45
38
  from kumoai.mixin import CastMixin
46
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
39
+ from kumoai.utils import ProgressLogger, display
47
40
 
48
41
  _RANDOM_SEED = 42
49
42
 
50
43
  _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
51
44
  _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
52
45
 
46
+ _MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
47
+ _MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
48
+
53
49
  _MAX_CONTEXT_SIZE = {
54
50
  RunMode.DEBUG: 100,
55
51
  RunMode.FAST: 1_000,
56
52
  RunMode.NORMAL: 5_000,
57
53
  RunMode.BEST: 10_000,
58
54
  }
59
- _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
60
- RunMode.DEBUG: 100,
61
- RunMode.FAST: 2_000,
62
- RunMode.NORMAL: 2_000,
63
- RunMode.BEST: 2_000,
55
+
56
+ _DEFAULT_NUM_NEIGHBORS = {
57
+ RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
58
+ RunMode.FAST: [32, 32, 8, 8, 4, 4],
59
+ RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
60
+ RunMode.BEST: [64, 64, 8, 8, 4, 4],
64
61
  }
65
62
 
66
63
  _MAX_SIZE = 30 * 1024 * 1024
@@ -98,24 +95,36 @@ class Explanation:
98
95
  def __getitem__(self, index: Literal[1]) -> str:
99
96
  pass
100
97
 
101
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
98
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
102
99
  if index == 0:
103
100
  return self.prediction
104
101
  if index == 1:
105
102
  return self.summary
106
103
  raise IndexError("Index out of range")
107
104
 
108
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
105
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
109
106
  return iter((self.prediction, self.summary))
110
107
 
111
108
  def __repr__(self) -> str:
112
109
  return str((self.prediction, self.summary))
113
110
 
114
- def _ipython_display_(self) -> None:
115
- from IPython.display import Markdown, display
111
+ def __str__(self) -> str:
112
+ console = Console(soft_wrap=True)
113
+ with console.capture() as cap:
114
+ console.print(display.to_rich_table(self.prediction))
115
+ console.print(Markdown(self.summary))
116
+ return cap.get()[:-1]
117
+
118
+ def print(self) -> None:
119
+ r"""Prints the explanation."""
120
+ if in_notebook():
121
+ display.dataframe(self.prediction)
122
+ display.message(self.summary)
123
+ else:
124
+ print(self)
116
125
 
117
- display(self.prediction)
118
- display(Markdown(self.summary))
126
+ def _ipython_display_(self) -> None:
127
+ self.print()
119
128
 
120
129
 
121
130
  class KumoRFM:
@@ -162,7 +171,7 @@ class KumoRFM:
162
171
  def __init__(
163
172
  self,
164
173
  graph: Graph,
165
- verbose: Union[bool, ProgressLogger] = True,
174
+ verbose: bool | ProgressLogger = True,
166
175
  optimize: bool = False,
167
176
  ) -> None:
168
177
  graph = graph.validate()
@@ -180,10 +189,10 @@ class KumoRFM:
180
189
  else:
181
190
  raise NotImplementedError
182
191
 
183
- self._client: Optional[RFMAPI] = None
192
+ self._client: RFMAPI | None = None
184
193
 
185
- self._batch_size: Optional[int | Literal['max']] = None
186
- self.num_retries: int = 0
194
+ self._batch_size: int | Literal['max'] | None = None
195
+ self._num_retries: int = 0
187
196
 
188
197
  @property
189
198
  def _api_client(self) -> RFMAPI:
@@ -197,10 +206,34 @@ class KumoRFM:
197
206
  def __repr__(self) -> str:
198
207
  return f'{self.__class__.__name__}()'
199
208
 
209
+ @contextmanager
210
+ def retry(
211
+ self,
212
+ num_retries: int = 1,
213
+ ) -> Generator[None, None, None]:
214
+ """Context manager to retry failed queries due to unexpected server
215
+ issues.
216
+
217
+ .. code-block:: python
218
+
219
+ with model.retry(num_retries=1):
220
+ df = model.predict(query, indices=...)
221
+
222
+ Args:
223
+ num_retries: The maximum number of retries.
224
+ """
225
+ if num_retries < 0:
226
+ raise ValueError(f"'num_retries' must be greater than or equal to "
227
+ f"zero (got {num_retries})")
228
+
229
+ self._num_retries = num_retries
230
+ yield
231
+ self._num_retries = 0
232
+
200
233
  @contextmanager
201
234
  def batch_mode(
202
235
  self,
203
- batch_size: Union[int, Literal['max']] = 'max',
236
+ batch_size: int | Literal['max'] = 'max',
204
237
  num_retries: int = 1,
205
238
  ) -> Generator[None, None, None]:
206
239
  """Context manager to predict in batches.
@@ -220,31 +253,26 @@ class KumoRFM:
220
253
  raise ValueError(f"'batch_size' must be greater than zero "
221
254
  f"(got {batch_size})")
222
255
 
223
- if num_retries < 0:
224
- raise ValueError(f"'num_retries' must be greater than or equal to "
225
- f"zero (got {num_retries})")
226
-
227
256
  self._batch_size = batch_size
228
- self.num_retries = num_retries
229
- yield
257
+ with self.retry(self._num_retries or num_retries):
258
+ yield
230
259
  self._batch_size = None
231
- self.num_retries = 0
232
260
 
233
261
  @overload
234
262
  def predict(
235
263
  self,
236
264
  query: str,
237
- indices: Union[List[str], List[float], List[int], None] = None,
265
+ indices: list[str] | list[float] | list[int] | None = None,
238
266
  *,
239
267
  explain: Literal[False] = False,
240
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
241
- context_anchor_time: Union[pd.Timestamp, None] = None,
242
- run_mode: Union[RunMode, str] = RunMode.FAST,
243
- num_neighbors: Optional[List[int]] = None,
268
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
269
+ context_anchor_time: pd.Timestamp | None = None,
270
+ run_mode: RunMode | str = RunMode.FAST,
271
+ num_neighbors: list[int] | None = None,
244
272
  num_hops: int = 2,
245
273
  max_pq_iterations: int = 10,
246
- random_seed: Optional[int] = _RANDOM_SEED,
247
- verbose: Union[bool, ProgressLogger] = True,
274
+ random_seed: int | None = _RANDOM_SEED,
275
+ verbose: bool | ProgressLogger = True,
248
276
  use_prediction_time: bool = False,
249
277
  ) -> pd.DataFrame:
250
278
  pass
@@ -253,37 +281,56 @@ class KumoRFM:
253
281
  def predict(
254
282
  self,
255
283
  query: str,
256
- indices: Union[List[str], List[float], List[int], None] = None,
284
+ indices: list[str] | list[float] | list[int] | None = None,
257
285
  *,
258
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
259
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
260
- context_anchor_time: Union[pd.Timestamp, None] = None,
261
- run_mode: Union[RunMode, str] = RunMode.FAST,
262
- num_neighbors: Optional[List[int]] = None,
286
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
287
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
288
+ context_anchor_time: pd.Timestamp | None = None,
289
+ run_mode: RunMode | str = RunMode.FAST,
290
+ num_neighbors: list[int] | None = None,
263
291
  num_hops: int = 2,
264
292
  max_pq_iterations: int = 10,
265
- random_seed: Optional[int] = _RANDOM_SEED,
266
- verbose: Union[bool, ProgressLogger] = True,
293
+ random_seed: int | None = _RANDOM_SEED,
294
+ verbose: bool | ProgressLogger = True,
267
295
  use_prediction_time: bool = False,
268
296
  ) -> Explanation:
269
297
  pass
270
298
 
299
+ @overload
300
+ def predict(
301
+ self,
302
+ query: str,
303
+ indices: list[str] | list[float] | list[int] | None = None,
304
+ *,
305
+ explain: bool | ExplainConfig | dict[str, Any] = False,
306
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
307
+ context_anchor_time: pd.Timestamp | None = None,
308
+ run_mode: RunMode | str = RunMode.FAST,
309
+ num_neighbors: list[int] | None = None,
310
+ num_hops: int = 2,
311
+ max_pq_iterations: int = 10,
312
+ random_seed: int | None = _RANDOM_SEED,
313
+ verbose: bool | ProgressLogger = True,
314
+ use_prediction_time: bool = False,
315
+ ) -> pd.DataFrame | Explanation:
316
+ pass
317
+
271
318
  def predict(
272
319
  self,
273
320
  query: str,
274
- indices: Union[List[str], List[float], List[int], None] = None,
321
+ indices: list[str] | list[float] | list[int] | None = None,
275
322
  *,
276
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
277
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
278
- context_anchor_time: Union[pd.Timestamp, None] = None,
279
- run_mode: Union[RunMode, str] = RunMode.FAST,
280
- num_neighbors: Optional[List[int]] = None,
323
+ explain: bool | ExplainConfig | dict[str, Any] = False,
324
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
325
+ context_anchor_time: pd.Timestamp | None = None,
326
+ run_mode: RunMode | str = RunMode.FAST,
327
+ num_neighbors: list[int] | None = None,
281
328
  num_hops: int = 2,
282
329
  max_pq_iterations: int = 10,
283
- random_seed: Optional[int] = _RANDOM_SEED,
284
- verbose: Union[bool, ProgressLogger] = True,
330
+ random_seed: int | None = _RANDOM_SEED,
331
+ verbose: bool | ProgressLogger = True,
285
332
  use_prediction_time: bool = False,
286
- ) -> Union[pd.DataFrame, Explanation]:
333
+ ) -> pd.DataFrame | Explanation:
287
334
  """Returns predictions for a predictive query.
288
335
 
289
336
  Args:
@@ -291,8 +338,7 @@ class KumoRFM:
291
338
  indices: The entity primary keys to predict for. Will override the
292
339
  indices given as part of the predictive query. Predictions will
293
340
  be generated for all indices, independent of whether they
294
- fulfill entity filter constraints. To pre-filter entities, use
295
- :meth:`~KumoRFM.is_valid_entity`.
341
+ fulfill entity filter constraints.
296
342
  explain: Configuration for explainability.
297
343
  If set to ``True``, will additionally explain the prediction.
298
344
  Passing in an :class:`ExplainConfig` instance provides control
@@ -325,18 +371,152 @@ class KumoRFM:
325
371
  If ``explain`` is provided, returns an :class:`Explanation` object
326
372
  containing the prediction, summary, and details.
327
373
  """
328
- explain_config: Optional[ExplainConfig] = None
329
- if explain is True:
330
- explain_config = ExplainConfig()
331
- elif explain is not False:
332
- explain_config = ExplainConfig._cast(explain)
333
-
334
374
  query_def = self._parse_query(query)
335
- query_str = query_def.to_string()
336
375
 
376
+ if indices is None:
377
+ if query_def.rfm_entity_ids is None:
378
+ raise ValueError("Cannot find entities to predict for. Please "
379
+ "pass them via `predict(query, indices=...)`")
380
+ indices = query_def.get_rfm_entity_id_list()
381
+ query_def = replace(
382
+ query_def,
383
+ for_each='FOR EACH',
384
+ rfm_entity_ids=None,
385
+ )
386
+
387
+ if not isinstance(verbose, ProgressLogger):
388
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
389
+ if explain is not False:
390
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
391
+ else:
392
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
393
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
394
+
395
+ with verbose as logger:
396
+ task_table = self._get_task_table(
397
+ query=query_def,
398
+ indices=indices,
399
+ anchor_time=anchor_time,
400
+ context_anchor_time=context_anchor_time,
401
+ run_mode=run_mode,
402
+ max_pq_iterations=max_pq_iterations,
403
+ random_seed=random_seed,
404
+ logger=logger,
405
+ )
406
+ task_table._query = query_def.to_string()
407
+
408
+ return self.predict_task(
409
+ task_table,
410
+ explain=explain,
411
+ run_mode=run_mode,
412
+ num_neighbors=num_neighbors,
413
+ num_hops=num_hops,
414
+ verbose=verbose,
415
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
416
+ use_prediction_time=use_prediction_time,
417
+ top_k=query_def.top_k,
418
+ )
419
+
420
+ @overload
421
+ def predict_task(
422
+ self,
423
+ task: TaskTable,
424
+ *,
425
+ explain: Literal[False] = False,
426
+ run_mode: RunMode | str = RunMode.FAST,
427
+ num_neighbors: list[int] | None = None,
428
+ num_hops: int = 2,
429
+ verbose: bool | ProgressLogger = True,
430
+ exclude_cols_dict: dict[str, list[str]] | None = None,
431
+ use_prediction_time: bool = False,
432
+ top_k: int | None = None,
433
+ ) -> pd.DataFrame:
434
+ pass
435
+
436
+ @overload
437
+ def predict_task(
438
+ self,
439
+ task: TaskTable,
440
+ *,
441
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
442
+ run_mode: RunMode | str = RunMode.FAST,
443
+ num_neighbors: list[int] | None = None,
444
+ num_hops: int = 2,
445
+ verbose: bool | ProgressLogger = True,
446
+ exclude_cols_dict: dict[str, list[str]] | None = None,
447
+ use_prediction_time: bool = False,
448
+ top_k: int | None = None,
449
+ ) -> Explanation:
450
+ pass
451
+
452
+ @overload
453
+ def predict_task(
454
+ self,
455
+ task: TaskTable,
456
+ *,
457
+ explain: bool | ExplainConfig | dict[str, Any] = False,
458
+ run_mode: RunMode | str = RunMode.FAST,
459
+ num_neighbors: list[int] | None = None,
460
+ num_hops: int = 2,
461
+ verbose: bool | ProgressLogger = True,
462
+ exclude_cols_dict: dict[str, list[str]] | None = None,
463
+ use_prediction_time: bool = False,
464
+ top_k: int | None = None,
465
+ ) -> pd.DataFrame | Explanation:
466
+ pass
467
+
468
+ def predict_task(
469
+ self,
470
+ task: TaskTable,
471
+ *,
472
+ explain: bool | ExplainConfig | dict[str, Any] = False,
473
+ run_mode: RunMode | str = RunMode.FAST,
474
+ num_neighbors: list[int] | None = None,
475
+ num_hops: int = 2,
476
+ verbose: bool | ProgressLogger = True,
477
+ exclude_cols_dict: dict[str, list[str]] | None = None,
478
+ use_prediction_time: bool = False,
479
+ top_k: int | None = None,
480
+ ) -> pd.DataFrame | Explanation:
481
+ """Returns predictions for a custom task specification.
482
+
483
+ Args:
484
+ task: The custom :class:`TaskTable`.
485
+ explain: Configuration for explainability.
486
+ If set to ``True``, will additionally explain the prediction.
487
+ Passing in an :class:`ExplainConfig` instance provides control
488
+ over which parts of explanation are generated.
489
+ Explainability is currently only supported for single entity
490
+ predictions with ``run_mode="FAST"``.
491
+ run_mode: The :class:`RunMode` for the query.
492
+ num_neighbors: The number of neighbors to sample for each hop.
493
+ If specified, the ``num_hops`` option will be ignored.
494
+ num_hops: The number of hops to sample when generating the context.
495
+ verbose: Whether to print verbose output.
496
+ exclude_cols_dict: Any column in any table to exclude from the
497
+ model input.
498
+ use_prediction_time: Whether to use the anchor timestamp as an
499
+ additional feature during prediction. This is typically
500
+ beneficial for time series forecasting tasks.
501
+ top_k: The number of predictions to return per entity.
502
+
503
+ Returns:
504
+ The predictions as a :class:`pandas.DataFrame`.
505
+ If ``explain`` is provided, returns an :class:`Explanation` object
506
+ containing the prediction, summary, and details.
507
+ """
337
508
  if num_hops != 2 and num_neighbors is not None:
338
509
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
339
510
  f"custom 'num_hops={num_hops}' option")
511
+ if num_neighbors is None:
512
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
513
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
514
+
515
+ explain_config: ExplainConfig | None = None
516
+ if explain is True:
517
+ explain_config = ExplainConfig()
518
+ elif explain is not False:
519
+ explain_config = ExplainConfig._cast(explain)
340
520
 
341
521
  if explain_config is not None and run_mode in {
342
522
  RunMode.NORMAL, RunMode.BEST
@@ -345,83 +525,82 @@ class KumoRFM:
345
525
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
346
526
  f"mode has been reset. Please lower the run mode to "
347
527
  f"suppress this warning.")
528
+ run_mode = RunMode.FAST
348
529
 
349
- if indices is None:
350
- if query_def.rfm_entity_ids is None:
351
- raise ValueError("Cannot find entities to predict for. Please "
352
- "pass them via `predict(query, indices=...)`")
353
- indices = query_def.get_rfm_entity_id_list()
354
- else:
355
- query_def = replace(query_def, rfm_entity_ids=None)
356
-
357
- if len(indices) == 0:
358
- raise ValueError("At least one entity is required")
359
-
360
- if explain_config is not None and len(indices) > 1:
361
- raise ValueError(
362
- f"Cannot explain predictions for more than a single entity "
363
- f"(got {len(indices)})")
364
-
365
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
366
- if explain_config is not None:
367
- msg = f'[bold]EXPLAIN[/bold] {query_repr}'
368
- else:
369
- msg = f'[bold]PREDICT[/bold] {query_repr}'
530
+ if explain_config is not None and task.num_prediction_examples > 1:
531
+ raise ValueError(f"Cannot explain predictions for more than a "
532
+ f"single entity "
533
+ f"(got {task.num_prediction_examples:,})")
370
534
 
371
535
  if not isinstance(verbose, ProgressLogger):
372
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
373
-
374
- with verbose as logger:
375
-
376
- batch_size: Optional[int] = None
377
- if self._batch_size == 'max':
378
- task_type = self._get_task_type(
379
- query=query_def,
380
- edge_types=self._sampler.edge_types,
381
- )
382
- batch_size = _MAX_PRED_SIZE[task_type]
536
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
537
+ task_type_repr = 'binary classification'
538
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
539
+ task_type_repr = 'multi-class classification'
540
+ elif task.task_type == TaskType.REGRESSION:
541
+ task_type_repr = 'regression'
542
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
543
+ task_type_repr = 'link prediction'
383
544
  else:
384
- batch_size = self._batch_size
545
+ task_type_repr = str(task.task_type)
385
546
 
386
- if batch_size is not None:
387
- offsets = range(0, len(indices), batch_size)
388
- batches = [indices[step:step + batch_size] for step in offsets]
547
+ if explain_config is not None:
548
+ msg = f"Explaining {task_type_repr} task"
389
549
  else:
390
- batches = [indices]
550
+ msg = f"Predicting {task_type_repr} task"
551
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
391
552
 
392
- if len(batches) > 1:
393
- logger.log(f"Splitting {len(indices):,} entities into "
394
- f"{len(batches):,} batches of size {batch_size:,}")
553
+ with verbose as logger:
554
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
555
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
556
+ f"out of {task.num_context_examples:,} in-context "
557
+ f"examples")
558
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
559
+
560
+ if self._batch_size is None:
561
+ batch_size = task.num_prediction_examples
562
+ elif self._batch_size == 'max':
563
+ batch_size = _MAX_PRED_SIZE[task.task_type]
564
+ else:
565
+ batch_size = self._batch_size
395
566
 
396
- predictions: List[pd.DataFrame] = []
397
- summary: Optional[str] = None
398
- details: Optional[Explanation] = None
399
- for i, batch in enumerate(batches):
400
- # TODO Re-use the context for subsequent predictions.
567
+ if batch_size > _MAX_PRED_SIZE[task.task_type]:
568
+ raise ValueError(f"Cannot predict for more than "
569
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
570
+ f"entities at once (got {batch_size:,}). Use "
571
+ f"`KumoRFM.batch_mode` to process entities "
572
+ f"in batches with a sufficient batch size.")
573
+
574
+ if task.num_prediction_examples > batch_size:
575
+ num = math.ceil(task.num_prediction_examples / batch_size)
576
+ logger.log(f"Splitting {task.num_prediction_examples:,} "
577
+ f"entities into {num:,} batches of size "
578
+ f"{batch_size:,}")
579
+
580
+ predictions: list[pd.DataFrame] = []
581
+ summary: str | None = None
582
+ details: Explanation | None = None
583
+ for start in range(0, task.num_prediction_examples, batch_size):
401
584
  context = self._get_context(
402
- query=query_def,
403
- indices=batch,
404
- anchor_time=anchor_time,
405
- context_anchor_time=context_anchor_time,
406
- run_mode=RunMode(run_mode),
585
+ task=task.narrow_prediction(start, length=batch_size),
586
+ run_mode=run_mode,
407
587
  num_neighbors=num_neighbors,
408
- num_hops=num_hops,
409
- max_pq_iterations=max_pq_iterations,
410
- evaluate=False,
411
- random_seed=random_seed,
412
- logger=logger if i == 0 else None,
588
+ exclude_cols_dict=exclude_cols_dict,
589
+ top_k=top_k,
413
590
  )
591
+ context.y_test = None
592
+
414
593
  request = RFMPredictRequest(
415
594
  context=context,
416
595
  run_mode=RunMode(run_mode),
417
- query=query_str,
596
+ query=task._query,
418
597
  use_prediction_time=use_prediction_time,
419
598
  )
420
599
  with warnings.catch_warnings():
421
600
  warnings.filterwarnings('ignore', message='gencode')
422
601
  request_msg = request.to_protobuf()
423
602
  _bytes = request_msg.SerializeToString()
424
- if i == 0:
603
+ if start == 0:
425
604
  logger.log(f"Generated context of size "
426
605
  f"{len(_bytes) / (1024*1024):.2f}MB")
427
606
 
@@ -429,14 +608,11 @@ class KumoRFM:
429
608
  stats = Context.get_memory_stats(request_msg.context)
430
609
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
431
610
 
432
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
433
- and len(batches) > 1):
434
- verbose.init_progress(
435
- total=len(batches),
436
- description='Predicting',
437
- )
611
+ if start == 0 and task.num_prediction_examples > batch_size:
612
+ num = math.ceil(task.num_prediction_examples / batch_size)
613
+ verbose.init_progress(total=num, description='Predicting')
438
614
 
439
- for attempt in range(self.num_retries + 1):
615
+ for attempt in range(self._num_retries + 1):
440
616
  try:
441
617
  if explain_config is not None:
442
618
  resp = self._api_client.explain(
@@ -452,7 +628,7 @@ class KumoRFM:
452
628
  # Cast 'ENTITY' to correct data type:
453
629
  if 'ENTITY' in df:
454
630
  table_dict = context.subgraph.table_dict
455
- table = table_dict[query_def.entity_table]
631
+ table = table_dict[context.entity_table_names[0]]
456
632
  ser = table.df[table.primary_key]
457
633
  df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
458
634
 
@@ -469,13 +645,12 @@ class KumoRFM:
469
645
 
470
646
  predictions.append(df)
471
647
 
472
- if (isinstance(verbose, InteractiveProgressLogger)
473
- and len(batches) > 1):
648
+ if task.num_prediction_examples > batch_size:
474
649
  verbose.step()
475
650
 
476
651
  break
477
652
  except HTTPException as e:
478
- if attempt == self.num_retries:
653
+ if attempt == self._num_retries:
479
654
  try:
480
655
  msg = json.loads(e.detail)['detail']
481
656
  except Exception:
@@ -505,64 +680,19 @@ class KumoRFM:
505
680
 
506
681
  return prediction
507
682
 
508
- def is_valid_entity(
509
- self,
510
- query: str,
511
- indices: Union[List[str], List[float], List[int], None] = None,
512
- *,
513
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
514
- ) -> np.ndarray:
515
- r"""Returns a mask that denotes which entities are valid for the
516
- given predictive query, *i.e.*, which entities fulfill (temporal)
517
- entity filter constraints.
518
-
519
- Args:
520
- query: The predictive query.
521
- indices: The entity primary keys to predict for. Will override the
522
- indices given as part of the predictive query.
523
- anchor_time: The anchor timestamp for the prediction. If set to
524
- ``None``, will use the maximum timestamp in the data.
525
- If set to ``"entity"``, will use the timestamp of the entity.
526
- """
527
- query_def = self._parse_query(query)
528
-
529
- if indices is None:
530
- if query_def.rfm_entity_ids is None:
531
- raise ValueError("Cannot find entities to predict for. Please "
532
- "pass them via "
533
- "`is_valid_entity(query, indices=...)`")
534
- indices = query_def.get_rfm_entity_id_list()
535
-
536
- if len(indices) == 0:
537
- raise ValueError("At least one entity is required")
538
-
539
- if anchor_time is None:
540
- anchor_time = self._get_default_anchor_time(query_def)
541
-
542
- if isinstance(anchor_time, pd.Timestamp):
543
- self._validate_time(query_def, anchor_time, None, False)
544
- else:
545
- assert anchor_time == 'entity'
546
- if query_def.entity_table not in self._sampler.time_column_dict:
547
- raise ValueError(f"Anchor time 'entity' requires the entity "
548
- f"table '{query_def.entity_table}' "
549
- f"to have a time column.")
550
-
551
- raise NotImplementedError
552
-
553
683
  def evaluate(
554
684
  self,
555
685
  query: str,
556
686
  *,
557
- metrics: Optional[List[str]] = None,
558
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
559
- context_anchor_time: Union[pd.Timestamp, None] = None,
560
- run_mode: Union[RunMode, str] = RunMode.FAST,
561
- num_neighbors: Optional[List[int]] = None,
687
+ metrics: list[str] | None = None,
688
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
689
+ context_anchor_time: pd.Timestamp | None = None,
690
+ run_mode: RunMode | str = RunMode.FAST,
691
+ num_neighbors: list[int] | None = None,
562
692
  num_hops: int = 2,
563
693
  max_pq_iterations: int = 10,
564
- random_seed: Optional[int] = _RANDOM_SEED,
565
- verbose: Union[bool, ProgressLogger] = True,
694
+ random_seed: int | None = _RANDOM_SEED,
695
+ verbose: bool | ProgressLogger = True,
566
696
  use_prediction_time: bool = False,
567
697
  ) -> pd.DataFrame:
568
698
  """Evaluates a predictive query.
@@ -594,41 +724,120 @@ class KumoRFM:
594
724
  Returns:
595
725
  The metrics as a :class:`pandas.DataFrame`
596
726
  """
597
- query_def = self._parse_query(query)
727
+ query_def = replace(
728
+ self._parse_query(query),
729
+ for_each='FOR EACH',
730
+ rfm_entity_ids=None,
731
+ )
732
+
733
+ if not isinstance(verbose, ProgressLogger):
734
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
735
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
736
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
737
+
738
+ with verbose as logger:
739
+ task_table = self._get_task_table(
740
+ query=query_def,
741
+ indices=None,
742
+ anchor_time=anchor_time,
743
+ context_anchor_time=context_anchor_time,
744
+ run_mode=run_mode,
745
+ max_pq_iterations=max_pq_iterations,
746
+ random_seed=random_seed,
747
+ logger=logger,
748
+ )
749
+
750
+ return self.evaluate_task(
751
+ task_table,
752
+ metrics=metrics,
753
+ run_mode=run_mode,
754
+ num_neighbors=num_neighbors,
755
+ num_hops=num_hops,
756
+ verbose=verbose,
757
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
758
+ use_prediction_time=use_prediction_time,
759
+ )
760
+
761
+ def evaluate_task(
762
+ self,
763
+ task: TaskTable,
764
+ *,
765
+ metrics: list[str] | None = None,
766
+ run_mode: RunMode | str = RunMode.FAST,
767
+ num_neighbors: list[int] | None = None,
768
+ num_hops: int = 2,
769
+ verbose: bool | ProgressLogger = True,
770
+ exclude_cols_dict: dict[str, list[str]] | None = None,
771
+ use_prediction_time: bool = False,
772
+ ) -> pd.DataFrame:
773
+ """Evaluates a custom task specification.
774
+
775
+ Args:
776
+ task: The custom :class:`TaskTable`.
777
+ metrics: The metrics to use.
778
+ run_mode: The :class:`RunMode` for the query.
779
+ num_neighbors: The number of neighbors to sample for each hop.
780
+ If specified, the ``num_hops`` option will be ignored.
781
+ num_hops: The number of hops to sample when generating the context.
782
+ verbose: Whether to print verbose output.
783
+ exclude_cols_dict: Any column in any table to exclude from the
784
+ model input.
785
+ use_prediction_time: Whether to use the anchor timestamp as an
786
+ additional feature during prediction. This is typically
787
+ beneficial for time series forecasting tasks.
598
788
 
789
+ Returns:
790
+ The metrics as a :class:`pandas.DataFrame`
791
+ """
599
792
  if num_hops != 2 and num_neighbors is not None:
600
793
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
601
794
  f"custom 'num_hops={num_hops}' option")
795
+ if num_neighbors is None:
796
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
797
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
602
798
 
603
- if query_def.rfm_entity_ids is not None:
604
- query_def = replace(
605
- query_def,
606
- rfm_entity_ids=None,
607
- )
608
-
609
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
610
- msg = f'[bold]EVALUATE[/bold] {query_repr}'
799
+ if metrics is not None and len(metrics) > 0:
800
+ self._validate_metrics(metrics, task.task_type)
801
+ metrics = list(dict.fromkeys(metrics))
611
802
 
612
803
  if not isinstance(verbose, ProgressLogger):
613
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
804
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
805
+ task_type_repr = 'binary classification'
806
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
807
+ task_type_repr = 'multi-class classification'
808
+ elif task.task_type == TaskType.REGRESSION:
809
+ task_type_repr = 'regression'
810
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
811
+ task_type_repr = 'link prediction'
812
+ else:
813
+ task_type_repr = str(task.task_type)
814
+
815
+ msg = f"Evaluating {task_type_repr} task"
816
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
614
817
 
615
818
  with verbose as logger:
819
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
820
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
821
+ f"out of {task.num_context_examples:,} in-context "
822
+ f"examples")
823
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
824
+
825
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
826
+ logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
827
+ f"out of {task.num_prediction_examples:,} test "
828
+ f"examples")
829
+ task = task.narrow_prediction(
830
+ start=0,
831
+ length=_MAX_TEST_SIZE[task.task_type],
832
+ )
833
+
616
834
  context = self._get_context(
617
- query=query_def,
618
- indices=None,
619
- anchor_time=anchor_time,
620
- context_anchor_time=context_anchor_time,
621
- run_mode=RunMode(run_mode),
835
+ task=task,
836
+ run_mode=run_mode,
622
837
  num_neighbors=num_neighbors,
623
- num_hops=num_hops,
624
- max_pq_iterations=max_pq_iterations,
625
- evaluate=True,
626
- random_seed=random_seed,
627
- logger=logger if verbose else None,
838
+ exclude_cols_dict=exclude_cols_dict,
628
839
  )
629
- if metrics is not None and len(metrics) > 0:
630
- self._validate_metrics(metrics, context.task_type)
631
- metrics = list(dict.fromkeys(metrics))
840
+
632
841
  request = RFMEvaluateRequest(
633
842
  context=context,
634
843
  run_mode=RunMode(run_mode),
@@ -646,17 +855,23 @@ class KumoRFM:
646
855
  stats_msg = Context.get_memory_stats(request_msg.context)
647
856
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
648
857
 
649
- try:
650
- resp = self._api_client.evaluate(request_bytes)
651
- except HTTPException as e:
858
+ for attempt in range(self._num_retries + 1):
652
859
  try:
653
- msg = json.loads(e.detail)['detail']
654
- except Exception:
655
- msg = e.detail
656
- raise RuntimeError(f"An unexpected exception occurred. "
657
- f"Please create an issue at "
658
- f"'https://github.com/kumo-ai/kumo-rfm'. "
659
- f"{msg}") from None
860
+ resp = self._api_client.evaluate(request_bytes)
861
+ break
862
+ except HTTPException as e:
863
+ if attempt == self._num_retries:
864
+ try:
865
+ msg = json.loads(e.detail)['detail']
866
+ except Exception:
867
+ msg = e.detail
868
+ raise RuntimeError(
869
+ f"An unexpected exception occurred. Please create "
870
+ f"an issue at "
871
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
872
+ ) from None
873
+
874
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
660
875
 
661
876
  return pd.DataFrame.from_dict(
662
877
  resp.metrics,
@@ -669,8 +884,8 @@ class KumoRFM:
669
884
  query: str,
670
885
  size: int,
671
886
  *,
672
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
673
- random_seed: Optional[int] = _RANDOM_SEED,
887
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
888
+ random_seed: int | None = _RANDOM_SEED,
674
889
  max_iterations: int = 10,
675
890
  ) -> pd.DataFrame:
676
891
  """Returns the labels of a predictive query for a specified anchor
@@ -708,7 +923,7 @@ class KumoRFM:
708
923
  f"to have a time column")
709
924
 
710
925
  train, test = self._sampler.sample_target(
711
- query=query,
926
+ query=query_def,
712
927
  num_train_examples=0,
713
928
  train_anchor_time=anchor_time,
714
929
  num_train_trials=0,
@@ -736,35 +951,39 @@ class KumoRFM:
736
951
  "`predict()` or `evaluate()` methods to perform "
737
952
  "predictions or evaluations.")
738
953
 
739
- try:
740
- request = RFMParseQueryRequest(
741
- query=query,
742
- graph_definition=self._graph_def,
743
- )
954
+ request = RFMParseQueryRequest(
955
+ query=query,
956
+ graph_definition=self._graph_def,
957
+ )
744
958
 
745
- resp = self._api_client.parse_query(request)
959
+ for attempt in range(self._num_retries + 1):
960
+ try:
961
+ resp = self._api_client.parse_query(request)
962
+ break
963
+ except HTTPException as e:
964
+ if attempt == self._num_retries:
965
+ try:
966
+ msg = json.loads(e.detail)['detail']
967
+ except Exception:
968
+ msg = e.detail
969
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
746
970
 
747
- if len(resp.validation_response.warnings) > 0:
748
- msg = '\n'.join([
749
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
750
- in enumerate(resp.validation_response.warnings)
751
- ])
752
- warnings.warn(f"Encountered the following warnings during "
753
- f"parsing:\n{msg}")
971
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
754
972
 
755
- return resp.query
756
- except HTTPException as e:
757
- try:
758
- msg = json.loads(e.detail)['detail']
759
- except Exception:
760
- msg = e.detail
761
- raise ValueError(f"Failed to parse query '{query}'. "
762
- f"{msg}") from None
973
+ if len(resp.validation_response.warnings) > 0:
974
+ msg = '\n'.join([
975
+ f'{i+1}. {warning.title}: {warning.message}'
976
+ for i, warning in enumerate(resp.validation_response.warnings)
977
+ ])
978
+ warnings.warn(f"Encountered the following warnings during "
979
+ f"parsing:\n{msg}")
980
+
981
+ return resp.query
763
982
 
764
983
  @staticmethod
765
984
  def _get_task_type(
766
985
  query: ValidatedPredictiveQuery,
767
- edge_types: List[Tuple[str, str, str]],
986
+ edge_types: list[tuple[str, str, str]],
768
987
  ) -> TaskType:
769
988
  if isinstance(query.target_ast, (Condition, LogicalOperation)):
770
989
  return TaskType.BINARY_CLASSIFICATION
@@ -803,31 +1022,38 @@ class KumoRFM:
803
1022
 
804
1023
  def _get_default_anchor_time(
805
1024
  self,
806
- query: ValidatedPredictiveQuery,
1025
+ query: ValidatedPredictiveQuery | None = None,
807
1026
  ) -> pd.Timestamp:
808
- if query.query_type == QueryType.TEMPORAL:
1027
+ if query is not None and query.query_type == QueryType.TEMPORAL:
809
1028
  aggr_table_names = [
810
1029
  aggr._get_target_column_name().split('.')[0]
811
1030
  for aggr in query.get_all_target_aggregations()
812
1031
  ]
813
1032
  return self._sampler.get_max_time(aggr_table_names)
814
1033
 
815
- assert query.query_type == QueryType.STATIC
816
1034
  return self._sampler.get_max_time()
817
1035
 
818
1036
  def _validate_time(
819
1037
  self,
820
1038
  query: ValidatedPredictiveQuery,
821
1039
  anchor_time: pd.Timestamp,
822
- context_anchor_time: Union[pd.Timestamp, None],
1040
+ context_anchor_time: pd.Timestamp | None,
823
1041
  evaluate: bool,
824
1042
  ) -> None:
825
1043
 
826
1044
  if len(self._sampler.time_column_dict) == 0:
827
1045
  return # Graph without timestamps
828
1046
 
829
- min_time = self._sampler.get_min_time()
830
- max_time = self._sampler.get_max_time()
1047
+ if query.query_type == QueryType.TEMPORAL:
1048
+ aggr_table_names = [
1049
+ aggr._get_target_column_name().split('.')[0]
1050
+ for aggr in query.get_all_target_aggregations()
1051
+ ]
1052
+ min_time = self._sampler.get_min_time(aggr_table_names)
1053
+ max_time = self._sampler.get_max_time(aggr_table_names)
1054
+ else:
1055
+ min_time = self._sampler.get_min_time()
1056
+ max_time = self._sampler.get_max_time()
831
1057
 
832
1058
  if anchor_time < min_time:
833
1059
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
@@ -882,40 +1108,26 @@ class KumoRFM:
882
1108
  f"Anchor timestamp for evaluation is after the latest "
883
1109
  f"supported timestamp '{max_time - end_offset}'.")
884
1110
 
885
- def _get_context(
1111
+ def _get_task_table(
886
1112
  self,
887
1113
  query: ValidatedPredictiveQuery,
888
- indices: Union[List[str], List[float], List[int], None],
889
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
890
- context_anchor_time: Union[pd.Timestamp, None],
891
- run_mode: RunMode,
892
- num_neighbors: Optional[List[int]],
893
- num_hops: int,
894
- max_pq_iterations: int,
895
- evaluate: bool,
896
- random_seed: Optional[int] = _RANDOM_SEED,
897
- logger: Optional[ProgressLogger] = None,
898
- ) -> Context:
899
-
900
- if num_neighbors is not None:
901
- num_hops = len(num_neighbors)
902
-
903
- if num_hops < 0:
904
- raise ValueError(f"'num_hops' must be non-negative "
905
- f"(got {num_hops})")
906
- if num_hops > 6:
907
- raise ValueError(f"Cannot predict on subgraphs with more than 6 "
908
- f"hops (got {num_hops}). Please reduce the "
909
- f"number of hops and try again. Please create a "
910
- f"feature request at "
911
- f"'https://github.com/kumo-ai/kumo-rfm' if you "
912
- f"must go beyond this for your use-case.")
1114
+ indices: list[str] | list[float] | list[int] | None,
1115
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
1116
+ context_anchor_time: pd.Timestamp | None = None,
1117
+ run_mode: RunMode = RunMode.FAST,
1118
+ max_pq_iterations: int = 10,
1119
+ random_seed: int | None = _RANDOM_SEED,
1120
+ logger: ProgressLogger | None = None,
1121
+ ) -> TaskTable:
913
1122
 
914
1123
  task_type = self._get_task_type(
915
1124
  query=query,
916
1125
  edge_types=self._sampler.edge_types,
917
1126
  )
918
1127
 
1128
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1129
+ num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
1130
+
919
1131
  if logger is not None:
920
1132
  if task_type == TaskType.BINARY_CLASSIFICATION:
921
1133
  task_type_repr = 'binary classification'
@@ -929,21 +1141,6 @@ class KumoRFM:
929
1141
  task_type_repr = str(task_type)
930
1142
  logger.log(f"Identified {query.query_type} {task_type_repr} task")
931
1143
 
932
- if task_type.is_link_pred and num_hops < 2:
933
- raise ValueError(f"Cannot perform link prediction on subgraphs "
934
- f"with less than 2 hops (got {num_hops}) since "
935
- f"historical target entities need to be part of "
936
- f"the context. Please increase the number of "
937
- f"hops and try again.")
938
-
939
- if num_neighbors is None:
940
- if run_mode == RunMode.DEBUG:
941
- num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
942
- elif run_mode == RunMode.FAST or task_type.is_link_pred:
943
- num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
944
- else:
945
- num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
946
-
947
1144
  if query.target_ast.date_offset_range is None:
948
1145
  step_offset = pd.DateOffset(0)
949
1146
  else:
@@ -952,8 +1149,7 @@ class KumoRFM:
952
1149
 
953
1150
  if anchor_time is None:
954
1151
  anchor_time = self._get_default_anchor_time(query)
955
-
956
- if evaluate:
1152
+ if num_test_examples > 0:
957
1153
  anchor_time = anchor_time - end_offset
958
1154
 
959
1155
  if logger is not None:
@@ -967,7 +1163,6 @@ class KumoRFM:
967
1163
  else:
968
1164
  logger.log(f"Derived anchor time {anchor_time}")
969
1165
 
970
- assert anchor_time is not None
971
1166
  if isinstance(anchor_time, pd.Timestamp):
972
1167
  if context_anchor_time == 'entity':
973
1168
  raise ValueError("Anchor time 'entity' needs to be shared "
@@ -975,7 +1170,7 @@ class KumoRFM:
975
1170
  if context_anchor_time is None:
976
1171
  context_anchor_time = anchor_time - end_offset
977
1172
  self._validate_time(query, anchor_time, context_anchor_time,
978
- evaluate)
1173
+ evaluate=num_test_examples > 0)
979
1174
  else:
980
1175
  assert anchor_time == 'entity'
981
1176
  if query.query_type != QueryType.STATIC:
@@ -990,14 +1185,6 @@ class KumoRFM:
990
1185
  "for context and prediction examples")
991
1186
  context_anchor_time = 'entity'
992
1187
 
993
- num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
994
- if evaluate:
995
- num_test_examples = _MAX_TEST_SIZE[run_mode]
996
- if task_type.is_link_pred:
997
- num_test_examples = num_test_examples // 5
998
- else:
999
- num_test_examples = 0
1000
-
1001
1188
  train, test = self._sampler.sample_target(
1002
1189
  query=query,
1003
1190
  num_train_examples=num_train_examples,
@@ -1008,39 +1195,32 @@ class KumoRFM:
1008
1195
  num_test_trials=max_pq_iterations * num_test_examples,
1009
1196
  random_seed=random_seed,
1010
1197
  )
1011
- train_pkey, train_time, y_train = train
1012
- test_pkey, test_time, y_test = test
1198
+ train_pkey, train_time, train_y = train
1199
+ test_pkey, test_time, test_y = test
1013
1200
 
1014
- if evaluate and logger is not None:
1201
+ if num_test_examples > 0 and logger is not None:
1015
1202
  if task_type == TaskType.BINARY_CLASSIFICATION:
1016
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
1017
- msg = (f"Collected {len(y_test):,} test examples with "
1203
+ pos = 100 * int((test_y > 0).sum()) / len(test_y)
1204
+ msg = (f"Collected {len(test_y):,} test examples with "
1018
1205
  f"{pos:.2f}% positive cases")
1019
1206
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1020
- msg = (f"Collected {len(y_test):,} test examples holding "
1021
- f"{y_test.nunique()} classes")
1207
+ msg = (f"Collected {len(test_y):,} test examples holding "
1208
+ f"{test_y.nunique()} classes")
1022
1209
  elif task_type == TaskType.REGRESSION:
1023
- _min, _max = float(y_test.min()), float(y_test.max())
1024
- msg = (f"Collected {len(y_test):,} test examples with targets "
1210
+ _min, _max = float(test_y.min()), float(test_y.max())
1211
+ msg = (f"Collected {len(test_y):,} test examples with targets "
1025
1212
  f"between {format_value(_min)} and "
1026
1213
  f"{format_value(_max)}")
1027
1214
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1028
- num_rhs = y_test.explode().nunique()
1029
- msg = (f"Collected {len(y_test):,} test examples with "
1215
+ num_rhs = test_y.explode().nunique()
1216
+ msg = (f"Collected {len(test_y):,} test examples with "
1030
1217
  f"{num_rhs:,} unique items")
1031
1218
  else:
1032
1219
  raise NotImplementedError
1033
1220
  logger.log(msg)
1034
1221
 
1035
- if not evaluate:
1222
+ if num_test_examples == 0:
1036
1223
  assert indices is not None
1037
- if len(indices) > _MAX_PRED_SIZE[task_type]:
1038
- raise ValueError(f"Cannot predict for more than "
1039
- f"{_MAX_PRED_SIZE[task_type]:,} entities at "
1040
- f"once (got {len(indices):,}). Use "
1041
- f"`KumoRFM.batch_mode` to process entities "
1042
- f"in batches")
1043
-
1044
1224
  test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
1045
1225
  if isinstance(anchor_time, pd.Timestamp):
1046
1226
  test_time = pd.Series([anchor_time]).repeat(
@@ -1050,26 +1230,26 @@ class KumoRFM:
1050
1230
 
1051
1231
  if logger is not None:
1052
1232
  if task_type == TaskType.BINARY_CLASSIFICATION:
1053
- pos = 100 * int((y_train > 0).sum()) / len(y_train)
1054
- msg = (f"Collected {len(y_train):,} in-context examples with "
1233
+ pos = 100 * int((train_y > 0).sum()) / len(train_y)
1234
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1055
1235
  f"{pos:.2f}% positive cases")
1056
1236
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1057
- msg = (f"Collected {len(y_train):,} in-context examples "
1058
- f"holding {y_train.nunique()} classes")
1237
+ msg = (f"Collected {len(train_y):,} in-context examples "
1238
+ f"holding {train_y.nunique()} classes")
1059
1239
  elif task_type == TaskType.REGRESSION:
1060
- _min, _max = float(y_train.min()), float(y_train.max())
1061
- msg = (f"Collected {len(y_train):,} in-context examples with "
1240
+ _min, _max = float(train_y.min()), float(train_y.max())
1241
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1062
1242
  f"targets between {format_value(_min)} and "
1063
1243
  f"{format_value(_max)}")
1064
1244
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1065
- num_rhs = y_train.explode().nunique()
1066
- msg = (f"Collected {len(y_train):,} in-context examples with "
1245
+ num_rhs = train_y.explode().nunique()
1246
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1067
1247
  f"{num_rhs:,} unique items")
1068
1248
  else:
1069
1249
  raise NotImplementedError
1070
1250
  logger.log(msg)
1071
1251
 
1072
- entity_table_names: Tuple[str, ...]
1252
+ entity_table_names: tuple[str] | tuple[str, str]
1073
1253
  if task_type.is_link_pred:
1074
1254
  final_aggr = query.get_final_target_aggregation()
1075
1255
  assert final_aggr is not None
@@ -1083,27 +1263,80 @@ class KumoRFM:
1083
1263
  else:
1084
1264
  entity_table_names = (query.entity_table, )
1085
1265
 
1266
+ context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
1267
+ if isinstance(train_time, pd.Series):
1268
+ context_df['ANCHOR_TIMESTAMP'] = train_time
1269
+ pred_df = pd.DataFrame({'ENTITY': test_pkey})
1270
+ if num_test_examples > 0:
1271
+ pred_df['TARGET'] = test_y
1272
+ if isinstance(test_time, pd.Series):
1273
+ pred_df['ANCHOR_TIMESTAMP'] = test_time
1274
+
1275
+ return TaskTable(
1276
+ task_type=task_type,
1277
+ context_df=context_df,
1278
+ pred_df=pred_df,
1279
+ entity_table_name=entity_table_names,
1280
+ entity_column='ENTITY',
1281
+ target_column='TARGET',
1282
+ time_column='ANCHOR_TIMESTAMP' if isinstance(
1283
+ train_time, pd.Series) else TaskTable.ENTITY_TIME,
1284
+ )
1285
+
1286
+ def _get_context(
1287
+ self,
1288
+ task: TaskTable,
1289
+ run_mode: RunMode | str = RunMode.FAST,
1290
+ num_neighbors: list[int] | None = None,
1291
+ exclude_cols_dict: dict[str, list[str]] | None = None,
1292
+ top_k: int | None = None,
1293
+ ) -> Context:
1294
+
1295
+ if num_neighbors is None:
1296
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1297
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
1298
+
1299
+ if len(num_neighbors) > 6:
1300
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
1301
+ f"hops (got {len(num_neighbors)}). Reduce the "
1302
+ f"number of hops and try again. Please create a "
1303
+ f"feature request at "
1304
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1305
+ f"must go beyond this for your use-case.")
1306
+
1086
1307
  # Exclude the entity anchor time from the feature set to prevent
1087
1308
  # running out-of-distribution between in-context and test examples:
1088
- exclude_cols_dict = query.get_exclude_cols_dict()
1089
- if entity_table_names[0] in self._sampler.time_column_dict:
1090
- if entity_table_names[0] not in exclude_cols_dict:
1091
- exclude_cols_dict[entity_table_names[0]] = []
1092
- time_column = self._sampler.time_column_dict[entity_table_names[0]]
1093
- exclude_cols_dict[entity_table_names[0]].append(time_column)
1309
+ exclude_cols_dict = exclude_cols_dict or {}
1310
+ if task.entity_table_name in self._sampler.time_column_dict:
1311
+ if task.entity_table_name not in exclude_cols_dict:
1312
+ exclude_cols_dict[task.entity_table_name] = []
1313
+ time_col = self._sampler.time_column_dict[task.entity_table_name]
1314
+ exclude_cols_dict[task.entity_table_name].append(time_col)
1315
+
1316
+ entity_pkey = pd.concat([
1317
+ task._context_df[task._entity_column],
1318
+ task._pred_df[task._entity_column],
1319
+ ], axis=0, ignore_index=True)
1320
+
1321
+ if task.use_entity_time:
1322
+ if task.entity_table_name not in self._sampler.time_column_dict:
1323
+ raise ValueError(f"The given annchor time requires the entity "
1324
+ f"table '{task.entity_table_name}' to have a "
1325
+ f"time column")
1326
+ anchor_time = 'entity'
1327
+ elif task._time_column is not None:
1328
+ anchor_time = pd.concat([
1329
+ task._context_df[task._time_column],
1330
+ task._pred_df[task._time_column],
1331
+ ], axis=0, ignore_index=True)
1332
+ else:
1333
+ anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
1334
+ (len(entity_pkey))).reset_index(drop=True)
1094
1335
 
1095
1336
  subgraph = self._sampler.sample_subgraph(
1096
- entity_table_names=entity_table_names,
1097
- entity_pkey=pd.concat(
1098
- [train_pkey, test_pkey],
1099
- axis=0,
1100
- ignore_index=True,
1101
- ),
1102
- anchor_time=pd.concat(
1103
- [train_time, test_time],
1104
- axis=0,
1105
- ignore_index=True,
1106
- ) if isinstance(train_time, pd.Series) else 'entity',
1337
+ entity_table_names=task.entity_table_names,
1338
+ entity_pkey=entity_pkey,
1339
+ anchor_time=anchor_time,
1107
1340
  num_neighbors=num_neighbors,
1108
1341
  exclude_cols_dict=exclude_cols_dict,
1109
1342
  )
@@ -1115,19 +1348,26 @@ class KumoRFM:
1115
1348
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1116
1349
  f"must go beyond this for your use-case.")
1117
1350
 
1351
+ if (task.task_type.is_link_pred
1352
+ and task.entity_table_names[-1] not in subgraph.table_dict):
1353
+ raise ValueError("Cannot perform link prediction on subgraphs "
1354
+ "without any historical target entities. Please "
1355
+ "increase the number of hops and try again.")
1356
+
1118
1357
  return Context(
1119
- task_type=task_type,
1120
- entity_table_names=entity_table_names,
1358
+ task_type=task.task_type,
1359
+ entity_table_names=task.entity_table_names,
1121
1360
  subgraph=subgraph,
1122
- y_train=y_train,
1123
- y_test=y_test if evaluate else None,
1124
- top_k=query.top_k,
1361
+ y_train=task._context_df[task.target_column.name],
1362
+ y_test=task._pred_df[task.target_column.name]
1363
+ if task.evaluate else None,
1364
+ top_k=top_k,
1125
1365
  step_size=None,
1126
1366
  )
1127
1367
 
1128
1368
  @staticmethod
1129
1369
  def _validate_metrics(
1130
- metrics: List[str],
1370
+ metrics: list[str],
1131
1371
  task_type: TaskType,
1132
1372
  ) -> None:
1133
1373
 
@@ -1184,7 +1424,7 @@ class KumoRFM:
1184
1424
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1185
1425
 
1186
1426
 
1187
- def format_value(value: Union[int, float]) -> str:
1427
+ def format_value(value: int | float) -> str:
1188
1428
  if value == int(value):
1189
1429
  return f'{int(value):,}'
1190
1430
  if abs(value) >= 1000: