kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (52) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +37 -8
  6. kumoai/connector/file_upload_connector.py +94 -85
  7. kumoai/connector/utils.py +1399 -210
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/authenticate.py +8 -5
  10. kumoai/experimental/rfm/backend/__init__.py +0 -0
  11. kumoai/experimental/rfm/backend/local/__init__.py +38 -0
  12. kumoai/experimental/rfm/backend/local/table.py +109 -0
  13. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  17. kumoai/experimental/rfm/base/__init__.py +10 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/source.py +18 -0
  20. kumoai/experimental/rfm/base/table.py +545 -0
  21. kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
  22. kumoai/experimental/rfm/infer/__init__.py +6 -0
  23. kumoai/experimental/rfm/infer/dtype.py +79 -0
  24. kumoai/experimental/rfm/infer/pkey.py +126 -0
  25. kumoai/experimental/rfm/infer/time_col.py +62 -0
  26. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  27. kumoai/experimental/rfm/local_graph_sampler.py +58 -11
  28. kumoai/experimental/rfm/local_graph_store.py +45 -37
  29. kumoai/experimental/rfm/local_pquery_driver.py +342 -46
  30. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  31. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
  32. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  33. kumoai/experimental/rfm/rfm.py +559 -148
  34. kumoai/experimental/rfm/sagemaker.py +138 -0
  35. kumoai/jobs.py +27 -1
  36. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  37. kumoai/pquery/prediction_table.py +5 -3
  38. kumoai/pquery/training_table.py +5 -3
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/trainer/job.py +9 -30
  42. kumoai/trainer/trainer.py +19 -10
  43. kumoai/utils/__init__.py +2 -1
  44. kumoai/utils/progress_logger.py +96 -16
  45. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
  46. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
  47. kumoai/experimental/rfm/local_table.py +0 -448
  48. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
  49. kumoai/experimental/rfm/utils.py +0 -347
  50. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
  51. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
  52. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,52 @@
1
1
  import json
2
+ import time
2
3
  import warnings
3
- from typing import List, Literal, Optional, Union
4
+ from collections import defaultdict
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
4
19
 
5
20
  import numpy as np
6
21
  import pandas as pd
7
22
  from kumoapi.model_plan import RunMode
8
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.rfm import Context
25
+ from kumoapi.rfm import Explanation as ExplanationConfig
9
26
  from kumoapi.rfm import (
10
- Context,
11
- PQueryDefinition,
12
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
13
29
  RFMPredictRequest,
14
- RFMValidateQueryRequest,
15
30
  )
16
31
  from kumoapi.task import TaskType
17
32
 
18
- from kumoai import global_state
33
+ from kumoai.client.rfm import RFMAPI
19
34
  from kumoai.exceptions import HTTPException
20
- from kumoai.experimental.rfm import LocalGraph
35
+ from kumoai.experimental.rfm import Graph
21
36
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
22
37
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
23
- from kumoai.experimental.rfm.local_pquery_driver import LocalPQueryDriver
24
- from kumoai.utils import ProgressLogger
38
+ from kumoai.experimental.rfm.local_pquery_driver import (
39
+ LocalPQueryDriver,
40
+ date_offset_to_seconds,
41
+ )
42
+ from kumoai.mixin import CastMixin
43
+ from kumoai.utils import InteractiveProgressLogger, ProgressLogger
25
44
 
26
45
  _RANDOM_SEED = 42
27
46
 
47
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
+
28
50
  _MAX_CONTEXT_SIZE = {
29
51
  RunMode.DEBUG: 100,
30
52
  RunMode.FAST: 1_000,
@@ -39,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
39
61
  }
40
62
 
41
63
  _MAX_SIZE = 30 * 1024 * 1024
42
- _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
64
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
43
65
  "reduce either the number of tables in the graph, their "
44
66
  "number of columns (e.g., large text columns), "
45
67
  "neighborhood configuration, or the run mode. If none of "
@@ -48,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
48
70
  "beyond this for your use-case.")
49
71
 
50
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
84
+ @dataclass(repr=False)
85
+ class Explanation:
86
+ prediction: pd.DataFrame
87
+ summary: str
88
+ details: ExplanationConfig
89
+
90
+ @overload
91
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
92
+ pass
93
+
94
+ @overload
95
+ def __getitem__(self, index: Literal[1]) -> str:
96
+ pass
97
+
98
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
99
+ if index == 0:
100
+ return self.prediction
101
+ if index == 1:
102
+ return self.summary
103
+ raise IndexError("Index out of range")
104
+
105
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
106
+ return iter((self.prediction, self.summary))
107
+
108
+ def __repr__(self) -> str:
109
+ return str((self.prediction, self.summary))
110
+
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
117
+
51
118
  class KumoRFM:
52
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
53
120
  Foundation Model for In-Context Learning on Relational Data
@@ -56,17 +123,17 @@ class KumoRFM:
56
123
  :class:`KumoRFM` is a foundation model to generate predictions for any
57
124
  relational dataset without training.
58
125
  The model is pre-trained and the class provides an interface to query the
59
- model from a :class:`LocalGraph` object.
126
+ model from a :class:`Graph` object.
60
127
 
61
128
  .. code-block:: python
62
129
 
63
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
130
+ from kumoai.experimental.rfm import Graph, KumoRFM
64
131
 
65
132
  df_users = pd.DataFrame(...)
66
133
  df_items = pd.DataFrame(...)
67
134
  df_orders = pd.DataFrame(...)
68
135
 
69
- graph = LocalGraph.from_data({
136
+ graph = Graph.from_data({
70
137
  'users': df_users,
71
138
  'items': df_items,
72
139
  'orders': df_orders,
@@ -74,59 +141,152 @@ class KumoRFM:
74
141
 
75
142
  rfm = KumoRFM(graph)
76
143
 
77
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
78
- "FOR users.user_id=0")
79
- result = rfm.query(query)
144
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
145
+ "FOR users.user_id=1")
146
+ result = rfm.predict(query)
80
147
 
81
148
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
82
149
  # 1 0.85
83
150
 
84
151
  Args:
85
152
  graph: The graph.
86
- preprocess: Whether to pre-process the data in advance during graph
87
- materialization.
88
- This is a runtime trade-off between graph materialization and model
89
- processing speed.
90
- It can be benefical to preprocess your data once and then run many
91
- queries on top to achieve maximum model speed.
92
- However, if activiated, graph materialization can take potentially
93
- much longer, especially on graphs with many large text columns.
94
- Best to tune this option manually.
95
153
  verbose: Whether to print verbose output.
96
154
  """
97
155
  def __init__(
98
156
  self,
99
- graph: LocalGraph,
100
- preprocess: bool = False,
101
- verbose: bool = True,
157
+ graph: Graph,
158
+ verbose: Union[bool, ProgressLogger] = True,
102
159
  ) -> None:
103
160
  graph = graph.validate()
104
161
  self._graph_def = graph._to_api_graph_definition()
105
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
162
+ self._graph_store = LocalGraphStore(graph, verbose)
106
163
  self._graph_sampler = LocalGraphSampler(self._graph_store)
107
164
 
165
+ self._client: Optional[RFMAPI] = None
166
+
167
+ self._batch_size: Optional[int | Literal['max']] = None
168
+ self.num_retries: int = 0
169
+
170
+ @property
171
+ def _api_client(self) -> RFMAPI:
172
+ if self._client is not None:
173
+ return self._client
174
+
175
+ from kumoai.experimental.rfm import global_state
176
+ self._client = RFMAPI(global_state.client)
177
+ return self._client
178
+
108
179
  def __repr__(self) -> str:
109
180
  return f'{self.__class__.__name__}()'
110
181
 
182
+ @contextmanager
183
+ def batch_mode(
184
+ self,
185
+ batch_size: Union[int, Literal['max']] = 'max',
186
+ num_retries: int = 1,
187
+ ) -> Generator[None, None, None]:
188
+ """Context manager to predict in batches.
189
+
190
+ .. code-block:: python
191
+
192
+ with model.batch_mode(batch_size='max', num_retries=1):
193
+ df = model.predict(query, indices=...)
194
+
195
+ Args:
196
+ batch_size: The batch size. If set to ``"max"``, will use the
197
+ maximum applicable batch size for the given task.
198
+ num_retries: The maximum number of retries for failed queries due
199
+ to unexpected server issues.
200
+ """
201
+ if batch_size != 'max' and batch_size <= 0:
202
+ raise ValueError(f"'batch_size' must be greater than zero "
203
+ f"(got {batch_size})")
204
+
205
+ if num_retries < 0:
206
+ raise ValueError(f"'num_retries' must be greater than or equal to "
207
+ f"zero (got {num_retries})")
208
+
209
+ self._batch_size = batch_size
210
+ self.num_retries = num_retries
211
+ yield
212
+ self._batch_size = None
213
+ self.num_retries = 0
214
+
215
+ @overload
111
216
  def predict(
112
217
  self,
113
218
  query: str,
219
+ indices: Union[List[str], List[float], List[int], None] = None,
114
220
  *,
221
+ explain: Literal[False] = False,
115
222
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
223
+ context_anchor_time: Union[pd.Timestamp, None] = None,
116
224
  run_mode: Union[RunMode, str] = RunMode.FAST,
117
225
  num_neighbors: Optional[List[int]] = None,
118
226
  num_hops: int = 2,
119
227
  max_pq_iterations: int = 20,
120
228
  random_seed: Optional[int] = _RANDOM_SEED,
121
- verbose: bool = True,
229
+ verbose: Union[bool, ProgressLogger] = True,
230
+ use_prediction_time: bool = False,
122
231
  ) -> pd.DataFrame:
232
+ pass
233
+
234
+ @overload
235
+ def predict(
236
+ self,
237
+ query: str,
238
+ indices: Union[List[str], List[float], List[int], None] = None,
239
+ *,
240
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
241
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
242
+ context_anchor_time: Union[pd.Timestamp, None] = None,
243
+ run_mode: Union[RunMode, str] = RunMode.FAST,
244
+ num_neighbors: Optional[List[int]] = None,
245
+ num_hops: int = 2,
246
+ max_pq_iterations: int = 20,
247
+ random_seed: Optional[int] = _RANDOM_SEED,
248
+ verbose: Union[bool, ProgressLogger] = True,
249
+ use_prediction_time: bool = False,
250
+ ) -> Explanation:
251
+ pass
252
+
253
+ def predict(
254
+ self,
255
+ query: str,
256
+ indices: Union[List[str], List[float], List[int], None] = None,
257
+ *,
258
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
259
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
260
+ context_anchor_time: Union[pd.Timestamp, None] = None,
261
+ run_mode: Union[RunMode, str] = RunMode.FAST,
262
+ num_neighbors: Optional[List[int]] = None,
263
+ num_hops: int = 2,
264
+ max_pq_iterations: int = 20,
265
+ random_seed: Optional[int] = _RANDOM_SEED,
266
+ verbose: Union[bool, ProgressLogger] = True,
267
+ use_prediction_time: bool = False,
268
+ ) -> Union[pd.DataFrame, Explanation]:
123
269
  """Returns predictions for a predictive query.
124
270
 
125
271
  Args:
126
272
  query: The predictive query.
127
- anchor_time: The anchor timestamp for the query. If set to
128
- :obj:`None`, will use the maximum timestamp in the data.
129
- If set to :`"entity"`, will use the timestamp of the entity.
273
+ indices: The entity primary keys to predict for. Will override the
274
+ indices given as part of the predictive query. Predictions will
275
+ be generated for all indices, independent of whether they
276
+ fulfill entity filter constraints. To pre-filter entities, use
277
+ :meth:`~KumoRFM.is_valid_entity`.
278
+ explain: Configuration for explainability.
279
+ If set to ``True``, will additionally explain the prediction.
280
+ Passing in an :class:`ExplainConfig` instance provides control
281
+ over which parts of explanation are generated.
282
+ Explainability is currently only supported for single entity
283
+ predictions with ``run_mode="FAST"``.
284
+ anchor_time: The anchor timestamp for the prediction. If set to
285
+ ``None``, will use the maximum timestamp in the data.
286
+ If set to ``"entity"``, will use the timestamp of the entity.
287
+ context_anchor_time: The maximum anchor timestamp for context
288
+ examples. If set to ``None``, ``anchor_time`` will
289
+ determine the anchor time for context examples.
130
290
  run_mode: The :class:`RunMode` for the query.
131
291
  num_neighbors: The number of neighbors to sample for each hop.
132
292
  If specified, the ``num_hops`` option will be ignored.
@@ -138,79 +298,244 @@ class KumoRFM:
138
298
  entities to find valid labels.
139
299
  random_seed: A manual seed for generating pseudo-random numbers.
140
300
  verbose: Whether to print verbose output.
301
+ use_prediction_time: Whether to use the anchor timestamp as an
302
+ additional feature during prediction. This is typically
303
+ beneficial for time series forecasting tasks.
141
304
 
142
305
  Returns:
143
- The predictions as a :class:`pandas.DataFrame`
306
+ The predictions as a :class:`pandas.DataFrame`.
307
+ If ``explain`` is provided, returns an :class:`Explanation` object
308
+ containing the prediction, summary, and details.
144
309
  """
145
- explain = False
310
+ explain_config: Optional[ExplainConfig] = None
311
+ if explain is True:
312
+ explain_config = ExplainConfig()
313
+ elif explain is not False:
314
+ explain_config = ExplainConfig._cast(explain)
315
+
146
316
  query_def = self._parse_query(query)
317
+ query_str = query_def.to_string()
147
318
 
148
319
  if num_hops != 2 and num_neighbors is not None:
149
320
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
150
321
  f"custom 'num_hops={num_hops}' option")
151
322
 
152
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
323
+ if explain_config is not None and run_mode in {
324
+ RunMode.NORMAL, RunMode.BEST
325
+ }:
153
326
  warnings.warn(f"Explainability is currently only supported for "
154
327
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
155
328
  f"mode has been reset. Please lower the run mode to "
156
329
  f"suppress this warning.")
157
330
 
158
- if explain:
159
- assert query_def.entity.ids is not None
160
- if len(query_def.entity.ids.value) > 1:
161
- raise ValueError(
162
- f"Cannot explain predictions for more than a single "
163
- f"entity (got {len(query_def.entity.ids.value)})")
331
+ if indices is None:
332
+ if query_def.rfm_entity_ids is None:
333
+ raise ValueError("Cannot find entities to predict for. Please "
334
+ "pass them via `predict(query, indices=...)`")
335
+ indices = query_def.get_rfm_entity_id_list()
336
+ else:
337
+ query_def = replace(query_def, rfm_entity_ids=None)
338
+
339
+ if len(indices) == 0:
340
+ raise ValueError("At least one entity is required")
341
+
342
+ if explain_config is not None and len(indices) > 1:
343
+ raise ValueError(
344
+ f"Cannot explain predictions for more than a single entity "
345
+ f"(got {len(indices)})")
164
346
 
165
347
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
166
- if explain:
348
+ if explain_config is not None:
167
349
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
168
350
  else:
169
351
  msg = f'[bold]PREDICT[/bold] {query_repr}'
170
352
 
171
- with ProgressLogger(msg, verbose=verbose) as logger:
172
- context = self._get_context(
173
- query_def,
174
- anchor_time=anchor_time,
175
- run_mode=RunMode(run_mode),
176
- num_neighbors=num_neighbors,
177
- num_hops=num_hops,
178
- max_pq_iterations=max_pq_iterations,
179
- evaluate=False,
180
- random_seed=random_seed,
181
- logger=logger,
182
- )
183
- request = RFMPredictRequest(
184
- context=context,
185
- run_mode=RunMode(run_mode),
353
+ if not isinstance(verbose, ProgressLogger):
354
+ verbose = InteractiveProgressLogger(msg, verbose=verbose)
355
+
356
+ with verbose as logger:
357
+
358
+ batch_size: Optional[int] = None
359
+ if self._batch_size == 'max':
360
+ task_type = LocalPQueryDriver.get_task_type(
361
+ query_def,
362
+ edge_types=self._graph_store.edge_types,
363
+ )
364
+ batch_size = _MAX_PRED_SIZE[task_type]
365
+ else:
366
+ batch_size = self._batch_size
367
+
368
+ if batch_size is not None:
369
+ offsets = range(0, len(indices), batch_size)
370
+ batches = [indices[step:step + batch_size] for step in offsets]
371
+ else:
372
+ batches = [indices]
373
+
374
+ if len(batches) > 1:
375
+ logger.log(f"Splitting {len(indices):,} entities into "
376
+ f"{len(batches):,} batches of size {batch_size:,}")
377
+
378
+ predictions: List[pd.DataFrame] = []
379
+ summary: Optional[str] = None
380
+ details: Optional[Explanation] = None
381
+ for i, batch in enumerate(batches):
382
+ # TODO Re-use the context for subsequent predictions.
383
+ context = self._get_context(
384
+ query=query_def,
385
+ indices=batch,
386
+ anchor_time=anchor_time,
387
+ context_anchor_time=context_anchor_time,
388
+ run_mode=RunMode(run_mode),
389
+ num_neighbors=num_neighbors,
390
+ num_hops=num_hops,
391
+ max_pq_iterations=max_pq_iterations,
392
+ evaluate=False,
393
+ random_seed=random_seed,
394
+ logger=logger if i == 0 else None,
395
+ )
396
+ request = RFMPredictRequest(
397
+ context=context,
398
+ run_mode=RunMode(run_mode),
399
+ query=query_str,
400
+ use_prediction_time=use_prediction_time,
401
+ )
402
+ with warnings.catch_warnings():
403
+ warnings.filterwarnings('ignore', message='gencode')
404
+ request_msg = request.to_protobuf()
405
+ _bytes = request_msg.SerializeToString()
406
+ if i == 0:
407
+ logger.log(f"Generated context of size "
408
+ f"{len(_bytes) / (1024*1024):.2f}MB")
409
+
410
+ if len(_bytes) > _MAX_SIZE:
411
+ stats = Context.get_memory_stats(request_msg.context)
412
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
413
+
414
+ if (isinstance(verbose, InteractiveProgressLogger) and i == 0
415
+ and len(batches) > 1):
416
+ verbose.init_progress(
417
+ total=len(batches),
418
+ description='Predicting',
419
+ )
420
+
421
+ for attempt in range(self.num_retries + 1):
422
+ try:
423
+ if explain_config is not None:
424
+ resp = self._api_client.explain(
425
+ request=_bytes,
426
+ skip_summary=explain_config.skip_summary,
427
+ )
428
+ summary = resp.summary
429
+ details = resp.details
430
+ else:
431
+ resp = self._api_client.predict(_bytes)
432
+ df = pd.DataFrame(**resp.prediction)
433
+
434
+ # Cast 'ENTITY' to correct data type:
435
+ if 'ENTITY' in df:
436
+ entity = query_def.entity_table
437
+ pkey_map = self._graph_store.pkey_map_dict[entity]
438
+ df['ENTITY'] = df['ENTITY'].astype(
439
+ type(pkey_map.index[0]))
440
+
441
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
442
+ if 'ANCHOR_TIMESTAMP' in df:
443
+ ser = df['ANCHOR_TIMESTAMP']
444
+ if not pd.api.types.is_datetime64_any_dtype(ser):
445
+ if isinstance(ser.iloc[0], str):
446
+ unit = None
447
+ else:
448
+ unit = 'ms'
449
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
450
+ ser, errors='coerce', unit=unit)
451
+
452
+ predictions.append(df)
453
+
454
+ if (isinstance(verbose, InteractiveProgressLogger)
455
+ and len(batches) > 1):
456
+ verbose.step()
457
+
458
+ break
459
+ except HTTPException as e:
460
+ if attempt == self.num_retries:
461
+ try:
462
+ msg = json.loads(e.detail)['detail']
463
+ except Exception:
464
+ msg = e.detail
465
+ raise RuntimeError(
466
+ f"An unexpected exception occurred. Please "
467
+ f"create an issue at "
468
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
469
+ ) from None
470
+
471
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
472
+
473
+ if len(predictions) == 1:
474
+ prediction = predictions[0]
475
+ else:
476
+ prediction = pd.concat(predictions, ignore_index=True)
477
+
478
+ if explain_config is not None:
479
+ assert len(predictions) == 1
480
+ assert summary is not None
481
+ assert details is not None
482
+ return Explanation(
483
+ prediction=prediction,
484
+ summary=summary,
485
+ details=details,
186
486
  )
187
- with warnings.catch_warnings():
188
- warnings.filterwarnings('ignore', message='Protobuf gencode')
189
- request_msg = request.to_protobuf()
190
- request_bytes = request_msg.SerializeToString()
191
- logger.log(f"Generated context of size "
192
- f"{len(request_bytes) / (1024*1024):.2f}MB")
193
487
 
194
- if len(request_bytes) > _MAX_SIZE:
195
- stats_msg = Context.get_memory_stats(request_msg.context)
196
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
488
+ return prediction
197
489
 
198
- try:
199
- if explain:
200
- resp = global_state.client.rfm_api.explain(request_bytes)
201
- else:
202
- resp = global_state.client.rfm_api.predict(request_bytes)
203
- except HTTPException as e:
204
- try:
205
- msg = json.loads(e.detail)['detail']
206
- except Exception:
207
- msg = e.detail
208
- raise RuntimeError(f"An unexpected exception occurred. "
209
- f"Please create an issue at "
210
- f"'https://github.com/kumo-ai/kumo-rfm'. "
211
- f"{msg}") from None
490
+ def is_valid_entity(
491
+ self,
492
+ query: str,
493
+ indices: Union[List[str], List[float], List[int], None] = None,
494
+ *,
495
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
496
+ ) -> np.ndarray:
497
+ r"""Returns a mask that denotes which entities are valid for the
498
+ given predictive query, *i.e.*, which entities fulfill (temporal)
499
+ entity filter constraints.
212
500
 
213
- return pd.DataFrame(**resp.prediction)
501
+ Args:
502
+ query: The predictive query.
503
+ indices: The entity primary keys to predict for. Will override the
504
+ indices given as part of the predictive query.
505
+ anchor_time: The anchor timestamp for the prediction. If set to
506
+ ``None``, will use the maximum timestamp in the data.
507
+ If set to ``"entity"``, will use the timestamp of the entity.
508
+ """
509
+ query_def = self._parse_query(query)
510
+
511
+ if indices is None:
512
+ if query_def.rfm_entity_ids is None:
513
+ raise ValueError("Cannot find entities to predict for. Please "
514
+ "pass them via "
515
+ "`is_valid_entity(query, indices=...)`")
516
+ indices = query_def.get_rfm_entity_id_list()
517
+
518
+ if len(indices) == 0:
519
+ raise ValueError("At least one entity is required")
520
+
521
+ if anchor_time is None:
522
+ anchor_time = self._graph_store.max_time
523
+
524
+ if isinstance(anchor_time, pd.Timestamp):
525
+ self._validate_time(query_def, anchor_time, None, False)
526
+ else:
527
+ assert anchor_time == 'entity'
528
+ if (query_def.entity_table not in self._graph_store.time_dict):
529
+ raise ValueError(f"Anchor time 'entity' requires the entity "
530
+ f"table '{query_def.entity_table}' "
531
+ f"to have a time column.")
532
+
533
+ node = self._graph_store.get_node_id(
534
+ table_name=query_def.entity_table,
535
+ pkey=pd.Series(indices),
536
+ )
537
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
538
+ return query_driver.is_valid(node, anchor_time)
214
539
 
215
540
  def evaluate(
216
541
  self,
@@ -218,21 +543,26 @@ class KumoRFM:
218
543
  *,
219
544
  metrics: Optional[List[str]] = None,
220
545
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
546
+ context_anchor_time: Union[pd.Timestamp, None] = None,
221
547
  run_mode: Union[RunMode, str] = RunMode.FAST,
222
548
  num_neighbors: Optional[List[int]] = None,
223
549
  num_hops: int = 2,
224
550
  max_pq_iterations: int = 20,
225
551
  random_seed: Optional[int] = _RANDOM_SEED,
226
- verbose: bool = True,
552
+ verbose: Union[bool, ProgressLogger] = True,
553
+ use_prediction_time: bool = False,
227
554
  ) -> pd.DataFrame:
228
555
  """Evaluates a predictive query.
229
556
 
230
557
  Args:
231
558
  query: The predictive query.
232
559
  metrics: The metrics to use.
233
- anchor_time: The anchor timestamp for the query. If set to
234
- :obj:`None`, will use the maximum timestamp in the data.
235
- If set to :`"entity"`, will use the timestamp of the entity.
560
+ anchor_time: The anchor timestamp for the prediction. If set to
561
+ ``None``, will use the maximum timestamp in the data.
562
+ If set to ``"entity"``, will use the timestamp of the entity.
563
+ context_anchor_time: The maximum anchor timestamp for context
564
+ examples. If set to ``None``, ``anchor_time`` will
565
+ determine the anchor time for context examples.
236
566
  run_mode: The :class:`RunMode` for the query.
237
567
  num_neighbors: The number of neighbors to sample for each hop.
238
568
  If specified, the ``num_hops`` option will be ignored.
@@ -244,6 +574,9 @@ class KumoRFM:
244
574
  entities to find valid labels.
245
575
  random_seed: A manual seed for generating pseudo-random numbers.
246
576
  verbose: Whether to print verbose output.
577
+ use_prediction_time: Whether to use the anchor timestamp as an
578
+ additional feature during prediction. This is typically
579
+ beneficial for time series forecasting tasks.
247
580
 
248
581
  Returns:
249
582
  The metrics as a :class:`pandas.DataFrame`
@@ -254,13 +587,24 @@ class KumoRFM:
254
587
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
255
588
  f"custom 'num_hops={num_hops}' option")
256
589
 
590
+ if query_def.rfm_entity_ids is not None:
591
+ query_def = replace(
592
+ query_def,
593
+ rfm_entity_ids=None,
594
+ )
595
+
257
596
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
258
597
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
259
598
 
260
- with ProgressLogger(msg, verbose=verbose) as logger:
599
+ if not isinstance(verbose, ProgressLogger):
600
+ verbose = InteractiveProgressLogger(msg, verbose=verbose)
601
+
602
+ with verbose as logger:
261
603
  context = self._get_context(
262
- query_def,
604
+ query=query_def,
605
+ indices=None,
263
606
  anchor_time=anchor_time,
607
+ context_anchor_time=context_anchor_time,
264
608
  run_mode=RunMode(run_mode),
265
609
  num_neighbors=num_neighbors,
266
610
  num_hops=num_hops,
@@ -276,6 +620,7 @@ class KumoRFM:
276
620
  context=context,
277
621
  run_mode=RunMode(run_mode),
278
622
  metrics=metrics,
623
+ use_prediction_time=use_prediction_time,
279
624
  )
280
625
  with warnings.catch_warnings():
281
626
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -286,10 +631,10 @@ class KumoRFM:
286
631
 
287
632
  if len(request_bytes) > _MAX_SIZE:
288
633
  stats_msg = Context.get_memory_stats(request_msg.context)
289
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
634
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
290
635
 
291
636
  try:
292
- resp = global_state.client.rfm_api.evaluate(request_bytes)
637
+ resp = self._api_client.evaluate(request_bytes)
293
638
  except HTTPException as e:
294
639
  try:
295
640
  msg = json.loads(e.detail)['detail']
@@ -334,17 +679,19 @@ class KumoRFM:
334
679
 
335
680
  if anchor_time is None:
336
681
  anchor_time = self._graph_store.max_time
337
- anchor_time = anchor_time - query_def.target.end_offset
682
+ if query_def.target_ast.date_offset_range is not None:
683
+ anchor_time = anchor_time - (
684
+ query_def.target_ast.date_offset_range.end_date_offset *
685
+ query_def.num_forecasts)
338
686
 
339
687
  assert anchor_time is not None
340
688
  if isinstance(anchor_time, pd.Timestamp):
341
- self._validate_time(query_def, anchor_time, evaluate=True)
689
+ self._validate_time(query_def, anchor_time, None, evaluate=True)
342
690
  else:
343
691
  assert anchor_time == 'entity'
344
- if (query_def.entity.pkey.table_name
345
- not in self._graph_store.time_dict):
692
+ if (query_def.entity_table not in self._graph_store.time_dict):
346
693
  raise ValueError(f"Anchor time 'entity' requires the entity "
347
- f"table '{query_def.entity.pkey.table_name}' "
694
+ f"table '{query_def.entity_table}' "
348
695
  f"to have a time column")
349
696
 
350
697
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -355,18 +702,22 @@ class KumoRFM:
355
702
  anchor_time=anchor_time,
356
703
  batch_size=min(10_000, size),
357
704
  max_iterations=max_iterations,
705
+ guarantee_train_examples=False,
358
706
  )
359
707
 
708
+ entity = self._graph_store.pkey_map_dict[
709
+ query_def.entity_table].index[node]
710
+
360
711
  return pd.DataFrame({
361
- 'ENTITY': node,
712
+ 'ENTITY': entity,
362
713
  'ANCHOR_TIMESTAMP': time,
363
714
  'TARGET': y,
364
715
  })
365
716
 
366
717
  # Helpers #################################################################
367
718
 
368
- def _parse_query(self, query: str) -> PQueryDefinition:
369
- if isinstance(query, PQueryDefinition):
719
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
720
+ if isinstance(query, ValidatedPredictiveQuery):
370
721
  return query
371
722
 
372
723
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -376,12 +727,13 @@ class KumoRFM:
376
727
  "predictions or evaluations.")
377
728
 
378
729
  try:
379
- request = RFMValidateQueryRequest(
730
+ request = RFMParseQueryRequest(
380
731
  query=query,
381
732
  graph_definition=self._graph_def,
382
733
  )
383
734
 
384
- resp = global_state.client.rfm_api.validate_query(request)
735
+ resp = self._api_client.parse_query(request)
736
+
385
737
  # TODO Expose validation warnings.
386
738
 
387
739
  if len(resp.validation_response.warnings) > 0:
@@ -392,7 +744,7 @@ class KumoRFM:
392
744
  warnings.warn(f"Encountered the following warnings during "
393
745
  f"parsing:\n{msg}")
394
746
 
395
- return resp.query_definition
747
+ return resp.query
396
748
  except HTTPException as e:
397
749
  try:
398
750
  msg = json.loads(e.detail)['detail']
@@ -403,8 +755,9 @@ class KumoRFM:
403
755
 
404
756
  def _validate_time(
405
757
  self,
406
- query: PQueryDefinition,
758
+ query: ValidatedPredictiveQuery,
407
759
  anchor_time: pd.Timestamp,
760
+ context_anchor_time: Union[pd.Timestamp, None],
408
761
  evaluate: bool,
409
762
  ) -> None:
410
763
 
@@ -416,22 +769,45 @@ class KumoRFM:
416
769
  f"the earliest timestamp "
417
770
  f"'{self._graph_store.min_time}' in the data.")
418
771
 
419
- if anchor_time - query.target.end_offset < self._graph_store.min_time:
420
- raise ValueError(f"Anchor timestamp is too early or aggregation "
421
- f"time range is too large. To make this "
422
- f"prediction, we would need data back to "
423
- f"'{anchor_time - query.target.end_offset}', "
424
- f"however, your data only contains data back to "
772
+ if (context_anchor_time is not None
773
+ and context_anchor_time < self._graph_store.min_time):
774
+ raise ValueError(f"Context anchor timestamp is too early or "
775
+ f"aggregation time range is too large. To make "
776
+ f"this prediction, we would need data back to "
777
+ f"'{context_anchor_time}', however, your data "
778
+ f"only contains data back to "
425
779
  f"'{self._graph_store.min_time}'.")
426
780
 
427
- if (anchor_time - 2 * query.target.end_offset
428
- < self._graph_store.min_time):
429
- warnings.warn(f"Anchor timestamp is too early or aggregation "
430
- f"time range is too large. To form proper input "
431
- f"data, we would need data back to "
432
- f"'{anchor_time - 2 * query.target.end_offset}', "
433
- f"however, your data only contains data back to "
434
- f"'{self._graph_store.min_time}'.")
781
+ if query.target_ast.date_offset_range is not None:
782
+ end_offset = query.target_ast.date_offset_range.end_date_offset
783
+ else:
784
+ end_offset = pd.DateOffset(0)
785
+ forecast_end_offset = end_offset * query.num_forecasts
786
+ if (context_anchor_time is not None
787
+ and context_anchor_time > anchor_time):
788
+ warnings.warn(f"Context anchor timestamp "
789
+ f"(got '{context_anchor_time}') is set to a later "
790
+ f"date than the prediction anchor timestamp "
791
+ f"(got '{anchor_time}'). Please make sure this is "
792
+ f"intended.")
793
+ elif (query.query_type == QueryType.TEMPORAL
794
+ and context_anchor_time is not None
795
+ and context_anchor_time + forecast_end_offset > anchor_time):
796
+ warnings.warn(f"Aggregation for context examples at timestamp "
797
+ f"'{context_anchor_time}' will leak information "
798
+ f"from the prediction anchor timestamp "
799
+ f"'{anchor_time}'. Please make sure this is "
800
+ f"intended.")
801
+
802
+ elif (context_anchor_time is not None
803
+ and context_anchor_time - forecast_end_offset
804
+ < self._graph_store.min_time):
805
+ _time = context_anchor_time - forecast_end_offset
806
+ warnings.warn(f"Context anchor timestamp is too early or "
807
+ f"aggregation time range is too large. To form "
808
+ f"proper input data, we would need data back to "
809
+ f"'{_time}', however, your data only contains "
810
+ f"data back to '{self._graph_store.min_time}'.")
435
811
 
436
812
  if (not evaluate and anchor_time
437
813
  > self._graph_store.max_time + pd.DateOffset(days=1)):
@@ -439,17 +815,18 @@ class KumoRFM:
439
815
  f"latest timestamp '{self._graph_store.max_time}' "
440
816
  f"in the data. Please make sure this is intended.")
441
817
 
442
- if (evaluate and anchor_time
443
- > self._graph_store.max_time - query.target.end_offset):
818
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
819
+ if evaluate and anchor_time > max_eval_time:
444
820
  raise ValueError(
445
821
  f"Anchor timestamp for evaluation is after the latest "
446
- f"supported timestamp "
447
- f"'{self._graph_store.max_time - query.target.end_offset}'.")
822
+ f"supported timestamp '{max_eval_time}'.")
448
823
 
449
824
  def _get_context(
450
825
  self,
451
- query: PQueryDefinition,
826
+ query: ValidatedPredictiveQuery,
827
+ indices: Union[List[str], List[float], List[int], None],
452
828
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
829
+ context_anchor_time: Union[pd.Timestamp, None],
453
830
  run_mode: RunMode,
454
831
  num_neighbors: Optional[List[int]],
455
832
  num_hops: int,
@@ -474,8 +851,8 @@ class KumoRFM:
474
851
  f"must go beyond this for your use-case.")
475
852
 
476
853
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
477
- task_type = query.get_task_type(
478
- stypes=self._graph_store.stype_dict,
854
+ task_type = LocalPQueryDriver.get_task_type(
855
+ query,
479
856
  edge_types=self._graph_store.edge_types,
480
857
  )
481
858
 
@@ -507,28 +884,42 @@ class KumoRFM:
507
884
  else:
508
885
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
509
886
 
887
+ if query.target_ast.date_offset_range is None:
888
+ end_offset = pd.DateOffset(0)
889
+ else:
890
+ end_offset = query.target_ast.date_offset_range.end_date_offset
891
+ forecast_end_offset = end_offset * query.num_forecasts
510
892
  if anchor_time is None:
511
893
  anchor_time = self._graph_store.max_time
512
894
  if evaluate:
513
- anchor_time = anchor_time - query.target.end_offset
895
+ anchor_time = anchor_time - forecast_end_offset
514
896
  if logger is not None:
515
897
  assert isinstance(anchor_time, pd.Timestamp)
516
- if (anchor_time.hour == 0 and anchor_time.minute == 0
517
- and anchor_time.second == 0
518
- and anchor_time.microsecond == 0):
898
+ if anchor_time == pd.Timestamp.min:
899
+ pass # Static graph
900
+ elif (anchor_time.hour == 0 and anchor_time.minute == 0
901
+ and anchor_time.second == 0
902
+ and anchor_time.microsecond == 0):
519
903
  logger.log(f"Derived anchor time {anchor_time.date()}")
520
904
  else:
521
905
  logger.log(f"Derived anchor time {anchor_time}")
522
906
 
523
907
  assert anchor_time is not None
524
908
  if isinstance(anchor_time, pd.Timestamp):
525
- self._validate_time(query, anchor_time, evaluate)
909
+ if context_anchor_time is None:
910
+ context_anchor_time = anchor_time - forecast_end_offset
911
+ self._validate_time(query, anchor_time, context_anchor_time,
912
+ evaluate)
526
913
  else:
527
914
  assert anchor_time == 'entity'
528
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
915
+ if query.entity_table not in self._graph_store.time_dict:
529
916
  raise ValueError(f"Anchor time 'entity' requires the entity "
530
- f"table '{query.entity.pkey.table_name}' to "
917
+ f"table '{query.entity_table}' to "
531
918
  f"have a time column")
919
+ if context_anchor_time is not None:
920
+ warnings.warn("Ignoring option 'context_anchor_time' for "
921
+ "`anchor_time='entity'`")
922
+ context_anchor_time = None
532
923
 
533
924
  y_test: Optional[pd.Series] = None
534
925
  if evaluate:
@@ -540,6 +931,7 @@ class KumoRFM:
540
931
  size=max_test_size,
541
932
  anchor_time=anchor_time,
542
933
  max_iterations=max_pq_iterations,
934
+ guarantee_train_examples=True,
543
935
  )
544
936
  if logger is not None:
545
937
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -563,34 +955,31 @@ class KumoRFM:
563
955
  logger.log(msg)
564
956
 
565
957
  else:
566
- assert query.entity.ids is not None
958
+ assert indices is not None
567
959
 
568
- max_num_test = 200 if task_type.is_link_pred else 1000
569
- if len(query.entity.ids.value) > max_num_test:
960
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
570
961
  raise ValueError(f"Cannot predict for more than "
571
- f"{max_num_test:,} entities at once "
572
- f"(got {len(query.entity.ids.value):,})")
962
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
963
+ f"once (got {len(indices):,}). Use "
964
+ f"`KumoRFM.batch_mode` to process entities "
965
+ f"in batches")
573
966
 
574
967
  test_node = self._graph_store.get_node_id(
575
- table_name=query.entity.pkey.table_name,
576
- pkey=pd.Series(
577
- query.entity.ids.value,
578
- dtype=query.entity.ids.dtype,
579
- ),
968
+ table_name=query.entity_table,
969
+ pkey=pd.Series(indices),
580
970
  )
581
971
 
582
972
  if isinstance(anchor_time, pd.Timestamp):
583
973
  test_time = pd.Series(anchor_time).repeat(
584
974
  len(test_node)).reset_index(drop=True)
585
975
  else:
586
- time = self._graph_store.time_dict[
587
- query.entity.pkey.table_name]
976
+ time = self._graph_store.time_dict[query.entity_table]
588
977
  time = time[test_node] * 1000**3
589
978
  test_time = pd.Series(time, dtype='datetime64[ns]')
590
979
 
591
980
  train_node, train_time, y_train = query_driver.collect_train(
592
981
  size=_MAX_CONTEXT_SIZE[run_mode],
593
- anchor_time=anchor_time,
982
+ anchor_time=context_anchor_time or 'entity',
594
983
  exclude_node=test_node if (query.query_type == QueryType.STATIC
595
984
  or anchor_time == 'entity') else None,
596
985
  max_iterations=max_pq_iterations,
@@ -617,12 +1006,23 @@ class KumoRFM:
617
1006
  raise NotImplementedError
618
1007
  logger.log(msg)
619
1008
 
620
- entity_table_names = query.get_entity_table_names(
621
- self._graph_store.edge_types)
1009
+ entity_table_names: Tuple[str, ...]
1010
+ if task_type.is_link_pred:
1011
+ final_aggr = query.get_final_target_aggregation()
1012
+ assert final_aggr is not None
1013
+ edge_fkey = final_aggr._get_target_column_name()
1014
+ for edge_type in self._graph_store.edge_types:
1015
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
+ entity_table_names = (
1017
+ query.entity_table,
1018
+ edge_type[2],
1019
+ )
1020
+ else:
1021
+ entity_table_names = (query.entity_table, )
622
1022
 
623
1023
  # Exclude the entity anchor time from the feature set to prevent
624
1024
  # running out-of-distribution between in-context and test examples:
625
- exclude_cols_dict = query.exclude_cols_dict
1025
+ exclude_cols_dict = query.get_exclude_cols_dict()
626
1026
  if anchor_time == 'entity':
627
1027
  if entity_table_names[0] not in exclude_cols_dict:
628
1028
  exclude_cols_dict[entity_table_names[0]] = []
@@ -637,11 +1037,21 @@ class KumoRFM:
637
1037
  train_time.astype('datetime64[ns]').astype(int).to_numpy(),
638
1038
  test_time.astype('datetime64[ns]').astype(int).to_numpy(),
639
1039
  ]),
640
- run_mode=run_mode,
641
1040
  num_neighbors=num_neighbors,
642
1041
  exclude_cols_dict=exclude_cols_dict,
643
1042
  )
644
1043
 
1044
+ if len(subgraph.table_dict) >= 15:
1045
+ raise ValueError(f"Cannot query from a graph with more than 15 "
1046
+ f"tables (got {len(subgraph.table_dict)}). "
1047
+ f"Please create a feature request at "
1048
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1049
+ f"must go beyond this for your use-case.")
1050
+
1051
+ step_size: Optional[int] = None
1052
+ if query.query_type == QueryType.TEMPORAL:
1053
+ step_size = date_offset_to_seconds(end_offset)
1054
+
645
1055
  return Context(
646
1056
  task_type=task_type,
647
1057
  entity_table_names=entity_table_names,
@@ -649,6 +1059,7 @@ class KumoRFM:
649
1059
  y_train=y_train,
650
1060
  y_test=y_test,
651
1061
  top_k=query.top_k,
1062
+ step_size=step_size,
652
1063
  )
653
1064
 
654
1065
  @staticmethod
@@ -664,7 +1075,7 @@ class KumoRFM:
664
1075
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
665
1076
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
666
1077
  elif task_type == TaskType.REGRESSION:
667
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1078
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
668
1079
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
669
1080
  supported_metrics = [
670
1081
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',