kumoai 2.9.0.dev202509081831__cp312-cp312-win_amd64.whl → 2.13.0.dev202511201731__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. kumoai/__init__.py +10 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +37 -8
  6. kumoai/connector/file_upload_connector.py +71 -102
  7. kumoai/connector/utils.py +1367 -236
  8. kumoai/experimental/rfm/__init__.py +153 -10
  9. kumoai/experimental/rfm/authenticate.py +8 -5
  10. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  11. kumoai/experimental/rfm/local_graph.py +90 -80
  12. kumoai/experimental/rfm/local_graph_sampler.py +16 -10
  13. kumoai/experimental/rfm/local_graph_store.py +22 -6
  14. kumoai/experimental/rfm/local_pquery_driver.py +336 -42
  15. kumoai/experimental/rfm/local_table.py +100 -22
  16. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  17. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
  18. kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -222
  19. kumoai/experimental/rfm/rfm.py +523 -124
  20. kumoai/experimental/rfm/sagemaker.py +130 -0
  21. kumoai/jobs.py +1 -0
  22. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  23. kumoai/spcs.py +1 -3
  24. kumoai/trainer/trainer.py +19 -10
  25. kumoai/utils/progress_logger.py +68 -0
  26. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/METADATA +13 -5
  27. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/RECORD +30 -29
  28. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/WHEEL +0 -0
  29. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/licenses/LICENSE +0 -0
  30. {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,36 @@
1
1
  import json
2
+ import time
2
3
  import warnings
3
- from typing import List, Literal, Optional, Union
4
+ from collections import defaultdict
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
4
19
 
5
20
  import numpy as np
6
21
  import pandas as pd
7
22
  from kumoapi.model_plan import RunMode
8
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.rfm import Context
25
+ from kumoapi.rfm import Explanation as ExplanationConfig
9
26
  from kumoapi.rfm import (
10
- Context,
11
- PQueryDefinition,
12
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
13
29
  RFMPredictRequest,
14
- RFMValidateQueryRequest,
15
30
  )
16
31
  from kumoapi.task import TaskType
17
32
 
18
- from kumoai import global_state
33
+ from kumoai.client.rfm import RFMAPI
19
34
  from kumoai.exceptions import HTTPException
20
35
  from kumoai.experimental.rfm import LocalGraph
21
36
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
@@ -24,10 +39,14 @@ from kumoai.experimental.rfm.local_pquery_driver import (
24
39
  LocalPQueryDriver,
25
40
  date_offset_to_seconds,
26
41
  )
42
+ from kumoai.mixin import CastMixin
27
43
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
28
44
 
29
45
  _RANDOM_SEED = 42
30
46
 
47
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
+
31
50
  _MAX_CONTEXT_SIZE = {
32
51
  RunMode.DEBUG: 100,
33
52
  RunMode.FAST: 1_000,
@@ -42,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
42
61
  }
43
62
 
44
63
  _MAX_SIZE = 30 * 1024 * 1024
45
- _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
64
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
46
65
  "reduce either the number of tables in the graph, their "
47
66
  "number of columns (e.g., large text columns), "
48
67
  "neighborhood configuration, or the run mode. If none of "
@@ -51,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
51
70
  "beyond this for your use-case.")
52
71
 
53
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
84
+ @dataclass(repr=False)
85
+ class Explanation:
86
+ prediction: pd.DataFrame
87
+ summary: str
88
+ details: ExplanationConfig
89
+
90
+ @overload
91
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
92
+ pass
93
+
94
+ @overload
95
+ def __getitem__(self, index: Literal[1]) -> str:
96
+ pass
97
+
98
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
99
+ if index == 0:
100
+ return self.prediction
101
+ if index == 1:
102
+ return self.summary
103
+ raise IndexError("Index out of range")
104
+
105
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
106
+ return iter((self.prediction, self.summary))
107
+
108
+ def __repr__(self) -> str:
109
+ return str((self.prediction, self.summary))
110
+
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
117
+
54
118
  class KumoRFM:
55
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
56
120
  Foundation Model for In-Context Learning on Relational Data
@@ -77,9 +141,9 @@ class KumoRFM:
77
141
 
78
142
  rfm = KumoRFM(graph)
79
143
 
80
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
81
- "FOR users.user_id=0")
82
- result = rfm.query(query)
144
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
145
+ "FOR users.user_id=1")
146
+ result = rfm.predict(query)
83
147
 
84
148
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
85
149
  # 1 0.85
@@ -108,28 +172,122 @@ class KumoRFM:
108
172
  self._graph_store = LocalGraphStore(graph, preprocess, verbose)
109
173
  self._graph_sampler = LocalGraphSampler(self._graph_store)
110
174
 
175
+ self._batch_size: Optional[int | Literal['max']] = None
176
+ self.num_retries: int = 0
177
+ from kumoai.experimental.rfm import global_state
178
+ self._api_client = RFMAPI(global_state.client)
179
+
111
180
  def __repr__(self) -> str:
112
181
  return f'{self.__class__.__name__}()'
113
182
 
183
+ @contextmanager
184
+ def batch_mode(
185
+ self,
186
+ batch_size: Union[int, Literal['max']] = 'max',
187
+ num_retries: int = 1,
188
+ ) -> Generator[None, None, None]:
189
+ """Context manager to predict in batches.
190
+
191
+ .. code-block:: python
192
+
193
+ with model.batch_mode(batch_size='max', num_retries=1):
194
+ df = model.predict(query, indices=...)
195
+
196
+ Args:
197
+ batch_size: The batch size. If set to ``"max"``, will use the
198
+ maximum applicable batch size for the given task.
199
+ num_retries: The maximum number of retries for failed queries due
200
+ to unexpected server issues.
201
+ """
202
+ if batch_size != 'max' and batch_size <= 0:
203
+ raise ValueError(f"'batch_size' must be greater than zero "
204
+ f"(got {batch_size})")
205
+
206
+ if num_retries < 0:
207
+ raise ValueError(f"'num_retries' must be greater than or equal to "
208
+ f"zero (got {num_retries})")
209
+
210
+ self._batch_size = batch_size
211
+ self.num_retries = num_retries
212
+ yield
213
+ self._batch_size = None
214
+ self.num_retries = 0
215
+
216
+ @overload
114
217
  def predict(
115
218
  self,
116
219
  query: str,
220
+ indices: Union[List[str], List[float], List[int], None] = None,
117
221
  *,
222
+ explain: Literal[False] = False,
118
223
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
224
+ context_anchor_time: Union[pd.Timestamp, None] = None,
119
225
  run_mode: Union[RunMode, str] = RunMode.FAST,
120
226
  num_neighbors: Optional[List[int]] = None,
121
227
  num_hops: int = 2,
122
228
  max_pq_iterations: int = 20,
123
229
  random_seed: Optional[int] = _RANDOM_SEED,
124
230
  verbose: Union[bool, ProgressLogger] = True,
231
+ use_prediction_time: bool = False,
125
232
  ) -> pd.DataFrame:
233
+ pass
234
+
235
+ @overload
236
+ def predict(
237
+ self,
238
+ query: str,
239
+ indices: Union[List[str], List[float], List[int], None] = None,
240
+ *,
241
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
242
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
243
+ context_anchor_time: Union[pd.Timestamp, None] = None,
244
+ run_mode: Union[RunMode, str] = RunMode.FAST,
245
+ num_neighbors: Optional[List[int]] = None,
246
+ num_hops: int = 2,
247
+ max_pq_iterations: int = 20,
248
+ random_seed: Optional[int] = _RANDOM_SEED,
249
+ verbose: Union[bool, ProgressLogger] = True,
250
+ use_prediction_time: bool = False,
251
+ ) -> Explanation:
252
+ pass
253
+
254
+ def predict(
255
+ self,
256
+ query: str,
257
+ indices: Union[List[str], List[float], List[int], None] = None,
258
+ *,
259
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
260
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
261
+ context_anchor_time: Union[pd.Timestamp, None] = None,
262
+ run_mode: Union[RunMode, str] = RunMode.FAST,
263
+ num_neighbors: Optional[List[int]] = None,
264
+ num_hops: int = 2,
265
+ max_pq_iterations: int = 20,
266
+ random_seed: Optional[int] = _RANDOM_SEED,
267
+ verbose: Union[bool, ProgressLogger] = True,
268
+ use_prediction_time: bool = False,
269
+ ) -> Union[pd.DataFrame, Explanation]:
126
270
  """Returns predictions for a predictive query.
127
271
 
128
272
  Args:
129
273
  query: The predictive query.
130
- anchor_time: The anchor timestamp for the query. If set to
131
- :obj:`None`, will use the maximum timestamp in the data.
132
- If set to :`"entity"`, will use the timestamp of the entity.
274
+ indices: The entity primary keys to predict for. Will override the
275
+ indices given as part of the predictive query. Predictions will
276
+ be generated for all indices, independent of whether they
277
+ fulfill entity filter constraints. To pre-filter entities, use
278
+ :meth:`~KumoRFM.is_valid_entity`.
279
+ explain: Configuration for explainability.
280
+ If set to ``True``, will additionally explain the prediction.
281
+ Passing in an :class:`ExplainConfig` instance provides control
282
+ over which parts of explanation are generated.
283
+ Explainability is currently only supported for single entity
284
+ predictions with ``run_mode="FAST"``.
285
+ anchor_time: The anchor timestamp for the prediction. If set to
286
+ ``None``, will use the maximum timestamp in the data.
287
+ If set to ``"entity"``, will use the timestamp of the entity.
288
+ context_anchor_time: The maximum anchor timestamp for context
289
+ examples. If set to ``None``, ``anchor_time`` will
290
+ determine the anchor time for context examples.
133
291
  run_mode: The :class:`RunMode` for the query.
134
292
  num_neighbors: The number of neighbors to sample for each hop.
135
293
  If specified, the ``num_hops`` option will be ignored.
@@ -141,32 +299,54 @@ class KumoRFM:
141
299
  entities to find valid labels.
142
300
  random_seed: A manual seed for generating pseudo-random numbers.
143
301
  verbose: Whether to print verbose output.
302
+ use_prediction_time: Whether to use the anchor timestamp as an
303
+ additional feature during prediction. This is typically
304
+ beneficial for time series forecasting tasks.
144
305
 
145
306
  Returns:
146
- The predictions as a :class:`pandas.DataFrame`
307
+ The predictions as a :class:`pandas.DataFrame`.
308
+ If ``explain`` is provided, returns an :class:`Explanation` object
309
+ containing the prediction, summary, and details.
147
310
  """
148
- explain = False
311
+ explain_config: Optional[ExplainConfig] = None
312
+ if explain is True:
313
+ explain_config = ExplainConfig()
314
+ elif explain is not False:
315
+ explain_config = ExplainConfig._cast(explain)
316
+
149
317
  query_def = self._parse_query(query)
318
+ query_str = query_def.to_string()
150
319
 
151
320
  if num_hops != 2 and num_neighbors is not None:
152
321
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
153
322
  f"custom 'num_hops={num_hops}' option")
154
323
 
155
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
324
+ if explain_config is not None and run_mode in {
325
+ RunMode.NORMAL, RunMode.BEST
326
+ }:
156
327
  warnings.warn(f"Explainability is currently only supported for "
157
328
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
158
329
  f"mode has been reset. Please lower the run mode to "
159
330
  f"suppress this warning.")
160
331
 
161
- if explain:
162
- assert query_def.entity.ids is not None
163
- if len(query_def.entity.ids.value) > 1:
164
- raise ValueError(
165
- f"Cannot explain predictions for more than a single "
166
- f"entity (got {len(query_def.entity.ids.value)})")
332
+ if indices is None:
333
+ if query_def.rfm_entity_ids is None:
334
+ raise ValueError("Cannot find entities to predict for. Please "
335
+ "pass them via `predict(query, indices=...)`")
336
+ indices = query_def.get_rfm_entity_id_list()
337
+ else:
338
+ query_def = replace(query_def, rfm_entity_ids=None)
339
+
340
+ if len(indices) == 0:
341
+ raise ValueError("At least one entity is required")
342
+
343
+ if explain_config is not None and len(indices) > 1:
344
+ raise ValueError(
345
+ f"Cannot explain predictions for more than a single entity "
346
+ f"(got {len(indices)})")
167
347
 
168
348
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
169
- if explain:
349
+ if explain_config is not None:
170
350
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
171
351
  else:
172
352
  msg = f'[bold]PREDICT[/bold] {query_repr}'
@@ -175,48 +355,188 @@ class KumoRFM:
175
355
  verbose = InteractiveProgressLogger(msg, verbose=verbose)
176
356
 
177
357
  with verbose as logger:
178
- context = self._get_context(
179
- query_def,
180
- anchor_time=anchor_time,
181
- run_mode=RunMode(run_mode),
182
- num_neighbors=num_neighbors,
183
- num_hops=num_hops,
184
- max_pq_iterations=max_pq_iterations,
185
- evaluate=False,
186
- random_seed=random_seed,
187
- logger=logger,
188
- )
189
- request = RFMPredictRequest(
190
- context=context,
191
- run_mode=RunMode(run_mode),
358
+
359
+ batch_size: Optional[int] = None
360
+ if self._batch_size == 'max':
361
+ task_type = LocalPQueryDriver.get_task_type(
362
+ query_def,
363
+ edge_types=self._graph_store.edge_types,
364
+ )
365
+ batch_size = _MAX_PRED_SIZE[task_type]
366
+ else:
367
+ batch_size = self._batch_size
368
+
369
+ if batch_size is not None:
370
+ offsets = range(0, len(indices), batch_size)
371
+ batches = [indices[step:step + batch_size] for step in offsets]
372
+ else:
373
+ batches = [indices]
374
+
375
+ if len(batches) > 1:
376
+ logger.log(f"Splitting {len(indices):,} entities into "
377
+ f"{len(batches):,} batches of size {batch_size:,}")
378
+
379
+ predictions: List[pd.DataFrame] = []
380
+ summary: Optional[str] = None
381
+ details: Optional[Explanation] = None
382
+ for i, batch in enumerate(batches):
383
+ # TODO Re-use the context for subsequent predictions.
384
+ context = self._get_context(
385
+ query=query_def,
386
+ indices=batch,
387
+ anchor_time=anchor_time,
388
+ context_anchor_time=context_anchor_time,
389
+ run_mode=RunMode(run_mode),
390
+ num_neighbors=num_neighbors,
391
+ num_hops=num_hops,
392
+ max_pq_iterations=max_pq_iterations,
393
+ evaluate=False,
394
+ random_seed=random_seed,
395
+ logger=logger if i == 0 else None,
396
+ )
397
+ request = RFMPredictRequest(
398
+ context=context,
399
+ run_mode=RunMode(run_mode),
400
+ query=query_str,
401
+ use_prediction_time=use_prediction_time,
402
+ )
403
+ with warnings.catch_warnings():
404
+ warnings.filterwarnings('ignore', message='gencode')
405
+ request_msg = request.to_protobuf()
406
+ _bytes = request_msg.SerializeToString()
407
+ if i == 0:
408
+ logger.log(f"Generated context of size "
409
+ f"{len(_bytes) / (1024*1024):.2f}MB")
410
+
411
+ if len(_bytes) > _MAX_SIZE:
412
+ stats = Context.get_memory_stats(request_msg.context)
413
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
414
+
415
+ if (isinstance(verbose, InteractiveProgressLogger) and i == 0
416
+ and len(batches) > 1):
417
+ verbose.init_progress(
418
+ total=len(batches),
419
+ description='Predicting',
420
+ )
421
+
422
+ for attempt in range(self.num_retries + 1):
423
+ try:
424
+ if explain_config is not None:
425
+ resp = self._api_client.explain(
426
+ request=_bytes,
427
+ skip_summary=explain_config.skip_summary,
428
+ )
429
+ summary = resp.summary
430
+ details = resp.details
431
+ else:
432
+ resp = self._api_client.predict(_bytes)
433
+ df = pd.DataFrame(**resp.prediction)
434
+
435
+ # Cast 'ENTITY' to correct data type:
436
+ if 'ENTITY' in df:
437
+ entity = query_def.entity_table
438
+ pkey_map = self._graph_store.pkey_map_dict[entity]
439
+ df['ENTITY'] = df['ENTITY'].astype(
440
+ type(pkey_map.index[0]))
441
+
442
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
443
+ if 'ANCHOR_TIMESTAMP' in df:
444
+ ser = df['ANCHOR_TIMESTAMP']
445
+ if not pd.api.types.is_datetime64_any_dtype(ser):
446
+ if isinstance(ser.iloc[0], str):
447
+ unit = None
448
+ else:
449
+ unit = 'ms'
450
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
451
+ ser, errors='coerce', unit=unit)
452
+
453
+ predictions.append(df)
454
+
455
+ if (isinstance(verbose, InteractiveProgressLogger)
456
+ and len(batches) > 1):
457
+ verbose.step()
458
+
459
+ break
460
+ except HTTPException as e:
461
+ if attempt == self.num_retries:
462
+ try:
463
+ msg = json.loads(e.detail)['detail']
464
+ except Exception:
465
+ msg = e.detail
466
+ raise RuntimeError(
467
+ f"An unexpected exception occurred. Please "
468
+ f"create an issue at "
469
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
470
+ ) from None
471
+
472
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
473
+
474
+ if len(predictions) == 1:
475
+ prediction = predictions[0]
476
+ else:
477
+ prediction = pd.concat(predictions, ignore_index=True)
478
+
479
+ if explain_config is not None:
480
+ assert len(predictions) == 1
481
+ assert summary is not None
482
+ assert details is not None
483
+ return Explanation(
484
+ prediction=prediction,
485
+ summary=summary,
486
+ details=details,
192
487
  )
193
- with warnings.catch_warnings():
194
- warnings.filterwarnings('ignore', message='Protobuf gencode')
195
- request_msg = request.to_protobuf()
196
- request_bytes = request_msg.SerializeToString()
197
- logger.log(f"Generated context of size "
198
- f"{len(request_bytes) / (1024*1024):.2f}MB")
199
488
 
200
- if len(request_bytes) > _MAX_SIZE:
201
- stats_msg = Context.get_memory_stats(request_msg.context)
202
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
489
+ return prediction
203
490
 
204
- try:
205
- if explain:
206
- resp = global_state.client.rfm_api.explain(request_bytes)
207
- else:
208
- resp = global_state.client.rfm_api.predict(request_bytes)
209
- except HTTPException as e:
210
- try:
211
- msg = json.loads(e.detail)['detail']
212
- except Exception:
213
- msg = e.detail
214
- raise RuntimeError(f"An unexpected exception occurred. "
215
- f"Please create an issue at "
216
- f"'https://github.com/kumo-ai/kumo-rfm'. "
217
- f"{msg}") from None
491
+ def is_valid_entity(
492
+ self,
493
+ query: str,
494
+ indices: Union[List[str], List[float], List[int], None] = None,
495
+ *,
496
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
497
+ ) -> np.ndarray:
498
+ r"""Returns a mask that denotes which entities are valid for the
499
+ given predictive query, *i.e.*, which entities fulfill (temporal)
500
+ entity filter constraints.
218
501
 
219
- return pd.DataFrame(**resp.prediction)
502
+ Args:
503
+ query: The predictive query.
504
+ indices: The entity primary keys to predict for. Will override the
505
+ indices given as part of the predictive query.
506
+ anchor_time: The anchor timestamp for the prediction. If set to
507
+ ``None``, will use the maximum timestamp in the data.
508
+ If set to ``"entity"``, will use the timestamp of the entity.
509
+ """
510
+ query_def = self._parse_query(query)
511
+
512
+ if indices is None:
513
+ if query_def.rfm_entity_ids is None:
514
+ raise ValueError("Cannot find entities to predict for. Please "
515
+ "pass them via "
516
+ "`is_valid_entity(query, indices=...)`")
517
+ indices = query_def.get_rfm_entity_id_list()
518
+
519
+ if len(indices) == 0:
520
+ raise ValueError("At least one entity is required")
521
+
522
+ if anchor_time is None:
523
+ anchor_time = self._graph_store.max_time
524
+
525
+ if isinstance(anchor_time, pd.Timestamp):
526
+ self._validate_time(query_def, anchor_time, None, False)
527
+ else:
528
+ assert anchor_time == 'entity'
529
+ if (query_def.entity_table not in self._graph_store.time_dict):
530
+ raise ValueError(f"Anchor time 'entity' requires the entity "
531
+ f"table '{query_def.entity_table}' "
532
+ f"to have a time column.")
533
+
534
+ node = self._graph_store.get_node_id(
535
+ table_name=query_def.entity_table,
536
+ pkey=pd.Series(indices),
537
+ )
538
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
539
+ return query_driver.is_valid(node, anchor_time)
220
540
 
221
541
  def evaluate(
222
542
  self,
@@ -224,21 +544,26 @@ class KumoRFM:
224
544
  *,
225
545
  metrics: Optional[List[str]] = None,
226
546
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
547
+ context_anchor_time: Union[pd.Timestamp, None] = None,
227
548
  run_mode: Union[RunMode, str] = RunMode.FAST,
228
549
  num_neighbors: Optional[List[int]] = None,
229
550
  num_hops: int = 2,
230
551
  max_pq_iterations: int = 20,
231
552
  random_seed: Optional[int] = _RANDOM_SEED,
232
553
  verbose: Union[bool, ProgressLogger] = True,
554
+ use_prediction_time: bool = False,
233
555
  ) -> pd.DataFrame:
234
556
  """Evaluates a predictive query.
235
557
 
236
558
  Args:
237
559
  query: The predictive query.
238
560
  metrics: The metrics to use.
239
- anchor_time: The anchor timestamp for the query. If set to
240
- :obj:`None`, will use the maximum timestamp in the data.
241
- If set to :`"entity"`, will use the timestamp of the entity.
561
+ anchor_time: The anchor timestamp for the prediction. If set to
562
+ ``None``, will use the maximum timestamp in the data.
563
+ If set to ``"entity"``, will use the timestamp of the entity.
564
+ context_anchor_time: The maximum anchor timestamp for context
565
+ examples. If set to ``None``, ``anchor_time`` will
566
+ determine the anchor time for context examples.
242
567
  run_mode: The :class:`RunMode` for the query.
243
568
  num_neighbors: The number of neighbors to sample for each hop.
244
569
  If specified, the ``num_hops`` option will be ignored.
@@ -250,6 +575,9 @@ class KumoRFM:
250
575
  entities to find valid labels.
251
576
  random_seed: A manual seed for generating pseudo-random numbers.
252
577
  verbose: Whether to print verbose output.
578
+ use_prediction_time: Whether to use the anchor timestamp as an
579
+ additional feature during prediction. This is typically
580
+ beneficial for time series forecasting tasks.
253
581
 
254
582
  Returns:
255
583
  The metrics as a :class:`pandas.DataFrame`
@@ -260,6 +588,12 @@ class KumoRFM:
260
588
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
261
589
  f"custom 'num_hops={num_hops}' option")
262
590
 
591
+ if query_def.rfm_entity_ids is not None:
592
+ query_def = replace(
593
+ query_def,
594
+ rfm_entity_ids=None,
595
+ )
596
+
263
597
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
264
598
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
265
599
 
@@ -268,8 +602,10 @@ class KumoRFM:
268
602
 
269
603
  with verbose as logger:
270
604
  context = self._get_context(
271
- query_def,
605
+ query=query_def,
606
+ indices=None,
272
607
  anchor_time=anchor_time,
608
+ context_anchor_time=context_anchor_time,
273
609
  run_mode=RunMode(run_mode),
274
610
  num_neighbors=num_neighbors,
275
611
  num_hops=num_hops,
@@ -285,6 +621,7 @@ class KumoRFM:
285
621
  context=context,
286
622
  run_mode=RunMode(run_mode),
287
623
  metrics=metrics,
624
+ use_prediction_time=use_prediction_time,
288
625
  )
289
626
  with warnings.catch_warnings():
290
627
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -295,10 +632,10 @@ class KumoRFM:
295
632
 
296
633
  if len(request_bytes) > _MAX_SIZE:
297
634
  stats_msg = Context.get_memory_stats(request_msg.context)
298
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
635
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
299
636
 
300
637
  try:
301
- resp = global_state.client.rfm_api.evaluate(request_bytes)
638
+ resp = self._api_client.evaluate(request_bytes)
302
639
  except HTTPException as e:
303
640
  try:
304
641
  msg = json.loads(e.detail)['detail']
@@ -343,17 +680,19 @@ class KumoRFM:
343
680
 
344
681
  if anchor_time is None:
345
682
  anchor_time = self._graph_store.max_time
346
- anchor_time = anchor_time - query_def.target.end_offset
683
+ if query_def.target_ast.date_offset_range is not None:
684
+ anchor_time = anchor_time - (
685
+ query_def.target_ast.date_offset_range.end_date_offset *
686
+ query_def.num_forecasts)
347
687
 
348
688
  assert anchor_time is not None
349
689
  if isinstance(anchor_time, pd.Timestamp):
350
- self._validate_time(query_def, anchor_time, evaluate=True)
690
+ self._validate_time(query_def, anchor_time, None, evaluate=True)
351
691
  else:
352
692
  assert anchor_time == 'entity'
353
- if (query_def.entity.pkey.table_name
354
- not in self._graph_store.time_dict):
693
+ if (query_def.entity_table not in self._graph_store.time_dict):
355
694
  raise ValueError(f"Anchor time 'entity' requires the entity "
356
- f"table '{query_def.entity.pkey.table_name}' "
695
+ f"table '{query_def.entity_table}' "
357
696
  f"to have a time column")
358
697
 
359
698
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -364,18 +703,22 @@ class KumoRFM:
364
703
  anchor_time=anchor_time,
365
704
  batch_size=min(10_000, size),
366
705
  max_iterations=max_iterations,
706
+ guarantee_train_examples=False,
367
707
  )
368
708
 
709
+ entity = self._graph_store.pkey_map_dict[
710
+ query_def.entity_table].index[node]
711
+
369
712
  return pd.DataFrame({
370
- 'ENTITY': node,
713
+ 'ENTITY': entity,
371
714
  'ANCHOR_TIMESTAMP': time,
372
715
  'TARGET': y,
373
716
  })
374
717
 
375
718
  # Helpers #################################################################
376
719
 
377
- def _parse_query(self, query: str) -> PQueryDefinition:
378
- if isinstance(query, PQueryDefinition):
720
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
721
+ if isinstance(query, ValidatedPredictiveQuery):
379
722
  return query
380
723
 
381
724
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -385,12 +728,13 @@ class KumoRFM:
385
728
  "predictions or evaluations.")
386
729
 
387
730
  try:
388
- request = RFMValidateQueryRequest(
731
+ request = RFMParseQueryRequest(
389
732
  query=query,
390
733
  graph_definition=self._graph_def,
391
734
  )
392
735
 
393
- resp = global_state.client.rfm_api.validate_query(request)
736
+ resp = self._api_client.parse_query(request)
737
+
394
738
  # TODO Expose validation warnings.
395
739
 
396
740
  if len(resp.validation_response.warnings) > 0:
@@ -401,7 +745,7 @@ class KumoRFM:
401
745
  warnings.warn(f"Encountered the following warnings during "
402
746
  f"parsing:\n{msg}")
403
747
 
404
- return resp.query_definition
748
+ return resp.query
405
749
  except HTTPException as e:
406
750
  try:
407
751
  msg = json.loads(e.detail)['detail']
@@ -412,8 +756,9 @@ class KumoRFM:
412
756
 
413
757
  def _validate_time(
414
758
  self,
415
- query: PQueryDefinition,
759
+ query: ValidatedPredictiveQuery,
416
760
  anchor_time: pd.Timestamp,
761
+ context_anchor_time: Union[pd.Timestamp, None],
417
762
  evaluate: bool,
418
763
  ) -> None:
419
764
 
@@ -425,20 +770,44 @@ class KumoRFM:
425
770
  f"the earliest timestamp "
426
771
  f"'{self._graph_store.min_time}' in the data.")
427
772
 
428
- req_time = anchor_time - query.target.end_offset * query.num_forecasts
429
- if req_time < self._graph_store.min_time:
430
- raise ValueError(f"Anchor timestamp is too early or aggregation "
431
- f"time range is too large. To make this "
432
- f"prediction, we would need data back to "
433
- f"'{req_time}', however, your data only contains "
434
- f"data back to '{self._graph_store.min_time}'.")
435
-
436
- req_time -= query.target.end_offset * query.num_forecasts
437
- if req_time < self._graph_store.min_time:
438
- warnings.warn(f"Anchor timestamp is too early or aggregation "
439
- f"time range is too large. To form proper input "
440
- f"data, we would need data back to "
441
- f"'{req_time}', however, your data only contains "
773
+ if (context_anchor_time is not None
774
+ and context_anchor_time < self._graph_store.min_time):
775
+ raise ValueError(f"Context anchor timestamp is too early or "
776
+ f"aggregation time range is too large. To make "
777
+ f"this prediction, we would need data back to "
778
+ f"'{context_anchor_time}', however, your data "
779
+ f"only contains data back to "
780
+ f"'{self._graph_store.min_time}'.")
781
+
782
+ if query.target_ast.date_offset_range is not None:
783
+ end_offset = query.target_ast.date_offset_range.end_date_offset
784
+ else:
785
+ end_offset = pd.DateOffset(0)
786
+ forecast_end_offset = end_offset * query.num_forecasts
787
+ if (context_anchor_time is not None
788
+ and context_anchor_time > anchor_time):
789
+ warnings.warn(f"Context anchor timestamp "
790
+ f"(got '{context_anchor_time}') is set to a later "
791
+ f"date than the prediction anchor timestamp "
792
+ f"(got '{anchor_time}'). Please make sure this is "
793
+ f"intended.")
794
+ elif (query.query_type == QueryType.TEMPORAL
795
+ and context_anchor_time is not None
796
+ and context_anchor_time + forecast_end_offset > anchor_time):
797
+ warnings.warn(f"Aggregation for context examples at timestamp "
798
+ f"'{context_anchor_time}' will leak information "
799
+ f"from the prediction anchor timestamp "
800
+ f"'{anchor_time}'. Please make sure this is "
801
+ f"intended.")
802
+
803
+ elif (context_anchor_time is not None
804
+ and context_anchor_time - forecast_end_offset
805
+ < self._graph_store.min_time):
806
+ _time = context_anchor_time - forecast_end_offset
807
+ warnings.warn(f"Context anchor timestamp is too early or "
808
+ f"aggregation time range is too large. To form "
809
+ f"proper input data, we would need data back to "
810
+ f"'{_time}', however, your data only contains "
442
811
  f"data back to '{self._graph_store.min_time}'.")
443
812
 
444
813
  if (not evaluate and anchor_time
@@ -447,8 +816,7 @@ class KumoRFM:
447
816
  f"latest timestamp '{self._graph_store.max_time}' "
448
817
  f"in the data. Please make sure this is intended.")
449
818
 
450
- max_eval_time = (self._graph_store.max_time -
451
- query.target.end_offset * query.num_forecasts)
819
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
452
820
  if evaluate and anchor_time > max_eval_time:
453
821
  raise ValueError(
454
822
  f"Anchor timestamp for evaluation is after the latest "
@@ -456,8 +824,10 @@ class KumoRFM:
456
824
 
457
825
  def _get_context(
458
826
  self,
459
- query: PQueryDefinition,
827
+ query: ValidatedPredictiveQuery,
828
+ indices: Union[List[str], List[float], List[int], None],
460
829
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
830
+ context_anchor_time: Union[pd.Timestamp, None],
461
831
  run_mode: RunMode,
462
832
  num_neighbors: Optional[List[int]],
463
833
  num_hops: int,
@@ -482,8 +852,8 @@ class KumoRFM:
482
852
  f"must go beyond this for your use-case.")
483
853
 
484
854
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
485
- task_type = query.get_task_type(
486
- stypes=self._graph_store.stype_dict,
855
+ task_type = LocalPQueryDriver.get_task_type(
856
+ query,
487
857
  edge_types=self._graph_store.edge_types,
488
858
  )
489
859
 
@@ -515,28 +885,42 @@ class KumoRFM:
515
885
  else:
516
886
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
517
887
 
888
+ if query.target_ast.date_offset_range is None:
889
+ end_offset = pd.DateOffset(0)
890
+ else:
891
+ end_offset = query.target_ast.date_offset_range.end_date_offset
892
+ forecast_end_offset = end_offset * query.num_forecasts
518
893
  if anchor_time is None:
519
894
  anchor_time = self._graph_store.max_time
520
895
  if evaluate:
521
- anchor_time = anchor_time - query.target.end_offset
896
+ anchor_time = anchor_time - forecast_end_offset
522
897
  if logger is not None:
523
898
  assert isinstance(anchor_time, pd.Timestamp)
524
- if (anchor_time.hour == 0 and anchor_time.minute == 0
525
- and anchor_time.second == 0
526
- and anchor_time.microsecond == 0):
899
+ if anchor_time == pd.Timestamp.min:
900
+ pass # Static graph
901
+ elif (anchor_time.hour == 0 and anchor_time.minute == 0
902
+ and anchor_time.second == 0
903
+ and anchor_time.microsecond == 0):
527
904
  logger.log(f"Derived anchor time {anchor_time.date()}")
528
905
  else:
529
906
  logger.log(f"Derived anchor time {anchor_time}")
530
907
 
531
908
  assert anchor_time is not None
532
909
  if isinstance(anchor_time, pd.Timestamp):
533
- self._validate_time(query, anchor_time, evaluate)
910
+ if context_anchor_time is None:
911
+ context_anchor_time = anchor_time - forecast_end_offset
912
+ self._validate_time(query, anchor_time, context_anchor_time,
913
+ evaluate)
534
914
  else:
535
915
  assert anchor_time == 'entity'
536
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
916
+ if query.entity_table not in self._graph_store.time_dict:
537
917
  raise ValueError(f"Anchor time 'entity' requires the entity "
538
- f"table '{query.entity.pkey.table_name}' to "
918
+ f"table '{query.entity_table}' to "
539
919
  f"have a time column")
920
+ if context_anchor_time is not None:
921
+ warnings.warn("Ignoring option 'context_anchor_time' for "
922
+ "`anchor_time='entity'`")
923
+ context_anchor_time = None
540
924
 
541
925
  y_test: Optional[pd.Series] = None
542
926
  if evaluate:
@@ -548,6 +932,7 @@ class KumoRFM:
548
932
  size=max_test_size,
549
933
  anchor_time=anchor_time,
550
934
  max_iterations=max_pq_iterations,
935
+ guarantee_train_examples=True,
551
936
  )
552
937
  if logger is not None:
553
938
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -571,34 +956,31 @@ class KumoRFM:
571
956
  logger.log(msg)
572
957
 
573
958
  else:
574
- assert query.entity.ids is not None
959
+ assert indices is not None
575
960
 
576
- max_num_test = 200 if task_type.is_link_pred else 1000
577
- if len(query.entity.ids.value) > max_num_test:
961
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
578
962
  raise ValueError(f"Cannot predict for more than "
579
- f"{max_num_test:,} entities at once "
580
- f"(got {len(query.entity.ids.value):,})")
963
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
964
+ f"once (got {len(indices):,}). Use "
965
+ f"`KumoRFM.batch_mode` to process entities "
966
+ f"in batches")
581
967
 
582
968
  test_node = self._graph_store.get_node_id(
583
- table_name=query.entity.pkey.table_name,
584
- pkey=pd.Series(
585
- query.entity.ids.value,
586
- dtype=query.entity.ids.dtype,
587
- ),
969
+ table_name=query.entity_table,
970
+ pkey=pd.Series(indices),
588
971
  )
589
972
 
590
973
  if isinstance(anchor_time, pd.Timestamp):
591
974
  test_time = pd.Series(anchor_time).repeat(
592
975
  len(test_node)).reset_index(drop=True)
593
976
  else:
594
- time = self._graph_store.time_dict[
595
- query.entity.pkey.table_name]
977
+ time = self._graph_store.time_dict[query.entity_table]
596
978
  time = time[test_node] * 1000**3
597
979
  test_time = pd.Series(time, dtype='datetime64[ns]')
598
980
 
599
981
  train_node, train_time, y_train = query_driver.collect_train(
600
982
  size=_MAX_CONTEXT_SIZE[run_mode],
601
- anchor_time=anchor_time,
983
+ anchor_time=context_anchor_time or 'entity',
602
984
  exclude_node=test_node if (query.query_type == QueryType.STATIC
603
985
  or anchor_time == 'entity') else None,
604
986
  max_iterations=max_pq_iterations,
@@ -625,12 +1007,23 @@ class KumoRFM:
625
1007
  raise NotImplementedError
626
1008
  logger.log(msg)
627
1009
 
628
- entity_table_names = query.get_entity_table_names(
629
- self._graph_store.edge_types)
1010
+ entity_table_names: Tuple[str, ...]
1011
+ if task_type.is_link_pred:
1012
+ final_aggr = query.get_final_target_aggregation()
1013
+ assert final_aggr is not None
1014
+ edge_fkey = final_aggr._get_target_column_name()
1015
+ for edge_type in self._graph_store.edge_types:
1016
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1017
+ entity_table_names = (
1018
+ query.entity_table,
1019
+ edge_type[2],
1020
+ )
1021
+ else:
1022
+ entity_table_names = (query.entity_table, )
630
1023
 
631
1024
  # Exclude the entity anchor time from the feature set to prevent
632
1025
  # running out-of-distribution between in-context and test examples:
633
- exclude_cols_dict = query.exclude_cols_dict
1026
+ exclude_cols_dict = query.get_exclude_cols_dict()
634
1027
  if anchor_time == 'entity':
635
1028
  if entity_table_names[0] not in exclude_cols_dict:
636
1029
  exclude_cols_dict[entity_table_names[0]] = []
@@ -645,14 +1038,20 @@ class KumoRFM:
645
1038
  train_time.astype('datetime64[ns]').astype(int).to_numpy(),
646
1039
  test_time.astype('datetime64[ns]').astype(int).to_numpy(),
647
1040
  ]),
648
- run_mode=run_mode,
649
1041
  num_neighbors=num_neighbors,
650
1042
  exclude_cols_dict=exclude_cols_dict,
651
1043
  )
652
1044
 
1045
+ if len(subgraph.table_dict) >= 15:
1046
+ raise ValueError(f"Cannot query from a graph with more than 15 "
1047
+ f"tables (got {len(subgraph.table_dict)}). "
1048
+ f"Please create a feature request at "
1049
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1050
+ f"must go beyond this for your use-case.")
1051
+
653
1052
  step_size: Optional[int] = None
654
1053
  if query.query_type == QueryType.TEMPORAL:
655
- step_size = date_offset_to_seconds(query.target.end_offset)
1054
+ step_size = date_offset_to_seconds(end_offset)
656
1055
 
657
1056
  return Context(
658
1057
  task_type=task_type,
@@ -677,7 +1076,7 @@ class KumoRFM:
677
1076
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
678
1077
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
679
1078
  elif task_type == TaskType.REGRESSION:
680
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1079
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
681
1080
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
682
1081
  supported_metrics = [
683
1082
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',