kumoai 2.13.0.dev202511211730__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +44 -9
  7. kumoai/experimental/rfm/__init__.py +70 -68
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  12. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  13. kumoai/experimental/rfm/backend/local/table.py +113 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +245 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  20. kumoai/experimental/rfm/base/__init__.py +30 -0
  21. kumoai/experimental/rfm/base/column.py +152 -0
  22. kumoai/experimental/rfm/base/expression.py +44 -0
  23. kumoai/experimental/rfm/base/mapper.py +69 -0
  24. kumoai/experimental/rfm/base/sampler.py +783 -0
  25. kumoai/experimental/rfm/base/source.py +19 -0
  26. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  27. kumoai/experimental/rfm/base/table.py +722 -0
  28. kumoai/experimental/rfm/base/utils.py +36 -0
  29. kumoai/experimental/rfm/{local_graph.py → graph.py} +581 -154
  30. kumoai/experimental/rfm/infer/__init__.py +8 -0
  31. kumoai/experimental/rfm/infer/dtype.py +84 -0
  32. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  33. kumoai/experimental/rfm/infer/pkey.py +128 -0
  34. kumoai/experimental/rfm/infer/stype.py +35 -0
  35. kumoai/experimental/rfm/infer/time_col.py +63 -0
  36. kumoai/experimental/rfm/pquery/executor.py +27 -27
  37. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  38. kumoai/experimental/rfm/relbench.py +76 -0
  39. kumoai/experimental/rfm/rfm.py +783 -481
  40. kumoai/experimental/rfm/sagemaker.py +15 -7
  41. kumoai/experimental/rfm/task_table.py +292 -0
  42. kumoai/pquery/predictive_query.py +10 -6
  43. kumoai/pquery/training_table.py +16 -2
  44. kumoai/testing/decorators.py +1 -1
  45. kumoai/testing/snow.py +50 -0
  46. kumoai/trainer/distilled_trainer.py +175 -0
  47. kumoai/utils/__init__.py +3 -2
  48. kumoai/utils/display.py +87 -0
  49. kumoai/utils/progress_logger.py +192 -13
  50. kumoai/utils/sql.py +3 -0
  51. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +10 -8
  52. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +55 -30
  53. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  54. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  55. kumoai/experimental/rfm/local_table.py +0 -545
  56. kumoai/experimental/rfm/utils.py +0 -344
  57. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
  58. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
  59. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,23 @@
1
1
  import json
2
+ import math
2
3
  import time
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from collections.abc import Generator
6
+ from collections.abc import Generator, Iterator
6
7
  from contextlib import contextmanager
7
8
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
9
+ from typing import Any, Literal, overload
19
10
 
20
- import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
24
21
  from kumoapi.rfm import Context
25
22
  from kumoapi.rfm import Explanation as ExplanationConfig
26
23
  from kumoapi.rfm import (
@@ -29,35 +26,38 @@ from kumoapi.rfm import (
29
26
  RFMPredictRequest,
30
27
  )
31
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
30
+ from rich.console import Console
31
+ from rich.markdown import Markdown
32
32
 
33
+ from kumoai import in_notebook
33
34
  from kumoai.client.rfm import RFMAPI
34
35
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
36
+ from kumoai.experimental.rfm import Graph, TaskTable
37
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
38
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
39
+ from kumoai.utils import ProgressLogger, display
44
40
 
45
41
  _RANDOM_SEED = 42
46
42
 
47
43
  _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
44
  _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
45
 
46
+ _MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
47
+ _MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
48
+
50
49
  _MAX_CONTEXT_SIZE = {
51
50
  RunMode.DEBUG: 100,
52
51
  RunMode.FAST: 1_000,
53
52
  RunMode.NORMAL: 5_000,
54
53
  RunMode.BEST: 10_000,
55
54
  }
56
- _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
57
- RunMode.DEBUG: 100,
58
- RunMode.FAST: 2_000,
59
- RunMode.NORMAL: 2_000,
60
- RunMode.BEST: 2_000,
55
+
56
+ _DEFAULT_NUM_NEIGHBORS = {
57
+ RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
58
+ RunMode.FAST: [32, 32, 8, 8, 4, 4],
59
+ RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
60
+ RunMode.BEST: [64, 64, 8, 8, 4, 4],
61
61
  }
62
62
 
63
63
  _MAX_SIZE = 30 * 1024 * 1024
@@ -95,24 +95,36 @@ class Explanation:
95
95
  def __getitem__(self, index: Literal[1]) -> str:
96
96
  pass
97
97
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
98
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
99
  if index == 0:
100
100
  return self.prediction
101
101
  if index == 1:
102
102
  return self.summary
103
103
  raise IndexError("Index out of range")
104
104
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
105
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
106
  return iter((self.prediction, self.summary))
107
107
 
108
108
  def __repr__(self) -> str:
109
109
  return str((self.prediction, self.summary))
110
110
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
111
+ def __str__(self) -> str:
112
+ console = Console(soft_wrap=True)
113
+ with console.capture() as cap:
114
+ console.print(display.to_rich_table(self.prediction))
115
+ console.print(Markdown(self.summary))
116
+ return cap.get()[:-1]
117
+
118
+ def print(self) -> None:
119
+ r"""Prints the explanation."""
120
+ if in_notebook():
121
+ display.dataframe(self.prediction)
122
+ display.message(self.summary)
123
+ else:
124
+ print(self)
113
125
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
126
+ def _ipython_display_(self) -> None:
127
+ self.print()
116
128
 
117
129
 
118
130
  class KumoRFM:
@@ -123,17 +135,17 @@ class KumoRFM:
123
135
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
136
  relational dataset without training.
125
137
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
138
+ model from a :class:`Graph` object.
127
139
 
128
140
  .. code-block:: python
129
141
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
142
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
143
 
132
144
  df_users = pd.DataFrame(...)
133
145
  df_items = pd.DataFrame(...)
134
146
  df_orders = pd.DataFrame(...)
135
147
 
136
- graph = LocalGraph.from_data({
148
+ graph = Graph.from_data({
137
149
  'users': df_users,
138
150
  'items': df_items,
139
151
  'orders': df_orders,
@@ -150,40 +162,78 @@ class KumoRFM:
150
162
 
151
163
  Args:
152
164
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
165
  verbose: Whether to print verbose output.
166
+ optimize: If set to ``True``, will optimize the underlying data backend
167
+ for optimal querying. For example, for transactional database
168
+ backends, will create any missing indices. Requires write-access to
169
+ the data backend.
163
170
  """
164
171
  def __init__(
165
172
  self,
166
- graph: LocalGraph,
167
- preprocess: bool = False,
168
- verbose: Union[bool, ProgressLogger] = True,
173
+ graph: Graph,
174
+ verbose: bool | ProgressLogger = True,
175
+ optimize: bool = False,
169
176
  ) -> None:
170
177
  graph = graph.validate()
171
178
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
- self._graph_sampler = LocalGraphSampler(self._graph_store)
174
179
 
175
- self._batch_size: Optional[int | Literal['max']] = None
176
- self.num_retries: int = 0
180
+ if graph.backend == DataBackend.LOCAL:
181
+ from kumoai.experimental.rfm.backend.local import LocalSampler
182
+ self._sampler: Sampler = LocalSampler(graph, verbose)
183
+ elif graph.backend == DataBackend.SQLITE:
184
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
185
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
186
+ elif graph.backend == DataBackend.SNOWFLAKE:
187
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
188
+ self._sampler = SnowSampler(graph, verbose)
189
+ else:
190
+ raise NotImplementedError
191
+
192
+ self._client: RFMAPI | None = None
193
+
194
+ self._batch_size: int | Literal['max'] | None = None
195
+ self._num_retries: int = 0
196
+
197
+ @property
198
+ def _api_client(self) -> RFMAPI:
199
+ if self._client is not None:
200
+ return self._client
201
+
177
202
  from kumoai.experimental.rfm import global_state
178
- self._api_client = RFMAPI(global_state.client)
203
+ self._client = RFMAPI(global_state.client)
204
+ return self._client
179
205
 
180
206
  def __repr__(self) -> str:
181
207
  return f'{self.__class__.__name__}()'
182
208
 
209
+ @contextmanager
210
+ def retry(
211
+ self,
212
+ num_retries: int = 1,
213
+ ) -> Generator[None, None, None]:
214
+ """Context manager to retry failed queries due to unexpected server
215
+ issues.
216
+
217
+ .. code-block:: python
218
+
219
+ with model.retry(num_retries=1):
220
+ df = model.predict(query, indices=...)
221
+
222
+ Args:
223
+ num_retries: The maximum number of retries.
224
+ """
225
+ if num_retries < 0:
226
+ raise ValueError(f"'num_retries' must be greater than or equal to "
227
+ f"zero (got {num_retries})")
228
+
229
+ self._num_retries = num_retries
230
+ yield
231
+ self._num_retries = 0
232
+
183
233
  @contextmanager
184
234
  def batch_mode(
185
235
  self,
186
- batch_size: Union[int, Literal['max']] = 'max',
236
+ batch_size: int | Literal['max'] = 'max',
187
237
  num_retries: int = 1,
188
238
  ) -> Generator[None, None, None]:
189
239
  """Context manager to predict in batches.
@@ -203,31 +253,26 @@ class KumoRFM:
203
253
  raise ValueError(f"'batch_size' must be greater than zero "
204
254
  f"(got {batch_size})")
205
255
 
206
- if num_retries < 0:
207
- raise ValueError(f"'num_retries' must be greater than or equal to "
208
- f"zero (got {num_retries})")
209
-
210
256
  self._batch_size = batch_size
211
- self.num_retries = num_retries
212
- yield
257
+ with self.retry(self._num_retries or num_retries):
258
+ yield
213
259
  self._batch_size = None
214
- self.num_retries = 0
215
260
 
216
261
  @overload
217
262
  def predict(
218
263
  self,
219
264
  query: str,
220
- indices: Union[List[str], List[float], List[int], None] = None,
265
+ indices: list[str] | list[float] | list[int] | None = None,
221
266
  *,
222
267
  explain: Literal[False] = False,
223
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
224
- context_anchor_time: Union[pd.Timestamp, None] = None,
225
- run_mode: Union[RunMode, str] = RunMode.FAST,
226
- num_neighbors: Optional[List[int]] = None,
268
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
269
+ context_anchor_time: pd.Timestamp | None = None,
270
+ run_mode: RunMode | str = RunMode.FAST,
271
+ num_neighbors: list[int] | None = None,
227
272
  num_hops: int = 2,
228
- max_pq_iterations: int = 20,
229
- random_seed: Optional[int] = _RANDOM_SEED,
230
- verbose: Union[bool, ProgressLogger] = True,
273
+ max_pq_iterations: int = 10,
274
+ random_seed: int | None = _RANDOM_SEED,
275
+ verbose: bool | ProgressLogger = True,
231
276
  use_prediction_time: bool = False,
232
277
  ) -> pd.DataFrame:
233
278
  pass
@@ -236,37 +281,56 @@ class KumoRFM:
236
281
  def predict(
237
282
  self,
238
283
  query: str,
239
- indices: Union[List[str], List[float], List[int], None] = None,
284
+ indices: list[str] | list[float] | list[int] | None = None,
240
285
  *,
241
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
242
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
243
- context_anchor_time: Union[pd.Timestamp, None] = None,
244
- run_mode: Union[RunMode, str] = RunMode.FAST,
245
- num_neighbors: Optional[List[int]] = None,
286
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
287
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
288
+ context_anchor_time: pd.Timestamp | None = None,
289
+ run_mode: RunMode | str = RunMode.FAST,
290
+ num_neighbors: list[int] | None = None,
246
291
  num_hops: int = 2,
247
- max_pq_iterations: int = 20,
248
- random_seed: Optional[int] = _RANDOM_SEED,
249
- verbose: Union[bool, ProgressLogger] = True,
292
+ max_pq_iterations: int = 10,
293
+ random_seed: int | None = _RANDOM_SEED,
294
+ verbose: bool | ProgressLogger = True,
250
295
  use_prediction_time: bool = False,
251
296
  ) -> Explanation:
252
297
  pass
253
298
 
299
+ @overload
300
+ def predict(
301
+ self,
302
+ query: str,
303
+ indices: list[str] | list[float] | list[int] | None = None,
304
+ *,
305
+ explain: bool | ExplainConfig | dict[str, Any] = False,
306
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
307
+ context_anchor_time: pd.Timestamp | None = None,
308
+ run_mode: RunMode | str = RunMode.FAST,
309
+ num_neighbors: list[int] | None = None,
310
+ num_hops: int = 2,
311
+ max_pq_iterations: int = 10,
312
+ random_seed: int | None = _RANDOM_SEED,
313
+ verbose: bool | ProgressLogger = True,
314
+ use_prediction_time: bool = False,
315
+ ) -> pd.DataFrame | Explanation:
316
+ pass
317
+
254
318
  def predict(
255
319
  self,
256
320
  query: str,
257
- indices: Union[List[str], List[float], List[int], None] = None,
321
+ indices: list[str] | list[float] | list[int] | None = None,
258
322
  *,
259
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
260
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
261
- context_anchor_time: Union[pd.Timestamp, None] = None,
262
- run_mode: Union[RunMode, str] = RunMode.FAST,
263
- num_neighbors: Optional[List[int]] = None,
323
+ explain: bool | ExplainConfig | dict[str, Any] = False,
324
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
325
+ context_anchor_time: pd.Timestamp | None = None,
326
+ run_mode: RunMode | str = RunMode.FAST,
327
+ num_neighbors: list[int] | None = None,
264
328
  num_hops: int = 2,
265
- max_pq_iterations: int = 20,
266
- random_seed: Optional[int] = _RANDOM_SEED,
267
- verbose: Union[bool, ProgressLogger] = True,
329
+ max_pq_iterations: int = 10,
330
+ random_seed: int | None = _RANDOM_SEED,
331
+ verbose: bool | ProgressLogger = True,
268
332
  use_prediction_time: bool = False,
269
- ) -> Union[pd.DataFrame, Explanation]:
333
+ ) -> pd.DataFrame | Explanation:
270
334
  """Returns predictions for a predictive query.
271
335
 
272
336
  Args:
@@ -274,8 +338,7 @@ class KumoRFM:
274
338
  indices: The entity primary keys to predict for. Will override the
275
339
  indices given as part of the predictive query. Predictions will
276
340
  be generated for all indices, independent of whether they
277
- fulfill entity filter constraints. To pre-filter entities, use
278
- :meth:`~KumoRFM.is_valid_entity`.
341
+ fulfill entity filter constraints.
279
342
  explain: Configuration for explainability.
280
343
  If set to ``True``, will additionally explain the prediction.
281
344
  Passing in an :class:`ExplainConfig` instance provides control
@@ -308,18 +371,152 @@ class KumoRFM:
308
371
  If ``explain`` is provided, returns an :class:`Explanation` object
309
372
  containing the prediction, summary, and details.
310
373
  """
311
- explain_config: Optional[ExplainConfig] = None
312
- if explain is True:
313
- explain_config = ExplainConfig()
314
- elif explain is not False:
315
- explain_config = ExplainConfig._cast(explain)
316
-
317
374
  query_def = self._parse_query(query)
318
- query_str = query_def.to_string()
319
375
 
376
+ if indices is None:
377
+ if query_def.rfm_entity_ids is None:
378
+ raise ValueError("Cannot find entities to predict for. Please "
379
+ "pass them via `predict(query, indices=...)`")
380
+ indices = query_def.get_rfm_entity_id_list()
381
+ query_def = replace(
382
+ query_def,
383
+ for_each='FOR EACH',
384
+ rfm_entity_ids=None,
385
+ )
386
+
387
+ if not isinstance(verbose, ProgressLogger):
388
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
389
+ if explain is not False:
390
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
391
+ else:
392
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
393
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
394
+
395
+ with verbose as logger:
396
+ task_table = self._get_task_table(
397
+ query=query_def,
398
+ indices=indices,
399
+ anchor_time=anchor_time,
400
+ context_anchor_time=context_anchor_time,
401
+ run_mode=run_mode,
402
+ max_pq_iterations=max_pq_iterations,
403
+ random_seed=random_seed,
404
+ logger=logger,
405
+ )
406
+ task_table._query = query_def.to_string()
407
+
408
+ return self.predict_task(
409
+ task_table,
410
+ explain=explain,
411
+ run_mode=run_mode,
412
+ num_neighbors=num_neighbors,
413
+ num_hops=num_hops,
414
+ verbose=verbose,
415
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
416
+ use_prediction_time=use_prediction_time,
417
+ top_k=query_def.top_k,
418
+ )
419
+
420
+ @overload
421
+ def predict_task(
422
+ self,
423
+ task: TaskTable,
424
+ *,
425
+ explain: Literal[False] = False,
426
+ run_mode: RunMode | str = RunMode.FAST,
427
+ num_neighbors: list[int] | None = None,
428
+ num_hops: int = 2,
429
+ verbose: bool | ProgressLogger = True,
430
+ exclude_cols_dict: dict[str, list[str]] | None = None,
431
+ use_prediction_time: bool = False,
432
+ top_k: int | None = None,
433
+ ) -> pd.DataFrame:
434
+ pass
435
+
436
+ @overload
437
+ def predict_task(
438
+ self,
439
+ task: TaskTable,
440
+ *,
441
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
442
+ run_mode: RunMode | str = RunMode.FAST,
443
+ num_neighbors: list[int] | None = None,
444
+ num_hops: int = 2,
445
+ verbose: bool | ProgressLogger = True,
446
+ exclude_cols_dict: dict[str, list[str]] | None = None,
447
+ use_prediction_time: bool = False,
448
+ top_k: int | None = None,
449
+ ) -> Explanation:
450
+ pass
451
+
452
+ @overload
453
+ def predict_task(
454
+ self,
455
+ task: TaskTable,
456
+ *,
457
+ explain: bool | ExplainConfig | dict[str, Any] = False,
458
+ run_mode: RunMode | str = RunMode.FAST,
459
+ num_neighbors: list[int] | None = None,
460
+ num_hops: int = 2,
461
+ verbose: bool | ProgressLogger = True,
462
+ exclude_cols_dict: dict[str, list[str]] | None = None,
463
+ use_prediction_time: bool = False,
464
+ top_k: int | None = None,
465
+ ) -> pd.DataFrame | Explanation:
466
+ pass
467
+
468
+ def predict_task(
469
+ self,
470
+ task: TaskTable,
471
+ *,
472
+ explain: bool | ExplainConfig | dict[str, Any] = False,
473
+ run_mode: RunMode | str = RunMode.FAST,
474
+ num_neighbors: list[int] | None = None,
475
+ num_hops: int = 2,
476
+ verbose: bool | ProgressLogger = True,
477
+ exclude_cols_dict: dict[str, list[str]] | None = None,
478
+ use_prediction_time: bool = False,
479
+ top_k: int | None = None,
480
+ ) -> pd.DataFrame | Explanation:
481
+ """Returns predictions for a custom task specification.
482
+
483
+ Args:
484
+ task: The custom :class:`TaskTable`.
485
+ explain: Configuration for explainability.
486
+ If set to ``True``, will additionally explain the prediction.
487
+ Passing in an :class:`ExplainConfig` instance provides control
488
+ over which parts of explanation are generated.
489
+ Explainability is currently only supported for single entity
490
+ predictions with ``run_mode="FAST"``.
491
+ run_mode: The :class:`RunMode` for the query.
492
+ num_neighbors: The number of neighbors to sample for each hop.
493
+ If specified, the ``num_hops`` option will be ignored.
494
+ num_hops: The number of hops to sample when generating the context.
495
+ verbose: Whether to print verbose output.
496
+ exclude_cols_dict: Any column in any table to exclude from the
497
+ model input.
498
+ use_prediction_time: Whether to use the anchor timestamp as an
499
+ additional feature during prediction. This is typically
500
+ beneficial for time series forecasting tasks.
501
+ top_k: The number of predictions to return per entity.
502
+
503
+ Returns:
504
+ The predictions as a :class:`pandas.DataFrame`.
505
+ If ``explain`` is provided, returns an :class:`Explanation` object
506
+ containing the prediction, summary, and details.
507
+ """
320
508
  if num_hops != 2 and num_neighbors is not None:
321
509
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
322
510
  f"custom 'num_hops={num_hops}' option")
511
+ if num_neighbors is None:
512
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
513
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
514
+
515
+ explain_config: ExplainConfig | None = None
516
+ if explain is True:
517
+ explain_config = ExplainConfig()
518
+ elif explain is not False:
519
+ explain_config = ExplainConfig._cast(explain)
323
520
 
324
521
  if explain_config is not None and run_mode in {
325
522
  RunMode.NORMAL, RunMode.BEST
@@ -328,83 +525,82 @@ class KumoRFM:
328
525
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
329
526
  f"mode has been reset. Please lower the run mode to "
330
527
  f"suppress this warning.")
528
+ run_mode = RunMode.FAST
331
529
 
332
- if indices is None:
333
- if query_def.rfm_entity_ids is None:
334
- raise ValueError("Cannot find entities to predict for. Please "
335
- "pass them via `predict(query, indices=...)`")
336
- indices = query_def.get_rfm_entity_id_list()
337
- else:
338
- query_def = replace(query_def, rfm_entity_ids=None)
339
-
340
- if len(indices) == 0:
341
- raise ValueError("At least one entity is required")
342
-
343
- if explain_config is not None and len(indices) > 1:
344
- raise ValueError(
345
- f"Cannot explain predictions for more than a single entity "
346
- f"(got {len(indices)})")
347
-
348
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
349
- if explain_config is not None:
350
- msg = f'[bold]EXPLAIN[/bold] {query_repr}'
351
- else:
352
- msg = f'[bold]PREDICT[/bold] {query_repr}'
530
+ if explain_config is not None and task.num_prediction_examples > 1:
531
+ raise ValueError(f"Cannot explain predictions for more than a "
532
+ f"single entity "
533
+ f"(got {task.num_prediction_examples:,})")
353
534
 
354
535
  if not isinstance(verbose, ProgressLogger):
355
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
356
-
357
- with verbose as logger:
358
-
359
- batch_size: Optional[int] = None
360
- if self._batch_size == 'max':
361
- task_type = LocalPQueryDriver.get_task_type(
362
- query_def,
363
- edge_types=self._graph_store.edge_types,
364
- )
365
- batch_size = _MAX_PRED_SIZE[task_type]
536
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
537
+ task_type_repr = 'binary classification'
538
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
539
+ task_type_repr = 'multi-class classification'
540
+ elif task.task_type == TaskType.REGRESSION:
541
+ task_type_repr = 'regression'
542
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
543
+ task_type_repr = 'link prediction'
366
544
  else:
367
- batch_size = self._batch_size
545
+ task_type_repr = str(task.task_type)
368
546
 
369
- if batch_size is not None:
370
- offsets = range(0, len(indices), batch_size)
371
- batches = [indices[step:step + batch_size] for step in offsets]
547
+ if explain_config is not None:
548
+ msg = f"Explaining {task_type_repr} task"
372
549
  else:
373
- batches = [indices]
550
+ msg = f"Predicting {task_type_repr} task"
551
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
374
552
 
375
- if len(batches) > 1:
376
- logger.log(f"Splitting {len(indices):,} entities into "
377
- f"{len(batches):,} batches of size {batch_size:,}")
553
+ with verbose as logger:
554
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
555
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
556
+ f"out of {task.num_context_examples:,} in-context "
557
+ f"examples")
558
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
559
+
560
+ if self._batch_size is None:
561
+ batch_size = task.num_prediction_examples
562
+ elif self._batch_size == 'max':
563
+ batch_size = _MAX_PRED_SIZE[task.task_type]
564
+ else:
565
+ batch_size = self._batch_size
378
566
 
379
- predictions: List[pd.DataFrame] = []
380
- summary: Optional[str] = None
381
- details: Optional[Explanation] = None
382
- for i, batch in enumerate(batches):
383
- # TODO Re-use the context for subsequent predictions.
567
+ if batch_size > _MAX_PRED_SIZE[task.task_type]:
568
+ raise ValueError(f"Cannot predict for more than "
569
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
570
+ f"entities at once (got {batch_size:,}). Use "
571
+ f"`KumoRFM.batch_mode` to process entities "
572
+ f"in batches with a sufficient batch size.")
573
+
574
+ if task.num_prediction_examples > batch_size:
575
+ num = math.ceil(task.num_prediction_examples / batch_size)
576
+ logger.log(f"Splitting {task.num_prediction_examples:,} "
577
+ f"entities into {num:,} batches of size "
578
+ f"{batch_size:,}")
579
+
580
+ predictions: list[pd.DataFrame] = []
581
+ summary: str | None = None
582
+ details: Explanation | None = None
583
+ for start in range(0, task.num_prediction_examples, batch_size):
384
584
  context = self._get_context(
385
- query=query_def,
386
- indices=batch,
387
- anchor_time=anchor_time,
388
- context_anchor_time=context_anchor_time,
389
- run_mode=RunMode(run_mode),
585
+ task=task.narrow_prediction(start, length=batch_size),
586
+ run_mode=run_mode,
390
587
  num_neighbors=num_neighbors,
391
- num_hops=num_hops,
392
- max_pq_iterations=max_pq_iterations,
393
- evaluate=False,
394
- random_seed=random_seed,
395
- logger=logger if i == 0 else None,
588
+ exclude_cols_dict=exclude_cols_dict,
589
+ top_k=top_k,
396
590
  )
591
+ context.y_test = None
592
+
397
593
  request = RFMPredictRequest(
398
594
  context=context,
399
595
  run_mode=RunMode(run_mode),
400
- query=query_str,
596
+ query=task._query,
401
597
  use_prediction_time=use_prediction_time,
402
598
  )
403
599
  with warnings.catch_warnings():
404
600
  warnings.filterwarnings('ignore', message='gencode')
405
601
  request_msg = request.to_protobuf()
406
602
  _bytes = request_msg.SerializeToString()
407
- if i == 0:
603
+ if start == 0:
408
604
  logger.log(f"Generated context of size "
409
605
  f"{len(_bytes) / (1024*1024):.2f}MB")
410
606
 
@@ -412,14 +608,11 @@ class KumoRFM:
412
608
  stats = Context.get_memory_stats(request_msg.context)
413
609
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
414
610
 
415
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
416
- and len(batches) > 1):
417
- verbose.init_progress(
418
- total=len(batches),
419
- description='Predicting',
420
- )
611
+ if start == 0 and task.num_prediction_examples > batch_size:
612
+ num = math.ceil(task.num_prediction_examples / batch_size)
613
+ verbose.init_progress(total=num, description='Predicting')
421
614
 
422
- for attempt in range(self.num_retries + 1):
615
+ for attempt in range(self._num_retries + 1):
423
616
  try:
424
617
  if explain_config is not None:
425
618
  resp = self._api_client.explain(
@@ -434,10 +627,10 @@ class KumoRFM:
434
627
 
435
628
  # Cast 'ENTITY' to correct data type:
436
629
  if 'ENTITY' in df:
437
- entity = query_def.entity_table
438
- pkey_map = self._graph_store.pkey_map_dict[entity]
439
- df['ENTITY'] = df['ENTITY'].astype(
440
- type(pkey_map.index[0]))
630
+ table_dict = context.subgraph.table_dict
631
+ table = table_dict[context.entity_table_names[0]]
632
+ ser = table.df[table.primary_key]
633
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
441
634
 
442
635
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
443
636
  if 'ANCHOR_TIMESTAMP' in df:
@@ -452,13 +645,12 @@ class KumoRFM:
452
645
 
453
646
  predictions.append(df)
454
647
 
455
- if (isinstance(verbose, InteractiveProgressLogger)
456
- and len(batches) > 1):
648
+ if task.num_prediction_examples > batch_size:
457
649
  verbose.step()
458
650
 
459
651
  break
460
652
  except HTTPException as e:
461
- if attempt == self.num_retries:
653
+ if attempt == self._num_retries:
462
654
  try:
463
655
  msg = json.loads(e.detail)['detail']
464
656
  except Exception:
@@ -488,69 +680,19 @@ class KumoRFM:
488
680
 
489
681
  return prediction
490
682
 
491
- def is_valid_entity(
492
- self,
493
- query: str,
494
- indices: Union[List[str], List[float], List[int], None] = None,
495
- *,
496
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
497
- ) -> np.ndarray:
498
- r"""Returns a mask that denotes which entities are valid for the
499
- given predictive query, *i.e.*, which entities fulfill (temporal)
500
- entity filter constraints.
501
-
502
- Args:
503
- query: The predictive query.
504
- indices: The entity primary keys to predict for. Will override the
505
- indices given as part of the predictive query.
506
- anchor_time: The anchor timestamp for the prediction. If set to
507
- ``None``, will use the maximum timestamp in the data.
508
- If set to ``"entity"``, will use the timestamp of the entity.
509
- """
510
- query_def = self._parse_query(query)
511
-
512
- if indices is None:
513
- if query_def.rfm_entity_ids is None:
514
- raise ValueError("Cannot find entities to predict for. Please "
515
- "pass them via "
516
- "`is_valid_entity(query, indices=...)`")
517
- indices = query_def.get_rfm_entity_id_list()
518
-
519
- if len(indices) == 0:
520
- raise ValueError("At least one entity is required")
521
-
522
- if anchor_time is None:
523
- anchor_time = self._graph_store.max_time
524
-
525
- if isinstance(anchor_time, pd.Timestamp):
526
- self._validate_time(query_def, anchor_time, None, False)
527
- else:
528
- assert anchor_time == 'entity'
529
- if (query_def.entity_table not in self._graph_store.time_dict):
530
- raise ValueError(f"Anchor time 'entity' requires the entity "
531
- f"table '{query_def.entity_table}' "
532
- f"to have a time column.")
533
-
534
- node = self._graph_store.get_node_id(
535
- table_name=query_def.entity_table,
536
- pkey=pd.Series(indices),
537
- )
538
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
539
- return query_driver.is_valid(node, anchor_time)
540
-
541
683
  def evaluate(
542
684
  self,
543
685
  query: str,
544
686
  *,
545
- metrics: Optional[List[str]] = None,
546
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
547
- context_anchor_time: Union[pd.Timestamp, None] = None,
548
- run_mode: Union[RunMode, str] = RunMode.FAST,
549
- num_neighbors: Optional[List[int]] = None,
687
+ metrics: list[str] | None = None,
688
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
689
+ context_anchor_time: pd.Timestamp | None = None,
690
+ run_mode: RunMode | str = RunMode.FAST,
691
+ num_neighbors: list[int] | None = None,
550
692
  num_hops: int = 2,
551
- max_pq_iterations: int = 20,
552
- random_seed: Optional[int] = _RANDOM_SEED,
553
- verbose: Union[bool, ProgressLogger] = True,
693
+ max_pq_iterations: int = 10,
694
+ random_seed: int | None = _RANDOM_SEED,
695
+ verbose: bool | ProgressLogger = True,
554
696
  use_prediction_time: bool = False,
555
697
  ) -> pd.DataFrame:
556
698
  """Evaluates a predictive query.
@@ -582,41 +724,120 @@ class KumoRFM:
582
724
  Returns:
583
725
  The metrics as a :class:`pandas.DataFrame`
584
726
  """
585
- query_def = self._parse_query(query)
727
+ query_def = replace(
728
+ self._parse_query(query),
729
+ for_each='FOR EACH',
730
+ rfm_entity_ids=None,
731
+ )
732
+
733
+ if not isinstance(verbose, ProgressLogger):
734
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
735
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
736
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
737
+
738
+ with verbose as logger:
739
+ task_table = self._get_task_table(
740
+ query=query_def,
741
+ indices=None,
742
+ anchor_time=anchor_time,
743
+ context_anchor_time=context_anchor_time,
744
+ run_mode=run_mode,
745
+ max_pq_iterations=max_pq_iterations,
746
+ random_seed=random_seed,
747
+ logger=logger,
748
+ )
749
+
750
+ return self.evaluate_task(
751
+ task_table,
752
+ metrics=metrics,
753
+ run_mode=run_mode,
754
+ num_neighbors=num_neighbors,
755
+ num_hops=num_hops,
756
+ verbose=verbose,
757
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
758
+ use_prediction_time=use_prediction_time,
759
+ )
760
+
761
+ def evaluate_task(
762
+ self,
763
+ task: TaskTable,
764
+ *,
765
+ metrics: list[str] | None = None,
766
+ run_mode: RunMode | str = RunMode.FAST,
767
+ num_neighbors: list[int] | None = None,
768
+ num_hops: int = 2,
769
+ verbose: bool | ProgressLogger = True,
770
+ exclude_cols_dict: dict[str, list[str]] | None = None,
771
+ use_prediction_time: bool = False,
772
+ ) -> pd.DataFrame:
773
+ """Evaluates a custom task specification.
586
774
 
775
+ Args:
776
+ task: The custom :class:`TaskTable`.
777
+ metrics: The metrics to use.
778
+ run_mode: The :class:`RunMode` for the query.
779
+ num_neighbors: The number of neighbors to sample for each hop.
780
+ If specified, the ``num_hops`` option will be ignored.
781
+ num_hops: The number of hops to sample when generating the context.
782
+ verbose: Whether to print verbose output.
783
+ exclude_cols_dict: Any column in any table to exclude from the
784
+ model input.
785
+ use_prediction_time: Whether to use the anchor timestamp as an
786
+ additional feature during prediction. This is typically
787
+ beneficial for time series forecasting tasks.
788
+
789
+ Returns:
790
+ The metrics as a :class:`pandas.DataFrame`
791
+ """
587
792
  if num_hops != 2 and num_neighbors is not None:
588
793
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
589
794
  f"custom 'num_hops={num_hops}' option")
795
+ if num_neighbors is None:
796
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
797
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
590
798
 
591
- if query_def.rfm_entity_ids is not None:
592
- query_def = replace(
593
- query_def,
594
- rfm_entity_ids=None,
595
- )
596
-
597
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
598
- msg = f'[bold]EVALUATE[/bold] {query_repr}'
799
+ if metrics is not None and len(metrics) > 0:
800
+ self._validate_metrics(metrics, task.task_type)
801
+ metrics = list(dict.fromkeys(metrics))
599
802
 
600
803
  if not isinstance(verbose, ProgressLogger):
601
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
804
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
805
+ task_type_repr = 'binary classification'
806
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
807
+ task_type_repr = 'multi-class classification'
808
+ elif task.task_type == TaskType.REGRESSION:
809
+ task_type_repr = 'regression'
810
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
811
+ task_type_repr = 'link prediction'
812
+ else:
813
+ task_type_repr = str(task.task_type)
814
+
815
+ msg = f"Evaluating {task_type_repr} task"
816
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
602
817
 
603
818
  with verbose as logger:
819
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
820
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
821
+ f"out of {task.num_context_examples:,} in-context "
822
+ f"examples")
823
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
824
+
825
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
826
+ logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
827
+ f"out of {task.num_prediction_examples:,} test "
828
+ f"examples")
829
+ task = task.narrow_prediction(
830
+ start=0,
831
+ length=_MAX_TEST_SIZE[task.task_type],
832
+ )
833
+
604
834
  context = self._get_context(
605
- query=query_def,
606
- indices=None,
607
- anchor_time=anchor_time,
608
- context_anchor_time=context_anchor_time,
609
- run_mode=RunMode(run_mode),
835
+ task=task,
836
+ run_mode=run_mode,
610
837
  num_neighbors=num_neighbors,
611
- num_hops=num_hops,
612
- max_pq_iterations=max_pq_iterations,
613
- evaluate=True,
614
- random_seed=random_seed,
615
- logger=logger if verbose else None,
838
+ exclude_cols_dict=exclude_cols_dict,
616
839
  )
617
- if metrics is not None and len(metrics) > 0:
618
- self._validate_metrics(metrics, context.task_type)
619
- metrics = list(dict.fromkeys(metrics))
840
+
620
841
  request = RFMEvaluateRequest(
621
842
  context=context,
622
843
  run_mode=RunMode(run_mode),
@@ -634,17 +855,23 @@ class KumoRFM:
634
855
  stats_msg = Context.get_memory_stats(request_msg.context)
635
856
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
636
857
 
637
- try:
638
- resp = self._api_client.evaluate(request_bytes)
639
- except HTTPException as e:
858
+ for attempt in range(self._num_retries + 1):
640
859
  try:
641
- msg = json.loads(e.detail)['detail']
642
- except Exception:
643
- msg = e.detail
644
- raise RuntimeError(f"An unexpected exception occurred. "
645
- f"Please create an issue at "
646
- f"'https://github.com/kumo-ai/kumo-rfm'. "
647
- f"{msg}") from None
860
+ resp = self._api_client.evaluate(request_bytes)
861
+ break
862
+ except HTTPException as e:
863
+ if attempt == self._num_retries:
864
+ try:
865
+ msg = json.loads(e.detail)['detail']
866
+ except Exception:
867
+ msg = e.detail
868
+ raise RuntimeError(
869
+ f"An unexpected exception occurred. Please create "
870
+ f"an issue at "
871
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
872
+ ) from None
873
+
874
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
648
875
 
649
876
  return pd.DataFrame.from_dict(
650
877
  resp.metrics,
@@ -657,9 +884,9 @@ class KumoRFM:
657
884
  query: str,
658
885
  size: int,
659
886
  *,
660
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
661
- random_seed: Optional[int] = _RANDOM_SEED,
662
- max_iterations: int = 20,
887
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
888
+ random_seed: int | None = _RANDOM_SEED,
889
+ max_iterations: int = 10,
663
890
  ) -> pd.DataFrame:
664
891
  """Returns the labels of a predictive query for a specified anchor
665
892
  time.
@@ -679,40 +906,37 @@ class KumoRFM:
679
906
  query_def = self._parse_query(query)
680
907
 
681
908
  if anchor_time is None:
682
- anchor_time = self._graph_store.max_time
909
+ anchor_time = self._get_default_anchor_time(query_def)
683
910
  if query_def.target_ast.date_offset_range is not None:
684
- anchor_time = anchor_time - (
685
- query_def.target_ast.date_offset_range.end_date_offset *
686
- query_def.num_forecasts)
911
+ offset = query_def.target_ast.date_offset_range.end_date_offset
912
+ offset *= query_def.num_forecasts
913
+ anchor_time -= offset
687
914
 
688
915
  assert anchor_time is not None
689
916
  if isinstance(anchor_time, pd.Timestamp):
690
917
  self._validate_time(query_def, anchor_time, None, evaluate=True)
691
918
  else:
692
919
  assert anchor_time == 'entity'
693
- if (query_def.entity_table not in self._graph_store.time_dict):
920
+ if query_def.entity_table not in self._sampler.time_column_dict:
694
921
  raise ValueError(f"Anchor time 'entity' requires the entity "
695
922
  f"table '{query_def.entity_table}' "
696
923
  f"to have a time column")
697
924
 
698
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
699
- random_seed)
700
-
701
- node, time, y = query_driver.collect_test(
702
- size=size,
703
- anchor_time=anchor_time,
704
- batch_size=min(10_000, size),
705
- max_iterations=max_iterations,
706
- guarantee_train_examples=False,
925
+ train, test = self._sampler.sample_target(
926
+ query=query_def,
927
+ num_train_examples=0,
928
+ train_anchor_time=anchor_time,
929
+ num_train_trials=0,
930
+ num_test_examples=size,
931
+ test_anchor_time=anchor_time,
932
+ num_test_trials=max_iterations * size,
933
+ random_seed=random_seed,
707
934
  )
708
935
 
709
- entity = self._graph_store.pkey_map_dict[
710
- query_def.entity_table].index[node]
711
-
712
936
  return pd.DataFrame({
713
- 'ENTITY': entity,
714
- 'ANCHOR_TIMESTAMP': time,
715
- 'TARGET': y,
937
+ 'ENTITY': test.entity_pkey,
938
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
939
+ 'TARGET': test.target,
716
940
  })
717
941
 
718
942
  # Helpers #################################################################
@@ -727,63 +951,128 @@ class KumoRFM:
727
951
  "`predict()` or `evaluate()` methods to perform "
728
952
  "predictions or evaluations.")
729
953
 
730
- try:
731
- request = RFMParseQueryRequest(
732
- query=query,
733
- graph_definition=self._graph_def,
734
- )
954
+ request = RFMParseQueryRequest(
955
+ query=query,
956
+ graph_definition=self._graph_def,
957
+ )
735
958
 
736
- resp = self._api_client.parse_query(request)
959
+ for attempt in range(self._num_retries + 1):
960
+ try:
961
+ resp = self._api_client.parse_query(request)
962
+ break
963
+ except HTTPException as e:
964
+ if attempt == self._num_retries:
965
+ try:
966
+ msg = json.loads(e.detail)['detail']
967
+ except Exception:
968
+ msg = e.detail
969
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
737
970
 
738
- # TODO Expose validation warnings.
971
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
739
972
 
740
- if len(resp.validation_response.warnings) > 0:
741
- msg = '\n'.join([
742
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
743
- in enumerate(resp.validation_response.warnings)
744
- ])
745
- warnings.warn(f"Encountered the following warnings during "
746
- f"parsing:\n{msg}")
973
+ if len(resp.validation_response.warnings) > 0:
974
+ msg = '\n'.join([
975
+ f'{i+1}. {warning.title}: {warning.message}'
976
+ for i, warning in enumerate(resp.validation_response.warnings)
977
+ ])
978
+ warnings.warn(f"Encountered the following warnings during "
979
+ f"parsing:\n{msg}")
747
980
 
748
- return resp.query
749
- except HTTPException as e:
750
- try:
751
- msg = json.loads(e.detail)['detail']
752
- except Exception:
753
- msg = e.detail
754
- raise ValueError(f"Failed to parse query '{query}'. "
755
- f"{msg}") from None
981
+ return resp.query
982
+
983
+ @staticmethod
984
+ def _get_task_type(
985
+ query: ValidatedPredictiveQuery,
986
+ edge_types: list[tuple[str, str, str]],
987
+ ) -> TaskType:
988
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
989
+ return TaskType.BINARY_CLASSIFICATION
990
+
991
+ target = query.target_ast
992
+ if isinstance(target, Join):
993
+ target = target.rhs_target
994
+ if isinstance(target, Aggregation):
995
+ if target.aggr == AggregationType.LIST_DISTINCT:
996
+ table_name, col_name = target._get_target_column_name().split(
997
+ '.')
998
+ target_edge_types = [
999
+ edge_type for edge_type in edge_types
1000
+ if edge_type[0] == table_name and edge_type[1] == col_name
1001
+ ]
1002
+ if len(target_edge_types) != 1:
1003
+ raise NotImplementedError(
1004
+ f"Multilabel-classification queries based on "
1005
+ f"'LIST_DISTINCT' are not supported yet. If you "
1006
+ f"planned to write a link prediction query instead, "
1007
+ f"make sure to register '{col_name}' as a "
1008
+ f"foreign key.")
1009
+ return TaskType.TEMPORAL_LINK_PREDICTION
1010
+
1011
+ return TaskType.REGRESSION
1012
+
1013
+ assert isinstance(target, Column)
1014
+
1015
+ if target.stype in {Stype.ID, Stype.categorical}:
1016
+ return TaskType.MULTICLASS_CLASSIFICATION
1017
+
1018
+ if target.stype in {Stype.numerical}:
1019
+ return TaskType.REGRESSION
1020
+
1021
+ raise NotImplementedError("Task type not yet supported")
1022
+
1023
+ def _get_default_anchor_time(
1024
+ self,
1025
+ query: ValidatedPredictiveQuery | None = None,
1026
+ ) -> pd.Timestamp:
1027
+ if query is not None and query.query_type == QueryType.TEMPORAL:
1028
+ aggr_table_names = [
1029
+ aggr._get_target_column_name().split('.')[0]
1030
+ for aggr in query.get_all_target_aggregations()
1031
+ ]
1032
+ return self._sampler.get_max_time(aggr_table_names)
1033
+
1034
+ return self._sampler.get_max_time()
756
1035
 
757
1036
  def _validate_time(
758
1037
  self,
759
1038
  query: ValidatedPredictiveQuery,
760
1039
  anchor_time: pd.Timestamp,
761
- context_anchor_time: Union[pd.Timestamp, None],
1040
+ context_anchor_time: pd.Timestamp | None,
762
1041
  evaluate: bool,
763
1042
  ) -> None:
764
1043
 
765
- if self._graph_store.min_time == pd.Timestamp.max:
1044
+ if len(self._sampler.time_column_dict) == 0:
766
1045
  return # Graph without timestamps
767
1046
 
768
- if anchor_time < self._graph_store.min_time:
1047
+ if query.query_type == QueryType.TEMPORAL:
1048
+ aggr_table_names = [
1049
+ aggr._get_target_column_name().split('.')[0]
1050
+ for aggr in query.get_all_target_aggregations()
1051
+ ]
1052
+ min_time = self._sampler.get_min_time(aggr_table_names)
1053
+ max_time = self._sampler.get_max_time(aggr_table_names)
1054
+ else:
1055
+ min_time = self._sampler.get_min_time()
1056
+ max_time = self._sampler.get_max_time()
1057
+
1058
+ if anchor_time < min_time:
769
1059
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
770
- f"the earliest timestamp "
771
- f"'{self._graph_store.min_time}' in the data.")
1060
+ f"the earliest timestamp '{min_time}' in the "
1061
+ f"data.")
772
1062
 
773
- if (context_anchor_time is not None
774
- and context_anchor_time < self._graph_store.min_time):
1063
+ if context_anchor_time is not None and context_anchor_time < min_time:
775
1064
  raise ValueError(f"Context anchor timestamp is too early or "
776
1065
  f"aggregation time range is too large. To make "
777
1066
  f"this prediction, we would need data back to "
778
1067
  f"'{context_anchor_time}', however, your data "
779
- f"only contains data back to "
780
- f"'{self._graph_store.min_time}'.")
1068
+ f"only contains data back to '{min_time}'.")
781
1069
 
782
1070
  if query.target_ast.date_offset_range is not None:
783
1071
  end_offset = query.target_ast.date_offset_range.end_date_offset
784
1072
  else:
785
1073
  end_offset = pd.DateOffset(0)
786
- forecast_end_offset = end_offset * query.num_forecasts
1074
+ end_offset = end_offset * query.num_forecasts
1075
+
787
1076
  if (context_anchor_time is not None
788
1077
  and context_anchor_time > anchor_time):
789
1078
  warnings.warn(f"Context anchor timestamp "
@@ -793,7 +1082,7 @@ class KumoRFM:
793
1082
  f"intended.")
794
1083
  elif (query.query_type == QueryType.TEMPORAL
795
1084
  and context_anchor_time is not None
796
- and context_anchor_time + forecast_end_offset > anchor_time):
1085
+ and context_anchor_time + end_offset > anchor_time):
797
1086
  warnings.warn(f"Aggregation for context examples at timestamp "
798
1087
  f"'{context_anchor_time}' will leak information "
799
1088
  f"from the prediction anchor timestamp "
@@ -801,62 +1090,44 @@ class KumoRFM:
801
1090
  f"intended.")
802
1091
 
803
1092
  elif (context_anchor_time is not None
804
- and context_anchor_time - forecast_end_offset
805
- < self._graph_store.min_time):
806
- _time = context_anchor_time - forecast_end_offset
1093
+ and context_anchor_time - end_offset < min_time):
1094
+ _time = context_anchor_time - end_offset
807
1095
  warnings.warn(f"Context anchor timestamp is too early or "
808
1096
  f"aggregation time range is too large. To form "
809
1097
  f"proper input data, we would need data back to "
810
1098
  f"'{_time}', however, your data only contains "
811
- f"data back to '{self._graph_store.min_time}'.")
1099
+ f"data back to '{min_time}'.")
812
1100
 
813
- if (not evaluate and anchor_time
814
- > self._graph_store.max_time + pd.DateOffset(days=1)):
1101
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
815
1102
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
816
- f"latest timestamp '{self._graph_store.max_time}' "
817
- f"in the data. Please make sure this is intended.")
1103
+ f"latest timestamp '{max_time}' in the data. Please "
1104
+ f"make sure this is intended.")
818
1105
 
819
- max_eval_time = self._graph_store.max_time - forecast_end_offset
820
- if evaluate and anchor_time > max_eval_time:
1106
+ if evaluate and anchor_time > max_time - end_offset:
821
1107
  raise ValueError(
822
1108
  f"Anchor timestamp for evaluation is after the latest "
823
- f"supported timestamp '{max_eval_time}'.")
1109
+ f"supported timestamp '{max_time - end_offset}'.")
824
1110
 
825
- def _get_context(
1111
+ def _get_task_table(
826
1112
  self,
827
1113
  query: ValidatedPredictiveQuery,
828
- indices: Union[List[str], List[float], List[int], None],
829
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
830
- context_anchor_time: Union[pd.Timestamp, None],
831
- run_mode: RunMode,
832
- num_neighbors: Optional[List[int]],
833
- num_hops: int,
834
- max_pq_iterations: int,
835
- evaluate: bool,
836
- random_seed: Optional[int] = _RANDOM_SEED,
837
- logger: Optional[ProgressLogger] = None,
838
- ) -> Context:
839
-
840
- if num_neighbors is not None:
841
- num_hops = len(num_neighbors)
842
-
843
- if num_hops < 0:
844
- raise ValueError(f"'num_hops' must be non-negative "
845
- f"(got {num_hops})")
846
- if num_hops > 6:
847
- raise ValueError(f"Cannot predict on subgraphs with more than 6 "
848
- f"hops (got {num_hops}). Please reduce the "
849
- f"number of hops and try again. Please create a "
850
- f"feature request at "
851
- f"'https://github.com/kumo-ai/kumo-rfm' if you "
852
- f"must go beyond this for your use-case.")
853
-
854
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
855
- task_type = LocalPQueryDriver.get_task_type(
856
- query,
857
- edge_types=self._graph_store.edge_types,
1114
+ indices: list[str] | list[float] | list[int] | None,
1115
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
1116
+ context_anchor_time: pd.Timestamp | None = None,
1117
+ run_mode: RunMode = RunMode.FAST,
1118
+ max_pq_iterations: int = 10,
1119
+ random_seed: int | None = _RANDOM_SEED,
1120
+ logger: ProgressLogger | None = None,
1121
+ ) -> TaskTable:
1122
+
1123
+ task_type = self._get_task_type(
1124
+ query=query,
1125
+ edge_types=self._sampler.edge_types,
858
1126
  )
859
1127
 
1128
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1129
+ num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
1130
+
860
1131
  if logger is not None:
861
1132
  if task_type == TaskType.BINARY_CLASSIFICATION:
862
1133
  task_type_repr = 'binary classification'
@@ -870,30 +1141,17 @@ class KumoRFM:
870
1141
  task_type_repr = str(task_type)
871
1142
  logger.log(f"Identified {query.query_type} {task_type_repr} task")
872
1143
 
873
- if task_type.is_link_pred and num_hops < 2:
874
- raise ValueError(f"Cannot perform link prediction on subgraphs "
875
- f"with less than 2 hops (got {num_hops}) since "
876
- f"historical target entities need to be part of "
877
- f"the context. Please increase the number of "
878
- f"hops and try again.")
879
-
880
- if num_neighbors is None:
881
- if run_mode == RunMode.DEBUG:
882
- num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
883
- elif run_mode == RunMode.FAST or task_type.is_link_pred:
884
- num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
885
- else:
886
- num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
887
-
888
1144
  if query.target_ast.date_offset_range is None:
889
- end_offset = pd.DateOffset(0)
1145
+ step_offset = pd.DateOffset(0)
890
1146
  else:
891
- end_offset = query.target_ast.date_offset_range.end_date_offset
892
- forecast_end_offset = end_offset * query.num_forecasts
1147
+ step_offset = query.target_ast.date_offset_range.end_date_offset
1148
+ end_offset = step_offset * query.num_forecasts
1149
+
893
1150
  if anchor_time is None:
894
- anchor_time = self._graph_store.max_time
895
- if evaluate:
896
- anchor_time = anchor_time - forecast_end_offset
1151
+ anchor_time = self._get_default_anchor_time(query)
1152
+ if num_test_examples > 0:
1153
+ anchor_time = anchor_time - end_offset
1154
+
897
1155
  if logger is not None:
898
1156
  assert isinstance(anchor_time, pd.Timestamp)
899
1157
  if anchor_time == pd.Timestamp.min:
@@ -905,114 +1163,98 @@ class KumoRFM:
905
1163
  else:
906
1164
  logger.log(f"Derived anchor time {anchor_time}")
907
1165
 
908
- assert anchor_time is not None
909
1166
  if isinstance(anchor_time, pd.Timestamp):
1167
+ if context_anchor_time == 'entity':
1168
+ raise ValueError("Anchor time 'entity' needs to be shared "
1169
+ "for context and prediction examples")
910
1170
  if context_anchor_time is None:
911
- context_anchor_time = anchor_time - forecast_end_offset
1171
+ context_anchor_time = anchor_time - end_offset
912
1172
  self._validate_time(query, anchor_time, context_anchor_time,
913
- evaluate)
1173
+ evaluate=num_test_examples > 0)
914
1174
  else:
915
1175
  assert anchor_time == 'entity'
916
- if query.entity_table not in self._graph_store.time_dict:
1176
+ if query.query_type != QueryType.STATIC:
1177
+ raise ValueError("Anchor time 'entity' is only valid for "
1178
+ "static predictive queries")
1179
+ if query.entity_table not in self._sampler.time_column_dict:
917
1180
  raise ValueError(f"Anchor time 'entity' requires the entity "
918
1181
  f"table '{query.entity_table}' to "
919
1182
  f"have a time column")
920
- if context_anchor_time is not None:
921
- warnings.warn("Ignoring option 'context_anchor_time' for "
922
- "`anchor_time='entity'`")
923
- context_anchor_time = None
924
-
925
- y_test: Optional[pd.Series] = None
926
- if evaluate:
927
- max_test_size = _MAX_TEST_SIZE[run_mode]
928
- if task_type.is_link_pred:
929
- max_test_size = max_test_size // 5
930
-
931
- test_node, test_time, y_test = query_driver.collect_test(
932
- size=max_test_size,
933
- anchor_time=anchor_time,
934
- max_iterations=max_pq_iterations,
935
- guarantee_train_examples=True,
936
- )
937
- if logger is not None:
938
- if task_type == TaskType.BINARY_CLASSIFICATION:
939
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
940
- msg = (f"Collected {len(y_test):,} test examples with "
941
- f"{pos:.2f}% positive cases")
942
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
943
- msg = (f"Collected {len(y_test):,} test examples "
944
- f"holding {y_test.nunique()} classes")
945
- elif task_type == TaskType.REGRESSION:
946
- _min, _max = float(y_test.min()), float(y_test.max())
947
- msg = (f"Collected {len(y_test):,} test examples with "
948
- f"targets between {format_value(_min)} and "
949
- f"{format_value(_max)}")
950
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
951
- num_rhs = y_test.explode().nunique()
952
- msg = (f"Collected {len(y_test):,} test examples with "
953
- f"{num_rhs:,} unique items")
954
- else:
955
- raise NotImplementedError
956
- logger.log(msg)
957
-
958
- else:
959
- assert indices is not None
960
-
961
- if len(indices) > _MAX_PRED_SIZE[task_type]:
962
- raise ValueError(f"Cannot predict for more than "
963
- f"{_MAX_PRED_SIZE[task_type]:,} entities at "
964
- f"once (got {len(indices):,}). Use "
965
- f"`KumoRFM.batch_mode` to process entities "
966
- f"in batches")
1183
+ if isinstance(context_anchor_time, pd.Timestamp):
1184
+ raise ValueError("Anchor time 'entity' needs to be shared "
1185
+ "for context and prediction examples")
1186
+ context_anchor_time = 'entity'
1187
+
1188
+ train, test = self._sampler.sample_target(
1189
+ query=query,
1190
+ num_train_examples=num_train_examples,
1191
+ train_anchor_time=context_anchor_time,
1192
+ num_train_trials=max_pq_iterations * num_train_examples,
1193
+ num_test_examples=num_test_examples,
1194
+ test_anchor_time=anchor_time,
1195
+ num_test_trials=max_pq_iterations * num_test_examples,
1196
+ random_seed=random_seed,
1197
+ )
1198
+ train_pkey, train_time, train_y = train
1199
+ test_pkey, test_time, test_y = test
967
1200
 
968
- test_node = self._graph_store.get_node_id(
969
- table_name=query.entity_table,
970
- pkey=pd.Series(indices),
971
- )
1201
+ if num_test_examples > 0 and logger is not None:
1202
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1203
+ pos = 100 * int((test_y > 0).sum()) / len(test_y)
1204
+ msg = (f"Collected {len(test_y):,} test examples with "
1205
+ f"{pos:.2f}% positive cases")
1206
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1207
+ msg = (f"Collected {len(test_y):,} test examples holding "
1208
+ f"{test_y.nunique()} classes")
1209
+ elif task_type == TaskType.REGRESSION:
1210
+ _min, _max = float(test_y.min()), float(test_y.max())
1211
+ msg = (f"Collected {len(test_y):,} test examples with targets "
1212
+ f"between {format_value(_min)} and "
1213
+ f"{format_value(_max)}")
1214
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1215
+ num_rhs = test_y.explode().nunique()
1216
+ msg = (f"Collected {len(test_y):,} test examples with "
1217
+ f"{num_rhs:,} unique items")
1218
+ else:
1219
+ raise NotImplementedError
1220
+ logger.log(msg)
972
1221
 
1222
+ if num_test_examples == 0:
1223
+ assert indices is not None
1224
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
973
1225
  if isinstance(anchor_time, pd.Timestamp):
974
- test_time = pd.Series(anchor_time).repeat(
975
- len(test_node)).reset_index(drop=True)
1226
+ test_time = pd.Series([anchor_time]).repeat(
1227
+ len(indices)).reset_index(drop=True)
976
1228
  else:
977
- time = self._graph_store.time_dict[query.entity_table]
978
- time = time[test_node] * 1000**3
979
- test_time = pd.Series(time, dtype='datetime64[ns]')
980
-
981
- train_node, train_time, y_train = query_driver.collect_train(
982
- size=_MAX_CONTEXT_SIZE[run_mode],
983
- anchor_time=context_anchor_time or 'entity',
984
- exclude_node=test_node if (query.query_type == QueryType.STATIC
985
- or anchor_time == 'entity') else None,
986
- max_iterations=max_pq_iterations,
987
- )
1229
+ train_time = test_time = 'entity'
988
1230
 
989
1231
  if logger is not None:
990
1232
  if task_type == TaskType.BINARY_CLASSIFICATION:
991
- pos = 100 * int((y_train > 0).sum()) / len(y_train)
992
- msg = (f"Collected {len(y_train):,} in-context examples with "
1233
+ pos = 100 * int((train_y > 0).sum()) / len(train_y)
1234
+ msg = (f"Collected {len(train_y):,} in-context examples with "
993
1235
  f"{pos:.2f}% positive cases")
994
1236
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
995
- msg = (f"Collected {len(y_train):,} in-context examples "
996
- f"holding {y_train.nunique()} classes")
1237
+ msg = (f"Collected {len(train_y):,} in-context examples "
1238
+ f"holding {train_y.nunique()} classes")
997
1239
  elif task_type == TaskType.REGRESSION:
998
- _min, _max = float(y_train.min()), float(y_train.max())
999
- msg = (f"Collected {len(y_train):,} in-context examples with "
1240
+ _min, _max = float(train_y.min()), float(train_y.max())
1241
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1000
1242
  f"targets between {format_value(_min)} and "
1001
1243
  f"{format_value(_max)}")
1002
1244
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1003
- num_rhs = y_train.explode().nunique()
1004
- msg = (f"Collected {len(y_train):,} in-context examples with "
1245
+ num_rhs = train_y.explode().nunique()
1246
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1005
1247
  f"{num_rhs:,} unique items")
1006
1248
  else:
1007
1249
  raise NotImplementedError
1008
1250
  logger.log(msg)
1009
1251
 
1010
- entity_table_names: Tuple[str, ...]
1252
+ entity_table_names: tuple[str] | tuple[str, str]
1011
1253
  if task_type.is_link_pred:
1012
1254
  final_aggr = query.get_final_target_aggregation()
1013
1255
  assert final_aggr is not None
1014
1256
  edge_fkey = final_aggr._get_target_column_name()
1015
- for edge_type in self._graph_store.edge_types:
1257
+ for edge_type in self._sampler.edge_types:
1016
1258
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1017
1259
  entity_table_names = (
1018
1260
  query.entity_table,
@@ -1021,23 +1263,80 @@ class KumoRFM:
1021
1263
  else:
1022
1264
  entity_table_names = (query.entity_table, )
1023
1265
 
1266
+ context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
1267
+ if isinstance(train_time, pd.Series):
1268
+ context_df['ANCHOR_TIMESTAMP'] = train_time
1269
+ pred_df = pd.DataFrame({'ENTITY': test_pkey})
1270
+ if num_test_examples > 0:
1271
+ pred_df['TARGET'] = test_y
1272
+ if isinstance(test_time, pd.Series):
1273
+ pred_df['ANCHOR_TIMESTAMP'] = test_time
1274
+
1275
+ return TaskTable(
1276
+ task_type=task_type,
1277
+ context_df=context_df,
1278
+ pred_df=pred_df,
1279
+ entity_table_name=entity_table_names,
1280
+ entity_column='ENTITY',
1281
+ target_column='TARGET',
1282
+ time_column='ANCHOR_TIMESTAMP' if isinstance(
1283
+ train_time, pd.Series) else TaskTable.ENTITY_TIME,
1284
+ )
1285
+
1286
+ def _get_context(
1287
+ self,
1288
+ task: TaskTable,
1289
+ run_mode: RunMode | str = RunMode.FAST,
1290
+ num_neighbors: list[int] | None = None,
1291
+ exclude_cols_dict: dict[str, list[str]] | None = None,
1292
+ top_k: int | None = None,
1293
+ ) -> Context:
1294
+
1295
+ if num_neighbors is None:
1296
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1297
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
1298
+
1299
+ if len(num_neighbors) > 6:
1300
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
1301
+ f"hops (got {len(num_neighbors)}). Reduce the "
1302
+ f"number of hops and try again. Please create a "
1303
+ f"feature request at "
1304
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1305
+ f"must go beyond this for your use-case.")
1306
+
1024
1307
  # Exclude the entity anchor time from the feature set to prevent
1025
1308
  # running out-of-distribution between in-context and test examples:
1026
- exclude_cols_dict = query.get_exclude_cols_dict()
1027
- if anchor_time == 'entity':
1028
- if entity_table_names[0] not in exclude_cols_dict:
1029
- exclude_cols_dict[entity_table_names[0]] = []
1030
- time_column_dict = self._graph_store.time_column_dict
1031
- time_column = time_column_dict[entity_table_names[0]]
1032
- exclude_cols_dict[entity_table_names[0]].append(time_column)
1033
-
1034
- subgraph = self._graph_sampler(
1035
- entity_table_names=entity_table_names,
1036
- node=np.concatenate([train_node, test_node]),
1037
- time=np.concatenate([
1038
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1040
- ]),
1309
+ exclude_cols_dict = exclude_cols_dict or {}
1310
+ if task.entity_table_name in self._sampler.time_column_dict:
1311
+ if task.entity_table_name not in exclude_cols_dict:
1312
+ exclude_cols_dict[task.entity_table_name] = []
1313
+ time_col = self._sampler.time_column_dict[task.entity_table_name]
1314
+ exclude_cols_dict[task.entity_table_name].append(time_col)
1315
+
1316
+ entity_pkey = pd.concat([
1317
+ task._context_df[task._entity_column],
1318
+ task._pred_df[task._entity_column],
1319
+ ], axis=0, ignore_index=True)
1320
+
1321
+ if task.use_entity_time:
1322
+ if task.entity_table_name not in self._sampler.time_column_dict:
1323
+ raise ValueError(f"The given annchor time requires the entity "
1324
+ f"table '{task.entity_table_name}' to have a "
1325
+ f"time column")
1326
+ anchor_time = 'entity'
1327
+ elif task._time_column is not None:
1328
+ anchor_time = pd.concat([
1329
+ task._context_df[task._time_column],
1330
+ task._pred_df[task._time_column],
1331
+ ], axis=0, ignore_index=True)
1332
+ else:
1333
+ anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
1334
+ (len(entity_pkey))).reset_index(drop=True)
1335
+
1336
+ subgraph = self._sampler.sample_subgraph(
1337
+ entity_table_names=task.entity_table_names,
1338
+ entity_pkey=entity_pkey,
1339
+ anchor_time=anchor_time,
1041
1340
  num_neighbors=num_neighbors,
1042
1341
  exclude_cols_dict=exclude_cols_dict,
1043
1342
  )
@@ -1049,23 +1348,26 @@ class KumoRFM:
1049
1348
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1050
1349
  f"must go beyond this for your use-case.")
1051
1350
 
1052
- step_size: Optional[int] = None
1053
- if query.query_type == QueryType.TEMPORAL:
1054
- step_size = date_offset_to_seconds(end_offset)
1351
+ if (task.task_type.is_link_pred
1352
+ and task.entity_table_names[-1] not in subgraph.table_dict):
1353
+ raise ValueError("Cannot perform link prediction on subgraphs "
1354
+ "without any historical target entities. Please "
1355
+ "increase the number of hops and try again.")
1055
1356
 
1056
1357
  return Context(
1057
- task_type=task_type,
1058
- entity_table_names=entity_table_names,
1358
+ task_type=task.task_type,
1359
+ entity_table_names=task.entity_table_names,
1059
1360
  subgraph=subgraph,
1060
- y_train=y_train,
1061
- y_test=y_test,
1062
- top_k=query.top_k,
1063
- step_size=step_size,
1361
+ y_train=task._context_df[task.target_column.name],
1362
+ y_test=task._pred_df[task.target_column.name]
1363
+ if task.evaluate else None,
1364
+ top_k=top_k,
1365
+ step_size=None,
1064
1366
  )
1065
1367
 
1066
1368
  @staticmethod
1067
1369
  def _validate_metrics(
1068
- metrics: List[str],
1370
+ metrics: list[str],
1069
1371
  task_type: TaskType,
1070
1372
  ) -> None:
1071
1373
 
@@ -1122,7 +1424,7 @@ class KumoRFM:
1122
1424
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1123
1425
 
1124
1426
 
1125
- def format_value(value: Union[int, float]) -> str:
1427
+ def format_value(value: int | float) -> str:
1126
1428
  if value == int(value):
1127
1429
  return f'{int(value):,}'
1128
1430
  if abs(value) >= 1000: