kumoai 2.12.0.dev202511061731__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. kumoai/__init__.py +41 -35
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/client/rfm.py +15 -7
  7. kumoai/connector/utils.py +23 -2
  8. kumoai/experimental/rfm/__init__.py +191 -48
  9. kumoai/experimental/rfm/authenticate.py +3 -4
  10. kumoai/experimental/rfm/backend/__init__.py +0 -0
  11. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  12. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  13. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  14. kumoai/experimental/rfm/backend/local/table.py +113 -0
  15. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  16. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  17. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  18. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  19. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  20. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  21. kumoai/experimental/rfm/base/__init__.py +30 -0
  22. kumoai/experimental/rfm/base/column.py +152 -0
  23. kumoai/experimental/rfm/base/expression.py +44 -0
  24. kumoai/experimental/rfm/base/sampler.py +761 -0
  25. kumoai/experimental/rfm/base/source.py +19 -0
  26. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  27. kumoai/experimental/rfm/base/table.py +735 -0
  28. kumoai/experimental/rfm/graph.py +1237 -0
  29. kumoai/experimental/rfm/infer/__init__.py +8 -0
  30. kumoai/experimental/rfm/infer/dtype.py +82 -0
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +128 -0
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +61 -0
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +346 -248
  39. kumoai/experimental/rfm/sagemaker.py +138 -0
  40. kumoai/kumolib.cp311-win_amd64.pyd +0 -0
  41. kumoai/pquery/predictive_query.py +10 -6
  42. kumoai/spcs.py +1 -3
  43. kumoai/testing/decorators.py +1 -1
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +51 -0
  48. kumoai/utils/progress_logger.py +188 -16
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
  51. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +54 -31
  52. kumoai/experimental/rfm/local_graph.py +0 -810
  53. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  54. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  55. kumoai/experimental/rfm/local_table.py +0 -545
  56. kumoai/experimental/rfm/utils.py +0 -344
  57. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
  58. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
  59. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
@@ -2,15 +2,22 @@ import json
2
2
  import time
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import Iterator, List, Literal, Optional, Tuple, Union, overload
8
+ from typing import Any, Literal, overload
9
9
 
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
  from kumoapi.model_plan import RunMode
13
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
14
21
  from kumoapi.rfm import Context
15
22
  from kumoapi.rfm import Explanation as ExplanationConfig
16
23
  from kumoapi.rfm import (
@@ -19,17 +26,14 @@ from kumoapi.rfm import (
19
26
  RFMPredictRequest,
20
27
  )
21
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
22
30
 
23
- from kumoai import global_state
31
+ from kumoai.client.rfm import RFMAPI
24
32
  from kumoai.exceptions import HTTPException
25
- from kumoai.experimental.rfm import LocalGraph
26
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
27
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
28
- from kumoai.experimental.rfm.local_pquery_driver import (
29
- LocalPQueryDriver,
30
- date_offset_to_seconds,
31
- )
32
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
33
+ from kumoai.experimental.rfm import Graph
34
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
35
+ from kumoai.mixin import CastMixin
36
+ from kumoai.utils import ProgressLogger, display
33
37
 
34
38
  _RANDOM_SEED = 42
35
39
 
@@ -59,6 +63,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
59
63
  "beyond this for your use-case.")
60
64
 
61
65
 
66
+ @dataclass(repr=False)
67
+ class ExplainConfig(CastMixin):
68
+ """Configuration for explainability.
69
+
70
+ Args:
71
+ skip_summary: Whether to skip generating a human-readable summary of
72
+ the explanation.
73
+ """
74
+ skip_summary: bool = False
75
+
76
+
62
77
  @dataclass(repr=False)
63
78
  class Explanation:
64
79
  prediction: pd.DataFrame
@@ -73,19 +88,27 @@ class Explanation:
73
88
  def __getitem__(self, index: Literal[1]) -> str:
74
89
  pass
75
90
 
76
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
91
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
77
92
  if index == 0:
78
93
  return self.prediction
79
94
  if index == 1:
80
95
  return self.summary
81
96
  raise IndexError("Index out of range")
82
97
 
83
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
98
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
84
99
  return iter((self.prediction, self.summary))
85
100
 
86
101
  def __repr__(self) -> str:
87
102
  return str((self.prediction, self.summary))
88
103
 
104
+ def print(self) -> None:
105
+ r"""Prints the explanation."""
106
+ display.dataframe(self.prediction)
107
+ display.message(self.summary)
108
+
109
+ def _ipython_display_(self) -> None:
110
+ self.print()
111
+
89
112
 
90
113
  class KumoRFM:
91
114
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
@@ -95,17 +118,17 @@ class KumoRFM:
95
118
  :class:`KumoRFM` is a foundation model to generate predictions for any
96
119
  relational dataset without training.
97
120
  The model is pre-trained and the class provides an interface to query the
98
- model from a :class:`LocalGraph` object.
121
+ model from a :class:`Graph` object.
99
122
 
100
123
  .. code-block:: python
101
124
 
102
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
125
+ from kumoai.experimental.rfm import Graph, KumoRFM
103
126
 
104
127
  df_users = pd.DataFrame(...)
105
128
  df_items = pd.DataFrame(...)
106
129
  df_orders = pd.DataFrame(...)
107
130
 
108
- graph = LocalGraph.from_data({
131
+ graph = Graph.from_data({
109
132
  'users': df_users,
110
133
  'items': df_items,
111
134
  'orders': df_orders,
@@ -113,47 +136,63 @@ class KumoRFM:
113
136
 
114
137
  rfm = KumoRFM(graph)
115
138
 
116
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
117
- "FOR users.user_id=0")
118
- result = rfm.query(query)
139
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
140
+ "FOR users.user_id=1")
141
+ result = rfm.predict(query)
119
142
 
120
143
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
121
144
  # 1 0.85
122
145
 
123
146
  Args:
124
147
  graph: The graph.
125
- preprocess: Whether to pre-process the data in advance during graph
126
- materialization.
127
- This is a runtime trade-off between graph materialization and model
128
- processing speed.
129
- It can be benefical to preprocess your data once and then run many
130
- queries on top to achieve maximum model speed.
131
- However, if activiated, graph materialization can take potentially
132
- much longer, especially on graphs with many large text columns.
133
- Best to tune this option manually.
134
148
  verbose: Whether to print verbose output.
149
+ optimize: If set to ``True``, will optimize the underlying data backend
150
+ for optimal querying. For example, for transactional database
151
+ backends, will create any missing indices. Requires write-access to
152
+ the data backend.
135
153
  """
136
154
  def __init__(
137
155
  self,
138
- graph: LocalGraph,
139
- preprocess: bool = False,
140
- verbose: Union[bool, ProgressLogger] = True,
156
+ graph: Graph,
157
+ verbose: bool | ProgressLogger = True,
158
+ optimize: bool = False,
141
159
  ) -> None:
142
160
  graph = graph.validate()
143
161
  self._graph_def = graph._to_api_graph_definition()
144
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
145
- self._graph_sampler = LocalGraphSampler(self._graph_store)
146
162
 
147
- self._batch_size: Optional[int | Literal['max']] = None
163
+ if graph.backend == DataBackend.LOCAL:
164
+ from kumoai.experimental.rfm.backend.local import LocalSampler
165
+ self._sampler: Sampler = LocalSampler(graph, verbose)
166
+ elif graph.backend == DataBackend.SQLITE:
167
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
168
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
169
+ elif graph.backend == DataBackend.SNOWFLAKE:
170
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
171
+ self._sampler = SnowSampler(graph, verbose)
172
+ else:
173
+ raise NotImplementedError
174
+
175
+ self._client: RFMAPI | None = None
176
+
177
+ self._batch_size: int | Literal['max'] | None = None
148
178
  self.num_retries: int = 0
149
179
 
180
+ @property
181
+ def _api_client(self) -> RFMAPI:
182
+ if self._client is not None:
183
+ return self._client
184
+
185
+ from kumoai.experimental.rfm import global_state
186
+ self._client = RFMAPI(global_state.client)
187
+ return self._client
188
+
150
189
  def __repr__(self) -> str:
151
190
  return f'{self.__class__.__name__}()'
152
191
 
153
192
  @contextmanager
154
193
  def batch_mode(
155
194
  self,
156
- batch_size: Union[int, Literal['max']] = 'max',
195
+ batch_size: int | Literal['max'] = 'max',
157
196
  num_retries: int = 1,
158
197
  ) -> Generator[None, None, None]:
159
198
  """Context manager to predict in batches.
@@ -187,17 +226,17 @@ class KumoRFM:
187
226
  def predict(
188
227
  self,
189
228
  query: str,
190
- indices: Union[List[str], List[float], List[int], None] = None,
229
+ indices: list[str] | list[float] | list[int] | None = None,
191
230
  *,
192
231
  explain: Literal[False] = False,
193
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
194
- context_anchor_time: Union[pd.Timestamp, None] = None,
195
- run_mode: Union[RunMode, str] = RunMode.FAST,
196
- num_neighbors: Optional[List[int]] = None,
232
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
233
+ context_anchor_time: pd.Timestamp | None = None,
234
+ run_mode: RunMode | str = RunMode.FAST,
235
+ num_neighbors: list[int] | None = None,
197
236
  num_hops: int = 2,
198
- max_pq_iterations: int = 20,
199
- random_seed: Optional[int] = _RANDOM_SEED,
200
- verbose: Union[bool, ProgressLogger] = True,
237
+ max_pq_iterations: int = 10,
238
+ random_seed: int | None = _RANDOM_SEED,
239
+ verbose: bool | ProgressLogger = True,
201
240
  use_prediction_time: bool = False,
202
241
  ) -> pd.DataFrame:
203
242
  pass
@@ -206,17 +245,17 @@ class KumoRFM:
206
245
  def predict(
207
246
  self,
208
247
  query: str,
209
- indices: Union[List[str], List[float], List[int], None] = None,
248
+ indices: list[str] | list[float] | list[int] | None = None,
210
249
  *,
211
- explain: Literal[True],
212
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
213
- context_anchor_time: Union[pd.Timestamp, None] = None,
214
- run_mode: Union[RunMode, str] = RunMode.FAST,
215
- num_neighbors: Optional[List[int]] = None,
250
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
251
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
252
+ context_anchor_time: pd.Timestamp | None = None,
253
+ run_mode: RunMode | str = RunMode.FAST,
254
+ num_neighbors: list[int] | None = None,
216
255
  num_hops: int = 2,
217
- max_pq_iterations: int = 20,
218
- random_seed: Optional[int] = _RANDOM_SEED,
219
- verbose: Union[bool, ProgressLogger] = True,
256
+ max_pq_iterations: int = 10,
257
+ random_seed: int | None = _RANDOM_SEED,
258
+ verbose: bool | ProgressLogger = True,
220
259
  use_prediction_time: bool = False,
221
260
  ) -> Explanation:
222
261
  pass
@@ -224,19 +263,19 @@ class KumoRFM:
224
263
  def predict(
225
264
  self,
226
265
  query: str,
227
- indices: Union[List[str], List[float], List[int], None] = None,
266
+ indices: list[str] | list[float] | list[int] | None = None,
228
267
  *,
229
- explain: bool = False,
230
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
231
- context_anchor_time: Union[pd.Timestamp, None] = None,
232
- run_mode: Union[RunMode, str] = RunMode.FAST,
233
- num_neighbors: Optional[List[int]] = None,
268
+ explain: bool | ExplainConfig | dict[str, Any] = False,
269
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
270
+ context_anchor_time: pd.Timestamp | None = None,
271
+ run_mode: RunMode | str = RunMode.FAST,
272
+ num_neighbors: list[int] | None = None,
234
273
  num_hops: int = 2,
235
- max_pq_iterations: int = 20,
236
- random_seed: Optional[int] = _RANDOM_SEED,
237
- verbose: Union[bool, ProgressLogger] = True,
274
+ max_pq_iterations: int = 10,
275
+ random_seed: int | None = _RANDOM_SEED,
276
+ verbose: bool | ProgressLogger = True,
238
277
  use_prediction_time: bool = False,
239
- ) -> Union[pd.DataFrame, Explanation]:
278
+ ) -> pd.DataFrame | Explanation:
240
279
  """Returns predictions for a predictive query.
241
280
 
242
281
  Args:
@@ -246,9 +285,12 @@ class KumoRFM:
246
285
  be generated for all indices, independent of whether they
247
286
  fulfill entity filter constraints. To pre-filter entities, use
248
287
  :meth:`~KumoRFM.is_valid_entity`.
249
- explain: If set to ``True``, will additionally explain the
250
- prediction. Explainability is currently only supported for
251
- single entity predictions with ``run_mode="FAST"``.
288
+ explain: Configuration for explainability.
289
+ If set to ``True``, will additionally explain the prediction.
290
+ Passing in an :class:`ExplainConfig` instance provides control
291
+ over which parts of explanation are generated.
292
+ Explainability is currently only supported for single entity
293
+ predictions with ``run_mode="FAST"``.
252
294
  anchor_time: The anchor timestamp for the prediction. If set to
253
295
  ``None``, will use the maximum timestamp in the data.
254
296
  If set to ``"entity"``, will use the timestamp of the entity.
@@ -272,16 +314,25 @@ class KumoRFM:
272
314
 
273
315
  Returns:
274
316
  The predictions as a :class:`pandas.DataFrame`.
275
- If ``explain=True``, additionally returns a textual summary that
276
- explains the prediction.
317
+ If ``explain`` is provided, returns an :class:`Explanation` object
318
+ containing the prediction, summary, and details.
277
319
  """
320
+ explain_config: ExplainConfig | None = None
321
+ if explain is True:
322
+ explain_config = ExplainConfig()
323
+ elif explain is not False:
324
+ explain_config = ExplainConfig._cast(explain)
325
+
278
326
  query_def = self._parse_query(query)
327
+ query_str = query_def.to_string()
279
328
 
280
329
  if num_hops != 2 and num_neighbors is not None:
281
330
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
282
331
  f"custom 'num_hops={num_hops}' option")
283
332
 
284
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
333
+ if explain_config is not None and run_mode in {
334
+ RunMode.NORMAL, RunMode.BEST
335
+ }:
285
336
  warnings.warn(f"Explainability is currently only supported for "
286
337
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
287
338
  f"mode has been reset. Please lower the run mode to "
@@ -298,27 +349,27 @@ class KumoRFM:
298
349
  if len(indices) == 0:
299
350
  raise ValueError("At least one entity is required")
300
351
 
301
- if explain and len(indices) > 1:
352
+ if explain_config is not None and len(indices) > 1:
302
353
  raise ValueError(
303
354
  f"Cannot explain predictions for more than a single entity "
304
355
  f"(got {len(indices)})")
305
356
 
306
357
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
307
- if explain:
358
+ if explain_config is not None:
308
359
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
309
360
  else:
310
361
  msg = f'[bold]PREDICT[/bold] {query_repr}'
311
362
 
312
363
  if not isinstance(verbose, ProgressLogger):
313
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
364
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
314
365
 
315
366
  with verbose as logger:
316
367
 
317
- batch_size: Optional[int] = None
368
+ batch_size: int | None = None
318
369
  if self._batch_size == 'max':
319
- task_type = LocalPQueryDriver.get_task_type(
320
- query_def,
321
- edge_types=self._graph_store.edge_types,
370
+ task_type = self._get_task_type(
371
+ query=query_def,
372
+ edge_types=self._sampler.edge_types,
322
373
  )
323
374
  batch_size = _MAX_PRED_SIZE[task_type]
324
375
  else:
@@ -334,9 +385,9 @@ class KumoRFM:
334
385
  logger.log(f"Splitting {len(indices):,} entities into "
335
386
  f"{len(batches):,} batches of size {batch_size:,}")
336
387
 
337
- predictions: List[pd.DataFrame] = []
338
- summary: Optional[str] = None
339
- details: Optional[Explanation] = None
388
+ predictions: list[pd.DataFrame] = []
389
+ summary: str | None = None
390
+ details: Explanation | None = None
340
391
  for i, batch in enumerate(batches):
341
392
  # TODO Re-use the context for subsequent predictions.
342
393
  context = self._get_context(
@@ -355,6 +406,7 @@ class KumoRFM:
355
406
  request = RFMPredictRequest(
356
407
  context=context,
357
408
  run_mode=RunMode(run_mode),
409
+ query=query_str,
358
410
  use_prediction_time=use_prediction_time,
359
411
  )
360
412
  with warnings.catch_warnings():
@@ -369,8 +421,7 @@ class KumoRFM:
369
421
  stats = Context.get_memory_stats(request_msg.context)
370
422
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
371
423
 
372
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
373
- and len(batches) > 1):
424
+ if i == 0 and len(batches) > 1:
374
425
  verbose.init_progress(
375
426
  total=len(batches),
376
427
  description='Predicting',
@@ -378,20 +429,23 @@ class KumoRFM:
378
429
 
379
430
  for attempt in range(self.num_retries + 1):
380
431
  try:
381
- if explain:
382
- resp = global_state.client.rfm_api.explain(_bytes)
432
+ if explain_config is not None:
433
+ resp = self._api_client.explain(
434
+ request=_bytes,
435
+ skip_summary=explain_config.skip_summary,
436
+ )
383
437
  summary = resp.summary
384
438
  details = resp.details
385
439
  else:
386
- resp = global_state.client.rfm_api.predict(_bytes)
440
+ resp = self._api_client.predict(_bytes)
387
441
  df = pd.DataFrame(**resp.prediction)
388
442
 
389
443
  # Cast 'ENTITY' to correct data type:
390
444
  if 'ENTITY' in df:
391
- entity = query_def.entity_table
392
- pkey_map = self._graph_store.pkey_map_dict[entity]
393
- df['ENTITY'] = df['ENTITY'].astype(
394
- type(pkey_map.index[0]))
445
+ table_dict = context.subgraph.table_dict
446
+ table = table_dict[query_def.entity_table]
447
+ ser = table.df[table.primary_key]
448
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
395
449
 
396
450
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
397
451
  if 'ANCHOR_TIMESTAMP' in df:
@@ -406,8 +460,7 @@ class KumoRFM:
406
460
 
407
461
  predictions.append(df)
408
462
 
409
- if (isinstance(verbose, InteractiveProgressLogger)
410
- and len(batches) > 1):
463
+ if len(batches) > 1:
411
464
  verbose.step()
412
465
 
413
466
  break
@@ -430,7 +483,7 @@ class KumoRFM:
430
483
  else:
431
484
  prediction = pd.concat(predictions, ignore_index=True)
432
485
 
433
- if explain:
486
+ if explain_config is not None:
434
487
  assert len(predictions) == 1
435
488
  assert summary is not None
436
489
  assert details is not None
@@ -445,9 +498,9 @@ class KumoRFM:
445
498
  def is_valid_entity(
446
499
  self,
447
500
  query: str,
448
- indices: Union[List[str], List[float], List[int], None] = None,
501
+ indices: list[str] | list[float] | list[int] | None = None,
449
502
  *,
450
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
503
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
451
504
  ) -> np.ndarray:
452
505
  r"""Returns a mask that denotes which entities are valid for the
453
506
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -474,37 +527,32 @@ class KumoRFM:
474
527
  raise ValueError("At least one entity is required")
475
528
 
476
529
  if anchor_time is None:
477
- anchor_time = self._graph_store.max_time
530
+ anchor_time = self._get_default_anchor_time(query_def)
478
531
 
479
532
  if isinstance(anchor_time, pd.Timestamp):
480
533
  self._validate_time(query_def, anchor_time, None, False)
481
534
  else:
482
535
  assert anchor_time == 'entity'
483
- if (query_def.entity_table not in self._graph_store.time_dict):
536
+ if query_def.entity_table not in self._sampler.time_column_dict:
484
537
  raise ValueError(f"Anchor time 'entity' requires the entity "
485
538
  f"table '{query_def.entity_table}' "
486
539
  f"to have a time column.")
487
540
 
488
- node = self._graph_store.get_node_id(
489
- table_name=query_def.entity_table,
490
- pkey=pd.Series(indices),
491
- )
492
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
493
- return query_driver.is_valid(node, anchor_time)
541
+ raise NotImplementedError
494
542
 
495
543
  def evaluate(
496
544
  self,
497
545
  query: str,
498
546
  *,
499
- metrics: Optional[List[str]] = None,
500
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
501
- context_anchor_time: Union[pd.Timestamp, None] = None,
502
- run_mode: Union[RunMode, str] = RunMode.FAST,
503
- num_neighbors: Optional[List[int]] = None,
547
+ metrics: list[str] | None = None,
548
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
549
+ context_anchor_time: pd.Timestamp | None = None,
550
+ run_mode: RunMode | str = RunMode.FAST,
551
+ num_neighbors: list[int] | None = None,
504
552
  num_hops: int = 2,
505
- max_pq_iterations: int = 20,
506
- random_seed: Optional[int] = _RANDOM_SEED,
507
- verbose: Union[bool, ProgressLogger] = True,
553
+ max_pq_iterations: int = 10,
554
+ random_seed: int | None = _RANDOM_SEED,
555
+ verbose: bool | ProgressLogger = True,
508
556
  use_prediction_time: bool = False,
509
557
  ) -> pd.DataFrame:
510
558
  """Evaluates a predictive query.
@@ -552,7 +600,7 @@ class KumoRFM:
552
600
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
553
601
 
554
602
  if not isinstance(verbose, ProgressLogger):
555
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
603
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
556
604
 
557
605
  with verbose as logger:
558
606
  context = self._get_context(
@@ -586,10 +634,10 @@ class KumoRFM:
586
634
 
587
635
  if len(request_bytes) > _MAX_SIZE:
588
636
  stats_msg = Context.get_memory_stats(request_msg.context)
589
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
637
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
590
638
 
591
639
  try:
592
- resp = global_state.client.rfm_api.evaluate(request_bytes)
640
+ resp = self._api_client.evaluate(request_bytes)
593
641
  except HTTPException as e:
594
642
  try:
595
643
  msg = json.loads(e.detail)['detail']
@@ -611,9 +659,9 @@ class KumoRFM:
611
659
  query: str,
612
660
  size: int,
613
661
  *,
614
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
615
- random_seed: Optional[int] = _RANDOM_SEED,
616
- max_iterations: int = 20,
662
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
663
+ random_seed: int | None = _RANDOM_SEED,
664
+ max_iterations: int = 10,
617
665
  ) -> pd.DataFrame:
618
666
  """Returns the labels of a predictive query for a specified anchor
619
667
  time.
@@ -633,40 +681,37 @@ class KumoRFM:
633
681
  query_def = self._parse_query(query)
634
682
 
635
683
  if anchor_time is None:
636
- anchor_time = self._graph_store.max_time
684
+ anchor_time = self._get_default_anchor_time(query_def)
637
685
  if query_def.target_ast.date_offset_range is not None:
638
- anchor_time = anchor_time - (
639
- query_def.target_ast.date_offset_range.end_date_offset *
640
- query_def.num_forecasts)
686
+ offset = query_def.target_ast.date_offset_range.end_date_offset
687
+ offset *= query_def.num_forecasts
688
+ anchor_time -= offset
641
689
 
642
690
  assert anchor_time is not None
643
691
  if isinstance(anchor_time, pd.Timestamp):
644
692
  self._validate_time(query_def, anchor_time, None, evaluate=True)
645
693
  else:
646
694
  assert anchor_time == 'entity'
647
- if (query_def.entity_table not in self._graph_store.time_dict):
695
+ if query_def.entity_table not in self._sampler.time_column_dict:
648
696
  raise ValueError(f"Anchor time 'entity' requires the entity "
649
697
  f"table '{query_def.entity_table}' "
650
698
  f"to have a time column")
651
699
 
652
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
653
- random_seed)
654
-
655
- node, time, y = query_driver.collect_test(
656
- size=size,
657
- anchor_time=anchor_time,
658
- batch_size=min(10_000, size),
659
- max_iterations=max_iterations,
660
- guarantee_train_examples=False,
700
+ train, test = self._sampler.sample_target(
701
+ query=query_def,
702
+ num_train_examples=0,
703
+ train_anchor_time=anchor_time,
704
+ num_train_trials=0,
705
+ num_test_examples=size,
706
+ test_anchor_time=anchor_time,
707
+ num_test_trials=max_iterations * size,
708
+ random_seed=random_seed,
661
709
  )
662
710
 
663
- entity = self._graph_store.pkey_map_dict[
664
- query_def.entity_table].index[node]
665
-
666
711
  return pd.DataFrame({
667
- 'ENTITY': entity,
668
- 'ANCHOR_TIMESTAMP': time,
669
- 'TARGET': y,
712
+ 'ENTITY': test.entity_pkey,
713
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
714
+ 'TARGET': test.target,
670
715
  })
671
716
 
672
717
  # Helpers #################################################################
@@ -687,8 +732,7 @@ class KumoRFM:
687
732
  graph_definition=self._graph_def,
688
733
  )
689
734
 
690
- resp = global_state.client.rfm_api.parse_query(request)
691
- # TODO Expose validation warnings.
735
+ resp = self._api_client.parse_query(request)
692
736
 
693
737
  if len(resp.validation_response.warnings) > 0:
694
738
  msg = '\n'.join([
@@ -707,36 +751,92 @@ class KumoRFM:
707
751
  raise ValueError(f"Failed to parse query '{query}'. "
708
752
  f"{msg}") from None
709
753
 
754
+ @staticmethod
755
+ def _get_task_type(
756
+ query: ValidatedPredictiveQuery,
757
+ edge_types: list[tuple[str, str, str]],
758
+ ) -> TaskType:
759
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
760
+ return TaskType.BINARY_CLASSIFICATION
761
+
762
+ target = query.target_ast
763
+ if isinstance(target, Join):
764
+ target = target.rhs_target
765
+ if isinstance(target, Aggregation):
766
+ if target.aggr == AggregationType.LIST_DISTINCT:
767
+ table_name, col_name = target._get_target_column_name().split(
768
+ '.')
769
+ target_edge_types = [
770
+ edge_type for edge_type in edge_types
771
+ if edge_type[0] == table_name and edge_type[1] == col_name
772
+ ]
773
+ if len(target_edge_types) != 1:
774
+ raise NotImplementedError(
775
+ f"Multilabel-classification queries based on "
776
+ f"'LIST_DISTINCT' are not supported yet. If you "
777
+ f"planned to write a link prediction query instead, "
778
+ f"make sure to register '{col_name}' as a "
779
+ f"foreign key.")
780
+ return TaskType.TEMPORAL_LINK_PREDICTION
781
+
782
+ return TaskType.REGRESSION
783
+
784
+ assert isinstance(target, Column)
785
+
786
+ if target.stype in {Stype.ID, Stype.categorical}:
787
+ return TaskType.MULTICLASS_CLASSIFICATION
788
+
789
+ if target.stype in {Stype.numerical}:
790
+ return TaskType.REGRESSION
791
+
792
+ raise NotImplementedError("Task type not yet supported")
793
+
794
+ def _get_default_anchor_time(
795
+ self,
796
+ query: ValidatedPredictiveQuery,
797
+ ) -> pd.Timestamp:
798
+ if query.query_type == QueryType.TEMPORAL:
799
+ aggr_table_names = [
800
+ aggr._get_target_column_name().split('.')[0]
801
+ for aggr in query.get_all_target_aggregations()
802
+ ]
803
+ return self._sampler.get_max_time(aggr_table_names)
804
+
805
+ assert query.query_type == QueryType.STATIC
806
+ return self._sampler.get_max_time()
807
+
710
808
  def _validate_time(
711
809
  self,
712
810
  query: ValidatedPredictiveQuery,
713
811
  anchor_time: pd.Timestamp,
714
- context_anchor_time: Union[pd.Timestamp, None],
812
+ context_anchor_time: pd.Timestamp | None,
715
813
  evaluate: bool,
716
814
  ) -> None:
717
815
 
718
- if self._graph_store.min_time == pd.Timestamp.max:
816
+ if len(self._sampler.time_column_dict) == 0:
719
817
  return # Graph without timestamps
720
818
 
721
- if anchor_time < self._graph_store.min_time:
819
+ min_time = self._sampler.get_min_time()
820
+ max_time = self._sampler.get_max_time()
821
+
822
+ if anchor_time < min_time:
722
823
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
723
- f"the earliest timestamp "
724
- f"'{self._graph_store.min_time}' in the data.")
824
+ f"the earliest timestamp '{min_time}' in the "
825
+ f"data.")
725
826
 
726
- if (context_anchor_time is not None
727
- and context_anchor_time < self._graph_store.min_time):
827
+ if context_anchor_time is not None and context_anchor_time < min_time:
728
828
  raise ValueError(f"Context anchor timestamp is too early or "
729
829
  f"aggregation time range is too large. To make "
730
830
  f"this prediction, we would need data back to "
731
831
  f"'{context_anchor_time}', however, your data "
732
- f"only contains data back to "
733
- f"'{self._graph_store.min_time}'.")
832
+ f"only contains data back to '{min_time}'.")
734
833
 
735
834
  if query.target_ast.date_offset_range is not None:
736
835
  end_offset = query.target_ast.date_offset_range.end_date_offset
737
836
  else:
738
837
  end_offset = pd.DateOffset(0)
739
- forecast_end_offset = end_offset * query.num_forecasts
838
+ end_offset = end_offset * query.num_forecasts
839
+
740
840
  if (context_anchor_time is not None
741
841
  and context_anchor_time > anchor_time):
742
842
  warnings.warn(f"Context anchor timestamp "
@@ -746,7 +846,7 @@ class KumoRFM:
746
846
  f"intended.")
747
847
  elif (query.query_type == QueryType.TEMPORAL
748
848
  and context_anchor_time is not None
749
- and context_anchor_time + forecast_end_offset > anchor_time):
849
+ and context_anchor_time + end_offset > anchor_time):
750
850
  warnings.warn(f"Aggregation for context examples at timestamp "
751
851
  f"'{context_anchor_time}' will leak information "
752
852
  f"from the prediction anchor timestamp "
@@ -754,40 +854,37 @@ class KumoRFM:
754
854
  f"intended.")
755
855
 
756
856
  elif (context_anchor_time is not None
757
- and context_anchor_time - forecast_end_offset
758
- < self._graph_store.min_time):
759
- _time = context_anchor_time - forecast_end_offset
857
+ and context_anchor_time - end_offset < min_time):
858
+ _time = context_anchor_time - end_offset
760
859
  warnings.warn(f"Context anchor timestamp is too early or "
761
860
  f"aggregation time range is too large. To form "
762
861
  f"proper input data, we would need data back to "
763
862
  f"'{_time}', however, your data only contains "
764
- f"data back to '{self._graph_store.min_time}'.")
863
+ f"data back to '{min_time}'.")
765
864
 
766
- if (not evaluate and anchor_time
767
- > self._graph_store.max_time + pd.DateOffset(days=1)):
865
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
768
866
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
769
- f"latest timestamp '{self._graph_store.max_time}' "
770
- f"in the data. Please make sure this is intended.")
867
+ f"latest timestamp '{max_time}' in the data. Please "
868
+ f"make sure this is intended.")
771
869
 
772
- max_eval_time = self._graph_store.max_time - forecast_end_offset
773
- if evaluate and anchor_time > max_eval_time:
870
+ if evaluate and anchor_time > max_time - end_offset:
774
871
  raise ValueError(
775
872
  f"Anchor timestamp for evaluation is after the latest "
776
- f"supported timestamp '{max_eval_time}'.")
873
+ f"supported timestamp '{max_time - end_offset}'.")
777
874
 
778
875
  def _get_context(
779
876
  self,
780
877
  query: ValidatedPredictiveQuery,
781
- indices: Union[List[str], List[float], List[int], None],
782
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
783
- context_anchor_time: Union[pd.Timestamp, None],
878
+ indices: list[str] | list[float] | list[int] | None,
879
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
880
+ context_anchor_time: pd.Timestamp | None,
784
881
  run_mode: RunMode,
785
- num_neighbors: Optional[List[int]],
882
+ num_neighbors: list[int] | None,
786
883
  num_hops: int,
787
884
  max_pq_iterations: int,
788
885
  evaluate: bool,
789
- random_seed: Optional[int] = _RANDOM_SEED,
790
- logger: Optional[ProgressLogger] = None,
886
+ random_seed: int | None = _RANDOM_SEED,
887
+ logger: ProgressLogger | None = None,
791
888
  ) -> Context:
792
889
 
793
890
  if num_neighbors is not None:
@@ -804,10 +901,9 @@ class KumoRFM:
804
901
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
805
902
  f"must go beyond this for your use-case.")
806
903
 
807
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
808
- task_type = LocalPQueryDriver.get_task_type(
809
- query,
810
- edge_types=self._graph_store.edge_types,
904
+ task_type = self._get_task_type(
905
+ query=query,
906
+ edge_types=self._sampler.edge_types,
811
907
  )
812
908
 
813
909
  if logger is not None:
@@ -839,14 +935,17 @@ class KumoRFM:
839
935
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
840
936
 
841
937
  if query.target_ast.date_offset_range is None:
842
- end_offset = pd.DateOffset(0)
938
+ step_offset = pd.DateOffset(0)
843
939
  else:
844
- end_offset = query.target_ast.date_offset_range.end_date_offset
845
- forecast_end_offset = end_offset * query.num_forecasts
940
+ step_offset = query.target_ast.date_offset_range.end_date_offset
941
+ end_offset = step_offset * query.num_forecasts
942
+
846
943
  if anchor_time is None:
847
- anchor_time = self._graph_store.max_time
944
+ anchor_time = self._get_default_anchor_time(query)
945
+
848
946
  if evaluate:
849
- anchor_time = anchor_time - forecast_end_offset
947
+ anchor_time = anchor_time - end_offset
948
+
850
949
  if logger is not None:
851
950
  assert isinstance(anchor_time, pd.Timestamp)
852
951
  if anchor_time == pd.Timestamp.min:
@@ -860,57 +959,71 @@ class KumoRFM:
860
959
 
861
960
  assert anchor_time is not None
862
961
  if isinstance(anchor_time, pd.Timestamp):
962
+ if context_anchor_time == 'entity':
963
+ raise ValueError("Anchor time 'entity' needs to be shared "
964
+ "for context and prediction examples")
863
965
  if context_anchor_time is None:
864
- context_anchor_time = anchor_time - forecast_end_offset
966
+ context_anchor_time = anchor_time - end_offset
865
967
  self._validate_time(query, anchor_time, context_anchor_time,
866
968
  evaluate)
867
969
  else:
868
970
  assert anchor_time == 'entity'
869
- if query.entity_table not in self._graph_store.time_dict:
971
+ if query.query_type != QueryType.STATIC:
972
+ raise ValueError("Anchor time 'entity' is only valid for "
973
+ "static predictive queries")
974
+ if query.entity_table not in self._sampler.time_column_dict:
870
975
  raise ValueError(f"Anchor time 'entity' requires the entity "
871
976
  f"table '{query.entity_table}' to "
872
977
  f"have a time column")
873
- if context_anchor_time is not None:
874
- warnings.warn("Ignoring option 'context_anchor_time' for "
875
- "`anchor_time='entity'`")
876
- context_anchor_time = None
978
+ if isinstance(context_anchor_time, pd.Timestamp):
979
+ raise ValueError("Anchor time 'entity' needs to be shared "
980
+ "for context and prediction examples")
981
+ context_anchor_time = 'entity'
877
982
 
878
- y_test: Optional[pd.Series] = None
983
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
879
984
  if evaluate:
880
- max_test_size = _MAX_TEST_SIZE[run_mode]
985
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
881
986
  if task_type.is_link_pred:
882
- max_test_size = max_test_size // 5
987
+ num_test_examples = num_test_examples // 5
988
+ else:
989
+ num_test_examples = 0
990
+
991
+ train, test = self._sampler.sample_target(
992
+ query=query,
993
+ num_train_examples=num_train_examples,
994
+ train_anchor_time=context_anchor_time,
995
+ num_train_trials=max_pq_iterations * num_train_examples,
996
+ num_test_examples=num_test_examples,
997
+ test_anchor_time=anchor_time,
998
+ num_test_trials=max_pq_iterations * num_test_examples,
999
+ random_seed=random_seed,
1000
+ )
1001
+ train_pkey, train_time, y_train = train
1002
+ test_pkey, test_time, y_test = test
883
1003
 
884
- test_node, test_time, y_test = query_driver.collect_test(
885
- size=max_test_size,
886
- anchor_time=anchor_time,
887
- max_iterations=max_pq_iterations,
888
- guarantee_train_examples=True,
889
- )
890
- if logger is not None:
891
- if task_type == TaskType.BINARY_CLASSIFICATION:
892
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
893
- msg = (f"Collected {len(y_test):,} test examples with "
894
- f"{pos:.2f}% positive cases")
895
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
896
- msg = (f"Collected {len(y_test):,} test examples "
897
- f"holding {y_test.nunique()} classes")
898
- elif task_type == TaskType.REGRESSION:
899
- _min, _max = float(y_test.min()), float(y_test.max())
900
- msg = (f"Collected {len(y_test):,} test examples with "
901
- f"targets between {format_value(_min)} and "
902
- f"{format_value(_max)}")
903
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
904
- num_rhs = y_test.explode().nunique()
905
- msg = (f"Collected {len(y_test):,} test examples with "
906
- f"{num_rhs:,} unique items")
907
- else:
908
- raise NotImplementedError
909
- logger.log(msg)
1004
+ if evaluate and logger is not None:
1005
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1006
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1007
+ msg = (f"Collected {len(y_test):,} test examples with "
1008
+ f"{pos:.2f}% positive cases")
1009
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1010
+ msg = (f"Collected {len(y_test):,} test examples holding "
1011
+ f"{y_test.nunique()} classes")
1012
+ elif task_type == TaskType.REGRESSION:
1013
+ _min, _max = float(y_test.min()), float(y_test.max())
1014
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1015
+ f"between {format_value(_min)} and "
1016
+ f"{format_value(_max)}")
1017
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1018
+ num_rhs = y_test.explode().nunique()
1019
+ msg = (f"Collected {len(y_test):,} test examples with "
1020
+ f"{num_rhs:,} unique items")
1021
+ else:
1022
+ raise NotImplementedError
1023
+ logger.log(msg)
910
1024
 
911
- else:
1025
+ if not evaluate:
912
1026
  assert indices is not None
913
-
914
1027
  if len(indices) > _MAX_PRED_SIZE[task_type]:
915
1028
  raise ValueError(f"Cannot predict for more than "
916
1029
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -918,26 +1031,12 @@ class KumoRFM:
918
1031
  f"`KumoRFM.batch_mode` to process entities "
919
1032
  f"in batches")
920
1033
 
921
- test_node = self._graph_store.get_node_id(
922
- table_name=query.entity_table,
923
- pkey=pd.Series(indices),
924
- )
925
-
1034
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
926
1035
  if isinstance(anchor_time, pd.Timestamp):
927
- test_time = pd.Series(anchor_time).repeat(
928
- len(test_node)).reset_index(drop=True)
1036
+ test_time = pd.Series([anchor_time]).repeat(
1037
+ len(indices)).reset_index(drop=True)
929
1038
  else:
930
- time = self._graph_store.time_dict[query.entity_table]
931
- time = time[test_node] * 1000**3
932
- test_time = pd.Series(time, dtype='datetime64[ns]')
933
-
934
- train_node, train_time, y_train = query_driver.collect_train(
935
- size=_MAX_CONTEXT_SIZE[run_mode],
936
- anchor_time=context_anchor_time or 'entity',
937
- exclude_node=test_node if (query.query_type == QueryType.STATIC
938
- or anchor_time == 'entity') else None,
939
- max_iterations=max_pq_iterations,
940
- )
1039
+ train_time = test_time = 'entity'
941
1040
 
942
1041
  if logger is not None:
943
1042
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -960,12 +1059,12 @@ class KumoRFM:
960
1059
  raise NotImplementedError
961
1060
  logger.log(msg)
962
1061
 
963
- entity_table_names: Tuple[str, ...]
1062
+ entity_table_names: tuple[str, ...]
964
1063
  if task_type.is_link_pred:
965
1064
  final_aggr = query.get_final_target_aggregation()
966
1065
  assert final_aggr is not None
967
1066
  edge_fkey = final_aggr._get_target_column_name()
968
- for edge_type in self._graph_store.edge_types:
1067
+ for edge_type in self._sampler.edge_types:
969
1068
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
970
1069
  entity_table_names = (
971
1070
  query.entity_table,
@@ -977,21 +1076,24 @@ class KumoRFM:
977
1076
  # Exclude the entity anchor time from the feature set to prevent
978
1077
  # running out-of-distribution between in-context and test examples:
979
1078
  exclude_cols_dict = query.get_exclude_cols_dict()
980
- if anchor_time == 'entity':
1079
+ if entity_table_names[0] in self._sampler.time_column_dict:
981
1080
  if entity_table_names[0] not in exclude_cols_dict:
982
1081
  exclude_cols_dict[entity_table_names[0]] = []
983
- time_column_dict = self._graph_store.time_column_dict
984
- time_column = time_column_dict[entity_table_names[0]]
1082
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
985
1083
  exclude_cols_dict[entity_table_names[0]].append(time_column)
986
1084
 
987
- subgraph = self._graph_sampler(
1085
+ subgraph = self._sampler.sample_subgraph(
988
1086
  entity_table_names=entity_table_names,
989
- node=np.concatenate([train_node, test_node]),
990
- time=np.concatenate([
991
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
992
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
993
- ]),
994
- run_mode=run_mode,
1087
+ entity_pkey=pd.concat(
1088
+ [train_pkey, test_pkey],
1089
+ axis=0,
1090
+ ignore_index=True,
1091
+ ),
1092
+ anchor_time=pd.concat(
1093
+ [train_time, test_time],
1094
+ axis=0,
1095
+ ignore_index=True,
1096
+ ) if isinstance(train_time, pd.Series) else 'entity',
995
1097
  num_neighbors=num_neighbors,
996
1098
  exclude_cols_dict=exclude_cols_dict,
997
1099
  )
@@ -1003,23 +1105,19 @@ class KumoRFM:
1003
1105
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1004
1106
  f"must go beyond this for your use-case.")
1005
1107
 
1006
- step_size: Optional[int] = None
1007
- if query.query_type == QueryType.TEMPORAL:
1008
- step_size = date_offset_to_seconds(end_offset)
1009
-
1010
1108
  return Context(
1011
1109
  task_type=task_type,
1012
1110
  entity_table_names=entity_table_names,
1013
1111
  subgraph=subgraph,
1014
1112
  y_train=y_train,
1015
- y_test=y_test,
1113
+ y_test=y_test if evaluate else None,
1016
1114
  top_k=query.top_k,
1017
- step_size=step_size,
1115
+ step_size=None,
1018
1116
  )
1019
1117
 
1020
1118
  @staticmethod
1021
1119
  def _validate_metrics(
1022
- metrics: List[str],
1120
+ metrics: list[str],
1023
1121
  task_type: TaskType,
1024
1122
  ) -> None:
1025
1123
 
@@ -1076,7 +1174,7 @@ class KumoRFM:
1076
1174
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1077
1175
 
1078
1176
 
1079
- def format_value(value: Union[int, float]) -> str:
1177
+ def format_value(value: int | float) -> str:
1080
1178
  if value == int(value):
1081
1179
  return f'{int(value):,}'
1082
1180
  if abs(value) >= 1000: