kumoai 2.10.0.dev202509281831__cp313-cp313-win_amd64.whl → 2.13.0.dev202511211730__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,21 +1,36 @@
1
1
  import json
2
+ import time
2
3
  import warnings
3
- from typing import List, Literal, Optional, Union
4
+ from collections import defaultdict
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
4
19
 
5
20
  import numpy as np
6
21
  import pandas as pd
7
22
  from kumoapi.model_plan import RunMode
8
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.rfm import Context
25
+ from kumoapi.rfm import Explanation as ExplanationConfig
9
26
  from kumoapi.rfm import (
10
- Context,
11
- PQueryDefinition,
12
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
13
29
  RFMPredictRequest,
14
- RFMValidateQueryRequest,
15
30
  )
16
31
  from kumoapi.task import TaskType
17
32
 
18
- from kumoai import global_state
33
+ from kumoai.client.rfm import RFMAPI
19
34
  from kumoai.exceptions import HTTPException
20
35
  from kumoai.experimental.rfm import LocalGraph
21
36
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
@@ -24,10 +39,14 @@ from kumoai.experimental.rfm.local_pquery_driver import (
24
39
  LocalPQueryDriver,
25
40
  date_offset_to_seconds,
26
41
  )
42
+ from kumoai.mixin import CastMixin
27
43
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
28
44
 
29
45
  _RANDOM_SEED = 42
30
46
 
47
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
+
31
50
  _MAX_CONTEXT_SIZE = {
32
51
  RunMode.DEBUG: 100,
33
52
  RunMode.FAST: 1_000,
@@ -42,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
42
61
  }
43
62
 
44
63
  _MAX_SIZE = 30 * 1024 * 1024
45
- _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
64
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
46
65
  "reduce either the number of tables in the graph, their "
47
66
  "number of columns (e.g., large text columns), "
48
67
  "neighborhood configuration, or the run mode. If none of "
@@ -51,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
51
70
  "beyond this for your use-case.")
52
71
 
53
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
84
+ @dataclass(repr=False)
85
+ class Explanation:
86
+ prediction: pd.DataFrame
87
+ summary: str
88
+ details: ExplanationConfig
89
+
90
+ @overload
91
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
92
+ pass
93
+
94
+ @overload
95
+ def __getitem__(self, index: Literal[1]) -> str:
96
+ pass
97
+
98
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
99
+ if index == 0:
100
+ return self.prediction
101
+ if index == 1:
102
+ return self.summary
103
+ raise IndexError("Index out of range")
104
+
105
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
106
+ return iter((self.prediction, self.summary))
107
+
108
+ def __repr__(self) -> str:
109
+ return str((self.prediction, self.summary))
110
+
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
117
+
54
118
  class KumoRFM:
55
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
56
120
  Foundation Model for In-Context Learning on Relational Data
@@ -77,9 +141,9 @@ class KumoRFM:
77
141
 
78
142
  rfm = KumoRFM(graph)
79
143
 
80
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
81
- "FOR users.user_id=0")
82
- result = rfm.query(query)
144
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
145
+ "FOR users.user_id=1")
146
+ result = rfm.predict(query)
83
147
 
84
148
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
85
149
  # 1 0.85
@@ -108,13 +172,54 @@ class KumoRFM:
108
172
  self._graph_store = LocalGraphStore(graph, preprocess, verbose)
109
173
  self._graph_sampler = LocalGraphSampler(self._graph_store)
110
174
 
175
+ self._batch_size: Optional[int | Literal['max']] = None
176
+ self.num_retries: int = 0
177
+ from kumoai.experimental.rfm import global_state
178
+ self._api_client = RFMAPI(global_state.client)
179
+
111
180
  def __repr__(self) -> str:
112
181
  return f'{self.__class__.__name__}()'
113
182
 
183
+ @contextmanager
184
+ def batch_mode(
185
+ self,
186
+ batch_size: Union[int, Literal['max']] = 'max',
187
+ num_retries: int = 1,
188
+ ) -> Generator[None, None, None]:
189
+ """Context manager to predict in batches.
190
+
191
+ .. code-block:: python
192
+
193
+ with model.batch_mode(batch_size='max', num_retries=1):
194
+ df = model.predict(query, indices=...)
195
+
196
+ Args:
197
+ batch_size: The batch size. If set to ``"max"``, will use the
198
+ maximum applicable batch size for the given task.
199
+ num_retries: The maximum number of retries for failed queries due
200
+ to unexpected server issues.
201
+ """
202
+ if batch_size != 'max' and batch_size <= 0:
203
+ raise ValueError(f"'batch_size' must be greater than zero "
204
+ f"(got {batch_size})")
205
+
206
+ if num_retries < 0:
207
+ raise ValueError(f"'num_retries' must be greater than or equal to "
208
+ f"zero (got {num_retries})")
209
+
210
+ self._batch_size = batch_size
211
+ self.num_retries = num_retries
212
+ yield
213
+ self._batch_size = None
214
+ self.num_retries = 0
215
+
216
+ @overload
114
217
  def predict(
115
218
  self,
116
219
  query: str,
220
+ indices: Union[List[str], List[float], List[int], None] = None,
117
221
  *,
222
+ explain: Literal[False] = False,
118
223
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
119
224
  context_anchor_time: Union[pd.Timestamp, None] = None,
120
225
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -123,16 +228,65 @@ class KumoRFM:
123
228
  max_pq_iterations: int = 20,
124
229
  random_seed: Optional[int] = _RANDOM_SEED,
125
230
  verbose: Union[bool, ProgressLogger] = True,
231
+ use_prediction_time: bool = False,
126
232
  ) -> pd.DataFrame:
233
+ pass
234
+
235
+ @overload
236
+ def predict(
237
+ self,
238
+ query: str,
239
+ indices: Union[List[str], List[float], List[int], None] = None,
240
+ *,
241
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
242
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
243
+ context_anchor_time: Union[pd.Timestamp, None] = None,
244
+ run_mode: Union[RunMode, str] = RunMode.FAST,
245
+ num_neighbors: Optional[List[int]] = None,
246
+ num_hops: int = 2,
247
+ max_pq_iterations: int = 20,
248
+ random_seed: Optional[int] = _RANDOM_SEED,
249
+ verbose: Union[bool, ProgressLogger] = True,
250
+ use_prediction_time: bool = False,
251
+ ) -> Explanation:
252
+ pass
253
+
254
+ def predict(
255
+ self,
256
+ query: str,
257
+ indices: Union[List[str], List[float], List[int], None] = None,
258
+ *,
259
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
260
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
261
+ context_anchor_time: Union[pd.Timestamp, None] = None,
262
+ run_mode: Union[RunMode, str] = RunMode.FAST,
263
+ num_neighbors: Optional[List[int]] = None,
264
+ num_hops: int = 2,
265
+ max_pq_iterations: int = 20,
266
+ random_seed: Optional[int] = _RANDOM_SEED,
267
+ verbose: Union[bool, ProgressLogger] = True,
268
+ use_prediction_time: bool = False,
269
+ ) -> Union[pd.DataFrame, Explanation]:
127
270
  """Returns predictions for a predictive query.
128
271
 
129
272
  Args:
130
273
  query: The predictive query.
274
+ indices: The entity primary keys to predict for. Will override the
275
+ indices given as part of the predictive query. Predictions will
276
+ be generated for all indices, independent of whether they
277
+ fulfill entity filter constraints. To pre-filter entities, use
278
+ :meth:`~KumoRFM.is_valid_entity`.
279
+ explain: Configuration for explainability.
280
+ If set to ``True``, will additionally explain the prediction.
281
+ Passing in an :class:`ExplainConfig` instance provides control
282
+ over which parts of explanation are generated.
283
+ Explainability is currently only supported for single entity
284
+ predictions with ``run_mode="FAST"``.
131
285
  anchor_time: The anchor timestamp for the prediction. If set to
132
- :obj:`None`, will use the maximum timestamp in the data.
133
- If set to :`"entity"`, will use the timestamp of the entity.
286
+ ``None``, will use the maximum timestamp in the data.
287
+ If set to ``"entity"``, will use the timestamp of the entity.
134
288
  context_anchor_time: The maximum anchor timestamp for context
135
- examples. If set to :obj:`None`, :obj:`anchor_time` will
289
+ examples. If set to ``None``, ``anchor_time`` will
136
290
  determine the anchor time for context examples.
137
291
  run_mode: The :class:`RunMode` for the query.
138
292
  num_neighbors: The number of neighbors to sample for each hop.
@@ -145,32 +299,54 @@ class KumoRFM:
145
299
  entities to find valid labels.
146
300
  random_seed: A manual seed for generating pseudo-random numbers.
147
301
  verbose: Whether to print verbose output.
302
+ use_prediction_time: Whether to use the anchor timestamp as an
303
+ additional feature during prediction. This is typically
304
+ beneficial for time series forecasting tasks.
148
305
 
149
306
  Returns:
150
- The predictions as a :class:`pandas.DataFrame`
307
+ The predictions as a :class:`pandas.DataFrame`.
308
+ If ``explain`` is provided, returns an :class:`Explanation` object
309
+ containing the prediction, summary, and details.
151
310
  """
152
- explain = False
311
+ explain_config: Optional[ExplainConfig] = None
312
+ if explain is True:
313
+ explain_config = ExplainConfig()
314
+ elif explain is not False:
315
+ explain_config = ExplainConfig._cast(explain)
316
+
153
317
  query_def = self._parse_query(query)
318
+ query_str = query_def.to_string()
154
319
 
155
320
  if num_hops != 2 and num_neighbors is not None:
156
321
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
157
322
  f"custom 'num_hops={num_hops}' option")
158
323
 
159
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
324
+ if explain_config is not None and run_mode in {
325
+ RunMode.NORMAL, RunMode.BEST
326
+ }:
160
327
  warnings.warn(f"Explainability is currently only supported for "
161
328
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
162
329
  f"mode has been reset. Please lower the run mode to "
163
330
  f"suppress this warning.")
164
331
 
165
- if explain:
166
- assert query_def.entity.ids is not None
167
- if len(query_def.entity.ids.value) > 1:
168
- raise ValueError(
169
- f"Cannot explain predictions for more than a single "
170
- f"entity (got {len(query_def.entity.ids.value)})")
332
+ if indices is None:
333
+ if query_def.rfm_entity_ids is None:
334
+ raise ValueError("Cannot find entities to predict for. Please "
335
+ "pass them via `predict(query, indices=...)`")
336
+ indices = query_def.get_rfm_entity_id_list()
337
+ else:
338
+ query_def = replace(query_def, rfm_entity_ids=None)
339
+
340
+ if len(indices) == 0:
341
+ raise ValueError("At least one entity is required")
342
+
343
+ if explain_config is not None and len(indices) > 1:
344
+ raise ValueError(
345
+ f"Cannot explain predictions for more than a single entity "
346
+ f"(got {len(indices)})")
171
347
 
172
348
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
173
- if explain:
349
+ if explain_config is not None:
174
350
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
175
351
  else:
176
352
  msg = f'[bold]PREDICT[/bold] {query_repr}'
@@ -179,49 +355,188 @@ class KumoRFM:
179
355
  verbose = InteractiveProgressLogger(msg, verbose=verbose)
180
356
 
181
357
  with verbose as logger:
182
- context = self._get_context(
183
- query_def,
184
- anchor_time=anchor_time,
185
- context_anchor_time=context_anchor_time,
186
- run_mode=RunMode(run_mode),
187
- num_neighbors=num_neighbors,
188
- num_hops=num_hops,
189
- max_pq_iterations=max_pq_iterations,
190
- evaluate=False,
191
- random_seed=random_seed,
192
- logger=logger,
193
- )
194
- request = RFMPredictRequest(
195
- context=context,
196
- run_mode=RunMode(run_mode),
358
+
359
+ batch_size: Optional[int] = None
360
+ if self._batch_size == 'max':
361
+ task_type = LocalPQueryDriver.get_task_type(
362
+ query_def,
363
+ edge_types=self._graph_store.edge_types,
364
+ )
365
+ batch_size = _MAX_PRED_SIZE[task_type]
366
+ else:
367
+ batch_size = self._batch_size
368
+
369
+ if batch_size is not None:
370
+ offsets = range(0, len(indices), batch_size)
371
+ batches = [indices[step:step + batch_size] for step in offsets]
372
+ else:
373
+ batches = [indices]
374
+
375
+ if len(batches) > 1:
376
+ logger.log(f"Splitting {len(indices):,} entities into "
377
+ f"{len(batches):,} batches of size {batch_size:,}")
378
+
379
+ predictions: List[pd.DataFrame] = []
380
+ summary: Optional[str] = None
381
+ details: Optional[Explanation] = None
382
+ for i, batch in enumerate(batches):
383
+ # TODO Re-use the context for subsequent predictions.
384
+ context = self._get_context(
385
+ query=query_def,
386
+ indices=batch,
387
+ anchor_time=anchor_time,
388
+ context_anchor_time=context_anchor_time,
389
+ run_mode=RunMode(run_mode),
390
+ num_neighbors=num_neighbors,
391
+ num_hops=num_hops,
392
+ max_pq_iterations=max_pq_iterations,
393
+ evaluate=False,
394
+ random_seed=random_seed,
395
+ logger=logger if i == 0 else None,
396
+ )
397
+ request = RFMPredictRequest(
398
+ context=context,
399
+ run_mode=RunMode(run_mode),
400
+ query=query_str,
401
+ use_prediction_time=use_prediction_time,
402
+ )
403
+ with warnings.catch_warnings():
404
+ warnings.filterwarnings('ignore', message='gencode')
405
+ request_msg = request.to_protobuf()
406
+ _bytes = request_msg.SerializeToString()
407
+ if i == 0:
408
+ logger.log(f"Generated context of size "
409
+ f"{len(_bytes) / (1024*1024):.2f}MB")
410
+
411
+ if len(_bytes) > _MAX_SIZE:
412
+ stats = Context.get_memory_stats(request_msg.context)
413
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
414
+
415
+ if (isinstance(verbose, InteractiveProgressLogger) and i == 0
416
+ and len(batches) > 1):
417
+ verbose.init_progress(
418
+ total=len(batches),
419
+ description='Predicting',
420
+ )
421
+
422
+ for attempt in range(self.num_retries + 1):
423
+ try:
424
+ if explain_config is not None:
425
+ resp = self._api_client.explain(
426
+ request=_bytes,
427
+ skip_summary=explain_config.skip_summary,
428
+ )
429
+ summary = resp.summary
430
+ details = resp.details
431
+ else:
432
+ resp = self._api_client.predict(_bytes)
433
+ df = pd.DataFrame(**resp.prediction)
434
+
435
+ # Cast 'ENTITY' to correct data type:
436
+ if 'ENTITY' in df:
437
+ entity = query_def.entity_table
438
+ pkey_map = self._graph_store.pkey_map_dict[entity]
439
+ df['ENTITY'] = df['ENTITY'].astype(
440
+ type(pkey_map.index[0]))
441
+
442
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
443
+ if 'ANCHOR_TIMESTAMP' in df:
444
+ ser = df['ANCHOR_TIMESTAMP']
445
+ if not pd.api.types.is_datetime64_any_dtype(ser):
446
+ if isinstance(ser.iloc[0], str):
447
+ unit = None
448
+ else:
449
+ unit = 'ms'
450
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
451
+ ser, errors='coerce', unit=unit)
452
+
453
+ predictions.append(df)
454
+
455
+ if (isinstance(verbose, InteractiveProgressLogger)
456
+ and len(batches) > 1):
457
+ verbose.step()
458
+
459
+ break
460
+ except HTTPException as e:
461
+ if attempt == self.num_retries:
462
+ try:
463
+ msg = json.loads(e.detail)['detail']
464
+ except Exception:
465
+ msg = e.detail
466
+ raise RuntimeError(
467
+ f"An unexpected exception occurred. Please "
468
+ f"create an issue at "
469
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
470
+ ) from None
471
+
472
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
473
+
474
+ if len(predictions) == 1:
475
+ prediction = predictions[0]
476
+ else:
477
+ prediction = pd.concat(predictions, ignore_index=True)
478
+
479
+ if explain_config is not None:
480
+ assert len(predictions) == 1
481
+ assert summary is not None
482
+ assert details is not None
483
+ return Explanation(
484
+ prediction=prediction,
485
+ summary=summary,
486
+ details=details,
197
487
  )
198
- with warnings.catch_warnings():
199
- warnings.filterwarnings('ignore', message='Protobuf gencode')
200
- request_msg = request.to_protobuf()
201
- request_bytes = request_msg.SerializeToString()
202
- logger.log(f"Generated context of size "
203
- f"{len(request_bytes) / (1024*1024):.2f}MB")
204
488
 
205
- if len(request_bytes) > _MAX_SIZE:
206
- stats_msg = Context.get_memory_stats(request_msg.context)
207
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
489
+ return prediction
208
490
 
209
- try:
210
- if explain:
211
- resp = global_state.client.rfm_api.explain(request_bytes)
212
- else:
213
- resp = global_state.client.rfm_api.predict(request_bytes)
214
- except HTTPException as e:
215
- try:
216
- msg = json.loads(e.detail)['detail']
217
- except Exception:
218
- msg = e.detail
219
- raise RuntimeError(f"An unexpected exception occurred. "
220
- f"Please create an issue at "
221
- f"'https://github.com/kumo-ai/kumo-rfm'. "
222
- f"{msg}") from None
491
+ def is_valid_entity(
492
+ self,
493
+ query: str,
494
+ indices: Union[List[str], List[float], List[int], None] = None,
495
+ *,
496
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
497
+ ) -> np.ndarray:
498
+ r"""Returns a mask that denotes which entities are valid for the
499
+ given predictive query, *i.e.*, which entities fulfill (temporal)
500
+ entity filter constraints.
501
+
502
+ Args:
503
+ query: The predictive query.
504
+ indices: The entity primary keys to predict for. Will override the
505
+ indices given as part of the predictive query.
506
+ anchor_time: The anchor timestamp for the prediction. If set to
507
+ ``None``, will use the maximum timestamp in the data.
508
+ If set to ``"entity"``, will use the timestamp of the entity.
509
+ """
510
+ query_def = self._parse_query(query)
511
+
512
+ if indices is None:
513
+ if query_def.rfm_entity_ids is None:
514
+ raise ValueError("Cannot find entities to predict for. Please "
515
+ "pass them via "
516
+ "`is_valid_entity(query, indices=...)`")
517
+ indices = query_def.get_rfm_entity_id_list()
518
+
519
+ if len(indices) == 0:
520
+ raise ValueError("At least one entity is required")
223
521
 
224
- return pd.DataFrame(**resp.prediction)
522
+ if anchor_time is None:
523
+ anchor_time = self._graph_store.max_time
524
+
525
+ if isinstance(anchor_time, pd.Timestamp):
526
+ self._validate_time(query_def, anchor_time, None, False)
527
+ else:
528
+ assert anchor_time == 'entity'
529
+ if (query_def.entity_table not in self._graph_store.time_dict):
530
+ raise ValueError(f"Anchor time 'entity' requires the entity "
531
+ f"table '{query_def.entity_table}' "
532
+ f"to have a time column.")
533
+
534
+ node = self._graph_store.get_node_id(
535
+ table_name=query_def.entity_table,
536
+ pkey=pd.Series(indices),
537
+ )
538
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
539
+ return query_driver.is_valid(node, anchor_time)
225
540
 
226
541
  def evaluate(
227
542
  self,
@@ -236,6 +551,7 @@ class KumoRFM:
236
551
  max_pq_iterations: int = 20,
237
552
  random_seed: Optional[int] = _RANDOM_SEED,
238
553
  verbose: Union[bool, ProgressLogger] = True,
554
+ use_prediction_time: bool = False,
239
555
  ) -> pd.DataFrame:
240
556
  """Evaluates a predictive query.
241
557
 
@@ -243,10 +559,10 @@ class KumoRFM:
243
559
  query: The predictive query.
244
560
  metrics: The metrics to use.
245
561
  anchor_time: The anchor timestamp for the prediction. If set to
246
- :obj:`None`, will use the maximum timestamp in the data.
247
- If set to :`"entity"`, will use the timestamp of the entity.
562
+ ``None``, will use the maximum timestamp in the data.
563
+ If set to ``"entity"``, will use the timestamp of the entity.
248
564
  context_anchor_time: The maximum anchor timestamp for context
249
- examples. If set to :obj:`None`, :obj:`anchor_time` will
565
+ examples. If set to ``None``, ``anchor_time`` will
250
566
  determine the anchor time for context examples.
251
567
  run_mode: The :class:`RunMode` for the query.
252
568
  num_neighbors: The number of neighbors to sample for each hop.
@@ -259,6 +575,9 @@ class KumoRFM:
259
575
  entities to find valid labels.
260
576
  random_seed: A manual seed for generating pseudo-random numbers.
261
577
  verbose: Whether to print verbose output.
578
+ use_prediction_time: Whether to use the anchor timestamp as an
579
+ additional feature during prediction. This is typically
580
+ beneficial for time series forecasting tasks.
262
581
 
263
582
  Returns:
264
583
  The metrics as a :class:`pandas.DataFrame`
@@ -269,6 +588,12 @@ class KumoRFM:
269
588
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
270
589
  f"custom 'num_hops={num_hops}' option")
271
590
 
591
+ if query_def.rfm_entity_ids is not None:
592
+ query_def = replace(
593
+ query_def,
594
+ rfm_entity_ids=None,
595
+ )
596
+
272
597
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
273
598
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
274
599
 
@@ -277,7 +602,8 @@ class KumoRFM:
277
602
 
278
603
  with verbose as logger:
279
604
  context = self._get_context(
280
- query_def,
605
+ query=query_def,
606
+ indices=None,
281
607
  anchor_time=anchor_time,
282
608
  context_anchor_time=context_anchor_time,
283
609
  run_mode=RunMode(run_mode),
@@ -295,6 +621,7 @@ class KumoRFM:
295
621
  context=context,
296
622
  run_mode=RunMode(run_mode),
297
623
  metrics=metrics,
624
+ use_prediction_time=use_prediction_time,
298
625
  )
299
626
  with warnings.catch_warnings():
300
627
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -305,10 +632,10 @@ class KumoRFM:
305
632
 
306
633
  if len(request_bytes) > _MAX_SIZE:
307
634
  stats_msg = Context.get_memory_stats(request_msg.context)
308
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
635
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
309
636
 
310
637
  try:
311
- resp = global_state.client.rfm_api.evaluate(request_bytes)
638
+ resp = self._api_client.evaluate(request_bytes)
312
639
  except HTTPException as e:
313
640
  try:
314
641
  msg = json.loads(e.detail)['detail']
@@ -353,18 +680,19 @@ class KumoRFM:
353
680
 
354
681
  if anchor_time is None:
355
682
  anchor_time = self._graph_store.max_time
356
- anchor_time = anchor_time - (query_def.target.end_offset *
357
- query_def.num_forecasts)
683
+ if query_def.target_ast.date_offset_range is not None:
684
+ anchor_time = anchor_time - (
685
+ query_def.target_ast.date_offset_range.end_date_offset *
686
+ query_def.num_forecasts)
358
687
 
359
688
  assert anchor_time is not None
360
689
  if isinstance(anchor_time, pd.Timestamp):
361
690
  self._validate_time(query_def, anchor_time, None, evaluate=True)
362
691
  else:
363
692
  assert anchor_time == 'entity'
364
- if (query_def.entity.pkey.table_name
365
- not in self._graph_store.time_dict):
693
+ if (query_def.entity_table not in self._graph_store.time_dict):
366
694
  raise ValueError(f"Anchor time 'entity' requires the entity "
367
- f"table '{query_def.entity.pkey.table_name}' "
695
+ f"table '{query_def.entity_table}' "
368
696
  f"to have a time column")
369
697
 
370
698
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -379,7 +707,7 @@ class KumoRFM:
379
707
  )
380
708
 
381
709
  entity = self._graph_store.pkey_map_dict[
382
- query_def.entity.pkey.table_name].index[node]
710
+ query_def.entity_table].index[node]
383
711
 
384
712
  return pd.DataFrame({
385
713
  'ENTITY': entity,
@@ -389,8 +717,8 @@ class KumoRFM:
389
717
 
390
718
  # Helpers #################################################################
391
719
 
392
- def _parse_query(self, query: str) -> PQueryDefinition:
393
- if isinstance(query, PQueryDefinition):
720
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
721
+ if isinstance(query, ValidatedPredictiveQuery):
394
722
  return query
395
723
 
396
724
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -400,12 +728,13 @@ class KumoRFM:
400
728
  "predictions or evaluations.")
401
729
 
402
730
  try:
403
- request = RFMValidateQueryRequest(
731
+ request = RFMParseQueryRequest(
404
732
  query=query,
405
733
  graph_definition=self._graph_def,
406
734
  )
407
735
 
408
- resp = global_state.client.rfm_api.validate_query(request)
736
+ resp = self._api_client.parse_query(request)
737
+
409
738
  # TODO Expose validation warnings.
410
739
 
411
740
  if len(resp.validation_response.warnings) > 0:
@@ -416,7 +745,7 @@ class KumoRFM:
416
745
  warnings.warn(f"Encountered the following warnings during "
417
746
  f"parsing:\n{msg}")
418
747
 
419
- return resp.query_definition
748
+ return resp.query
420
749
  except HTTPException as e:
421
750
  try:
422
751
  msg = json.loads(e.detail)['detail']
@@ -427,7 +756,7 @@ class KumoRFM:
427
756
 
428
757
  def _validate_time(
429
758
  self,
430
- query: PQueryDefinition,
759
+ query: ValidatedPredictiveQuery,
431
760
  anchor_time: pd.Timestamp,
432
761
  context_anchor_time: Union[pd.Timestamp, None],
433
762
  evaluate: bool,
@@ -450,6 +779,11 @@ class KumoRFM:
450
779
  f"only contains data back to "
451
780
  f"'{self._graph_store.min_time}'.")
452
781
 
782
+ if query.target_ast.date_offset_range is not None:
783
+ end_offset = query.target_ast.date_offset_range.end_date_offset
784
+ else:
785
+ end_offset = pd.DateOffset(0)
786
+ forecast_end_offset = end_offset * query.num_forecasts
453
787
  if (context_anchor_time is not None
454
788
  and context_anchor_time > anchor_time):
455
789
  warnings.warn(f"Context anchor timestamp "
@@ -458,19 +792,18 @@ class KumoRFM:
458
792
  f"(got '{anchor_time}'). Please make sure this is "
459
793
  f"intended.")
460
794
  elif (query.query_type == QueryType.TEMPORAL
461
- and context_anchor_time is not None and context_anchor_time +
462
- query.target.end_offset * query.num_forecasts > anchor_time):
795
+ and context_anchor_time is not None
796
+ and context_anchor_time + forecast_end_offset > anchor_time):
463
797
  warnings.warn(f"Aggregation for context examples at timestamp "
464
798
  f"'{context_anchor_time}' will leak information "
465
799
  f"from the prediction anchor timestamp "
466
800
  f"'{anchor_time}'. Please make sure this is "
467
801
  f"intended.")
468
802
 
469
- elif (context_anchor_time is not None and context_anchor_time -
470
- query.target.end_offset * query.num_forecasts
803
+ elif (context_anchor_time is not None
804
+ and context_anchor_time - forecast_end_offset
471
805
  < self._graph_store.min_time):
472
- _time = context_anchor_time - (query.target.end_offset *
473
- query.num_forecasts)
806
+ _time = context_anchor_time - forecast_end_offset
474
807
  warnings.warn(f"Context anchor timestamp is too early or "
475
808
  f"aggregation time range is too large. To form "
476
809
  f"proper input data, we would need data back to "
@@ -483,8 +816,7 @@ class KumoRFM:
483
816
  f"latest timestamp '{self._graph_store.max_time}' "
484
817
  f"in the data. Please make sure this is intended.")
485
818
 
486
- max_eval_time = (self._graph_store.max_time -
487
- query.target.end_offset * query.num_forecasts)
819
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
488
820
  if evaluate and anchor_time > max_eval_time:
489
821
  raise ValueError(
490
822
  f"Anchor timestamp for evaluation is after the latest "
@@ -492,7 +824,8 @@ class KumoRFM:
492
824
 
493
825
  def _get_context(
494
826
  self,
495
- query: PQueryDefinition,
827
+ query: ValidatedPredictiveQuery,
828
+ indices: Union[List[str], List[float], List[int], None],
496
829
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
497
830
  context_anchor_time: Union[pd.Timestamp, None],
498
831
  run_mode: RunMode,
@@ -519,8 +852,8 @@ class KumoRFM:
519
852
  f"must go beyond this for your use-case.")
520
853
 
521
854
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
522
- task_type = query.get_task_type(
523
- stypes=self._graph_store.stype_dict,
855
+ task_type = LocalPQueryDriver.get_task_type(
856
+ query,
524
857
  edge_types=self._graph_store.edge_types,
525
858
  )
526
859
 
@@ -552,11 +885,15 @@ class KumoRFM:
552
885
  else:
553
886
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
554
887
 
888
+ if query.target_ast.date_offset_range is None:
889
+ end_offset = pd.DateOffset(0)
890
+ else:
891
+ end_offset = query.target_ast.date_offset_range.end_date_offset
892
+ forecast_end_offset = end_offset * query.num_forecasts
555
893
  if anchor_time is None:
556
894
  anchor_time = self._graph_store.max_time
557
895
  if evaluate:
558
- anchor_time = anchor_time - (query.target.end_offset *
559
- query.num_forecasts)
896
+ anchor_time = anchor_time - forecast_end_offset
560
897
  if logger is not None:
561
898
  assert isinstance(anchor_time, pd.Timestamp)
562
899
  if anchor_time == pd.Timestamp.min:
@@ -571,15 +908,14 @@ class KumoRFM:
571
908
  assert anchor_time is not None
572
909
  if isinstance(anchor_time, pd.Timestamp):
573
910
  if context_anchor_time is None:
574
- context_anchor_time = anchor_time - (query.target.end_offset *
575
- query.num_forecasts)
911
+ context_anchor_time = anchor_time - forecast_end_offset
576
912
  self._validate_time(query, anchor_time, context_anchor_time,
577
913
  evaluate)
578
914
  else:
579
915
  assert anchor_time == 'entity'
580
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
916
+ if query.entity_table not in self._graph_store.time_dict:
581
917
  raise ValueError(f"Anchor time 'entity' requires the entity "
582
- f"table '{query.entity.pkey.table_name}' to "
918
+ f"table '{query.entity_table}' to "
583
919
  f"have a time column")
584
920
  if context_anchor_time is not None:
585
921
  warnings.warn("Ignoring option 'context_anchor_time' for "
@@ -620,28 +956,25 @@ class KumoRFM:
620
956
  logger.log(msg)
621
957
 
622
958
  else:
623
- assert query.entity.ids is not None
959
+ assert indices is not None
624
960
 
625
- max_num_test = 200 if task_type.is_link_pred else 1000
626
- if len(query.entity.ids.value) > max_num_test:
961
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
627
962
  raise ValueError(f"Cannot predict for more than "
628
- f"{max_num_test:,} entities at once "
629
- f"(got {len(query.entity.ids.value):,})")
963
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
964
+ f"once (got {len(indices):,}). Use "
965
+ f"`KumoRFM.batch_mode` to process entities "
966
+ f"in batches")
630
967
 
631
968
  test_node = self._graph_store.get_node_id(
632
- table_name=query.entity.pkey.table_name,
633
- pkey=pd.Series(
634
- query.entity.ids.value,
635
- dtype=query.entity.ids.dtype,
636
- ),
969
+ table_name=query.entity_table,
970
+ pkey=pd.Series(indices),
637
971
  )
638
972
 
639
973
  if isinstance(anchor_time, pd.Timestamp):
640
974
  test_time = pd.Series(anchor_time).repeat(
641
975
  len(test_node)).reset_index(drop=True)
642
976
  else:
643
- time = self._graph_store.time_dict[
644
- query.entity.pkey.table_name]
977
+ time = self._graph_store.time_dict[query.entity_table]
645
978
  time = time[test_node] * 1000**3
646
979
  test_time = pd.Series(time, dtype='datetime64[ns]')
647
980
 
@@ -674,12 +1007,23 @@ class KumoRFM:
674
1007
  raise NotImplementedError
675
1008
  logger.log(msg)
676
1009
 
677
- entity_table_names = query.get_entity_table_names(
678
- self._graph_store.edge_types)
1010
+ entity_table_names: Tuple[str, ...]
1011
+ if task_type.is_link_pred:
1012
+ final_aggr = query.get_final_target_aggregation()
1013
+ assert final_aggr is not None
1014
+ edge_fkey = final_aggr._get_target_column_name()
1015
+ for edge_type in self._graph_store.edge_types:
1016
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1017
+ entity_table_names = (
1018
+ query.entity_table,
1019
+ edge_type[2],
1020
+ )
1021
+ else:
1022
+ entity_table_names = (query.entity_table, )
679
1023
 
680
1024
  # Exclude the entity anchor time from the feature set to prevent
681
1025
  # running out-of-distribution between in-context and test examples:
682
- exclude_cols_dict = query.exclude_cols_dict
1026
+ exclude_cols_dict = query.get_exclude_cols_dict()
683
1027
  if anchor_time == 'entity':
684
1028
  if entity_table_names[0] not in exclude_cols_dict:
685
1029
  exclude_cols_dict[entity_table_names[0]] = []
@@ -694,7 +1038,6 @@ class KumoRFM:
694
1038
  train_time.astype('datetime64[ns]').astype(int).to_numpy(),
695
1039
  test_time.astype('datetime64[ns]').astype(int).to_numpy(),
696
1040
  ]),
697
- run_mode=run_mode,
698
1041
  num_neighbors=num_neighbors,
699
1042
  exclude_cols_dict=exclude_cols_dict,
700
1043
  )
@@ -708,7 +1051,7 @@ class KumoRFM:
708
1051
 
709
1052
  step_size: Optional[int] = None
710
1053
  if query.query_type == QueryType.TEMPORAL:
711
- step_size = date_offset_to_seconds(query.target.end_offset)
1054
+ step_size = date_offset_to_seconds(end_offset)
712
1055
 
713
1056
  return Context(
714
1057
  task_type=task_type,
@@ -733,7 +1076,7 @@ class KumoRFM:
733
1076
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
734
1077
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
735
1078
  elif task_type == TaskType.REGRESSION:
736
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1079
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
737
1080
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
738
1081
  supported_metrics = [
739
1082
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',