kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
  11. kumoai/experimental/rfm/backend/snow/table.py +137 -64
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/sampler.py +5 -17
  18. kumoai/experimental/rfm/base/source.py +1 -1
  19. kumoai/experimental/rfm/base/sql_sampler.py +69 -9
  20. kumoai/experimental/rfm/base/table.py +258 -97
  21. kumoai/experimental/rfm/graph.py +106 -98
  22. kumoai/experimental/rfm/infer/dtype.py +4 -1
  23. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  24. kumoai/experimental/rfm/relbench.py +76 -0
  25. kumoai/experimental/rfm/rfm.py +394 -241
  26. kumoai/experimental/rfm/task_table.py +290 -0
  27. kumoai/trainer/distilled_trainer.py +175 -0
  28. kumoai/utils/display.py +51 -0
  29. kumoai/utils/progress_logger.py +13 -1
  30. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +1 -1
  31. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +34 -31
  32. kumoai/experimental/rfm/base/column_expression.py +0 -50
  33. kumoai/experimental/rfm/base/sql_table.py +0 -229
  34. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
  35. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
  36. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import math
2
3
  import time
3
4
  import warnings
4
5
  from collections import defaultdict
@@ -28,30 +29,33 @@ from kumoapi.rfm import (
28
29
  from kumoapi.task import TaskType
29
30
  from kumoapi.typing import AggregationType, Stype
30
31
 
31
- from kumoai import in_notebook, in_snowflake_notebook
32
32
  from kumoai.client.rfm import RFMAPI
33
33
  from kumoai.exceptions import HTTPException
34
- from kumoai.experimental.rfm import Graph
34
+ from kumoai.experimental.rfm import Graph, TaskTable
35
35
  from kumoai.experimental.rfm.base import DataBackend, Sampler
36
36
  from kumoai.mixin import CastMixin
37
- from kumoai.utils import ProgressLogger
37
+ from kumoai.utils import ProgressLogger, display
38
38
 
39
39
  _RANDOM_SEED = 42
40
40
 
41
41
  _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
42
42
  _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
43
43
 
44
+ _MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
45
+ _MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
46
+
44
47
  _MAX_CONTEXT_SIZE = {
45
48
  RunMode.DEBUG: 100,
46
49
  RunMode.FAST: 1_000,
47
50
  RunMode.NORMAL: 5_000,
48
51
  RunMode.BEST: 10_000,
49
52
  }
50
- _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
51
- RunMode.DEBUG: 100,
52
- RunMode.FAST: 2_000,
53
- RunMode.NORMAL: 2_000,
54
- RunMode.BEST: 2_000,
53
+
54
+ _DEFAULT_NUM_NEIGHBORS = {
55
+ RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
56
+ RunMode.FAST: [32, 32, 8, 8, 4, 4],
57
+ RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
58
+ RunMode.BEST: [64, 64, 8, 8, 4, 4],
55
59
  }
56
60
 
57
61
  _MAX_SIZE = 30 * 1024 * 1024
@@ -104,23 +108,8 @@ class Explanation:
104
108
 
105
109
  def print(self) -> None:
106
110
  r"""Prints the explanation."""
107
- if in_snowflake_notebook():
108
- import streamlit as st
109
- st.dataframe(self.prediction, hide_index=True)
110
- st.markdown(self.summary)
111
- elif in_notebook():
112
- from IPython.display import Markdown, display
113
- try:
114
- if hasattr(self.prediction.style, 'hide'):
115
- display(self.prediction.hide(axis='index')) # pandas=2
116
- else:
117
- display(self.prediction.hide_index()) # pandas <1.3
118
- except ImportError:
119
- print(self.prediction.to_string(index=False)) # missing jinja2
120
- display(Markdown(self.summary))
121
- else:
122
- print(self.prediction.to_string(index=False))
123
- print(self.summary)
111
+ display.dataframe(self.prediction)
112
+ display.message(self.summary)
124
113
 
125
114
  def _ipython_display_(self) -> None:
126
115
  self.print()
@@ -333,18 +322,133 @@ class KumoRFM:
333
322
  If ``explain`` is provided, returns an :class:`Explanation` object
334
323
  containing the prediction, summary, and details.
335
324
  """
336
- explain_config: ExplainConfig | None = None
337
- if explain is True:
338
- explain_config = ExplainConfig()
339
- elif explain is not False:
340
- explain_config = ExplainConfig._cast(explain)
341
-
342
325
  query_def = self._parse_query(query)
343
- query_str = query_def.to_string()
344
326
 
327
+ if indices is None:
328
+ if query_def.rfm_entity_ids is None:
329
+ raise ValueError("Cannot find entities to predict for. Please "
330
+ "pass them via `predict(query, indices=...)`")
331
+ indices = query_def.get_rfm_entity_id_list()
332
+ else:
333
+ query_def = replace(query_def, rfm_entity_ids=None)
334
+
335
+ if not isinstance(verbose, ProgressLogger):
336
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
337
+ if explain is not False:
338
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
339
+ else:
340
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
341
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
342
+
343
+ with verbose as logger:
344
+ task_table = self._get_task_table(
345
+ query=query_def,
346
+ indices=indices,
347
+ anchor_time=anchor_time,
348
+ context_anchor_time=context_anchor_time,
349
+ run_mode=run_mode,
350
+ max_pq_iterations=max_pq_iterations,
351
+ random_seed=random_seed,
352
+ logger=logger,
353
+ )
354
+ task_table._query = query_def.to_string() # type: ignore
355
+
356
+ return self.predict_task(
357
+ task_table,
358
+ explain=explain, # type: ignore
359
+ run_mode=run_mode,
360
+ num_neighbors=num_neighbors,
361
+ num_hops=num_hops,
362
+ verbose=verbose,
363
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
364
+ use_prediction_time=use_prediction_time,
365
+ top_k=query_def.top_k,
366
+ )
367
+
368
+ @overload
369
+ def predict_task(
370
+ self,
371
+ task: TaskTable,
372
+ *,
373
+ explain: Literal[False] = False,
374
+ run_mode: RunMode | str = RunMode.FAST,
375
+ num_neighbors: list[int] | None = None,
376
+ num_hops: int = 2,
377
+ verbose: bool | ProgressLogger = True,
378
+ exclude_cols_dict: dict[str, list[str]] | None = None,
379
+ use_prediction_time: bool = False,
380
+ top_k: int | None = None,
381
+ ) -> pd.DataFrame:
382
+ pass
383
+
384
+ @overload
385
+ def predict_task(
386
+ self,
387
+ task: TaskTable,
388
+ *,
389
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
390
+ run_mode: RunMode | str = RunMode.FAST,
391
+ num_neighbors: list[int] | None = None,
392
+ num_hops: int = 2,
393
+ verbose: bool | ProgressLogger = True,
394
+ exclude_cols_dict: dict[str, list[str]] | None = None,
395
+ use_prediction_time: bool = False,
396
+ top_k: int | None = None,
397
+ ) -> Explanation:
398
+ pass
399
+
400
+ def predict_task(
401
+ self,
402
+ task: TaskTable,
403
+ *,
404
+ explain: bool | ExplainConfig | dict[str, Any] = False,
405
+ run_mode: RunMode | str = RunMode.FAST,
406
+ num_neighbors: list[int] | None = None,
407
+ num_hops: int = 2,
408
+ verbose: bool | ProgressLogger = True,
409
+ exclude_cols_dict: dict[str, list[str]] | None = None,
410
+ use_prediction_time: bool = False,
411
+ top_k: int | None = None,
412
+ ) -> pd.DataFrame | Explanation:
413
+ """Returns predictions for a custom task specification.
414
+
415
+ Args:
416
+ task: The custom :class:`TaskTable`.
417
+ explain: Configuration for explainability.
418
+ If set to ``True``, will additionally explain the prediction.
419
+ Passing in an :class:`ExplainConfig` instance provides control
420
+ over which parts of explanation are generated.
421
+ Explainability is currently only supported for single entity
422
+ predictions with ``run_mode="FAST"``.
423
+ run_mode: The :class:`RunMode` for the query.
424
+ num_neighbors: The number of neighbors to sample for each hop.
425
+ If specified, the ``num_hops`` option will be ignored.
426
+ num_hops: The number of hops to sample when generating the context.
427
+ verbose: Whether to print verbose output.
428
+ exclude_cols_dict: Any column in any table to exclude from the
429
+ model input.
430
+ use_prediction_time: Whether to use the anchor timestamp as an
431
+ additional feature during prediction. This is typically
432
+ beneficial for time series forecasting tasks.
433
+ top_k: The number of predictions to return per entity.
434
+
435
+ Returns:
436
+ The predictions as a :class:`pandas.DataFrame`.
437
+ If ``explain`` is provided, returns an :class:`Explanation` object
438
+ containing the prediction, summary, and details.
439
+ """
345
440
  if num_hops != 2 and num_neighbors is not None:
346
441
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
347
442
  f"custom 'num_hops={num_hops}' option")
443
+ if num_neighbors is None:
444
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
445
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
446
+
447
+ explain_config: ExplainConfig | None = None
448
+ if explain is True:
449
+ explain_config = ExplainConfig()
450
+ elif explain is not False:
451
+ explain_config = ExplainConfig._cast(explain)
348
452
 
349
453
  if explain_config is not None and run_mode in {
350
454
  RunMode.NORMAL, RunMode.BEST
@@ -353,83 +457,82 @@ class KumoRFM:
353
457
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
354
458
  f"mode has been reset. Please lower the run mode to "
355
459
  f"suppress this warning.")
460
+ run_mode = RunMode.FAST
356
461
 
357
- if indices is None:
358
- if query_def.rfm_entity_ids is None:
359
- raise ValueError("Cannot find entities to predict for. Please "
360
- "pass them via `predict(query, indices=...)`")
361
- indices = query_def.get_rfm_entity_id_list()
362
- else:
363
- query_def = replace(query_def, rfm_entity_ids=None)
364
-
365
- if len(indices) == 0:
366
- raise ValueError("At least one entity is required")
367
-
368
- if explain_config is not None and len(indices) > 1:
369
- raise ValueError(
370
- f"Cannot explain predictions for more than a single entity "
371
- f"(got {len(indices)})")
372
-
373
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
374
- if explain_config is not None:
375
- msg = f'[bold]EXPLAIN[/bold] {query_repr}'
376
- else:
377
- msg = f'[bold]PREDICT[/bold] {query_repr}'
462
+ if explain_config is not None and task.num_prediction_examples > 1:
463
+ raise ValueError(f"Cannot explain predictions for more than a "
464
+ f"single entity "
465
+ f"(got {task.num_prediction_examples:,})")
378
466
 
379
467
  if not isinstance(verbose, ProgressLogger):
468
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
469
+ task_type_repr = 'binary classification'
470
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
471
+ task_type_repr = 'multi-class classification'
472
+ elif task.task_type == TaskType.REGRESSION:
473
+ task_type_repr = 'regression'
474
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
475
+ task_type_repr = 'link prediction'
476
+ else:
477
+ task_type_repr = str(task.task_type)
478
+
479
+ if explain_config is not None:
480
+ msg = f'Explain {task_type_repr} task'
481
+ else:
482
+ msg = f'Predict {task_type_repr} task'
380
483
  verbose = ProgressLogger.default(msg=msg, verbose=verbose)
381
484
 
382
485
  with verbose as logger:
383
-
384
- batch_size: int | None = None
385
- if self._batch_size == 'max':
386
- task_type = self._get_task_type(
387
- query=query_def,
388
- edge_types=self._sampler.edge_types,
389
- )
390
- batch_size = _MAX_PRED_SIZE[task_type]
486
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
487
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
488
+ f"out of {task.num_context_examples:,} in-context "
489
+ f"examples")
490
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
491
+
492
+ if self._batch_size is None:
493
+ batch_size = task.num_prediction_examples
494
+ elif self._batch_size == 'max':
495
+ batch_size = _MAX_PRED_SIZE[task.task_type]
391
496
  else:
392
497
  batch_size = self._batch_size
393
498
 
394
- if batch_size is not None:
395
- offsets = range(0, len(indices), batch_size)
396
- batches = [indices[step:step + batch_size] for step in offsets]
397
- else:
398
- batches = [indices]
499
+ if batch_size > _MAX_PRED_SIZE[task.task_type]:
500
+ raise ValueError(f"Cannot predict for more than "
501
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
502
+ f"entities at once (got {batch_size:,}). Use "
503
+ f"`KumoRFM.batch_mode` to process entities "
504
+ f"in batches with a sufficient batch size.")
399
505
 
400
- if len(batches) > 1:
401
- logger.log(f"Splitting {len(indices):,} entities into "
402
- f"{len(batches):,} batches of size {batch_size:,}")
506
+ if task.num_prediction_examples > batch_size:
507
+ num = math.ceil(task.num_prediction_examples / batch_size)
508
+ logger.log(f"Splitting {task.num_prediction_examples:,} "
509
+ f"entities into {num:,} batches of size "
510
+ f"{batch_size:,}")
403
511
 
404
512
  predictions: list[pd.DataFrame] = []
405
513
  summary: str | None = None
406
514
  details: Explanation | None = None
407
- for i, batch in enumerate(batches):
408
- # TODO Re-use the context for subsequent predictions.
515
+ for start in range(0, task.num_prediction_examples, batch_size):
409
516
  context = self._get_context(
410
- query=query_def,
411
- indices=batch,
412
- anchor_time=anchor_time,
413
- context_anchor_time=context_anchor_time,
414
- run_mode=RunMode(run_mode),
517
+ task=task.narrow_prediction(start, length=batch_size),
518
+ run_mode=run_mode,
415
519
  num_neighbors=num_neighbors,
416
- num_hops=num_hops,
417
- max_pq_iterations=max_pq_iterations,
418
- evaluate=False,
419
- random_seed=random_seed,
420
- logger=logger if i == 0 else None,
520
+ exclude_cols_dict=exclude_cols_dict,
521
+ top_k=top_k,
421
522
  )
523
+ context.y_test = None
524
+
422
525
  request = RFMPredictRequest(
423
526
  context=context,
424
527
  run_mode=RunMode(run_mode),
425
- query=query_str,
528
+ query=getattr(task, '_query', ''),
426
529
  use_prediction_time=use_prediction_time,
427
530
  )
428
531
  with warnings.catch_warnings():
429
532
  warnings.filterwarnings('ignore', message='gencode')
430
533
  request_msg = request.to_protobuf()
431
534
  _bytes = request_msg.SerializeToString()
432
- if i == 0:
535
+ if start == 0:
433
536
  logger.log(f"Generated context of size "
434
537
  f"{len(_bytes) / (1024*1024):.2f}MB")
435
538
 
@@ -437,11 +540,9 @@ class KumoRFM:
437
540
  stats = Context.get_memory_stats(request_msg.context)
438
541
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
439
542
 
440
- if i == 0 and len(batches) > 1:
441
- verbose.init_progress(
442
- total=len(batches),
443
- description='Predicting',
444
- )
543
+ if start == 0 and task.num_prediction_examples > batch_size:
544
+ num = math.ceil(task.num_prediction_examples / batch_size)
545
+ verbose.init_progress(total=num, description='Predicting')
445
546
 
446
547
  for attempt in range(self.num_retries + 1):
447
548
  try:
@@ -459,7 +560,7 @@ class KumoRFM:
459
560
  # Cast 'ENTITY' to correct data type:
460
561
  if 'ENTITY' in df:
461
562
  table_dict = context.subgraph.table_dict
462
- table = table_dict[query_def.entity_table]
563
+ table = table_dict[context.entity_table_names[0]]
463
564
  ser = table.df[table.primary_key]
464
565
  df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
465
566
 
@@ -476,7 +577,7 @@ class KumoRFM:
476
577
 
477
578
  predictions.append(df)
478
579
 
479
- if len(batches) > 1:
580
+ if task.num_prediction_examples > batch_size:
480
581
  verbose.step()
481
582
 
482
583
  break
@@ -601,40 +702,51 @@ class KumoRFM:
601
702
  The metrics as a :class:`pandas.DataFrame`
602
703
  """
603
704
  query_def = self._parse_query(query)
604
-
605
- if num_hops != 2 and num_neighbors is not None:
606
- warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
607
- f"custom 'num_hops={num_hops}' option")
608
-
609
705
  if query_def.rfm_entity_ids is not None:
610
706
  query_def = replace(
611
707
  query_def,
612
708
  rfm_entity_ids=None,
613
709
  )
614
710
 
615
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
616
- msg = f'[bold]EVALUATE[/bold] {query_repr}'
711
+ task_type = self._get_task_type(
712
+ query=query_def,
713
+ edge_types=self._sampler.edge_types,
714
+ )
715
+
716
+ if num_hops != 2 and num_neighbors is not None:
717
+ warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
718
+ f"custom 'num_hops={num_hops}' option")
719
+ if num_neighbors is None:
720
+ key = RunMode.FAST if task_type.is_link_pred else run_mode
721
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
722
+
723
+ if metrics is not None and len(metrics) > 0:
724
+ self._validate_metrics(metrics, task_type)
725
+ metrics = list(dict.fromkeys(metrics))
617
726
 
618
727
  if not isinstance(verbose, ProgressLogger):
728
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
729
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
619
730
  verbose = ProgressLogger.default(msg=msg, verbose=verbose)
620
731
 
621
732
  with verbose as logger:
622
- context = self._get_context(
733
+ task_table = self._get_task_table(
623
734
  query=query_def,
624
735
  indices=None,
625
736
  anchor_time=anchor_time,
626
737
  context_anchor_time=context_anchor_time,
627
- run_mode=RunMode(run_mode),
628
- num_neighbors=num_neighbors,
629
- num_hops=num_hops,
738
+ run_mode=run_mode,
630
739
  max_pq_iterations=max_pq_iterations,
631
- evaluate=True,
632
740
  random_seed=random_seed,
633
- logger=logger if verbose else None,
741
+ logger=logger,
634
742
  )
635
- if metrics is not None and len(metrics) > 0:
636
- self._validate_metrics(metrics, context.task_type)
637
- metrics = list(dict.fromkeys(metrics))
743
+ context = self._get_context(
744
+ task=task_table,
745
+ run_mode=run_mode,
746
+ num_neighbors=num_neighbors,
747
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
748
+ )
749
+
638
750
  request = RFMEvaluateRequest(
639
751
  context=context,
640
752
  run_mode=RunMode(run_mode),
@@ -652,17 +764,23 @@ class KumoRFM:
652
764
  stats_msg = Context.get_memory_stats(request_msg.context)
653
765
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
654
766
 
655
- try:
656
- resp = self._api_client.evaluate(request_bytes)
657
- except HTTPException as e:
767
+ for attempt in range(self.num_retries + 1):
658
768
  try:
659
- msg = json.loads(e.detail)['detail']
660
- except Exception:
661
- msg = e.detail
662
- raise RuntimeError(f"An unexpected exception occurred. "
663
- f"Please create an issue at "
664
- f"'https://github.com/kumo-ai/kumo-rfm'. "
665
- f"{msg}") from None
769
+ resp = self._api_client.evaluate(request_bytes)
770
+ break
771
+ except HTTPException as e:
772
+ if attempt == self.num_retries:
773
+ try:
774
+ msg = json.loads(e.detail)['detail']
775
+ except Exception:
776
+ msg = e.detail
777
+ raise RuntimeError(
778
+ f"An unexpected exception occurred. Please create "
779
+ f"an issue at "
780
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
781
+ ) from None
782
+
783
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
666
784
 
667
785
  return pd.DataFrame.from_dict(
668
786
  resp.metrics,
@@ -714,7 +832,7 @@ class KumoRFM:
714
832
  f"to have a time column")
715
833
 
716
834
  train, test = self._sampler.sample_target(
717
- query=query,
835
+ query=query_def,
718
836
  num_train_examples=0,
719
837
  train_anchor_time=anchor_time,
720
838
  num_train_trials=0,
@@ -742,30 +860,34 @@ class KumoRFM:
742
860
  "`predict()` or `evaluate()` methods to perform "
743
861
  "predictions or evaluations.")
744
862
 
745
- try:
746
- request = RFMParseQueryRequest(
747
- query=query,
748
- graph_definition=self._graph_def,
749
- )
863
+ request = RFMParseQueryRequest(
864
+ query=query,
865
+ graph_definition=self._graph_def,
866
+ )
867
+
868
+ for attempt in range(self.num_retries + 1):
869
+ try:
870
+ resp = self._api_client.parse_query(request)
871
+ break
872
+ except HTTPException as e:
873
+ if attempt == self.num_retries:
874
+ try:
875
+ msg = json.loads(e.detail)['detail']
876
+ except Exception:
877
+ msg = e.detail
878
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
750
879
 
751
- resp = self._api_client.parse_query(request)
880
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
752
881
 
753
- if len(resp.validation_response.warnings) > 0:
754
- msg = '\n'.join([
755
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
756
- in enumerate(resp.validation_response.warnings)
757
- ])
758
- warnings.warn(f"Encountered the following warnings during "
759
- f"parsing:\n{msg}")
882
+ if len(resp.validation_response.warnings) > 0:
883
+ msg = '\n'.join([
884
+ f'{i+1}. {warning.title}: {warning.message}'
885
+ for i, warning in enumerate(resp.validation_response.warnings)
886
+ ])
887
+ warnings.warn(f"Encountered the following warnings during "
888
+ f"parsing:\n{msg}")
760
889
 
761
- return resp.query
762
- except HTTPException as e:
763
- try:
764
- msg = json.loads(e.detail)['detail']
765
- except Exception:
766
- msg = e.detail
767
- raise ValueError(f"Failed to parse query '{query}'. "
768
- f"{msg}") from None
890
+ return resp.query
769
891
 
770
892
  @staticmethod
771
893
  def _get_task_type(
@@ -809,16 +931,15 @@ class KumoRFM:
809
931
 
810
932
  def _get_default_anchor_time(
811
933
  self,
812
- query: ValidatedPredictiveQuery,
934
+ query: ValidatedPredictiveQuery | None = None,
813
935
  ) -> pd.Timestamp:
814
- if query.query_type == QueryType.TEMPORAL:
936
+ if query is not None and query.query_type == QueryType.TEMPORAL:
815
937
  aggr_table_names = [
816
938
  aggr._get_target_column_name().split('.')[0]
817
939
  for aggr in query.get_all_target_aggregations()
818
940
  ]
819
941
  return self._sampler.get_max_time(aggr_table_names)
820
942
 
821
- assert query.query_type == QueryType.STATIC
822
943
  return self._sampler.get_max_time()
823
944
 
824
945
  def _validate_time(
@@ -888,40 +1009,26 @@ class KumoRFM:
888
1009
  f"Anchor timestamp for evaluation is after the latest "
889
1010
  f"supported timestamp '{max_time - end_offset}'.")
890
1011
 
891
- def _get_context(
1012
+ def _get_task_table(
892
1013
  self,
893
1014
  query: ValidatedPredictiveQuery,
894
1015
  indices: list[str] | list[float] | list[int] | None,
895
- anchor_time: pd.Timestamp | Literal['entity'] | None,
896
- context_anchor_time: pd.Timestamp | None,
897
- run_mode: RunMode,
898
- num_neighbors: list[int] | None,
899
- num_hops: int,
900
- max_pq_iterations: int,
901
- evaluate: bool,
1016
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
1017
+ context_anchor_time: pd.Timestamp | None = None,
1018
+ run_mode: RunMode = RunMode.FAST,
1019
+ max_pq_iterations: int = 10,
902
1020
  random_seed: int | None = _RANDOM_SEED,
903
1021
  logger: ProgressLogger | None = None,
904
- ) -> Context:
905
-
906
- if num_neighbors is not None:
907
- num_hops = len(num_neighbors)
908
-
909
- if num_hops < 0:
910
- raise ValueError(f"'num_hops' must be non-negative "
911
- f"(got {num_hops})")
912
- if num_hops > 6:
913
- raise ValueError(f"Cannot predict on subgraphs with more than 6 "
914
- f"hops (got {num_hops}). Please reduce the "
915
- f"number of hops and try again. Please create a "
916
- f"feature request at "
917
- f"'https://github.com/kumo-ai/kumo-rfm' if you "
918
- f"must go beyond this for your use-case.")
1022
+ ) -> TaskTable:
919
1023
 
920
1024
  task_type = self._get_task_type(
921
1025
  query=query,
922
1026
  edge_types=self._sampler.edge_types,
923
1027
  )
924
1028
 
1029
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1030
+ num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
1031
+
925
1032
  if logger is not None:
926
1033
  if task_type == TaskType.BINARY_CLASSIFICATION:
927
1034
  task_type_repr = 'binary classification'
@@ -935,21 +1042,6 @@ class KumoRFM:
935
1042
  task_type_repr = str(task_type)
936
1043
  logger.log(f"Identified {query.query_type} {task_type_repr} task")
937
1044
 
938
- if task_type.is_link_pred and num_hops < 2:
939
- raise ValueError(f"Cannot perform link prediction on subgraphs "
940
- f"with less than 2 hops (got {num_hops}) since "
941
- f"historical target entities need to be part of "
942
- f"the context. Please increase the number of "
943
- f"hops and try again.")
944
-
945
- if num_neighbors is None:
946
- if run_mode == RunMode.DEBUG:
947
- num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
948
- elif run_mode == RunMode.FAST or task_type.is_link_pred:
949
- num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
950
- else:
951
- num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
952
-
953
1045
  if query.target_ast.date_offset_range is None:
954
1046
  step_offset = pd.DateOffset(0)
955
1047
  else:
@@ -958,8 +1050,7 @@ class KumoRFM:
958
1050
 
959
1051
  if anchor_time is None:
960
1052
  anchor_time = self._get_default_anchor_time(query)
961
-
962
- if evaluate:
1053
+ if num_test_examples > 0:
963
1054
  anchor_time = anchor_time - end_offset
964
1055
 
965
1056
  if logger is not None:
@@ -973,7 +1064,6 @@ class KumoRFM:
973
1064
  else:
974
1065
  logger.log(f"Derived anchor time {anchor_time}")
975
1066
 
976
- assert anchor_time is not None
977
1067
  if isinstance(anchor_time, pd.Timestamp):
978
1068
  if context_anchor_time == 'entity':
979
1069
  raise ValueError("Anchor time 'entity' needs to be shared "
@@ -981,7 +1071,7 @@ class KumoRFM:
981
1071
  if context_anchor_time is None:
982
1072
  context_anchor_time = anchor_time - end_offset
983
1073
  self._validate_time(query, anchor_time, context_anchor_time,
984
- evaluate)
1074
+ evaluate=num_test_examples > 0)
985
1075
  else:
986
1076
  assert anchor_time == 'entity'
987
1077
  if query.query_type != QueryType.STATIC:
@@ -996,14 +1086,6 @@ class KumoRFM:
996
1086
  "for context and prediction examples")
997
1087
  context_anchor_time = 'entity'
998
1088
 
999
- num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1000
- if evaluate:
1001
- num_test_examples = _MAX_TEST_SIZE[run_mode]
1002
- if task_type.is_link_pred:
1003
- num_test_examples = num_test_examples // 5
1004
- else:
1005
- num_test_examples = 0
1006
-
1007
1089
  train, test = self._sampler.sample_target(
1008
1090
  query=query,
1009
1091
  num_train_examples=num_train_examples,
@@ -1014,39 +1096,32 @@ class KumoRFM:
1014
1096
  num_test_trials=max_pq_iterations * num_test_examples,
1015
1097
  random_seed=random_seed,
1016
1098
  )
1017
- train_pkey, train_time, y_train = train
1018
- test_pkey, test_time, y_test = test
1099
+ train_pkey, train_time, train_y = train
1100
+ test_pkey, test_time, test_y = test
1019
1101
 
1020
- if evaluate and logger is not None:
1102
+ if num_test_examples > 0 and logger is not None:
1021
1103
  if task_type == TaskType.BINARY_CLASSIFICATION:
1022
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
1023
- msg = (f"Collected {len(y_test):,} test examples with "
1104
+ pos = 100 * int((test_y > 0).sum()) / len(test_y)
1105
+ msg = (f"Collected {len(test_y):,} test examples with "
1024
1106
  f"{pos:.2f}% positive cases")
1025
1107
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1026
- msg = (f"Collected {len(y_test):,} test examples holding "
1027
- f"{y_test.nunique()} classes")
1108
+ msg = (f"Collected {len(test_y):,} test examples holding "
1109
+ f"{test_y.nunique()} classes")
1028
1110
  elif task_type == TaskType.REGRESSION:
1029
- _min, _max = float(y_test.min()), float(y_test.max())
1030
- msg = (f"Collected {len(y_test):,} test examples with targets "
1111
+ _min, _max = float(test_y.min()), float(test_y.max())
1112
+ msg = (f"Collected {len(test_y):,} test examples with targets "
1031
1113
  f"between {format_value(_min)} and "
1032
1114
  f"{format_value(_max)}")
1033
1115
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1034
- num_rhs = y_test.explode().nunique()
1035
- msg = (f"Collected {len(y_test):,} test examples with "
1116
+ num_rhs = test_y.explode().nunique()
1117
+ msg = (f"Collected {len(test_y):,} test examples with "
1036
1118
  f"{num_rhs:,} unique items")
1037
1119
  else:
1038
1120
  raise NotImplementedError
1039
1121
  logger.log(msg)
1040
1122
 
1041
- if not evaluate:
1123
+ if num_test_examples == 0:
1042
1124
  assert indices is not None
1043
- if len(indices) > _MAX_PRED_SIZE[task_type]:
1044
- raise ValueError(f"Cannot predict for more than "
1045
- f"{_MAX_PRED_SIZE[task_type]:,} entities at "
1046
- f"once (got {len(indices):,}). Use "
1047
- f"`KumoRFM.batch_mode` to process entities "
1048
- f"in batches")
1049
-
1050
1125
  test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
1051
1126
  if isinstance(anchor_time, pd.Timestamp):
1052
1127
  test_time = pd.Series([anchor_time]).repeat(
@@ -1056,26 +1131,26 @@ class KumoRFM:
1056
1131
 
1057
1132
  if logger is not None:
1058
1133
  if task_type == TaskType.BINARY_CLASSIFICATION:
1059
- pos = 100 * int((y_train > 0).sum()) / len(y_train)
1060
- msg = (f"Collected {len(y_train):,} in-context examples with "
1134
+ pos = 100 * int((train_y > 0).sum()) / len(train_y)
1135
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1061
1136
  f"{pos:.2f}% positive cases")
1062
1137
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1063
- msg = (f"Collected {len(y_train):,} in-context examples "
1064
- f"holding {y_train.nunique()} classes")
1138
+ msg = (f"Collected {len(train_y):,} in-context examples "
1139
+ f"holding {train_y.nunique()} classes")
1065
1140
  elif task_type == TaskType.REGRESSION:
1066
- _min, _max = float(y_train.min()), float(y_train.max())
1067
- msg = (f"Collected {len(y_train):,} in-context examples with "
1141
+ _min, _max = float(train_y.min()), float(train_y.max())
1142
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1068
1143
  f"targets between {format_value(_min)} and "
1069
1144
  f"{format_value(_max)}")
1070
1145
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1071
- num_rhs = y_train.explode().nunique()
1072
- msg = (f"Collected {len(y_train):,} in-context examples with "
1146
+ num_rhs = train_y.explode().nunique()
1147
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1073
1148
  f"{num_rhs:,} unique items")
1074
1149
  else:
1075
1150
  raise NotImplementedError
1076
1151
  logger.log(msg)
1077
1152
 
1078
- entity_table_names: tuple[str, ...]
1153
+ entity_table_names: tuple[str] | tuple[str, str]
1079
1154
  if task_type.is_link_pred:
1080
1155
  final_aggr = query.get_final_target_aggregation()
1081
1156
  assert final_aggr is not None
@@ -1089,27 +1164,98 @@ class KumoRFM:
1089
1164
  else:
1090
1165
  entity_table_names = (query.entity_table, )
1091
1166
 
1167
+ context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
1168
+ if isinstance(train_time, pd.Series):
1169
+ context_df['ANCHOR_TIMESTAMP'] = train_time
1170
+ pred_df = pd.DataFrame({'ENTITY': test_pkey})
1171
+ if num_test_examples > 0:
1172
+ pred_df['TARGET'] = test_y
1173
+ if isinstance(test_time, pd.Series):
1174
+ pred_df['ANCHOR_TIMESTAMP'] = test_time
1175
+
1176
+ return TaskTable(
1177
+ task_type=task_type,
1178
+ context_df=context_df,
1179
+ pred_df=pred_df,
1180
+ entity_table_name=entity_table_names,
1181
+ entity_column='ENTITY',
1182
+ target_column='TARGET',
1183
+ time_column='ANCHOR_TIMESTAMP' if isinstance(
1184
+ train_time, pd.Series) else TaskTable.ENTITY_TIME,
1185
+ )
1186
+
1187
+ def _get_context(
1188
+ self,
1189
+ task: TaskTable,
1190
+ run_mode: RunMode | str = RunMode.FAST,
1191
+ num_neighbors: list[int] | None = None,
1192
+ exclude_cols_dict: dict[str, list[str]] | None = None,
1193
+ top_k: int | None = None,
1194
+ ) -> Context:
1195
+
1196
+ # TODO Remove all
1197
+ if task.num_context_examples > max(_MAX_CONTEXT_SIZE.values()):
1198
+ raise ValueError(f"Cannot process a context with more than "
1199
+ f"{max(_MAX_CONTEXT_SIZE.values()):,} samples "
1200
+ f"(got {task.num_context_examples:,})")
1201
+ if task.evaluate:
1202
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
1203
+ raise ValueError(f"Cannot process a test set with more than "
1204
+ f"{_MAX_TEST_SIZE[task.task_type]:,} samples "
1205
+ f"for evaluation "
1206
+ f"(got {task.num_prediction_examples:,})")
1207
+ else:
1208
+ if task.num_prediction_examples > _MAX_PRED_SIZE[task.task_type]:
1209
+ raise ValueError(f"Cannot predict for more than "
1210
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
1211
+ f"entities at once "
1212
+ f"(got {task.num_prediction_examples:,})")
1213
+
1214
+ if num_neighbors is None:
1215
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1216
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
1217
+
1218
+ if len(num_neighbors) > 6:
1219
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
1220
+ f"hops (got {len(num_neighbors)}). Reduce the "
1221
+ f"number of hops and try again. Please create a "
1222
+ f"feature request at "
1223
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1224
+ f"must go beyond this for your use-case.")
1225
+
1092
1226
  # Exclude the entity anchor time from the feature set to prevent
1093
1227
  # running out-of-distribution between in-context and test examples:
1094
- exclude_cols_dict = query.get_exclude_cols_dict()
1095
- if entity_table_names[0] in self._sampler.time_column_dict:
1096
- if entity_table_names[0] not in exclude_cols_dict:
1097
- exclude_cols_dict[entity_table_names[0]] = []
1098
- time_column = self._sampler.time_column_dict[entity_table_names[0]]
1099
- exclude_cols_dict[entity_table_names[0]].append(time_column)
1228
+ exclude_cols_dict = exclude_cols_dict or {}
1229
+ if task.entity_table_name in self._sampler.time_column_dict:
1230
+ if task.entity_table_name not in exclude_cols_dict:
1231
+ exclude_cols_dict[task.entity_table_name] = []
1232
+ time_col = self._sampler.time_column_dict[task.entity_table_name]
1233
+ exclude_cols_dict[task.entity_table_name].append(time_col)
1234
+
1235
+ entity_pkey = pd.concat([
1236
+ task._context_df[task._entity_column],
1237
+ task._pred_df[task._entity_column],
1238
+ ], axis=0, ignore_index=True)
1239
+
1240
+ if task.use_entity_time:
1241
+ if task.entity_table_name not in self._sampler.time_column_dict:
1242
+ raise ValueError(f"The given annchor time requires the entity "
1243
+ f"table '{task.entity_table_name}' to have a "
1244
+ f"time column")
1245
+ anchor_time = 'entity'
1246
+ elif task._time_column is not None:
1247
+ anchor_time = pd.concat([
1248
+ task._context_df[task._time_column],
1249
+ task._pred_df[task._time_column],
1250
+ ], axis=0, ignore_index=True)
1251
+ else:
1252
+ anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
1253
+ (len(entity_pkey))).reset_index(drop=True)
1100
1254
 
1101
1255
  subgraph = self._sampler.sample_subgraph(
1102
- entity_table_names=entity_table_names,
1103
- entity_pkey=pd.concat(
1104
- [train_pkey, test_pkey],
1105
- axis=0,
1106
- ignore_index=True,
1107
- ),
1108
- anchor_time=pd.concat(
1109
- [train_time, test_time],
1110
- axis=0,
1111
- ignore_index=True,
1112
- ) if isinstance(train_time, pd.Series) else 'entity',
1256
+ entity_table_names=task.entity_table_names,
1257
+ entity_pkey=entity_pkey,
1258
+ anchor_time=anchor_time,
1113
1259
  num_neighbors=num_neighbors,
1114
1260
  exclude_cols_dict=exclude_cols_dict,
1115
1261
  )
@@ -1121,13 +1267,20 @@ class KumoRFM:
1121
1267
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1122
1268
  f"must go beyond this for your use-case.")
1123
1269
 
1270
+ if (task.task_type.is_link_pred
1271
+ and task.entity_table_names[-1] not in subgraph.table_dict):
1272
+ raise ValueError("Cannot perform link prediction on subgraphs "
1273
+ "without any historical target entities. Please "
1274
+ "increase the number of hops and try again.")
1275
+
1124
1276
  return Context(
1125
- task_type=task_type,
1126
- entity_table_names=entity_table_names,
1277
+ task_type=task.task_type,
1278
+ entity_table_names=task.entity_table_names,
1127
1279
  subgraph=subgraph,
1128
- y_train=y_train,
1129
- y_test=y_test if evaluate else None,
1130
- top_k=query.top_k,
1280
+ y_train=task._context_df[task.target_column.name],
1281
+ y_test=task._pred_df[task.target_column.name]
1282
+ if task.evaluate else None,
1283
+ top_k=top_k,
1131
1284
  step_size=None,
1132
1285
  )
1133
1286