kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/graph_store.py +52 -104
  10. kumoai/experimental/rfm/backend/local/sampler.py +125 -55
  11. kumoai/experimental/rfm/backend/local/table.py +35 -31
  12. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +174 -49
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  18. kumoai/experimental/rfm/base/__init__.py +21 -5
  19. kumoai/experimental/rfm/base/column.py +96 -10
  20. kumoai/experimental/rfm/base/expression.py +44 -0
  21. kumoai/experimental/rfm/base/sampler.py +422 -35
  22. kumoai/experimental/rfm/base/source.py +2 -1
  23. kumoai/experimental/rfm/base/sql_sampler.py +144 -0
  24. kumoai/experimental/rfm/base/table.py +386 -195
  25. kumoai/experimental/rfm/graph.py +350 -178
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +7 -4
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +1 -2
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +630 -408
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +290 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/testing/snow.py +50 -0
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +3 -2
  42. kumoai/utils/display.py +51 -0
  43. kumoai/utils/progress_logger.py +190 -12
  44. kumoai/utils/sql.py +3 -0
  45. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +3 -2
  46. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +49 -40
  47. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  48. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  49. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
  50. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
  51. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,24 @@
1
1
  import json
2
+ import math
2
3
  import time
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from collections.abc import Generator
6
+ from collections.abc import Generator, Iterator
6
7
  from contextlib import contextmanager
7
8
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
9
+ from typing import Any, Literal, overload
19
10
 
20
11
  import numpy as np
21
12
  import pandas as pd
22
13
  from kumoapi.model_plan import RunMode
23
14
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
15
+ from kumoapi.pquery.AST import (
16
+ Aggregation,
17
+ Column,
18
+ Condition,
19
+ Join,
20
+ LogicalOperation,
21
+ )
24
22
  from kumoapi.rfm import Context
25
23
  from kumoapi.rfm import Explanation as ExplanationConfig
26
24
  from kumoapi.rfm import (
@@ -29,35 +27,35 @@ from kumoapi.rfm import (
29
27
  RFMPredictRequest,
30
28
  )
31
29
  from kumoapi.task import TaskType
30
+ from kumoapi.typing import AggregationType, Stype
32
31
 
33
32
  from kumoai.client.rfm import RFMAPI
34
33
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import Graph
36
- from kumoai.experimental.rfm.backend.local import LocalGraphStore
37
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
34
+ from kumoai.experimental.rfm import Graph, TaskTable
35
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
36
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
37
+ from kumoai.utils import ProgressLogger, display
44
38
 
45
39
  _RANDOM_SEED = 42
46
40
 
47
41
  _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
42
  _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
43
 
44
+ _MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
45
+ _MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
46
+
50
47
  _MAX_CONTEXT_SIZE = {
51
48
  RunMode.DEBUG: 100,
52
49
  RunMode.FAST: 1_000,
53
50
  RunMode.NORMAL: 5_000,
54
51
  RunMode.BEST: 10_000,
55
52
  }
56
- _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
57
- RunMode.DEBUG: 100,
58
- RunMode.FAST: 2_000,
59
- RunMode.NORMAL: 2_000,
60
- RunMode.BEST: 2_000,
53
+
54
+ _DEFAULT_NUM_NEIGHBORS = {
55
+ RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
56
+ RunMode.FAST: [32, 32, 8, 8, 4, 4],
57
+ RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
58
+ RunMode.BEST: [64, 64, 8, 8, 4, 4],
61
59
  }
62
60
 
63
61
  _MAX_SIZE = 30 * 1024 * 1024
@@ -95,24 +93,26 @@ class Explanation:
95
93
  def __getitem__(self, index: Literal[1]) -> str:
96
94
  pass
97
95
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
96
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
97
  if index == 0:
100
98
  return self.prediction
101
99
  if index == 1:
102
100
  return self.summary
103
101
  raise IndexError("Index out of range")
104
102
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
103
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
104
  return iter((self.prediction, self.summary))
107
105
 
108
106
  def __repr__(self) -> str:
109
107
  return str((self.prediction, self.summary))
110
108
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
109
+ def print(self) -> None:
110
+ r"""Prints the explanation."""
111
+ display.dataframe(self.prediction)
112
+ display.message(self.summary)
113
113
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
114
+ def _ipython_display_(self) -> None:
115
+ self.print()
116
116
 
117
117
 
118
118
  class KumoRFM:
@@ -151,20 +151,35 @@ class KumoRFM:
151
151
  Args:
152
152
  graph: The graph.
153
153
  verbose: Whether to print verbose output.
154
+ optimize: If set to ``True``, will optimize the underlying data backend
155
+ for optimal querying. For example, for transactional database
156
+ backends, will create any missing indices. Requires write-access to
157
+ the data backend.
154
158
  """
155
159
  def __init__(
156
160
  self,
157
161
  graph: Graph,
158
- verbose: Union[bool, ProgressLogger] = True,
162
+ verbose: bool | ProgressLogger = True,
163
+ optimize: bool = False,
159
164
  ) -> None:
160
165
  graph = graph.validate()
161
166
  self._graph_def = graph._to_api_graph_definition()
162
- self._graph_store = LocalGraphStore(graph, verbose)
163
- self._graph_sampler = LocalGraphSampler(self._graph_store)
164
167
 
165
- self._client: Optional[RFMAPI] = None
168
+ if graph.backend == DataBackend.LOCAL:
169
+ from kumoai.experimental.rfm.backend.local import LocalSampler
170
+ self._sampler: Sampler = LocalSampler(graph, verbose)
171
+ elif graph.backend == DataBackend.SQLITE:
172
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
173
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
174
+ elif graph.backend == DataBackend.SNOWFLAKE:
175
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
176
+ self._sampler = SnowSampler(graph, verbose)
177
+ else:
178
+ raise NotImplementedError
166
179
 
167
- self._batch_size: Optional[int | Literal['max']] = None
180
+ self._client: RFMAPI | None = None
181
+
182
+ self._batch_size: int | Literal['max'] | None = None
168
183
  self.num_retries: int = 0
169
184
 
170
185
  @property
@@ -182,7 +197,7 @@ class KumoRFM:
182
197
  @contextmanager
183
198
  def batch_mode(
184
199
  self,
185
- batch_size: Union[int, Literal['max']] = 'max',
200
+ batch_size: int | Literal['max'] = 'max',
186
201
  num_retries: int = 1,
187
202
  ) -> Generator[None, None, None]:
188
203
  """Context manager to predict in batches.
@@ -216,17 +231,17 @@ class KumoRFM:
216
231
  def predict(
217
232
  self,
218
233
  query: str,
219
- indices: Union[List[str], List[float], List[int], None] = None,
234
+ indices: list[str] | list[float] | list[int] | None = None,
220
235
  *,
221
236
  explain: Literal[False] = False,
222
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
223
- context_anchor_time: Union[pd.Timestamp, None] = None,
224
- run_mode: Union[RunMode, str] = RunMode.FAST,
225
- num_neighbors: Optional[List[int]] = None,
237
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
238
+ context_anchor_time: pd.Timestamp | None = None,
239
+ run_mode: RunMode | str = RunMode.FAST,
240
+ num_neighbors: list[int] | None = None,
226
241
  num_hops: int = 2,
227
- max_pq_iterations: int = 20,
228
- random_seed: Optional[int] = _RANDOM_SEED,
229
- verbose: Union[bool, ProgressLogger] = True,
242
+ max_pq_iterations: int = 10,
243
+ random_seed: int | None = _RANDOM_SEED,
244
+ verbose: bool | ProgressLogger = True,
230
245
  use_prediction_time: bool = False,
231
246
  ) -> pd.DataFrame:
232
247
  pass
@@ -235,17 +250,17 @@ class KumoRFM:
235
250
  def predict(
236
251
  self,
237
252
  query: str,
238
- indices: Union[List[str], List[float], List[int], None] = None,
253
+ indices: list[str] | list[float] | list[int] | None = None,
239
254
  *,
240
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
241
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
242
- context_anchor_time: Union[pd.Timestamp, None] = None,
243
- run_mode: Union[RunMode, str] = RunMode.FAST,
244
- num_neighbors: Optional[List[int]] = None,
255
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
256
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
257
+ context_anchor_time: pd.Timestamp | None = None,
258
+ run_mode: RunMode | str = RunMode.FAST,
259
+ num_neighbors: list[int] | None = None,
245
260
  num_hops: int = 2,
246
- max_pq_iterations: int = 20,
247
- random_seed: Optional[int] = _RANDOM_SEED,
248
- verbose: Union[bool, ProgressLogger] = True,
261
+ max_pq_iterations: int = 10,
262
+ random_seed: int | None = _RANDOM_SEED,
263
+ verbose: bool | ProgressLogger = True,
249
264
  use_prediction_time: bool = False,
250
265
  ) -> Explanation:
251
266
  pass
@@ -253,19 +268,19 @@ class KumoRFM:
253
268
  def predict(
254
269
  self,
255
270
  query: str,
256
- indices: Union[List[str], List[float], List[int], None] = None,
271
+ indices: list[str] | list[float] | list[int] | None = None,
257
272
  *,
258
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
259
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
260
- context_anchor_time: Union[pd.Timestamp, None] = None,
261
- run_mode: Union[RunMode, str] = RunMode.FAST,
262
- num_neighbors: Optional[List[int]] = None,
273
+ explain: bool | ExplainConfig | dict[str, Any] = False,
274
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
275
+ context_anchor_time: pd.Timestamp | None = None,
276
+ run_mode: RunMode | str = RunMode.FAST,
277
+ num_neighbors: list[int] | None = None,
263
278
  num_hops: int = 2,
264
- max_pq_iterations: int = 20,
265
- random_seed: Optional[int] = _RANDOM_SEED,
266
- verbose: Union[bool, ProgressLogger] = True,
279
+ max_pq_iterations: int = 10,
280
+ random_seed: int | None = _RANDOM_SEED,
281
+ verbose: bool | ProgressLogger = True,
267
282
  use_prediction_time: bool = False,
268
- ) -> Union[pd.DataFrame, Explanation]:
283
+ ) -> pd.DataFrame | Explanation:
269
284
  """Returns predictions for a predictive query.
270
285
 
271
286
  Args:
@@ -307,18 +322,133 @@ class KumoRFM:
307
322
  If ``explain`` is provided, returns an :class:`Explanation` object
308
323
  containing the prediction, summary, and details.
309
324
  """
310
- explain_config: Optional[ExplainConfig] = None
311
- if explain is True:
312
- explain_config = ExplainConfig()
313
- elif explain is not False:
314
- explain_config = ExplainConfig._cast(explain)
315
-
316
325
  query_def = self._parse_query(query)
317
- query_str = query_def.to_string()
318
326
 
327
+ if indices is None:
328
+ if query_def.rfm_entity_ids is None:
329
+ raise ValueError("Cannot find entities to predict for. Please "
330
+ "pass them via `predict(query, indices=...)`")
331
+ indices = query_def.get_rfm_entity_id_list()
332
+ else:
333
+ query_def = replace(query_def, rfm_entity_ids=None)
334
+
335
+ if not isinstance(verbose, ProgressLogger):
336
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
337
+ if explain is not False:
338
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
339
+ else:
340
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
341
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
342
+
343
+ with verbose as logger:
344
+ task_table = self._get_task_table(
345
+ query=query_def,
346
+ indices=indices,
347
+ anchor_time=anchor_time,
348
+ context_anchor_time=context_anchor_time,
349
+ run_mode=run_mode,
350
+ max_pq_iterations=max_pq_iterations,
351
+ random_seed=random_seed,
352
+ logger=logger,
353
+ )
354
+ task_table._query = query_def.to_string() # type: ignore
355
+
356
+ return self.predict_task(
357
+ task_table,
358
+ explain=explain, # type: ignore
359
+ run_mode=run_mode,
360
+ num_neighbors=num_neighbors,
361
+ num_hops=num_hops,
362
+ verbose=verbose,
363
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
364
+ use_prediction_time=use_prediction_time,
365
+ top_k=query_def.top_k,
366
+ )
367
+
368
+ @overload
369
+ def predict_task(
370
+ self,
371
+ task: TaskTable,
372
+ *,
373
+ explain: Literal[False] = False,
374
+ run_mode: RunMode | str = RunMode.FAST,
375
+ num_neighbors: list[int] | None = None,
376
+ num_hops: int = 2,
377
+ verbose: bool | ProgressLogger = True,
378
+ exclude_cols_dict: dict[str, list[str]] | None = None,
379
+ use_prediction_time: bool = False,
380
+ top_k: int | None = None,
381
+ ) -> pd.DataFrame:
382
+ pass
383
+
384
+ @overload
385
+ def predict_task(
386
+ self,
387
+ task: TaskTable,
388
+ *,
389
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
390
+ run_mode: RunMode | str = RunMode.FAST,
391
+ num_neighbors: list[int] | None = None,
392
+ num_hops: int = 2,
393
+ verbose: bool | ProgressLogger = True,
394
+ exclude_cols_dict: dict[str, list[str]] | None = None,
395
+ use_prediction_time: bool = False,
396
+ top_k: int | None = None,
397
+ ) -> Explanation:
398
+ pass
399
+
400
+ def predict_task(
401
+ self,
402
+ task: TaskTable,
403
+ *,
404
+ explain: bool | ExplainConfig | dict[str, Any] = False,
405
+ run_mode: RunMode | str = RunMode.FAST,
406
+ num_neighbors: list[int] | None = None,
407
+ num_hops: int = 2,
408
+ verbose: bool | ProgressLogger = True,
409
+ exclude_cols_dict: dict[str, list[str]] | None = None,
410
+ use_prediction_time: bool = False,
411
+ top_k: int | None = None,
412
+ ) -> pd.DataFrame | Explanation:
413
+ """Returns predictions for a custom task specification.
414
+
415
+ Args:
416
+ task: The custom :class:`TaskTable`.
417
+ explain: Configuration for explainability.
418
+ If set to ``True``, will additionally explain the prediction.
419
+ Passing in an :class:`ExplainConfig` instance provides control
420
+ over which parts of explanation are generated.
421
+ Explainability is currently only supported for single entity
422
+ predictions with ``run_mode="FAST"``.
423
+ run_mode: The :class:`RunMode` for the query.
424
+ num_neighbors: The number of neighbors to sample for each hop.
425
+ If specified, the ``num_hops`` option will be ignored.
426
+ num_hops: The number of hops to sample when generating the context.
427
+ verbose: Whether to print verbose output.
428
+ exclude_cols_dict: Any column in any table to exclude from the
429
+ model input.
430
+ use_prediction_time: Whether to use the anchor timestamp as an
431
+ additional feature during prediction. This is typically
432
+ beneficial for time series forecasting tasks.
433
+ top_k: The number of predictions to return per entity.
434
+
435
+ Returns:
436
+ The predictions as a :class:`pandas.DataFrame`.
437
+ If ``explain`` is provided, returns an :class:`Explanation` object
438
+ containing the prediction, summary, and details.
439
+ """
319
440
  if num_hops != 2 and num_neighbors is not None:
320
441
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
321
442
  f"custom 'num_hops={num_hops}' option")
443
+ if num_neighbors is None:
444
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
445
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
446
+
447
+ explain_config: ExplainConfig | None = None
448
+ if explain is True:
449
+ explain_config = ExplainConfig()
450
+ elif explain is not False:
451
+ explain_config = ExplainConfig._cast(explain)
322
452
 
323
453
  if explain_config is not None and run_mode in {
324
454
  RunMode.NORMAL, RunMode.BEST
@@ -327,83 +457,82 @@ class KumoRFM:
327
457
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
328
458
  f"mode has been reset. Please lower the run mode to "
329
459
  f"suppress this warning.")
460
+ run_mode = RunMode.FAST
330
461
 
331
- if indices is None:
332
- if query_def.rfm_entity_ids is None:
333
- raise ValueError("Cannot find entities to predict for. Please "
334
- "pass them via `predict(query, indices=...)`")
335
- indices = query_def.get_rfm_entity_id_list()
336
- else:
337
- query_def = replace(query_def, rfm_entity_ids=None)
338
-
339
- if len(indices) == 0:
340
- raise ValueError("At least one entity is required")
341
-
342
- if explain_config is not None and len(indices) > 1:
343
- raise ValueError(
344
- f"Cannot explain predictions for more than a single entity "
345
- f"(got {len(indices)})")
346
-
347
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
348
- if explain_config is not None:
349
- msg = f'[bold]EXPLAIN[/bold] {query_repr}'
350
- else:
351
- msg = f'[bold]PREDICT[/bold] {query_repr}'
462
+ if explain_config is not None and task.num_prediction_examples > 1:
463
+ raise ValueError(f"Cannot explain predictions for more than a "
464
+ f"single entity "
465
+ f"(got {task.num_prediction_examples:,})")
352
466
 
353
467
  if not isinstance(verbose, ProgressLogger):
354
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
355
-
356
- with verbose as logger:
357
-
358
- batch_size: Optional[int] = None
359
- if self._batch_size == 'max':
360
- task_type = LocalPQueryDriver.get_task_type(
361
- query_def,
362
- edge_types=self._graph_store.edge_types,
363
- )
364
- batch_size = _MAX_PRED_SIZE[task_type]
468
+ if task.task_type == TaskType.BINARY_CLASSIFICATION:
469
+ task_type_repr = 'binary classification'
470
+ elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
471
+ task_type_repr = 'multi-class classification'
472
+ elif task.task_type == TaskType.REGRESSION:
473
+ task_type_repr = 'regression'
474
+ elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
475
+ task_type_repr = 'link prediction'
365
476
  else:
366
- batch_size = self._batch_size
477
+ task_type_repr = str(task.task_type)
367
478
 
368
- if batch_size is not None:
369
- offsets = range(0, len(indices), batch_size)
370
- batches = [indices[step:step + batch_size] for step in offsets]
479
+ if explain_config is not None:
480
+ msg = f'Explain {task_type_repr} task'
371
481
  else:
372
- batches = [indices]
482
+ msg = f'Predict {task_type_repr} task'
483
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
373
484
 
374
- if len(batches) > 1:
375
- logger.log(f"Splitting {len(indices):,} entities into "
376
- f"{len(batches):,} batches of size {batch_size:,}")
485
+ with verbose as logger:
486
+ if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
487
+ logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
488
+ f"out of {task.num_context_examples:,} in-context "
489
+ f"examples")
490
+ task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
491
+
492
+ if self._batch_size is None:
493
+ batch_size = task.num_prediction_examples
494
+ elif self._batch_size == 'max':
495
+ batch_size = _MAX_PRED_SIZE[task.task_type]
496
+ else:
497
+ batch_size = self._batch_size
377
498
 
378
- predictions: List[pd.DataFrame] = []
379
- summary: Optional[str] = None
380
- details: Optional[Explanation] = None
381
- for i, batch in enumerate(batches):
382
- # TODO Re-use the context for subsequent predictions.
499
+ if batch_size > _MAX_PRED_SIZE[task.task_type]:
500
+ raise ValueError(f"Cannot predict for more than "
501
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
502
+ f"entities at once (got {batch_size:,}). Use "
503
+ f"`KumoRFM.batch_mode` to process entities "
504
+ f"in batches with a sufficient batch size.")
505
+
506
+ if task.num_prediction_examples > batch_size:
507
+ num = math.ceil(task.num_prediction_examples / batch_size)
508
+ logger.log(f"Splitting {task.num_prediction_examples:,} "
509
+ f"entities into {num:,} batches of size "
510
+ f"{batch_size:,}")
511
+
512
+ predictions: list[pd.DataFrame] = []
513
+ summary: str | None = None
514
+ details: Explanation | None = None
515
+ for start in range(0, task.num_prediction_examples, batch_size):
383
516
  context = self._get_context(
384
- query=query_def,
385
- indices=batch,
386
- anchor_time=anchor_time,
387
- context_anchor_time=context_anchor_time,
388
- run_mode=RunMode(run_mode),
517
+ task=task.narrow_prediction(start, length=batch_size),
518
+ run_mode=run_mode,
389
519
  num_neighbors=num_neighbors,
390
- num_hops=num_hops,
391
- max_pq_iterations=max_pq_iterations,
392
- evaluate=False,
393
- random_seed=random_seed,
394
- logger=logger if i == 0 else None,
520
+ exclude_cols_dict=exclude_cols_dict,
521
+ top_k=top_k,
395
522
  )
523
+ context.y_test = None
524
+
396
525
  request = RFMPredictRequest(
397
526
  context=context,
398
527
  run_mode=RunMode(run_mode),
399
- query=query_str,
528
+ query=getattr(task, '_query', ''),
400
529
  use_prediction_time=use_prediction_time,
401
530
  )
402
531
  with warnings.catch_warnings():
403
532
  warnings.filterwarnings('ignore', message='gencode')
404
533
  request_msg = request.to_protobuf()
405
534
  _bytes = request_msg.SerializeToString()
406
- if i == 0:
535
+ if start == 0:
407
536
  logger.log(f"Generated context of size "
408
537
  f"{len(_bytes) / (1024*1024):.2f}MB")
409
538
 
@@ -411,12 +540,9 @@ class KumoRFM:
411
540
  stats = Context.get_memory_stats(request_msg.context)
412
541
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
413
542
 
414
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
415
- and len(batches) > 1):
416
- verbose.init_progress(
417
- total=len(batches),
418
- description='Predicting',
419
- )
543
+ if start == 0 and task.num_prediction_examples > batch_size:
544
+ num = math.ceil(task.num_prediction_examples / batch_size)
545
+ verbose.init_progress(total=num, description='Predicting')
420
546
 
421
547
  for attempt in range(self.num_retries + 1):
422
548
  try:
@@ -433,10 +559,10 @@ class KumoRFM:
433
559
 
434
560
  # Cast 'ENTITY' to correct data type:
435
561
  if 'ENTITY' in df:
436
- entity = query_def.entity_table
437
- pkey_map = self._graph_store.pkey_map_dict[entity]
438
- df['ENTITY'] = df['ENTITY'].astype(
439
- type(pkey_map.index[0]))
562
+ table_dict = context.subgraph.table_dict
563
+ table = table_dict[context.entity_table_names[0]]
564
+ ser = table.df[table.primary_key]
565
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
440
566
 
441
567
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
442
568
  if 'ANCHOR_TIMESTAMP' in df:
@@ -451,8 +577,7 @@ class KumoRFM:
451
577
 
452
578
  predictions.append(df)
453
579
 
454
- if (isinstance(verbose, InteractiveProgressLogger)
455
- and len(batches) > 1):
580
+ if task.num_prediction_examples > batch_size:
456
581
  verbose.step()
457
582
 
458
583
  break
@@ -490,9 +615,9 @@ class KumoRFM:
490
615
  def is_valid_entity(
491
616
  self,
492
617
  query: str,
493
- indices: Union[List[str], List[float], List[int], None] = None,
618
+ indices: list[str] | list[float] | list[int] | None = None,
494
619
  *,
495
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
620
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
496
621
  ) -> np.ndarray:
497
622
  r"""Returns a mask that denotes which entities are valid for the
498
623
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -519,37 +644,32 @@ class KumoRFM:
519
644
  raise ValueError("At least one entity is required")
520
645
 
521
646
  if anchor_time is None:
522
- anchor_time = self._graph_store.max_time
647
+ anchor_time = self._get_default_anchor_time(query_def)
523
648
 
524
649
  if isinstance(anchor_time, pd.Timestamp):
525
650
  self._validate_time(query_def, anchor_time, None, False)
526
651
  else:
527
652
  assert anchor_time == 'entity'
528
- if (query_def.entity_table not in self._graph_store.time_dict):
653
+ if query_def.entity_table not in self._sampler.time_column_dict:
529
654
  raise ValueError(f"Anchor time 'entity' requires the entity "
530
655
  f"table '{query_def.entity_table}' "
531
656
  f"to have a time column.")
532
657
 
533
- node = self._graph_store.get_node_id(
534
- table_name=query_def.entity_table,
535
- pkey=pd.Series(indices),
536
- )
537
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
538
- return query_driver.is_valid(node, anchor_time)
658
+ raise NotImplementedError
539
659
 
540
660
  def evaluate(
541
661
  self,
542
662
  query: str,
543
663
  *,
544
- metrics: Optional[List[str]] = None,
545
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
546
- context_anchor_time: Union[pd.Timestamp, None] = None,
547
- run_mode: Union[RunMode, str] = RunMode.FAST,
548
- num_neighbors: Optional[List[int]] = None,
664
+ metrics: list[str] | None = None,
665
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
666
+ context_anchor_time: pd.Timestamp | None = None,
667
+ run_mode: RunMode | str = RunMode.FAST,
668
+ num_neighbors: list[int] | None = None,
549
669
  num_hops: int = 2,
550
- max_pq_iterations: int = 20,
551
- random_seed: Optional[int] = _RANDOM_SEED,
552
- verbose: Union[bool, ProgressLogger] = True,
670
+ max_pq_iterations: int = 10,
671
+ random_seed: int | None = _RANDOM_SEED,
672
+ verbose: bool | ProgressLogger = True,
553
673
  use_prediction_time: bool = False,
554
674
  ) -> pd.DataFrame:
555
675
  """Evaluates a predictive query.
@@ -582,40 +702,51 @@ class KumoRFM:
582
702
  The metrics as a :class:`pandas.DataFrame`
583
703
  """
584
704
  query_def = self._parse_query(query)
585
-
586
- if num_hops != 2 and num_neighbors is not None:
587
- warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
588
- f"custom 'num_hops={num_hops}' option")
589
-
590
705
  if query_def.rfm_entity_ids is not None:
591
706
  query_def = replace(
592
707
  query_def,
593
708
  rfm_entity_ids=None,
594
709
  )
595
710
 
596
- query_repr = query_def.to_string(rich=True, exclude_predict=True)
597
- msg = f'[bold]EVALUATE[/bold] {query_repr}'
711
+ task_type = self._get_task_type(
712
+ query=query_def,
713
+ edge_types=self._sampler.edge_types,
714
+ )
715
+
716
+ if num_hops != 2 and num_neighbors is not None:
717
+ warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
718
+ f"custom 'num_hops={num_hops}' option")
719
+ if num_neighbors is None:
720
+ key = RunMode.FAST if task_type.is_link_pred else run_mode
721
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
722
+
723
+ if metrics is not None and len(metrics) > 0:
724
+ self._validate_metrics(metrics, task_type)
725
+ metrics = list(dict.fromkeys(metrics))
598
726
 
599
727
  if not isinstance(verbose, ProgressLogger):
600
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
728
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
729
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
730
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
601
731
 
602
732
  with verbose as logger:
603
- context = self._get_context(
733
+ task_table = self._get_task_table(
604
734
  query=query_def,
605
735
  indices=None,
606
736
  anchor_time=anchor_time,
607
737
  context_anchor_time=context_anchor_time,
608
- run_mode=RunMode(run_mode),
609
- num_neighbors=num_neighbors,
610
- num_hops=num_hops,
738
+ run_mode=run_mode,
611
739
  max_pq_iterations=max_pq_iterations,
612
- evaluate=True,
613
740
  random_seed=random_seed,
614
- logger=logger if verbose else None,
741
+ logger=logger,
742
+ )
743
+ context = self._get_context(
744
+ task=task_table,
745
+ run_mode=run_mode,
746
+ num_neighbors=num_neighbors,
747
+ exclude_cols_dict=query_def.get_exclude_cols_dict(),
615
748
  )
616
- if metrics is not None and len(metrics) > 0:
617
- self._validate_metrics(metrics, context.task_type)
618
- metrics = list(dict.fromkeys(metrics))
749
+
619
750
  request = RFMEvaluateRequest(
620
751
  context=context,
621
752
  run_mode=RunMode(run_mode),
@@ -633,17 +764,23 @@ class KumoRFM:
633
764
  stats_msg = Context.get_memory_stats(request_msg.context)
634
765
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
635
766
 
636
- try:
637
- resp = self._api_client.evaluate(request_bytes)
638
- except HTTPException as e:
767
+ for attempt in range(self.num_retries + 1):
639
768
  try:
640
- msg = json.loads(e.detail)['detail']
641
- except Exception:
642
- msg = e.detail
643
- raise RuntimeError(f"An unexpected exception occurred. "
644
- f"Please create an issue at "
645
- f"'https://github.com/kumo-ai/kumo-rfm'. "
646
- f"{msg}") from None
769
+ resp = self._api_client.evaluate(request_bytes)
770
+ break
771
+ except HTTPException as e:
772
+ if attempt == self.num_retries:
773
+ try:
774
+ msg = json.loads(e.detail)['detail']
775
+ except Exception:
776
+ msg = e.detail
777
+ raise RuntimeError(
778
+ f"An unexpected exception occurred. Please create "
779
+ f"an issue at "
780
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
781
+ ) from None
782
+
783
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
647
784
 
648
785
  return pd.DataFrame.from_dict(
649
786
  resp.metrics,
@@ -656,9 +793,9 @@ class KumoRFM:
656
793
  query: str,
657
794
  size: int,
658
795
  *,
659
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
660
- random_seed: Optional[int] = _RANDOM_SEED,
661
- max_iterations: int = 20,
796
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
797
+ random_seed: int | None = _RANDOM_SEED,
798
+ max_iterations: int = 10,
662
799
  ) -> pd.DataFrame:
663
800
  """Returns the labels of a predictive query for a specified anchor
664
801
  time.
@@ -678,40 +815,37 @@ class KumoRFM:
678
815
  query_def = self._parse_query(query)
679
816
 
680
817
  if anchor_time is None:
681
- anchor_time = self._graph_store.max_time
818
+ anchor_time = self._get_default_anchor_time(query_def)
682
819
  if query_def.target_ast.date_offset_range is not None:
683
- anchor_time = anchor_time - (
684
- query_def.target_ast.date_offset_range.end_date_offset *
685
- query_def.num_forecasts)
820
+ offset = query_def.target_ast.date_offset_range.end_date_offset
821
+ offset *= query_def.num_forecasts
822
+ anchor_time -= offset
686
823
 
687
824
  assert anchor_time is not None
688
825
  if isinstance(anchor_time, pd.Timestamp):
689
826
  self._validate_time(query_def, anchor_time, None, evaluate=True)
690
827
  else:
691
828
  assert anchor_time == 'entity'
692
- if (query_def.entity_table not in self._graph_store.time_dict):
829
+ if query_def.entity_table not in self._sampler.time_column_dict:
693
830
  raise ValueError(f"Anchor time 'entity' requires the entity "
694
831
  f"table '{query_def.entity_table}' "
695
832
  f"to have a time column")
696
833
 
697
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
698
- random_seed)
699
-
700
- node, time, y = query_driver.collect_test(
701
- size=size,
702
- anchor_time=anchor_time,
703
- batch_size=min(10_000, size),
704
- max_iterations=max_iterations,
705
- guarantee_train_examples=False,
834
+ train, test = self._sampler.sample_target(
835
+ query=query_def,
836
+ num_train_examples=0,
837
+ train_anchor_time=anchor_time,
838
+ num_train_trials=0,
839
+ num_test_examples=size,
840
+ test_anchor_time=anchor_time,
841
+ num_test_trials=max_iterations * size,
842
+ random_seed=random_seed,
706
843
  )
707
844
 
708
- entity = self._graph_store.pkey_map_dict[
709
- query_def.entity_table].index[node]
710
-
711
845
  return pd.DataFrame({
712
- 'ENTITY': entity,
713
- 'ANCHOR_TIMESTAMP': time,
714
- 'TARGET': y,
846
+ 'ENTITY': test.entity_pkey,
847
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
848
+ 'TARGET': test.target,
715
849
  })
716
850
 
717
851
  # Helpers #################################################################
@@ -726,63 +860,120 @@ class KumoRFM:
726
860
  "`predict()` or `evaluate()` methods to perform "
727
861
  "predictions or evaluations.")
728
862
 
729
- try:
730
- request = RFMParseQueryRequest(
731
- query=query,
732
- graph_definition=self._graph_def,
733
- )
863
+ request = RFMParseQueryRequest(
864
+ query=query,
865
+ graph_definition=self._graph_def,
866
+ )
734
867
 
735
- resp = self._api_client.parse_query(request)
868
+ for attempt in range(self.num_retries + 1):
869
+ try:
870
+ resp = self._api_client.parse_query(request)
871
+ break
872
+ except HTTPException as e:
873
+ if attempt == self.num_retries:
874
+ try:
875
+ msg = json.loads(e.detail)['detail']
876
+ except Exception:
877
+ msg = e.detail
878
+ raise ValueError(f"Failed to parse query '{query}'. {msg}")
736
879
 
737
- # TODO Expose validation warnings.
880
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
738
881
 
739
- if len(resp.validation_response.warnings) > 0:
740
- msg = '\n'.join([
741
- f'{i+1}. {warning.title}: {warning.message}' for i, warning
742
- in enumerate(resp.validation_response.warnings)
743
- ])
744
- warnings.warn(f"Encountered the following warnings during "
745
- f"parsing:\n{msg}")
882
+ if len(resp.validation_response.warnings) > 0:
883
+ msg = '\n'.join([
884
+ f'{i+1}. {warning.title}: {warning.message}'
885
+ for i, warning in enumerate(resp.validation_response.warnings)
886
+ ])
887
+ warnings.warn(f"Encountered the following warnings during "
888
+ f"parsing:\n{msg}")
746
889
 
747
- return resp.query
748
- except HTTPException as e:
749
- try:
750
- msg = json.loads(e.detail)['detail']
751
- except Exception:
752
- msg = e.detail
753
- raise ValueError(f"Failed to parse query '{query}'. "
754
- f"{msg}") from None
890
+ return resp.query
891
+
892
+ @staticmethod
893
+ def _get_task_type(
894
+ query: ValidatedPredictiveQuery,
895
+ edge_types: list[tuple[str, str, str]],
896
+ ) -> TaskType:
897
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
898
+ return TaskType.BINARY_CLASSIFICATION
899
+
900
+ target = query.target_ast
901
+ if isinstance(target, Join):
902
+ target = target.rhs_target
903
+ if isinstance(target, Aggregation):
904
+ if target.aggr == AggregationType.LIST_DISTINCT:
905
+ table_name, col_name = target._get_target_column_name().split(
906
+ '.')
907
+ target_edge_types = [
908
+ edge_type for edge_type in edge_types
909
+ if edge_type[0] == table_name and edge_type[1] == col_name
910
+ ]
911
+ if len(target_edge_types) != 1:
912
+ raise NotImplementedError(
913
+ f"Multilabel-classification queries based on "
914
+ f"'LIST_DISTINCT' are not supported yet. If you "
915
+ f"planned to write a link prediction query instead, "
916
+ f"make sure to register '{col_name}' as a "
917
+ f"foreign key.")
918
+ return TaskType.TEMPORAL_LINK_PREDICTION
919
+
920
+ return TaskType.REGRESSION
921
+
922
+ assert isinstance(target, Column)
923
+
924
+ if target.stype in {Stype.ID, Stype.categorical}:
925
+ return TaskType.MULTICLASS_CLASSIFICATION
926
+
927
+ if target.stype in {Stype.numerical}:
928
+ return TaskType.REGRESSION
929
+
930
+ raise NotImplementedError("Task type not yet supported")
931
+
932
+ def _get_default_anchor_time(
933
+ self,
934
+ query: ValidatedPredictiveQuery | None = None,
935
+ ) -> pd.Timestamp:
936
+ if query is not None and query.query_type == QueryType.TEMPORAL:
937
+ aggr_table_names = [
938
+ aggr._get_target_column_name().split('.')[0]
939
+ for aggr in query.get_all_target_aggregations()
940
+ ]
941
+ return self._sampler.get_max_time(aggr_table_names)
942
+
943
+ return self._sampler.get_max_time()
755
944
 
756
945
  def _validate_time(
757
946
  self,
758
947
  query: ValidatedPredictiveQuery,
759
948
  anchor_time: pd.Timestamp,
760
- context_anchor_time: Union[pd.Timestamp, None],
949
+ context_anchor_time: pd.Timestamp | None,
761
950
  evaluate: bool,
762
951
  ) -> None:
763
952
 
764
- if self._graph_store.min_time == pd.Timestamp.max:
953
+ if len(self._sampler.time_column_dict) == 0:
765
954
  return # Graph without timestamps
766
955
 
767
- if anchor_time < self._graph_store.min_time:
956
+ min_time = self._sampler.get_min_time()
957
+ max_time = self._sampler.get_max_time()
958
+
959
+ if anchor_time < min_time:
768
960
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
769
- f"the earliest timestamp "
770
- f"'{self._graph_store.min_time}' in the data.")
961
+ f"the earliest timestamp '{min_time}' in the "
962
+ f"data.")
771
963
 
772
- if (context_anchor_time is not None
773
- and context_anchor_time < self._graph_store.min_time):
964
+ if context_anchor_time is not None and context_anchor_time < min_time:
774
965
  raise ValueError(f"Context anchor timestamp is too early or "
775
966
  f"aggregation time range is too large. To make "
776
967
  f"this prediction, we would need data back to "
777
968
  f"'{context_anchor_time}', however, your data "
778
- f"only contains data back to "
779
- f"'{self._graph_store.min_time}'.")
969
+ f"only contains data back to '{min_time}'.")
780
970
 
781
971
  if query.target_ast.date_offset_range is not None:
782
972
  end_offset = query.target_ast.date_offset_range.end_date_offset
783
973
  else:
784
974
  end_offset = pd.DateOffset(0)
785
- forecast_end_offset = end_offset * query.num_forecasts
975
+ end_offset = end_offset * query.num_forecasts
976
+
786
977
  if (context_anchor_time is not None
787
978
  and context_anchor_time > anchor_time):
788
979
  warnings.warn(f"Context anchor timestamp "
@@ -792,7 +983,7 @@ class KumoRFM:
792
983
  f"intended.")
793
984
  elif (query.query_type == QueryType.TEMPORAL
794
985
  and context_anchor_time is not None
795
- and context_anchor_time + forecast_end_offset > anchor_time):
986
+ and context_anchor_time + end_offset > anchor_time):
796
987
  warnings.warn(f"Aggregation for context examples at timestamp "
797
988
  f"'{context_anchor_time}' will leak information "
798
989
  f"from the prediction anchor timestamp "
@@ -800,62 +991,44 @@ class KumoRFM:
800
991
  f"intended.")
801
992
 
802
993
  elif (context_anchor_time is not None
803
- and context_anchor_time - forecast_end_offset
804
- < self._graph_store.min_time):
805
- _time = context_anchor_time - forecast_end_offset
994
+ and context_anchor_time - end_offset < min_time):
995
+ _time = context_anchor_time - end_offset
806
996
  warnings.warn(f"Context anchor timestamp is too early or "
807
997
  f"aggregation time range is too large. To form "
808
998
  f"proper input data, we would need data back to "
809
999
  f"'{_time}', however, your data only contains "
810
- f"data back to '{self._graph_store.min_time}'.")
1000
+ f"data back to '{min_time}'.")
811
1001
 
812
- if (not evaluate and anchor_time
813
- > self._graph_store.max_time + pd.DateOffset(days=1)):
1002
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
814
1003
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
815
- f"latest timestamp '{self._graph_store.max_time}' "
816
- f"in the data. Please make sure this is intended.")
1004
+ f"latest timestamp '{max_time}' in the data. Please "
1005
+ f"make sure this is intended.")
817
1006
 
818
- max_eval_time = self._graph_store.max_time - forecast_end_offset
819
- if evaluate and anchor_time > max_eval_time:
1007
+ if evaluate and anchor_time > max_time - end_offset:
820
1008
  raise ValueError(
821
1009
  f"Anchor timestamp for evaluation is after the latest "
822
- f"supported timestamp '{max_eval_time}'.")
1010
+ f"supported timestamp '{max_time - end_offset}'.")
823
1011
 
824
- def _get_context(
1012
+ def _get_task_table(
825
1013
  self,
826
1014
  query: ValidatedPredictiveQuery,
827
- indices: Union[List[str], List[float], List[int], None],
828
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
829
- context_anchor_time: Union[pd.Timestamp, None],
830
- run_mode: RunMode,
831
- num_neighbors: Optional[List[int]],
832
- num_hops: int,
833
- max_pq_iterations: int,
834
- evaluate: bool,
835
- random_seed: Optional[int] = _RANDOM_SEED,
836
- logger: Optional[ProgressLogger] = None,
837
- ) -> Context:
838
-
839
- if num_neighbors is not None:
840
- num_hops = len(num_neighbors)
841
-
842
- if num_hops < 0:
843
- raise ValueError(f"'num_hops' must be non-negative "
844
- f"(got {num_hops})")
845
- if num_hops > 6:
846
- raise ValueError(f"Cannot predict on subgraphs with more than 6 "
847
- f"hops (got {num_hops}). Please reduce the "
848
- f"number of hops and try again. Please create a "
849
- f"feature request at "
850
- f"'https://github.com/kumo-ai/kumo-rfm' if you "
851
- f"must go beyond this for your use-case.")
852
-
853
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
854
- task_type = LocalPQueryDriver.get_task_type(
855
- query,
856
- edge_types=self._graph_store.edge_types,
1015
+ indices: list[str] | list[float] | list[int] | None,
1016
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
1017
+ context_anchor_time: pd.Timestamp | None = None,
1018
+ run_mode: RunMode = RunMode.FAST,
1019
+ max_pq_iterations: int = 10,
1020
+ random_seed: int | None = _RANDOM_SEED,
1021
+ logger: ProgressLogger | None = None,
1022
+ ) -> TaskTable:
1023
+
1024
+ task_type = self._get_task_type(
1025
+ query=query,
1026
+ edge_types=self._sampler.edge_types,
857
1027
  )
858
1028
 
1029
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
1030
+ num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
1031
+
859
1032
  if logger is not None:
860
1033
  if task_type == TaskType.BINARY_CLASSIFICATION:
861
1034
  task_type_repr = 'binary classification'
@@ -869,30 +1042,17 @@ class KumoRFM:
869
1042
  task_type_repr = str(task_type)
870
1043
  logger.log(f"Identified {query.query_type} {task_type_repr} task")
871
1044
 
872
- if task_type.is_link_pred and num_hops < 2:
873
- raise ValueError(f"Cannot perform link prediction on subgraphs "
874
- f"with less than 2 hops (got {num_hops}) since "
875
- f"historical target entities need to be part of "
876
- f"the context. Please increase the number of "
877
- f"hops and try again.")
878
-
879
- if num_neighbors is None:
880
- if run_mode == RunMode.DEBUG:
881
- num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
882
- elif run_mode == RunMode.FAST or task_type.is_link_pred:
883
- num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
884
- else:
885
- num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
886
-
887
1045
  if query.target_ast.date_offset_range is None:
888
- end_offset = pd.DateOffset(0)
1046
+ step_offset = pd.DateOffset(0)
889
1047
  else:
890
- end_offset = query.target_ast.date_offset_range.end_date_offset
891
- forecast_end_offset = end_offset * query.num_forecasts
1048
+ step_offset = query.target_ast.date_offset_range.end_date_offset
1049
+ end_offset = step_offset * query.num_forecasts
1050
+
892
1051
  if anchor_time is None:
893
- anchor_time = self._graph_store.max_time
894
- if evaluate:
895
- anchor_time = anchor_time - forecast_end_offset
1052
+ anchor_time = self._get_default_anchor_time(query)
1053
+ if num_test_examples > 0:
1054
+ anchor_time = anchor_time - end_offset
1055
+
896
1056
  if logger is not None:
897
1057
  assert isinstance(anchor_time, pd.Timestamp)
898
1058
  if anchor_time == pd.Timestamp.min:
@@ -904,114 +1064,98 @@ class KumoRFM:
904
1064
  else:
905
1065
  logger.log(f"Derived anchor time {anchor_time}")
906
1066
 
907
- assert anchor_time is not None
908
1067
  if isinstance(anchor_time, pd.Timestamp):
1068
+ if context_anchor_time == 'entity':
1069
+ raise ValueError("Anchor time 'entity' needs to be shared "
1070
+ "for context and prediction examples")
909
1071
  if context_anchor_time is None:
910
- context_anchor_time = anchor_time - forecast_end_offset
1072
+ context_anchor_time = anchor_time - end_offset
911
1073
  self._validate_time(query, anchor_time, context_anchor_time,
912
- evaluate)
1074
+ evaluate=num_test_examples > 0)
913
1075
  else:
914
1076
  assert anchor_time == 'entity'
915
- if query.entity_table not in self._graph_store.time_dict:
1077
+ if query.query_type != QueryType.STATIC:
1078
+ raise ValueError("Anchor time 'entity' is only valid for "
1079
+ "static predictive queries")
1080
+ if query.entity_table not in self._sampler.time_column_dict:
916
1081
  raise ValueError(f"Anchor time 'entity' requires the entity "
917
1082
  f"table '{query.entity_table}' to "
918
1083
  f"have a time column")
919
- if context_anchor_time is not None:
920
- warnings.warn("Ignoring option 'context_anchor_time' for "
921
- "`anchor_time='entity'`")
922
- context_anchor_time = None
923
-
924
- y_test: Optional[pd.Series] = None
925
- if evaluate:
926
- max_test_size = _MAX_TEST_SIZE[run_mode]
927
- if task_type.is_link_pred:
928
- max_test_size = max_test_size // 5
929
-
930
- test_node, test_time, y_test = query_driver.collect_test(
931
- size=max_test_size,
932
- anchor_time=anchor_time,
933
- max_iterations=max_pq_iterations,
934
- guarantee_train_examples=True,
935
- )
936
- if logger is not None:
937
- if task_type == TaskType.BINARY_CLASSIFICATION:
938
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
939
- msg = (f"Collected {len(y_test):,} test examples with "
940
- f"{pos:.2f}% positive cases")
941
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
942
- msg = (f"Collected {len(y_test):,} test examples "
943
- f"holding {y_test.nunique()} classes")
944
- elif task_type == TaskType.REGRESSION:
945
- _min, _max = float(y_test.min()), float(y_test.max())
946
- msg = (f"Collected {len(y_test):,} test examples with "
947
- f"targets between {format_value(_min)} and "
948
- f"{format_value(_max)}")
949
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
950
- num_rhs = y_test.explode().nunique()
951
- msg = (f"Collected {len(y_test):,} test examples with "
952
- f"{num_rhs:,} unique items")
953
- else:
954
- raise NotImplementedError
955
- logger.log(msg)
956
-
957
- else:
958
- assert indices is not None
959
-
960
- if len(indices) > _MAX_PRED_SIZE[task_type]:
961
- raise ValueError(f"Cannot predict for more than "
962
- f"{_MAX_PRED_SIZE[task_type]:,} entities at "
963
- f"once (got {len(indices):,}). Use "
964
- f"`KumoRFM.batch_mode` to process entities "
965
- f"in batches")
1084
+ if isinstance(context_anchor_time, pd.Timestamp):
1085
+ raise ValueError("Anchor time 'entity' needs to be shared "
1086
+ "for context and prediction examples")
1087
+ context_anchor_time = 'entity'
1088
+
1089
+ train, test = self._sampler.sample_target(
1090
+ query=query,
1091
+ num_train_examples=num_train_examples,
1092
+ train_anchor_time=context_anchor_time,
1093
+ num_train_trials=max_pq_iterations * num_train_examples,
1094
+ num_test_examples=num_test_examples,
1095
+ test_anchor_time=anchor_time,
1096
+ num_test_trials=max_pq_iterations * num_test_examples,
1097
+ random_seed=random_seed,
1098
+ )
1099
+ train_pkey, train_time, train_y = train
1100
+ test_pkey, test_time, test_y = test
966
1101
 
967
- test_node = self._graph_store.get_node_id(
968
- table_name=query.entity_table,
969
- pkey=pd.Series(indices),
970
- )
1102
+ if num_test_examples > 0 and logger is not None:
1103
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1104
+ pos = 100 * int((test_y > 0).sum()) / len(test_y)
1105
+ msg = (f"Collected {len(test_y):,} test examples with "
1106
+ f"{pos:.2f}% positive cases")
1107
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1108
+ msg = (f"Collected {len(test_y):,} test examples holding "
1109
+ f"{test_y.nunique()} classes")
1110
+ elif task_type == TaskType.REGRESSION:
1111
+ _min, _max = float(test_y.min()), float(test_y.max())
1112
+ msg = (f"Collected {len(test_y):,} test examples with targets "
1113
+ f"between {format_value(_min)} and "
1114
+ f"{format_value(_max)}")
1115
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1116
+ num_rhs = test_y.explode().nunique()
1117
+ msg = (f"Collected {len(test_y):,} test examples with "
1118
+ f"{num_rhs:,} unique items")
1119
+ else:
1120
+ raise NotImplementedError
1121
+ logger.log(msg)
971
1122
 
1123
+ if num_test_examples == 0:
1124
+ assert indices is not None
1125
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
972
1126
  if isinstance(anchor_time, pd.Timestamp):
973
- test_time = pd.Series(anchor_time).repeat(
974
- len(test_node)).reset_index(drop=True)
1127
+ test_time = pd.Series([anchor_time]).repeat(
1128
+ len(indices)).reset_index(drop=True)
975
1129
  else:
976
- time = self._graph_store.time_dict[query.entity_table]
977
- time = time[test_node] * 1000**3
978
- test_time = pd.Series(time, dtype='datetime64[ns]')
979
-
980
- train_node, train_time, y_train = query_driver.collect_train(
981
- size=_MAX_CONTEXT_SIZE[run_mode],
982
- anchor_time=context_anchor_time or 'entity',
983
- exclude_node=test_node if (query.query_type == QueryType.STATIC
984
- or anchor_time == 'entity') else None,
985
- max_iterations=max_pq_iterations,
986
- )
1130
+ train_time = test_time = 'entity'
987
1131
 
988
1132
  if logger is not None:
989
1133
  if task_type == TaskType.BINARY_CLASSIFICATION:
990
- pos = 100 * int((y_train > 0).sum()) / len(y_train)
991
- msg = (f"Collected {len(y_train):,} in-context examples with "
1134
+ pos = 100 * int((train_y > 0).sum()) / len(train_y)
1135
+ msg = (f"Collected {len(train_y):,} in-context examples with "
992
1136
  f"{pos:.2f}% positive cases")
993
1137
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
994
- msg = (f"Collected {len(y_train):,} in-context examples "
995
- f"holding {y_train.nunique()} classes")
1138
+ msg = (f"Collected {len(train_y):,} in-context examples "
1139
+ f"holding {train_y.nunique()} classes")
996
1140
  elif task_type == TaskType.REGRESSION:
997
- _min, _max = float(y_train.min()), float(y_train.max())
998
- msg = (f"Collected {len(y_train):,} in-context examples with "
1141
+ _min, _max = float(train_y.min()), float(train_y.max())
1142
+ msg = (f"Collected {len(train_y):,} in-context examples with "
999
1143
  f"targets between {format_value(_min)} and "
1000
1144
  f"{format_value(_max)}")
1001
1145
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1002
- num_rhs = y_train.explode().nunique()
1003
- msg = (f"Collected {len(y_train):,} in-context examples with "
1146
+ num_rhs = train_y.explode().nunique()
1147
+ msg = (f"Collected {len(train_y):,} in-context examples with "
1004
1148
  f"{num_rhs:,} unique items")
1005
1149
  else:
1006
1150
  raise NotImplementedError
1007
1151
  logger.log(msg)
1008
1152
 
1009
- entity_table_names: Tuple[str, ...]
1153
+ entity_table_names: tuple[str] | tuple[str, str]
1010
1154
  if task_type.is_link_pred:
1011
1155
  final_aggr = query.get_final_target_aggregation()
1012
1156
  assert final_aggr is not None
1013
1157
  edge_fkey = final_aggr._get_target_column_name()
1014
- for edge_type in self._graph_store.edge_types:
1158
+ for edge_type in self._sampler.edge_types:
1015
1159
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
1160
  entity_table_names = (
1017
1161
  query.entity_table,
@@ -1020,23 +1164,98 @@ class KumoRFM:
1020
1164
  else:
1021
1165
  entity_table_names = (query.entity_table, )
1022
1166
 
1167
+ context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
1168
+ if isinstance(train_time, pd.Series):
1169
+ context_df['ANCHOR_TIMESTAMP'] = train_time
1170
+ pred_df = pd.DataFrame({'ENTITY': test_pkey})
1171
+ if num_test_examples > 0:
1172
+ pred_df['TARGET'] = test_y
1173
+ if isinstance(test_time, pd.Series):
1174
+ pred_df['ANCHOR_TIMESTAMP'] = test_time
1175
+
1176
+ return TaskTable(
1177
+ task_type=task_type,
1178
+ context_df=context_df,
1179
+ pred_df=pred_df,
1180
+ entity_table_name=entity_table_names,
1181
+ entity_column='ENTITY',
1182
+ target_column='TARGET',
1183
+ time_column='ANCHOR_TIMESTAMP' if isinstance(
1184
+ train_time, pd.Series) else TaskTable.ENTITY_TIME,
1185
+ )
1186
+
1187
+ def _get_context(
1188
+ self,
1189
+ task: TaskTable,
1190
+ run_mode: RunMode | str = RunMode.FAST,
1191
+ num_neighbors: list[int] | None = None,
1192
+ exclude_cols_dict: dict[str, list[str]] | None = None,
1193
+ top_k: int | None = None,
1194
+ ) -> Context:
1195
+
1196
+ # TODO Remove all
1197
+ if task.num_context_examples > max(_MAX_CONTEXT_SIZE.values()):
1198
+ raise ValueError(f"Cannot process a context with more than "
1199
+ f"{max(_MAX_CONTEXT_SIZE.values()):,} samples "
1200
+ f"(got {task.num_context_examples:,})")
1201
+ if task.evaluate:
1202
+ if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
1203
+ raise ValueError(f"Cannot process a test set with more than "
1204
+ f"{_MAX_TEST_SIZE[task.task_type]:,} samples "
1205
+ f"for evaluation "
1206
+ f"(got {task.num_prediction_examples:,})")
1207
+ else:
1208
+ if task.num_prediction_examples > _MAX_PRED_SIZE[task.task_type]:
1209
+ raise ValueError(f"Cannot predict for more than "
1210
+ f"{_MAX_PRED_SIZE[task.task_type]:,} "
1211
+ f"entities at once "
1212
+ f"(got {task.num_prediction_examples:,})")
1213
+
1214
+ if num_neighbors is None:
1215
+ key = RunMode.FAST if task.task_type.is_link_pred else run_mode
1216
+ num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
1217
+
1218
+ if len(num_neighbors) > 6:
1219
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
1220
+ f"hops (got {len(num_neighbors)}). Reduce the "
1221
+ f"number of hops and try again. Please create a "
1222
+ f"feature request at "
1223
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1224
+ f"must go beyond this for your use-case.")
1225
+
1023
1226
  # Exclude the entity anchor time from the feature set to prevent
1024
1227
  # running out-of-distribution between in-context and test examples:
1025
- exclude_cols_dict = query.get_exclude_cols_dict()
1026
- if anchor_time == 'entity':
1027
- if entity_table_names[0] not in exclude_cols_dict:
1028
- exclude_cols_dict[entity_table_names[0]] = []
1029
- time_column_dict = self._graph_store.time_column_dict
1030
- time_column = time_column_dict[entity_table_names[0]]
1031
- exclude_cols_dict[entity_table_names[0]].append(time_column)
1032
-
1033
- subgraph = self._graph_sampler(
1034
- entity_table_names=entity_table_names,
1035
- node=np.concatenate([train_node, test_node]),
1036
- time=np.concatenate([
1037
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1038
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- ]),
1228
+ exclude_cols_dict = exclude_cols_dict or {}
1229
+ if task.entity_table_name in self._sampler.time_column_dict:
1230
+ if task.entity_table_name not in exclude_cols_dict:
1231
+ exclude_cols_dict[task.entity_table_name] = []
1232
+ time_col = self._sampler.time_column_dict[task.entity_table_name]
1233
+ exclude_cols_dict[task.entity_table_name].append(time_col)
1234
+
1235
+ entity_pkey = pd.concat([
1236
+ task._context_df[task._entity_column],
1237
+ task._pred_df[task._entity_column],
1238
+ ], axis=0, ignore_index=True)
1239
+
1240
+ if task.use_entity_time:
1241
+ if task.entity_table_name not in self._sampler.time_column_dict:
1242
+ raise ValueError(f"The given annchor time requires the entity "
1243
+ f"table '{task.entity_table_name}' to have a "
1244
+ f"time column")
1245
+ anchor_time = 'entity'
1246
+ elif task._time_column is not None:
1247
+ anchor_time = pd.concat([
1248
+ task._context_df[task._time_column],
1249
+ task._pred_df[task._time_column],
1250
+ ], axis=0, ignore_index=True)
1251
+ else:
1252
+ anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
1253
+ (len(entity_pkey))).reset_index(drop=True)
1254
+
1255
+ subgraph = self._sampler.sample_subgraph(
1256
+ entity_table_names=task.entity_table_names,
1257
+ entity_pkey=entity_pkey,
1258
+ anchor_time=anchor_time,
1040
1259
  num_neighbors=num_neighbors,
1041
1260
  exclude_cols_dict=exclude_cols_dict,
1042
1261
  )
@@ -1048,23 +1267,26 @@ class KumoRFM:
1048
1267
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1049
1268
  f"must go beyond this for your use-case.")
1050
1269
 
1051
- step_size: Optional[int] = None
1052
- if query.query_type == QueryType.TEMPORAL:
1053
- step_size = date_offset_to_seconds(end_offset)
1270
+ if (task.task_type.is_link_pred
1271
+ and task.entity_table_names[-1] not in subgraph.table_dict):
1272
+ raise ValueError("Cannot perform link prediction on subgraphs "
1273
+ "without any historical target entities. Please "
1274
+ "increase the number of hops and try again.")
1054
1275
 
1055
1276
  return Context(
1056
- task_type=task_type,
1057
- entity_table_names=entity_table_names,
1277
+ task_type=task.task_type,
1278
+ entity_table_names=task.entity_table_names,
1058
1279
  subgraph=subgraph,
1059
- y_train=y_train,
1060
- y_test=y_test,
1061
- top_k=query.top_k,
1062
- step_size=step_size,
1280
+ y_train=task._context_df[task.target_column.name],
1281
+ y_test=task._pred_df[task.target_column.name]
1282
+ if task.evaluate else None,
1283
+ top_k=top_k,
1284
+ step_size=None,
1063
1285
  )
1064
1286
 
1065
1287
  @staticmethod
1066
1288
  def _validate_metrics(
1067
- metrics: List[str],
1289
+ metrics: list[str],
1068
1290
  task_type: TaskType,
1069
1291
  ) -> None:
1070
1292
 
@@ -1121,7 +1343,7 @@ class KumoRFM:
1121
1343
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1122
1344
 
1123
1345
 
1124
- def format_value(value: Union[int, float]) -> str:
1346
+ def format_value(value: int | float) -> str:
1125
1347
  if value == int(value):
1126
1348
  return f'{int(value):,}'
1127
1349
  if abs(value) >= 1000: