kumoai 2.7.0.dev202508201830__cp312-cp312-win_amd64.whl → 2.12.0.dev202511111731__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. kumoai/__init__.py +4 -2
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +10 -5
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +37 -8
  6. kumoai/connector/file_upload_connector.py +94 -85
  7. kumoai/connector/snowflake_connector.py +9 -0
  8. kumoai/connector/utils.py +1377 -209
  9. kumoai/experimental/rfm/__init__.py +5 -3
  10. kumoai/experimental/rfm/authenticate.py +8 -5
  11. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  12. kumoai/experimental/rfm/local_graph.py +96 -82
  13. kumoai/experimental/rfm/local_graph_sampler.py +16 -8
  14. kumoai/experimental/rfm/local_graph_store.py +32 -10
  15. kumoai/experimental/rfm/local_pquery_driver.py +342 -46
  16. kumoai/experimental/rfm/local_table.py +142 -45
  17. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  18. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
  19. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  20. kumoai/experimental/rfm/rfm.py +535 -125
  21. kumoai/experimental/rfm/utils.py +0 -3
  22. kumoai/jobs.py +27 -1
  23. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  24. kumoai/pquery/prediction_table.py +5 -3
  25. kumoai/pquery/training_table.py +5 -3
  26. kumoai/trainer/job.py +9 -30
  27. kumoai/trainer/trainer.py +19 -10
  28. kumoai/utils/__init__.py +2 -1
  29. kumoai/utils/progress_logger.py +96 -16
  30. {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/METADATA +4 -5
  31. {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/RECORD +34 -34
  32. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
  33. {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/WHEEL +0 -0
  34. {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/licenses/LICENSE +0 -0
  35. {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,32 @@
1
1
  import json
2
+ import time
2
3
  import warnings
3
- from typing import List, Literal, Optional, Union
4
+ from collections import defaultdict
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
4
19
 
5
20
  import numpy as np
6
21
  import pandas as pd
7
22
  from kumoapi.model_plan import RunMode
8
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.rfm import Context
25
+ from kumoapi.rfm import Explanation as ExplanationConfig
9
26
  from kumoapi.rfm import (
10
- Context,
11
- PQueryDefinition,
12
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
13
29
  RFMPredictRequest,
14
- RFMValidateQueryRequest,
15
30
  )
16
31
  from kumoapi.task import TaskType
17
32
 
@@ -20,11 +35,18 @@ from kumoai.exceptions import HTTPException
20
35
  from kumoai.experimental.rfm import LocalGraph
21
36
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
22
37
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
23
- from kumoai.experimental.rfm.local_pquery_driver import LocalPQueryDriver
24
- from kumoai.utils import ProgressLogger
38
+ from kumoai.experimental.rfm.local_pquery_driver import (
39
+ LocalPQueryDriver,
40
+ date_offset_to_seconds,
41
+ )
42
+ from kumoai.mixin import CastMixin
43
+ from kumoai.utils import InteractiveProgressLogger, ProgressLogger
25
44
 
26
45
  _RANDOM_SEED = 42
27
46
 
47
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
+
28
50
  _MAX_CONTEXT_SIZE = {
29
51
  RunMode.DEBUG: 100,
30
52
  RunMode.FAST: 1_000,
@@ -39,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
39
61
  }
40
62
 
41
63
  _MAX_SIZE = 30 * 1024 * 1024
42
- _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
64
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
43
65
  "reduce either the number of tables in the graph, their "
44
66
  "number of columns (e.g., large text columns), "
45
67
  "neighborhood configuration, or the run mode. If none of "
@@ -48,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
48
70
  "beyond this for your use-case.")
49
71
 
50
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
84
+ @dataclass(repr=False)
85
+ class Explanation:
86
+ prediction: pd.DataFrame
87
+ summary: str
88
+ details: ExplanationConfig
89
+
90
+ @overload
91
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
92
+ pass
93
+
94
+ @overload
95
+ def __getitem__(self, index: Literal[1]) -> str:
96
+ pass
97
+
98
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
99
+ if index == 0:
100
+ return self.prediction
101
+ if index == 1:
102
+ return self.summary
103
+ raise IndexError("Index out of range")
104
+
105
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
106
+ return iter((self.prediction, self.summary))
107
+
108
+ def __repr__(self) -> str:
109
+ return str((self.prediction, self.summary))
110
+
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
117
+
51
118
  class KumoRFM:
52
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
53
120
  Foundation Model for In-Context Learning on Relational Data
@@ -98,35 +165,127 @@ class KumoRFM:
98
165
  self,
99
166
  graph: LocalGraph,
100
167
  preprocess: bool = False,
101
- verbose: bool = True,
168
+ verbose: Union[bool, ProgressLogger] = True,
102
169
  ) -> None:
103
170
  graph = graph.validate()
104
171
  self._graph_def = graph._to_api_graph_definition()
105
172
  self._graph_store = LocalGraphStore(graph, preprocess, verbose)
106
173
  self._graph_sampler = LocalGraphSampler(self._graph_store)
107
174
 
175
+ self._batch_size: Optional[int | Literal['max']] = None
176
+ self.num_retries: int = 0
177
+
108
178
  def __repr__(self) -> str:
109
179
  return f'{self.__class__.__name__}()'
110
180
 
181
+ @contextmanager
182
+ def batch_mode(
183
+ self,
184
+ batch_size: Union[int, Literal['max']] = 'max',
185
+ num_retries: int = 1,
186
+ ) -> Generator[None, None, None]:
187
+ """Context manager to predict in batches.
188
+
189
+ .. code-block:: python
190
+
191
+ with model.batch_mode(batch_size='max', num_retries=1):
192
+ df = model.predict(query, indices=...)
193
+
194
+ Args:
195
+ batch_size: The batch size. If set to ``"max"``, will use the
196
+ maximum applicable batch size for the given task.
197
+ num_retries: The maximum number of retries for failed queries due
198
+ to unexpected server issues.
199
+ """
200
+ if batch_size != 'max' and batch_size <= 0:
201
+ raise ValueError(f"'batch_size' must be greater than zero "
202
+ f"(got {batch_size})")
203
+
204
+ if num_retries < 0:
205
+ raise ValueError(f"'num_retries' must be greater than or equal to "
206
+ f"zero (got {num_retries})")
207
+
208
+ self._batch_size = batch_size
209
+ self.num_retries = num_retries
210
+ yield
211
+ self._batch_size = None
212
+ self.num_retries = 0
213
+
214
+ @overload
111
215
  def predict(
112
216
  self,
113
217
  query: str,
218
+ indices: Union[List[str], List[float], List[int], None] = None,
114
219
  *,
220
+ explain: Literal[False] = False,
115
221
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
222
+ context_anchor_time: Union[pd.Timestamp, None] = None,
116
223
  run_mode: Union[RunMode, str] = RunMode.FAST,
117
224
  num_neighbors: Optional[List[int]] = None,
118
225
  num_hops: int = 2,
119
226
  max_pq_iterations: int = 20,
120
227
  random_seed: Optional[int] = _RANDOM_SEED,
121
- verbose: bool = True,
228
+ verbose: Union[bool, ProgressLogger] = True,
229
+ use_prediction_time: bool = False,
122
230
  ) -> pd.DataFrame:
231
+ pass
232
+
233
+ @overload
234
+ def predict(
235
+ self,
236
+ query: str,
237
+ indices: Union[List[str], List[float], List[int], None] = None,
238
+ *,
239
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
240
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
241
+ context_anchor_time: Union[pd.Timestamp, None] = None,
242
+ run_mode: Union[RunMode, str] = RunMode.FAST,
243
+ num_neighbors: Optional[List[int]] = None,
244
+ num_hops: int = 2,
245
+ max_pq_iterations: int = 20,
246
+ random_seed: Optional[int] = _RANDOM_SEED,
247
+ verbose: Union[bool, ProgressLogger] = True,
248
+ use_prediction_time: bool = False,
249
+ ) -> Explanation:
250
+ pass
251
+
252
+ def predict(
253
+ self,
254
+ query: str,
255
+ indices: Union[List[str], List[float], List[int], None] = None,
256
+ *,
257
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
258
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
259
+ context_anchor_time: Union[pd.Timestamp, None] = None,
260
+ run_mode: Union[RunMode, str] = RunMode.FAST,
261
+ num_neighbors: Optional[List[int]] = None,
262
+ num_hops: int = 2,
263
+ max_pq_iterations: int = 20,
264
+ random_seed: Optional[int] = _RANDOM_SEED,
265
+ verbose: Union[bool, ProgressLogger] = True,
266
+ use_prediction_time: bool = False,
267
+ ) -> Union[pd.DataFrame, Explanation]:
123
268
  """Returns predictions for a predictive query.
124
269
 
125
270
  Args:
126
271
  query: The predictive query.
127
- anchor_time: The anchor timestamp for the query. If set to
128
- :obj:`None`, will use the maximum timestamp in the data.
129
- If set to :`"entity"`, will use the timestamp of the entity.
272
+ indices: The entity primary keys to predict for. Will override the
273
+ indices given as part of the predictive query. Predictions will
274
+ be generated for all indices, independent of whether they
275
+ fulfill entity filter constraints. To pre-filter entities, use
276
+ :meth:`~KumoRFM.is_valid_entity`.
277
+ explain: Configuration for explainability.
278
+ If set to ``True``, will additionally explain the prediction.
279
+ Passing in an :class:`ExplainConfig` instance provides control
280
+ over which parts of explanation are generated.
281
+ Explainability is currently only supported for single entity
282
+ predictions with ``run_mode="FAST"``.
283
+ anchor_time: The anchor timestamp for the prediction. If set to
284
+ ``None``, will use the maximum timestamp in the data.
285
+ If set to ``"entity"``, will use the timestamp of the entity.
286
+ context_anchor_time: The maximum anchor timestamp for context
287
+ examples. If set to ``None``, ``anchor_time`` will
288
+ determine the anchor time for context examples.
130
289
  run_mode: The :class:`RunMode` for the query.
131
290
  num_neighbors: The number of neighbors to sample for each hop.
132
291
  If specified, the ``num_hops`` option will be ignored.
@@ -138,79 +297,244 @@ class KumoRFM:
138
297
  entities to find valid labels.
139
298
  random_seed: A manual seed for generating pseudo-random numbers.
140
299
  verbose: Whether to print verbose output.
300
+ use_prediction_time: Whether to use the anchor timestamp as an
301
+ additional feature during prediction. This is typically
302
+ beneficial for time series forecasting tasks.
141
303
 
142
304
  Returns:
143
- The predictions as a :class:`pandas.DataFrame`
305
+ The predictions as a :class:`pandas.DataFrame`.
306
+ If ``explain`` is provided, returns an :class:`Explanation` object
307
+ containing the prediction, summary, and details.
144
308
  """
145
- explain = False
309
+ explain_config: Optional[ExplainConfig] = None
310
+ if explain is True:
311
+ explain_config = ExplainConfig()
312
+ elif explain is not False:
313
+ explain_config = ExplainConfig._cast(explain)
314
+
146
315
  query_def = self._parse_query(query)
316
+ query_str = query_def.to_string()
147
317
 
148
318
  if num_hops != 2 and num_neighbors is not None:
149
319
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
150
320
  f"custom 'num_hops={num_hops}' option")
151
321
 
152
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
322
+ if explain_config is not None and run_mode in {
323
+ RunMode.NORMAL, RunMode.BEST
324
+ }:
153
325
  warnings.warn(f"Explainability is currently only supported for "
154
326
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
155
327
  f"mode has been reset. Please lower the run mode to "
156
328
  f"suppress this warning.")
157
329
 
158
- if explain:
159
- assert query_def.entity.ids is not None
160
- if len(query_def.entity.ids.value) > 1:
161
- raise ValueError(
162
- f"Cannot explain predictions for more than a single "
163
- f"entity (got {len(query_def.entity.ids.value)})")
330
+ if indices is None:
331
+ if query_def.rfm_entity_ids is None:
332
+ raise ValueError("Cannot find entities to predict for. Please "
333
+ "pass them via `predict(query, indices=...)`")
334
+ indices = query_def.get_rfm_entity_id_list()
335
+ else:
336
+ query_def = replace(query_def, rfm_entity_ids=None)
337
+
338
+ if len(indices) == 0:
339
+ raise ValueError("At least one entity is required")
340
+
341
+ if explain_config is not None and len(indices) > 1:
342
+ raise ValueError(
343
+ f"Cannot explain predictions for more than a single entity "
344
+ f"(got {len(indices)})")
164
345
 
165
346
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
166
- if explain:
347
+ if explain_config is not None:
167
348
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
168
349
  else:
169
350
  msg = f'[bold]PREDICT[/bold] {query_repr}'
170
351
 
171
- with ProgressLogger(msg, verbose=verbose) as logger:
172
- context = self._get_context(
173
- query_def,
174
- anchor_time=anchor_time,
175
- run_mode=RunMode(run_mode),
176
- num_neighbors=num_neighbors,
177
- num_hops=num_hops,
178
- max_pq_iterations=max_pq_iterations,
179
- evaluate=False,
180
- random_seed=random_seed,
181
- logger=logger,
182
- )
183
- request = RFMPredictRequest(
184
- context=context,
185
- run_mode=RunMode(run_mode),
352
+ if not isinstance(verbose, ProgressLogger):
353
+ verbose = InteractiveProgressLogger(msg, verbose=verbose)
354
+
355
+ with verbose as logger:
356
+
357
+ batch_size: Optional[int] = None
358
+ if self._batch_size == 'max':
359
+ task_type = LocalPQueryDriver.get_task_type(
360
+ query_def,
361
+ edge_types=self._graph_store.edge_types,
362
+ )
363
+ batch_size = _MAX_PRED_SIZE[task_type]
364
+ else:
365
+ batch_size = self._batch_size
366
+
367
+ if batch_size is not None:
368
+ offsets = range(0, len(indices), batch_size)
369
+ batches = [indices[step:step + batch_size] for step in offsets]
370
+ else:
371
+ batches = [indices]
372
+
373
+ if len(batches) > 1:
374
+ logger.log(f"Splitting {len(indices):,} entities into "
375
+ f"{len(batches):,} batches of size {batch_size:,}")
376
+
377
+ predictions: List[pd.DataFrame] = []
378
+ summary: Optional[str] = None
379
+ details: Optional[Explanation] = None
380
+ for i, batch in enumerate(batches):
381
+ # TODO Re-use the context for subsequent predictions.
382
+ context = self._get_context(
383
+ query=query_def,
384
+ indices=batch,
385
+ anchor_time=anchor_time,
386
+ context_anchor_time=context_anchor_time,
387
+ run_mode=RunMode(run_mode),
388
+ num_neighbors=num_neighbors,
389
+ num_hops=num_hops,
390
+ max_pq_iterations=max_pq_iterations,
391
+ evaluate=False,
392
+ random_seed=random_seed,
393
+ logger=logger if i == 0 else None,
394
+ )
395
+ request = RFMPredictRequest(
396
+ context=context,
397
+ run_mode=RunMode(run_mode),
398
+ query=query_str,
399
+ use_prediction_time=use_prediction_time,
400
+ )
401
+ with warnings.catch_warnings():
402
+ warnings.filterwarnings('ignore', message='gencode')
403
+ request_msg = request.to_protobuf()
404
+ _bytes = request_msg.SerializeToString()
405
+ if i == 0:
406
+ logger.log(f"Generated context of size "
407
+ f"{len(_bytes) / (1024*1024):.2f}MB")
408
+
409
+ if len(_bytes) > _MAX_SIZE:
410
+ stats = Context.get_memory_stats(request_msg.context)
411
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
412
+
413
+ if (isinstance(verbose, InteractiveProgressLogger) and i == 0
414
+ and len(batches) > 1):
415
+ verbose.init_progress(
416
+ total=len(batches),
417
+ description='Predicting',
418
+ )
419
+
420
+ for attempt in range(self.num_retries + 1):
421
+ try:
422
+ if explain_config is not None:
423
+ resp = global_state.client.rfm_api.explain(
424
+ request=_bytes,
425
+ skip_summary=explain_config.skip_summary,
426
+ )
427
+ summary = resp.summary
428
+ details = resp.details
429
+ else:
430
+ resp = global_state.client.rfm_api.predict(_bytes)
431
+ df = pd.DataFrame(**resp.prediction)
432
+
433
+ # Cast 'ENTITY' to correct data type:
434
+ if 'ENTITY' in df:
435
+ entity = query_def.entity_table
436
+ pkey_map = self._graph_store.pkey_map_dict[entity]
437
+ df['ENTITY'] = df['ENTITY'].astype(
438
+ type(pkey_map.index[0]))
439
+
440
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
441
+ if 'ANCHOR_TIMESTAMP' in df:
442
+ ser = df['ANCHOR_TIMESTAMP']
443
+ if not pd.api.types.is_datetime64_any_dtype(ser):
444
+ if isinstance(ser.iloc[0], str):
445
+ unit = None
446
+ else:
447
+ unit = 'ms'
448
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
449
+ ser, errors='coerce', unit=unit)
450
+
451
+ predictions.append(df)
452
+
453
+ if (isinstance(verbose, InteractiveProgressLogger)
454
+ and len(batches) > 1):
455
+ verbose.step()
456
+
457
+ break
458
+ except HTTPException as e:
459
+ if attempt == self.num_retries:
460
+ try:
461
+ msg = json.loads(e.detail)['detail']
462
+ except Exception:
463
+ msg = e.detail
464
+ raise RuntimeError(
465
+ f"An unexpected exception occurred. Please "
466
+ f"create an issue at "
467
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
468
+ ) from None
469
+
470
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
471
+
472
+ if len(predictions) == 1:
473
+ prediction = predictions[0]
474
+ else:
475
+ prediction = pd.concat(predictions, ignore_index=True)
476
+
477
+ if explain_config is not None:
478
+ assert len(predictions) == 1
479
+ assert summary is not None
480
+ assert details is not None
481
+ return Explanation(
482
+ prediction=prediction,
483
+ summary=summary,
484
+ details=details,
186
485
  )
187
- with warnings.catch_warnings():
188
- warnings.filterwarnings('ignore', message='Protobuf gencode')
189
- request_msg = request.to_protobuf()
190
- request_bytes = request_msg.SerializeToString()
191
- logger.log(f"Generated context of size "
192
- f"{len(request_bytes) / (1024*1024):.2f}MB")
193
486
 
194
- if len(request_bytes) > _MAX_SIZE:
195
- stats_msg = Context.get_memory_stats(request_msg.context)
196
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
487
+ return prediction
197
488
 
198
- try:
199
- if explain:
200
- resp = global_state.client.rfm_api.explain(request_bytes)
201
- else:
202
- resp = global_state.client.rfm_api.predict(request_bytes)
203
- except HTTPException as e:
204
- try:
205
- msg = json.loads(e.detail)['detail']
206
- except Exception:
207
- msg = e.detail
208
- raise RuntimeError(f"An unexpected exception occurred. "
209
- f"Please create an issue at "
210
- f"'https://github.com/kumo-ai/kumo-rfm'. "
211
- f"{msg}") from None
489
+ def is_valid_entity(
490
+ self,
491
+ query: str,
492
+ indices: Union[List[str], List[float], List[int], None] = None,
493
+ *,
494
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
495
+ ) -> np.ndarray:
496
+ r"""Returns a mask that denotes which entities are valid for the
497
+ given predictive query, *i.e.*, which entities fulfill (temporal)
498
+ entity filter constraints.
212
499
 
213
- return pd.DataFrame(**resp.prediction)
500
+ Args:
501
+ query: The predictive query.
502
+ indices: The entity primary keys to predict for. Will override the
503
+ indices given as part of the predictive query.
504
+ anchor_time: The anchor timestamp for the prediction. If set to
505
+ ``None``, will use the maximum timestamp in the data.
506
+ If set to ``"entity"``, will use the timestamp of the entity.
507
+ """
508
+ query_def = self._parse_query(query)
509
+
510
+ if indices is None:
511
+ if query_def.rfm_entity_ids is None:
512
+ raise ValueError("Cannot find entities to predict for. Please "
513
+ "pass them via "
514
+ "`is_valid_entity(query, indices=...)`")
515
+ indices = query_def.get_rfm_entity_id_list()
516
+
517
+ if len(indices) == 0:
518
+ raise ValueError("At least one entity is required")
519
+
520
+ if anchor_time is None:
521
+ anchor_time = self._graph_store.max_time
522
+
523
+ if isinstance(anchor_time, pd.Timestamp):
524
+ self._validate_time(query_def, anchor_time, None, False)
525
+ else:
526
+ assert anchor_time == 'entity'
527
+ if (query_def.entity_table not in self._graph_store.time_dict):
528
+ raise ValueError(f"Anchor time 'entity' requires the entity "
529
+ f"table '{query_def.entity_table}' "
530
+ f"to have a time column.")
531
+
532
+ node = self._graph_store.get_node_id(
533
+ table_name=query_def.entity_table,
534
+ pkey=pd.Series(indices),
535
+ )
536
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
537
+ return query_driver.is_valid(node, anchor_time)
214
538
 
215
539
  def evaluate(
216
540
  self,
@@ -218,21 +542,26 @@ class KumoRFM:
218
542
  *,
219
543
  metrics: Optional[List[str]] = None,
220
544
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
545
+ context_anchor_time: Union[pd.Timestamp, None] = None,
221
546
  run_mode: Union[RunMode, str] = RunMode.FAST,
222
547
  num_neighbors: Optional[List[int]] = None,
223
548
  num_hops: int = 2,
224
549
  max_pq_iterations: int = 20,
225
550
  random_seed: Optional[int] = _RANDOM_SEED,
226
- verbose: bool = True,
551
+ verbose: Union[bool, ProgressLogger] = True,
552
+ use_prediction_time: bool = False,
227
553
  ) -> pd.DataFrame:
228
554
  """Evaluates a predictive query.
229
555
 
230
556
  Args:
231
557
  query: The predictive query.
232
558
  metrics: The metrics to use.
233
- anchor_time: The anchor timestamp for the query. If set to
234
- :obj:`None`, will use the maximum timestamp in the data.
235
- If set to :`"entity"`, will use the timestamp of the entity.
559
+ anchor_time: The anchor timestamp for the prediction. If set to
560
+ ``None``, will use the maximum timestamp in the data.
561
+ If set to ``"entity"``, will use the timestamp of the entity.
562
+ context_anchor_time: The maximum anchor timestamp for context
563
+ examples. If set to ``None``, ``anchor_time`` will
564
+ determine the anchor time for context examples.
236
565
  run_mode: The :class:`RunMode` for the query.
237
566
  num_neighbors: The number of neighbors to sample for each hop.
238
567
  If specified, the ``num_hops`` option will be ignored.
@@ -244,6 +573,9 @@ class KumoRFM:
244
573
  entities to find valid labels.
245
574
  random_seed: A manual seed for generating pseudo-random numbers.
246
575
  verbose: Whether to print verbose output.
576
+ use_prediction_time: Whether to use the anchor timestamp as an
577
+ additional feature during prediction. This is typically
578
+ beneficial for time series forecasting tasks.
247
579
 
248
580
  Returns:
249
581
  The metrics as a :class:`pandas.DataFrame`
@@ -254,13 +586,24 @@ class KumoRFM:
254
586
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
255
587
  f"custom 'num_hops={num_hops}' option")
256
588
 
589
+ if query_def.rfm_entity_ids is not None:
590
+ query_def = replace(
591
+ query_def,
592
+ rfm_entity_ids=None,
593
+ )
594
+
257
595
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
258
596
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
259
597
 
260
- with ProgressLogger(msg, verbose=verbose) as logger:
598
+ if not isinstance(verbose, ProgressLogger):
599
+ verbose = InteractiveProgressLogger(msg, verbose=verbose)
600
+
601
+ with verbose as logger:
261
602
  context = self._get_context(
262
- query_def,
603
+ query=query_def,
604
+ indices=None,
263
605
  anchor_time=anchor_time,
606
+ context_anchor_time=context_anchor_time,
264
607
  run_mode=RunMode(run_mode),
265
608
  num_neighbors=num_neighbors,
266
609
  num_hops=num_hops,
@@ -276,6 +619,7 @@ class KumoRFM:
276
619
  context=context,
277
620
  run_mode=RunMode(run_mode),
278
621
  metrics=metrics,
622
+ use_prediction_time=use_prediction_time,
279
623
  )
280
624
  with warnings.catch_warnings():
281
625
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -334,17 +678,19 @@ class KumoRFM:
334
678
 
335
679
  if anchor_time is None:
336
680
  anchor_time = self._graph_store.max_time
337
- anchor_time = anchor_time - query_def.target.end_offset
681
+ if query_def.target_ast.date_offset_range is not None:
682
+ anchor_time = anchor_time - (
683
+ query_def.target_ast.date_offset_range.end_date_offset *
684
+ query_def.num_forecasts)
338
685
 
339
686
  assert anchor_time is not None
340
687
  if isinstance(anchor_time, pd.Timestamp):
341
- self._validate_time(query_def, anchor_time, evaluate=True)
688
+ self._validate_time(query_def, anchor_time, None, evaluate=True)
342
689
  else:
343
690
  assert anchor_time == 'entity'
344
- if (query_def.entity.pkey.table_name
345
- not in self._graph_store.time_dict):
691
+ if (query_def.entity_table not in self._graph_store.time_dict):
346
692
  raise ValueError(f"Anchor time 'entity' requires the entity "
347
- f"table '{query_def.entity.pkey.table_name}' "
693
+ f"table '{query_def.entity_table}' "
348
694
  f"to have a time column")
349
695
 
350
696
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -355,18 +701,22 @@ class KumoRFM:
355
701
  anchor_time=anchor_time,
356
702
  batch_size=min(10_000, size),
357
703
  max_iterations=max_iterations,
704
+ guarantee_train_examples=False,
358
705
  )
359
706
 
707
+ entity = self._graph_store.pkey_map_dict[
708
+ query_def.entity_table].index[node]
709
+
360
710
  return pd.DataFrame({
361
- 'ENTITY': node,
711
+ 'ENTITY': entity,
362
712
  'ANCHOR_TIMESTAMP': time,
363
713
  'TARGET': y,
364
714
  })
365
715
 
366
716
  # Helpers #################################################################
367
717
 
368
- def _parse_query(self, query: str) -> PQueryDefinition:
369
- if isinstance(query, PQueryDefinition):
718
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
719
+ if isinstance(query, ValidatedPredictiveQuery):
370
720
  return query
371
721
 
372
722
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -376,12 +726,12 @@ class KumoRFM:
376
726
  "predictions or evaluations.")
377
727
 
378
728
  try:
379
- request = RFMValidateQueryRequest(
729
+ request = RFMParseQueryRequest(
380
730
  query=query,
381
731
  graph_definition=self._graph_def,
382
732
  )
383
733
 
384
- resp = global_state.client.rfm_api.validate_query(request)
734
+ resp = global_state.client.rfm_api.parse_query(request)
385
735
  # TODO Expose validation warnings.
386
736
 
387
737
  if len(resp.validation_response.warnings) > 0:
@@ -392,7 +742,7 @@ class KumoRFM:
392
742
  warnings.warn(f"Encountered the following warnings during "
393
743
  f"parsing:\n{msg}")
394
744
 
395
- return resp.query_definition
745
+ return resp.query
396
746
  except HTTPException as e:
397
747
  try:
398
748
  msg = json.loads(e.detail)['detail']
@@ -403,8 +753,9 @@ class KumoRFM:
403
753
 
404
754
  def _validate_time(
405
755
  self,
406
- query: PQueryDefinition,
756
+ query: ValidatedPredictiveQuery,
407
757
  anchor_time: pd.Timestamp,
758
+ context_anchor_time: Union[pd.Timestamp, None],
408
759
  evaluate: bool,
409
760
  ) -> None:
410
761
 
@@ -416,22 +767,45 @@ class KumoRFM:
416
767
  f"the earliest timestamp "
417
768
  f"'{self._graph_store.min_time}' in the data.")
418
769
 
419
- if anchor_time - query.target.end_offset < self._graph_store.min_time:
420
- raise ValueError(f"Anchor timestamp is too early or aggregation "
421
- f"time range is too large. To make this "
422
- f"prediction, we would need data back to "
423
- f"'{anchor_time - query.target.end_offset}', "
424
- f"however, your data only contains data back to "
770
+ if (context_anchor_time is not None
771
+ and context_anchor_time < self._graph_store.min_time):
772
+ raise ValueError(f"Context anchor timestamp is too early or "
773
+ f"aggregation time range is too large. To make "
774
+ f"this prediction, we would need data back to "
775
+ f"'{context_anchor_time}', however, your data "
776
+ f"only contains data back to "
425
777
  f"'{self._graph_store.min_time}'.")
426
778
 
427
- if (anchor_time - 2 * query.target.end_offset
428
- < self._graph_store.min_time):
429
- warnings.warn(f"Anchor timestamp is too early or aggregation "
430
- f"time range is too large. To form proper input "
431
- f"data, we would need data back to "
432
- f"'{anchor_time - 2 * query.target.end_offset}', "
433
- f"however, your data only contains data back to "
434
- f"'{self._graph_store.min_time}'.")
779
+ if query.target_ast.date_offset_range is not None:
780
+ end_offset = query.target_ast.date_offset_range.end_date_offset
781
+ else:
782
+ end_offset = pd.DateOffset(0)
783
+ forecast_end_offset = end_offset * query.num_forecasts
784
+ if (context_anchor_time is not None
785
+ and context_anchor_time > anchor_time):
786
+ warnings.warn(f"Context anchor timestamp "
787
+ f"(got '{context_anchor_time}') is set to a later "
788
+ f"date than the prediction anchor timestamp "
789
+ f"(got '{anchor_time}'). Please make sure this is "
790
+ f"intended.")
791
+ elif (query.query_type == QueryType.TEMPORAL
792
+ and context_anchor_time is not None
793
+ and context_anchor_time + forecast_end_offset > anchor_time):
794
+ warnings.warn(f"Aggregation for context examples at timestamp "
795
+ f"'{context_anchor_time}' will leak information "
796
+ f"from the prediction anchor timestamp "
797
+ f"'{anchor_time}'. Please make sure this is "
798
+ f"intended.")
799
+
800
+ elif (context_anchor_time is not None
801
+ and context_anchor_time - forecast_end_offset
802
+ < self._graph_store.min_time):
803
+ _time = context_anchor_time - forecast_end_offset
804
+ warnings.warn(f"Context anchor timestamp is too early or "
805
+ f"aggregation time range is too large. To form "
806
+ f"proper input data, we would need data back to "
807
+ f"'{_time}', however, your data only contains "
808
+ f"data back to '{self._graph_store.min_time}'.")
435
809
 
436
810
  if (not evaluate and anchor_time
437
811
  > self._graph_store.max_time + pd.DateOffset(days=1)):
@@ -439,17 +813,18 @@ class KumoRFM:
439
813
  f"latest timestamp '{self._graph_store.max_time}' "
440
814
  f"in the data. Please make sure this is intended.")
441
815
 
442
- if (evaluate and anchor_time
443
- > self._graph_store.max_time - query.target.end_offset):
816
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
817
+ if evaluate and anchor_time > max_eval_time:
444
818
  raise ValueError(
445
819
  f"Anchor timestamp for evaluation is after the latest "
446
- f"supported timestamp "
447
- f"'{self._graph_store.max_time - query.target.end_offset}'.")
820
+ f"supported timestamp '{max_eval_time}'.")
448
821
 
449
822
  def _get_context(
450
823
  self,
451
- query: PQueryDefinition,
824
+ query: ValidatedPredictiveQuery,
825
+ indices: Union[List[str], List[float], List[int], None],
452
826
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
827
+ context_anchor_time: Union[pd.Timestamp, None],
453
828
  run_mode: RunMode,
454
829
  num_neighbors: Optional[List[int]],
455
830
  num_hops: int,
@@ -474,8 +849,8 @@ class KumoRFM:
474
849
  f"must go beyond this for your use-case.")
475
850
 
476
851
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
477
- task_type = query.get_task_type(
478
- stypes=self._graph_store.stype_dict,
852
+ task_type = LocalPQueryDriver.get_task_type(
853
+ query,
479
854
  edge_types=self._graph_store.edge_types,
480
855
  )
481
856
 
@@ -507,28 +882,42 @@ class KumoRFM:
507
882
  else:
508
883
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
509
884
 
885
+ if query.target_ast.date_offset_range is None:
886
+ end_offset = pd.DateOffset(0)
887
+ else:
888
+ end_offset = query.target_ast.date_offset_range.end_date_offset
889
+ forecast_end_offset = end_offset * query.num_forecasts
510
890
  if anchor_time is None:
511
891
  anchor_time = self._graph_store.max_time
512
892
  if evaluate:
513
- anchor_time = anchor_time - query.target.end_offset
893
+ anchor_time = anchor_time - forecast_end_offset
514
894
  if logger is not None:
515
895
  assert isinstance(anchor_time, pd.Timestamp)
516
- if (anchor_time.hour == 0 and anchor_time.minute == 0
517
- and anchor_time.second == 0
518
- and anchor_time.microsecond == 0):
896
+ if anchor_time == pd.Timestamp.min:
897
+ pass # Static graph
898
+ elif (anchor_time.hour == 0 and anchor_time.minute == 0
899
+ and anchor_time.second == 0
900
+ and anchor_time.microsecond == 0):
519
901
  logger.log(f"Derived anchor time {anchor_time.date()}")
520
902
  else:
521
903
  logger.log(f"Derived anchor time {anchor_time}")
522
904
 
523
905
  assert anchor_time is not None
524
906
  if isinstance(anchor_time, pd.Timestamp):
525
- self._validate_time(query, anchor_time, evaluate)
907
+ if context_anchor_time is None:
908
+ context_anchor_time = anchor_time - forecast_end_offset
909
+ self._validate_time(query, anchor_time, context_anchor_time,
910
+ evaluate)
526
911
  else:
527
912
  assert anchor_time == 'entity'
528
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
913
+ if query.entity_table not in self._graph_store.time_dict:
529
914
  raise ValueError(f"Anchor time 'entity' requires the entity "
530
- f"table '{query.entity.pkey.table_name}' to "
915
+ f"table '{query.entity_table}' to "
531
916
  f"have a time column")
917
+ if context_anchor_time is not None:
918
+ warnings.warn("Ignoring option 'context_anchor_time' for "
919
+ "`anchor_time='entity'`")
920
+ context_anchor_time = None
532
921
 
533
922
  y_test: Optional[pd.Series] = None
534
923
  if evaluate:
@@ -540,6 +929,7 @@ class KumoRFM:
540
929
  size=max_test_size,
541
930
  anchor_time=anchor_time,
542
931
  max_iterations=max_pq_iterations,
932
+ guarantee_train_examples=True,
543
933
  )
544
934
  if logger is not None:
545
935
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -563,34 +953,31 @@ class KumoRFM:
563
953
  logger.log(msg)
564
954
 
565
955
  else:
566
- assert query.entity.ids is not None
956
+ assert indices is not None
567
957
 
568
- max_num_test = 200 if task_type.is_link_pred else 1000
569
- if len(query.entity.ids.value) > max_num_test:
958
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
570
959
  raise ValueError(f"Cannot predict for more than "
571
- f"{max_num_test:,} entities at once "
572
- f"(got {len(query.entity.ids.value):,})")
960
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
961
+ f"once (got {len(indices):,}). Use "
962
+ f"`KumoRFM.batch_mode` to process entities "
963
+ f"in batches")
573
964
 
574
965
  test_node = self._graph_store.get_node_id(
575
- table_name=query.entity.pkey.table_name,
576
- pkey=pd.Series(
577
- query.entity.ids.value,
578
- dtype=query.entity.ids.dtype,
579
- ),
966
+ table_name=query.entity_table,
967
+ pkey=pd.Series(indices),
580
968
  )
581
969
 
582
970
  if isinstance(anchor_time, pd.Timestamp):
583
971
  test_time = pd.Series(anchor_time).repeat(
584
972
  len(test_node)).reset_index(drop=True)
585
973
  else:
586
- time = self._graph_store.time_dict[
587
- query.entity.pkey.table_name]
974
+ time = self._graph_store.time_dict[query.entity_table]
588
975
  time = time[test_node] * 1000**3
589
976
  test_time = pd.Series(time, dtype='datetime64[ns]')
590
977
 
591
978
  train_node, train_time, y_train = query_driver.collect_train(
592
979
  size=_MAX_CONTEXT_SIZE[run_mode],
593
- anchor_time=anchor_time,
980
+ anchor_time=context_anchor_time or 'entity',
594
981
  exclude_node=test_node if (query.query_type == QueryType.STATIC
595
982
  or anchor_time == 'entity') else None,
596
983
  max_iterations=max_pq_iterations,
@@ -617,12 +1004,23 @@ class KumoRFM:
617
1004
  raise NotImplementedError
618
1005
  logger.log(msg)
619
1006
 
620
- entity_table_names = query.get_entity_table_names(
621
- self._graph_store.edge_types)
1007
+ entity_table_names: Tuple[str, ...]
1008
+ if task_type.is_link_pred:
1009
+ final_aggr = query.get_final_target_aggregation()
1010
+ assert final_aggr is not None
1011
+ edge_fkey = final_aggr._get_target_column_name()
1012
+ for edge_type in self._graph_store.edge_types:
1013
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1014
+ entity_table_names = (
1015
+ query.entity_table,
1016
+ edge_type[2],
1017
+ )
1018
+ else:
1019
+ entity_table_names = (query.entity_table, )
622
1020
 
623
1021
  # Exclude the entity anchor time from the feature set to prevent
624
1022
  # running out-of-distribution between in-context and test examples:
625
- exclude_cols_dict = query.exclude_cols_dict
1023
+ exclude_cols_dict = query.get_exclude_cols_dict()
626
1024
  if anchor_time == 'entity':
627
1025
  if entity_table_names[0] not in exclude_cols_dict:
628
1026
  exclude_cols_dict[entity_table_names[0]] = []
@@ -642,6 +1040,17 @@ class KumoRFM:
642
1040
  exclude_cols_dict=exclude_cols_dict,
643
1041
  )
644
1042
 
1043
+ if len(subgraph.table_dict) >= 15:
1044
+ raise ValueError(f"Cannot query from a graph with more than 15 "
1045
+ f"tables (got {len(subgraph.table_dict)}). "
1046
+ f"Please create a feature request at "
1047
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1048
+ f"must go beyond this for your use-case.")
1049
+
1050
+ step_size: Optional[int] = None
1051
+ if query.query_type == QueryType.TEMPORAL:
1052
+ step_size = date_offset_to_seconds(end_offset)
1053
+
645
1054
  return Context(
646
1055
  task_type=task_type,
647
1056
  entity_table_names=entity_table_names,
@@ -649,6 +1058,7 @@ class KumoRFM:
649
1058
  y_train=y_train,
650
1059
  y_test=y_test,
651
1060
  top_k=query.top_k,
1061
+ step_size=step_size,
652
1062
  )
653
1063
 
654
1064
  @staticmethod
@@ -664,7 +1074,7 @@ class KumoRFM:
664
1074
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
665
1075
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
666
1076
  elif task_type == TaskType.REGRESSION:
667
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1077
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
668
1078
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
669
1079
  supported_metrics = [
670
1080
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',