kumoai 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl → 2.15.0.dev202601141731__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +178 -50
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +456 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +22 -4
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/mapper.py +69 -0
  23. kumoai/experimental/rfm/base/sampler.py +696 -47
  24. kumoai/experimental/rfm/base/source.py +2 -1
  25. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  26. kumoai/experimental/rfm/base/table.py +384 -207
  27. kumoai/experimental/rfm/base/utils.py +36 -0
  28. kumoai/experimental/rfm/graph.py +359 -187
  29. kumoai/experimental/rfm/infer/__init__.py +6 -4
  30. kumoai/experimental/rfm/infer/dtype.py +10 -5
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +4 -2
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +5 -4
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +770 -467
  39. kumoai/experimental/rfm/sagemaker.py +4 -4
  40. kumoai/experimental/rfm/task_table.py +292 -0
  41. kumoai/kumolib.cp310-win_amd64.pyd +0 -0
  42. kumoai/pquery/predictive_query.py +10 -6
  43. kumoai/pquery/training_table.py +16 -2
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +87 -0
  48. kumoai/utils/progress_logger.py +192 -13
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +3 -2
  51. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +54 -42
  52. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  53. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  54. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
  55. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
  56. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,23 @@
1
1
  import json
2
+ import math
2
3
  import time
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from collections.abc import Generator
6
+ from collections.abc import Generator, Iterator
6
7
  from contextlib import contextmanager
7
8
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
9
+ from typing import Any, Literal, overload
19
10
 
20
- import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
24
21
  from kumoapi.rfm import Context
25
22
  from kumoapi.rfm import Explanation as ExplanationConfig
26
23
  from kumoapi.rfm import (
@@ -29,35 +26,38 @@ from kumoapi.rfm import (
29
26
  RFMPredictRequest,
30
27
  )
31
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
30
+ from rich.console import Console
31
+ from rich.markdown import Markdown
32
32
 
33
+ from kumoai import in_notebook
33
34
  from kumoai.client.rfm import RFMAPI
34
35
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import Graph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
36
+ from kumoai.experimental.rfm import Graph, TaskTable
37
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
38
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
39
+ from kumoai.utils import ProgressLogger, display
44
40
 
45
41
  _RANDOM_SEED = 42
46
42
 
47
43
  _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
44
  _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
45
 
46
+ _MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
47
+ _MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
48
+
50
49
  _MAX_CONTEXT_SIZE = {
51
50
  RunMode.DEBUG: 100,
52
51
  RunMode.FAST: 1_000,
53
52
  RunMode.NORMAL: 5_000,
54
53
  RunMode.BEST: 10_000,
55
54
  }
56
- _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
57
- RunMode.DEBUG: 100,
58
- RunMode.FAST: 2_000,
59
- RunMode.NORMAL: 2_000,
60
- RunMode.BEST: 2_000,
55
+
56
+ _DEFAULT_NUM_NEIGHBORS = {
57
+ RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
58
+ RunMode.FAST: [32, 32, 8, 8, 4, 4],
59
+ RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
60
+ RunMode.BEST: [64, 64, 8, 8, 4, 4],
61
61
  }
62
62
 
63
63
  _MAX_SIZE = 30 * 1024 * 1024
@@ -95,24 +95,36 @@ class Explanation:
95
95
  def __getitem__(self, index: Literal[1]) -> str:
96
96
  pass
97
97
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
98
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
99
  if index == 0:
100
100
  return self.prediction
101
101
  if index == 1:
102
102
  return self.summary
103
103
  raise IndexError("Index out of range")
104
104
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
105
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
106
  return iter((self.prediction, self.summary))
107
107
 
108
108
  def __repr__(self) -> str:
109
109
  return str((self.prediction, self.summary))
110
110
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
111
+ def __str__(self) -> str:
112
+ console = Console(soft_wrap=True)
113
+ with console.capture() as cap:
114
+ console.print(display.to_rich_table(self.prediction))
115
+ console.print(Markdown(self.summary))
116
+ return cap.get()[:-1]
117
+
118
+ def print(self) -> None:
119
+ r"""Prints the explanation."""
120
+ if in_notebook():
121
+ display.dataframe(self.prediction)
122
+ display.message(self.summary)
123
+ else:
124
+ print(self)
113
125
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
126
+ def _ipython_display_(self) -> None:
127
+ self.print()
116
128
 
117
129
 
118
130
  class KumoRFM:
@@ -151,21 +163,36 @@ class KumoRFM:
151
163
  Args:
152
164
  graph: The graph.
153
165
  verbose: Whether to print verbose output.
166
+ optimize: If set to ``True``, will optimize the underlying data backend
167
+ for optimal querying. For example, for transactional database
168
+ backends, will create any missing indices. Requires write-access to
169
+ the data backend.
154
170
  """
155
171
  def __init__(
156
172
  self,
157
173
  graph: Graph,
158
- verbose: Union[bool, ProgressLogger] = True,
174
+ verbose: bool | ProgressLogger = True,
175
+ optimize: bool = False,
159
176
  ) -> None:
160
177
  graph = graph.validate()
161
178
  self._graph_def = graph._to_api_graph_definition()
162
- self._graph_store = LocalGraphStore(graph, verbose)
163
- self._graph_sampler = LocalGraphSampler(self._graph_store)
164
179
 
165
- self._client: Optional[RFMAPI] = None
180
+ if graph.backend == DataBackend.LOCAL:
181
+ from kumoai.experimental.rfm.backend.local import LocalSampler
182
+ self._sampler: Sampler = LocalSampler(graph, verbose)
183
+ elif graph.backend == DataBackend.SQLITE:
184
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
185
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
186
+ elif graph.backend == DataBackend.SNOWFLAKE:
187
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
188
+ self._sampler = SnowSampler(graph, verbose)
189
+ else:
190
+ raise NotImplementedError
166
191
 
167
- self._batch_size: Optional[int | Literal['max']] = None
168
- self.num_retries: int = 0
192
+ self._client: RFMAPI | None = None
193
+
194
+ self._batch_size: int | Literal['max'] | None = None
195
+ self._num_retries: int = 0
169
196
 
170
197
  @property
171
198
  def _api_client(self) -> RFMAPI:
@@ -179,10 +206,34 @@ class KumoRFM:
179
206
  def __repr__(self) -> str:
180
207
  return f'{self.__class__.__name__}()'
181
208
 
209
+ @contextmanager
210
+ def retry(
211
+ self,
212
+ num_retries: int = 1,
213
+ ) -> Generator[None, None, None]:
214
+ """Context manager to retry failed queries due to unexpected server
215
+ issues.
216
+
217
+ .. code-block:: python
218
+
219
+ with model.retry(num_retries=1):
220
+ df = model.predict(query, indices=...)
221
+
222
+ Args:
223
+ num_retries: The maximum number of retries.
224
+ """
225
+ if num_retries < 0:
226
+ raise ValueError(f"'num_retries' must be greater than or equal to "
227
+ f"zero (got {num_retries})")
228
+
229
+ self._num_retries = num_retries
230
+ yield
231
+ self._num_retries = 0
232
+
182
233
  @contextmanager
183
234
  def batch_mode(
184
235
  self,
185
- batch_size: Union[int, Literal['max']] = 'max',
236
+ batch_size: int | Literal['max'] = 'max',
186
237
  num_retries: int = 1,
187
238
  ) -> Generator[None, None, None]:
188
239
  """Context manager to predict in batches.
@@ -202,31 +253,26 @@ class KumoRFM:
202
253
  raise ValueError(f"'batch_size' must be greater than zero "
203
254
  f"(got {batch_size})")
204
255
 
205
- if num_retries < 0:
206
- raise ValueError(f"'num_retries' must be greater than or equal to "
207
- f"zero (got {num_retries})")
208
-
209
256
  self._batch_size = batch_size
210
- self.num_retries = num_retries
211
- yield
257
+ with self.retry(self._num_retries or num_retries):
258
+ yield
212
259
  self._batch_size = None
213
- self.num_retries = 0
214
260
 
215
261
  @overload
216
262
  def predict(
217
263
  self,
218
264
  query: str,
219
- indices: Union[List[str], List[float], List[int], None] = None,
265
+ indices: list[str] | list[float] | list[int] | None = None,
220
266
  *,
221
267
  explain: Literal[False] = False,
222
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
223
- context_anchor_time: Union[pd.Timestamp, None] = None,
224
- run_mode: Union[RunMode, str] = RunMode.FAST,
225
- num_neighbors: Optional[List[int]] = None,
268
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
269
+ context_anchor_time: pd.Timestamp | None = None,
270
+ run_mode: RunMode | str = RunMode.FAST,
271
+ num_neighbors: list[int] | None = None,
226
272
  num_hops: int = 2,
227
- max_pq_iterations: int = 20,
228
- random_seed: Optional[int] = _RANDOM_SEED,
229
- verbose: Union[bool, ProgressLogger] = True,
273
+ max_pq_iterations: int = 10,
274
+ random_seed: int | None = _RANDOM_SEED,
275
+ verbose: bool | ProgressLogger = True,
230
276
  use_prediction_time: bool = False,
231
277
  ) -> pd.DataFrame:
232
278
  pass
@@ -235,37 +281,56 @@ class KumoRFM:
235
281
  def predict(
236
282
  self,
237
283
  query: str,
238
- indices: Union[List[str], List[float], List[int], None] = None,
284
+ indices: list[str] | list[float] | list[int] | None = None,
239
285
  *,
240
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
241
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
242
- context_anchor_time: Union[pd.Timestamp, None] = None,
243
- run_mode: Union[RunMode, str] = RunMode.FAST,
244
- num_neighbors: Optional[List[int]] = None,
286
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
287
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
288
+ context_anchor_time: pd.Timestamp | None = None,
289
+ run_mode: RunMode | str = RunMode.FAST,
290
+ num_neighbors: list[int] | None = None,
245
291
  num_hops: int = 2,
246
- max_pq_iterations: int = 20,
247
- random_seed: Optional[int] = _RANDOM_SEED,
248
- verbose: Union[bool, ProgressLogger] = True,
292
+ max_pq_iterations: int = 10,
293
+ random_seed: int | None = _RANDOM_SEED,
294
+ verbose: bool | ProgressLogger = True,
249
295
  use_prediction_time: bool = False,
250
296
  ) -> Explanation:
251
297
  pass
252
298
 
299
+ @overload
300
+ def predict(
301
+ self,
302
+ query: str,
303
+ indices: list[str] | list[float] | list[int] | None = None,
304
+ *,
305
+ explain: bool | ExplainConfig | dict[str, Any] = False,
306
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
307
+ context_anchor_time: pd.Timestamp | None = None,
308
+ run_mode: RunMode | str = RunMode.FAST,
309
+ num_neighbors: list[int] | None = None,
310
+ num_hops: int = 2,
311
+ max_pq_iterations: int = 10,
312
+ random_seed: int | None = _RANDOM_SEED,
313
+ verbose: bool | ProgressLogger = True,
314
+ use_prediction_time: bool = False,
315
+ ) -> pd.DataFrame | Explanation:
316
+ pass
317
+
253
318
  def predict(
254
319
  self,
255
320
  query: str,
256
- indices: Union[List[str], List[float], List[int], None] = None,
321
+ indices: list[str] | list[float] | list[int] | None = None,
257
322
  *,
258
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
259
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
260
- context_anchor_time: Union[pd.Timestamp, None] = None,
261
- run_mode: Union[RunMode, str] = RunMode.FAST,
262
- num_neighbors: Optional[List[int]] = None,
323
+ explain: bool | ExplainConfig | dict[str, Any] = False,
324
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
325
+ context_anchor_time: pd.Timestamp | None = None,
326
+ run_mode: RunMode | str = RunMode.FAST,
327
+ num_neighbors: list[int] | None = None,
263
328
  num_hops: int = 2,
264
- max_pq_iterations: int = 20,
265
- random_seed: Optional[int] = _RANDOM_SEED,
266
- verbose: Union[bool, ProgressLogger] = True,
329
+ max_pq_iterations: int = 10,
330
+ random_seed: int | None = _RANDOM_SEED,
331
+ verbose: bool | ProgressLogger = True,
267
332
  use_prediction_time: bool = False,
268
- ) -> Union[pd.DataFrame, Explanation]:
333
+ ) -> pd.DataFrame | Explanation:
269
334
  """Returns predictions for a predictive query.
270
335
 
271
336
  Args:
@@ -273,8 +338,7 @@ class KumoRFM:
273
338
  indices: The entity primary keys to predict for. Will override the
274
339
  indices given as part of the predictive query. Predictions will
275
340
  be generated for all indices, independent of whether they
276
- fulfill entity filter constraints. To pre-filter entities, use
277
- :meth:`~KumoRFM.is_valid_entity`.
341
+ fulfill entity filter constraints.
278
342
  explain: Configuration for explainability.
279
343
  If set to ``True``, will additionally explain the prediction.
280
344
  Passing in an :class:`ExplainConfig` instance provides control
@@ -307,18 +371,152 @@ class KumoRFM:
307
371
  If ``explain`` is provided, returns an :class:`Explanation` object
308
372
  containing the prediction, summary, and details.
309
373
  """
310
- explain_config: Optional[ExplainConfig] = None
311
- if explain is True:
312
- explain_config = ExplainConfig()
313
- elif explain is not False:
314
- explain_config = ExplainConfig._cast(explain)
315
-
316
374
  query_def = self._parse_query(query)
317
- query_str = query_def.to_string()
318
375
 
376
+ if indices is None:
377
+ if query_def.rfm_entity_ids is None:
378
+ raise ValueError("Cannot find entities to predict for. Please "
379
+ "pass them via `predict(query, indices=...)`")
380
+ indices = query_def.get_rfm_entity_id_list()
381
+ query_def = replace(
382
+ query_def,
383
+ for_each='FOR EACH',
384
+ rfm_entity_ids=None,
385
+ )
386
+
387
+ if not isinstance(verbose, ProgressLogger):
388
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
389
+ if explain is not False:
390
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
391
+ else:
392
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
393
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
394
+
395
+ with verbose as logger:
396
+ task_table = self._get_task_table(
397
+ query=query_def,
398
+ indices=indices,
399
+ anchor_time=anchor_time,
400
+ context_anchor_time=context_anchor_time,
401
+ run_mode=run_mode,
402
+ max_pq_iterations=max_pq_iterations,
403
+ random_seed=random_seed,
404
+ logger=logger,
405
+ )
406
+ task_table._query = query_def.to_string()
407
+
408
+ return self.predict_task(
409
+ task_table,
410
+ explain=explain,
411
+ run_mode=run_mode,
412
+ num_neighbors=num_neighbors,
413
+ num_hops=num_hops,
414
+ verbose=verbose,
415
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
416
+ use_prediction_time=use_prediction_time,
417
+ top_k=query_def.top_k,
418
+ )
419
+
420
+ @overload
421
+ def predict_task(
422
+ self,
423
+ task: TaskTable,
424
+ *,
425
+ explain: Literal[False] = False,
426
+ run_mode: RunMode | str = RunMode.FAST,
427
+ num_neighbors: list[int] | None = None,
428
+ num_hops: int = 2,
429
+ verbose: bool | ProgressLogger = True,
430
+ exclude_cols_dict: dict[str, list[str]] | None = None,
431
+ use_prediction_time: bool = False,
432
+ top_k: int | None = None,
433
+ ) -> pd.DataFrame:
434
+ pass
435
+
436
+ @overload
437
+ def predict_task(
438
+ self,
439
+ task: TaskTable,
440
+ *,
441
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
442
+ run_mode: RunMode | str = RunMode.FAST,
443
+ num_neighbors: list[int] | None = None,
444
+ num_hops: int = 2,
445
+ verbose: bool | ProgressLogger = True,
446
+ exclude_cols_dict: dict[str, list[str]] | None = None,
447
+ use_prediction_time: bool = False,
448
+ top_k: int | None = None,
449
+ ) -> Explanation:
450
+ pass
451
+
452
+ @overload
453
+ def predict_task(
454
+ self,
455
+ task: TaskTable,
456
+ *,
457
+ explain: bool | ExplainConfig | dict[str, Any] = False,
458
+ run_mode: RunMode | str = RunMode.FAST,
459
+ num_neighbors: list[int] | None = None,
460
+ num_hops: int = 2,
461
+ verbose: bool | ProgressLogger = True,
462
+ exclude_cols_dict: dict[str, list[str]] | None = None,
463
+ use_prediction_time: bool = False,
464
+ top_k: int | None = None,
465
+ ) -> pd.DataFrame | Explanation:
466
+ pass
467
+
468
+ def predict_task(
469
+ self,
470
+ task: TaskTable,
471
+ *,
472
+ explain: bool | ExplainConfig | dict[str, Any] = False,
473
+ run_mode: RunMode | str = RunMode.FAST,
474
+ num_neighbors: list[int] | None = None,
475
+ num_hops: int = 2,
476
+ verbose: bool | ProgressLogger = True,
477
+ exclude_cols_dict: dict[str, list[str]] | None = None,
478
+ use_prediction_time: bool = False,
479
+ top_k: int | None = None,
480
+ ) -> pd.DataFrame | Explanation:
481
+ """Returns predictions for a custom task specification.
482
+
483
+ Args:
484
+ task: The custom :class:`TaskTable`.
485
+ explain: Configuration for explainability.
486
+ If set to ``True``, will additionally explain the prediction.
487
+ Passing in an :class:`ExplainConfig` instance provides control
488
+ over which parts of explanation are generated.
489
+ Explainability is currently only supported for single entity
490
+ predictions with ``run_mode="FAST"``.
491
+ run_mode: The :class:`RunMode` for the query.
492
+ num_neighbors: The number of neighbors to sample for each hop.
493
+ If specified, the ``num_hops`` option will be ignored.
494
+ num_hops: The number of hops to sample when generating the context.
495
+ verbose: Whether to print verbose output.
496
+ exclude_cols_dict: Any column in any table to exclude from the
497
+ model input.
498
+ use_prediction_time: Whether to use the anchor timestamp as an
499
+ additional feature during prediction. This is typically
500
+ beneficial for time series forecasting tasks.
501
+ top_k: The number of predictions to return per entity.
502
+
503
+ Returns:
504
+ The predictions as a :class:`pandas.DataFrame`.
505
+ If ``explain`` is provided, returns an :class:`Explanation` object
506
+ containing the prediction, summary, and details.
507
+ """
319
508
  if num_hops != 2 and num_neighbors is not None:
320
509
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
321
510
  f"custom 'num_hops={num_hops}' option")
511
+ if num_neighbors is None:
512
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
513
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
514
+
515
+ explain_config: ExplainConfig | None = None
516
+ if explain is True:
517
+ explain_config = ExplainConfig()
518
+ elif explain is not False:
519
+ explain_config = ExplainConfig._cast(explain)
322
520
 
323
521
  if explain_config is not None and run_mode in {
324
522
  RunMode.NORMAL, RunMode.BEST
@@ -327,83 +525,82 @@ class KumoRFM:
327
525
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
328
526
  f"mode has been reset. Please lower the run mode to "
329
527
  f"suppress this warning.")
528
+ run_mode = RunMode.FAST
330
529
 
331
- if indices is None:
332
- if query_def.rfm_entity_ids is None:
333
- raise ValueError("Cannot find entities to predict for. Please "
334
- "pass them via `predict(query, indices=...)`")
335
- indices = query_def.get_rfm_entity_id_list()
336
- else:
337
- query_def = replace(query_def, rfm_entity_ids=None)
338
-
339
- if len(indices) == 0:
340
- raise ValueError("At least one entity is required")
341
-
342
- if explain_config is not None and len(indices) > 1:
343
- raise ValueError(
344
- f"Cannot explain predictions for more than a single entity "
345
- f"(got {len(indices)})")
346
-
347
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
348
- if explain_config is not None:
349
- msg = f'[bold]EXPLAIN[/bold] {query_repr}'
350
- else:
351
- msg = f'[bold]PREDICT[/bold] {query_repr}'
530
+ if explain_config is not None and task.num_prediction_examples > 1:
531
+ raise ValueError(f"Cannot explain predictions for more than a "
532
+ f"single entity "
533
+ f"(got {task.num_prediction_examples:,})")
352
534
 
353
535
  if not isinstance(verbose, ProgressLogger):
354
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
355
-
356
- with verbose as logger:
357
-
358
- batch_size: Optional[int] = None
359
- if self._batch_size == 'max':
360
- task_type = LocalPQueryDriver.get_task_type(
361
- query_def,
362
- edge_types=self._graph_store.edge_types,
363
- )
364
- batch_size = _MAX_PRED_SIZE[task_type]
536
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
537
+ task_type_repr = 'binary classification'
538
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
539
+ task_type_repr = 'multi-class classification'
540
+ elif task.task_type == TaskType.REGRESSION:
541
+ task_type_repr = 'regression'
542
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
543
+ task_type_repr = 'link prediction'
365
544
  else:
366
- batch_size = self._batch_size
545
+ task_type_repr = str(task.task_type)
367
546
 
368
- if batch_size is not None:
369
- offsets = range(0, len(indices), batch_size)
370
- batches = [indices[step:step + batch_size] for step in offsets]
547
+ if explain_config is not None:
548
+ msg = f"Explaining {task_type_repr} task"
371
549
  else:
372
- batches = [indices]
550
+ msg = f"Predicting {task_type_repr} task"
551
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
373
552
 
374
- if len(batches) > 1:
375
- logger.log(f"Splitting {len(indices):,} entities into "
376
- f"{len(batches):,} batches of size {batch_size:,}")
553
+ with verbose as logger:
554
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
555
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
556
+ f"out of {task.num_context_examples:,} in-context "
557
+ f"examples")
558
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
559
+
560
+ if self._batch_size is None:
561
+ batch_size = task.num_prediction_examples
562
+ elif self._batch_size == 'max':
563
+ batch_size = _MAX_PRED_SIZE[task.task_type]
564
+ else:
565
+ batch_size = self._batch_size
377
566
 
378
- predictions: List[pd.DataFrame] = []
379
- summary: Optional[str] = None
380
- details: Optional[Explanation] = None
381
- for i, batch in enumerate(batches):
382
- # TODO Re-use the context for subsequent predictions.
567
+ if batch_size > _MAX_PRED_SIZE[task.task_type]:
568
+ raise ValueError(f"Cannot predict for more than "
569
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
570
+ f"entities at once (got {batch_size:,}). Use "
571
+ f"`KumoRFM.batch_mode` to process entities "
572
+ f"in batches with a sufficient batch size.")
573
+
574
+ if task.num_prediction_examples > batch_size:
575
+ num = math.ceil(task.num_prediction_examples / batch_size)
576
+ logger.log(f"Splitting {task.num_prediction_examples:,} "
577
+ f"entities into {num:,} batches of size "
578
+ f"{batch_size:,}")
579
+
580
+ predictions: list[pd.DataFrame] = []
581
+ summary: str | None = None
582
+ details: Explanation | None = None
583
+ for start in range(0, task.num_prediction_examples, batch_size):
383
584
  context = self._get_context(
384
- query=query_def,
385
- indices=batch,
386
- anchor_time=anchor_time,
387
- context_anchor_time=context_anchor_time,
388
- run_mode=RunMode(run_mode),
585
+ task=task.narrow_prediction(start, length=batch_size),
586
+ run_mode=run_mode,
389
587
  num_neighbors=num_neighbors,
390
- num_hops=num_hops,
391
- max_pq_iterations=max_pq_iterations,
392
- evaluate=False,
393
- random_seed=random_seed,
394
- logger=logger if i == 0 else None,
588
+ exclude_cols_dict=exclude_cols_dict,
589
+ top_k=top_k,
395
590
  )
591
+ context.y_test = None
592
+
396
593
  request = RFMPredictRequest(
397
594
  context=context,
398
595
  run_mode=RunMode(run_mode),
399
- query=query_str,
596
+ query=task._query,
400
597
  use_prediction_time=use_prediction_time,
401
598
  )
402
599
  with warnings.catch_warnings():
403
600
  warnings.filterwarnings('ignore', message='gencode')
404
601
  request_msg = request.to_protobuf()
405
602
  _bytes = request_msg.SerializeToString()
406
- if i == 0:
603
+ if start == 0:
407
604
  logger.log(f"Generated context of size "
408
605
  f"{len(_bytes) / (1024*1024):.2f}MB")
409
606
 
@@ -411,14 +608,11 @@ class KumoRFM:
411
608
  stats = Context.get_memory_stats(request_msg.context)
412
609
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
413
610
 
414
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
415
- and len(batches) > 1):
416
- verbose.init_progress(
417
- total=len(batches),
418
- description='Predicting',
419
- )
611
+ if start == 0 and task.num_prediction_examples > batch_size:
612
+ num = math.ceil(task.num_prediction_examples / batch_size)
613
+ verbose.init_progress(total=num, description='Predicting')
420
614
 
421
- for attempt in range(self.num_retries + 1):
615
+ for attempt in range(self._num_retries + 1):
422
616
  try:
423
617
  if explain_config is not None:
424
618
  resp = self._api_client.explain(
@@ -433,10 +627,10 @@ class KumoRFM:
433
627
 
434
628
  # Cast 'ENTITY' to correct data type:
435
629
  if 'ENTITY' in df:
436
- entity = query_def.entity_table
437
- pkey_map = self._graph_store.pkey_map_dict[entity]
438
- df['ENTITY'] = df['ENTITY'].astype(
439
- type(pkey_map.index[0]))
630
+ table_dict = context.subgraph.table_dict
631
+ table = table_dict[context.entity_table_names[0]]
632
+ ser = table.df[table.primary_key]
633
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
440
634
 
441
635
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
442
636
  if 'ANCHOR_TIMESTAMP' in df:
@@ -451,13 +645,12 @@ class KumoRFM:
451
645
 
452
646
  predictions.append(df)
453
647
 
454
- if (isinstance(verbose, InteractiveProgressLogger)
455
- and len(batches) > 1):
648
+ if task.num_prediction_examples > batch_size:
456
649
  verbose.step()
457
650
 
458
651
  break
459
652
  except HTTPException as e:
460
- if attempt == self.num_retries:
653
+ if attempt == self._num_retries:
461
654
  try:
462
655
  msg = json.loads(e.detail)['detail']
463
656
  except Exception:
@@ -487,69 +680,19 @@ class KumoRFM:
487
680
 
488
681
  return prediction
489
682
 
490
- def is_valid_entity(
491
- self,
492
- query: str,
493
- indices: Union[List[str], List[float], List[int], None] = None,
494
- *,
495
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
496
- ) -> np.ndarray:
497
- r"""Returns a mask that denotes which entities are valid for the
498
- given predictive query, *i.e.*, which entities fulfill (temporal)
499
- entity filter constraints.
500
-
501
- Args:
502
- query: The predictive query.
503
- indices: The entity primary keys to predict for. Will override the
504
- indices given as part of the predictive query.
505
- anchor_time: The anchor timestamp for the prediction. If set to
506
- ``None``, will use the maximum timestamp in the data.
507
- If set to ``"entity"``, will use the timestamp of the entity.
508
- """
509
- query_def = self._parse_query(query)
510
-
511
- if indices is None:
512
- if query_def.rfm_entity_ids is None:
513
- raise ValueError("Cannot find entities to predict for. Please "
514
- "pass them via "
515
- "`is_valid_entity(query, indices=...)`")
516
- indices = query_def.get_rfm_entity_id_list()
517
-
518
- if len(indices) == 0:
519
- raise ValueError("At least one entity is required")
520
-
521
- if anchor_time is None:
522
- anchor_time = self._graph_store.max_time
523
-
524
- if isinstance(anchor_time, pd.Timestamp):
525
- self._validate_time(query_def, anchor_time, None, False)
526
- else:
527
- assert anchor_time == 'entity'
528
- if (query_def.entity_table not in self._graph_store.time_dict):
529
- raise ValueError(f"Anchor time 'entity' requires the entity "
530
- f"table '{query_def.entity_table}' "
531
- f"to have a time column.")
532
-
533
- node = self._graph_store.get_node_id(
534
- table_name=query_def.entity_table,
535
- pkey=pd.Series(indices),
536
- )
537
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
538
- return query_driver.is_valid(node, anchor_time)
539
-
540
683
  def evaluate(
541
684
  self,
542
685
  query: str,
543
686
  *,
544
- metrics: Optional[List[str]] = None,
545
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
546
- context_anchor_time: Union[pd.Timestamp, None] = None,
547
- run_mode: Union[RunMode, str] = RunMode.FAST,
548
- num_neighbors: Optional[List[int]] = None,
687
+ metrics: list[str] | None = None,
688
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
689
+ context_anchor_time: pd.Timestamp | None = None,
690
+ run_mode: RunMode | str = RunMode.FAST,
691
+ num_neighbors: list[int] | None = None,
549
692
  num_hops: int = 2,
550
- max_pq_iterations: int = 20,
551
- random_seed: Optional[int] = _RANDOM_SEED,
552
- verbose: Union[bool, ProgressLogger] = True,
693
+ max_pq_iterations: int = 10,
694
+ random_seed: int | None = _RANDOM_SEED,
695
+ verbose: bool | ProgressLogger = True,
553
696
  use_prediction_time: bool = False,
554
697
  ) -> pd.DataFrame:
555
698
  """Evaluates a predictive query.
@@ -581,41 +724,120 @@ class KumoRFM:
581
724
  Returns:
582
725
  The metrics as a :class:`pandas.DataFrame`
583
726
  """
584
- query_def = self._parse_query(query)
727
+ query_def = replace(
728
+ self._parse_query(query),
729
+ for_each='FOR EACH',
730
+ rfm_entity_ids=None,
731
+ )
732
+
733
+ if not isinstance(verbose, ProgressLogger):
734
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
735
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
736
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
585
737
 
738
+ with verbose as logger:
739
+ task_table = self._get_task_table(
740
+ query=query_def,
741
+ indices=None,
742
+ anchor_time=anchor_time,
743
+ context_anchor_time=context_anchor_time,
744
+ run_mode=run_mode,
745
+ max_pq_iterations=max_pq_iterations,
746
+ random_seed=random_seed,
747
+ logger=logger,
748
+ )
749
+
750
+ return self.evaluate_task(
751
+ task_table,
752
+ metrics=metrics,
753
+ run_mode=run_mode,
754
+ num_neighbors=num_neighbors,
755
+ num_hops=num_hops,
756
+ verbose=verbose,
757
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
758
+ use_prediction_time=use_prediction_time,
759
+ )
760
+
761
+ def evaluate_task(
762
+ self,
763
+ task: TaskTable,
764
+ *,
765
+ metrics: list[str] | None = None,
766
+ run_mode: RunMode | str = RunMode.FAST,
767
+ num_neighbors: list[int] | None = None,
768
+ num_hops: int = 2,
769
+ verbose: bool | ProgressLogger = True,
770
+ exclude_cols_dict: dict[str, list[str]] | None = None,
771
+ use_prediction_time: bool = False,
772
+ ) -> pd.DataFrame:
773
+ """Evaluates a custom task specification.
774
+
775
+ Args:
776
+ task: The custom :class:`TaskTable`.
777
+ metrics: The metrics to use.
778
+ run_mode: The :class:`RunMode` for the query.
779
+ num_neighbors: The number of neighbors to sample for each hop.
780
+ If specified, the ``num_hops`` option will be ignored.
781
+ num_hops: The number of hops to sample when generating the context.
782
+ verbose: Whether to print verbose output.
783
+ exclude_cols_dict: Any column in any table to exclude from the
784
+ model input.
785
+ use_prediction_time: Whether to use the anchor timestamp as an
786
+ additional feature during prediction. This is typically
787
+ beneficial for time series forecasting tasks.
788
+
789
+ Returns:
790
+ The metrics as a :class:`pandas.DataFrame`
791
+ """
586
792
  if num_hops != 2 and num_neighbors is not None:
587
793
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
588
794
  f"custom 'num_hops={num_hops}' option")
795
+ if num_neighbors is None:
796
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
797
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
589
798
 
590
- if query_def.rfm_entity_ids is not None:
591
- query_def = replace(
592
- query_def,
593
- rfm_entity_ids=None,
594
- )
595
-
596
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
597
- msg = f'[bold]EVALUATE[/bold] {query_repr}'
799
+ if metrics is not None and len(metrics) > 0:
800
+ self._validate_metrics(metrics, task.task_type)
801
+ metrics = list(dict.fromkeys(metrics))
598
802
 
599
803
  if not isinstance(verbose, ProgressLogger):
600
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
804
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
805
+ task_type_repr = 'binary classification'
806
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
807
+ task_type_repr = 'multi-class classification'
808
+ elif task.task_type == TaskType.REGRESSION:
809
+ task_type_repr = 'regression'
810
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
811
+ task_type_repr = 'link prediction'
812
+ else:
813
+ task_type_repr = str(task.task_type)
814
+
815
+ msg = f"Evaluating {task_type_repr} task"
816
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
601
817
 
602
818
  with verbose as logger:
819
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
820
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
821
+ f"out of {task.num_context_examples:,} in-context "
822
+ f"examples")
823
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
824
+
825
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
826
+ logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
827
+ f"out of {task.num_prediction_examples:,} test "
828
+ f"examples")
829
+ task = task.narrow_prediction(
830
+ start=0,
831
+ length=_MAX_TEST_SIZE[task.task_type],
832
+ )
833
+
603
834
  context = self._get_context(
604
- query=query_def,
605
- indices=None,
606
- anchor_time=anchor_time,
607
- context_anchor_time=context_anchor_time,
608
- run_mode=RunMode(run_mode),
835
+ task=task,
836
+ run_mode=run_mode,
609
837
  num_neighbors=num_neighbors,
610
- num_hops=num_hops,
611
- max_pq_iterations=max_pq_iterations,
612
- evaluate=True,
613
- random_seed=random_seed,
614
- logger=logger if verbose else None,
838
+ exclude_cols_dict=exclude_cols_dict,
615
839
  )
616
- if metrics is not None and len(metrics) > 0:
617
- self._validate_metrics(metrics, context.task_type)
618
- metrics = list(dict.fromkeys(metrics))
840
+
619
841
  request = RFMEvaluateRequest(
620
842
  context=context,
621
843
  run_mode=RunMode(run_mode),
@@ -633,17 +855,23 @@ class KumoRFM:
633
855
  stats_msg = Context.get_memory_stats(request_msg.context)
634
856
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
635
857
 
636
- try:
637
- resp = self._api_client.evaluate(request_bytes)
638
- except HTTPException as e:
858
+ for attempt in range(self._num_retries + 1):
639
859
  try:
640
- msg = json.loads(e.detail)['detail']
641
- except Exception:
642
- msg = e.detail
643
- raise RuntimeError(f"An unexpected exception occurred. "
644
- f"Please create an issue at "
645
- f"'https://github.com/kumo-ai/kumo-rfm'. "
646
- f"{msg}") from None
860
+ resp = self._api_client.evaluate(request_bytes)
861
+ break
862
+ except HTTPException as e:
863
+ if attempt == self._num_retries:
864
+ try:
865
+ msg = json.loads(e.detail)['detail']
866
+ except Exception:
867
+ msg = e.detail
868
+ raise RuntimeError(
869
+ f"An unexpected exception occurred. Please create "
870
+ f"an issue at "
871
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
872
+ ) from None
873
+
874
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
647
875
 
648
876
  return pd.DataFrame.from_dict(
649
877
  resp.metrics,
@@ -656,9 +884,9 @@ class KumoRFM:
656
884
  query: str,
657
885
  size: int,
658
886
  *,
659
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
660
- random_seed: Optional[int] = _RANDOM_SEED,
661
- max_iterations: int = 20,
887
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
888
+ random_seed: int | None = _RANDOM_SEED,
889
+ max_iterations: int = 10,
662
890
  ) -> pd.DataFrame:
663
891
  """Returns the labels of a predictive query for a specified anchor
664
892
  time.
@@ -678,40 +906,37 @@ class KumoRFM:
678
906
  query_def = self._parse_query(query)
679
907
 
680
908
  if anchor_time is None:
681
- anchor_time = self._graph_store.max_time
909
+ anchor_time = self._get_default_anchor_time(query_def)
682
910
  if query_def.target_ast.date_offset_range is not None:
683
- anchor_time = anchor_time - (
684
- query_def.target_ast.date_offset_range.end_date_offset *
685
- query_def.num_forecasts)
911
+ offset = query_def.target_ast.date_offset_range.end_date_offset
912
+ offset *= query_def.num_forecasts
913
+ anchor_time -= offset
686
914
 
687
915
  assert anchor_time is not None
688
916
  if isinstance(anchor_time, pd.Timestamp):
689
917
  self._validate_time(query_def, anchor_time, None, evaluate=True)
690
918
  else:
691
919
  assert anchor_time == 'entity'
692
- if (query_def.entity_table not in self._graph_store.time_dict):
920
+ if query_def.entity_table not in self._sampler.time_column_dict:
693
921
  raise ValueError(f"Anchor time 'entity' requires the entity "
694
922
  f"table '{query_def.entity_table}' "
695
923
  f"to have a time column")
696
924
 
697
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
698
- random_seed)
699
-
700
- node, time, y = query_driver.collect_test(
701
- size=size,
702
- anchor_time=anchor_time,
703
- batch_size=min(10_000, size),
704
- max_iterations=max_iterations,
705
- guarantee_train_examples=False,
925
+ train, test = self._sampler.sample_target(
926
+ query=query_def,
927
+ num_train_examples=0,
928
+ train_anchor_time=anchor_time,
929
+ num_train_trials=0,
930
+ num_test_examples=size,
931
+ test_anchor_time=anchor_time,
932
+ num_test_trials=max_iterations * size,
933
+ random_seed=random_seed,
706
934
  )
707
935
 
708
- entity = self._graph_store.pkey_map_dict[
709
- query_def.entity_table].index[node]
710
-
711
936
  return pd.DataFrame({
712
- 'ENTITY': entity,
713
- 'ANCHOR_TIMESTAMP': time,
714
- 'TARGET': y,
937
+ 'ENTITY': test.entity_pkey,
938
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
939
+ 'TARGET': test.target,
715
940
  })
716
941
 
717
942
  # Helpers #################################################################
@@ -726,63 +951,128 @@ class KumoRFM:
726
951
  "`predict()` or `evaluate()` methods to perform "
727
952
  "predictions or evaluations.")
728
953
 
729
- try:
730
- request = RFMParseQueryRequest(
731
- query=query,
732
- graph_definition=self._graph_def,
733
- )
954
+ request = RFMParseQueryRequest(
955
+ query=query,
956
+ graph_definition=self._graph_def,
957
+ )
958
+
959
+ for attempt in range(self._num_retries + 1):
960
+ try:
961
+ resp = self._api_client.parse_query(request)
962
+ break
963
+ except HTTPException as e:
964
+ if attempt == self._num_retries:
965
+ try:
966
+ msg = json.loads(e.detail)['detail']
967
+ except Exception:
968
+ msg = e.detail
969
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
734
970
 
735
- resp = self._api_client.parse_query(request)
971
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
736
972
 
737
- # TODO Expose validation warnings.
973
+ if len(resp.validation_response.warnings) > 0:
974
+ msg = '\n'.join([
975
+ f'{i+1}. {warning.title}: {warning.message}'
976
+ for i, warning in enumerate(resp.validation_response.warnings)
977
+ ])
978
+ warnings.warn(f"Encountered the following warnings during "
979
+ f"parsing:\n{msg}")
738
980
 
739
- if len(resp.validation_response.warnings) > 0:
740
- msg = '\n'.join([
741
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
742
- in enumerate(resp.validation_response.warnings)
743
- ])
744
- warnings.warn(f"Encountered the following warnings during "
745
- f"parsing:\n{msg}")
981
+ return resp.query
746
982
 
747
- return resp.query
748
- except HTTPException as e:
749
- try:
750
- msg = json.loads(e.detail)['detail']
751
- except Exception:
752
- msg = e.detail
753
- raise ValueError(f"Failed to parse query '{query}'. "
754
- f"{msg}") from None
983
+ @staticmethod
984
+ def _get_task_type(
985
+ query: ValidatedPredictiveQuery,
986
+ edge_types: list[tuple[str, str, str]],
987
+ ) -> TaskType:
988
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
989
+ return TaskType.BINARY_CLASSIFICATION
990
+
991
+ target = query.target_ast
992
+ if isinstance(target, Join):
993
+ target = target.rhs_target
994
+ if isinstance(target, Aggregation):
995
+ if target.aggr == AggregationType.LIST_DISTINCT:
996
+ table_name, col_name = target._get_target_column_name().split(
997
+ '.')
998
+ target_edge_types = [
999
+ edge_type for edge_type in edge_types
1000
+ if edge_type[0] == table_name and edge_type[1] == col_name
1001
+ ]
1002
+ if len(target_edge_types) != 1:
1003
+ raise NotImplementedError(
1004
+ f"Multilabel-classification queries based on "
1005
+ f"'LIST_DISTINCT' are not supported yet. If you "
1006
+ f"planned to write a link prediction query instead, "
1007
+ f"make sure to register '{col_name}' as a "
1008
+ f"foreign key.")
1009
+ return TaskType.TEMPORAL_LINK_PREDICTION
1010
+
1011
+ return TaskType.REGRESSION
1012
+
1013
+ assert isinstance(target, Column)
1014
+
1015
+ if target.stype in {Stype.ID, Stype.categorical}:
1016
+ return TaskType.MULTICLASS_CLASSIFICATION
1017
+
1018
+ if target.stype in {Stype.numerical}:
1019
+ return TaskType.REGRESSION
1020
+
1021
+ raise NotImplementedError("Task type not yet supported")
1022
+
1023
+ def _get_default_anchor_time(
1024
+ self,
1025
+ query: ValidatedPredictiveQuery | None = None,
1026
+ ) -> pd.Timestamp:
1027
+ if query is not None and query.query_type == QueryType.TEMPORAL:
1028
+ aggr_table_names = [
1029
+ aggr._get_target_column_name().split('.')[0]
1030
+ for aggr in query.get_all_target_aggregations()
1031
+ ]
1032
+ return self._sampler.get_max_time(aggr_table_names)
1033
+
1034
+ return self._sampler.get_max_time()
755
1035
 
756
1036
  def _validate_time(
757
1037
  self,
758
1038
  query: ValidatedPredictiveQuery,
759
1039
  anchor_time: pd.Timestamp,
760
- context_anchor_time: Union[pd.Timestamp, None],
1040
+ context_anchor_time: pd.Timestamp | None,
761
1041
  evaluate: bool,
762
1042
  ) -> None:
763
1043
 
764
- if self._graph_store.min_time == pd.Timestamp.max:
1044
+ if len(self._sampler.time_column_dict) == 0:
765
1045
  return # Graph without timestamps
766
1046
 
767
- if anchor_time < self._graph_store.min_time:
1047
+ if query.query_type == QueryType.TEMPORAL:
1048
+ aggr_table_names = [
1049
+ aggr._get_target_column_name().split('.')[0]
1050
+ for aggr in query.get_all_target_aggregations()
1051
+ ]
1052
+ min_time = self._sampler.get_min_time(aggr_table_names)
1053
+ max_time = self._sampler.get_max_time(aggr_table_names)
1054
+ else:
1055
+ min_time = self._sampler.get_min_time()
1056
+ max_time = self._sampler.get_max_time()
1057
+
1058
+ if anchor_time < min_time:
768
1059
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
769
- f"the earliest timestamp "
770
- f"'{self._graph_store.min_time}' in the data.")
1060
+ f"the earliest timestamp '{min_time}' in the "
1061
+ f"data.")
771
1062
 
772
- if (context_anchor_time is not None
773
- and context_anchor_time < self._graph_store.min_time):
1063
+ if context_anchor_time is not None and context_anchor_time < min_time:
774
1064
  raise ValueError(f"Context anchor timestamp is too early or "
775
1065
  f"aggregation time range is too large. To make "
776
1066
  f"this prediction, we would need data back to "
777
1067
  f"'{context_anchor_time}', however, your data "
778
- f"only contains data back to "
779
- f"'{self._graph_store.min_time}'.")
1068
+ f"only contains data back to '{min_time}'.")
780
1069
 
781
1070
  if query.target_ast.date_offset_range is not None:
782
1071
  end_offset = query.target_ast.date_offset_range.end_date_offset
783
1072
  else:
784
1073
  end_offset = pd.DateOffset(0)
785
- forecast_end_offset = end_offset * query.num_forecasts
1074
+ end_offset = end_offset * query.num_forecasts
1075
+
786
1076
  if (context_anchor_time is not None
787
1077
  and context_anchor_time > anchor_time):
788
1078
  warnings.warn(f"Context anchor timestamp "
@@ -792,7 +1082,7 @@ class KumoRFM:
792
1082
  f"intended.")
793
1083
  elif (query.query_type == QueryType.TEMPORAL
794
1084
  and context_anchor_time is not None
795
- and context_anchor_time + forecast_end_offset > anchor_time):
1085
+ and context_anchor_time + end_offset > anchor_time):
796
1086
  warnings.warn(f"Aggregation for context examples at timestamp "
797
1087
  f"'{context_anchor_time}' will leak information "
798
1088
  f"from the prediction anchor timestamp "
@@ -800,62 +1090,44 @@ class KumoRFM:
800
1090
  f"intended.")
801
1091
 
802
1092
  elif (context_anchor_time is not None
803
- and context_anchor_time - forecast_end_offset
804
- < self._graph_store.min_time):
805
- _time = context_anchor_time - forecast_end_offset
1093
+ and context_anchor_time - end_offset < min_time):
1094
+ _time = context_anchor_time - end_offset
806
1095
  warnings.warn(f"Context anchor timestamp is too early or "
807
1096
  f"aggregation time range is too large. To form "
808
1097
  f"proper input data, we would need data back to "
809
1098
  f"'{_time}', however, your data only contains "
810
- f"data back to '{self._graph_store.min_time}'.")
1099
+ f"data back to '{min_time}'.")
811
1100
 
812
- if (not evaluate and anchor_time
813
- > self._graph_store.max_time + pd.DateOffset(days=1)):
1101
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
814
1102
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
815
- f"latest timestamp '{self._graph_store.max_time}' "
816
- f"in the data. Please make sure this is intended.")
1103
+ f"latest timestamp '{max_time}' in the data. Please "
1104
+ f"make sure this is intended.")
817
1105
 
818
- max_eval_time = self._graph_store.max_time - forecast_end_offset
819
- if evaluate and anchor_time > max_eval_time:
1106
+ if evaluate and anchor_time > max_time - end_offset:
820
1107
  raise ValueError(
821
1108
  f"Anchor timestamp for evaluation is after the latest "
822
- f"supported timestamp '{max_eval_time}'.")
1109
+ f"supported timestamp '{max_time - end_offset}'.")
823
1110
 
824
- def _get_context(
1111
+ def _get_task_table(
825
1112
  self,
826
1113
  query: ValidatedPredictiveQuery,
827
- indices: Union[List[str], List[float], List[int], None],
828
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
829
- context_anchor_time: Union[pd.Timestamp, None],
830
- run_mode: RunMode,
831
- num_neighbors: Optional[List[int]],
832
- num_hops: int,
833
- max_pq_iterations: int,
834
- evaluate: bool,
835
- random_seed: Optional[int] = _RANDOM_SEED,
836
- logger: Optional[ProgressLogger] = None,
837
- ) -> Context:
838
-
839
- if num_neighbors is not None:
840
- num_hops = len(num_neighbors)
841
-
842
- if num_hops < 0:
843
- raise ValueError(f"'num_hops' must be non-negative "
844
- f"(got {num_hops})")
845
- if num_hops > 6:
846
- raise ValueError(f"Cannot predict on subgraphs with more than 6 "
847
- f"hops (got {num_hops}). Please reduce the "
848
- f"number of hops and try again. Please create a "
849
- f"feature request at "
850
- f"'https://github.com/kumo-ai/kumo-rfm' if you "
851
- f"must go beyond this for your use-case.")
852
-
853
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
854
- task_type = LocalPQueryDriver.get_task_type(
855
- query,
856
- edge_types=self._graph_store.edge_types,
1114
+ indices: list[str] | list[float] | list[int] | None,
1115
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
1116
+ context_anchor_time: pd.Timestamp | None = None,
1117
+ run_mode: RunMode = RunMode.FAST,
1118
+ max_pq_iterations: int = 10,
1119
+ random_seed: int | None = _RANDOM_SEED,
1120
+ logger: ProgressLogger | None = None,
1121
+ ) -> TaskTable:
1122
+
1123
+ task_type = self._get_task_type(
1124
+ query=query,
1125
+ edge_types=self._sampler.edge_types,
857
1126
  )
858
1127
 
1128
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1129
+ num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
1130
+
859
1131
  if logger is not None:
860
1132
  if task_type == TaskType.BINARY_CLASSIFICATION:
861
1133
  task_type_repr = 'binary classification'
@@ -869,30 +1141,17 @@ class KumoRFM:
869
1141
  task_type_repr = str(task_type)
870
1142
  logger.log(f"Identified {query.query_type} {task_type_repr} task")
871
1143
 
872
- if task_type.is_link_pred and num_hops < 2:
873
- raise ValueError(f"Cannot perform link prediction on subgraphs "
874
- f"with less than 2 hops (got {num_hops}) since "
875
- f"historical target entities need to be part of "
876
- f"the context. Please increase the number of "
877
- f"hops and try again.")
878
-
879
- if num_neighbors is None:
880
- if run_mode == RunMode.DEBUG:
881
- num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
882
- elif run_mode == RunMode.FAST or task_type.is_link_pred:
883
- num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
884
- else:
885
- num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
886
-
887
1144
  if query.target_ast.date_offset_range is None:
888
- end_offset = pd.DateOffset(0)
1145
+ step_offset = pd.DateOffset(0)
889
1146
  else:
890
- end_offset = query.target_ast.date_offset_range.end_date_offset
891
- forecast_end_offset = end_offset * query.num_forecasts
1147
+ step_offset = query.target_ast.date_offset_range.end_date_offset
1148
+ end_offset = step_offset * query.num_forecasts
1149
+
892
1150
  if anchor_time is None:
893
- anchor_time = self._graph_store.max_time
894
- if evaluate:
895
- anchor_time = anchor_time - forecast_end_offset
1151
+ anchor_time = self._get_default_anchor_time(query)
1152
+ if num_test_examples > 0:
1153
+ anchor_time = anchor_time - end_offset
1154
+
896
1155
  if logger is not None:
897
1156
  assert isinstance(anchor_time, pd.Timestamp)
898
1157
  if anchor_time == pd.Timestamp.min:
@@ -904,114 +1163,98 @@ class KumoRFM:
904
1163
  else:
905
1164
  logger.log(f"Derived anchor time {anchor_time}")
906
1165
 
907
- assert anchor_time is not None
908
1166
  if isinstance(anchor_time, pd.Timestamp):
1167
+ if context_anchor_time == 'entity':
1168
+ raise ValueError("Anchor time 'entity' needs to be shared "
1169
+ "for context and prediction examples")
909
1170
  if context_anchor_time is None:
910
- context_anchor_time = anchor_time - forecast_end_offset
1171
+ context_anchor_time = anchor_time - end_offset
911
1172
  self._validate_time(query, anchor_time, context_anchor_time,
912
- evaluate)
1173
+ evaluate=num_test_examples > 0)
913
1174
  else:
914
1175
  assert anchor_time == 'entity'
915
- if query.entity_table not in self._graph_store.time_dict:
1176
+ if query.query_type != QueryType.STATIC:
1177
+ raise ValueError("Anchor time 'entity' is only valid for "
1178
+ "static predictive queries")
1179
+ if query.entity_table not in self._sampler.time_column_dict:
916
1180
  raise ValueError(f"Anchor time 'entity' requires the entity "
917
1181
  f"table '{query.entity_table}' to "
918
1182
  f"have a time column")
919
- if context_anchor_time is not None:
920
- warnings.warn("Ignoring option 'context_anchor_time' for "
921
- "`anchor_time='entity'`")
922
- context_anchor_time = None
923
-
924
- y_test: Optional[pd.Series] = None
925
- if evaluate:
926
- max_test_size = _MAX_TEST_SIZE[run_mode]
927
- if task_type.is_link_pred:
928
- max_test_size = max_test_size // 5
929
-
930
- test_node, test_time, y_test = query_driver.collect_test(
931
- size=max_test_size,
932
- anchor_time=anchor_time,
933
- max_iterations=max_pq_iterations,
934
- guarantee_train_examples=True,
935
- )
936
- if logger is not None:
937
- if task_type == TaskType.BINARY_CLASSIFICATION:
938
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
939
- msg = (f"Collected {len(y_test):,} test examples with "
940
- f"{pos:.2f}% positive cases")
941
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
942
- msg = (f"Collected {len(y_test):,} test examples "
943
- f"holding {y_test.nunique()} classes")
944
- elif task_type == TaskType.REGRESSION:
945
- _min, _max = float(y_test.min()), float(y_test.max())
946
- msg = (f"Collected {len(y_test):,} test examples with "
947
- f"targets between {format_value(_min)} and "
948
- f"{format_value(_max)}")
949
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
950
- num_rhs = y_test.explode().nunique()
951
- msg = (f"Collected {len(y_test):,} test examples with "
952
- f"{num_rhs:,} unique items")
953
- else:
954
- raise NotImplementedError
955
- logger.log(msg)
956
-
957
- else:
958
- assert indices is not None
959
-
960
- if len(indices) > _MAX_PRED_SIZE[task_type]:
961
- raise ValueError(f"Cannot predict for more than "
962
- f"{_MAX_PRED_SIZE[task_type]:,} entities at "
963
- f"once (got {len(indices):,}). Use "
964
- f"`KumoRFM.batch_mode` to process entities "
965
- f"in batches")
1183
+ if isinstance(context_anchor_time, pd.Timestamp):
1184
+ raise ValueError("Anchor time 'entity' needs to be shared "
1185
+ "for context and prediction examples")
1186
+ context_anchor_time = 'entity'
1187
+
1188
+ train, test = self._sampler.sample_target(
1189
+ query=query,
1190
+ num_train_examples=num_train_examples,
1191
+ train_anchor_time=context_anchor_time,
1192
+ num_train_trials=max_pq_iterations * num_train_examples,
1193
+ num_test_examples=num_test_examples,
1194
+ test_anchor_time=anchor_time,
1195
+ num_test_trials=max_pq_iterations * num_test_examples,
1196
+ random_seed=random_seed,
1197
+ )
1198
+ train_pkey, train_time, train_y = train
1199
+ test_pkey, test_time, test_y = test
966
1200
 
967
- test_node = self._graph_store.get_node_id(
968
- table_name=query.entity_table,
969
- pkey=pd.Series(indices),
970
- )
1201
+ if num_test_examples > 0 and logger is not None:
1202
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1203
+ pos = 100 * int((test_y > 0).sum()) / len(test_y)
1204
+ msg = (f"Collected {len(test_y):,} test examples with "
1205
+ f"{pos:.2f}% positive cases")
1206
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1207
+ msg = (f"Collected {len(test_y):,} test examples holding "
1208
+ f"{test_y.nunique()} classes")
1209
+ elif task_type == TaskType.REGRESSION:
1210
+ _min, _max = float(test_y.min()), float(test_y.max())
1211
+ msg = (f"Collected {len(test_y):,} test examples with targets "
1212
+ f"between {format_value(_min)} and "
1213
+ f"{format_value(_max)}")
1214
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1215
+ num_rhs = test_y.explode().nunique()
1216
+ msg = (f"Collected {len(test_y):,} test examples with "
1217
+ f"{num_rhs:,} unique items")
1218
+ else:
1219
+ raise NotImplementedError
1220
+ logger.log(msg)
971
1221
 
1222
+ if num_test_examples == 0:
1223
+ assert indices is not None
1224
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
972
1225
  if isinstance(anchor_time, pd.Timestamp):
973
- test_time = pd.Series(anchor_time).repeat(
974
- len(test_node)).reset_index(drop=True)
1226
+ test_time = pd.Series([anchor_time]).repeat(
1227
+ len(indices)).reset_index(drop=True)
975
1228
  else:
976
- time = self._graph_store.time_dict[query.entity_table]
977
- time = time[test_node] * 1000**3
978
- test_time = pd.Series(time, dtype='datetime64[ns]')
979
-
980
- train_node, train_time, y_train = query_driver.collect_train(
981
- size=_MAX_CONTEXT_SIZE[run_mode],
982
- anchor_time=context_anchor_time or 'entity',
983
- exclude_node=test_node if (query.query_type == QueryType.STATIC
984
- or anchor_time == 'entity') else None,
985
- max_iterations=max_pq_iterations,
986
- )
1229
+ train_time = test_time = 'entity'
987
1230
 
988
1231
  if logger is not None:
989
1232
  if task_type == TaskType.BINARY_CLASSIFICATION:
990
- pos = 100 * int((y_train > 0).sum()) / len(y_train)
991
- msg = (f"Collected {len(y_train):,} in-context examples with "
1233
+ pos = 100 * int((train_y > 0).sum()) / len(train_y)
1234
+ msg = (f"Collected {len(train_y):,} in-context examples with "
992
1235
  f"{pos:.2f}% positive cases")
993
1236
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
994
- msg = (f"Collected {len(y_train):,} in-context examples "
995
- f"holding {y_train.nunique()} classes")
1237
+ msg = (f"Collected {len(train_y):,} in-context examples "
1238
+ f"holding {train_y.nunique()} classes")
996
1239
  elif task_type == TaskType.REGRESSION:
997
- _min, _max = float(y_train.min()), float(y_train.max())
998
- msg = (f"Collected {len(y_train):,} in-context examples with "
1240
+ _min, _max = float(train_y.min()), float(train_y.max())
1241
+ msg = (f"Collected {len(train_y):,} in-context examples with "
999
1242
  f"targets between {format_value(_min)} and "
1000
1243
  f"{format_value(_max)}")
1001
1244
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1002
- num_rhs = y_train.explode().nunique()
1003
- msg = (f"Collected {len(y_train):,} in-context examples with "
1245
+ num_rhs = train_y.explode().nunique()
1246
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1004
1247
  f"{num_rhs:,} unique items")
1005
1248
  else:
1006
1249
  raise NotImplementedError
1007
1250
  logger.log(msg)
1008
1251
 
1009
- entity_table_names: Tuple[str, ...]
1252
+ entity_table_names: tuple[str] | tuple[str, str]
1010
1253
  if task_type.is_link_pred:
1011
1254
  final_aggr = query.get_final_target_aggregation()
1012
1255
  assert final_aggr is not None
1013
1256
  edge_fkey = final_aggr._get_target_column_name()
1014
- for edge_type in self._graph_store.edge_types:
1257
+ for edge_type in self._sampler.edge_types:
1015
1258
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
1259
  entity_table_names = (
1017
1260
  query.entity_table,
@@ -1020,23 +1263,80 @@ class KumoRFM:
1020
1263
  else:
1021
1264
  entity_table_names = (query.entity_table, )
1022
1265
 
1266
+ context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
1267
+ if isinstance(train_time, pd.Series):
1268
+ context_df['ANCHOR_TIMESTAMP'] = train_time
1269
+ pred_df = pd.DataFrame({'ENTITY': test_pkey})
1270
+ if num_test_examples > 0:
1271
+ pred_df['TARGET'] = test_y
1272
+ if isinstance(test_time, pd.Series):
1273
+ pred_df['ANCHOR_TIMESTAMP'] = test_time
1274
+
1275
+ return TaskTable(
1276
+ task_type=task_type,
1277
+ context_df=context_df,
1278
+ pred_df=pred_df,
1279
+ entity_table_name=entity_table_names,
1280
+ entity_column='ENTITY',
1281
+ target_column='TARGET',
1282
+ time_column='ANCHOR_TIMESTAMP' if isinstance(
1283
+ train_time, pd.Series) else TaskTable.ENTITY_TIME,
1284
+ )
1285
+
1286
+ def _get_context(
1287
+ self,
1288
+ task: TaskTable,
1289
+ run_mode: RunMode | str = RunMode.FAST,
1290
+ num_neighbors: list[int] | None = None,
1291
+ exclude_cols_dict: dict[str, list[str]] | None = None,
1292
+ top_k: int | None = None,
1293
+ ) -> Context:
1294
+
1295
+ if num_neighbors is None:
1296
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1297
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
1298
+
1299
+ if len(num_neighbors) > 6:
1300
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
1301
+ f"hops (got {len(num_neighbors)}). Reduce the "
1302
+ f"number of hops and try again. Please create a "
1303
+ f"feature request at "
1304
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1305
+ f"must go beyond this for your use-case.")
1306
+
1023
1307
  # Exclude the entity anchor time from the feature set to prevent
1024
1308
  # running out-of-distribution between in-context and test examples:
1025
- exclude_cols_dict = query.get_exclude_cols_dict()
1026
- if anchor_time == 'entity':
1027
- if entity_table_names[0] not in exclude_cols_dict:
1028
- exclude_cols_dict[entity_table_names[0]] = []
1029
- time_column_dict = self._graph_store.time_column_dict
1030
- time_column = time_column_dict[entity_table_names[0]]
1031
- exclude_cols_dict[entity_table_names[0]].append(time_column)
1032
-
1033
- subgraph = self._graph_sampler(
1034
- entity_table_names=entity_table_names,
1035
- node=np.concatenate([train_node, test_node]),
1036
- time=np.concatenate([
1037
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1038
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- ]),
1309
+ exclude_cols_dict = exclude_cols_dict or {}
1310
+ if task.entity_table_name in self._sampler.time_column_dict:
1311
+ if task.entity_table_name not in exclude_cols_dict:
1312
+ exclude_cols_dict[task.entity_table_name] = []
1313
+ time_col = self._sampler.time_column_dict[task.entity_table_name]
1314
+ exclude_cols_dict[task.entity_table_name].append(time_col)
1315
+
1316
+ entity_pkey = pd.concat([
1317
+ task._context_df[task._entity_column],
1318
+ task._pred_df[task._entity_column],
1319
+ ], axis=0, ignore_index=True)
1320
+
1321
+ if task.use_entity_time:
1322
+ if task.entity_table_name not in self._sampler.time_column_dict:
1323
+ raise ValueError(f"The given annchor time requires the entity "
1324
+ f"table '{task.entity_table_name}' to have a "
1325
+ f"time column")
1326
+ anchor_time = 'entity'
1327
+ elif task._time_column is not None:
1328
+ anchor_time = pd.concat([
1329
+ task._context_df[task._time_column],
1330
+ task._pred_df[task._time_column],
1331
+ ], axis=0, ignore_index=True)
1332
+ else:
1333
+ anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
1334
+ (len(entity_pkey))).reset_index(drop=True)
1335
+
1336
+ subgraph = self._sampler.sample_subgraph(
1337
+ entity_table_names=task.entity_table_names,
1338
+ entity_pkey=entity_pkey,
1339
+ anchor_time=anchor_time,
1040
1340
  num_neighbors=num_neighbors,
1041
1341
  exclude_cols_dict=exclude_cols_dict,
1042
1342
  )
@@ -1048,23 +1348,26 @@ class KumoRFM:
1048
1348
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1049
1349
  f"must go beyond this for your use-case.")
1050
1350
 
1051
- step_size: Optional[int] = None
1052
- if query.query_type == QueryType.TEMPORAL:
1053
- step_size = date_offset_to_seconds(end_offset)
1351
+ if (task.task_type.is_link_pred
1352
+ and task.entity_table_names[-1] not in subgraph.table_dict):
1353
+ raise ValueError("Cannot perform link prediction on subgraphs "
1354
+ "without any historical target entities. Please "
1355
+ "increase the number of hops and try again.")
1054
1356
 
1055
1357
  return Context(
1056
- task_type=task_type,
1057
- entity_table_names=entity_table_names,
1358
+ task_type=task.task_type,
1359
+ entity_table_names=task.entity_table_names,
1058
1360
  subgraph=subgraph,
1059
- y_train=y_train,
1060
- y_test=y_test,
1061
- top_k=query.top_k,
1062
- step_size=step_size,
1361
+ y_train=task._context_df[task.target_column.name],
1362
+ y_test=task._pred_df[task.target_column.name]
1363
+ if task.evaluate else None,
1364
+ top_k=top_k,
1365
+ step_size=None,
1063
1366
  )
1064
1367
 
1065
1368
  @staticmethod
1066
1369
  def _validate_metrics(
1067
- metrics: List[str],
1370
+ metrics: list[str],
1068
1371
  task_type: TaskType,
1069
1372
  ) -> None:
1070
1373
 
@@ -1121,7 +1424,7 @@ class KumoRFM:
1121
1424
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1122
1425
 
1123
1426
 
1124
- def format_value(value: Union[int, float]) -> str:
1427
+ def format_value(value: int | float) -> str:
1125
1428
  if value == int(value):
1126
1429
  return f'{int(value):,}'
1127
1430
  if abs(value) >= 1000: