kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.15.0.dev202601151732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
  11. kumoai/experimental/rfm/backend/snow/table.py +146 -70
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/mapper.py +69 -0
  18. kumoai/experimental/rfm/base/sampler.py +28 -18
  19. kumoai/experimental/rfm/base/source.py +1 -1
  20. kumoai/experimental/rfm/base/sql_sampler.py +320 -19
  21. kumoai/experimental/rfm/base/table.py +256 -109
  22. kumoai/experimental/rfm/base/utils.py +36 -0
  23. kumoai/experimental/rfm/graph.py +130 -110
  24. kumoai/experimental/rfm/infer/dtype.py +7 -2
  25. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  26. kumoai/experimental/rfm/infer/time_col.py +4 -2
  27. kumoai/experimental/rfm/relbench.py +76 -0
  28. kumoai/experimental/rfm/rfm.py +540 -306
  29. kumoai/experimental/rfm/task_table.py +292 -0
  30. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  31. kumoai/pquery/training_table.py +16 -2
  32. kumoai/testing/snow.py +3 -3
  33. kumoai/trainer/distilled_trainer.py +175 -0
  34. kumoai/utils/display.py +87 -0
  35. kumoai/utils/progress_logger.py +15 -2
  36. kumoai/utils/sql.py +2 -2
  37. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA +2 -2
  38. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/RECORD +41 -36
  39. kumoai/experimental/rfm/base/column_expression.py +0 -50
  40. kumoai/experimental/rfm/base/sql_table.py +0 -229
  41. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import math
2
3
  import time
3
4
  import warnings
4
5
  from collections import defaultdict
@@ -7,7 +8,6 @@ from contextlib import contextmanager
7
8
  from dataclasses import dataclass, replace
8
9
  from typing import Any, Literal, overload
9
10
 
10
- import numpy as np
11
11
  import pandas as pd
12
12
  from kumoapi.model_plan import RunMode
13
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
@@ -27,31 +27,37 @@ from kumoapi.rfm import (
27
27
  )
28
28
  from kumoapi.task import TaskType
29
29
  from kumoapi.typing import AggregationType, Stype
30
+ from rich.console import Console
31
+ from rich.markdown import Markdown
30
32
 
31
- from kumoai import in_notebook, in_snowflake_notebook
33
+ from kumoai import in_notebook
32
34
  from kumoai.client.rfm import RFMAPI
33
35
  from kumoai.exceptions import HTTPException
34
- from kumoai.experimental.rfm import Graph
36
+ from kumoai.experimental.rfm import Graph, TaskTable
35
37
  from kumoai.experimental.rfm.base import DataBackend, Sampler
36
38
  from kumoai.mixin import CastMixin
37
- from kumoai.utils import ProgressLogger
39
+ from kumoai.utils import ProgressLogger, display
38
40
 
39
41
  _RANDOM_SEED = 42
40
42
 
41
43
  _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
42
44
  _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
43
45
 
46
+ _MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
47
+ _MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
48
+
44
49
  _MAX_CONTEXT_SIZE = {
45
50
  RunMode.DEBUG: 100,
46
51
  RunMode.FAST: 1_000,
47
52
  RunMode.NORMAL: 5_000,
48
53
  RunMode.BEST: 10_000,
49
54
  }
50
- _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
51
- RunMode.DEBUG: 100,
52
- RunMode.FAST: 2_000,
53
- RunMode.NORMAL: 2_000,
54
- RunMode.BEST: 2_000,
55
+
56
+ _DEFAULT_NUM_NEIGHBORS = {
57
+ RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
58
+ RunMode.FAST: [32, 32, 8, 8, 4, 4],
59
+ RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
60
+ RunMode.BEST: [64, 64, 8, 8, 4, 4],
55
61
  }
56
62
 
57
63
  _MAX_SIZE = 30 * 1024 * 1024
@@ -102,25 +108,20 @@ class Explanation:
102
108
  def __repr__(self) -> str:
103
109
  return str((self.prediction, self.summary))
104
110
 
111
+ def __str__(self) -> str:
112
+ console = Console(soft_wrap=True)
113
+ with console.capture() as cap:
114
+ console.print(display.to_rich_table(self.prediction))
115
+ console.print(Markdown(self.summary))
116
+ return cap.get()[:-1]
117
+
105
118
  def print(self) -> None:
106
119
  r"""Prints the explanation."""
107
- if in_snowflake_notebook():
108
- import streamlit as st
109
- st.dataframe(self.prediction, hide_index=True)
110
- st.markdown(self.summary)
111
- elif in_notebook():
112
- from IPython.display import Markdown, display
113
- try:
114
- if hasattr(self.prediction.style, 'hide'):
115
- display(self.prediction.hide(axis='index')) # pandas=2
116
- else:
117
- display(self.prediction.hide_index()) # pandas <1.3
118
- except ImportError:
119
- print(self.prediction.to_string(index=False)) # missing jinja2
120
- display(Markdown(self.summary))
120
+ if in_notebook():
121
+ display.dataframe(self.prediction)
122
+ display.message(self.summary)
121
123
  else:
122
- print(self.prediction.to_string(index=False))
123
- print(self.summary)
124
+ print(self)
124
125
 
125
126
  def _ipython_display_(self) -> None:
126
127
  self.print()
@@ -191,7 +192,7 @@ class KumoRFM:
191
192
  self._client: RFMAPI | None = None
192
193
 
193
194
  self._batch_size: int | Literal['max'] | None = None
194
- self.num_retries: int = 0
195
+ self._num_retries: int = 0
195
196
 
196
197
  @property
197
198
  def _api_client(self) -> RFMAPI:
@@ -205,6 +206,30 @@ class KumoRFM:
205
206
  def __repr__(self) -> str:
206
207
  return f'{self.__class__.__name__}()'
207
208
 
209
+ @contextmanager
210
+ def retry(
211
+ self,
212
+ num_retries: int = 1,
213
+ ) -> Generator[None, None, None]:
214
+ """Context manager to retry failed queries due to unexpected server
215
+ issues.
216
+
217
+ .. code-block:: python
218
+
219
+ with model.retry(num_retries=1):
220
+ df = model.predict(query, indices=...)
221
+
222
+ Args:
223
+ num_retries: The maximum number of retries.
224
+ """
225
+ if num_retries < 0:
226
+ raise ValueError(f"'num_retries' must be greater than or equal to "
227
+ f"zero (got {num_retries})")
228
+
229
+ self._num_retries = num_retries
230
+ yield
231
+ self._num_retries = 0
232
+
208
233
  @contextmanager
209
234
  def batch_mode(
210
235
  self,
@@ -228,15 +253,10 @@ class KumoRFM:
228
253
  raise ValueError(f"'batch_size' must be greater than zero "
229
254
  f"(got {batch_size})")
230
255
 
231
- if num_retries < 0:
232
- raise ValueError(f"'num_retries' must be greater than or equal to "
233
- f"zero (got {num_retries})")
234
-
235
256
  self._batch_size = batch_size
236
- self.num_retries = num_retries
237
- yield
257
+ with self.retry(self._num_retries or num_retries):
258
+ yield
238
259
  self._batch_size = None
239
- self.num_retries = 0
240
260
 
241
261
  @overload
242
262
  def predict(
@@ -276,6 +296,25 @@ class KumoRFM:
276
296
  ) -> Explanation:
277
297
  pass
278
298
 
299
+ @overload
300
+ def predict(
301
+ self,
302
+ query: str,
303
+ indices: list[str] | list[float] | list[int] | None = None,
304
+ *,
305
+ explain: bool | ExplainConfig | dict[str, Any] = False,
306
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
307
+ context_anchor_time: pd.Timestamp | None = None,
308
+ run_mode: RunMode | str = RunMode.FAST,
309
+ num_neighbors: list[int] | None = None,
310
+ num_hops: int = 2,
311
+ max_pq_iterations: int = 10,
312
+ random_seed: int | None = _RANDOM_SEED,
313
+ verbose: bool | ProgressLogger = True,
314
+ use_prediction_time: bool = False,
315
+ ) -> pd.DataFrame | Explanation:
316
+ pass
317
+
279
318
  def predict(
280
319
  self,
281
320
  query: str,
@@ -299,8 +338,7 @@ class KumoRFM:
299
338
  indices: The entity primary keys to predict for. Will override the
300
339
  indices given as part of the predictive query. Predictions will
301
340
  be generated for all indices, independent of whether they
302
- fulfill entity filter constraints. To pre-filter entities, use
303
- :meth:`~KumoRFM.is_valid_entity`.
341
+ fulfill entity filter constraints.
304
342
  explain: Configuration for explainability.
305
343
  If set to ``True``, will additionally explain the prediction.
306
344
  Passing in an :class:`ExplainConfig` instance provides control
@@ -333,18 +371,152 @@ class KumoRFM:
333
371
  If ``explain`` is provided, returns an :class:`Explanation` object
334
372
  containing the prediction, summary, and details.
335
373
  """
336
- explain_config: ExplainConfig | None = None
337
- if explain is True:
338
- explain_config = ExplainConfig()
339
- elif explain is not False:
340
- explain_config = ExplainConfig._cast(explain)
341
-
342
374
  query_def = self._parse_query(query)
343
- query_str = query_def.to_string()
344
375
 
376
+ if indices is None:
377
+ if query_def.rfm_entity_ids is None:
378
+ raise ValueError("Cannot find entities to predict for. Please "
379
+ "pass them via `predict(query, indices=...)`")
380
+ indices = query_def.get_rfm_entity_id_list()
381
+ query_def = replace(
382
+ query_def,
383
+ for_each='FOR EACH',
384
+ rfm_entity_ids=None,
385
+ )
386
+
387
+ if not isinstance(verbose, ProgressLogger):
388
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
389
+ if explain is not False:
390
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
391
+ else:
392
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
393
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
394
+
395
+ with verbose as logger:
396
+ task_table = self._get_task_table(
397
+ query=query_def,
398
+ indices=indices,
399
+ anchor_time=anchor_time,
400
+ context_anchor_time=context_anchor_time,
401
+ run_mode=run_mode,
402
+ max_pq_iterations=max_pq_iterations,
403
+ random_seed=random_seed,
404
+ logger=logger,
405
+ )
406
+ task_table._query = query_def.to_string()
407
+
408
+ return self.predict_task(
409
+ task_table,
410
+ explain=explain,
411
+ run_mode=run_mode,
412
+ num_neighbors=num_neighbors,
413
+ num_hops=num_hops,
414
+ verbose=verbose,
415
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
416
+ use_prediction_time=use_prediction_time,
417
+ top_k=query_def.top_k,
418
+ )
419
+
420
+ @overload
421
+ def predict_task(
422
+ self,
423
+ task: TaskTable,
424
+ *,
425
+ explain: Literal[False] = False,
426
+ run_mode: RunMode | str = RunMode.FAST,
427
+ num_neighbors: list[int] | None = None,
428
+ num_hops: int = 2,
429
+ verbose: bool | ProgressLogger = True,
430
+ exclude_cols_dict: dict[str, list[str]] | None = None,
431
+ use_prediction_time: bool = False,
432
+ top_k: int | None = None,
433
+ ) -> pd.DataFrame:
434
+ pass
435
+
436
+ @overload
437
+ def predict_task(
438
+ self,
439
+ task: TaskTable,
440
+ *,
441
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
442
+ run_mode: RunMode | str = RunMode.FAST,
443
+ num_neighbors: list[int] | None = None,
444
+ num_hops: int = 2,
445
+ verbose: bool | ProgressLogger = True,
446
+ exclude_cols_dict: dict[str, list[str]] | None = None,
447
+ use_prediction_time: bool = False,
448
+ top_k: int | None = None,
449
+ ) -> Explanation:
450
+ pass
451
+
452
+ @overload
453
+ def predict_task(
454
+ self,
455
+ task: TaskTable,
456
+ *,
457
+ explain: bool | ExplainConfig | dict[str, Any] = False,
458
+ run_mode: RunMode | str = RunMode.FAST,
459
+ num_neighbors: list[int] | None = None,
460
+ num_hops: int = 2,
461
+ verbose: bool | ProgressLogger = True,
462
+ exclude_cols_dict: dict[str, list[str]] | None = None,
463
+ use_prediction_time: bool = False,
464
+ top_k: int | None = None,
465
+ ) -> pd.DataFrame | Explanation:
466
+ pass
467
+
468
+ def predict_task(
469
+ self,
470
+ task: TaskTable,
471
+ *,
472
+ explain: bool | ExplainConfig | dict[str, Any] = False,
473
+ run_mode: RunMode | str = RunMode.FAST,
474
+ num_neighbors: list[int] | None = None,
475
+ num_hops: int = 2,
476
+ verbose: bool | ProgressLogger = True,
477
+ exclude_cols_dict: dict[str, list[str]] | None = None,
478
+ use_prediction_time: bool = False,
479
+ top_k: int | None = None,
480
+ ) -> pd.DataFrame | Explanation:
481
+ """Returns predictions for a custom task specification.
482
+
483
+ Args:
484
+ task: The custom :class:`TaskTable`.
485
+ explain: Configuration for explainability.
486
+ If set to ``True``, will additionally explain the prediction.
487
+ Passing in an :class:`ExplainConfig` instance provides control
488
+ over which parts of explanation are generated.
489
+ Explainability is currently only supported for single entity
490
+ predictions with ``run_mode="FAST"``.
491
+ run_mode: The :class:`RunMode` for the query.
492
+ num_neighbors: The number of neighbors to sample for each hop.
493
+ If specified, the ``num_hops`` option will be ignored.
494
+ num_hops: The number of hops to sample when generating the context.
495
+ verbose: Whether to print verbose output.
496
+ exclude_cols_dict: Any column in any table to exclude from the
497
+ model input.
498
+ use_prediction_time: Whether to use the anchor timestamp as an
499
+ additional feature during prediction. This is typically
500
+ beneficial for time series forecasting tasks.
501
+ top_k: The number of predictions to return per entity.
502
+
503
+ Returns:
504
+ The predictions as a :class:`pandas.DataFrame`.
505
+ If ``explain`` is provided, returns an :class:`Explanation` object
506
+ containing the prediction, summary, and details.
507
+ """
345
508
  if num_hops != 2 and num_neighbors is not None:
346
509
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
347
510
  f"custom 'num_hops={num_hops}' option")
511
+ if num_neighbors is None:
512
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
513
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
514
+
515
+ explain_config: ExplainConfig | None = None
516
+ if explain is True:
517
+ explain_config = ExplainConfig()
518
+ elif explain is not False:
519
+ explain_config = ExplainConfig._cast(explain)
348
520
 
349
521
  if explain_config is not None and run_mode in {
350
522
  RunMode.NORMAL, RunMode.BEST
@@ -353,83 +525,82 @@ class KumoRFM:
353
525
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
354
526
  f"mode has been reset. Please lower the run mode to "
355
527
  f"suppress this warning.")
528
+ run_mode = RunMode.FAST
356
529
 
357
- if indices is None:
358
- if query_def.rfm_entity_ids is None:
359
- raise ValueError("Cannot find entities to predict for. Please "
360
- "pass them via `predict(query, indices=...)`")
361
- indices = query_def.get_rfm_entity_id_list()
362
- else:
363
- query_def = replace(query_def, rfm_entity_ids=None)
364
-
365
- if len(indices) == 0:
366
- raise ValueError("At least one entity is required")
367
-
368
- if explain_config is not None and len(indices) > 1:
369
- raise ValueError(
370
- f"Cannot explain predictions for more than a single entity "
371
- f"(got {len(indices)})")
372
-
373
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
374
- if explain_config is not None:
375
- msg = f'[bold]EXPLAIN[/bold] {query_repr}'
376
- else:
377
- msg = f'[bold]PREDICT[/bold] {query_repr}'
530
+ if explain_config is not None and task.num_prediction_examples > 1:
531
+ raise ValueError(f"Cannot explain predictions for more than a "
532
+ f"single entity "
533
+ f"(got {task.num_prediction_examples:,})")
378
534
 
379
535
  if not isinstance(verbose, ProgressLogger):
536
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
537
+ task_type_repr = 'binary classification'
538
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
539
+ task_type_repr = 'multi-class classification'
540
+ elif task.task_type == TaskType.REGRESSION:
541
+ task_type_repr = 'regression'
542
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
543
+ task_type_repr = 'link prediction'
544
+ else:
545
+ task_type_repr = str(task.task_type)
546
+
547
+ if explain_config is not None:
548
+ msg = f"Explaining {task_type_repr} task"
549
+ else:
550
+ msg = f"Predicting {task_type_repr} task"
380
551
  verbose = ProgressLogger.default(msg=msg, verbose=verbose)
381
552
 
382
553
  with verbose as logger:
383
-
384
- batch_size: int | None = None
385
- if self._batch_size == 'max':
386
- task_type = self._get_task_type(
387
- query=query_def,
388
- edge_types=self._sampler.edge_types,
389
- )
390
- batch_size = _MAX_PRED_SIZE[task_type]
554
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
555
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
556
+ f"out of {task.num_context_examples:,} in-context "
557
+ f"examples")
558
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
559
+
560
+ if self._batch_size is None:
561
+ batch_size = task.num_prediction_examples
562
+ elif self._batch_size == 'max':
563
+ batch_size = _MAX_PRED_SIZE[task.task_type]
391
564
  else:
392
565
  batch_size = self._batch_size
393
566
 
394
- if batch_size is not None:
395
- offsets = range(0, len(indices), batch_size)
396
- batches = [indices[step:step + batch_size] for step in offsets]
397
- else:
398
- batches = [indices]
567
+ if batch_size > _MAX_PRED_SIZE[task.task_type]:
568
+ raise ValueError(f"Cannot predict for more than "
569
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
570
+ f"entities at once (got {batch_size:,}). Use "
571
+ f"`KumoRFM.batch_mode` to process entities "
572
+ f"in batches with a sufficient batch size.")
399
573
 
400
- if len(batches) > 1:
401
- logger.log(f"Splitting {len(indices):,} entities into "
402
- f"{len(batches):,} batches of size {batch_size:,}")
574
+ if task.num_prediction_examples > batch_size:
575
+ num = math.ceil(task.num_prediction_examples / batch_size)
576
+ logger.log(f"Splitting {task.num_prediction_examples:,} "
577
+ f"entities into {num:,} batches of size "
578
+ f"{batch_size:,}")
403
579
 
404
580
  predictions: list[pd.DataFrame] = []
405
581
  summary: str | None = None
406
582
  details: Explanation | None = None
407
- for i, batch in enumerate(batches):
408
- # TODO Re-use the context for subsequent predictions.
583
+ for start in range(0, task.num_prediction_examples, batch_size):
409
584
  context = self._get_context(
410
- query=query_def,
411
- indices=batch,
412
- anchor_time=anchor_time,
413
- context_anchor_time=context_anchor_time,
414
- run_mode=RunMode(run_mode),
585
+ task=task.narrow_prediction(start, length=batch_size),
586
+ run_mode=run_mode,
415
587
  num_neighbors=num_neighbors,
416
- num_hops=num_hops,
417
- max_pq_iterations=max_pq_iterations,
418
- evaluate=False,
419
- random_seed=random_seed,
420
- logger=logger if i == 0 else None,
588
+ exclude_cols_dict=exclude_cols_dict,
589
+ top_k=top_k,
421
590
  )
591
+ context.y_test = None
592
+
422
593
  request = RFMPredictRequest(
423
594
  context=context,
424
595
  run_mode=RunMode(run_mode),
425
- query=query_str,
596
+ query=task._query,
426
597
  use_prediction_time=use_prediction_time,
427
598
  )
428
599
  with warnings.catch_warnings():
429
600
  warnings.filterwarnings('ignore', message='gencode')
430
601
  request_msg = request.to_protobuf()
431
602
  _bytes = request_msg.SerializeToString()
432
- if i == 0:
603
+ if start == 0:
433
604
  logger.log(f"Generated context of size "
434
605
  f"{len(_bytes) / (1024*1024):.2f}MB")
435
606
 
@@ -437,13 +608,11 @@ class KumoRFM:
437
608
  stats = Context.get_memory_stats(request_msg.context)
438
609
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
439
610
 
440
- if i == 0 and len(batches) > 1:
441
- verbose.init_progress(
442
- total=len(batches),
443
- description='Predicting',
444
- )
611
+ if start == 0 and task.num_prediction_examples > batch_size:
612
+ num = math.ceil(task.num_prediction_examples / batch_size)
613
+ verbose.init_progress(total=num, description='Predicting')
445
614
 
446
- for attempt in range(self.num_retries + 1):
615
+ for attempt in range(self._num_retries + 1):
447
616
  try:
448
617
  if explain_config is not None:
449
618
  resp = self._api_client.explain(
@@ -459,7 +628,7 @@ class KumoRFM:
459
628
  # Cast 'ENTITY' to correct data type:
460
629
  if 'ENTITY' in df:
461
630
  table_dict = context.subgraph.table_dict
462
- table = table_dict[query_def.entity_table]
631
+ table = table_dict[context.entity_table_names[0]]
463
632
  ser = table.df[table.primary_key]
464
633
  df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
465
634
 
@@ -476,12 +645,12 @@ class KumoRFM:
476
645
 
477
646
  predictions.append(df)
478
647
 
479
- if len(batches) > 1:
648
+ if task.num_prediction_examples > batch_size:
480
649
  verbose.step()
481
650
 
482
651
  break
483
652
  except HTTPException as e:
484
- if attempt == self.num_retries:
653
+ if attempt == self._num_retries:
485
654
  try:
486
655
  msg = json.loads(e.detail)['detail']
487
656
  except Exception:
@@ -511,51 +680,6 @@ class KumoRFM:
511
680
 
512
681
  return prediction
513
682
 
514
- def is_valid_entity(
515
- self,
516
- query: str,
517
- indices: list[str] | list[float] | list[int] | None = None,
518
- *,
519
- anchor_time: pd.Timestamp | Literal['entity'] | None = None,
520
- ) -> np.ndarray:
521
- r"""Returns a mask that denotes which entities are valid for the
522
- given predictive query, *i.e.*, which entities fulfill (temporal)
523
- entity filter constraints.
524
-
525
- Args:
526
- query: The predictive query.
527
- indices: The entity primary keys to predict for. Will override the
528
- indices given as part of the predictive query.
529
- anchor_time: The anchor timestamp for the prediction. If set to
530
- ``None``, will use the maximum timestamp in the data.
531
- If set to ``"entity"``, will use the timestamp of the entity.
532
- """
533
- query_def = self._parse_query(query)
534
-
535
- if indices is None:
536
- if query_def.rfm_entity_ids is None:
537
- raise ValueError("Cannot find entities to predict for. Please "
538
- "pass them via "
539
- "`is_valid_entity(query, indices=...)`")
540
- indices = query_def.get_rfm_entity_id_list()
541
-
542
- if len(indices) == 0:
543
- raise ValueError("At least one entity is required")
544
-
545
- if anchor_time is None:
546
- anchor_time = self._get_default_anchor_time(query_def)
547
-
548
- if isinstance(anchor_time, pd.Timestamp):
549
- self._validate_time(query_def, anchor_time, None, False)
550
- else:
551
- assert anchor_time == 'entity'
552
- if query_def.entity_table not in self._sampler.time_column_dict:
553
- raise ValueError(f"Anchor time 'entity' requires the entity "
554
- f"table '{query_def.entity_table}' "
555
- f"to have a time column.")
556
-
557
- raise NotImplementedError
558
-
559
683
  def evaluate(
560
684
  self,
561
685
  query: str,
@@ -600,41 +724,120 @@ class KumoRFM:
600
724
  Returns:
601
725
  The metrics as a :class:`pandas.DataFrame`
602
726
  """
603
- query_def = self._parse_query(query)
727
+ query_def = replace(
728
+ self._parse_query(query),
729
+ for_each='FOR EACH',
730
+ rfm_entity_ids=None,
731
+ )
732
+
733
+ if not isinstance(verbose, ProgressLogger):
734
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
735
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
736
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
737
+
738
+ with verbose as logger:
739
+ task_table = self._get_task_table(
740
+ query=query_def,
741
+ indices=None,
742
+ anchor_time=anchor_time,
743
+ context_anchor_time=context_anchor_time,
744
+ run_mode=run_mode,
745
+ max_pq_iterations=max_pq_iterations,
746
+ random_seed=random_seed,
747
+ logger=logger,
748
+ )
749
+
750
+ return self.evaluate_task(
751
+ task_table,
752
+ metrics=metrics,
753
+ run_mode=run_mode,
754
+ num_neighbors=num_neighbors,
755
+ num_hops=num_hops,
756
+ verbose=verbose,
757
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
758
+ use_prediction_time=use_prediction_time,
759
+ )
604
760
 
761
+ def evaluate_task(
762
+ self,
763
+ task: TaskTable,
764
+ *,
765
+ metrics: list[str] | None = None,
766
+ run_mode: RunMode | str = RunMode.FAST,
767
+ num_neighbors: list[int] | None = None,
768
+ num_hops: int = 2,
769
+ verbose: bool | ProgressLogger = True,
770
+ exclude_cols_dict: dict[str, list[str]] | None = None,
771
+ use_prediction_time: bool = False,
772
+ ) -> pd.DataFrame:
773
+ """Evaluates a custom task specification.
774
+
775
+ Args:
776
+ task: The custom :class:`TaskTable`.
777
+ metrics: The metrics to use.
778
+ run_mode: The :class:`RunMode` for the query.
779
+ num_neighbors: The number of neighbors to sample for each hop.
780
+ If specified, the ``num_hops`` option will be ignored.
781
+ num_hops: The number of hops to sample when generating the context.
782
+ verbose: Whether to print verbose output.
783
+ exclude_cols_dict: Any column in any table to exclude from the
784
+ model input.
785
+ use_prediction_time: Whether to use the anchor timestamp as an
786
+ additional feature during prediction. This is typically
787
+ beneficial for time series forecasting tasks.
788
+
789
+ Returns:
790
+ The metrics as a :class:`pandas.DataFrame`
791
+ """
605
792
  if num_hops != 2 and num_neighbors is not None:
606
793
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
607
794
  f"custom 'num_hops={num_hops}' option")
795
+ if num_neighbors is None:
796
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
797
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
608
798
 
609
- if query_def.rfm_entity_ids is not None:
610
- query_def = replace(
611
- query_def,
612
- rfm_entity_ids=None,
613
- )
614
-
615
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
616
- msg = f'[bold]EVALUATE[/bold] {query_repr}'
799
+ if metrics is not None and len(metrics) > 0:
800
+ self._validate_metrics(metrics, task.task_type)
801
+ metrics = list(dict.fromkeys(metrics))
617
802
 
618
803
  if not isinstance(verbose, ProgressLogger):
804
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
805
+ task_type_repr = 'binary classification'
806
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
807
+ task_type_repr = 'multi-class classification'
808
+ elif task.task_type == TaskType.REGRESSION:
809
+ task_type_repr = 'regression'
810
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
811
+ task_type_repr = 'link prediction'
812
+ else:
813
+ task_type_repr = str(task.task_type)
814
+
815
+ msg = f"Evaluating {task_type_repr} task"
619
816
  verbose = ProgressLogger.default(msg=msg, verbose=verbose)
620
817
 
621
818
  with verbose as logger:
819
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
820
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
821
+ f"out of {task.num_context_examples:,} in-context "
822
+ f"examples")
823
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
824
+
825
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
826
+ logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
827
+ f"out of {task.num_prediction_examples:,} test "
828
+ f"examples")
829
+ task = task.narrow_prediction(
830
+ start=0,
831
+ length=_MAX_TEST_SIZE[task.task_type],
832
+ )
833
+
622
834
  context = self._get_context(
623
- query=query_def,
624
- indices=None,
625
- anchor_time=anchor_time,
626
- context_anchor_time=context_anchor_time,
627
- run_mode=RunMode(run_mode),
835
+ task=task,
836
+ run_mode=run_mode,
628
837
  num_neighbors=num_neighbors,
629
- num_hops=num_hops,
630
- max_pq_iterations=max_pq_iterations,
631
- evaluate=True,
632
- random_seed=random_seed,
633
- logger=logger if verbose else None,
838
+ exclude_cols_dict=exclude_cols_dict,
634
839
  )
635
- if metrics is not None and len(metrics) > 0:
636
- self._validate_metrics(metrics, context.task_type)
637
- metrics = list(dict.fromkeys(metrics))
840
+
638
841
  request = RFMEvaluateRequest(
639
842
  context=context,
640
843
  run_mode=RunMode(run_mode),
@@ -652,17 +855,23 @@ class KumoRFM:
652
855
  stats_msg = Context.get_memory_stats(request_msg.context)
653
856
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
654
857
 
655
- try:
656
- resp = self._api_client.evaluate(request_bytes)
657
- except HTTPException as e:
858
+ for attempt in range(self._num_retries + 1):
658
859
  try:
659
- msg = json.loads(e.detail)['detail']
660
- except Exception:
661
- msg = e.detail
662
- raise RuntimeError(f"An unexpected exception occurred. "
663
- f"Please create an issue at "
664
- f"'https://github.com/kumo-ai/kumo-rfm'. "
665
- f"{msg}") from None
860
+ resp = self._api_client.evaluate(request_bytes)
861
+ break
862
+ except HTTPException as e:
863
+ if attempt == self._num_retries:
864
+ try:
865
+ msg = json.loads(e.detail)['detail']
866
+ except Exception:
867
+ msg = e.detail
868
+ raise RuntimeError(
869
+ f"An unexpected exception occurred. Please create "
870
+ f"an issue at "
871
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
872
+ ) from None
873
+
874
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
666
875
 
667
876
  return pd.DataFrame.from_dict(
668
877
  resp.metrics,
@@ -714,7 +923,7 @@ class KumoRFM:
714
923
  f"to have a time column")
715
924
 
716
925
  train, test = self._sampler.sample_target(
717
- query=query,
926
+ query=query_def,
718
927
  num_train_examples=0,
719
928
  train_anchor_time=anchor_time,
720
929
  num_train_trials=0,
@@ -742,30 +951,34 @@ class KumoRFM:
742
951
  "`predict()` or `evaluate()` methods to perform "
743
952
  "predictions or evaluations.")
744
953
 
745
- try:
746
- request = RFMParseQueryRequest(
747
- query=query,
748
- graph_definition=self._graph_def,
749
- )
954
+ request = RFMParseQueryRequest(
955
+ query=query,
956
+ graph_definition=self._graph_def,
957
+ )
958
+
959
+ for attempt in range(self._num_retries + 1):
960
+ try:
961
+ resp = self._api_client.parse_query(request)
962
+ break
963
+ except HTTPException as e:
964
+ if attempt == self._num_retries:
965
+ try:
966
+ msg = json.loads(e.detail)['detail']
967
+ except Exception:
968
+ msg = e.detail
969
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
750
970
 
751
- resp = self._api_client.parse_query(request)
971
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
752
972
 
753
- if len(resp.validation_response.warnings) > 0:
754
- msg = '\n'.join([
755
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
756
- in enumerate(resp.validation_response.warnings)
757
- ])
758
- warnings.warn(f"Encountered the following warnings during "
759
- f"parsing:\n{msg}")
973
+ if len(resp.validation_response.warnings) > 0:
974
+ msg = '\n'.join([
975
+ f'{i+1}. {warning.title}: {warning.message}'
976
+ for i, warning in enumerate(resp.validation_response.warnings)
977
+ ])
978
+ warnings.warn(f"Encountered the following warnings during "
979
+ f"parsing:\n{msg}")
760
980
 
761
- return resp.query
762
- except HTTPException as e:
763
- try:
764
- msg = json.loads(e.detail)['detail']
765
- except Exception:
766
- msg = e.detail
767
- raise ValueError(f"Failed to parse query '{query}'. "
768
- f"{msg}") from None
981
+ return resp.query
769
982
 
770
983
  @staticmethod
771
984
  def _get_task_type(
@@ -809,16 +1022,15 @@ class KumoRFM:
809
1022
 
810
1023
  def _get_default_anchor_time(
811
1024
  self,
812
- query: ValidatedPredictiveQuery,
1025
+ query: ValidatedPredictiveQuery | None = None,
813
1026
  ) -> pd.Timestamp:
814
- if query.query_type == QueryType.TEMPORAL:
1027
+ if query is not None and query.query_type == QueryType.TEMPORAL:
815
1028
  aggr_table_names = [
816
1029
  aggr._get_target_column_name().split('.')[0]
817
1030
  for aggr in query.get_all_target_aggregations()
818
1031
  ]
819
1032
  return self._sampler.get_max_time(aggr_table_names)
820
1033
 
821
- assert query.query_type == QueryType.STATIC
822
1034
  return self._sampler.get_max_time()
823
1035
 
824
1036
  def _validate_time(
@@ -832,8 +1044,16 @@ class KumoRFM:
832
1044
  if len(self._sampler.time_column_dict) == 0:
833
1045
  return # Graph without timestamps
834
1046
 
835
- min_time = self._sampler.get_min_time()
836
- max_time = self._sampler.get_max_time()
1047
+ if query.query_type == QueryType.TEMPORAL:
1048
+ aggr_table_names = [
1049
+ aggr._get_target_column_name().split('.')[0]
1050
+ for aggr in query.get_all_target_aggregations()
1051
+ ]
1052
+ min_time = self._sampler.get_min_time(aggr_table_names)
1053
+ max_time = self._sampler.get_max_time(aggr_table_names)
1054
+ else:
1055
+ min_time = self._sampler.get_min_time()
1056
+ max_time = self._sampler.get_max_time()
837
1057
 
838
1058
  if anchor_time < min_time:
839
1059
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
@@ -888,40 +1108,26 @@ class KumoRFM:
888
1108
  f"Anchor timestamp for evaluation is after the latest "
889
1109
  f"supported timestamp '{max_time - end_offset}'.")
890
1110
 
891
- def _get_context(
1111
+ def _get_task_table(
892
1112
  self,
893
1113
  query: ValidatedPredictiveQuery,
894
1114
  indices: list[str] | list[float] | list[int] | None,
895
- anchor_time: pd.Timestamp | Literal['entity'] | None,
896
- context_anchor_time: pd.Timestamp | None,
897
- run_mode: RunMode,
898
- num_neighbors: list[int] | None,
899
- num_hops: int,
900
- max_pq_iterations: int,
901
- evaluate: bool,
1115
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
1116
+ context_anchor_time: pd.Timestamp | None = None,
1117
+ run_mode: RunMode = RunMode.FAST,
1118
+ max_pq_iterations: int = 10,
902
1119
  random_seed: int | None = _RANDOM_SEED,
903
1120
  logger: ProgressLogger | None = None,
904
- ) -> Context:
905
-
906
- if num_neighbors is not None:
907
- num_hops = len(num_neighbors)
908
-
909
- if num_hops < 0:
910
- raise ValueError(f"'num_hops' must be non-negative "
911
- f"(got {num_hops})")
912
- if num_hops > 6:
913
- raise ValueError(f"Cannot predict on subgraphs with more than 6 "
914
- f"hops (got {num_hops}). Please reduce the "
915
- f"number of hops and try again. Please create a "
916
- f"feature request at "
917
- f"'https://github.com/kumo-ai/kumo-rfm' if you "
918
- f"must go beyond this for your use-case.")
1121
+ ) -> TaskTable:
919
1122
 
920
1123
  task_type = self._get_task_type(
921
1124
  query=query,
922
1125
  edge_types=self._sampler.edge_types,
923
1126
  )
924
1127
 
1128
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1129
+ num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
1130
+
925
1131
  if logger is not None:
926
1132
  if task_type == TaskType.BINARY_CLASSIFICATION:
927
1133
  task_type_repr = 'binary classification'
@@ -935,21 +1141,6 @@ class KumoRFM:
935
1141
  task_type_repr = str(task_type)
936
1142
  logger.log(f"Identified {query.query_type} {task_type_repr} task")
937
1143
 
938
- if task_type.is_link_pred and num_hops < 2:
939
- raise ValueError(f"Cannot perform link prediction on subgraphs "
940
- f"with less than 2 hops (got {num_hops}) since "
941
- f"historical target entities need to be part of "
942
- f"the context. Please increase the number of "
943
- f"hops and try again.")
944
-
945
- if num_neighbors is None:
946
- if run_mode == RunMode.DEBUG:
947
- num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
948
- elif run_mode == RunMode.FAST or task_type.is_link_pred:
949
- num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
950
- else:
951
- num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
952
-
953
1144
  if query.target_ast.date_offset_range is None:
954
1145
  step_offset = pd.DateOffset(0)
955
1146
  else:
@@ -958,8 +1149,7 @@ class KumoRFM:
958
1149
 
959
1150
  if anchor_time is None:
960
1151
  anchor_time = self._get_default_anchor_time(query)
961
-
962
- if evaluate:
1152
+ if num_test_examples > 0:
963
1153
  anchor_time = anchor_time - end_offset
964
1154
 
965
1155
  if logger is not None:
@@ -973,7 +1163,6 @@ class KumoRFM:
973
1163
  else:
974
1164
  logger.log(f"Derived anchor time {anchor_time}")
975
1165
 
976
- assert anchor_time is not None
977
1166
  if isinstance(anchor_time, pd.Timestamp):
978
1167
  if context_anchor_time == 'entity':
979
1168
  raise ValueError("Anchor time 'entity' needs to be shared "
@@ -981,7 +1170,7 @@ class KumoRFM:
981
1170
  if context_anchor_time is None:
982
1171
  context_anchor_time = anchor_time - end_offset
983
1172
  self._validate_time(query, anchor_time, context_anchor_time,
984
- evaluate)
1173
+ evaluate=num_test_examples > 0)
985
1174
  else:
986
1175
  assert anchor_time == 'entity'
987
1176
  if query.query_type != QueryType.STATIC:
@@ -996,14 +1185,6 @@ class KumoRFM:
996
1185
  "for context and prediction examples")
997
1186
  context_anchor_time = 'entity'
998
1187
 
999
- num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1000
- if evaluate:
1001
- num_test_examples = _MAX_TEST_SIZE[run_mode]
1002
- if task_type.is_link_pred:
1003
- num_test_examples = num_test_examples // 5
1004
- else:
1005
- num_test_examples = 0
1006
-
1007
1188
  train, test = self._sampler.sample_target(
1008
1189
  query=query,
1009
1190
  num_train_examples=num_train_examples,
@@ -1014,39 +1195,32 @@ class KumoRFM:
1014
1195
  num_test_trials=max_pq_iterations * num_test_examples,
1015
1196
  random_seed=random_seed,
1016
1197
  )
1017
- train_pkey, train_time, y_train = train
1018
- test_pkey, test_time, y_test = test
1198
+ train_pkey, train_time, train_y = train
1199
+ test_pkey, test_time, test_y = test
1019
1200
 
1020
- if evaluate and logger is not None:
1201
+ if num_test_examples > 0 and logger is not None:
1021
1202
  if task_type == TaskType.BINARY_CLASSIFICATION:
1022
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
1023
- msg = (f"Collected {len(y_test):,} test examples with "
1203
+ pos = 100 * int((test_y > 0).sum()) / len(test_y)
1204
+ msg = (f"Collected {len(test_y):,} test examples with "
1024
1205
  f"{pos:.2f}% positive cases")
1025
1206
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1026
- msg = (f"Collected {len(y_test):,} test examples holding "
1027
- f"{y_test.nunique()} classes")
1207
+ msg = (f"Collected {len(test_y):,} test examples holding "
1208
+ f"{test_y.nunique()} classes")
1028
1209
  elif task_type == TaskType.REGRESSION:
1029
- _min, _max = float(y_test.min()), float(y_test.max())
1030
- msg = (f"Collected {len(y_test):,} test examples with targets "
1210
+ _min, _max = float(test_y.min()), float(test_y.max())
1211
+ msg = (f"Collected {len(test_y):,} test examples with targets "
1031
1212
  f"between {format_value(_min)} and "
1032
1213
  f"{format_value(_max)}")
1033
1214
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1034
- num_rhs = y_test.explode().nunique()
1035
- msg = (f"Collected {len(y_test):,} test examples with "
1215
+ num_rhs = test_y.explode().nunique()
1216
+ msg = (f"Collected {len(test_y):,} test examples with "
1036
1217
  f"{num_rhs:,} unique items")
1037
1218
  else:
1038
1219
  raise NotImplementedError
1039
1220
  logger.log(msg)
1040
1221
 
1041
- if not evaluate:
1222
+ if num_test_examples == 0:
1042
1223
  assert indices is not None
1043
- if len(indices) > _MAX_PRED_SIZE[task_type]:
1044
- raise ValueError(f"Cannot predict for more than "
1045
- f"{_MAX_PRED_SIZE[task_type]:,} entities at "
1046
- f"once (got {len(indices):,}). Use "
1047
- f"`KumoRFM.batch_mode` to process entities "
1048
- f"in batches")
1049
-
1050
1224
  test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
1051
1225
  if isinstance(anchor_time, pd.Timestamp):
1052
1226
  test_time = pd.Series([anchor_time]).repeat(
@@ -1056,26 +1230,26 @@ class KumoRFM:
1056
1230
 
1057
1231
  if logger is not None:
1058
1232
  if task_type == TaskType.BINARY_CLASSIFICATION:
1059
- pos = 100 * int((y_train > 0).sum()) / len(y_train)
1060
- msg = (f"Collected {len(y_train):,} in-context examples with "
1233
+ pos = 100 * int((train_y > 0).sum()) / len(train_y)
1234
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1061
1235
  f"{pos:.2f}% positive cases")
1062
1236
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1063
- msg = (f"Collected {len(y_train):,} in-context examples "
1064
- f"holding {y_train.nunique()} classes")
1237
+ msg = (f"Collected {len(train_y):,} in-context examples "
1238
+ f"holding {train_y.nunique()} classes")
1065
1239
  elif task_type == TaskType.REGRESSION:
1066
- _min, _max = float(y_train.min()), float(y_train.max())
1067
- msg = (f"Collected {len(y_train):,} in-context examples with "
1240
+ _min, _max = float(train_y.min()), float(train_y.max())
1241
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1068
1242
  f"targets between {format_value(_min)} and "
1069
1243
  f"{format_value(_max)}")
1070
1244
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1071
- num_rhs = y_train.explode().nunique()
1072
- msg = (f"Collected {len(y_train):,} in-context examples with "
1245
+ num_rhs = train_y.explode().nunique()
1246
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1073
1247
  f"{num_rhs:,} unique items")
1074
1248
  else:
1075
1249
  raise NotImplementedError
1076
1250
  logger.log(msg)
1077
1251
 
1078
- entity_table_names: tuple[str, ...]
1252
+ entity_table_names: tuple[str] | tuple[str, str]
1079
1253
  if task_type.is_link_pred:
1080
1254
  final_aggr = query.get_final_target_aggregation()
1081
1255
  assert final_aggr is not None
@@ -1089,27 +1263,80 @@ class KumoRFM:
1089
1263
  else:
1090
1264
  entity_table_names = (query.entity_table, )
1091
1265
 
1266
+ context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
1267
+ if isinstance(train_time, pd.Series):
1268
+ context_df['ANCHOR_TIMESTAMP'] = train_time
1269
+ pred_df = pd.DataFrame({'ENTITY': test_pkey})
1270
+ if num_test_examples > 0:
1271
+ pred_df['TARGET'] = test_y
1272
+ if isinstance(test_time, pd.Series):
1273
+ pred_df['ANCHOR_TIMESTAMP'] = test_time
1274
+
1275
+ return TaskTable(
1276
+ task_type=task_type,
1277
+ context_df=context_df,
1278
+ pred_df=pred_df,
1279
+ entity_table_name=entity_table_names,
1280
+ entity_column='ENTITY',
1281
+ target_column='TARGET',
1282
+ time_column='ANCHOR_TIMESTAMP' if isinstance(
1283
+ train_time, pd.Series) else TaskTable.ENTITY_TIME,
1284
+ )
1285
+
1286
+ def _get_context(
1287
+ self,
1288
+ task: TaskTable,
1289
+ run_mode: RunMode | str = RunMode.FAST,
1290
+ num_neighbors: list[int] | None = None,
1291
+ exclude_cols_dict: dict[str, list[str]] | None = None,
1292
+ top_k: int | None = None,
1293
+ ) -> Context:
1294
+
1295
+ if num_neighbors is None:
1296
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1297
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
1298
+
1299
+ if len(num_neighbors) > 6:
1300
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
1301
+ f"hops (got {len(num_neighbors)}). Reduce the "
1302
+ f"number of hops and try again. Please create a "
1303
+ f"feature request at "
1304
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1305
+ f"must go beyond this for your use-case.")
1306
+
1092
1307
  # Exclude the entity anchor time from the feature set to prevent
1093
1308
  # running out-of-distribution between in-context and test examples:
1094
- exclude_cols_dict = query.get_exclude_cols_dict()
1095
- if entity_table_names[0] in self._sampler.time_column_dict:
1096
- if entity_table_names[0] not in exclude_cols_dict:
1097
- exclude_cols_dict[entity_table_names[0]] = []
1098
- time_column = self._sampler.time_column_dict[entity_table_names[0]]
1099
- exclude_cols_dict[entity_table_names[0]].append(time_column)
1309
+ exclude_cols_dict = exclude_cols_dict or {}
1310
+ if task.entity_table_name in self._sampler.time_column_dict:
1311
+ if task.entity_table_name not in exclude_cols_dict:
1312
+ exclude_cols_dict[task.entity_table_name] = []
1313
+ time_col = self._sampler.time_column_dict[task.entity_table_name]
1314
+ exclude_cols_dict[task.entity_table_name].append(time_col)
1315
+
1316
+ entity_pkey = pd.concat([
1317
+ task._context_df[task._entity_column],
1318
+ task._pred_df[task._entity_column],
1319
+ ], axis=0, ignore_index=True)
1320
+
1321
+ if task.use_entity_time:
1322
+ if task.entity_table_name not in self._sampler.time_column_dict:
1323
+ raise ValueError(f"The given annchor time requires the entity "
1324
+ f"table '{task.entity_table_name}' to have a "
1325
+ f"time column")
1326
+ anchor_time = 'entity'
1327
+ elif task._time_column is not None:
1328
+ anchor_time = pd.concat([
1329
+ task._context_df[task._time_column],
1330
+ task._pred_df[task._time_column],
1331
+ ], axis=0, ignore_index=True)
1332
+ else:
1333
+ anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
1334
+ (len(entity_pkey))).reset_index(drop=True)
1100
1335
 
1101
1336
  subgraph = self._sampler.sample_subgraph(
1102
- entity_table_names=entity_table_names,
1103
- entity_pkey=pd.concat(
1104
- [train_pkey, test_pkey],
1105
- axis=0,
1106
- ignore_index=True,
1107
- ),
1108
- anchor_time=pd.concat(
1109
- [train_time, test_time],
1110
- axis=0,
1111
- ignore_index=True,
1112
- ) if isinstance(train_time, pd.Series) else 'entity',
1337
+ entity_table_names=task.entity_table_names,
1338
+ entity_pkey=entity_pkey,
1339
+ anchor_time=anchor_time,
1113
1340
  num_neighbors=num_neighbors,
1114
1341
  exclude_cols_dict=exclude_cols_dict,
1115
1342
  )
@@ -1121,13 +1348,20 @@ class KumoRFM:
1121
1348
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1122
1349
  f"must go beyond this for your use-case.")
1123
1350
 
1351
+ if (task.task_type.is_link_pred
1352
+ and task.entity_table_names[-1] not in subgraph.table_dict):
1353
+ raise ValueError("Cannot perform link prediction on subgraphs "
1354
+ "without any historical target entities. Please "
1355
+ "increase the number of hops and try again.")
1356
+
1124
1357
  return Context(
1125
- task_type=task_type,
1126
- entity_table_names=entity_table_names,
1358
+ task_type=task.task_type,
1359
+ entity_table_names=task.entity_table_names,
1127
1360
  subgraph=subgraph,
1128
- y_train=y_train,
1129
- y_test=y_test if evaluate else None,
1130
- top_k=query.top_k,
1361
+ y_train=task._context_df[task.target_column.name],
1362
+ y_test=task._pred_df[task.target_column.name]
1363
+ if task.evaluate else None,
1364
+ top_k=top_k,
1131
1365
  step_size=None,
1132
1366
  )
1133
1367