kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +177 -50
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +23 -3
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/sampler.py +782 -0
  23. kumoai/experimental/rfm/base/source.py +2 -1
  24. kumoai/experimental/rfm/base/sql_sampler.py +247 -0
  25. kumoai/experimental/rfm/base/table.py +404 -203
  26. kumoai/experimental/rfm/graph.py +374 -172
  27. kumoai/experimental/rfm/infer/__init__.py +6 -4
  28. kumoai/experimental/rfm/infer/dtype.py +7 -4
  29. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  30. kumoai/experimental/rfm/infer/pkey.py +4 -2
  31. kumoai/experimental/rfm/infer/stype.py +35 -0
  32. kumoai/experimental/rfm/infer/time_col.py +1 -2
  33. kumoai/experimental/rfm/pquery/executor.py +27 -27
  34. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  35. kumoai/experimental/rfm/relbench.py +76 -0
  36. kumoai/experimental/rfm/rfm.py +762 -467
  37. kumoai/experimental/rfm/sagemaker.py +4 -4
  38. kumoai/experimental/rfm/task_table.py +292 -0
  39. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  40. kumoai/pquery/predictive_query.py +10 -6
  41. kumoai/pquery/training_table.py +16 -2
  42. kumoai/testing/snow.py +50 -0
  43. kumoai/trainer/distilled_trainer.py +175 -0
  44. kumoai/utils/__init__.py +3 -2
  45. kumoai/utils/display.py +87 -0
  46. kumoai/utils/progress_logger.py +190 -12
  47. kumoai/utils/sql.py +3 -0
  48. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +3 -2
  49. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +52 -41
  50. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  51. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  52. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
  53. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
  54. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,23 @@
1
1
  import json
2
+ import math
2
3
  import time
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from collections.abc import Generator
6
+ from collections.abc import Generator, Iterator
6
7
  from contextlib import contextmanager
7
8
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
9
+ from typing import Any, Literal, overload
19
10
 
20
- import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
24
21
  from kumoapi.rfm import Context
25
22
  from kumoapi.rfm import Explanation as ExplanationConfig
26
23
  from kumoapi.rfm import (
@@ -29,35 +26,38 @@ from kumoapi.rfm import (
29
26
  RFMPredictRequest,
30
27
  )
31
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
30
+ from rich.console import Console
31
+ from rich.markdown import Markdown
32
32
 
33
+ from kumoai import in_notebook
33
34
  from kumoai.client.rfm import RFMAPI
34
35
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import Graph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
36
+ from kumoai.experimental.rfm import Graph, TaskTable
37
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
38
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
39
+ from kumoai.utils import ProgressLogger, display
44
40
 
45
41
  _RANDOM_SEED = 42
46
42
 
47
43
  _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
44
  _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
45
 
46
+ _MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
47
+ _MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
48
+
50
49
  _MAX_CONTEXT_SIZE = {
51
50
  RunMode.DEBUG: 100,
52
51
  RunMode.FAST: 1_000,
53
52
  RunMode.NORMAL: 5_000,
54
53
  RunMode.BEST: 10_000,
55
54
  }
56
- _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
57
- RunMode.DEBUG: 100,
58
- RunMode.FAST: 2_000,
59
- RunMode.NORMAL: 2_000,
60
- RunMode.BEST: 2_000,
55
+
56
+ _DEFAULT_NUM_NEIGHBORS = {
57
+ RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
58
+ RunMode.FAST: [32, 32, 8, 8, 4, 4],
59
+ RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
60
+ RunMode.BEST: [64, 64, 8, 8, 4, 4],
61
61
  }
62
62
 
63
63
  _MAX_SIZE = 30 * 1024 * 1024
@@ -95,24 +95,36 @@ class Explanation:
95
95
  def __getitem__(self, index: Literal[1]) -> str:
96
96
  pass
97
97
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
98
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
99
  if index == 0:
100
100
  return self.prediction
101
101
  if index == 1:
102
102
  return self.summary
103
103
  raise IndexError("Index out of range")
104
104
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
105
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
106
  return iter((self.prediction, self.summary))
107
107
 
108
108
  def __repr__(self) -> str:
109
109
  return str((self.prediction, self.summary))
110
110
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
111
+ def __str__(self) -> str:
112
+ console = Console(soft_wrap=True)
113
+ with console.capture() as cap:
114
+ console.print(display.to_rich_table(self.prediction))
115
+ console.print(Markdown(self.summary))
116
+ return cap.get()[:-1]
117
+
118
+ def print(self) -> None:
119
+ r"""Prints the explanation."""
120
+ if in_notebook():
121
+ display.dataframe(self.prediction)
122
+ display.message(self.summary)
123
+ else:
124
+ print(self)
113
125
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
126
+ def _ipython_display_(self) -> None:
127
+ self.print()
116
128
 
117
129
 
118
130
  class KumoRFM:
@@ -151,21 +163,36 @@ class KumoRFM:
151
163
  Args:
152
164
  graph: The graph.
153
165
  verbose: Whether to print verbose output.
166
+ optimize: If set to ``True``, will optimize the underlying data backend
167
+ for optimal querying. For example, for transactional database
168
+ backends, will create any missing indices. Requires write-access to
169
+ the data backend.
154
170
  """
155
171
  def __init__(
156
172
  self,
157
173
  graph: Graph,
158
- verbose: Union[bool, ProgressLogger] = True,
174
+ verbose: bool | ProgressLogger = True,
175
+ optimize: bool = False,
159
176
  ) -> None:
160
177
  graph = graph.validate()
161
178
  self._graph_def = graph._to_api_graph_definition()
162
- self._graph_store = LocalGraphStore(graph, verbose)
163
- self._graph_sampler = LocalGraphSampler(self._graph_store)
164
179
 
165
- self._client: Optional[RFMAPI] = None
180
+ if graph.backend == DataBackend.LOCAL:
181
+ from kumoai.experimental.rfm.backend.local import LocalSampler
182
+ self._sampler: Sampler = LocalSampler(graph, verbose)
183
+ elif graph.backend == DataBackend.SQLITE:
184
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
185
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
186
+ elif graph.backend == DataBackend.SNOWFLAKE:
187
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
188
+ self._sampler = SnowSampler(graph, verbose)
189
+ else:
190
+ raise NotImplementedError
166
191
 
167
- self._batch_size: Optional[int | Literal['max']] = None
168
- self.num_retries: int = 0
192
+ self._client: RFMAPI | None = None
193
+
194
+ self._batch_size: int | Literal['max'] | None = None
195
+ self._num_retries: int = 0
169
196
 
170
197
  @property
171
198
  def _api_client(self) -> RFMAPI:
@@ -179,10 +206,34 @@ class KumoRFM:
179
206
  def __repr__(self) -> str:
180
207
  return f'{self.__class__.__name__}()'
181
208
 
209
+ @contextmanager
210
+ def retry(
211
+ self,
212
+ num_retries: int = 1,
213
+ ) -> Generator[None, None, None]:
214
+ """Context manager to retry failed queries due to unexpected server
215
+ issues.
216
+
217
+ .. code-block:: python
218
+
219
+ with model.retry(num_retries=1):
220
+ df = model.predict(query, indices=...)
221
+
222
+ Args:
223
+ num_retries: The maximum number of retries.
224
+ """
225
+ if num_retries < 0:
226
+ raise ValueError(f"'num_retries' must be greater than or equal to "
227
+ f"zero (got {num_retries})")
228
+
229
+ self._num_retries = num_retries
230
+ yield
231
+ self._num_retries = 0
232
+
182
233
  @contextmanager
183
234
  def batch_mode(
184
235
  self,
185
- batch_size: Union[int, Literal['max']] = 'max',
236
+ batch_size: int | Literal['max'] = 'max',
186
237
  num_retries: int = 1,
187
238
  ) -> Generator[None, None, None]:
188
239
  """Context manager to predict in batches.
@@ -202,31 +253,26 @@ class KumoRFM:
202
253
  raise ValueError(f"'batch_size' must be greater than zero "
203
254
  f"(got {batch_size})")
204
255
 
205
- if num_retries < 0:
206
- raise ValueError(f"'num_retries' must be greater than or equal to "
207
- f"zero (got {num_retries})")
208
-
209
256
  self._batch_size = batch_size
210
- self.num_retries = num_retries
211
- yield
257
+ with self.retry(self._num_retries or num_retries):
258
+ yield
212
259
  self._batch_size = None
213
- self.num_retries = 0
214
260
 
215
261
  @overload
216
262
  def predict(
217
263
  self,
218
264
  query: str,
219
- indices: Union[List[str], List[float], List[int], None] = None,
265
+ indices: list[str] | list[float] | list[int] | None = None,
220
266
  *,
221
267
  explain: Literal[False] = False,
222
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
223
- context_anchor_time: Union[pd.Timestamp, None] = None,
224
- run_mode: Union[RunMode, str] = RunMode.FAST,
225
- num_neighbors: Optional[List[int]] = None,
268
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
269
+ context_anchor_time: pd.Timestamp | None = None,
270
+ run_mode: RunMode | str = RunMode.FAST,
271
+ num_neighbors: list[int] | None = None,
226
272
  num_hops: int = 2,
227
- max_pq_iterations: int = 20,
228
- random_seed: Optional[int] = _RANDOM_SEED,
229
- verbose: Union[bool, ProgressLogger] = True,
273
+ max_pq_iterations: int = 10,
274
+ random_seed: int | None = _RANDOM_SEED,
275
+ verbose: bool | ProgressLogger = True,
230
276
  use_prediction_time: bool = False,
231
277
  ) -> pd.DataFrame:
232
278
  pass
@@ -235,37 +281,56 @@ class KumoRFM:
235
281
  def predict(
236
282
  self,
237
283
  query: str,
238
- indices: Union[List[str], List[float], List[int], None] = None,
284
+ indices: list[str] | list[float] | list[int] | None = None,
239
285
  *,
240
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
241
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
242
- context_anchor_time: Union[pd.Timestamp, None] = None,
243
- run_mode: Union[RunMode, str] = RunMode.FAST,
244
- num_neighbors: Optional[List[int]] = None,
286
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
287
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
288
+ context_anchor_time: pd.Timestamp | None = None,
289
+ run_mode: RunMode | str = RunMode.FAST,
290
+ num_neighbors: list[int] | None = None,
245
291
  num_hops: int = 2,
246
- max_pq_iterations: int = 20,
247
- random_seed: Optional[int] = _RANDOM_SEED,
248
- verbose: Union[bool, ProgressLogger] = True,
292
+ max_pq_iterations: int = 10,
293
+ random_seed: int | None = _RANDOM_SEED,
294
+ verbose: bool | ProgressLogger = True,
249
295
  use_prediction_time: bool = False,
250
296
  ) -> Explanation:
251
297
  pass
252
298
 
299
+ @overload
253
300
  def predict(
254
301
  self,
255
302
  query: str,
256
- indices: Union[List[str], List[float], List[int], None] = None,
303
+ indices: list[str] | list[float] | list[int] | None = None,
257
304
  *,
258
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
259
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
260
- context_anchor_time: Union[pd.Timestamp, None] = None,
261
- run_mode: Union[RunMode, str] = RunMode.FAST,
262
- num_neighbors: Optional[List[int]] = None,
305
+ explain: bool | ExplainConfig | dict[str, Any] = False,
306
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
307
+ context_anchor_time: pd.Timestamp | None = None,
308
+ run_mode: RunMode | str = RunMode.FAST,
309
+ num_neighbors: list[int] | None = None,
263
310
  num_hops: int = 2,
264
- max_pq_iterations: int = 20,
265
- random_seed: Optional[int] = _RANDOM_SEED,
266
- verbose: Union[bool, ProgressLogger] = True,
311
+ max_pq_iterations: int = 10,
312
+ random_seed: int | None = _RANDOM_SEED,
313
+ verbose: bool | ProgressLogger = True,
267
314
  use_prediction_time: bool = False,
268
- ) -> Union[pd.DataFrame, Explanation]:
315
+ ) -> pd.DataFrame | Explanation:
316
+ pass
317
+
318
+ def predict(
319
+ self,
320
+ query: str,
321
+ indices: list[str] | list[float] | list[int] | None = None,
322
+ *,
323
+ explain: bool | ExplainConfig | dict[str, Any] = False,
324
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
325
+ context_anchor_time: pd.Timestamp | None = None,
326
+ run_mode: RunMode | str = RunMode.FAST,
327
+ num_neighbors: list[int] | None = None,
328
+ num_hops: int = 2,
329
+ max_pq_iterations: int = 10,
330
+ random_seed: int | None = _RANDOM_SEED,
331
+ verbose: bool | ProgressLogger = True,
332
+ use_prediction_time: bool = False,
333
+ ) -> pd.DataFrame | Explanation:
269
334
  """Returns predictions for a predictive query.
270
335
 
271
336
  Args:
@@ -273,8 +338,7 @@ class KumoRFM:
273
338
  indices: The entity primary keys to predict for. Will override the
274
339
  indices given as part of the predictive query. Predictions will
275
340
  be generated for all indices, independent of whether they
276
- fulfill entity filter constraints. To pre-filter entities, use
277
- :meth:`~KumoRFM.is_valid_entity`.
341
+ fulfill entity filter constraints.
278
342
  explain: Configuration for explainability.
279
343
  If set to ``True``, will additionally explain the prediction.
280
344
  Passing in an :class:`ExplainConfig` instance provides control
@@ -307,18 +371,152 @@ class KumoRFM:
307
371
  If ``explain`` is provided, returns an :class:`Explanation` object
308
372
  containing the prediction, summary, and details.
309
373
  """
310
- explain_config: Optional[ExplainConfig] = None
311
- if explain is True:
312
- explain_config = ExplainConfig()
313
- elif explain is not False:
314
- explain_config = ExplainConfig._cast(explain)
315
-
316
374
  query_def = self._parse_query(query)
317
- query_str = query_def.to_string()
318
375
 
376
+ if indices is None:
377
+ if query_def.rfm_entity_ids is None:
378
+ raise ValueError("Cannot find entities to predict for. Please "
379
+ "pass them via `predict(query, indices=...)`")
380
+ indices = query_def.get_rfm_entity_id_list()
381
+ query_def = replace(
382
+ query_def,
383
+ for_each='FOR EACH',
384
+ rfm_entity_ids=None,
385
+ )
386
+
387
+ if not isinstance(verbose, ProgressLogger):
388
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
389
+ if explain is not False:
390
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
391
+ else:
392
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
393
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
394
+
395
+ with verbose as logger:
396
+ task_table = self._get_task_table(
397
+ query=query_def,
398
+ indices=indices,
399
+ anchor_time=anchor_time,
400
+ context_anchor_time=context_anchor_time,
401
+ run_mode=run_mode,
402
+ max_pq_iterations=max_pq_iterations,
403
+ random_seed=random_seed,
404
+ logger=logger,
405
+ )
406
+ task_table._query = query_def.to_string()
407
+
408
+ return self.predict_task(
409
+ task_table,
410
+ explain=explain,
411
+ run_mode=run_mode,
412
+ num_neighbors=num_neighbors,
413
+ num_hops=num_hops,
414
+ verbose=verbose,
415
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
416
+ use_prediction_time=use_prediction_time,
417
+ top_k=query_def.top_k,
418
+ )
419
+
420
+ @overload
421
+ def predict_task(
422
+ self,
423
+ task: TaskTable,
424
+ *,
425
+ explain: Literal[False] = False,
426
+ run_mode: RunMode | str = RunMode.FAST,
427
+ num_neighbors: list[int] | None = None,
428
+ num_hops: int = 2,
429
+ verbose: bool | ProgressLogger = True,
430
+ exclude_cols_dict: dict[str, list[str]] | None = None,
431
+ use_prediction_time: bool = False,
432
+ top_k: int | None = None,
433
+ ) -> pd.DataFrame:
434
+ pass
435
+
436
+ @overload
437
+ def predict_task(
438
+ self,
439
+ task: TaskTable,
440
+ *,
441
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
442
+ run_mode: RunMode | str = RunMode.FAST,
443
+ num_neighbors: list[int] | None = None,
444
+ num_hops: int = 2,
445
+ verbose: bool | ProgressLogger = True,
446
+ exclude_cols_dict: dict[str, list[str]] | None = None,
447
+ use_prediction_time: bool = False,
448
+ top_k: int | None = None,
449
+ ) -> Explanation:
450
+ pass
451
+
452
+ @overload
453
+ def predict_task(
454
+ self,
455
+ task: TaskTable,
456
+ *,
457
+ explain: bool | ExplainConfig | dict[str, Any] = False,
458
+ run_mode: RunMode | str = RunMode.FAST,
459
+ num_neighbors: list[int] | None = None,
460
+ num_hops: int = 2,
461
+ verbose: bool | ProgressLogger = True,
462
+ exclude_cols_dict: dict[str, list[str]] | None = None,
463
+ use_prediction_time: bool = False,
464
+ top_k: int | None = None,
465
+ ) -> pd.DataFrame | Explanation:
466
+ pass
467
+
468
+ def predict_task(
469
+ self,
470
+ task: TaskTable,
471
+ *,
472
+ explain: bool | ExplainConfig | dict[str, Any] = False,
473
+ run_mode: RunMode | str = RunMode.FAST,
474
+ num_neighbors: list[int] | None = None,
475
+ num_hops: int = 2,
476
+ verbose: bool | ProgressLogger = True,
477
+ exclude_cols_dict: dict[str, list[str]] | None = None,
478
+ use_prediction_time: bool = False,
479
+ top_k: int | None = None,
480
+ ) -> pd.DataFrame | Explanation:
481
+ """Returns predictions for a custom task specification.
482
+
483
+ Args:
484
+ task: The custom :class:`TaskTable`.
485
+ explain: Configuration for explainability.
486
+ If set to ``True``, will additionally explain the prediction.
487
+ Passing in an :class:`ExplainConfig` instance provides control
488
+ over which parts of explanation are generated.
489
+ Explainability is currently only supported for single entity
490
+ predictions with ``run_mode="FAST"``.
491
+ run_mode: The :class:`RunMode` for the query.
492
+ num_neighbors: The number of neighbors to sample for each hop.
493
+ If specified, the ``num_hops`` option will be ignored.
494
+ num_hops: The number of hops to sample when generating the context.
495
+ verbose: Whether to print verbose output.
496
+ exclude_cols_dict: Any column in any table to exclude from the
497
+ model input.
498
+ use_prediction_time: Whether to use the anchor timestamp as an
499
+ additional feature during prediction. This is typically
500
+ beneficial for time series forecasting tasks.
501
+ top_k: The number of predictions to return per entity.
502
+
503
+ Returns:
504
+ The predictions as a :class:`pandas.DataFrame`.
505
+ If ``explain`` is provided, returns an :class:`Explanation` object
506
+ containing the prediction, summary, and details.
507
+ """
319
508
  if num_hops != 2 and num_neighbors is not None:
320
509
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
321
510
  f"custom 'num_hops={num_hops}' option")
511
+ if num_neighbors is None:
512
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
513
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
514
+
515
+ explain_config: ExplainConfig | None = None
516
+ if explain is True:
517
+ explain_config = ExplainConfig()
518
+ elif explain is not False:
519
+ explain_config = ExplainConfig._cast(explain)
322
520
 
323
521
  if explain_config is not None and run_mode in {
324
522
  RunMode.NORMAL, RunMode.BEST
@@ -327,83 +525,82 @@ class KumoRFM:
327
525
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
328
526
  f"mode has been reset. Please lower the run mode to "
329
527
  f"suppress this warning.")
528
+ run_mode = RunMode.FAST
330
529
 
331
- if indices is None:
332
- if query_def.rfm_entity_ids is None:
333
- raise ValueError("Cannot find entities to predict for. Please "
334
- "pass them via `predict(query, indices=...)`")
335
- indices = query_def.get_rfm_entity_id_list()
336
- else:
337
- query_def = replace(query_def, rfm_entity_ids=None)
338
-
339
- if len(indices) == 0:
340
- raise ValueError("At least one entity is required")
341
-
342
- if explain_config is not None and len(indices) > 1:
343
- raise ValueError(
344
- f"Cannot explain predictions for more than a single entity "
345
- f"(got {len(indices)})")
346
-
347
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
348
- if explain_config is not None:
349
- msg = f'[bold]EXPLAIN[/bold] {query_repr}'
350
- else:
351
- msg = f'[bold]PREDICT[/bold] {query_repr}'
530
+ if explain_config is not None and task.num_prediction_examples > 1:
531
+ raise ValueError(f"Cannot explain predictions for more than a "
532
+ f"single entity "
533
+ f"(got {task.num_prediction_examples:,})")
352
534
 
353
535
  if not isinstance(verbose, ProgressLogger):
354
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
355
-
356
- with verbose as logger:
357
-
358
- batch_size: Optional[int] = None
359
- if self._batch_size == 'max':
360
- task_type = LocalPQueryDriver.get_task_type(
361
- query_def,
362
- edge_types=self._graph_store.edge_types,
363
- )
364
- batch_size = _MAX_PRED_SIZE[task_type]
536
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
537
+ task_type_repr = 'binary classification'
538
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
539
+ task_type_repr = 'multi-class classification'
540
+ elif task.task_type == TaskType.REGRESSION:
541
+ task_type_repr = 'regression'
542
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
543
+ task_type_repr = 'link prediction'
365
544
  else:
366
- batch_size = self._batch_size
545
+ task_type_repr = str(task.task_type)
367
546
 
368
- if batch_size is not None:
369
- offsets = range(0, len(indices), batch_size)
370
- batches = [indices[step:step + batch_size] for step in offsets]
547
+ if explain_config is not None:
548
+ msg = f"Explaining {task_type_repr} task"
371
549
  else:
372
- batches = [indices]
550
+ msg = f"Predicting {task_type_repr} task"
551
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
373
552
 
374
- if len(batches) > 1:
375
- logger.log(f"Splitting {len(indices):,} entities into "
376
- f"{len(batches):,} batches of size {batch_size:,}")
553
+ with verbose as logger:
554
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
555
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
556
+ f"out of {task.num_context_examples:,} in-context "
557
+ f"examples")
558
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
559
+
560
+ if self._batch_size is None:
561
+ batch_size = task.num_prediction_examples
562
+ elif self._batch_size == 'max':
563
+ batch_size = _MAX_PRED_SIZE[task.task_type]
564
+ else:
565
+ batch_size = self._batch_size
377
566
 
378
- predictions: List[pd.DataFrame] = []
379
- summary: Optional[str] = None
380
- details: Optional[Explanation] = None
381
- for i, batch in enumerate(batches):
382
- # TODO Re-use the context for subsequent predictions.
567
+ if batch_size > _MAX_PRED_SIZE[task.task_type]:
568
+ raise ValueError(f"Cannot predict for more than "
569
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
570
+ f"entities at once (got {batch_size:,}). Use "
571
+ f"`KumoRFM.batch_mode` to process entities "
572
+ f"in batches with a sufficient batch size.")
573
+
574
+ if task.num_prediction_examples > batch_size:
575
+ num = math.ceil(task.num_prediction_examples / batch_size)
576
+ logger.log(f"Splitting {task.num_prediction_examples:,} "
577
+ f"entities into {num:,} batches of size "
578
+ f"{batch_size:,}")
579
+
580
+ predictions: list[pd.DataFrame] = []
581
+ summary: str | None = None
582
+ details: Explanation | None = None
583
+ for start in range(0, task.num_prediction_examples, batch_size):
383
584
  context = self._get_context(
384
- query=query_def,
385
- indices=batch,
386
- anchor_time=anchor_time,
387
- context_anchor_time=context_anchor_time,
388
- run_mode=RunMode(run_mode),
585
+ task=task.narrow_prediction(start, length=batch_size),
586
+ run_mode=run_mode,
389
587
  num_neighbors=num_neighbors,
390
- num_hops=num_hops,
391
- max_pq_iterations=max_pq_iterations,
392
- evaluate=False,
393
- random_seed=random_seed,
394
- logger=logger if i == 0 else None,
588
+ exclude_cols_dict=exclude_cols_dict,
589
+ top_k=top_k,
395
590
  )
591
+ context.y_test = None
592
+
396
593
  request = RFMPredictRequest(
397
594
  context=context,
398
595
  run_mode=RunMode(run_mode),
399
- query=query_str,
596
+ query=task._query,
400
597
  use_prediction_time=use_prediction_time,
401
598
  )
402
599
  with warnings.catch_warnings():
403
600
  warnings.filterwarnings('ignore', message='gencode')
404
601
  request_msg = request.to_protobuf()
405
602
  _bytes = request_msg.SerializeToString()
406
- if i == 0:
603
+ if start == 0:
407
604
  logger.log(f"Generated context of size "
408
605
  f"{len(_bytes) / (1024*1024):.2f}MB")
409
606
 
@@ -411,14 +608,11 @@ class KumoRFM:
411
608
  stats = Context.get_memory_stats(request_msg.context)
412
609
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
413
610
 
414
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
415
- and len(batches) > 1):
416
- verbose.init_progress(
417
- total=len(batches),
418
- description='Predicting',
419
- )
611
+ if start == 0 and task.num_prediction_examples > batch_size:
612
+ num = math.ceil(task.num_prediction_examples / batch_size)
613
+ verbose.init_progress(total=num, description='Predicting')
420
614
 
421
- for attempt in range(self.num_retries + 1):
615
+ for attempt in range(self._num_retries + 1):
422
616
  try:
423
617
  if explain_config is not None:
424
618
  resp = self._api_client.explain(
@@ -433,10 +627,10 @@ class KumoRFM:
433
627
 
434
628
  # Cast 'ENTITY' to correct data type:
435
629
  if 'ENTITY' in df:
436
- entity = query_def.entity_table
437
- pkey_map = self._graph_store.pkey_map_dict[entity]
438
- df['ENTITY'] = df['ENTITY'].astype(
439
- type(pkey_map.index[0]))
630
+ table_dict = context.subgraph.table_dict
631
+ table = table_dict[context.entity_table_names[0]]
632
+ ser = table.df[table.primary_key]
633
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
440
634
 
441
635
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
442
636
  if 'ANCHOR_TIMESTAMP' in df:
@@ -451,13 +645,12 @@ class KumoRFM:
451
645
 
452
646
  predictions.append(df)
453
647
 
454
- if (isinstance(verbose, InteractiveProgressLogger)
455
- and len(batches) > 1):
648
+ if task.num_prediction_examples > batch_size:
456
649
  verbose.step()
457
650
 
458
651
  break
459
652
  except HTTPException as e:
460
- if attempt == self.num_retries:
653
+ if attempt == self._num_retries:
461
654
  try:
462
655
  msg = json.loads(e.detail)['detail']
463
656
  except Exception:
@@ -487,69 +680,19 @@ class KumoRFM:
487
680
 
488
681
  return prediction
489
682
 
490
- def is_valid_entity(
491
- self,
492
- query: str,
493
- indices: Union[List[str], List[float], List[int], None] = None,
494
- *,
495
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
496
- ) -> np.ndarray:
497
- r"""Returns a mask that denotes which entities are valid for the
498
- given predictive query, *i.e.*, which entities fulfill (temporal)
499
- entity filter constraints.
500
-
501
- Args:
502
- query: The predictive query.
503
- indices: The entity primary keys to predict for. Will override the
504
- indices given as part of the predictive query.
505
- anchor_time: The anchor timestamp for the prediction. If set to
506
- ``None``, will use the maximum timestamp in the data.
507
- If set to ``"entity"``, will use the timestamp of the entity.
508
- """
509
- query_def = self._parse_query(query)
510
-
511
- if indices is None:
512
- if query_def.rfm_entity_ids is None:
513
- raise ValueError("Cannot find entities to predict for. Please "
514
- "pass them via "
515
- "`is_valid_entity(query, indices=...)`")
516
- indices = query_def.get_rfm_entity_id_list()
517
-
518
- if len(indices) == 0:
519
- raise ValueError("At least one entity is required")
520
-
521
- if anchor_time is None:
522
- anchor_time = self._graph_store.max_time
523
-
524
- if isinstance(anchor_time, pd.Timestamp):
525
- self._validate_time(query_def, anchor_time, None, False)
526
- else:
527
- assert anchor_time == 'entity'
528
- if (query_def.entity_table not in self._graph_store.time_dict):
529
- raise ValueError(f"Anchor time 'entity' requires the entity "
530
- f"table '{query_def.entity_table}' "
531
- f"to have a time column.")
532
-
533
- node = self._graph_store.get_node_id(
534
- table_name=query_def.entity_table,
535
- pkey=pd.Series(indices),
536
- )
537
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
538
- return query_driver.is_valid(node, anchor_time)
539
-
540
683
  def evaluate(
541
684
  self,
542
685
  query: str,
543
686
  *,
544
- metrics: Optional[List[str]] = None,
545
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
546
- context_anchor_time: Union[pd.Timestamp, None] = None,
547
- run_mode: Union[RunMode, str] = RunMode.FAST,
548
- num_neighbors: Optional[List[int]] = None,
687
+ metrics: list[str] | None = None,
688
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
689
+ context_anchor_time: pd.Timestamp | None = None,
690
+ run_mode: RunMode | str = RunMode.FAST,
691
+ num_neighbors: list[int] | None = None,
549
692
  num_hops: int = 2,
550
- max_pq_iterations: int = 20,
551
- random_seed: Optional[int] = _RANDOM_SEED,
552
- verbose: Union[bool, ProgressLogger] = True,
693
+ max_pq_iterations: int = 10,
694
+ random_seed: int | None = _RANDOM_SEED,
695
+ verbose: bool | ProgressLogger = True,
553
696
  use_prediction_time: bool = False,
554
697
  ) -> pd.DataFrame:
555
698
  """Evaluates a predictive query.
@@ -581,41 +724,120 @@ class KumoRFM:
581
724
  Returns:
582
725
  The metrics as a :class:`pandas.DataFrame`
583
726
  """
584
- query_def = self._parse_query(query)
727
+ query_def = replace(
728
+ self._parse_query(query),
729
+ for_each='FOR EACH',
730
+ rfm_entity_ids=None,
731
+ )
732
+
733
+ if not isinstance(verbose, ProgressLogger):
734
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
735
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
736
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
737
+
738
+ with verbose as logger:
739
+ task_table = self._get_task_table(
740
+ query=query_def,
741
+ indices=None,
742
+ anchor_time=anchor_time,
743
+ context_anchor_time=context_anchor_time,
744
+ run_mode=run_mode,
745
+ max_pq_iterations=max_pq_iterations,
746
+ random_seed=random_seed,
747
+ logger=logger,
748
+ )
749
+
750
+ return self.evaluate_task(
751
+ task_table,
752
+ metrics=metrics,
753
+ run_mode=run_mode,
754
+ num_neighbors=num_neighbors,
755
+ num_hops=num_hops,
756
+ verbose=verbose,
757
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
758
+ use_prediction_time=use_prediction_time,
759
+ )
760
+
761
+ def evaluate_task(
762
+ self,
763
+ task: TaskTable,
764
+ *,
765
+ metrics: list[str] | None = None,
766
+ run_mode: RunMode | str = RunMode.FAST,
767
+ num_neighbors: list[int] | None = None,
768
+ num_hops: int = 2,
769
+ verbose: bool | ProgressLogger = True,
770
+ exclude_cols_dict: dict[str, list[str]] | None = None,
771
+ use_prediction_time: bool = False,
772
+ ) -> pd.DataFrame:
773
+ """Evaluates a custom task specification.
585
774
 
775
+ Args:
776
+ task: The custom :class:`TaskTable`.
777
+ metrics: The metrics to use.
778
+ run_mode: The :class:`RunMode` for the query.
779
+ num_neighbors: The number of neighbors to sample for each hop.
780
+ If specified, the ``num_hops`` option will be ignored.
781
+ num_hops: The number of hops to sample when generating the context.
782
+ verbose: Whether to print verbose output.
783
+ exclude_cols_dict: Any column in any table to exclude from the
784
+ model input.
785
+ use_prediction_time: Whether to use the anchor timestamp as an
786
+ additional feature during prediction. This is typically
787
+ beneficial for time series forecasting tasks.
788
+
789
+ Returns:
790
+ The metrics as a :class:`pandas.DataFrame`
791
+ """
586
792
  if num_hops != 2 and num_neighbors is not None:
587
793
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
588
794
  f"custom 'num_hops={num_hops}' option")
795
+ if num_neighbors is None:
796
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
797
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
589
798
 
590
- if query_def.rfm_entity_ids is not None:
591
- query_def = replace(
592
- query_def,
593
- rfm_entity_ids=None,
594
- )
595
-
596
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
597
- msg = f'[bold]EVALUATE[/bold] {query_repr}'
799
+ if metrics is not None and len(metrics) > 0:
800
+ self._validate_metrics(metrics, task.task_type)
801
+ metrics = list(dict.fromkeys(metrics))
598
802
 
599
803
  if not isinstance(verbose, ProgressLogger):
600
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
804
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
805
+ task_type_repr = 'binary classification'
806
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
807
+ task_type_repr = 'multi-class classification'
808
+ elif task.task_type == TaskType.REGRESSION:
809
+ task_type_repr = 'regression'
810
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
811
+ task_type_repr = 'link prediction'
812
+ else:
813
+ task_type_repr = str(task.task_type)
814
+
815
+ msg = f"Evaluating {task_type_repr} task"
816
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
601
817
 
602
818
  with verbose as logger:
819
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
820
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
821
+ f"out of {task.num_context_examples:,} in-context "
822
+ f"examples")
823
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
824
+
825
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
826
+ logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
827
+ f"out of {task.num_prediction_examples:,} test "
828
+ f"examples")
829
+ task = task.narrow_prediction(
830
+ start=0,
831
+ length=_MAX_TEST_SIZE[task.task_type],
832
+ )
833
+
603
834
  context = self._get_context(
604
- query=query_def,
605
- indices=None,
606
- anchor_time=anchor_time,
607
- context_anchor_time=context_anchor_time,
608
- run_mode=RunMode(run_mode),
835
+ task=task,
836
+ run_mode=run_mode,
609
837
  num_neighbors=num_neighbors,
610
- num_hops=num_hops,
611
- max_pq_iterations=max_pq_iterations,
612
- evaluate=True,
613
- random_seed=random_seed,
614
- logger=logger if verbose else None,
838
+ exclude_cols_dict=exclude_cols_dict,
615
839
  )
616
- if metrics is not None and len(metrics) > 0:
617
- self._validate_metrics(metrics, context.task_type)
618
- metrics = list(dict.fromkeys(metrics))
840
+
619
841
  request = RFMEvaluateRequest(
620
842
  context=context,
621
843
  run_mode=RunMode(run_mode),
@@ -633,17 +855,23 @@ class KumoRFM:
633
855
  stats_msg = Context.get_memory_stats(request_msg.context)
634
856
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
635
857
 
636
- try:
637
- resp = self._api_client.evaluate(request_bytes)
638
- except HTTPException as e:
858
+ for attempt in range(self._num_retries + 1):
639
859
  try:
640
- msg = json.loads(e.detail)['detail']
641
- except Exception:
642
- msg = e.detail
643
- raise RuntimeError(f"An unexpected exception occurred. "
644
- f"Please create an issue at "
645
- f"'https://github.com/kumo-ai/kumo-rfm'. "
646
- f"{msg}") from None
860
+ resp = self._api_client.evaluate(request_bytes)
861
+ break
862
+ except HTTPException as e:
863
+ if attempt == self._num_retries:
864
+ try:
865
+ msg = json.loads(e.detail)['detail']
866
+ except Exception:
867
+ msg = e.detail
868
+ raise RuntimeError(
869
+ f"An unexpected exception occurred. Please create "
870
+ f"an issue at "
871
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
872
+ ) from None
873
+
874
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
647
875
 
648
876
  return pd.DataFrame.from_dict(
649
877
  resp.metrics,
@@ -656,9 +884,9 @@ class KumoRFM:
656
884
  query: str,
657
885
  size: int,
658
886
  *,
659
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
660
- random_seed: Optional[int] = _RANDOM_SEED,
661
- max_iterations: int = 20,
887
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
888
+ random_seed: int | None = _RANDOM_SEED,
889
+ max_iterations: int = 10,
662
890
  ) -> pd.DataFrame:
663
891
  """Returns the labels of a predictive query for a specified anchor
664
892
  time.
@@ -678,40 +906,37 @@ class KumoRFM:
678
906
  query_def = self._parse_query(query)
679
907
 
680
908
  if anchor_time is None:
681
- anchor_time = self._graph_store.max_time
909
+ anchor_time = self._get_default_anchor_time(query_def)
682
910
  if query_def.target_ast.date_offset_range is not None:
683
- anchor_time = anchor_time - (
684
- query_def.target_ast.date_offset_range.end_date_offset *
685
- query_def.num_forecasts)
911
+ offset = query_def.target_ast.date_offset_range.end_date_offset
912
+ offset *= query_def.num_forecasts
913
+ anchor_time -= offset
686
914
 
687
915
  assert anchor_time is not None
688
916
  if isinstance(anchor_time, pd.Timestamp):
689
917
  self._validate_time(query_def, anchor_time, None, evaluate=True)
690
918
  else:
691
919
  assert anchor_time == 'entity'
692
- if (query_def.entity_table not in self._graph_store.time_dict):
920
+ if query_def.entity_table not in self._sampler.time_column_dict:
693
921
  raise ValueError(f"Anchor time 'entity' requires the entity "
694
922
  f"table '{query_def.entity_table}' "
695
923
  f"to have a time column")
696
924
 
697
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
698
- random_seed)
699
-
700
- node, time, y = query_driver.collect_test(
701
- size=size,
702
- anchor_time=anchor_time,
703
- batch_size=min(10_000, size),
704
- max_iterations=max_iterations,
705
- guarantee_train_examples=False,
925
+ train, test = self._sampler.sample_target(
926
+ query=query_def,
927
+ num_train_examples=0,
928
+ train_anchor_time=anchor_time,
929
+ num_train_trials=0,
930
+ num_test_examples=size,
931
+ test_anchor_time=anchor_time,
932
+ num_test_trials=max_iterations * size,
933
+ random_seed=random_seed,
706
934
  )
707
935
 
708
- entity = self._graph_store.pkey_map_dict[
709
- query_def.entity_table].index[node]
710
-
711
936
  return pd.DataFrame({
712
- 'ENTITY': entity,
713
- 'ANCHOR_TIMESTAMP': time,
714
- 'TARGET': y,
937
+ 'ENTITY': test.entity_pkey,
938
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
939
+ 'TARGET': test.target,
715
940
  })
716
941
 
717
942
  # Helpers #################################################################
@@ -726,63 +951,120 @@ class KumoRFM:
726
951
  "`predict()` or `evaluate()` methods to perform "
727
952
  "predictions or evaluations.")
728
953
 
729
- try:
730
- request = RFMParseQueryRequest(
731
- query=query,
732
- graph_definition=self._graph_def,
733
- )
954
+ request = RFMParseQueryRequest(
955
+ query=query,
956
+ graph_definition=self._graph_def,
957
+ )
734
958
 
735
- resp = self._api_client.parse_query(request)
959
+ for attempt in range(self._num_retries + 1):
960
+ try:
961
+ resp = self._api_client.parse_query(request)
962
+ break
963
+ except HTTPException as e:
964
+ if attempt == self._num_retries:
965
+ try:
966
+ msg = json.loads(e.detail)['detail']
967
+ except Exception:
968
+ msg = e.detail
969
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
736
970
 
737
- # TODO Expose validation warnings.
971
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
738
972
 
739
- if len(resp.validation_response.warnings) > 0:
740
- msg = '\n'.join([
741
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
742
- in enumerate(resp.validation_response.warnings)
743
- ])
744
- warnings.warn(f"Encountered the following warnings during "
745
- f"parsing:\n{msg}")
973
+ if len(resp.validation_response.warnings) > 0:
974
+ msg = '\n'.join([
975
+ f'{i+1}. {warning.title}: {warning.message}'
976
+ for i, warning in enumerate(resp.validation_response.warnings)
977
+ ])
978
+ warnings.warn(f"Encountered the following warnings during "
979
+ f"parsing:\n{msg}")
746
980
 
747
- return resp.query
748
- except HTTPException as e:
749
- try:
750
- msg = json.loads(e.detail)['detail']
751
- except Exception:
752
- msg = e.detail
753
- raise ValueError(f"Failed to parse query '{query}'. "
754
- f"{msg}") from None
981
+ return resp.query
982
+
983
+ @staticmethod
984
+ def _get_task_type(
985
+ query: ValidatedPredictiveQuery,
986
+ edge_types: list[tuple[str, str, str]],
987
+ ) -> TaskType:
988
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
989
+ return TaskType.BINARY_CLASSIFICATION
990
+
991
+ target = query.target_ast
992
+ if isinstance(target, Join):
993
+ target = target.rhs_target
994
+ if isinstance(target, Aggregation):
995
+ if target.aggr == AggregationType.LIST_DISTINCT:
996
+ table_name, col_name = target._get_target_column_name().split(
997
+ '.')
998
+ target_edge_types = [
999
+ edge_type for edge_type in edge_types
1000
+ if edge_type[0] == table_name and edge_type[1] == col_name
1001
+ ]
1002
+ if len(target_edge_types) != 1:
1003
+ raise NotImplementedError(
1004
+ f"Multilabel-classification queries based on "
1005
+ f"'LIST_DISTINCT' are not supported yet. If you "
1006
+ f"planned to write a link prediction query instead, "
1007
+ f"make sure to register '{col_name}' as a "
1008
+ f"foreign key.")
1009
+ return TaskType.TEMPORAL_LINK_PREDICTION
1010
+
1011
+ return TaskType.REGRESSION
1012
+
1013
+ assert isinstance(target, Column)
1014
+
1015
+ if target.stype in {Stype.ID, Stype.categorical}:
1016
+ return TaskType.MULTICLASS_CLASSIFICATION
1017
+
1018
+ if target.stype in {Stype.numerical}:
1019
+ return TaskType.REGRESSION
1020
+
1021
+ raise NotImplementedError("Task type not yet supported")
1022
+
1023
+ def _get_default_anchor_time(
1024
+ self,
1025
+ query: ValidatedPredictiveQuery | None = None,
1026
+ ) -> pd.Timestamp:
1027
+ if query is not None and query.query_type == QueryType.TEMPORAL:
1028
+ aggr_table_names = [
1029
+ aggr._get_target_column_name().split('.')[0]
1030
+ for aggr in query.get_all_target_aggregations()
1031
+ ]
1032
+ return self._sampler.get_max_time(aggr_table_names)
1033
+
1034
+ return self._sampler.get_max_time()
755
1035
 
756
1036
  def _validate_time(
757
1037
  self,
758
1038
  query: ValidatedPredictiveQuery,
759
1039
  anchor_time: pd.Timestamp,
760
- context_anchor_time: Union[pd.Timestamp, None],
1040
+ context_anchor_time: pd.Timestamp | None,
761
1041
  evaluate: bool,
762
1042
  ) -> None:
763
1043
 
764
- if self._graph_store.min_time == pd.Timestamp.max:
1044
+ if len(self._sampler.time_column_dict) == 0:
765
1045
  return # Graph without timestamps
766
1046
 
767
- if anchor_time < self._graph_store.min_time:
1047
+ min_time = self._sampler.get_min_time()
1048
+ max_time = self._sampler.get_max_time()
1049
+
1050
+ if anchor_time < min_time:
768
1051
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
769
- f"the earliest timestamp "
770
- f"'{self._graph_store.min_time}' in the data.")
1052
+ f"the earliest timestamp '{min_time}' in the "
1053
+ f"data.")
771
1054
 
772
- if (context_anchor_time is not None
773
- and context_anchor_time < self._graph_store.min_time):
1055
+ if context_anchor_time is not None and context_anchor_time < min_time:
774
1056
  raise ValueError(f"Context anchor timestamp is too early or "
775
1057
  f"aggregation time range is too large. To make "
776
1058
  f"this prediction, we would need data back to "
777
1059
  f"'{context_anchor_time}', however, your data "
778
- f"only contains data back to "
779
- f"'{self._graph_store.min_time}'.")
1060
+ f"only contains data back to '{min_time}'.")
780
1061
 
781
1062
  if query.target_ast.date_offset_range is not None:
782
1063
  end_offset = query.target_ast.date_offset_range.end_date_offset
783
1064
  else:
784
1065
  end_offset = pd.DateOffset(0)
785
- forecast_end_offset = end_offset * query.num_forecasts
1066
+ end_offset = end_offset * query.num_forecasts
1067
+
786
1068
  if (context_anchor_time is not None
787
1069
  and context_anchor_time > anchor_time):
788
1070
  warnings.warn(f"Context anchor timestamp "
@@ -792,7 +1074,7 @@ class KumoRFM:
792
1074
  f"intended.")
793
1075
  elif (query.query_type == QueryType.TEMPORAL
794
1076
  and context_anchor_time is not None
795
- and context_anchor_time + forecast_end_offset > anchor_time):
1077
+ and context_anchor_time + end_offset > anchor_time):
796
1078
  warnings.warn(f"Aggregation for context examples at timestamp "
797
1079
  f"'{context_anchor_time}' will leak information "
798
1080
  f"from the prediction anchor timestamp "
@@ -800,62 +1082,44 @@ class KumoRFM:
800
1082
  f"intended.")
801
1083
 
802
1084
  elif (context_anchor_time is not None
803
- and context_anchor_time - forecast_end_offset
804
- < self._graph_store.min_time):
805
- _time = context_anchor_time - forecast_end_offset
1085
+ and context_anchor_time - end_offset < min_time):
1086
+ _time = context_anchor_time - end_offset
806
1087
  warnings.warn(f"Context anchor timestamp is too early or "
807
1088
  f"aggregation time range is too large. To form "
808
1089
  f"proper input data, we would need data back to "
809
1090
  f"'{_time}', however, your data only contains "
810
- f"data back to '{self._graph_store.min_time}'.")
1091
+ f"data back to '{min_time}'.")
811
1092
 
812
- if (not evaluate and anchor_time
813
- > self._graph_store.max_time + pd.DateOffset(days=1)):
1093
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
814
1094
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
815
- f"latest timestamp '{self._graph_store.max_time}' "
816
- f"in the data. Please make sure this is intended.")
1095
+ f"latest timestamp '{max_time}' in the data. Please "
1096
+ f"make sure this is intended.")
817
1097
 
818
- max_eval_time = self._graph_store.max_time - forecast_end_offset
819
- if evaluate and anchor_time > max_eval_time:
1098
+ if evaluate and anchor_time > max_time - end_offset:
820
1099
  raise ValueError(
821
1100
  f"Anchor timestamp for evaluation is after the latest "
822
- f"supported timestamp '{max_eval_time}'.")
1101
+ f"supported timestamp '{max_time - end_offset}'.")
823
1102
 
824
- def _get_context(
1103
+ def _get_task_table(
825
1104
  self,
826
1105
  query: ValidatedPredictiveQuery,
827
- indices: Union[List[str], List[float], List[int], None],
828
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
829
- context_anchor_time: Union[pd.Timestamp, None],
830
- run_mode: RunMode,
831
- num_neighbors: Optional[List[int]],
832
- num_hops: int,
833
- max_pq_iterations: int,
834
- evaluate: bool,
835
- random_seed: Optional[int] = _RANDOM_SEED,
836
- logger: Optional[ProgressLogger] = None,
837
- ) -> Context:
838
-
839
- if num_neighbors is not None:
840
- num_hops = len(num_neighbors)
841
-
842
- if num_hops < 0:
843
- raise ValueError(f"'num_hops' must be non-negative "
844
- f"(got {num_hops})")
845
- if num_hops > 6:
846
- raise ValueError(f"Cannot predict on subgraphs with more than 6 "
847
- f"hops (got {num_hops}). Please reduce the "
848
- f"number of hops and try again. Please create a "
849
- f"feature request at "
850
- f"'https://github.com/kumo-ai/kumo-rfm' if you "
851
- f"must go beyond this for your use-case.")
852
-
853
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
854
- task_type = LocalPQueryDriver.get_task_type(
855
- query,
856
- edge_types=self._graph_store.edge_types,
1106
+ indices: list[str] | list[float] | list[int] | None,
1107
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
1108
+ context_anchor_time: pd.Timestamp | None = None,
1109
+ run_mode: RunMode = RunMode.FAST,
1110
+ max_pq_iterations: int = 10,
1111
+ random_seed: int | None = _RANDOM_SEED,
1112
+ logger: ProgressLogger | None = None,
1113
+ ) -> TaskTable:
1114
+
1115
+ task_type = self._get_task_type(
1116
+ query=query,
1117
+ edge_types=self._sampler.edge_types,
857
1118
  )
858
1119
 
1120
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1121
+ num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
1122
+
859
1123
  if logger is not None:
860
1124
  if task_type == TaskType.BINARY_CLASSIFICATION:
861
1125
  task_type_repr = 'binary classification'
@@ -869,30 +1133,17 @@ class KumoRFM:
869
1133
  task_type_repr = str(task_type)
870
1134
  logger.log(f"Identified {query.query_type} {task_type_repr} task")
871
1135
 
872
- if task_type.is_link_pred and num_hops < 2:
873
- raise ValueError(f"Cannot perform link prediction on subgraphs "
874
- f"with less than 2 hops (got {num_hops}) since "
875
- f"historical target entities need to be part of "
876
- f"the context. Please increase the number of "
877
- f"hops and try again.")
878
-
879
- if num_neighbors is None:
880
- if run_mode == RunMode.DEBUG:
881
- num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
882
- elif run_mode == RunMode.FAST or task_type.is_link_pred:
883
- num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
884
- else:
885
- num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
886
-
887
1136
  if query.target_ast.date_offset_range is None:
888
- end_offset = pd.DateOffset(0)
1137
+ step_offset = pd.DateOffset(0)
889
1138
  else:
890
- end_offset = query.target_ast.date_offset_range.end_date_offset
891
- forecast_end_offset = end_offset * query.num_forecasts
1139
+ step_offset = query.target_ast.date_offset_range.end_date_offset
1140
+ end_offset = step_offset * query.num_forecasts
1141
+
892
1142
  if anchor_time is None:
893
- anchor_time = self._graph_store.max_time
894
- if evaluate:
895
- anchor_time = anchor_time - forecast_end_offset
1143
+ anchor_time = self._get_default_anchor_time(query)
1144
+ if num_test_examples > 0:
1145
+ anchor_time = anchor_time - end_offset
1146
+
896
1147
  if logger is not None:
897
1148
  assert isinstance(anchor_time, pd.Timestamp)
898
1149
  if anchor_time == pd.Timestamp.min:
@@ -904,114 +1155,98 @@ class KumoRFM:
904
1155
  else:
905
1156
  logger.log(f"Derived anchor time {anchor_time}")
906
1157
 
907
- assert anchor_time is not None
908
1158
  if isinstance(anchor_time, pd.Timestamp):
1159
+ if context_anchor_time == 'entity':
1160
+ raise ValueError("Anchor time 'entity' needs to be shared "
1161
+ "for context and prediction examples")
909
1162
  if context_anchor_time is None:
910
- context_anchor_time = anchor_time - forecast_end_offset
1163
+ context_anchor_time = anchor_time - end_offset
911
1164
  self._validate_time(query, anchor_time, context_anchor_time,
912
- evaluate)
1165
+ evaluate=num_test_examples > 0)
913
1166
  else:
914
1167
  assert anchor_time == 'entity'
915
- if query.entity_table not in self._graph_store.time_dict:
1168
+ if query.query_type != QueryType.STATIC:
1169
+ raise ValueError("Anchor time 'entity' is only valid for "
1170
+ "static predictive queries")
1171
+ if query.entity_table not in self._sampler.time_column_dict:
916
1172
  raise ValueError(f"Anchor time 'entity' requires the entity "
917
1173
  f"table '{query.entity_table}' to "
918
1174
  f"have a time column")
919
- if context_anchor_time is not None:
920
- warnings.warn("Ignoring option 'context_anchor_time' for "
921
- "`anchor_time='entity'`")
922
- context_anchor_time = None
923
-
924
- y_test: Optional[pd.Series] = None
925
- if evaluate:
926
- max_test_size = _MAX_TEST_SIZE[run_mode]
927
- if task_type.is_link_pred:
928
- max_test_size = max_test_size // 5
929
-
930
- test_node, test_time, y_test = query_driver.collect_test(
931
- size=max_test_size,
932
- anchor_time=anchor_time,
933
- max_iterations=max_pq_iterations,
934
- guarantee_train_examples=True,
935
- )
936
- if logger is not None:
937
- if task_type == TaskType.BINARY_CLASSIFICATION:
938
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
939
- msg = (f"Collected {len(y_test):,} test examples with "
940
- f"{pos:.2f}% positive cases")
941
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
942
- msg = (f"Collected {len(y_test):,} test examples "
943
- f"holding {y_test.nunique()} classes")
944
- elif task_type == TaskType.REGRESSION:
945
- _min, _max = float(y_test.min()), float(y_test.max())
946
- msg = (f"Collected {len(y_test):,} test examples with "
947
- f"targets between {format_value(_min)} and "
948
- f"{format_value(_max)}")
949
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
950
- num_rhs = y_test.explode().nunique()
951
- msg = (f"Collected {len(y_test):,} test examples with "
952
- f"{num_rhs:,} unique items")
953
- else:
954
- raise NotImplementedError
955
- logger.log(msg)
956
-
957
- else:
958
- assert indices is not None
959
-
960
- if len(indices) > _MAX_PRED_SIZE[task_type]:
961
- raise ValueError(f"Cannot predict for more than "
962
- f"{_MAX_PRED_SIZE[task_type]:,} entities at "
963
- f"once (got {len(indices):,}). Use "
964
- f"`KumoRFM.batch_mode` to process entities "
965
- f"in batches")
1175
+ if isinstance(context_anchor_time, pd.Timestamp):
1176
+ raise ValueError("Anchor time 'entity' needs to be shared "
1177
+ "for context and prediction examples")
1178
+ context_anchor_time = 'entity'
1179
+
1180
+ train, test = self._sampler.sample_target(
1181
+ query=query,
1182
+ num_train_examples=num_train_examples,
1183
+ train_anchor_time=context_anchor_time,
1184
+ num_train_trials=max_pq_iterations * num_train_examples,
1185
+ num_test_examples=num_test_examples,
1186
+ test_anchor_time=anchor_time,
1187
+ num_test_trials=max_pq_iterations * num_test_examples,
1188
+ random_seed=random_seed,
1189
+ )
1190
+ train_pkey, train_time, train_y = train
1191
+ test_pkey, test_time, test_y = test
966
1192
 
967
- test_node = self._graph_store.get_node_id(
968
- table_name=query.entity_table,
969
- pkey=pd.Series(indices),
970
- )
1193
+ if num_test_examples > 0 and logger is not None:
1194
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1195
+ pos = 100 * int((test_y > 0).sum()) / len(test_y)
1196
+ msg = (f"Collected {len(test_y):,} test examples with "
1197
+ f"{pos:.2f}% positive cases")
1198
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1199
+ msg = (f"Collected {len(test_y):,} test examples holding "
1200
+ f"{test_y.nunique()} classes")
1201
+ elif task_type == TaskType.REGRESSION:
1202
+ _min, _max = float(test_y.min()), float(test_y.max())
1203
+ msg = (f"Collected {len(test_y):,} test examples with targets "
1204
+ f"between {format_value(_min)} and "
1205
+ f"{format_value(_max)}")
1206
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1207
+ num_rhs = test_y.explode().nunique()
1208
+ msg = (f"Collected {len(test_y):,} test examples with "
1209
+ f"{num_rhs:,} unique items")
1210
+ else:
1211
+ raise NotImplementedError
1212
+ logger.log(msg)
971
1213
 
1214
+ if num_test_examples == 0:
1215
+ assert indices is not None
1216
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
972
1217
  if isinstance(anchor_time, pd.Timestamp):
973
- test_time = pd.Series(anchor_time).repeat(
974
- len(test_node)).reset_index(drop=True)
1218
+ test_time = pd.Series([anchor_time]).repeat(
1219
+ len(indices)).reset_index(drop=True)
975
1220
  else:
976
- time = self._graph_store.time_dict[query.entity_table]
977
- time = time[test_node] * 1000**3
978
- test_time = pd.Series(time, dtype='datetime64[ns]')
979
-
980
- train_node, train_time, y_train = query_driver.collect_train(
981
- size=_MAX_CONTEXT_SIZE[run_mode],
982
- anchor_time=context_anchor_time or 'entity',
983
- exclude_node=test_node if (query.query_type == QueryType.STATIC
984
- or anchor_time == 'entity') else None,
985
- max_iterations=max_pq_iterations,
986
- )
1221
+ train_time = test_time = 'entity'
987
1222
 
988
1223
  if logger is not None:
989
1224
  if task_type == TaskType.BINARY_CLASSIFICATION:
990
- pos = 100 * int((y_train > 0).sum()) / len(y_train)
991
- msg = (f"Collected {len(y_train):,} in-context examples with "
1225
+ pos = 100 * int((train_y > 0).sum()) / len(train_y)
1226
+ msg = (f"Collected {len(train_y):,} in-context examples with "
992
1227
  f"{pos:.2f}% positive cases")
993
1228
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
994
- msg = (f"Collected {len(y_train):,} in-context examples "
995
- f"holding {y_train.nunique()} classes")
1229
+ msg = (f"Collected {len(train_y):,} in-context examples "
1230
+ f"holding {train_y.nunique()} classes")
996
1231
  elif task_type == TaskType.REGRESSION:
997
- _min, _max = float(y_train.min()), float(y_train.max())
998
- msg = (f"Collected {len(y_train):,} in-context examples with "
1232
+ _min, _max = float(train_y.min()), float(train_y.max())
1233
+ msg = (f"Collected {len(train_y):,} in-context examples with "
999
1234
  f"targets between {format_value(_min)} and "
1000
1235
  f"{format_value(_max)}")
1001
1236
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1002
- num_rhs = y_train.explode().nunique()
1003
- msg = (f"Collected {len(y_train):,} in-context examples with "
1237
+ num_rhs = train_y.explode().nunique()
1238
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1004
1239
  f"{num_rhs:,} unique items")
1005
1240
  else:
1006
1241
  raise NotImplementedError
1007
1242
  logger.log(msg)
1008
1243
 
1009
- entity_table_names: Tuple[str, ...]
1244
+ entity_table_names: tuple[str] | tuple[str, str]
1010
1245
  if task_type.is_link_pred:
1011
1246
  final_aggr = query.get_final_target_aggregation()
1012
1247
  assert final_aggr is not None
1013
1248
  edge_fkey = final_aggr._get_target_column_name()
1014
- for edge_type in self._graph_store.edge_types:
1249
+ for edge_type in self._sampler.edge_types:
1015
1250
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
1251
  entity_table_names = (
1017
1252
  query.entity_table,
@@ -1020,23 +1255,80 @@ class KumoRFM:
1020
1255
  else:
1021
1256
  entity_table_names = (query.entity_table, )
1022
1257
 
1258
+ context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
1259
+ if isinstance(train_time, pd.Series):
1260
+ context_df['ANCHOR_TIMESTAMP'] = train_time
1261
+ pred_df = pd.DataFrame({'ENTITY': test_pkey})
1262
+ if num_test_examples > 0:
1263
+ pred_df['TARGET'] = test_y
1264
+ if isinstance(test_time, pd.Series):
1265
+ pred_df['ANCHOR_TIMESTAMP'] = test_time
1266
+
1267
+ return TaskTable(
1268
+ task_type=task_type,
1269
+ context_df=context_df,
1270
+ pred_df=pred_df,
1271
+ entity_table_name=entity_table_names,
1272
+ entity_column='ENTITY',
1273
+ target_column='TARGET',
1274
+ time_column='ANCHOR_TIMESTAMP' if isinstance(
1275
+ train_time, pd.Series) else TaskTable.ENTITY_TIME,
1276
+ )
1277
+
1278
+ def _get_context(
1279
+ self,
1280
+ task: TaskTable,
1281
+ run_mode: RunMode | str = RunMode.FAST,
1282
+ num_neighbors: list[int] | None = None,
1283
+ exclude_cols_dict: dict[str, list[str]] | None = None,
1284
+ top_k: int | None = None,
1285
+ ) -> Context:
1286
+
1287
+ if num_neighbors is None:
1288
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1289
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
1290
+
1291
+ if len(num_neighbors) > 6:
1292
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
1293
+ f"hops (got {len(num_neighbors)}). Reduce the "
1294
+ f"number of hops and try again. Please create a "
1295
+ f"feature request at "
1296
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1297
+ f"must go beyond this for your use-case.")
1298
+
1023
1299
  # Exclude the entity anchor time from the feature set to prevent
1024
1300
  # running out-of-distribution between in-context and test examples:
1025
- exclude_cols_dict = query.get_exclude_cols_dict()
1026
- if anchor_time == 'entity':
1027
- if entity_table_names[0] not in exclude_cols_dict:
1028
- exclude_cols_dict[entity_table_names[0]] = []
1029
- time_column_dict = self._graph_store.time_column_dict
1030
- time_column = time_column_dict[entity_table_names[0]]
1031
- exclude_cols_dict[entity_table_names[0]].append(time_column)
1032
-
1033
- subgraph = self._graph_sampler(
1034
- entity_table_names=entity_table_names,
1035
- node=np.concatenate([train_node, test_node]),
1036
- time=np.concatenate([
1037
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1038
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- ]),
1301
+ exclude_cols_dict = exclude_cols_dict or {}
1302
+ if task.entity_table_name in self._sampler.time_column_dict:
1303
+ if task.entity_table_name not in exclude_cols_dict:
1304
+ exclude_cols_dict[task.entity_table_name] = []
1305
+ time_col = self._sampler.time_column_dict[task.entity_table_name]
1306
+ exclude_cols_dict[task.entity_table_name].append(time_col)
1307
+
1308
+ entity_pkey = pd.concat([
1309
+ task._context_df[task._entity_column],
1310
+ task._pred_df[task._entity_column],
1311
+ ], axis=0, ignore_index=True)
1312
+
1313
+ if task.use_entity_time:
1314
+ if task.entity_table_name not in self._sampler.time_column_dict:
1315
+ raise ValueError(f"The given annchor time requires the entity "
1316
+ f"table '{task.entity_table_name}' to have a "
1317
+ f"time column")
1318
+ anchor_time = 'entity'
1319
+ elif task._time_column is not None:
1320
+ anchor_time = pd.concat([
1321
+ task._context_df[task._time_column],
1322
+ task._pred_df[task._time_column],
1323
+ ], axis=0, ignore_index=True)
1324
+ else:
1325
+ anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
1326
+ (len(entity_pkey))).reset_index(drop=True)
1327
+
1328
+ subgraph = self._sampler.sample_subgraph(
1329
+ entity_table_names=task.entity_table_names,
1330
+ entity_pkey=entity_pkey,
1331
+ anchor_time=anchor_time,
1040
1332
  num_neighbors=num_neighbors,
1041
1333
  exclude_cols_dict=exclude_cols_dict,
1042
1334
  )
@@ -1048,23 +1340,26 @@ class KumoRFM:
1048
1340
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1049
1341
  f"must go beyond this for your use-case.")
1050
1342
 
1051
- step_size: Optional[int] = None
1052
- if query.query_type == QueryType.TEMPORAL:
1053
- step_size = date_offset_to_seconds(end_offset)
1343
+ if (task.task_type.is_link_pred
1344
+ and task.entity_table_names[-1] not in subgraph.table_dict):
1345
+ raise ValueError("Cannot perform link prediction on subgraphs "
1346
+ "without any historical target entities. Please "
1347
+ "increase the number of hops and try again.")
1054
1348
 
1055
1349
  return Context(
1056
- task_type=task_type,
1057
- entity_table_names=entity_table_names,
1350
+ task_type=task.task_type,
1351
+ entity_table_names=task.entity_table_names,
1058
1352
  subgraph=subgraph,
1059
- y_train=y_train,
1060
- y_test=y_test,
1061
- top_k=query.top_k,
1062
- step_size=step_size,
1353
+ y_train=task._context_df[task.target_column.name],
1354
+ y_test=task._pred_df[task.target_column.name]
1355
+ if task.evaluate else None,
1356
+ top_k=top_k,
1357
+ step_size=None,
1063
1358
  )
1064
1359
 
1065
1360
  @staticmethod
1066
1361
  def _validate_metrics(
1067
- metrics: List[str],
1362
+ metrics: list[str],
1068
1363
  task_type: TaskType,
1069
1364
  ) -> None:
1070
1365
 
@@ -1121,7 +1416,7 @@ class KumoRFM:
1121
1416
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1122
1417
 
1123
1418
 
1124
- def format_value(value: Union[int, float]) -> str:
1419
+ def format_value(value: int | float) -> str:
1125
1420
  if value == int(value):
1126
1421
  return f'{int(value):,}'
1127
1422
  if abs(value) >= 1000: