kumoai 2.14.0.dev202512141732__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +4 -5
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +331 -43
  12. kumoai/experimental/rfm/backend/snow/table.py +166 -56
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +372 -30
  15. kumoai/experimental/rfm/backend/sqlite/table.py +117 -48
  16. kumoai/experimental/rfm/base/__init__.py +8 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +36 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +10 -5
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +606 -361
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +192 -13
  44. kumoai/utils/sql.py +2 -2
  45. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +3 -2
  46. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +49 -40
  47. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
  48. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
  49. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,13 @@
1
1
  import json
2
+ import math
2
3
  import time
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from collections.abc import Generator
6
+ from collections.abc import Generator, Iterator
6
7
  from contextlib import contextmanager
7
8
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
9
+ from typing import Any, Literal, overload
19
10
 
20
- import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
@@ -37,30 +27,37 @@ from kumoapi.rfm import (
37
27
  )
38
28
  from kumoapi.task import TaskType
39
29
  from kumoapi.typing import AggregationType, Stype
30
+ from rich.console import Console
31
+ from rich.markdown import Markdown
40
32
 
33
+ from kumoai import in_notebook
41
34
  from kumoai.client.rfm import RFMAPI
42
35
  from kumoai.exceptions import HTTPException
43
- from kumoai.experimental.rfm import Graph
36
+ from kumoai.experimental.rfm import Graph, TaskTable
44
37
  from kumoai.experimental.rfm.base import DataBackend, Sampler
45
38
  from kumoai.mixin import CastMixin
46
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
39
+ from kumoai.utils import ProgressLogger, display
47
40
 
48
41
  _RANDOM_SEED = 42
49
42
 
50
43
  _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
51
44
  _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
52
45
 
46
+ _MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
47
+ _MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
48
+
53
49
  _MAX_CONTEXT_SIZE = {
54
50
  RunMode.DEBUG: 100,
55
51
  RunMode.FAST: 1_000,
56
52
  RunMode.NORMAL: 5_000,
57
53
  RunMode.BEST: 10_000,
58
54
  }
59
- _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
60
- RunMode.DEBUG: 100,
61
- RunMode.FAST: 2_000,
62
- RunMode.NORMAL: 2_000,
63
- RunMode.BEST: 2_000,
55
+
56
+ _DEFAULT_NUM_NEIGHBORS = {
57
+ RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
58
+ RunMode.FAST: [32, 32, 8, 8, 4, 4],
59
+ RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
60
+ RunMode.BEST: [64, 64, 8, 8, 4, 4],
64
61
  }
65
62
 
66
63
  _MAX_SIZE = 30 * 1024 * 1024
@@ -98,24 +95,36 @@ class Explanation:
98
95
  def __getitem__(self, index: Literal[1]) -> str:
99
96
  pass
100
97
 
101
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
98
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
102
99
  if index == 0:
103
100
  return self.prediction
104
101
  if index == 1:
105
102
  return self.summary
106
103
  raise IndexError("Index out of range")
107
104
 
108
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
105
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
109
106
  return iter((self.prediction, self.summary))
110
107
 
111
108
  def __repr__(self) -> str:
112
109
  return str((self.prediction, self.summary))
113
110
 
114
- def _ipython_display_(self) -> None:
115
- from IPython.display import Markdown, display
111
+ def __str__(self) -> str:
112
+ console = Console(soft_wrap=True)
113
+ with console.capture() as cap:
114
+ console.print(display.to_rich_table(self.prediction))
115
+ console.print(Markdown(self.summary))
116
+ return cap.get()[:-1]
117
+
118
+ def print(self) -> None:
119
+ r"""Prints the explanation."""
120
+ if in_notebook():
121
+ display.dataframe(self.prediction)
122
+ display.message(self.summary)
123
+ else:
124
+ print(self)
116
125
 
117
- display(self.prediction)
118
- display(Markdown(self.summary))
126
+ def _ipython_display_(self) -> None:
127
+ self.print()
119
128
 
120
129
 
121
130
  class KumoRFM:
@@ -154,11 +163,16 @@ class KumoRFM:
154
163
  Args:
155
164
  graph: The graph.
156
165
  verbose: Whether to print verbose output.
166
+ optimize: If set to ``True``, will optimize the underlying data backend
167
+ for optimal querying. For example, for transactional database
168
+ backends, will create any missing indices. Requires write-access to
169
+ the data backend.
157
170
  """
158
171
  def __init__(
159
172
  self,
160
173
  graph: Graph,
161
- verbose: Union[bool, ProgressLogger] = True,
174
+ verbose: bool | ProgressLogger = True,
175
+ optimize: bool = False,
162
176
  ) -> None:
163
177
  graph = graph.validate()
164
178
  self._graph_def = graph._to_api_graph_definition()
@@ -168,17 +182,17 @@ class KumoRFM:
168
182
  self._sampler: Sampler = LocalSampler(graph, verbose)
169
183
  elif graph.backend == DataBackend.SQLITE:
170
184
  from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
171
- self._sampler = SQLiteSampler(graph, verbose)
185
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
172
186
  elif graph.backend == DataBackend.SNOWFLAKE:
173
187
  from kumoai.experimental.rfm.backend.snow import SnowSampler
174
188
  self._sampler = SnowSampler(graph, verbose)
175
189
  else:
176
190
  raise NotImplementedError
177
191
 
178
- self._client: Optional[RFMAPI] = None
192
+ self._client: RFMAPI | None = None
179
193
 
180
- self._batch_size: Optional[int | Literal['max']] = None
181
- self.num_retries: int = 0
194
+ self._batch_size: int | Literal['max'] | None = None
195
+ self._num_retries: int = 0
182
196
 
183
197
  @property
184
198
  def _api_client(self) -> RFMAPI:
@@ -192,10 +206,34 @@ class KumoRFM:
192
206
  def __repr__(self) -> str:
193
207
  return f'{self.__class__.__name__}()'
194
208
 
209
+ @contextmanager
210
+ def retry(
211
+ self,
212
+ num_retries: int = 1,
213
+ ) -> Generator[None, None, None]:
214
+ """Context manager to retry failed queries due to unexpected server
215
+ issues.
216
+
217
+ .. code-block:: python
218
+
219
+ with model.retry(num_retries=1):
220
+ df = model.predict(query, indices=...)
221
+
222
+ Args:
223
+ num_retries: The maximum number of retries.
224
+ """
225
+ if num_retries < 0:
226
+ raise ValueError(f"'num_retries' must be greater than or equal to "
227
+ f"zero (got {num_retries})")
228
+
229
+ self._num_retries = num_retries
230
+ yield
231
+ self._num_retries = 0
232
+
195
233
  @contextmanager
196
234
  def batch_mode(
197
235
  self,
198
- batch_size: Union[int, Literal['max']] = 'max',
236
+ batch_size: int | Literal['max'] = 'max',
199
237
  num_retries: int = 1,
200
238
  ) -> Generator[None, None, None]:
201
239
  """Context manager to predict in batches.
@@ -215,31 +253,26 @@ class KumoRFM:
215
253
  raise ValueError(f"'batch_size' must be greater than zero "
216
254
  f"(got {batch_size})")
217
255
 
218
- if num_retries < 0:
219
- raise ValueError(f"'num_retries' must be greater than or equal to "
220
- f"zero (got {num_retries})")
221
-
222
256
  self._batch_size = batch_size
223
- self.num_retries = num_retries
224
- yield
257
+ with self.retry(self._num_retries or num_retries):
258
+ yield
225
259
  self._batch_size = None
226
- self.num_retries = 0
227
260
 
228
261
  @overload
229
262
  def predict(
230
263
  self,
231
264
  query: str,
232
- indices: Union[List[str], List[float], List[int], None] = None,
265
+ indices: list[str] | list[float] | list[int] | None = None,
233
266
  *,
234
267
  explain: Literal[False] = False,
235
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
236
- context_anchor_time: Union[pd.Timestamp, None] = None,
237
- run_mode: Union[RunMode, str] = RunMode.FAST,
238
- num_neighbors: Optional[List[int]] = None,
268
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
269
+ context_anchor_time: pd.Timestamp | None = None,
270
+ run_mode: RunMode | str = RunMode.FAST,
271
+ num_neighbors: list[int] | None = None,
239
272
  num_hops: int = 2,
240
273
  max_pq_iterations: int = 10,
241
- random_seed: Optional[int] = _RANDOM_SEED,
242
- verbose: Union[bool, ProgressLogger] = True,
274
+ random_seed: int | None = _RANDOM_SEED,
275
+ verbose: bool | ProgressLogger = True,
243
276
  use_prediction_time: bool = False,
244
277
  ) -> pd.DataFrame:
245
278
  pass
@@ -248,37 +281,56 @@ class KumoRFM:
248
281
  def predict(
249
282
  self,
250
283
  query: str,
251
- indices: Union[List[str], List[float], List[int], None] = None,
284
+ indices: list[str] | list[float] | list[int] | None = None,
252
285
  *,
253
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
254
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
255
- context_anchor_time: Union[pd.Timestamp, None] = None,
256
- run_mode: Union[RunMode, str] = RunMode.FAST,
257
- num_neighbors: Optional[List[int]] = None,
286
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
287
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
288
+ context_anchor_time: pd.Timestamp | None = None,
289
+ run_mode: RunMode | str = RunMode.FAST,
290
+ num_neighbors: list[int] | None = None,
258
291
  num_hops: int = 2,
259
292
  max_pq_iterations: int = 10,
260
- random_seed: Optional[int] = _RANDOM_SEED,
261
- verbose: Union[bool, ProgressLogger] = True,
293
+ random_seed: int | None = _RANDOM_SEED,
294
+ verbose: bool | ProgressLogger = True,
262
295
  use_prediction_time: bool = False,
263
296
  ) -> Explanation:
264
297
  pass
265
298
 
299
+ @overload
300
+ def predict(
301
+ self,
302
+ query: str,
303
+ indices: list[str] | list[float] | list[int] | None = None,
304
+ *,
305
+ explain: bool | ExplainConfig | dict[str, Any] = False,
306
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
307
+ context_anchor_time: pd.Timestamp | None = None,
308
+ run_mode: RunMode | str = RunMode.FAST,
309
+ num_neighbors: list[int] | None = None,
310
+ num_hops: int = 2,
311
+ max_pq_iterations: int = 10,
312
+ random_seed: int | None = _RANDOM_SEED,
313
+ verbose: bool | ProgressLogger = True,
314
+ use_prediction_time: bool = False,
315
+ ) -> pd.DataFrame | Explanation:
316
+ pass
317
+
266
318
  def predict(
267
319
  self,
268
320
  query: str,
269
- indices: Union[List[str], List[float], List[int], None] = None,
321
+ indices: list[str] | list[float] | list[int] | None = None,
270
322
  *,
271
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
272
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
273
- context_anchor_time: Union[pd.Timestamp, None] = None,
274
- run_mode: Union[RunMode, str] = RunMode.FAST,
275
- num_neighbors: Optional[List[int]] = None,
323
+ explain: bool | ExplainConfig | dict[str, Any] = False,
324
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
325
+ context_anchor_time: pd.Timestamp | None = None,
326
+ run_mode: RunMode | str = RunMode.FAST,
327
+ num_neighbors: list[int] | None = None,
276
328
  num_hops: int = 2,
277
329
  max_pq_iterations: int = 10,
278
- random_seed: Optional[int] = _RANDOM_SEED,
279
- verbose: Union[bool, ProgressLogger] = True,
330
+ random_seed: int | None = _RANDOM_SEED,
331
+ verbose: bool | ProgressLogger = True,
280
332
  use_prediction_time: bool = False,
281
- ) -> Union[pd.DataFrame, Explanation]:
333
+ ) -> pd.DataFrame | Explanation:
282
334
  """Returns predictions for a predictive query.
283
335
 
284
336
  Args:
@@ -286,8 +338,7 @@ class KumoRFM:
286
338
  indices: The entity primary keys to predict for. Will override the
287
339
  indices given as part of the predictive query. Predictions will
288
340
  be generated for all indices, independent of whether they
289
- fulfill entity filter constraints. To pre-filter entities, use
290
- :meth:`~KumoRFM.is_valid_entity`.
341
+ fulfill entity filter constraints.
291
342
  explain: Configuration for explainability.
292
343
  If set to ``True``, will additionally explain the prediction.
293
344
  Passing in an :class:`ExplainConfig` instance provides control
@@ -320,18 +371,152 @@ class KumoRFM:
320
371
  If ``explain`` is provided, returns an :class:`Explanation` object
321
372
  containing the prediction, summary, and details.
322
373
  """
323
- explain_config: Optional[ExplainConfig] = None
324
- if explain is True:
325
- explain_config = ExplainConfig()
326
- elif explain is not False:
327
- explain_config = ExplainConfig._cast(explain)
328
-
329
374
  query_def = self._parse_query(query)
330
- query_str = query_def.to_string()
331
375
 
376
+ if indices is None:
377
+ if query_def.rfm_entity_ids is None:
378
+ raise ValueError("Cannot find entities to predict for. Please "
379
+ "pass them via `predict(query, indices=...)`")
380
+ indices = query_def.get_rfm_entity_id_list()
381
+ query_def = replace(
382
+ query_def,
383
+ for_each='FOR EACH',
384
+ rfm_entity_ids=None,
385
+ )
386
+
387
+ if not isinstance(verbose, ProgressLogger):
388
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
389
+ if explain is not False:
390
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
391
+ else:
392
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
393
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
394
+
395
+ with verbose as logger:
396
+ task_table = self._get_task_table(
397
+ query=query_def,
398
+ indices=indices,
399
+ anchor_time=anchor_time,
400
+ context_anchor_time=context_anchor_time,
401
+ run_mode=run_mode,
402
+ max_pq_iterations=max_pq_iterations,
403
+ random_seed=random_seed,
404
+ logger=logger,
405
+ )
406
+ task_table._query = query_def.to_string()
407
+
408
+ return self.predict_task(
409
+ task_table,
410
+ explain=explain,
411
+ run_mode=run_mode,
412
+ num_neighbors=num_neighbors,
413
+ num_hops=num_hops,
414
+ verbose=verbose,
415
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
416
+ use_prediction_time=use_prediction_time,
417
+ top_k=query_def.top_k,
418
+ )
419
+
420
+ @overload
421
+ def predict_task(
422
+ self,
423
+ task: TaskTable,
424
+ *,
425
+ explain: Literal[False] = False,
426
+ run_mode: RunMode | str = RunMode.FAST,
427
+ num_neighbors: list[int] | None = None,
428
+ num_hops: int = 2,
429
+ verbose: bool | ProgressLogger = True,
430
+ exclude_cols_dict: dict[str, list[str]] | None = None,
431
+ use_prediction_time: bool = False,
432
+ top_k: int | None = None,
433
+ ) -> pd.DataFrame:
434
+ pass
435
+
436
+ @overload
437
+ def predict_task(
438
+ self,
439
+ task: TaskTable,
440
+ *,
441
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
442
+ run_mode: RunMode | str = RunMode.FAST,
443
+ num_neighbors: list[int] | None = None,
444
+ num_hops: int = 2,
445
+ verbose: bool | ProgressLogger = True,
446
+ exclude_cols_dict: dict[str, list[str]] | None = None,
447
+ use_prediction_time: bool = False,
448
+ top_k: int | None = None,
449
+ ) -> Explanation:
450
+ pass
451
+
452
+ @overload
453
+ def predict_task(
454
+ self,
455
+ task: TaskTable,
456
+ *,
457
+ explain: bool | ExplainConfig | dict[str, Any] = False,
458
+ run_mode: RunMode | str = RunMode.FAST,
459
+ num_neighbors: list[int] | None = None,
460
+ num_hops: int = 2,
461
+ verbose: bool | ProgressLogger = True,
462
+ exclude_cols_dict: dict[str, list[str]] | None = None,
463
+ use_prediction_time: bool = False,
464
+ top_k: int | None = None,
465
+ ) -> pd.DataFrame | Explanation:
466
+ pass
467
+
468
+ def predict_task(
469
+ self,
470
+ task: TaskTable,
471
+ *,
472
+ explain: bool | ExplainConfig | dict[str, Any] = False,
473
+ run_mode: RunMode | str = RunMode.FAST,
474
+ num_neighbors: list[int] | None = None,
475
+ num_hops: int = 2,
476
+ verbose: bool | ProgressLogger = True,
477
+ exclude_cols_dict: dict[str, list[str]] | None = None,
478
+ use_prediction_time: bool = False,
479
+ top_k: int | None = None,
480
+ ) -> pd.DataFrame | Explanation:
481
+ """Returns predictions for a custom task specification.
482
+
483
+ Args:
484
+ task: The custom :class:`TaskTable`.
485
+ explain: Configuration for explainability.
486
+ If set to ``True``, will additionally explain the prediction.
487
+ Passing in an :class:`ExplainConfig` instance provides control
488
+ over which parts of explanation are generated.
489
+ Explainability is currently only supported for single entity
490
+ predictions with ``run_mode="FAST"``.
491
+ run_mode: The :class:`RunMode` for the query.
492
+ num_neighbors: The number of neighbors to sample for each hop.
493
+ If specified, the ``num_hops`` option will be ignored.
494
+ num_hops: The number of hops to sample when generating the context.
495
+ verbose: Whether to print verbose output.
496
+ exclude_cols_dict: Any column in any table to exclude from the
497
+ model input.
498
+ use_prediction_time: Whether to use the anchor timestamp as an
499
+ additional feature during prediction. This is typically
500
+ beneficial for time series forecasting tasks.
501
+ top_k: The number of predictions to return per entity.
502
+
503
+ Returns:
504
+ The predictions as a :class:`pandas.DataFrame`.
505
+ If ``explain`` is provided, returns an :class:`Explanation` object
506
+ containing the prediction, summary, and details.
507
+ """
332
508
  if num_hops != 2 and num_neighbors is not None:
333
509
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
334
510
  f"custom 'num_hops={num_hops}' option")
511
+ if num_neighbors is None:
512
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
513
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
514
+
515
+ explain_config: ExplainConfig | None = None
516
+ if explain is True:
517
+ explain_config = ExplainConfig()
518
+ elif explain is not False:
519
+ explain_config = ExplainConfig._cast(explain)
335
520
 
336
521
  if explain_config is not None and run_mode in {
337
522
  RunMode.NORMAL, RunMode.BEST
@@ -340,83 +525,82 @@ class KumoRFM:
340
525
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
341
526
  f"mode has been reset. Please lower the run mode to "
342
527
  f"suppress this warning.")
528
+ run_mode = RunMode.FAST
343
529
 
344
- if indices is None:
345
- if query_def.rfm_entity_ids is None:
346
- raise ValueError("Cannot find entities to predict for. Please "
347
- "pass them via `predict(query, indices=...)`")
348
- indices = query_def.get_rfm_entity_id_list()
349
- else:
350
- query_def = replace(query_def, rfm_entity_ids=None)
351
-
352
- if len(indices) == 0:
353
- raise ValueError("At least one entity is required")
354
-
355
- if explain_config is not None and len(indices) > 1:
356
- raise ValueError(
357
- f"Cannot explain predictions for more than a single entity "
358
- f"(got {len(indices)})")
359
-
360
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
361
- if explain_config is not None:
362
- msg = f'[bold]EXPLAIN[/bold] {query_repr}'
363
- else:
364
- msg = f'[bold]PREDICT[/bold] {query_repr}'
530
+ if explain_config is not None and task.num_prediction_examples > 1:
531
+ raise ValueError(f"Cannot explain predictions for more than a "
532
+ f"single entity "
533
+ f"(got {task.num_prediction_examples:,})")
365
534
 
366
535
  if not isinstance(verbose, ProgressLogger):
367
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
368
-
369
- with verbose as logger:
370
-
371
- batch_size: Optional[int] = None
372
- if self._batch_size == 'max':
373
- task_type = self._get_task_type(
374
- query=query_def,
375
- edge_types=self._sampler.edge_types,
376
- )
377
- batch_size = _MAX_PRED_SIZE[task_type]
536
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
537
+ task_type_repr = 'binary classification'
538
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
539
+ task_type_repr = 'multi-class classification'
540
+ elif task.task_type == TaskType.REGRESSION:
541
+ task_type_repr = 'regression'
542
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
543
+ task_type_repr = 'link prediction'
378
544
  else:
379
- batch_size = self._batch_size
545
+ task_type_repr = str(task.task_type)
380
546
 
381
- if batch_size is not None:
382
- offsets = range(0, len(indices), batch_size)
383
- batches = [indices[step:step + batch_size] for step in offsets]
547
+ if explain_config is not None:
548
+ msg = f"Explaining {task_type_repr} task"
384
549
  else:
385
- batches = [indices]
550
+ msg = f"Predicting {task_type_repr} task"
551
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
386
552
 
387
- if len(batches) > 1:
388
- logger.log(f"Splitting {len(indices):,} entities into "
389
- f"{len(batches):,} batches of size {batch_size:,}")
553
+ with verbose as logger:
554
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
555
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
556
+ f"out of {task.num_context_examples:,} in-context "
557
+ f"examples")
558
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
559
+
560
+ if self._batch_size is None:
561
+ batch_size = task.num_prediction_examples
562
+ elif self._batch_size == 'max':
563
+ batch_size = _MAX_PRED_SIZE[task.task_type]
564
+ else:
565
+ batch_size = self._batch_size
390
566
 
391
- predictions: List[pd.DataFrame] = []
392
- summary: Optional[str] = None
393
- details: Optional[Explanation] = None
394
- for i, batch in enumerate(batches):
395
- # TODO Re-use the context for subsequent predictions.
567
+ if batch_size > _MAX_PRED_SIZE[task.task_type]:
568
+ raise ValueError(f"Cannot predict for more than "
569
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
570
+ f"entities at once (got {batch_size:,}). Use "
571
+ f"`KumoRFM.batch_mode` to process entities "
572
+ f"in batches with a sufficient batch size.")
573
+
574
+ if task.num_prediction_examples > batch_size:
575
+ num = math.ceil(task.num_prediction_examples / batch_size)
576
+ logger.log(f"Splitting {task.num_prediction_examples:,} "
577
+ f"entities into {num:,} batches of size "
578
+ f"{batch_size:,}")
579
+
580
+ predictions: list[pd.DataFrame] = []
581
+ summary: str | None = None
582
+ details: Explanation | None = None
583
+ for start in range(0, task.num_prediction_examples, batch_size):
396
584
  context = self._get_context(
397
- query=query_def,
398
- indices=batch,
399
- anchor_time=anchor_time,
400
- context_anchor_time=context_anchor_time,
401
- run_mode=RunMode(run_mode),
585
+ task=task.narrow_prediction(start, length=batch_size),
586
+ run_mode=run_mode,
402
587
  num_neighbors=num_neighbors,
403
- num_hops=num_hops,
404
- max_pq_iterations=max_pq_iterations,
405
- evaluate=False,
406
- random_seed=random_seed,
407
- logger=logger if i == 0 else None,
588
+ exclude_cols_dict=exclude_cols_dict,
589
+ top_k=top_k,
408
590
  )
591
+ context.y_test = None
592
+
409
593
  request = RFMPredictRequest(
410
594
  context=context,
411
595
  run_mode=RunMode(run_mode),
412
- query=query_str,
596
+ query=task._query,
413
597
  use_prediction_time=use_prediction_time,
414
598
  )
415
599
  with warnings.catch_warnings():
416
600
  warnings.filterwarnings('ignore', message='gencode')
417
601
  request_msg = request.to_protobuf()
418
602
  _bytes = request_msg.SerializeToString()
419
- if i == 0:
603
+ if start == 0:
420
604
  logger.log(f"Generated context of size "
421
605
  f"{len(_bytes) / (1024*1024):.2f}MB")
422
606
 
@@ -424,14 +608,11 @@ class KumoRFM:
424
608
  stats = Context.get_memory_stats(request_msg.context)
425
609
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
426
610
 
427
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
428
- and len(batches) > 1):
429
- verbose.init_progress(
430
- total=len(batches),
431
- description='Predicting',
432
- )
611
+ if start == 0 and task.num_prediction_examples > batch_size:
612
+ num = math.ceil(task.num_prediction_examples / batch_size)
613
+ verbose.init_progress(total=num, description='Predicting')
433
614
 
434
- for attempt in range(self.num_retries + 1):
615
+ for attempt in range(self._num_retries + 1):
435
616
  try:
436
617
  if explain_config is not None:
437
618
  resp = self._api_client.explain(
@@ -447,7 +628,7 @@ class KumoRFM:
447
628
  # Cast 'ENTITY' to correct data type:
448
629
  if 'ENTITY' in df:
449
630
  table_dict = context.subgraph.table_dict
450
- table = table_dict[query_def.entity_table]
631
+ table = table_dict[context.entity_table_names[0]]
451
632
  ser = table.df[table.primary_key]
452
633
  df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
453
634
 
@@ -464,13 +645,12 @@ class KumoRFM:
464
645
 
465
646
  predictions.append(df)
466
647
 
467
- if (isinstance(verbose, InteractiveProgressLogger)
468
- and len(batches) > 1):
648
+ if task.num_prediction_examples > batch_size:
469
649
  verbose.step()
470
650
 
471
651
  break
472
652
  except HTTPException as e:
473
- if attempt == self.num_retries:
653
+ if attempt == self._num_retries:
474
654
  try:
475
655
  msg = json.loads(e.detail)['detail']
476
656
  except Exception:
@@ -500,64 +680,19 @@ class KumoRFM:
500
680
 
501
681
  return prediction
502
682
 
503
- def is_valid_entity(
504
- self,
505
- query: str,
506
- indices: Union[List[str], List[float], List[int], None] = None,
507
- *,
508
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
509
- ) -> np.ndarray:
510
- r"""Returns a mask that denotes which entities are valid for the
511
- given predictive query, *i.e.*, which entities fulfill (temporal)
512
- entity filter constraints.
513
-
514
- Args:
515
- query: The predictive query.
516
- indices: The entity primary keys to predict for. Will override the
517
- indices given as part of the predictive query.
518
- anchor_time: The anchor timestamp for the prediction. If set to
519
- ``None``, will use the maximum timestamp in the data.
520
- If set to ``"entity"``, will use the timestamp of the entity.
521
- """
522
- query_def = self._parse_query(query)
523
-
524
- if indices is None:
525
- if query_def.rfm_entity_ids is None:
526
- raise ValueError("Cannot find entities to predict for. Please "
527
- "pass them via "
528
- "`is_valid_entity(query, indices=...)`")
529
- indices = query_def.get_rfm_entity_id_list()
530
-
531
- if len(indices) == 0:
532
- raise ValueError("At least one entity is required")
533
-
534
- if anchor_time is None:
535
- anchor_time = self._get_default_anchor_time(query_def)
536
-
537
- if isinstance(anchor_time, pd.Timestamp):
538
- self._validate_time(query_def, anchor_time, None, False)
539
- else:
540
- assert anchor_time == 'entity'
541
- if query_def.entity_table not in self._sampler.time_column_dict:
542
- raise ValueError(f"Anchor time 'entity' requires the entity "
543
- f"table '{query_def.entity_table}' "
544
- f"to have a time column.")
545
-
546
- raise NotImplementedError
547
-
548
683
  def evaluate(
549
684
  self,
550
685
  query: str,
551
686
  *,
552
- metrics: Optional[List[str]] = None,
553
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
554
- context_anchor_time: Union[pd.Timestamp, None] = None,
555
- run_mode: Union[RunMode, str] = RunMode.FAST,
556
- num_neighbors: Optional[List[int]] = None,
687
+ metrics: list[str] | None = None,
688
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
689
+ context_anchor_time: pd.Timestamp | None = None,
690
+ run_mode: RunMode | str = RunMode.FAST,
691
+ num_neighbors: list[int] | None = None,
557
692
  num_hops: int = 2,
558
693
  max_pq_iterations: int = 10,
559
- random_seed: Optional[int] = _RANDOM_SEED,
560
- verbose: Union[bool, ProgressLogger] = True,
694
+ random_seed: int | None = _RANDOM_SEED,
695
+ verbose: bool | ProgressLogger = True,
561
696
  use_prediction_time: bool = False,
562
697
  ) -> pd.DataFrame:
563
698
  """Evaluates a predictive query.
@@ -589,41 +724,120 @@ class KumoRFM:
589
724
  Returns:
590
725
  The metrics as a :class:`pandas.DataFrame`
591
726
  """
592
- query_def = self._parse_query(query)
727
+ query_def = replace(
728
+ self._parse_query(query),
729
+ for_each='FOR EACH',
730
+ rfm_entity_ids=None,
731
+ )
732
+
733
+ if not isinstance(verbose, ProgressLogger):
734
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
735
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
736
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
737
+
738
+ with verbose as logger:
739
+ task_table = self._get_task_table(
740
+ query=query_def,
741
+ indices=None,
742
+ anchor_time=anchor_time,
743
+ context_anchor_time=context_anchor_time,
744
+ run_mode=run_mode,
745
+ max_pq_iterations=max_pq_iterations,
746
+ random_seed=random_seed,
747
+ logger=logger,
748
+ )
749
+
750
+ return self.evaluate_task(
751
+ task_table,
752
+ metrics=metrics,
753
+ run_mode=run_mode,
754
+ num_neighbors=num_neighbors,
755
+ num_hops=num_hops,
756
+ verbose=verbose,
757
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
758
+ use_prediction_time=use_prediction_time,
759
+ )
760
+
761
+ def evaluate_task(
762
+ self,
763
+ task: TaskTable,
764
+ *,
765
+ metrics: list[str] | None = None,
766
+ run_mode: RunMode | str = RunMode.FAST,
767
+ num_neighbors: list[int] | None = None,
768
+ num_hops: int = 2,
769
+ verbose: bool | ProgressLogger = True,
770
+ exclude_cols_dict: dict[str, list[str]] | None = None,
771
+ use_prediction_time: bool = False,
772
+ ) -> pd.DataFrame:
773
+ """Evaluates a custom task specification.
774
+
775
+ Args:
776
+ task: The custom :class:`TaskTable`.
777
+ metrics: The metrics to use.
778
+ run_mode: The :class:`RunMode` for the query.
779
+ num_neighbors: The number of neighbors to sample for each hop.
780
+ If specified, the ``num_hops`` option will be ignored.
781
+ num_hops: The number of hops to sample when generating the context.
782
+ verbose: Whether to print verbose output.
783
+ exclude_cols_dict: Any column in any table to exclude from the
784
+ model input.
785
+ use_prediction_time: Whether to use the anchor timestamp as an
786
+ additional feature during prediction. This is typically
787
+ beneficial for time series forecasting tasks.
593
788
 
789
+ Returns:
790
+ The metrics as a :class:`pandas.DataFrame`
791
+ """
594
792
  if num_hops != 2 and num_neighbors is not None:
595
793
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
596
794
  f"custom 'num_hops={num_hops}' option")
795
+ if num_neighbors is None:
796
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
797
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
597
798
 
598
- if query_def.rfm_entity_ids is not None:
599
- query_def = replace(
600
- query_def,
601
- rfm_entity_ids=None,
602
- )
603
-
604
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
605
- msg = f'[bold]EVALUATE[/bold] {query_repr}'
799
+ if metrics is not None and len(metrics) > 0:
800
+ self._validate_metrics(metrics, task.task_type)
801
+ metrics = list(dict.fromkeys(metrics))
606
802
 
607
803
  if not isinstance(verbose, ProgressLogger):
608
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
804
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
805
+ task_type_repr = 'binary classification'
806
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
807
+ task_type_repr = 'multi-class classification'
808
+ elif task.task_type == TaskType.REGRESSION:
809
+ task_type_repr = 'regression'
810
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
811
+ task_type_repr = 'link prediction'
812
+ else:
813
+ task_type_repr = str(task.task_type)
814
+
815
+ msg = f"Evaluating {task_type_repr} task"
816
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
609
817
 
610
818
  with verbose as logger:
819
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
820
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
821
+ f"out of {task.num_context_examples:,} in-context "
822
+ f"examples")
823
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
824
+
825
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
826
+ logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
827
+ f"out of {task.num_prediction_examples:,} test "
828
+ f"examples")
829
+ task = task.narrow_prediction(
830
+ start=0,
831
+ length=_MAX_TEST_SIZE[task.task_type],
832
+ )
833
+
611
834
  context = self._get_context(
612
- query=query_def,
613
- indices=None,
614
- anchor_time=anchor_time,
615
- context_anchor_time=context_anchor_time,
616
- run_mode=RunMode(run_mode),
835
+ task=task,
836
+ run_mode=run_mode,
617
837
  num_neighbors=num_neighbors,
618
- num_hops=num_hops,
619
- max_pq_iterations=max_pq_iterations,
620
- evaluate=True,
621
- random_seed=random_seed,
622
- logger=logger if verbose else None,
838
+ exclude_cols_dict=exclude_cols_dict,
623
839
  )
624
- if metrics is not None and len(metrics) > 0:
625
- self._validate_metrics(metrics, context.task_type)
626
- metrics = list(dict.fromkeys(metrics))
840
+
627
841
  request = RFMEvaluateRequest(
628
842
  context=context,
629
843
  run_mode=RunMode(run_mode),
@@ -641,17 +855,23 @@ class KumoRFM:
641
855
  stats_msg = Context.get_memory_stats(request_msg.context)
642
856
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
643
857
 
644
- try:
645
- resp = self._api_client.evaluate(request_bytes)
646
- except HTTPException as e:
858
+ for attempt in range(self._num_retries + 1):
647
859
  try:
648
- msg = json.loads(e.detail)['detail']
649
- except Exception:
650
- msg = e.detail
651
- raise RuntimeError(f"An unexpected exception occurred. "
652
- f"Please create an issue at "
653
- f"'https://github.com/kumo-ai/kumo-rfm'. "
654
- f"{msg}") from None
860
+ resp = self._api_client.evaluate(request_bytes)
861
+ break
862
+ except HTTPException as e:
863
+ if attempt == self._num_retries:
864
+ try:
865
+ msg = json.loads(e.detail)['detail']
866
+ except Exception:
867
+ msg = e.detail
868
+ raise RuntimeError(
869
+ f"An unexpected exception occurred. Please create "
870
+ f"an issue at "
871
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
872
+ ) from None
873
+
874
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
655
875
 
656
876
  return pd.DataFrame.from_dict(
657
877
  resp.metrics,
@@ -664,8 +884,8 @@ class KumoRFM:
664
884
  query: str,
665
885
  size: int,
666
886
  *,
667
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
668
- random_seed: Optional[int] = _RANDOM_SEED,
887
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
888
+ random_seed: int | None = _RANDOM_SEED,
669
889
  max_iterations: int = 10,
670
890
  ) -> pd.DataFrame:
671
891
  """Returns the labels of a predictive query for a specified anchor
@@ -703,7 +923,7 @@ class KumoRFM:
703
923
  f"to have a time column")
704
924
 
705
925
  train, test = self._sampler.sample_target(
706
- query=query,
926
+ query=query_def,
707
927
  num_train_examples=0,
708
928
  train_anchor_time=anchor_time,
709
929
  num_train_trials=0,
@@ -731,35 +951,39 @@ class KumoRFM:
731
951
  "`predict()` or `evaluate()` methods to perform "
732
952
  "predictions or evaluations.")
733
953
 
734
- try:
735
- request = RFMParseQueryRequest(
736
- query=query,
737
- graph_definition=self._graph_def,
738
- )
954
+ request = RFMParseQueryRequest(
955
+ query=query,
956
+ graph_definition=self._graph_def,
957
+ )
739
958
 
740
- resp = self._api_client.parse_query(request)
959
+ for attempt in range(self._num_retries + 1):
960
+ try:
961
+ resp = self._api_client.parse_query(request)
962
+ break
963
+ except HTTPException as e:
964
+ if attempt == self._num_retries:
965
+ try:
966
+ msg = json.loads(e.detail)['detail']
967
+ except Exception:
968
+ msg = e.detail
969
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
741
970
 
742
- if len(resp.validation_response.warnings) > 0:
743
- msg = '\n'.join([
744
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
745
- in enumerate(resp.validation_response.warnings)
746
- ])
747
- warnings.warn(f"Encountered the following warnings during "
748
- f"parsing:\n{msg}")
971
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
749
972
 
750
- return resp.query
751
- except HTTPException as e:
752
- try:
753
- msg = json.loads(e.detail)['detail']
754
- except Exception:
755
- msg = e.detail
756
- raise ValueError(f"Failed to parse query '{query}'. "
757
- f"{msg}") from None
973
+ if len(resp.validation_response.warnings) > 0:
974
+ msg = '\n'.join([
975
+ f'{i+1}. {warning.title}: {warning.message}'
976
+ for i, warning in enumerate(resp.validation_response.warnings)
977
+ ])
978
+ warnings.warn(f"Encountered the following warnings during "
979
+ f"parsing:\n{msg}")
980
+
981
+ return resp.query
758
982
 
759
983
  @staticmethod
760
984
  def _get_task_type(
761
985
  query: ValidatedPredictiveQuery,
762
- edge_types: List[Tuple[str, str, str]],
986
+ edge_types: list[tuple[str, str, str]],
763
987
  ) -> TaskType:
764
988
  if isinstance(query.target_ast, (Condition, LogicalOperation)):
765
989
  return TaskType.BINARY_CLASSIFICATION
@@ -798,31 +1022,38 @@ class KumoRFM:
798
1022
 
799
1023
  def _get_default_anchor_time(
800
1024
  self,
801
- query: ValidatedPredictiveQuery,
1025
+ query: ValidatedPredictiveQuery | None = None,
802
1026
  ) -> pd.Timestamp:
803
- if query.query_type == QueryType.TEMPORAL:
1027
+ if query is not None and query.query_type == QueryType.TEMPORAL:
804
1028
  aggr_table_names = [
805
1029
  aggr._get_target_column_name().split('.')[0]
806
1030
  for aggr in query.get_all_target_aggregations()
807
1031
  ]
808
1032
  return self._sampler.get_max_time(aggr_table_names)
809
1033
 
810
- assert query.query_type == QueryType.STATIC
811
1034
  return self._sampler.get_max_time()
812
1035
 
813
1036
  def _validate_time(
814
1037
  self,
815
1038
  query: ValidatedPredictiveQuery,
816
1039
  anchor_time: pd.Timestamp,
817
- context_anchor_time: Union[pd.Timestamp, None],
1040
+ context_anchor_time: pd.Timestamp | None,
818
1041
  evaluate: bool,
819
1042
  ) -> None:
820
1043
 
821
1044
  if len(self._sampler.time_column_dict) == 0:
822
1045
  return # Graph without timestamps
823
1046
 
824
- min_time = self._sampler.get_min_time()
825
- max_time = self._sampler.get_max_time()
1047
+ if query.query_type == QueryType.TEMPORAL:
1048
+ aggr_table_names = [
1049
+ aggr._get_target_column_name().split('.')[0]
1050
+ for aggr in query.get_all_target_aggregations()
1051
+ ]
1052
+ min_time = self._sampler.get_min_time(aggr_table_names)
1053
+ max_time = self._sampler.get_max_time(aggr_table_names)
1054
+ else:
1055
+ min_time = self._sampler.get_min_time()
1056
+ max_time = self._sampler.get_max_time()
826
1057
 
827
1058
  if anchor_time < min_time:
828
1059
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
@@ -877,40 +1108,26 @@ class KumoRFM:
877
1108
  f"Anchor timestamp for evaluation is after the latest "
878
1109
  f"supported timestamp '{max_time - end_offset}'.")
879
1110
 
880
- def _get_context(
1111
+ def _get_task_table(
881
1112
  self,
882
1113
  query: ValidatedPredictiveQuery,
883
- indices: Union[List[str], List[float], List[int], None],
884
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
885
- context_anchor_time: Union[pd.Timestamp, None],
886
- run_mode: RunMode,
887
- num_neighbors: Optional[List[int]],
888
- num_hops: int,
889
- max_pq_iterations: int,
890
- evaluate: bool,
891
- random_seed: Optional[int] = _RANDOM_SEED,
892
- logger: Optional[ProgressLogger] = None,
893
- ) -> Context:
894
-
895
- if num_neighbors is not None:
896
- num_hops = len(num_neighbors)
897
-
898
- if num_hops < 0:
899
- raise ValueError(f"'num_hops' must be non-negative "
900
- f"(got {num_hops})")
901
- if num_hops > 6:
902
- raise ValueError(f"Cannot predict on subgraphs with more than 6 "
903
- f"hops (got {num_hops}). Please reduce the "
904
- f"number of hops and try again. Please create a "
905
- f"feature request at "
906
- f"'https://github.com/kumo-ai/kumo-rfm' if you "
907
- f"must go beyond this for your use-case.")
1114
+ indices: list[str] | list[float] | list[int] | None,
1115
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
1116
+ context_anchor_time: pd.Timestamp | None = None,
1117
+ run_mode: RunMode = RunMode.FAST,
1118
+ max_pq_iterations: int = 10,
1119
+ random_seed: int | None = _RANDOM_SEED,
1120
+ logger: ProgressLogger | None = None,
1121
+ ) -> TaskTable:
908
1122
 
909
1123
  task_type = self._get_task_type(
910
1124
  query=query,
911
1125
  edge_types=self._sampler.edge_types,
912
1126
  )
913
1127
 
1128
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1129
+ num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
1130
+
914
1131
  if logger is not None:
915
1132
  if task_type == TaskType.BINARY_CLASSIFICATION:
916
1133
  task_type_repr = 'binary classification'
@@ -924,21 +1141,6 @@ class KumoRFM:
924
1141
  task_type_repr = str(task_type)
925
1142
  logger.log(f"Identified {query.query_type} {task_type_repr} task")
926
1143
 
927
- if task_type.is_link_pred and num_hops < 2:
928
- raise ValueError(f"Cannot perform link prediction on subgraphs "
929
- f"with less than 2 hops (got {num_hops}) since "
930
- f"historical target entities need to be part of "
931
- f"the context. Please increase the number of "
932
- f"hops and try again.")
933
-
934
- if num_neighbors is None:
935
- if run_mode == RunMode.DEBUG:
936
- num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
937
- elif run_mode == RunMode.FAST or task_type.is_link_pred:
938
- num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
939
- else:
940
- num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
941
-
942
1144
  if query.target_ast.date_offset_range is None:
943
1145
  step_offset = pd.DateOffset(0)
944
1146
  else:
@@ -947,8 +1149,7 @@ class KumoRFM:
947
1149
 
948
1150
  if anchor_time is None:
949
1151
  anchor_time = self._get_default_anchor_time(query)
950
-
951
- if evaluate:
1152
+ if num_test_examples > 0:
952
1153
  anchor_time = anchor_time - end_offset
953
1154
 
954
1155
  if logger is not None:
@@ -962,7 +1163,6 @@ class KumoRFM:
962
1163
  else:
963
1164
  logger.log(f"Derived anchor time {anchor_time}")
964
1165
 
965
- assert anchor_time is not None
966
1166
  if isinstance(anchor_time, pd.Timestamp):
967
1167
  if context_anchor_time == 'entity':
968
1168
  raise ValueError("Anchor time 'entity' needs to be shared "
@@ -970,7 +1170,7 @@ class KumoRFM:
970
1170
  if context_anchor_time is None:
971
1171
  context_anchor_time = anchor_time - end_offset
972
1172
  self._validate_time(query, anchor_time, context_anchor_time,
973
- evaluate)
1173
+ evaluate=num_test_examples > 0)
974
1174
  else:
975
1175
  assert anchor_time == 'entity'
976
1176
  if query.query_type != QueryType.STATIC:
@@ -985,14 +1185,6 @@ class KumoRFM:
985
1185
  "for context and prediction examples")
986
1186
  context_anchor_time = 'entity'
987
1187
 
988
- num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
989
- if evaluate:
990
- num_test_examples = _MAX_TEST_SIZE[run_mode]
991
- if task_type.is_link_pred:
992
- num_test_examples = num_test_examples // 5
993
- else:
994
- num_test_examples = 0
995
-
996
1188
  train, test = self._sampler.sample_target(
997
1189
  query=query,
998
1190
  num_train_examples=num_train_examples,
@@ -1003,39 +1195,32 @@ class KumoRFM:
1003
1195
  num_test_trials=max_pq_iterations * num_test_examples,
1004
1196
  random_seed=random_seed,
1005
1197
  )
1006
- train_pkey, train_time, y_train = train
1007
- test_pkey, test_time, y_test = test
1198
+ train_pkey, train_time, train_y = train
1199
+ test_pkey, test_time, test_y = test
1008
1200
 
1009
- if evaluate and logger is not None:
1201
+ if num_test_examples > 0 and logger is not None:
1010
1202
  if task_type == TaskType.BINARY_CLASSIFICATION:
1011
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
1012
- msg = (f"Collected {len(y_test):,} test examples with "
1203
+ pos = 100 * int((test_y > 0).sum()) / len(test_y)
1204
+ msg = (f"Collected {len(test_y):,} test examples with "
1013
1205
  f"{pos:.2f}% positive cases")
1014
1206
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1015
- msg = (f"Collected {len(y_test):,} test examples holding "
1016
- f"{y_test.nunique()} classes")
1207
+ msg = (f"Collected {len(test_y):,} test examples holding "
1208
+ f"{test_y.nunique()} classes")
1017
1209
  elif task_type == TaskType.REGRESSION:
1018
- _min, _max = float(y_test.min()), float(y_test.max())
1019
- msg = (f"Collected {len(y_test):,} test examples with targets "
1210
+ _min, _max = float(test_y.min()), float(test_y.max())
1211
+ msg = (f"Collected {len(test_y):,} test examples with targets "
1020
1212
  f"between {format_value(_min)} and "
1021
1213
  f"{format_value(_max)}")
1022
1214
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1023
- num_rhs = y_test.explode().nunique()
1024
- msg = (f"Collected {len(y_test):,} test examples with "
1215
+ num_rhs = test_y.explode().nunique()
1216
+ msg = (f"Collected {len(test_y):,} test examples with "
1025
1217
  f"{num_rhs:,} unique items")
1026
1218
  else:
1027
1219
  raise NotImplementedError
1028
1220
  logger.log(msg)
1029
1221
 
1030
- if not evaluate:
1222
+ if num_test_examples == 0:
1031
1223
  assert indices is not None
1032
- if len(indices) > _MAX_PRED_SIZE[task_type]:
1033
- raise ValueError(f"Cannot predict for more than "
1034
- f"{_MAX_PRED_SIZE[task_type]:,} entities at "
1035
- f"once (got {len(indices):,}). Use "
1036
- f"`KumoRFM.batch_mode` to process entities "
1037
- f"in batches")
1038
-
1039
1224
  test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
1040
1225
  if isinstance(anchor_time, pd.Timestamp):
1041
1226
  test_time = pd.Series([anchor_time]).repeat(
@@ -1045,26 +1230,26 @@ class KumoRFM:
1045
1230
 
1046
1231
  if logger is not None:
1047
1232
  if task_type == TaskType.BINARY_CLASSIFICATION:
1048
- pos = 100 * int((y_train > 0).sum()) / len(y_train)
1049
- msg = (f"Collected {len(y_train):,} in-context examples with "
1233
+ pos = 100 * int((train_y > 0).sum()) / len(train_y)
1234
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1050
1235
  f"{pos:.2f}% positive cases")
1051
1236
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1052
- msg = (f"Collected {len(y_train):,} in-context examples "
1053
- f"holding {y_train.nunique()} classes")
1237
+ msg = (f"Collected {len(train_y):,} in-context examples "
1238
+ f"holding {train_y.nunique()} classes")
1054
1239
  elif task_type == TaskType.REGRESSION:
1055
- _min, _max = float(y_train.min()), float(y_train.max())
1056
- msg = (f"Collected {len(y_train):,} in-context examples with "
1240
+ _min, _max = float(train_y.min()), float(train_y.max())
1241
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1057
1242
  f"targets between {format_value(_min)} and "
1058
1243
  f"{format_value(_max)}")
1059
1244
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1060
- num_rhs = y_train.explode().nunique()
1061
- msg = (f"Collected {len(y_train):,} in-context examples with "
1245
+ num_rhs = train_y.explode().nunique()
1246
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1062
1247
  f"{num_rhs:,} unique items")
1063
1248
  else:
1064
1249
  raise NotImplementedError
1065
1250
  logger.log(msg)
1066
1251
 
1067
- entity_table_names: Tuple[str, ...]
1252
+ entity_table_names: tuple[str] | tuple[str, str]
1068
1253
  if task_type.is_link_pred:
1069
1254
  final_aggr = query.get_final_target_aggregation()
1070
1255
  assert final_aggr is not None
@@ -1078,27 +1263,80 @@ class KumoRFM:
1078
1263
  else:
1079
1264
  entity_table_names = (query.entity_table, )
1080
1265
 
1266
+ context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
1267
+ if isinstance(train_time, pd.Series):
1268
+ context_df['ANCHOR_TIMESTAMP'] = train_time
1269
+ pred_df = pd.DataFrame({'ENTITY': test_pkey})
1270
+ if num_test_examples > 0:
1271
+ pred_df['TARGET'] = test_y
1272
+ if isinstance(test_time, pd.Series):
1273
+ pred_df['ANCHOR_TIMESTAMP'] = test_time
1274
+
1275
+ return TaskTable(
1276
+ task_type=task_type,
1277
+ context_df=context_df,
1278
+ pred_df=pred_df,
1279
+ entity_table_name=entity_table_names,
1280
+ entity_column='ENTITY',
1281
+ target_column='TARGET',
1282
+ time_column='ANCHOR_TIMESTAMP' if isinstance(
1283
+ train_time, pd.Series) else TaskTable.ENTITY_TIME,
1284
+ )
1285
+
1286
+ def _get_context(
1287
+ self,
1288
+ task: TaskTable,
1289
+ run_mode: RunMode | str = RunMode.FAST,
1290
+ num_neighbors: list[int] | None = None,
1291
+ exclude_cols_dict: dict[str, list[str]] | None = None,
1292
+ top_k: int | None = None,
1293
+ ) -> Context:
1294
+
1295
+ if num_neighbors is None:
1296
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1297
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
1298
+
1299
+ if len(num_neighbors) > 6:
1300
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
1301
+ f"hops (got {len(num_neighbors)}). Reduce the "
1302
+ f"number of hops and try again. Please create a "
1303
+ f"feature request at "
1304
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1305
+ f"must go beyond this for your use-case.")
1306
+
1081
1307
  # Exclude the entity anchor time from the feature set to prevent
1082
1308
  # running out-of-distribution between in-context and test examples:
1083
- exclude_cols_dict = query.get_exclude_cols_dict()
1084
- if entity_table_names[0] in self._sampler.time_column_dict:
1085
- if entity_table_names[0] not in exclude_cols_dict:
1086
- exclude_cols_dict[entity_table_names[0]] = []
1087
- time_column = self._sampler.time_column_dict[entity_table_names[0]]
1088
- exclude_cols_dict[entity_table_names[0]].append(time_column)
1309
+ exclude_cols_dict = exclude_cols_dict or {}
1310
+ if task.entity_table_name in self._sampler.time_column_dict:
1311
+ if task.entity_table_name not in exclude_cols_dict:
1312
+ exclude_cols_dict[task.entity_table_name] = []
1313
+ time_col = self._sampler.time_column_dict[task.entity_table_name]
1314
+ exclude_cols_dict[task.entity_table_name].append(time_col)
1315
+
1316
+ entity_pkey = pd.concat([
1317
+ task._context_df[task._entity_column],
1318
+ task._pred_df[task._entity_column],
1319
+ ], axis=0, ignore_index=True)
1320
+
1321
+ if task.use_entity_time:
1322
+ if task.entity_table_name not in self._sampler.time_column_dict:
1323
+ raise ValueError(f"The given annchor time requires the entity "
1324
+ f"table '{task.entity_table_name}' to have a "
1325
+ f"time column")
1326
+ anchor_time = 'entity'
1327
+ elif task._time_column is not None:
1328
+ anchor_time = pd.concat([
1329
+ task._context_df[task._time_column],
1330
+ task._pred_df[task._time_column],
1331
+ ], axis=0, ignore_index=True)
1332
+ else:
1333
+ anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
1334
+ (len(entity_pkey))).reset_index(drop=True)
1089
1335
 
1090
1336
  subgraph = self._sampler.sample_subgraph(
1091
- entity_table_names=entity_table_names,
1092
- entity_pkey=pd.concat(
1093
- [train_pkey, test_pkey],
1094
- axis=0,
1095
- ignore_index=True,
1096
- ),
1097
- anchor_time=pd.concat(
1098
- [train_time, test_time],
1099
- axis=0,
1100
- ignore_index=True,
1101
- ) if isinstance(train_time, pd.Series) else 'entity',
1337
+ entity_table_names=task.entity_table_names,
1338
+ entity_pkey=entity_pkey,
1339
+ anchor_time=anchor_time,
1102
1340
  num_neighbors=num_neighbors,
1103
1341
  exclude_cols_dict=exclude_cols_dict,
1104
1342
  )
@@ -1110,19 +1348,26 @@ class KumoRFM:
1110
1348
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1111
1349
  f"must go beyond this for your use-case.")
1112
1350
 
1351
+ if (task.task_type.is_link_pred
1352
+ and task.entity_table_names[-1] not in subgraph.table_dict):
1353
+ raise ValueError("Cannot perform link prediction on subgraphs "
1354
+ "without any historical target entities. Please "
1355
+ "increase the number of hops and try again.")
1356
+
1113
1357
  return Context(
1114
- task_type=task_type,
1115
- entity_table_names=entity_table_names,
1358
+ task_type=task.task_type,
1359
+ entity_table_names=task.entity_table_names,
1116
1360
  subgraph=subgraph,
1117
- y_train=y_train,
1118
- y_test=y_test if evaluate else None,
1119
- top_k=query.top_k,
1361
+ y_train=task._context_df[task.target_column.name],
1362
+ y_test=task._pred_df[task.target_column.name]
1363
+ if task.evaluate else None,
1364
+ top_k=top_k,
1120
1365
  step_size=None,
1121
1366
  )
1122
1367
 
1123
1368
  @staticmethod
1124
1369
  def _validate_metrics(
1125
- metrics: List[str],
1370
+ metrics: list[str],
1126
1371
  task_type: TaskType,
1127
1372
  ) -> None:
1128
1373
 
@@ -1179,7 +1424,7 @@ class KumoRFM:
1179
1424
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1180
1425
 
1181
1426
 
1182
- def format_value(value: Union[int, float]) -> str:
1427
+ def format_value(value: int | float) -> str:
1183
1428
  if value == int(value):
1184
1429
  return f'{int(value):,}'
1185
1430
  if abs(value) >= 1000: