kumoai 2.12.0.dev202510231830__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. kumoai/__init__.py +41 -35
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/jobs.py +24 -0
  6. kumoai/client/pquery.py +6 -2
  7. kumoai/client/rfm.py +35 -7
  8. kumoai/connector/utils.py +23 -2
  9. kumoai/experimental/rfm/__init__.py +191 -48
  10. kumoai/experimental/rfm/authenticate.py +3 -4
  11. kumoai/experimental/rfm/backend/__init__.py +0 -0
  12. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  13. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  14. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  15. kumoai/experimental/rfm/backend/local/table.py +113 -0
  16. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  17. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  18. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  19. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  20. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  21. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  22. kumoai/experimental/rfm/base/__init__.py +30 -0
  23. kumoai/experimental/rfm/base/column.py +152 -0
  24. kumoai/experimental/rfm/base/expression.py +44 -0
  25. kumoai/experimental/rfm/base/sampler.py +761 -0
  26. kumoai/experimental/rfm/base/source.py +19 -0
  27. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  28. kumoai/experimental/rfm/base/table.py +735 -0
  29. kumoai/experimental/rfm/graph.py +1237 -0
  30. kumoai/experimental/rfm/infer/__init__.py +8 -0
  31. kumoai/experimental/rfm/infer/dtype.py +82 -0
  32. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  33. kumoai/experimental/rfm/infer/pkey.py +128 -0
  34. kumoai/experimental/rfm/infer/stype.py +35 -0
  35. kumoai/experimental/rfm/infer/time_col.py +61 -0
  36. kumoai/experimental/rfm/pquery/__init__.py +0 -4
  37. kumoai/experimental/rfm/pquery/executor.py +27 -27
  38. kumoai/experimental/rfm/pquery/pandas_executor.py +64 -40
  39. kumoai/experimental/rfm/relbench.py +76 -0
  40. kumoai/experimental/rfm/rfm.py +386 -276
  41. kumoai/experimental/rfm/sagemaker.py +138 -0
  42. kumoai/kumolib.cp311-win_amd64.pyd +0 -0
  43. kumoai/pquery/predictive_query.py +10 -6
  44. kumoai/spcs.py +1 -3
  45. kumoai/testing/decorators.py +1 -1
  46. kumoai/testing/snow.py +50 -0
  47. kumoai/trainer/distilled_trainer.py +175 -0
  48. kumoai/trainer/trainer.py +9 -10
  49. kumoai/utils/__init__.py +3 -2
  50. kumoai/utils/display.py +51 -0
  51. kumoai/utils/progress_logger.py +188 -16
  52. kumoai/utils/sql.py +3 -0
  53. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
  54. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +57 -36
  55. kumoai/experimental/rfm/local_graph.py +0 -810
  56. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  57. kumoai/experimental/rfm/local_pquery_driver.py +0 -494
  58. kumoai/experimental/rfm/local_table.py +0 -545
  59. kumoai/experimental/rfm/pquery/backend.py +0 -136
  60. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
  61. kumoai/experimental/rfm/utils.py +0 -344
  62. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
  63. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
  64. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
@@ -2,35 +2,38 @@ import json
2
2
  import time
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import Iterator, List, Literal, Optional, Union, overload
8
+ from typing import Any, Literal, overload
9
9
 
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
  from kumoapi.model_plan import RunMode
13
- from kumoapi.pquery import QueryType
13
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
14
21
  from kumoapi.rfm import Context
15
22
  from kumoapi.rfm import Explanation as ExplanationConfig
16
23
  from kumoapi.rfm import (
17
- PQueryDefinition,
18
24
  RFMEvaluateRequest,
25
+ RFMParseQueryRequest,
19
26
  RFMPredictRequest,
20
- RFMValidateQueryRequest,
21
27
  )
22
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
23
30
 
24
- from kumoai import global_state
31
+ from kumoai.client.rfm import RFMAPI
25
32
  from kumoai.exceptions import HTTPException
26
- from kumoai.experimental.rfm import LocalGraph
27
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
28
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
29
- from kumoai.experimental.rfm.local_pquery_driver import (
30
- LocalPQueryDriver,
31
- date_offset_to_seconds,
32
- )
33
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
33
+ from kumoai.experimental.rfm import Graph
34
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
35
+ from kumoai.mixin import CastMixin
36
+ from kumoai.utils import ProgressLogger, display
34
37
 
35
38
  _RANDOM_SEED = 42
36
39
 
@@ -60,6 +63,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
60
63
  "beyond this for your use-case.")
61
64
 
62
65
 
66
+ @dataclass(repr=False)
67
+ class ExplainConfig(CastMixin):
68
+ """Configuration for explainability.
69
+
70
+ Args:
71
+ skip_summary: Whether to skip generating a human-readable summary of
72
+ the explanation.
73
+ """
74
+ skip_summary: bool = False
75
+
76
+
63
77
  @dataclass(repr=False)
64
78
  class Explanation:
65
79
  prediction: pd.DataFrame
@@ -74,19 +88,27 @@ class Explanation:
74
88
  def __getitem__(self, index: Literal[1]) -> str:
75
89
  pass
76
90
 
77
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
91
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
78
92
  if index == 0:
79
93
  return self.prediction
80
94
  if index == 1:
81
95
  return self.summary
82
96
  raise IndexError("Index out of range")
83
97
 
84
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
98
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
85
99
  return iter((self.prediction, self.summary))
86
100
 
87
101
  def __repr__(self) -> str:
88
102
  return str((self.prediction, self.summary))
89
103
 
104
+ def print(self) -> None:
105
+ r"""Prints the explanation."""
106
+ display.dataframe(self.prediction)
107
+ display.message(self.summary)
108
+
109
+ def _ipython_display_(self) -> None:
110
+ self.print()
111
+
90
112
 
91
113
  class KumoRFM:
92
114
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
@@ -96,17 +118,17 @@ class KumoRFM:
96
118
  :class:`KumoRFM` is a foundation model to generate predictions for any
97
119
  relational dataset without training.
98
120
  The model is pre-trained and the class provides an interface to query the
99
- model from a :class:`LocalGraph` object.
121
+ model from a :class:`Graph` object.
100
122
 
101
123
  .. code-block:: python
102
124
 
103
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
125
+ from kumoai.experimental.rfm import Graph, KumoRFM
104
126
 
105
127
  df_users = pd.DataFrame(...)
106
128
  df_items = pd.DataFrame(...)
107
129
  df_orders = pd.DataFrame(...)
108
130
 
109
- graph = LocalGraph.from_data({
131
+ graph = Graph.from_data({
110
132
  'users': df_users,
111
133
  'items': df_items,
112
134
  'orders': df_orders,
@@ -114,47 +136,63 @@ class KumoRFM:
114
136
 
115
137
  rfm = KumoRFM(graph)
116
138
 
117
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
118
- "FOR users.user_id=0")
119
- result = rfm.query(query)
139
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
140
+ "FOR users.user_id=1")
141
+ result = rfm.predict(query)
120
142
 
121
143
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
122
144
  # 1 0.85
123
145
 
124
146
  Args:
125
147
  graph: The graph.
126
- preprocess: Whether to pre-process the data in advance during graph
127
- materialization.
128
- This is a runtime trade-off between graph materialization and model
129
- processing speed.
130
- It can be benefical to preprocess your data once and then run many
131
- queries on top to achieve maximum model speed.
132
- However, if activiated, graph materialization can take potentially
133
- much longer, especially on graphs with many large text columns.
134
- Best to tune this option manually.
135
148
  verbose: Whether to print verbose output.
149
+ optimize: If set to ``True``, will optimize the underlying data backend
150
+ for optimal querying. For example, for transactional database
151
+ backends, will create any missing indices. Requires write-access to
152
+ the data backend.
136
153
  """
137
154
  def __init__(
138
155
  self,
139
- graph: LocalGraph,
140
- preprocess: bool = False,
141
- verbose: Union[bool, ProgressLogger] = True,
156
+ graph: Graph,
157
+ verbose: bool | ProgressLogger = True,
158
+ optimize: bool = False,
142
159
  ) -> None:
143
160
  graph = graph.validate()
144
161
  self._graph_def = graph._to_api_graph_definition()
145
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
146
- self._graph_sampler = LocalGraphSampler(self._graph_store)
147
162
 
148
- self._batch_size: Optional[int | Literal['max']] = None
163
+ if graph.backend == DataBackend.LOCAL:
164
+ from kumoai.experimental.rfm.backend.local import LocalSampler
165
+ self._sampler: Sampler = LocalSampler(graph, verbose)
166
+ elif graph.backend == DataBackend.SQLITE:
167
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
168
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
169
+ elif graph.backend == DataBackend.SNOWFLAKE:
170
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
171
+ self._sampler = SnowSampler(graph, verbose)
172
+ else:
173
+ raise NotImplementedError
174
+
175
+ self._client: RFMAPI | None = None
176
+
177
+ self._batch_size: int | Literal['max'] | None = None
149
178
  self.num_retries: int = 0
150
179
 
180
+ @property
181
+ def _api_client(self) -> RFMAPI:
182
+ if self._client is not None:
183
+ return self._client
184
+
185
+ from kumoai.experimental.rfm import global_state
186
+ self._client = RFMAPI(global_state.client)
187
+ return self._client
188
+
151
189
  def __repr__(self) -> str:
152
190
  return f'{self.__class__.__name__}()'
153
191
 
154
192
  @contextmanager
155
193
  def batch_mode(
156
194
  self,
157
- batch_size: Union[int, Literal['max']] = 'max',
195
+ batch_size: int | Literal['max'] = 'max',
158
196
  num_retries: int = 1,
159
197
  ) -> Generator[None, None, None]:
160
198
  """Context manager to predict in batches.
@@ -188,17 +226,17 @@ class KumoRFM:
188
226
  def predict(
189
227
  self,
190
228
  query: str,
191
- indices: Union[List[str], List[float], List[int], None] = None,
229
+ indices: list[str] | list[float] | list[int] | None = None,
192
230
  *,
193
231
  explain: Literal[False] = False,
194
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
195
- context_anchor_time: Union[pd.Timestamp, None] = None,
196
- run_mode: Union[RunMode, str] = RunMode.FAST,
197
- num_neighbors: Optional[List[int]] = None,
232
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
233
+ context_anchor_time: pd.Timestamp | None = None,
234
+ run_mode: RunMode | str = RunMode.FAST,
235
+ num_neighbors: list[int] | None = None,
198
236
  num_hops: int = 2,
199
- max_pq_iterations: int = 20,
200
- random_seed: Optional[int] = _RANDOM_SEED,
201
- verbose: Union[bool, ProgressLogger] = True,
237
+ max_pq_iterations: int = 10,
238
+ random_seed: int | None = _RANDOM_SEED,
239
+ verbose: bool | ProgressLogger = True,
202
240
  use_prediction_time: bool = False,
203
241
  ) -> pd.DataFrame:
204
242
  pass
@@ -207,17 +245,17 @@ class KumoRFM:
207
245
  def predict(
208
246
  self,
209
247
  query: str,
210
- indices: Union[List[str], List[float], List[int], None] = None,
248
+ indices: list[str] | list[float] | list[int] | None = None,
211
249
  *,
212
- explain: Literal[True],
213
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
214
- context_anchor_time: Union[pd.Timestamp, None] = None,
215
- run_mode: Union[RunMode, str] = RunMode.FAST,
216
- num_neighbors: Optional[List[int]] = None,
250
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
251
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
252
+ context_anchor_time: pd.Timestamp | None = None,
253
+ run_mode: RunMode | str = RunMode.FAST,
254
+ num_neighbors: list[int] | None = None,
217
255
  num_hops: int = 2,
218
- max_pq_iterations: int = 20,
219
- random_seed: Optional[int] = _RANDOM_SEED,
220
- verbose: Union[bool, ProgressLogger] = True,
256
+ max_pq_iterations: int = 10,
257
+ random_seed: int | None = _RANDOM_SEED,
258
+ verbose: bool | ProgressLogger = True,
221
259
  use_prediction_time: bool = False,
222
260
  ) -> Explanation:
223
261
  pass
@@ -225,19 +263,19 @@ class KumoRFM:
225
263
  def predict(
226
264
  self,
227
265
  query: str,
228
- indices: Union[List[str], List[float], List[int], None] = None,
266
+ indices: list[str] | list[float] | list[int] | None = None,
229
267
  *,
230
- explain: bool = False,
231
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
232
- context_anchor_time: Union[pd.Timestamp, None] = None,
233
- run_mode: Union[RunMode, str] = RunMode.FAST,
234
- num_neighbors: Optional[List[int]] = None,
268
+ explain: bool | ExplainConfig | dict[str, Any] = False,
269
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
270
+ context_anchor_time: pd.Timestamp | None = None,
271
+ run_mode: RunMode | str = RunMode.FAST,
272
+ num_neighbors: list[int] | None = None,
235
273
  num_hops: int = 2,
236
- max_pq_iterations: int = 20,
237
- random_seed: Optional[int] = _RANDOM_SEED,
238
- verbose: Union[bool, ProgressLogger] = True,
274
+ max_pq_iterations: int = 10,
275
+ random_seed: int | None = _RANDOM_SEED,
276
+ verbose: bool | ProgressLogger = True,
239
277
  use_prediction_time: bool = False,
240
- ) -> Union[pd.DataFrame, Explanation]:
278
+ ) -> pd.DataFrame | Explanation:
241
279
  """Returns predictions for a predictive query.
242
280
 
243
281
  Args:
@@ -247,9 +285,12 @@ class KumoRFM:
247
285
  be generated for all indices, independent of whether they
248
286
  fulfill entity filter constraints. To pre-filter entities, use
249
287
  :meth:`~KumoRFM.is_valid_entity`.
250
- explain: If set to ``True``, will additionally explain the
251
- prediction. Explainability is currently only supported for
252
- single entity predictions with ``run_mode="FAST"``.
288
+ explain: Configuration for explainability.
289
+ If set to ``True``, will additionally explain the prediction.
290
+ Passing in an :class:`ExplainConfig` instance provides control
291
+ over which parts of explanation are generated.
292
+ Explainability is currently only supported for single entity
293
+ predictions with ``run_mode="FAST"``.
253
294
  anchor_time: The anchor timestamp for the prediction. If set to
254
295
  ``None``, will use the maximum timestamp in the data.
255
296
  If set to ``"entity"``, will use the timestamp of the entity.
@@ -273,56 +314,62 @@ class KumoRFM:
273
314
 
274
315
  Returns:
275
316
  The predictions as a :class:`pandas.DataFrame`.
276
- If ``explain=True``, additionally returns a textual summary that
277
- explains the prediction.
317
+ If ``explain`` is provided, returns an :class:`Explanation` object
318
+ containing the prediction, summary, and details.
278
319
  """
320
+ explain_config: ExplainConfig | None = None
321
+ if explain is True:
322
+ explain_config = ExplainConfig()
323
+ elif explain is not False:
324
+ explain_config = ExplainConfig._cast(explain)
325
+
279
326
  query_def = self._parse_query(query)
327
+ query_str = query_def.to_string()
280
328
 
281
329
  if num_hops != 2 and num_neighbors is not None:
282
330
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
283
331
  f"custom 'num_hops={num_hops}' option")
284
332
 
285
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
333
+ if explain_config is not None and run_mode in {
334
+ RunMode.NORMAL, RunMode.BEST
335
+ }:
286
336
  warnings.warn(f"Explainability is currently only supported for "
287
337
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
288
338
  f"mode has been reset. Please lower the run mode to "
289
339
  f"suppress this warning.")
290
340
 
291
341
  if indices is None:
292
- if query_def.entity.ids is None:
342
+ if query_def.rfm_entity_ids is None:
293
343
  raise ValueError("Cannot find entities to predict for. Please "
294
344
  "pass them via `predict(query, indices=...)`")
295
- indices = query_def.entity.ids.value
345
+ indices = query_def.get_rfm_entity_id_list()
296
346
  else:
297
- query_def = replace(
298
- query_def,
299
- entity=replace(query_def.entity, ids=None),
300
- )
347
+ query_def = replace(query_def, rfm_entity_ids=None)
301
348
 
302
349
  if len(indices) == 0:
303
350
  raise ValueError("At least one entity is required")
304
351
 
305
- if explain and len(indices) > 1:
352
+ if explain_config is not None and len(indices) > 1:
306
353
  raise ValueError(
307
354
  f"Cannot explain predictions for more than a single entity "
308
355
  f"(got {len(indices)})")
309
356
 
310
357
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
311
- if explain:
358
+ if explain_config is not None:
312
359
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
313
360
  else:
314
361
  msg = f'[bold]PREDICT[/bold] {query_repr}'
315
362
 
316
363
  if not isinstance(verbose, ProgressLogger):
317
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
364
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
318
365
 
319
366
  with verbose as logger:
320
367
 
321
- batch_size: Optional[int] = None
368
+ batch_size: int | None = None
322
369
  if self._batch_size == 'max':
323
- task_type = query_def.get_task_type(
324
- stypes=self._graph_store.stype_dict,
325
- edge_types=self._graph_store.edge_types,
370
+ task_type = self._get_task_type(
371
+ query=query_def,
372
+ edge_types=self._sampler.edge_types,
326
373
  )
327
374
  batch_size = _MAX_PRED_SIZE[task_type]
328
375
  else:
@@ -338,9 +385,9 @@ class KumoRFM:
338
385
  logger.log(f"Splitting {len(indices):,} entities into "
339
386
  f"{len(batches):,} batches of size {batch_size:,}")
340
387
 
341
- predictions: List[pd.DataFrame] = []
342
- summary: Optional[str] = None
343
- details: Optional[Explanation] = None
388
+ predictions: list[pd.DataFrame] = []
389
+ summary: str | None = None
390
+ details: Explanation | None = None
344
391
  for i, batch in enumerate(batches):
345
392
  # TODO Re-use the context for subsequent predictions.
346
393
  context = self._get_context(
@@ -359,6 +406,7 @@ class KumoRFM:
359
406
  request = RFMPredictRequest(
360
407
  context=context,
361
408
  run_mode=RunMode(run_mode),
409
+ query=query_str,
362
410
  use_prediction_time=use_prediction_time,
363
411
  )
364
412
  with warnings.catch_warnings():
@@ -373,8 +421,7 @@ class KumoRFM:
373
421
  stats = Context.get_memory_stats(request_msg.context)
374
422
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
375
423
 
376
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
377
- and len(batches) > 1):
424
+ if i == 0 and len(batches) > 1:
378
425
  verbose.init_progress(
379
426
  total=len(batches),
380
427
  description='Predicting',
@@ -382,20 +429,23 @@ class KumoRFM:
382
429
 
383
430
  for attempt in range(self.num_retries + 1):
384
431
  try:
385
- if explain:
386
- resp = global_state.client.rfm_api.explain(_bytes)
432
+ if explain_config is not None:
433
+ resp = self._api_client.explain(
434
+ request=_bytes,
435
+ skip_summary=explain_config.skip_summary,
436
+ )
387
437
  summary = resp.summary
388
438
  details = resp.details
389
439
  else:
390
- resp = global_state.client.rfm_api.predict(_bytes)
440
+ resp = self._api_client.predict(_bytes)
391
441
  df = pd.DataFrame(**resp.prediction)
392
442
 
393
443
  # Cast 'ENTITY' to correct data type:
394
444
  if 'ENTITY' in df:
395
- entity = query_def.entity.pkey.table_name
396
- pkey_map = self._graph_store.pkey_map_dict[entity]
397
- df['ENTITY'] = df['ENTITY'].astype(
398
- type(pkey_map.index[0]))
445
+ table_dict = context.subgraph.table_dict
446
+ table = table_dict[query_def.entity_table]
447
+ ser = table.df[table.primary_key]
448
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
399
449
 
400
450
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
401
451
  if 'ANCHOR_TIMESTAMP' in df:
@@ -410,8 +460,7 @@ class KumoRFM:
410
460
 
411
461
  predictions.append(df)
412
462
 
413
- if (isinstance(verbose, InteractiveProgressLogger)
414
- and len(batches) > 1):
463
+ if len(batches) > 1:
415
464
  verbose.step()
416
465
 
417
466
  break
@@ -434,7 +483,7 @@ class KumoRFM:
434
483
  else:
435
484
  prediction = pd.concat(predictions, ignore_index=True)
436
485
 
437
- if explain:
486
+ if explain_config is not None:
438
487
  assert len(predictions) == 1
439
488
  assert summary is not None
440
489
  assert details is not None
@@ -449,9 +498,9 @@ class KumoRFM:
449
498
  def is_valid_entity(
450
499
  self,
451
500
  query: str,
452
- indices: Union[List[str], List[float], List[int], None] = None,
501
+ indices: list[str] | list[float] | list[int] | None = None,
453
502
  *,
454
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
503
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
455
504
  ) -> np.ndarray:
456
505
  r"""Returns a mask that denotes which entities are valid for the
457
506
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -468,48 +517,42 @@ class KumoRFM:
468
517
  query_def = self._parse_query(query)
469
518
 
470
519
  if indices is None:
471
- if query_def.entity.ids is None:
520
+ if query_def.rfm_entity_ids is None:
472
521
  raise ValueError("Cannot find entities to predict for. Please "
473
522
  "pass them via "
474
523
  "`is_valid_entity(query, indices=...)`")
475
- indices = query_def.entity.ids.value
524
+ indices = query_def.get_rfm_entity_id_list()
476
525
 
477
526
  if len(indices) == 0:
478
527
  raise ValueError("At least one entity is required")
479
528
 
480
529
  if anchor_time is None:
481
- anchor_time = self._graph_store.max_time
530
+ anchor_time = self._get_default_anchor_time(query_def)
482
531
 
483
532
  if isinstance(anchor_time, pd.Timestamp):
484
533
  self._validate_time(query_def, anchor_time, None, False)
485
534
  else:
486
535
  assert anchor_time == 'entity'
487
- if (query_def.entity.pkey.table_name
488
- not in self._graph_store.time_dict):
536
+ if query_def.entity_table not in self._sampler.time_column_dict:
489
537
  raise ValueError(f"Anchor time 'entity' requires the entity "
490
- f"table '{query_def.entity.pkey.table_name}' "
491
- f"to have a time column")
538
+ f"table '{query_def.entity_table}' "
539
+ f"to have a time column.")
492
540
 
493
- node = self._graph_store.get_node_id(
494
- table_name=query_def.entity.pkey.table_name,
495
- pkey=pd.Series(indices),
496
- )
497
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
498
- return query_driver.is_valid(node, anchor_time)
541
+ raise NotImplementedError
499
542
 
500
543
  def evaluate(
501
544
  self,
502
545
  query: str,
503
546
  *,
504
- metrics: Optional[List[str]] = None,
505
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
506
- context_anchor_time: Union[pd.Timestamp, None] = None,
507
- run_mode: Union[RunMode, str] = RunMode.FAST,
508
- num_neighbors: Optional[List[int]] = None,
547
+ metrics: list[str] | None = None,
548
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
549
+ context_anchor_time: pd.Timestamp | None = None,
550
+ run_mode: RunMode | str = RunMode.FAST,
551
+ num_neighbors: list[int] | None = None,
509
552
  num_hops: int = 2,
510
- max_pq_iterations: int = 20,
511
- random_seed: Optional[int] = _RANDOM_SEED,
512
- verbose: Union[bool, ProgressLogger] = True,
553
+ max_pq_iterations: int = 10,
554
+ random_seed: int | None = _RANDOM_SEED,
555
+ verbose: bool | ProgressLogger = True,
513
556
  use_prediction_time: bool = False,
514
557
  ) -> pd.DataFrame:
515
558
  """Evaluates a predictive query.
@@ -547,17 +590,17 @@ class KumoRFM:
547
590
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
548
591
  f"custom 'num_hops={num_hops}' option")
549
592
 
550
- if query_def.entity.ids is not None:
593
+ if query_def.rfm_entity_ids is not None:
551
594
  query_def = replace(
552
595
  query_def,
553
- entity=replace(query_def.entity, ids=None),
596
+ rfm_entity_ids=None,
554
597
  )
555
598
 
556
599
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
557
600
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
558
601
 
559
602
  if not isinstance(verbose, ProgressLogger):
560
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
603
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
561
604
 
562
605
  with verbose as logger:
563
606
  context = self._get_context(
@@ -591,10 +634,10 @@ class KumoRFM:
591
634
 
592
635
  if len(request_bytes) > _MAX_SIZE:
593
636
  stats_msg = Context.get_memory_stats(request_msg.context)
594
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
637
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
595
638
 
596
639
  try:
597
- resp = global_state.client.rfm_api.evaluate(request_bytes)
640
+ resp = self._api_client.evaluate(request_bytes)
598
641
  except HTTPException as e:
599
642
  try:
600
643
  msg = json.loads(e.detail)['detail']
@@ -616,9 +659,9 @@ class KumoRFM:
616
659
  query: str,
617
660
  size: int,
618
661
  *,
619
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
620
- random_seed: Optional[int] = _RANDOM_SEED,
621
- max_iterations: int = 20,
662
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
663
+ random_seed: int | None = _RANDOM_SEED,
664
+ max_iterations: int = 10,
622
665
  ) -> pd.DataFrame:
623
666
  """Returns the labels of a predictive query for a specified anchor
624
667
  time.
@@ -638,45 +681,43 @@ class KumoRFM:
638
681
  query_def = self._parse_query(query)
639
682
 
640
683
  if anchor_time is None:
641
- anchor_time = self._graph_store.max_time
642
- anchor_time = anchor_time - (query_def.target.end_offset *
643
- query_def.num_forecasts)
684
+ anchor_time = self._get_default_anchor_time(query_def)
685
+ if query_def.target_ast.date_offset_range is not None:
686
+ offset = query_def.target_ast.date_offset_range.end_date_offset
687
+ offset *= query_def.num_forecasts
688
+ anchor_time -= offset
644
689
 
645
690
  assert anchor_time is not None
646
691
  if isinstance(anchor_time, pd.Timestamp):
647
692
  self._validate_time(query_def, anchor_time, None, evaluate=True)
648
693
  else:
649
694
  assert anchor_time == 'entity'
650
- if (query_def.entity.pkey.table_name
651
- not in self._graph_store.time_dict):
695
+ if query_def.entity_table not in self._sampler.time_column_dict:
652
696
  raise ValueError(f"Anchor time 'entity' requires the entity "
653
- f"table '{query_def.entity.pkey.table_name}' "
697
+ f"table '{query_def.entity_table}' "
654
698
  f"to have a time column")
655
699
 
656
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
657
- random_seed)
658
-
659
- node, time, y = query_driver.collect_test(
660
- size=size,
661
- anchor_time=anchor_time,
662
- batch_size=min(10_000, size),
663
- max_iterations=max_iterations,
664
- guarantee_train_examples=False,
700
+ train, test = self._sampler.sample_target(
701
+ query=query_def,
702
+ num_train_examples=0,
703
+ train_anchor_time=anchor_time,
704
+ num_train_trials=0,
705
+ num_test_examples=size,
706
+ test_anchor_time=anchor_time,
707
+ num_test_trials=max_iterations * size,
708
+ random_seed=random_seed,
665
709
  )
666
710
 
667
- entity = self._graph_store.pkey_map_dict[
668
- query_def.entity.pkey.table_name].index[node]
669
-
670
711
  return pd.DataFrame({
671
- 'ENTITY': entity,
672
- 'ANCHOR_TIMESTAMP': time,
673
- 'TARGET': y,
712
+ 'ENTITY': test.entity_pkey,
713
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
714
+ 'TARGET': test.target,
674
715
  })
675
716
 
676
717
  # Helpers #################################################################
677
718
 
678
- def _parse_query(self, query: str) -> PQueryDefinition:
679
- if isinstance(query, PQueryDefinition):
719
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
720
+ if isinstance(query, ValidatedPredictiveQuery):
680
721
  return query
681
722
 
682
723
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -686,13 +727,12 @@ class KumoRFM:
686
727
  "predictions or evaluations.")
687
728
 
688
729
  try:
689
- request = RFMValidateQueryRequest(
730
+ request = RFMParseQueryRequest(
690
731
  query=query,
691
732
  graph_definition=self._graph_def,
692
733
  )
693
734
 
694
- resp = global_state.client.rfm_api.validate_query(request)
695
- # TODO Expose validation warnings.
735
+ resp = self._api_client.parse_query(request)
696
736
 
697
737
  if len(resp.validation_response.warnings) > 0:
698
738
  msg = '\n'.join([
@@ -702,7 +742,7 @@ class KumoRFM:
702
742
  warnings.warn(f"Encountered the following warnings during "
703
743
  f"parsing:\n{msg}")
704
744
 
705
- return resp.query_definition
745
+ return resp.query
706
746
  except HTTPException as e:
707
747
  try:
708
748
  msg = json.loads(e.detail)['detail']
@@ -711,30 +751,91 @@ class KumoRFM:
711
751
  raise ValueError(f"Failed to parse query '{query}'. "
712
752
  f"{msg}") from None
713
753
 
754
+ @staticmethod
755
+ def _get_task_type(
756
+ query: ValidatedPredictiveQuery,
757
+ edge_types: list[tuple[str, str, str]],
758
+ ) -> TaskType:
759
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
760
+ return TaskType.BINARY_CLASSIFICATION
761
+
762
+ target = query.target_ast
763
+ if isinstance(target, Join):
764
+ target = target.rhs_target
765
+ if isinstance(target, Aggregation):
766
+ if target.aggr == AggregationType.LIST_DISTINCT:
767
+ table_name, col_name = target._get_target_column_name().split(
768
+ '.')
769
+ target_edge_types = [
770
+ edge_type for edge_type in edge_types
771
+ if edge_type[0] == table_name and edge_type[1] == col_name
772
+ ]
773
+ if len(target_edge_types) != 1:
774
+ raise NotImplementedError(
775
+ f"Multilabel-classification queries based on "
776
+ f"'LIST_DISTINCT' are not supported yet. If you "
777
+ f"planned to write a link prediction query instead, "
778
+ f"make sure to register '{col_name}' as a "
779
+ f"foreign key.")
780
+ return TaskType.TEMPORAL_LINK_PREDICTION
781
+
782
+ return TaskType.REGRESSION
783
+
784
+ assert isinstance(target, Column)
785
+
786
+ if target.stype in {Stype.ID, Stype.categorical}:
787
+ return TaskType.MULTICLASS_CLASSIFICATION
788
+
789
+ if target.stype in {Stype.numerical}:
790
+ return TaskType.REGRESSION
791
+
792
+ raise NotImplementedError("Task type not yet supported")
793
+
794
+ def _get_default_anchor_time(
795
+ self,
796
+ query: ValidatedPredictiveQuery,
797
+ ) -> pd.Timestamp:
798
+ if query.query_type == QueryType.TEMPORAL:
799
+ aggr_table_names = [
800
+ aggr._get_target_column_name().split('.')[0]
801
+ for aggr in query.get_all_target_aggregations()
802
+ ]
803
+ return self._sampler.get_max_time(aggr_table_names)
804
+
805
+ assert query.query_type == QueryType.STATIC
806
+ return self._sampler.get_max_time()
807
+
714
808
  def _validate_time(
715
809
  self,
716
- query: PQueryDefinition,
810
+ query: ValidatedPredictiveQuery,
717
811
  anchor_time: pd.Timestamp,
718
- context_anchor_time: Union[pd.Timestamp, None],
812
+ context_anchor_time: pd.Timestamp | None,
719
813
  evaluate: bool,
720
814
  ) -> None:
721
815
 
722
- if self._graph_store.min_time == pd.Timestamp.max:
816
+ if len(self._sampler.time_column_dict) == 0:
723
817
  return # Graph without timestamps
724
818
 
725
- if anchor_time < self._graph_store.min_time:
819
+ min_time = self._sampler.get_min_time()
820
+ max_time = self._sampler.get_max_time()
821
+
822
+ if anchor_time < min_time:
726
823
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
727
- f"the earliest timestamp "
728
- f"'{self._graph_store.min_time}' in the data.")
824
+ f"the earliest timestamp '{min_time}' in the "
825
+ f"data.")
729
826
 
730
- if (context_anchor_time is not None
731
- and context_anchor_time < self._graph_store.min_time):
827
+ if context_anchor_time is not None and context_anchor_time < min_time:
732
828
  raise ValueError(f"Context anchor timestamp is too early or "
733
829
  f"aggregation time range is too large. To make "
734
830
  f"this prediction, we would need data back to "
735
831
  f"'{context_anchor_time}', however, your data "
736
- f"only contains data back to "
737
- f"'{self._graph_store.min_time}'.")
832
+ f"only contains data back to '{min_time}'.")
833
+
834
+ if query.target_ast.date_offset_range is not None:
835
+ end_offset = query.target_ast.date_offset_range.end_date_offset
836
+ else:
837
+ end_offset = pd.DateOffset(0)
838
+ end_offset = end_offset * query.num_forecasts
738
839
 
739
840
  if (context_anchor_time is not None
740
841
  and context_anchor_time > anchor_time):
@@ -744,51 +845,46 @@ class KumoRFM:
744
845
  f"(got '{anchor_time}'). Please make sure this is "
745
846
  f"intended.")
746
847
  elif (query.query_type == QueryType.TEMPORAL
747
- and context_anchor_time is not None and context_anchor_time +
748
- query.target.end_offset * query.num_forecasts > anchor_time):
848
+ and context_anchor_time is not None
849
+ and context_anchor_time + end_offset > anchor_time):
749
850
  warnings.warn(f"Aggregation for context examples at timestamp "
750
851
  f"'{context_anchor_time}' will leak information "
751
852
  f"from the prediction anchor timestamp "
752
853
  f"'{anchor_time}'. Please make sure this is "
753
854
  f"intended.")
754
855
 
755
- elif (context_anchor_time is not None and context_anchor_time -
756
- query.target.end_offset * query.num_forecasts
757
- < self._graph_store.min_time):
758
- _time = context_anchor_time - (query.target.end_offset *
759
- query.num_forecasts)
856
+ elif (context_anchor_time is not None
857
+ and context_anchor_time - end_offset < min_time):
858
+ _time = context_anchor_time - end_offset
760
859
  warnings.warn(f"Context anchor timestamp is too early or "
761
860
  f"aggregation time range is too large. To form "
762
861
  f"proper input data, we would need data back to "
763
862
  f"'{_time}', however, your data only contains "
764
- f"data back to '{self._graph_store.min_time}'.")
863
+ f"data back to '{min_time}'.")
765
864
 
766
- if (not evaluate and anchor_time
767
- > self._graph_store.max_time + pd.DateOffset(days=1)):
865
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
768
866
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
769
- f"latest timestamp '{self._graph_store.max_time}' "
770
- f"in the data. Please make sure this is intended.")
867
+ f"latest timestamp '{max_time}' in the data. Please "
868
+ f"make sure this is intended.")
771
869
 
772
- max_eval_time = (self._graph_store.max_time -
773
- query.target.end_offset * query.num_forecasts)
774
- if evaluate and anchor_time > max_eval_time:
870
+ if evaluate and anchor_time > max_time - end_offset:
775
871
  raise ValueError(
776
872
  f"Anchor timestamp for evaluation is after the latest "
777
- f"supported timestamp '{max_eval_time}'.")
873
+ f"supported timestamp '{max_time - end_offset}'.")
778
874
 
779
875
  def _get_context(
780
876
  self,
781
- query: PQueryDefinition,
782
- indices: Union[List[str], List[float], List[int], None],
783
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
784
- context_anchor_time: Union[pd.Timestamp, None],
877
+ query: ValidatedPredictiveQuery,
878
+ indices: list[str] | list[float] | list[int] | None,
879
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
880
+ context_anchor_time: pd.Timestamp | None,
785
881
  run_mode: RunMode,
786
- num_neighbors: Optional[List[int]],
882
+ num_neighbors: list[int] | None,
787
883
  num_hops: int,
788
884
  max_pq_iterations: int,
789
885
  evaluate: bool,
790
- random_seed: Optional[int] = _RANDOM_SEED,
791
- logger: Optional[ProgressLogger] = None,
886
+ random_seed: int | None = _RANDOM_SEED,
887
+ logger: ProgressLogger | None = None,
792
888
  ) -> Context:
793
889
 
794
890
  if num_neighbors is not None:
@@ -805,10 +901,9 @@ class KumoRFM:
805
901
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
806
902
  f"must go beyond this for your use-case.")
807
903
 
808
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
809
- task_type = query.get_task_type(
810
- stypes=self._graph_store.stype_dict,
811
- edge_types=self._graph_store.edge_types,
904
+ task_type = self._get_task_type(
905
+ query=query,
906
+ edge_types=self._sampler.edge_types,
812
907
  )
813
908
 
814
909
  if logger is not None:
@@ -839,11 +934,18 @@ class KumoRFM:
839
934
  else:
840
935
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
841
936
 
937
+ if query.target_ast.date_offset_range is None:
938
+ step_offset = pd.DateOffset(0)
939
+ else:
940
+ step_offset = query.target_ast.date_offset_range.end_date_offset
941
+ end_offset = step_offset * query.num_forecasts
942
+
842
943
  if anchor_time is None:
843
- anchor_time = self._graph_store.max_time
944
+ anchor_time = self._get_default_anchor_time(query)
945
+
844
946
  if evaluate:
845
- anchor_time = anchor_time - (query.target.end_offset *
846
- query.num_forecasts)
947
+ anchor_time = anchor_time - end_offset
948
+
847
949
  if logger is not None:
848
950
  assert isinstance(anchor_time, pd.Timestamp)
849
951
  if anchor_time == pd.Timestamp.min:
@@ -857,58 +959,71 @@ class KumoRFM:
857
959
 
858
960
  assert anchor_time is not None
859
961
  if isinstance(anchor_time, pd.Timestamp):
962
+ if context_anchor_time == 'entity':
963
+ raise ValueError("Anchor time 'entity' needs to be shared "
964
+ "for context and prediction examples")
860
965
  if context_anchor_time is None:
861
- context_anchor_time = anchor_time - (query.target.end_offset *
862
- query.num_forecasts)
966
+ context_anchor_time = anchor_time - end_offset
863
967
  self._validate_time(query, anchor_time, context_anchor_time,
864
968
  evaluate)
865
969
  else:
866
970
  assert anchor_time == 'entity'
867
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
971
+ if query.query_type != QueryType.STATIC:
972
+ raise ValueError("Anchor time 'entity' is only valid for "
973
+ "static predictive queries")
974
+ if query.entity_table not in self._sampler.time_column_dict:
868
975
  raise ValueError(f"Anchor time 'entity' requires the entity "
869
- f"table '{query.entity.pkey.table_name}' to "
976
+ f"table '{query.entity_table}' to "
870
977
  f"have a time column")
871
- if context_anchor_time is not None:
872
- warnings.warn("Ignoring option 'context_anchor_time' for "
873
- "`anchor_time='entity'`")
874
- context_anchor_time = None
978
+ if isinstance(context_anchor_time, pd.Timestamp):
979
+ raise ValueError("Anchor time 'entity' needs to be shared "
980
+ "for context and prediction examples")
981
+ context_anchor_time = 'entity'
875
982
 
876
- y_test: Optional[pd.Series] = None
983
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
877
984
  if evaluate:
878
- max_test_size = _MAX_TEST_SIZE[run_mode]
985
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
879
986
  if task_type.is_link_pred:
880
- max_test_size = max_test_size // 5
987
+ num_test_examples = num_test_examples // 5
988
+ else:
989
+ num_test_examples = 0
990
+
991
+ train, test = self._sampler.sample_target(
992
+ query=query,
993
+ num_train_examples=num_train_examples,
994
+ train_anchor_time=context_anchor_time,
995
+ num_train_trials=max_pq_iterations * num_train_examples,
996
+ num_test_examples=num_test_examples,
997
+ test_anchor_time=anchor_time,
998
+ num_test_trials=max_pq_iterations * num_test_examples,
999
+ random_seed=random_seed,
1000
+ )
1001
+ train_pkey, train_time, y_train = train
1002
+ test_pkey, test_time, y_test = test
881
1003
 
882
- test_node, test_time, y_test = query_driver.collect_test(
883
- size=max_test_size,
884
- anchor_time=anchor_time,
885
- max_iterations=max_pq_iterations,
886
- guarantee_train_examples=True,
887
- )
888
- if logger is not None:
889
- if task_type == TaskType.BINARY_CLASSIFICATION:
890
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
891
- msg = (f"Collected {len(y_test):,} test examples with "
892
- f"{pos:.2f}% positive cases")
893
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
894
- msg = (f"Collected {len(y_test):,} test examples "
895
- f"holding {y_test.nunique()} classes")
896
- elif task_type == TaskType.REGRESSION:
897
- _min, _max = float(y_test.min()), float(y_test.max())
898
- msg = (f"Collected {len(y_test):,} test examples with "
899
- f"targets between {format_value(_min)} and "
900
- f"{format_value(_max)}")
901
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
902
- num_rhs = y_test.explode().nunique()
903
- msg = (f"Collected {len(y_test):,} test examples with "
904
- f"{num_rhs:,} unique items")
905
- else:
906
- raise NotImplementedError
907
- logger.log(msg)
1004
+ if evaluate and logger is not None:
1005
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1006
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1007
+ msg = (f"Collected {len(y_test):,} test examples with "
1008
+ f"{pos:.2f}% positive cases")
1009
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1010
+ msg = (f"Collected {len(y_test):,} test examples holding "
1011
+ f"{y_test.nunique()} classes")
1012
+ elif task_type == TaskType.REGRESSION:
1013
+ _min, _max = float(y_test.min()), float(y_test.max())
1014
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1015
+ f"between {format_value(_min)} and "
1016
+ f"{format_value(_max)}")
1017
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1018
+ num_rhs = y_test.explode().nunique()
1019
+ msg = (f"Collected {len(y_test):,} test examples with "
1020
+ f"{num_rhs:,} unique items")
1021
+ else:
1022
+ raise NotImplementedError
1023
+ logger.log(msg)
908
1024
 
909
- else:
1025
+ if not evaluate:
910
1026
  assert indices is not None
911
-
912
1027
  if len(indices) > _MAX_PRED_SIZE[task_type]:
913
1028
  raise ValueError(f"Cannot predict for more than "
914
1029
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -916,27 +1031,12 @@ class KumoRFM:
916
1031
  f"`KumoRFM.batch_mode` to process entities "
917
1032
  f"in batches")
918
1033
 
919
- test_node = self._graph_store.get_node_id(
920
- table_name=query.entity.pkey.table_name,
921
- pkey=pd.Series(indices),
922
- )
923
-
1034
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
924
1035
  if isinstance(anchor_time, pd.Timestamp):
925
- test_time = pd.Series(anchor_time).repeat(
926
- len(test_node)).reset_index(drop=True)
1036
+ test_time = pd.Series([anchor_time]).repeat(
1037
+ len(indices)).reset_index(drop=True)
927
1038
  else:
928
- time = self._graph_store.time_dict[
929
- query.entity.pkey.table_name]
930
- time = time[test_node] * 1000**3
931
- test_time = pd.Series(time, dtype='datetime64[ns]')
932
-
933
- train_node, train_time, y_train = query_driver.collect_train(
934
- size=_MAX_CONTEXT_SIZE[run_mode],
935
- anchor_time=context_anchor_time or 'entity',
936
- exclude_node=test_node if (query.query_type == QueryType.STATIC
937
- or anchor_time == 'entity') else None,
938
- max_iterations=max_pq_iterations,
939
- )
1039
+ train_time = test_time = 'entity'
940
1040
 
941
1041
  if logger is not None:
942
1042
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -959,27 +1059,41 @@ class KumoRFM:
959
1059
  raise NotImplementedError
960
1060
  logger.log(msg)
961
1061
 
962
- entity_table_names = query.get_entity_table_names(
963
- self._graph_store.edge_types)
1062
+ entity_table_names: tuple[str, ...]
1063
+ if task_type.is_link_pred:
1064
+ final_aggr = query.get_final_target_aggregation()
1065
+ assert final_aggr is not None
1066
+ edge_fkey = final_aggr._get_target_column_name()
1067
+ for edge_type in self._sampler.edge_types:
1068
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1069
+ entity_table_names = (
1070
+ query.entity_table,
1071
+ edge_type[2],
1072
+ )
1073
+ else:
1074
+ entity_table_names = (query.entity_table, )
964
1075
 
965
1076
  # Exclude the entity anchor time from the feature set to prevent
966
1077
  # running out-of-distribution between in-context and test examples:
967
- exclude_cols_dict = query.exclude_cols_dict
968
- if anchor_time == 'entity':
1078
+ exclude_cols_dict = query.get_exclude_cols_dict()
1079
+ if entity_table_names[0] in self._sampler.time_column_dict:
969
1080
  if entity_table_names[0] not in exclude_cols_dict:
970
1081
  exclude_cols_dict[entity_table_names[0]] = []
971
- time_column_dict = self._graph_store.time_column_dict
972
- time_column = time_column_dict[entity_table_names[0]]
1082
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
973
1083
  exclude_cols_dict[entity_table_names[0]].append(time_column)
974
1084
 
975
- subgraph = self._graph_sampler(
1085
+ subgraph = self._sampler.sample_subgraph(
976
1086
  entity_table_names=entity_table_names,
977
- node=np.concatenate([train_node, test_node]),
978
- time=np.concatenate([
979
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
980
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
981
- ]),
982
- run_mode=run_mode,
1087
+ entity_pkey=pd.concat(
1088
+ [train_pkey, test_pkey],
1089
+ axis=0,
1090
+ ignore_index=True,
1091
+ ),
1092
+ anchor_time=pd.concat(
1093
+ [train_time, test_time],
1094
+ axis=0,
1095
+ ignore_index=True,
1096
+ ) if isinstance(train_time, pd.Series) else 'entity',
983
1097
  num_neighbors=num_neighbors,
984
1098
  exclude_cols_dict=exclude_cols_dict,
985
1099
  )
@@ -991,23 +1105,19 @@ class KumoRFM:
991
1105
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
992
1106
  f"must go beyond this for your use-case.")
993
1107
 
994
- step_size: Optional[int] = None
995
- if query.query_type == QueryType.TEMPORAL:
996
- step_size = date_offset_to_seconds(query.target.end_offset)
997
-
998
1108
  return Context(
999
1109
  task_type=task_type,
1000
1110
  entity_table_names=entity_table_names,
1001
1111
  subgraph=subgraph,
1002
1112
  y_train=y_train,
1003
- y_test=y_test,
1113
+ y_test=y_test if evaluate else None,
1004
1114
  top_k=query.top_k,
1005
- step_size=step_size,
1115
+ step_size=None,
1006
1116
  )
1007
1117
 
1008
1118
  @staticmethod
1009
1119
  def _validate_metrics(
1010
- metrics: List[str],
1120
+ metrics: list[str],
1011
1121
  task_type: TaskType,
1012
1122
  ) -> None:
1013
1123
 
@@ -1064,7 +1174,7 @@ class KumoRFM:
1064
1174
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1065
1175
 
1066
1176
 
1067
- def format_value(value: Union[int, float]) -> str:
1177
+ def format_value(value: int | float) -> str:
1068
1178
  if value == int(value):
1069
1179
  return f'{int(value):,}'
1070
1180
  if abs(value) >= 1000: