kumoai 2.9.0.dev202509081831__cp312-cp312-win_amd64.whl → 2.12.0.dev202511111731__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. kumoai/__init__.py +4 -2
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +10 -5
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +37 -8
  6. kumoai/connector/file_upload_connector.py +71 -102
  7. kumoai/connector/utils.py +1367 -236
  8. kumoai/experimental/rfm/__init__.py +5 -3
  9. kumoai/experimental/rfm/authenticate.py +8 -5
  10. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  11. kumoai/experimental/rfm/local_graph.py +90 -80
  12. kumoai/experimental/rfm/local_graph_sampler.py +16 -8
  13. kumoai/experimental/rfm/local_graph_store.py +22 -6
  14. kumoai/experimental/rfm/local_pquery_driver.py +336 -42
  15. kumoai/experimental/rfm/local_table.py +100 -22
  16. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  17. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
  18. kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -222
  19. kumoai/experimental/rfm/rfm.py +514 -117
  20. kumoai/jobs.py +1 -0
  21. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  22. kumoai/trainer/trainer.py +19 -10
  23. kumoai/utils/progress_logger.py +68 -0
  24. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/METADATA +4 -5
  25. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/RECORD +28 -28
  26. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/WHEEL +0 -0
  27. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/licenses/LICENSE +0 -0
  28. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,32 @@
1
1
  import json
2
+ import time
2
3
  import warnings
3
- from typing import List, Literal, Optional, Union
4
+ from collections import defaultdict
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
4
19
 
5
20
  import numpy as np
6
21
  import pandas as pd
7
22
  from kumoapi.model_plan import RunMode
8
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.rfm import Context
25
+ from kumoapi.rfm import Explanation as ExplanationConfig
9
26
  from kumoapi.rfm import (
10
- Context,
11
- PQueryDefinition,
12
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
13
29
  RFMPredictRequest,
14
- RFMValidateQueryRequest,
15
30
  )
16
31
  from kumoapi.task import TaskType
17
32
 
@@ -24,10 +39,14 @@ from kumoai.experimental.rfm.local_pquery_driver import (
24
39
  LocalPQueryDriver,
25
40
  date_offset_to_seconds,
26
41
  )
42
+ from kumoai.mixin import CastMixin
27
43
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
28
44
 
29
45
  _RANDOM_SEED = 42
30
46
 
47
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
+
31
50
  _MAX_CONTEXT_SIZE = {
32
51
  RunMode.DEBUG: 100,
33
52
  RunMode.FAST: 1_000,
@@ -42,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
42
61
  }
43
62
 
44
63
  _MAX_SIZE = 30 * 1024 * 1024
45
- _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
64
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
46
65
  "reduce either the number of tables in the graph, their "
47
66
  "number of columns (e.g., large text columns), "
48
67
  "neighborhood configuration, or the run mode. If none of "
@@ -51,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
51
70
  "beyond this for your use-case.")
52
71
 
53
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
84
+ @dataclass(repr=False)
85
+ class Explanation:
86
+ prediction: pd.DataFrame
87
+ summary: str
88
+ details: ExplanationConfig
89
+
90
+ @overload
91
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
92
+ pass
93
+
94
+ @overload
95
+ def __getitem__(self, index: Literal[1]) -> str:
96
+ pass
97
+
98
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
99
+ if index == 0:
100
+ return self.prediction
101
+ if index == 1:
102
+ return self.summary
103
+ raise IndexError("Index out of range")
104
+
105
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
106
+ return iter((self.prediction, self.summary))
107
+
108
+ def __repr__(self) -> str:
109
+ return str((self.prediction, self.summary))
110
+
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
117
+
54
118
  class KumoRFM:
55
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
56
120
  Foundation Model for In-Context Learning on Relational Data
@@ -108,28 +172,120 @@ class KumoRFM:
108
172
  self._graph_store = LocalGraphStore(graph, preprocess, verbose)
109
173
  self._graph_sampler = LocalGraphSampler(self._graph_store)
110
174
 
175
+ self._batch_size: Optional[int | Literal['max']] = None
176
+ self.num_retries: int = 0
177
+
111
178
  def __repr__(self) -> str:
112
179
  return f'{self.__class__.__name__}()'
113
180
 
181
+ @contextmanager
182
+ def batch_mode(
183
+ self,
184
+ batch_size: Union[int, Literal['max']] = 'max',
185
+ num_retries: int = 1,
186
+ ) -> Generator[None, None, None]:
187
+ """Context manager to predict in batches.
188
+
189
+ .. code-block:: python
190
+
191
+ with model.batch_mode(batch_size='max', num_retries=1):
192
+ df = model.predict(query, indices=...)
193
+
194
+ Args:
195
+ batch_size: The batch size. If set to ``"max"``, will use the
196
+ maximum applicable batch size for the given task.
197
+ num_retries: The maximum number of retries for failed queries due
198
+ to unexpected server issues.
199
+ """
200
+ if batch_size != 'max' and batch_size <= 0:
201
+ raise ValueError(f"'batch_size' must be greater than zero "
202
+ f"(got {batch_size})")
203
+
204
+ if num_retries < 0:
205
+ raise ValueError(f"'num_retries' must be greater than or equal to "
206
+ f"zero (got {num_retries})")
207
+
208
+ self._batch_size = batch_size
209
+ self.num_retries = num_retries
210
+ yield
211
+ self._batch_size = None
212
+ self.num_retries = 0
213
+
214
+ @overload
114
215
  def predict(
115
216
  self,
116
217
  query: str,
218
+ indices: Union[List[str], List[float], List[int], None] = None,
117
219
  *,
220
+ explain: Literal[False] = False,
118
221
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
222
+ context_anchor_time: Union[pd.Timestamp, None] = None,
119
223
  run_mode: Union[RunMode, str] = RunMode.FAST,
120
224
  num_neighbors: Optional[List[int]] = None,
121
225
  num_hops: int = 2,
122
226
  max_pq_iterations: int = 20,
123
227
  random_seed: Optional[int] = _RANDOM_SEED,
124
228
  verbose: Union[bool, ProgressLogger] = True,
229
+ use_prediction_time: bool = False,
125
230
  ) -> pd.DataFrame:
231
+ pass
232
+
233
+ @overload
234
+ def predict(
235
+ self,
236
+ query: str,
237
+ indices: Union[List[str], List[float], List[int], None] = None,
238
+ *,
239
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
240
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
241
+ context_anchor_time: Union[pd.Timestamp, None] = None,
242
+ run_mode: Union[RunMode, str] = RunMode.FAST,
243
+ num_neighbors: Optional[List[int]] = None,
244
+ num_hops: int = 2,
245
+ max_pq_iterations: int = 20,
246
+ random_seed: Optional[int] = _RANDOM_SEED,
247
+ verbose: Union[bool, ProgressLogger] = True,
248
+ use_prediction_time: bool = False,
249
+ ) -> Explanation:
250
+ pass
251
+
252
+ def predict(
253
+ self,
254
+ query: str,
255
+ indices: Union[List[str], List[float], List[int], None] = None,
256
+ *,
257
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
258
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
259
+ context_anchor_time: Union[pd.Timestamp, None] = None,
260
+ run_mode: Union[RunMode, str] = RunMode.FAST,
261
+ num_neighbors: Optional[List[int]] = None,
262
+ num_hops: int = 2,
263
+ max_pq_iterations: int = 20,
264
+ random_seed: Optional[int] = _RANDOM_SEED,
265
+ verbose: Union[bool, ProgressLogger] = True,
266
+ use_prediction_time: bool = False,
267
+ ) -> Union[pd.DataFrame, Explanation]:
126
268
  """Returns predictions for a predictive query.
127
269
 
128
270
  Args:
129
271
  query: The predictive query.
130
- anchor_time: The anchor timestamp for the query. If set to
131
- :obj:`None`, will use the maximum timestamp in the data.
132
- If set to :`"entity"`, will use the timestamp of the entity.
272
+ indices: The entity primary keys to predict for. Will override the
273
+ indices given as part of the predictive query. Predictions will
274
+ be generated for all indices, independent of whether they
275
+ fulfill entity filter constraints. To pre-filter entities, use
276
+ :meth:`~KumoRFM.is_valid_entity`.
277
+ explain: Configuration for explainability.
278
+ If set to ``True``, will additionally explain the prediction.
279
+ Passing in an :class:`ExplainConfig` instance provides control
280
+ over which parts of explanation are generated.
281
+ Explainability is currently only supported for single entity
282
+ predictions with ``run_mode="FAST"``.
283
+ anchor_time: The anchor timestamp for the prediction. If set to
284
+ ``None``, will use the maximum timestamp in the data.
285
+ If set to ``"entity"``, will use the timestamp of the entity.
286
+ context_anchor_time: The maximum anchor timestamp for context
287
+ examples. If set to ``None``, ``anchor_time`` will
288
+ determine the anchor time for context examples.
133
289
  run_mode: The :class:`RunMode` for the query.
134
290
  num_neighbors: The number of neighbors to sample for each hop.
135
291
  If specified, the ``num_hops`` option will be ignored.
@@ -141,32 +297,54 @@ class KumoRFM:
141
297
  entities to find valid labels.
142
298
  random_seed: A manual seed for generating pseudo-random numbers.
143
299
  verbose: Whether to print verbose output.
300
+ use_prediction_time: Whether to use the anchor timestamp as an
301
+ additional feature during prediction. This is typically
302
+ beneficial for time series forecasting tasks.
144
303
 
145
304
  Returns:
146
- The predictions as a :class:`pandas.DataFrame`
305
+ The predictions as a :class:`pandas.DataFrame`.
306
+ If ``explain`` is provided, returns an :class:`Explanation` object
307
+ containing the prediction, summary, and details.
147
308
  """
148
- explain = False
309
+ explain_config: Optional[ExplainConfig] = None
310
+ if explain is True:
311
+ explain_config = ExplainConfig()
312
+ elif explain is not False:
313
+ explain_config = ExplainConfig._cast(explain)
314
+
149
315
  query_def = self._parse_query(query)
316
+ query_str = query_def.to_string()
150
317
 
151
318
  if num_hops != 2 and num_neighbors is not None:
152
319
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
153
320
  f"custom 'num_hops={num_hops}' option")
154
321
 
155
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
322
+ if explain_config is not None and run_mode in {
323
+ RunMode.NORMAL, RunMode.BEST
324
+ }:
156
325
  warnings.warn(f"Explainability is currently only supported for "
157
326
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
158
327
  f"mode has been reset. Please lower the run mode to "
159
328
  f"suppress this warning.")
160
329
 
161
- if explain:
162
- assert query_def.entity.ids is not None
163
- if len(query_def.entity.ids.value) > 1:
164
- raise ValueError(
165
- f"Cannot explain predictions for more than a single "
166
- f"entity (got {len(query_def.entity.ids.value)})")
330
+ if indices is None:
331
+ if query_def.rfm_entity_ids is None:
332
+ raise ValueError("Cannot find entities to predict for. Please "
333
+ "pass them via `predict(query, indices=...)`")
334
+ indices = query_def.get_rfm_entity_id_list()
335
+ else:
336
+ query_def = replace(query_def, rfm_entity_ids=None)
337
+
338
+ if len(indices) == 0:
339
+ raise ValueError("At least one entity is required")
340
+
341
+ if explain_config is not None and len(indices) > 1:
342
+ raise ValueError(
343
+ f"Cannot explain predictions for more than a single entity "
344
+ f"(got {len(indices)})")
167
345
 
168
346
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
169
- if explain:
347
+ if explain_config is not None:
170
348
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
171
349
  else:
172
350
  msg = f'[bold]PREDICT[/bold] {query_repr}'
@@ -175,48 +353,188 @@ class KumoRFM:
175
353
  verbose = InteractiveProgressLogger(msg, verbose=verbose)
176
354
 
177
355
  with verbose as logger:
178
- context = self._get_context(
179
- query_def,
180
- anchor_time=anchor_time,
181
- run_mode=RunMode(run_mode),
182
- num_neighbors=num_neighbors,
183
- num_hops=num_hops,
184
- max_pq_iterations=max_pq_iterations,
185
- evaluate=False,
186
- random_seed=random_seed,
187
- logger=logger,
188
- )
189
- request = RFMPredictRequest(
190
- context=context,
191
- run_mode=RunMode(run_mode),
356
+
357
+ batch_size: Optional[int] = None
358
+ if self._batch_size == 'max':
359
+ task_type = LocalPQueryDriver.get_task_type(
360
+ query_def,
361
+ edge_types=self._graph_store.edge_types,
362
+ )
363
+ batch_size = _MAX_PRED_SIZE[task_type]
364
+ else:
365
+ batch_size = self._batch_size
366
+
367
+ if batch_size is not None:
368
+ offsets = range(0, len(indices), batch_size)
369
+ batches = [indices[step:step + batch_size] for step in offsets]
370
+ else:
371
+ batches = [indices]
372
+
373
+ if len(batches) > 1:
374
+ logger.log(f"Splitting {len(indices):,} entities into "
375
+ f"{len(batches):,} batches of size {batch_size:,}")
376
+
377
+ predictions: List[pd.DataFrame] = []
378
+ summary: Optional[str] = None
379
+ details: Optional[Explanation] = None
380
+ for i, batch in enumerate(batches):
381
+ # TODO Re-use the context for subsequent predictions.
382
+ context = self._get_context(
383
+ query=query_def,
384
+ indices=batch,
385
+ anchor_time=anchor_time,
386
+ context_anchor_time=context_anchor_time,
387
+ run_mode=RunMode(run_mode),
388
+ num_neighbors=num_neighbors,
389
+ num_hops=num_hops,
390
+ max_pq_iterations=max_pq_iterations,
391
+ evaluate=False,
392
+ random_seed=random_seed,
393
+ logger=logger if i == 0 else None,
394
+ )
395
+ request = RFMPredictRequest(
396
+ context=context,
397
+ run_mode=RunMode(run_mode),
398
+ query=query_str,
399
+ use_prediction_time=use_prediction_time,
400
+ )
401
+ with warnings.catch_warnings():
402
+ warnings.filterwarnings('ignore', message='gencode')
403
+ request_msg = request.to_protobuf()
404
+ _bytes = request_msg.SerializeToString()
405
+ if i == 0:
406
+ logger.log(f"Generated context of size "
407
+ f"{len(_bytes) / (1024*1024):.2f}MB")
408
+
409
+ if len(_bytes) > _MAX_SIZE:
410
+ stats = Context.get_memory_stats(request_msg.context)
411
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
412
+
413
+ if (isinstance(verbose, InteractiveProgressLogger) and i == 0
414
+ and len(batches) > 1):
415
+ verbose.init_progress(
416
+ total=len(batches),
417
+ description='Predicting',
418
+ )
419
+
420
+ for attempt in range(self.num_retries + 1):
421
+ try:
422
+ if explain_config is not None:
423
+ resp = global_state.client.rfm_api.explain(
424
+ request=_bytes,
425
+ skip_summary=explain_config.skip_summary,
426
+ )
427
+ summary = resp.summary
428
+ details = resp.details
429
+ else:
430
+ resp = global_state.client.rfm_api.predict(_bytes)
431
+ df = pd.DataFrame(**resp.prediction)
432
+
433
+ # Cast 'ENTITY' to correct data type:
434
+ if 'ENTITY' in df:
435
+ entity = query_def.entity_table
436
+ pkey_map = self._graph_store.pkey_map_dict[entity]
437
+ df['ENTITY'] = df['ENTITY'].astype(
438
+ type(pkey_map.index[0]))
439
+
440
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
441
+ if 'ANCHOR_TIMESTAMP' in df:
442
+ ser = df['ANCHOR_TIMESTAMP']
443
+ if not pd.api.types.is_datetime64_any_dtype(ser):
444
+ if isinstance(ser.iloc[0], str):
445
+ unit = None
446
+ else:
447
+ unit = 'ms'
448
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
449
+ ser, errors='coerce', unit=unit)
450
+
451
+ predictions.append(df)
452
+
453
+ if (isinstance(verbose, InteractiveProgressLogger)
454
+ and len(batches) > 1):
455
+ verbose.step()
456
+
457
+ break
458
+ except HTTPException as e:
459
+ if attempt == self.num_retries:
460
+ try:
461
+ msg = json.loads(e.detail)['detail']
462
+ except Exception:
463
+ msg = e.detail
464
+ raise RuntimeError(
465
+ f"An unexpected exception occurred. Please "
466
+ f"create an issue at "
467
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
468
+ ) from None
469
+
470
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
471
+
472
+ if len(predictions) == 1:
473
+ prediction = predictions[0]
474
+ else:
475
+ prediction = pd.concat(predictions, ignore_index=True)
476
+
477
+ if explain_config is not None:
478
+ assert len(predictions) == 1
479
+ assert summary is not None
480
+ assert details is not None
481
+ return Explanation(
482
+ prediction=prediction,
483
+ summary=summary,
484
+ details=details,
192
485
  )
193
- with warnings.catch_warnings():
194
- warnings.filterwarnings('ignore', message='Protobuf gencode')
195
- request_msg = request.to_protobuf()
196
- request_bytes = request_msg.SerializeToString()
197
- logger.log(f"Generated context of size "
198
- f"{len(request_bytes) / (1024*1024):.2f}MB")
199
486
 
200
- if len(request_bytes) > _MAX_SIZE:
201
- stats_msg = Context.get_memory_stats(request_msg.context)
202
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
487
+ return prediction
203
488
 
204
- try:
205
- if explain:
206
- resp = global_state.client.rfm_api.explain(request_bytes)
207
- else:
208
- resp = global_state.client.rfm_api.predict(request_bytes)
209
- except HTTPException as e:
210
- try:
211
- msg = json.loads(e.detail)['detail']
212
- except Exception:
213
- msg = e.detail
214
- raise RuntimeError(f"An unexpected exception occurred. "
215
- f"Please create an issue at "
216
- f"'https://github.com/kumo-ai/kumo-rfm'. "
217
- f"{msg}") from None
489
+ def is_valid_entity(
490
+ self,
491
+ query: str,
492
+ indices: Union[List[str], List[float], List[int], None] = None,
493
+ *,
494
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
495
+ ) -> np.ndarray:
496
+ r"""Returns a mask that denotes which entities are valid for the
497
+ given predictive query, *i.e.*, which entities fulfill (temporal)
498
+ entity filter constraints.
499
+
500
+ Args:
501
+ query: The predictive query.
502
+ indices: The entity primary keys to predict for. Will override the
503
+ indices given as part of the predictive query.
504
+ anchor_time: The anchor timestamp for the prediction. If set to
505
+ ``None``, will use the maximum timestamp in the data.
506
+ If set to ``"entity"``, will use the timestamp of the entity.
507
+ """
508
+ query_def = self._parse_query(query)
509
+
510
+ if indices is None:
511
+ if query_def.rfm_entity_ids is None:
512
+ raise ValueError("Cannot find entities to predict for. Please "
513
+ "pass them via "
514
+ "`is_valid_entity(query, indices=...)`")
515
+ indices = query_def.get_rfm_entity_id_list()
516
+
517
+ if len(indices) == 0:
518
+ raise ValueError("At least one entity is required")
519
+
520
+ if anchor_time is None:
521
+ anchor_time = self._graph_store.max_time
522
+
523
+ if isinstance(anchor_time, pd.Timestamp):
524
+ self._validate_time(query_def, anchor_time, None, False)
525
+ else:
526
+ assert anchor_time == 'entity'
527
+ if (query_def.entity_table not in self._graph_store.time_dict):
528
+ raise ValueError(f"Anchor time 'entity' requires the entity "
529
+ f"table '{query_def.entity_table}' "
530
+ f"to have a time column.")
218
531
 
219
- return pd.DataFrame(**resp.prediction)
532
+ node = self._graph_store.get_node_id(
533
+ table_name=query_def.entity_table,
534
+ pkey=pd.Series(indices),
535
+ )
536
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
537
+ return query_driver.is_valid(node, anchor_time)
220
538
 
221
539
  def evaluate(
222
540
  self,
@@ -224,21 +542,26 @@ class KumoRFM:
224
542
  *,
225
543
  metrics: Optional[List[str]] = None,
226
544
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
545
+ context_anchor_time: Union[pd.Timestamp, None] = None,
227
546
  run_mode: Union[RunMode, str] = RunMode.FAST,
228
547
  num_neighbors: Optional[List[int]] = None,
229
548
  num_hops: int = 2,
230
549
  max_pq_iterations: int = 20,
231
550
  random_seed: Optional[int] = _RANDOM_SEED,
232
551
  verbose: Union[bool, ProgressLogger] = True,
552
+ use_prediction_time: bool = False,
233
553
  ) -> pd.DataFrame:
234
554
  """Evaluates a predictive query.
235
555
 
236
556
  Args:
237
557
  query: The predictive query.
238
558
  metrics: The metrics to use.
239
- anchor_time: The anchor timestamp for the query. If set to
240
- :obj:`None`, will use the maximum timestamp in the data.
241
- If set to :`"entity"`, will use the timestamp of the entity.
559
+ anchor_time: The anchor timestamp for the prediction. If set to
560
+ ``None``, will use the maximum timestamp in the data.
561
+ If set to ``"entity"``, will use the timestamp of the entity.
562
+ context_anchor_time: The maximum anchor timestamp for context
563
+ examples. If set to ``None``, ``anchor_time`` will
564
+ determine the anchor time for context examples.
242
565
  run_mode: The :class:`RunMode` for the query.
243
566
  num_neighbors: The number of neighbors to sample for each hop.
244
567
  If specified, the ``num_hops`` option will be ignored.
@@ -250,6 +573,9 @@ class KumoRFM:
250
573
  entities to find valid labels.
251
574
  random_seed: A manual seed for generating pseudo-random numbers.
252
575
  verbose: Whether to print verbose output.
576
+ use_prediction_time: Whether to use the anchor timestamp as an
577
+ additional feature during prediction. This is typically
578
+ beneficial for time series forecasting tasks.
253
579
 
254
580
  Returns:
255
581
  The metrics as a :class:`pandas.DataFrame`
@@ -260,6 +586,12 @@ class KumoRFM:
260
586
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
261
587
  f"custom 'num_hops={num_hops}' option")
262
588
 
589
+ if query_def.rfm_entity_ids is not None:
590
+ query_def = replace(
591
+ query_def,
592
+ rfm_entity_ids=None,
593
+ )
594
+
263
595
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
264
596
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
265
597
 
@@ -268,8 +600,10 @@ class KumoRFM:
268
600
 
269
601
  with verbose as logger:
270
602
  context = self._get_context(
271
- query_def,
603
+ query=query_def,
604
+ indices=None,
272
605
  anchor_time=anchor_time,
606
+ context_anchor_time=context_anchor_time,
273
607
  run_mode=RunMode(run_mode),
274
608
  num_neighbors=num_neighbors,
275
609
  num_hops=num_hops,
@@ -285,6 +619,7 @@ class KumoRFM:
285
619
  context=context,
286
620
  run_mode=RunMode(run_mode),
287
621
  metrics=metrics,
622
+ use_prediction_time=use_prediction_time,
288
623
  )
289
624
  with warnings.catch_warnings():
290
625
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -343,17 +678,19 @@ class KumoRFM:
343
678
 
344
679
  if anchor_time is None:
345
680
  anchor_time = self._graph_store.max_time
346
- anchor_time = anchor_time - query_def.target.end_offset
681
+ if query_def.target_ast.date_offset_range is not None:
682
+ anchor_time = anchor_time - (
683
+ query_def.target_ast.date_offset_range.end_date_offset *
684
+ query_def.num_forecasts)
347
685
 
348
686
  assert anchor_time is not None
349
687
  if isinstance(anchor_time, pd.Timestamp):
350
- self._validate_time(query_def, anchor_time, evaluate=True)
688
+ self._validate_time(query_def, anchor_time, None, evaluate=True)
351
689
  else:
352
690
  assert anchor_time == 'entity'
353
- if (query_def.entity.pkey.table_name
354
- not in self._graph_store.time_dict):
691
+ if (query_def.entity_table not in self._graph_store.time_dict):
355
692
  raise ValueError(f"Anchor time 'entity' requires the entity "
356
- f"table '{query_def.entity.pkey.table_name}' "
693
+ f"table '{query_def.entity_table}' "
357
694
  f"to have a time column")
358
695
 
359
696
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -364,18 +701,22 @@ class KumoRFM:
364
701
  anchor_time=anchor_time,
365
702
  batch_size=min(10_000, size),
366
703
  max_iterations=max_iterations,
704
+ guarantee_train_examples=False,
367
705
  )
368
706
 
707
+ entity = self._graph_store.pkey_map_dict[
708
+ query_def.entity_table].index[node]
709
+
369
710
  return pd.DataFrame({
370
- 'ENTITY': node,
711
+ 'ENTITY': entity,
371
712
  'ANCHOR_TIMESTAMP': time,
372
713
  'TARGET': y,
373
714
  })
374
715
 
375
716
  # Helpers #################################################################
376
717
 
377
- def _parse_query(self, query: str) -> PQueryDefinition:
378
- if isinstance(query, PQueryDefinition):
718
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
719
+ if isinstance(query, ValidatedPredictiveQuery):
379
720
  return query
380
721
 
381
722
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -385,12 +726,12 @@ class KumoRFM:
385
726
  "predictions or evaluations.")
386
727
 
387
728
  try:
388
- request = RFMValidateQueryRequest(
729
+ request = RFMParseQueryRequest(
389
730
  query=query,
390
731
  graph_definition=self._graph_def,
391
732
  )
392
733
 
393
- resp = global_state.client.rfm_api.validate_query(request)
734
+ resp = global_state.client.rfm_api.parse_query(request)
394
735
  # TODO Expose validation warnings.
395
736
 
396
737
  if len(resp.validation_response.warnings) > 0:
@@ -401,7 +742,7 @@ class KumoRFM:
401
742
  warnings.warn(f"Encountered the following warnings during "
402
743
  f"parsing:\n{msg}")
403
744
 
404
- return resp.query_definition
745
+ return resp.query
405
746
  except HTTPException as e:
406
747
  try:
407
748
  msg = json.loads(e.detail)['detail']
@@ -412,8 +753,9 @@ class KumoRFM:
412
753
 
413
754
  def _validate_time(
414
755
  self,
415
- query: PQueryDefinition,
756
+ query: ValidatedPredictiveQuery,
416
757
  anchor_time: pd.Timestamp,
758
+ context_anchor_time: Union[pd.Timestamp, None],
417
759
  evaluate: bool,
418
760
  ) -> None:
419
761
 
@@ -425,20 +767,44 @@ class KumoRFM:
425
767
  f"the earliest timestamp "
426
768
  f"'{self._graph_store.min_time}' in the data.")
427
769
 
428
- req_time = anchor_time - query.target.end_offset * query.num_forecasts
429
- if req_time < self._graph_store.min_time:
430
- raise ValueError(f"Anchor timestamp is too early or aggregation "
431
- f"time range is too large. To make this "
432
- f"prediction, we would need data back to "
433
- f"'{req_time}', however, your data only contains "
434
- f"data back to '{self._graph_store.min_time}'.")
435
-
436
- req_time -= query.target.end_offset * query.num_forecasts
437
- if req_time < self._graph_store.min_time:
438
- warnings.warn(f"Anchor timestamp is too early or aggregation "
439
- f"time range is too large. To form proper input "
440
- f"data, we would need data back to "
441
- f"'{req_time}', however, your data only contains "
770
+ if (context_anchor_time is not None
771
+ and context_anchor_time < self._graph_store.min_time):
772
+ raise ValueError(f"Context anchor timestamp is too early or "
773
+ f"aggregation time range is too large. To make "
774
+ f"this prediction, we would need data back to "
775
+ f"'{context_anchor_time}', however, your data "
776
+ f"only contains data back to "
777
+ f"'{self._graph_store.min_time}'.")
778
+
779
+ if query.target_ast.date_offset_range is not None:
780
+ end_offset = query.target_ast.date_offset_range.end_date_offset
781
+ else:
782
+ end_offset = pd.DateOffset(0)
783
+ forecast_end_offset = end_offset * query.num_forecasts
784
+ if (context_anchor_time is not None
785
+ and context_anchor_time > anchor_time):
786
+ warnings.warn(f"Context anchor timestamp "
787
+ f"(got '{context_anchor_time}') is set to a later "
788
+ f"date than the prediction anchor timestamp "
789
+ f"(got '{anchor_time}'). Please make sure this is "
790
+ f"intended.")
791
+ elif (query.query_type == QueryType.TEMPORAL
792
+ and context_anchor_time is not None
793
+ and context_anchor_time + forecast_end_offset > anchor_time):
794
+ warnings.warn(f"Aggregation for context examples at timestamp "
795
+ f"'{context_anchor_time}' will leak information "
796
+ f"from the prediction anchor timestamp "
797
+ f"'{anchor_time}'. Please make sure this is "
798
+ f"intended.")
799
+
800
+ elif (context_anchor_time is not None
801
+ and context_anchor_time - forecast_end_offset
802
+ < self._graph_store.min_time):
803
+ _time = context_anchor_time - forecast_end_offset
804
+ warnings.warn(f"Context anchor timestamp is too early or "
805
+ f"aggregation time range is too large. To form "
806
+ f"proper input data, we would need data back to "
807
+ f"'{_time}', however, your data only contains "
442
808
  f"data back to '{self._graph_store.min_time}'.")
443
809
 
444
810
  if (not evaluate and anchor_time
@@ -447,8 +813,7 @@ class KumoRFM:
447
813
  f"latest timestamp '{self._graph_store.max_time}' "
448
814
  f"in the data. Please make sure this is intended.")
449
815
 
450
- max_eval_time = (self._graph_store.max_time -
451
- query.target.end_offset * query.num_forecasts)
816
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
452
817
  if evaluate and anchor_time > max_eval_time:
453
818
  raise ValueError(
454
819
  f"Anchor timestamp for evaluation is after the latest "
@@ -456,8 +821,10 @@ class KumoRFM:
456
821
 
457
822
  def _get_context(
458
823
  self,
459
- query: PQueryDefinition,
824
+ query: ValidatedPredictiveQuery,
825
+ indices: Union[List[str], List[float], List[int], None],
460
826
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
827
+ context_anchor_time: Union[pd.Timestamp, None],
461
828
  run_mode: RunMode,
462
829
  num_neighbors: Optional[List[int]],
463
830
  num_hops: int,
@@ -482,8 +849,8 @@ class KumoRFM:
482
849
  f"must go beyond this for your use-case.")
483
850
 
484
851
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
485
- task_type = query.get_task_type(
486
- stypes=self._graph_store.stype_dict,
852
+ task_type = LocalPQueryDriver.get_task_type(
853
+ query,
487
854
  edge_types=self._graph_store.edge_types,
488
855
  )
489
856
 
@@ -515,28 +882,42 @@ class KumoRFM:
515
882
  else:
516
883
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
517
884
 
885
+ if query.target_ast.date_offset_range is None:
886
+ end_offset = pd.DateOffset(0)
887
+ else:
888
+ end_offset = query.target_ast.date_offset_range.end_date_offset
889
+ forecast_end_offset = end_offset * query.num_forecasts
518
890
  if anchor_time is None:
519
891
  anchor_time = self._graph_store.max_time
520
892
  if evaluate:
521
- anchor_time = anchor_time - query.target.end_offset
893
+ anchor_time = anchor_time - forecast_end_offset
522
894
  if logger is not None:
523
895
  assert isinstance(anchor_time, pd.Timestamp)
524
- if (anchor_time.hour == 0 and anchor_time.minute == 0
525
- and anchor_time.second == 0
526
- and anchor_time.microsecond == 0):
896
+ if anchor_time == pd.Timestamp.min:
897
+ pass # Static graph
898
+ elif (anchor_time.hour == 0 and anchor_time.minute == 0
899
+ and anchor_time.second == 0
900
+ and anchor_time.microsecond == 0):
527
901
  logger.log(f"Derived anchor time {anchor_time.date()}")
528
902
  else:
529
903
  logger.log(f"Derived anchor time {anchor_time}")
530
904
 
531
905
  assert anchor_time is not None
532
906
  if isinstance(anchor_time, pd.Timestamp):
533
- self._validate_time(query, anchor_time, evaluate)
907
+ if context_anchor_time is None:
908
+ context_anchor_time = anchor_time - forecast_end_offset
909
+ self._validate_time(query, anchor_time, context_anchor_time,
910
+ evaluate)
534
911
  else:
535
912
  assert anchor_time == 'entity'
536
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
913
+ if query.entity_table not in self._graph_store.time_dict:
537
914
  raise ValueError(f"Anchor time 'entity' requires the entity "
538
- f"table '{query.entity.pkey.table_name}' to "
915
+ f"table '{query.entity_table}' to "
539
916
  f"have a time column")
917
+ if context_anchor_time is not None:
918
+ warnings.warn("Ignoring option 'context_anchor_time' for "
919
+ "`anchor_time='entity'`")
920
+ context_anchor_time = None
540
921
 
541
922
  y_test: Optional[pd.Series] = None
542
923
  if evaluate:
@@ -548,6 +929,7 @@ class KumoRFM:
548
929
  size=max_test_size,
549
930
  anchor_time=anchor_time,
550
931
  max_iterations=max_pq_iterations,
932
+ guarantee_train_examples=True,
551
933
  )
552
934
  if logger is not None:
553
935
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -571,34 +953,31 @@ class KumoRFM:
571
953
  logger.log(msg)
572
954
 
573
955
  else:
574
- assert query.entity.ids is not None
956
+ assert indices is not None
575
957
 
576
- max_num_test = 200 if task_type.is_link_pred else 1000
577
- if len(query.entity.ids.value) > max_num_test:
958
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
578
959
  raise ValueError(f"Cannot predict for more than "
579
- f"{max_num_test:,} entities at once "
580
- f"(got {len(query.entity.ids.value):,})")
960
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
961
+ f"once (got {len(indices):,}). Use "
962
+ f"`KumoRFM.batch_mode` to process entities "
963
+ f"in batches")
581
964
 
582
965
  test_node = self._graph_store.get_node_id(
583
- table_name=query.entity.pkey.table_name,
584
- pkey=pd.Series(
585
- query.entity.ids.value,
586
- dtype=query.entity.ids.dtype,
587
- ),
966
+ table_name=query.entity_table,
967
+ pkey=pd.Series(indices),
588
968
  )
589
969
 
590
970
  if isinstance(anchor_time, pd.Timestamp):
591
971
  test_time = pd.Series(anchor_time).repeat(
592
972
  len(test_node)).reset_index(drop=True)
593
973
  else:
594
- time = self._graph_store.time_dict[
595
- query.entity.pkey.table_name]
974
+ time = self._graph_store.time_dict[query.entity_table]
596
975
  time = time[test_node] * 1000**3
597
976
  test_time = pd.Series(time, dtype='datetime64[ns]')
598
977
 
599
978
  train_node, train_time, y_train = query_driver.collect_train(
600
979
  size=_MAX_CONTEXT_SIZE[run_mode],
601
- anchor_time=anchor_time,
980
+ anchor_time=context_anchor_time or 'entity',
602
981
  exclude_node=test_node if (query.query_type == QueryType.STATIC
603
982
  or anchor_time == 'entity') else None,
604
983
  max_iterations=max_pq_iterations,
@@ -625,12 +1004,23 @@ class KumoRFM:
625
1004
  raise NotImplementedError
626
1005
  logger.log(msg)
627
1006
 
628
- entity_table_names = query.get_entity_table_names(
629
- self._graph_store.edge_types)
1007
+ entity_table_names: Tuple[str, ...]
1008
+ if task_type.is_link_pred:
1009
+ final_aggr = query.get_final_target_aggregation()
1010
+ assert final_aggr is not None
1011
+ edge_fkey = final_aggr._get_target_column_name()
1012
+ for edge_type in self._graph_store.edge_types:
1013
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1014
+ entity_table_names = (
1015
+ query.entity_table,
1016
+ edge_type[2],
1017
+ )
1018
+ else:
1019
+ entity_table_names = (query.entity_table, )
630
1020
 
631
1021
  # Exclude the entity anchor time from the feature set to prevent
632
1022
  # running out-of-distribution between in-context and test examples:
633
- exclude_cols_dict = query.exclude_cols_dict
1023
+ exclude_cols_dict = query.get_exclude_cols_dict()
634
1024
  if anchor_time == 'entity':
635
1025
  if entity_table_names[0] not in exclude_cols_dict:
636
1026
  exclude_cols_dict[entity_table_names[0]] = []
@@ -650,9 +1040,16 @@ class KumoRFM:
650
1040
  exclude_cols_dict=exclude_cols_dict,
651
1041
  )
652
1042
 
1043
+ if len(subgraph.table_dict) >= 15:
1044
+ raise ValueError(f"Cannot query from a graph with more than 15 "
1045
+ f"tables (got {len(subgraph.table_dict)}). "
1046
+ f"Please create a feature request at "
1047
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1048
+ f"must go beyond this for your use-case.")
1049
+
653
1050
  step_size: Optional[int] = None
654
1051
  if query.query_type == QueryType.TEMPORAL:
655
- step_size = date_offset_to_seconds(query.target.end_offset)
1052
+ step_size = date_offset_to_seconds(end_offset)
656
1053
 
657
1054
  return Context(
658
1055
  task_type=task_type,
@@ -677,7 +1074,7 @@ class KumoRFM:
677
1074
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
678
1075
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
679
1076
  elif task_type == TaskType.REGRESSION:
680
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1077
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
681
1078
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
682
1079
  supported_metrics = [
683
1080
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',