kumoai 2.9.0.dev202509061830__cp311-cp311-macosx_11_0_arm64.whl → 2.12.0.dev202511031731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. kumoai/__init__.py +4 -2
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +10 -5
  4. kumoai/client/rfm.py +3 -2
  5. kumoai/connector/file_upload_connector.py +71 -102
  6. kumoai/connector/utils.py +1367 -236
  7. kumoai/experimental/rfm/__init__.py +2 -2
  8. kumoai/experimental/rfm/authenticate.py +8 -5
  9. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  10. kumoai/experimental/rfm/local_graph.py +90 -80
  11. kumoai/experimental/rfm/local_graph_sampler.py +16 -8
  12. kumoai/experimental/rfm/local_graph_store.py +22 -6
  13. kumoai/experimental/rfm/local_pquery_driver.py +129 -28
  14. kumoai/experimental/rfm/local_table.py +100 -22
  15. kumoai/experimental/rfm/pquery/__init__.py +4 -0
  16. kumoai/experimental/rfm/pquery/backend.py +4 -0
  17. kumoai/experimental/rfm/pquery/executor.py +102 -0
  18. kumoai/experimental/rfm/pquery/pandas_backend.py +71 -30
  19. kumoai/experimental/rfm/pquery/pandas_executor.py +506 -0
  20. kumoai/experimental/rfm/rfm.py +442 -94
  21. kumoai/jobs.py +1 -0
  22. kumoai/trainer/trainer.py +19 -10
  23. kumoai/utils/progress_logger.py +62 -0
  24. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/METADATA +4 -5
  25. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/RECORD +28 -26
  26. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/WHEEL +0 -0
  27. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/licenses/LICENSE +0 -0
  28. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,19 @@
1
1
  import json
2
+ import time
2
3
  import warnings
3
- from typing import List, Literal, Optional, Union
4
+ from collections import defaultdict
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import Iterator, List, Literal, Optional, Union, overload
4
9
 
5
10
  import numpy as np
6
11
  import pandas as pd
7
12
  from kumoapi.model_plan import RunMode
8
13
  from kumoapi.pquery import QueryType
14
+ from kumoapi.rfm import Context
15
+ from kumoapi.rfm import Explanation as ExplanationConfig
9
16
  from kumoapi.rfm import (
10
- Context,
11
17
  PQueryDefinition,
12
18
  RFMEvaluateRequest,
13
19
  RFMPredictRequest,
@@ -20,11 +26,17 @@ from kumoai.exceptions import HTTPException
20
26
  from kumoai.experimental.rfm import LocalGraph
21
27
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
22
28
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
23
- from kumoai.experimental.rfm.local_pquery_driver import LocalPQueryDriver
29
+ from kumoai.experimental.rfm.local_pquery_driver import (
30
+ LocalPQueryDriver,
31
+ date_offset_to_seconds,
32
+ )
24
33
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
25
34
 
26
35
  _RANDOM_SEED = 42
27
36
 
37
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
38
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
39
+
28
40
  _MAX_CONTEXT_SIZE = {
29
41
  RunMode.DEBUG: 100,
30
42
  RunMode.FAST: 1_000,
@@ -39,7 +51,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
39
51
  }
40
52
 
41
53
  _MAX_SIZE = 30 * 1024 * 1024
42
- _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
54
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
43
55
  "reduce either the number of tables in the graph, their "
44
56
  "number of columns (e.g., large text columns), "
45
57
  "neighborhood configuration, or the run mode. If none of "
@@ -48,6 +60,34 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
48
60
  "beyond this for your use-case.")
49
61
 
50
62
 
63
+ @dataclass(repr=False)
64
+ class Explanation:
65
+ prediction: pd.DataFrame
66
+ summary: str
67
+ details: ExplanationConfig
68
+
69
+ @overload
70
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
71
+ pass
72
+
73
+ @overload
74
+ def __getitem__(self, index: Literal[1]) -> str:
75
+ pass
76
+
77
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
78
+ if index == 0:
79
+ return self.prediction
80
+ if index == 1:
81
+ return self.summary
82
+ raise IndexError("Index out of range")
83
+
84
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
85
+ return iter((self.prediction, self.summary))
86
+
87
+ def __repr__(self) -> str:
88
+ return str((self.prediction, self.summary))
89
+
90
+
51
91
  class KumoRFM:
52
92
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
53
93
  Foundation Model for In-Context Learning on Relational Data
@@ -105,28 +145,117 @@ class KumoRFM:
105
145
  self._graph_store = LocalGraphStore(graph, preprocess, verbose)
106
146
  self._graph_sampler = LocalGraphSampler(self._graph_store)
107
147
 
148
+ self._batch_size: Optional[int | Literal['max']] = None
149
+ self.num_retries: int = 0
150
+
108
151
  def __repr__(self) -> str:
109
152
  return f'{self.__class__.__name__}()'
110
153
 
154
+ @contextmanager
155
+ def batch_mode(
156
+ self,
157
+ batch_size: Union[int, Literal['max']] = 'max',
158
+ num_retries: int = 1,
159
+ ) -> Generator[None, None, None]:
160
+ """Context manager to predict in batches.
161
+
162
+ .. code-block:: python
163
+
164
+ with model.batch_mode(batch_size='max', num_retries=1):
165
+ df = model.predict(query, indices=...)
166
+
167
+ Args:
168
+ batch_size: The batch size. If set to ``"max"``, will use the
169
+ maximum applicable batch size for the given task.
170
+ num_retries: The maximum number of retries for failed queries due
171
+ to unexpected server issues.
172
+ """
173
+ if batch_size != 'max' and batch_size <= 0:
174
+ raise ValueError(f"'batch_size' must be greater than zero "
175
+ f"(got {batch_size})")
176
+
177
+ if num_retries < 0:
178
+ raise ValueError(f"'num_retries' must be greater than or equal to "
179
+ f"zero (got {num_retries})")
180
+
181
+ self._batch_size = batch_size
182
+ self.num_retries = num_retries
183
+ yield
184
+ self._batch_size = None
185
+ self.num_retries = 0
186
+
187
+ @overload
111
188
  def predict(
112
189
  self,
113
190
  query: str,
191
+ indices: Union[List[str], List[float], List[int], None] = None,
114
192
  *,
193
+ explain: Literal[False] = False,
115
194
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
195
+ context_anchor_time: Union[pd.Timestamp, None] = None,
116
196
  run_mode: Union[RunMode, str] = RunMode.FAST,
117
197
  num_neighbors: Optional[List[int]] = None,
118
198
  num_hops: int = 2,
119
199
  max_pq_iterations: int = 20,
120
200
  random_seed: Optional[int] = _RANDOM_SEED,
121
201
  verbose: Union[bool, ProgressLogger] = True,
202
+ use_prediction_time: bool = False,
122
203
  ) -> pd.DataFrame:
204
+ pass
205
+
206
+ @overload
207
+ def predict(
208
+ self,
209
+ query: str,
210
+ indices: Union[List[str], List[float], List[int], None] = None,
211
+ *,
212
+ explain: Literal[True],
213
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
214
+ context_anchor_time: Union[pd.Timestamp, None] = None,
215
+ run_mode: Union[RunMode, str] = RunMode.FAST,
216
+ num_neighbors: Optional[List[int]] = None,
217
+ num_hops: int = 2,
218
+ max_pq_iterations: int = 20,
219
+ random_seed: Optional[int] = _RANDOM_SEED,
220
+ verbose: Union[bool, ProgressLogger] = True,
221
+ use_prediction_time: bool = False,
222
+ ) -> Explanation:
223
+ pass
224
+
225
+ def predict(
226
+ self,
227
+ query: str,
228
+ indices: Union[List[str], List[float], List[int], None] = None,
229
+ *,
230
+ explain: bool = False,
231
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
232
+ context_anchor_time: Union[pd.Timestamp, None] = None,
233
+ run_mode: Union[RunMode, str] = RunMode.FAST,
234
+ num_neighbors: Optional[List[int]] = None,
235
+ num_hops: int = 2,
236
+ max_pq_iterations: int = 20,
237
+ random_seed: Optional[int] = _RANDOM_SEED,
238
+ verbose: Union[bool, ProgressLogger] = True,
239
+ use_prediction_time: bool = False,
240
+ ) -> Union[pd.DataFrame, Explanation]:
123
241
  """Returns predictions for a predictive query.
124
242
 
125
243
  Args:
126
244
  query: The predictive query.
127
- anchor_time: The anchor timestamp for the query. If set to
128
- :obj:`None`, will use the maximum timestamp in the data.
129
- If set to :`"entity"`, will use the timestamp of the entity.
245
+ indices: The entity primary keys to predict for. Will override the
246
+ indices given as part of the predictive query. Predictions will
247
+ be generated for all indices, independent of whether they
248
+ fulfill entity filter constraints. To pre-filter entities, use
249
+ :meth:`~KumoRFM.is_valid_entity`.
250
+ explain: If set to ``True``, will additionally explain the
251
+ prediction. Explainability is currently only supported for
252
+ single entity predictions with ``run_mode="FAST"``.
253
+ anchor_time: The anchor timestamp for the prediction. If set to
254
+ ``None``, will use the maximum timestamp in the data.
255
+ If set to ``"entity"``, will use the timestamp of the entity.
256
+ context_anchor_time: The maximum anchor timestamp for context
257
+ examples. If set to ``None``, ``anchor_time`` will
258
+ determine the anchor time for context examples.
130
259
  run_mode: The :class:`RunMode` for the query.
131
260
  num_neighbors: The number of neighbors to sample for each hop.
132
261
  If specified, the ``num_hops`` option will be ignored.
@@ -138,11 +267,15 @@ class KumoRFM:
138
267
  entities to find valid labels.
139
268
  random_seed: A manual seed for generating pseudo-random numbers.
140
269
  verbose: Whether to print verbose output.
270
+ use_prediction_time: Whether to use the anchor timestamp as an
271
+ additional feature during prediction. This is typically
272
+ beneficial for time series forecasting tasks.
141
273
 
142
274
  Returns:
143
- The predictions as a :class:`pandas.DataFrame`
275
+ The predictions as a :class:`pandas.DataFrame`.
276
+ If ``explain=True``, additionally returns a textual summary that
277
+ explains the prediction.
144
278
  """
145
- explain = False
146
279
  query_def = self._parse_query(query)
147
280
 
148
281
  if num_hops != 2 and num_neighbors is not None:
@@ -155,12 +288,24 @@ class KumoRFM:
155
288
  f"mode has been reset. Please lower the run mode to "
156
289
  f"suppress this warning.")
157
290
 
158
- if explain:
159
- assert query_def.entity.ids is not None
160
- if len(query_def.entity.ids.value) > 1:
161
- raise ValueError(
162
- f"Cannot explain predictions for more than a single "
163
- f"entity (got {len(query_def.entity.ids.value)})")
291
+ if indices is None:
292
+ if query_def.entity.ids is None:
293
+ raise ValueError("Cannot find entities to predict for. Please "
294
+ "pass them via `predict(query, indices=...)`")
295
+ indices = query_def.entity.ids.value
296
+ else:
297
+ query_def = replace(
298
+ query_def,
299
+ entity=replace(query_def.entity, ids=None),
300
+ )
301
+
302
+ if len(indices) == 0:
303
+ raise ValueError("At least one entity is required")
304
+
305
+ if explain and len(indices) > 1:
306
+ raise ValueError(
307
+ f"Cannot explain predictions for more than a single entity "
308
+ f"(got {len(indices)})")
164
309
 
165
310
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
166
311
  if explain:
@@ -172,48 +317,185 @@ class KumoRFM:
172
317
  verbose = InteractiveProgressLogger(msg, verbose=verbose)
173
318
 
174
319
  with verbose as logger:
175
- context = self._get_context(
176
- query_def,
177
- anchor_time=anchor_time,
178
- run_mode=RunMode(run_mode),
179
- num_neighbors=num_neighbors,
180
- num_hops=num_hops,
181
- max_pq_iterations=max_pq_iterations,
182
- evaluate=False,
183
- random_seed=random_seed,
184
- logger=logger,
185
- )
186
- request = RFMPredictRequest(
187
- context=context,
188
- run_mode=RunMode(run_mode),
320
+
321
+ batch_size: Optional[int] = None
322
+ if self._batch_size == 'max':
323
+ task_type = query_def.get_task_type(
324
+ stypes=self._graph_store.stype_dict,
325
+ edge_types=self._graph_store.edge_types,
326
+ )
327
+ batch_size = _MAX_PRED_SIZE[task_type]
328
+ else:
329
+ batch_size = self._batch_size
330
+
331
+ if batch_size is not None:
332
+ offsets = range(0, len(indices), batch_size)
333
+ batches = [indices[step:step + batch_size] for step in offsets]
334
+ else:
335
+ batches = [indices]
336
+
337
+ if len(batches) > 1:
338
+ logger.log(f"Splitting {len(indices):,} entities into "
339
+ f"{len(batches):,} batches of size {batch_size:,}")
340
+
341
+ predictions: List[pd.DataFrame] = []
342
+ summary: Optional[str] = None
343
+ details: Optional[Explanation] = None
344
+ for i, batch in enumerate(batches):
345
+ # TODO Re-use the context for subsequent predictions.
346
+ context = self._get_context(
347
+ query=query_def,
348
+ indices=batch,
349
+ anchor_time=anchor_time,
350
+ context_anchor_time=context_anchor_time,
351
+ run_mode=RunMode(run_mode),
352
+ num_neighbors=num_neighbors,
353
+ num_hops=num_hops,
354
+ max_pq_iterations=max_pq_iterations,
355
+ evaluate=False,
356
+ random_seed=random_seed,
357
+ logger=logger if i == 0 else None,
358
+ )
359
+ request = RFMPredictRequest(
360
+ context=context,
361
+ run_mode=RunMode(run_mode),
362
+ use_prediction_time=use_prediction_time,
363
+ )
364
+ with warnings.catch_warnings():
365
+ warnings.filterwarnings('ignore', message='gencode')
366
+ request_msg = request.to_protobuf()
367
+ _bytes = request_msg.SerializeToString()
368
+ if i == 0:
369
+ logger.log(f"Generated context of size "
370
+ f"{len(_bytes) / (1024*1024):.2f}MB")
371
+
372
+ if len(_bytes) > _MAX_SIZE:
373
+ stats = Context.get_memory_stats(request_msg.context)
374
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
375
+
376
+ if (isinstance(verbose, InteractiveProgressLogger) and i == 0
377
+ and len(batches) > 1):
378
+ verbose.init_progress(
379
+ total=len(batches),
380
+ description='Predicting',
381
+ )
382
+
383
+ for attempt in range(self.num_retries + 1):
384
+ try:
385
+ if explain:
386
+ resp = global_state.client.rfm_api.explain(_bytes)
387
+ summary = resp.summary
388
+ details = resp.details
389
+ else:
390
+ resp = global_state.client.rfm_api.predict(_bytes)
391
+ df = pd.DataFrame(**resp.prediction)
392
+
393
+ # Cast 'ENTITY' to correct data type:
394
+ if 'ENTITY' in df:
395
+ entity = query_def.entity.pkey.table_name
396
+ pkey_map = self._graph_store.pkey_map_dict[entity]
397
+ df['ENTITY'] = df['ENTITY'].astype(
398
+ type(pkey_map.index[0]))
399
+
400
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
401
+ if 'ANCHOR_TIMESTAMP' in df:
402
+ ser = df['ANCHOR_TIMESTAMP']
403
+ if not pd.api.types.is_datetime64_any_dtype(ser):
404
+ if isinstance(ser.iloc[0], str):
405
+ unit = None
406
+ else:
407
+ unit = 'ms'
408
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
409
+ ser, errors='coerce', unit=unit)
410
+
411
+ predictions.append(df)
412
+
413
+ if (isinstance(verbose, InteractiveProgressLogger)
414
+ and len(batches) > 1):
415
+ verbose.step()
416
+
417
+ break
418
+ except HTTPException as e:
419
+ if attempt == self.num_retries:
420
+ try:
421
+ msg = json.loads(e.detail)['detail']
422
+ except Exception:
423
+ msg = e.detail
424
+ raise RuntimeError(
425
+ f"An unexpected exception occurred. Please "
426
+ f"create an issue at "
427
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
428
+ ) from None
429
+
430
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
431
+
432
+ if len(predictions) == 1:
433
+ prediction = predictions[0]
434
+ else:
435
+ prediction = pd.concat(predictions, ignore_index=True)
436
+
437
+ if explain:
438
+ assert len(predictions) == 1
439
+ assert summary is not None
440
+ assert details is not None
441
+ return Explanation(
442
+ prediction=prediction,
443
+ summary=summary,
444
+ details=details,
189
445
  )
190
- with warnings.catch_warnings():
191
- warnings.filterwarnings('ignore', message='Protobuf gencode')
192
- request_msg = request.to_protobuf()
193
- request_bytes = request_msg.SerializeToString()
194
- logger.log(f"Generated context of size "
195
- f"{len(request_bytes) / (1024*1024):.2f}MB")
196
446
 
197
- if len(request_bytes) > _MAX_SIZE:
198
- stats_msg = Context.get_memory_stats(request_msg.context)
199
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
447
+ return prediction
200
448
 
201
- try:
202
- if explain:
203
- resp = global_state.client.rfm_api.explain(request_bytes)
204
- else:
205
- resp = global_state.client.rfm_api.predict(request_bytes)
206
- except HTTPException as e:
207
- try:
208
- msg = json.loads(e.detail)['detail']
209
- except Exception:
210
- msg = e.detail
211
- raise RuntimeError(f"An unexpected exception occurred. "
212
- f"Please create an issue at "
213
- f"'https://github.com/kumo-ai/kumo-rfm'. "
214
- f"{msg}") from None
449
+ def is_valid_entity(
450
+ self,
451
+ query: str,
452
+ indices: Union[List[str], List[float], List[int], None] = None,
453
+ *,
454
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
455
+ ) -> np.ndarray:
456
+ r"""Returns a mask that denotes which entities are valid for the
457
+ given predictive query, *i.e.*, which entities fulfill (temporal)
458
+ entity filter constraints.
459
+
460
+ Args:
461
+ query: The predictive query.
462
+ indices: The entity primary keys to predict for. Will override the
463
+ indices given as part of the predictive query.
464
+ anchor_time: The anchor timestamp for the prediction. If set to
465
+ ``None``, will use the maximum timestamp in the data.
466
+ If set to ``"entity"``, will use the timestamp of the entity.
467
+ """
468
+ query_def = self._parse_query(query)
469
+
470
+ if indices is None:
471
+ if query_def.entity.ids is None:
472
+ raise ValueError("Cannot find entities to predict for. Please "
473
+ "pass them via "
474
+ "`is_valid_entity(query, indices=...)`")
475
+ indices = query_def.entity.ids.value
476
+
477
+ if len(indices) == 0:
478
+ raise ValueError("At least one entity is required")
479
+
480
+ if anchor_time is None:
481
+ anchor_time = self._graph_store.max_time
482
+
483
+ if isinstance(anchor_time, pd.Timestamp):
484
+ self._validate_time(query_def, anchor_time, None, False)
485
+ else:
486
+ assert anchor_time == 'entity'
487
+ if (query_def.entity.pkey.table_name
488
+ not in self._graph_store.time_dict):
489
+ raise ValueError(f"Anchor time 'entity' requires the entity "
490
+ f"table '{query_def.entity.pkey.table_name}' "
491
+ f"to have a time column")
215
492
 
216
- return pd.DataFrame(**resp.prediction)
493
+ node = self._graph_store.get_node_id(
494
+ table_name=query_def.entity.pkey.table_name,
495
+ pkey=pd.Series(indices),
496
+ )
497
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
498
+ return query_driver.is_valid(node, anchor_time)
217
499
 
218
500
  def evaluate(
219
501
  self,
@@ -221,21 +503,26 @@ class KumoRFM:
221
503
  *,
222
504
  metrics: Optional[List[str]] = None,
223
505
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
506
+ context_anchor_time: Union[pd.Timestamp, None] = None,
224
507
  run_mode: Union[RunMode, str] = RunMode.FAST,
225
508
  num_neighbors: Optional[List[int]] = None,
226
509
  num_hops: int = 2,
227
510
  max_pq_iterations: int = 20,
228
511
  random_seed: Optional[int] = _RANDOM_SEED,
229
512
  verbose: Union[bool, ProgressLogger] = True,
513
+ use_prediction_time: bool = False,
230
514
  ) -> pd.DataFrame:
231
515
  """Evaluates a predictive query.
232
516
 
233
517
  Args:
234
518
  query: The predictive query.
235
519
  metrics: The metrics to use.
236
- anchor_time: The anchor timestamp for the query. If set to
237
- :obj:`None`, will use the maximum timestamp in the data.
238
- If set to :`"entity"`, will use the timestamp of the entity.
520
+ anchor_time: The anchor timestamp for the prediction. If set to
521
+ ``None``, will use the maximum timestamp in the data.
522
+ If set to ``"entity"``, will use the timestamp of the entity.
523
+ context_anchor_time: The maximum anchor timestamp for context
524
+ examples. If set to ``None``, ``anchor_time`` will
525
+ determine the anchor time for context examples.
239
526
  run_mode: The :class:`RunMode` for the query.
240
527
  num_neighbors: The number of neighbors to sample for each hop.
241
528
  If specified, the ``num_hops`` option will be ignored.
@@ -247,6 +534,9 @@ class KumoRFM:
247
534
  entities to find valid labels.
248
535
  random_seed: A manual seed for generating pseudo-random numbers.
249
536
  verbose: Whether to print verbose output.
537
+ use_prediction_time: Whether to use the anchor timestamp as an
538
+ additional feature during prediction. This is typically
539
+ beneficial for time series forecasting tasks.
250
540
 
251
541
  Returns:
252
542
  The metrics as a :class:`pandas.DataFrame`
@@ -257,6 +547,12 @@ class KumoRFM:
257
547
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
258
548
  f"custom 'num_hops={num_hops}' option")
259
549
 
550
+ if query_def.entity.ids is not None:
551
+ query_def = replace(
552
+ query_def,
553
+ entity=replace(query_def.entity, ids=None),
554
+ )
555
+
260
556
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
261
557
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
262
558
 
@@ -265,8 +561,10 @@ class KumoRFM:
265
561
 
266
562
  with verbose as logger:
267
563
  context = self._get_context(
268
- query_def,
564
+ query=query_def,
565
+ indices=None,
269
566
  anchor_time=anchor_time,
567
+ context_anchor_time=context_anchor_time,
270
568
  run_mode=RunMode(run_mode),
271
569
  num_neighbors=num_neighbors,
272
570
  num_hops=num_hops,
@@ -282,6 +580,7 @@ class KumoRFM:
282
580
  context=context,
283
581
  run_mode=RunMode(run_mode),
284
582
  metrics=metrics,
583
+ use_prediction_time=use_prediction_time,
285
584
  )
286
585
  with warnings.catch_warnings():
287
586
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -340,11 +639,12 @@ class KumoRFM:
340
639
 
341
640
  if anchor_time is None:
342
641
  anchor_time = self._graph_store.max_time
343
- anchor_time = anchor_time - query_def.target.end_offset
642
+ anchor_time = anchor_time - (query_def.target.end_offset *
643
+ query_def.num_forecasts)
344
644
 
345
645
  assert anchor_time is not None
346
646
  if isinstance(anchor_time, pd.Timestamp):
347
- self._validate_time(query_def, anchor_time, evaluate=True)
647
+ self._validate_time(query_def, anchor_time, None, evaluate=True)
348
648
  else:
349
649
  assert anchor_time == 'entity'
350
650
  if (query_def.entity.pkey.table_name
@@ -361,10 +661,14 @@ class KumoRFM:
361
661
  anchor_time=anchor_time,
362
662
  batch_size=min(10_000, size),
363
663
  max_iterations=max_iterations,
664
+ guarantee_train_examples=False,
364
665
  )
365
666
 
667
+ entity = self._graph_store.pkey_map_dict[
668
+ query_def.entity.pkey.table_name].index[node]
669
+
366
670
  return pd.DataFrame({
367
- 'ENTITY': node,
671
+ 'ENTITY': entity,
368
672
  'ANCHOR_TIMESTAMP': time,
369
673
  'TARGET': y,
370
674
  })
@@ -411,6 +715,7 @@ class KumoRFM:
411
715
  self,
412
716
  query: PQueryDefinition,
413
717
  anchor_time: pd.Timestamp,
718
+ context_anchor_time: Union[pd.Timestamp, None],
414
719
  evaluate: bool,
415
720
  ) -> None:
416
721
 
@@ -422,22 +727,41 @@ class KumoRFM:
422
727
  f"the earliest timestamp "
423
728
  f"'{self._graph_store.min_time}' in the data.")
424
729
 
425
- if anchor_time - query.target.end_offset < self._graph_store.min_time:
426
- raise ValueError(f"Anchor timestamp is too early or aggregation "
427
- f"time range is too large. To make this "
428
- f"prediction, we would need data back to "
429
- f"'{anchor_time - query.target.end_offset}', "
430
- f"however, your data only contains data back to "
730
+ if (context_anchor_time is not None
731
+ and context_anchor_time < self._graph_store.min_time):
732
+ raise ValueError(f"Context anchor timestamp is too early or "
733
+ f"aggregation time range is too large. To make "
734
+ f"this prediction, we would need data back to "
735
+ f"'{context_anchor_time}', however, your data "
736
+ f"only contains data back to "
431
737
  f"'{self._graph_store.min_time}'.")
432
738
 
433
- if (anchor_time - 2 * query.target.end_offset
434
- < self._graph_store.min_time):
435
- warnings.warn(f"Anchor timestamp is too early or aggregation "
436
- f"time range is too large. To form proper input "
437
- f"data, we would need data back to "
438
- f"'{anchor_time - 2 * query.target.end_offset}', "
439
- f"however, your data only contains data back to "
440
- f"'{self._graph_store.min_time}'.")
739
+ if (context_anchor_time is not None
740
+ and context_anchor_time > anchor_time):
741
+ warnings.warn(f"Context anchor timestamp "
742
+ f"(got '{context_anchor_time}') is set to a later "
743
+ f"date than the prediction anchor timestamp "
744
+ f"(got '{anchor_time}'). Please make sure this is "
745
+ f"intended.")
746
+ elif (query.query_type == QueryType.TEMPORAL
747
+ and context_anchor_time is not None and context_anchor_time +
748
+ query.target.end_offset * query.num_forecasts > anchor_time):
749
+ warnings.warn(f"Aggregation for context examples at timestamp "
750
+ f"'{context_anchor_time}' will leak information "
751
+ f"from the prediction anchor timestamp "
752
+ f"'{anchor_time}'. Please make sure this is "
753
+ f"intended.")
754
+
755
+ elif (context_anchor_time is not None and context_anchor_time -
756
+ query.target.end_offset * query.num_forecasts
757
+ < self._graph_store.min_time):
758
+ _time = context_anchor_time - (query.target.end_offset *
759
+ query.num_forecasts)
760
+ warnings.warn(f"Context anchor timestamp is too early or "
761
+ f"aggregation time range is too large. To form "
762
+ f"proper input data, we would need data back to "
763
+ f"'{_time}', however, your data only contains "
764
+ f"data back to '{self._graph_store.min_time}'.")
441
765
 
442
766
  if (not evaluate and anchor_time
443
767
  > self._graph_store.max_time + pd.DateOffset(days=1)):
@@ -445,17 +769,19 @@ class KumoRFM:
445
769
  f"latest timestamp '{self._graph_store.max_time}' "
446
770
  f"in the data. Please make sure this is intended.")
447
771
 
448
- if (evaluate and anchor_time
449
- > self._graph_store.max_time - query.target.end_offset):
772
+ max_eval_time = (self._graph_store.max_time -
773
+ query.target.end_offset * query.num_forecasts)
774
+ if evaluate and anchor_time > max_eval_time:
450
775
  raise ValueError(
451
776
  f"Anchor timestamp for evaluation is after the latest "
452
- f"supported timestamp "
453
- f"'{self._graph_store.max_time - query.target.end_offset}'.")
777
+ f"supported timestamp '{max_eval_time}'.")
454
778
 
455
779
  def _get_context(
456
780
  self,
457
781
  query: PQueryDefinition,
782
+ indices: Union[List[str], List[float], List[int], None],
458
783
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
784
+ context_anchor_time: Union[pd.Timestamp, None],
459
785
  run_mode: RunMode,
460
786
  num_neighbors: Optional[List[int]],
461
787
  num_hops: int,
@@ -516,25 +842,36 @@ class KumoRFM:
516
842
  if anchor_time is None:
517
843
  anchor_time = self._graph_store.max_time
518
844
  if evaluate:
519
- anchor_time = anchor_time - query.target.end_offset
845
+ anchor_time = anchor_time - (query.target.end_offset *
846
+ query.num_forecasts)
520
847
  if logger is not None:
521
848
  assert isinstance(anchor_time, pd.Timestamp)
522
- if (anchor_time.hour == 0 and anchor_time.minute == 0
523
- and anchor_time.second == 0
524
- and anchor_time.microsecond == 0):
849
+ if anchor_time == pd.Timestamp.min:
850
+ pass # Static graph
851
+ elif (anchor_time.hour == 0 and anchor_time.minute == 0
852
+ and anchor_time.second == 0
853
+ and anchor_time.microsecond == 0):
525
854
  logger.log(f"Derived anchor time {anchor_time.date()}")
526
855
  else:
527
856
  logger.log(f"Derived anchor time {anchor_time}")
528
857
 
529
858
  assert anchor_time is not None
530
859
  if isinstance(anchor_time, pd.Timestamp):
531
- self._validate_time(query, anchor_time, evaluate)
860
+ if context_anchor_time is None:
861
+ context_anchor_time = anchor_time - (query.target.end_offset *
862
+ query.num_forecasts)
863
+ self._validate_time(query, anchor_time, context_anchor_time,
864
+ evaluate)
532
865
  else:
533
866
  assert anchor_time == 'entity'
534
867
  if query.entity.pkey.table_name not in self._graph_store.time_dict:
535
868
  raise ValueError(f"Anchor time 'entity' requires the entity "
536
869
  f"table '{query.entity.pkey.table_name}' to "
537
870
  f"have a time column")
871
+ if context_anchor_time is not None:
872
+ warnings.warn("Ignoring option 'context_anchor_time' for "
873
+ "`anchor_time='entity'`")
874
+ context_anchor_time = None
538
875
 
539
876
  y_test: Optional[pd.Series] = None
540
877
  if evaluate:
@@ -546,6 +883,7 @@ class KumoRFM:
546
883
  size=max_test_size,
547
884
  anchor_time=anchor_time,
548
885
  max_iterations=max_pq_iterations,
886
+ guarantee_train_examples=True,
549
887
  )
550
888
  if logger is not None:
551
889
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -569,20 +907,18 @@ class KumoRFM:
569
907
  logger.log(msg)
570
908
 
571
909
  else:
572
- assert query.entity.ids is not None
910
+ assert indices is not None
573
911
 
574
- max_num_test = 200 if task_type.is_link_pred else 1000
575
- if len(query.entity.ids.value) > max_num_test:
912
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
576
913
  raise ValueError(f"Cannot predict for more than "
577
- f"{max_num_test:,} entities at once "
578
- f"(got {len(query.entity.ids.value):,})")
914
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
915
+ f"once (got {len(indices):,}). Use "
916
+ f"`KumoRFM.batch_mode` to process entities "
917
+ f"in batches")
579
918
 
580
919
  test_node = self._graph_store.get_node_id(
581
920
  table_name=query.entity.pkey.table_name,
582
- pkey=pd.Series(
583
- query.entity.ids.value,
584
- dtype=query.entity.ids.dtype,
585
- ),
921
+ pkey=pd.Series(indices),
586
922
  )
587
923
 
588
924
  if isinstance(anchor_time, pd.Timestamp):
@@ -596,7 +932,7 @@ class KumoRFM:
596
932
 
597
933
  train_node, train_time, y_train = query_driver.collect_train(
598
934
  size=_MAX_CONTEXT_SIZE[run_mode],
599
- anchor_time=anchor_time,
935
+ anchor_time=context_anchor_time or 'entity',
600
936
  exclude_node=test_node if (query.query_type == QueryType.STATIC
601
937
  or anchor_time == 'entity') else None,
602
938
  max_iterations=max_pq_iterations,
@@ -648,6 +984,17 @@ class KumoRFM:
648
984
  exclude_cols_dict=exclude_cols_dict,
649
985
  )
650
986
 
987
+ if len(subgraph.table_dict) >= 15:
988
+ raise ValueError(f"Cannot query from a graph with more than 15 "
989
+ f"tables (got {len(subgraph.table_dict)}). "
990
+ f"Please create a feature request at "
991
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
992
+ f"must go beyond this for your use-case.")
993
+
994
+ step_size: Optional[int] = None
995
+ if query.query_type == QueryType.TEMPORAL:
996
+ step_size = date_offset_to_seconds(query.target.end_offset)
997
+
651
998
  return Context(
652
999
  task_type=task_type,
653
1000
  entity_table_names=entity_table_names,
@@ -655,6 +1002,7 @@ class KumoRFM:
655
1002
  y_train=y_train,
656
1003
  y_test=y_test,
657
1004
  top_k=query.top_k,
1005
+ step_size=step_size,
658
1006
  )
659
1007
 
660
1008
  @staticmethod
@@ -670,7 +1018,7 @@ class KumoRFM:
670
1018
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
671
1019
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
672
1020
  elif task_type == TaskType.REGRESSION:
673
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1021
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
674
1022
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
675
1023
  supported_metrics = [
676
1024
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',