kumoai 2.10.0.dev202509291830__cp312-cp312-macosx_11_0_arm64.whl → 2.13.0.dev202511161731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,19 +4,29 @@ import warnings
4
4
  from collections import defaultdict
5
5
  from collections.abc import Generator
6
6
  from contextlib import contextmanager
7
- from dataclasses import replace
8
- from typing import List, Literal, Optional, Union
7
+ from dataclasses import dataclass, replace
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
9
19
 
10
20
  import numpy as np
11
21
  import pandas as pd
12
22
  from kumoapi.model_plan import RunMode
13
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.rfm import Context
25
+ from kumoapi.rfm import Explanation as ExplanationConfig
14
26
  from kumoapi.rfm import (
15
- Context,
16
- PQueryDefinition,
17
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
18
29
  RFMPredictRequest,
19
- RFMValidateQueryRequest,
20
30
  )
21
31
  from kumoapi.task import TaskType
22
32
 
@@ -29,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
29
39
  LocalPQueryDriver,
30
40
  date_offset_to_seconds,
31
41
  )
42
+ from kumoai.mixin import CastMixin
32
43
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
33
44
 
34
45
  _RANDOM_SEED = 42
@@ -59,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
59
70
  "beyond this for your use-case.")
60
71
 
61
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
84
+ @dataclass(repr=False)
85
+ class Explanation:
86
+ prediction: pd.DataFrame
87
+ summary: str
88
+ details: ExplanationConfig
89
+
90
+ @overload
91
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
92
+ pass
93
+
94
+ @overload
95
+ def __getitem__(self, index: Literal[1]) -> str:
96
+ pass
97
+
98
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
99
+ if index == 0:
100
+ return self.prediction
101
+ if index == 1:
102
+ return self.summary
103
+ raise IndexError("Index out of range")
104
+
105
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
106
+ return iter((self.prediction, self.summary))
107
+
108
+ def __repr__(self) -> str:
109
+ return str((self.prediction, self.summary))
110
+
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
117
+
62
118
  class KumoRFM:
63
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
64
120
  Foundation Model for In-Context Learning on Relational Data
@@ -116,7 +172,7 @@ class KumoRFM:
116
172
  self._graph_store = LocalGraphStore(graph, preprocess, verbose)
117
173
  self._graph_sampler = LocalGraphSampler(self._graph_store)
118
174
 
119
- self._batch_size: Optional[int | Literal['auto']] = None
175
+ self._batch_size: Optional[int | Literal['max']] = None
120
176
  self.num_retries: int = 0
121
177
 
122
178
  def __repr__(self) -> str:
@@ -125,23 +181,23 @@ class KumoRFM:
125
181
  @contextmanager
126
182
  def batch_mode(
127
183
  self,
128
- batch_size: Union[int, Literal['auto']] = 'auto',
184
+ batch_size: Union[int, Literal['max']] = 'max',
129
185
  num_retries: int = 1,
130
186
  ) -> Generator[None, None, None]:
131
187
  """Context manager to predict in batches.
132
188
 
133
189
  .. code-block:: python
134
190
 
135
- with model.batch_mode(batch_size='auto', num_retries=1):
191
+ with model.batch_mode(batch_size='max', num_retries=1):
136
192
  df = model.predict(query, indices=...)
137
193
 
138
194
  Args:
139
- batch_size: The batch size. If set to ``"auto"``, will use the
195
+ batch_size: The batch size. If set to ``"max"``, will use the
140
196
  maximum applicable batch size for the given task.
141
197
  num_retries: The maximum number of retries for failed queries due
142
198
  to unexpected server issues.
143
199
  """
144
- if batch_size != 'auto' and batch_size <= 0:
200
+ if batch_size != 'max' and batch_size <= 0:
145
201
  raise ValueError(f"'batch_size' must be greater than zero "
146
202
  f"(got {batch_size})")
147
203
 
@@ -155,11 +211,13 @@ class KumoRFM:
155
211
  self._batch_size = None
156
212
  self.num_retries = 0
157
213
 
214
+ @overload
158
215
  def predict(
159
216
  self,
160
217
  query: str,
161
218
  indices: Union[List[str], List[float], List[int], None] = None,
162
219
  *,
220
+ explain: Literal[False] = False,
163
221
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
164
222
  context_anchor_time: Union[pd.Timestamp, None] = None,
165
223
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -168,18 +226,65 @@ class KumoRFM:
168
226
  max_pq_iterations: int = 20,
169
227
  random_seed: Optional[int] = _RANDOM_SEED,
170
228
  verbose: Union[bool, ProgressLogger] = True,
229
+ use_prediction_time: bool = False,
171
230
  ) -> pd.DataFrame:
231
+ pass
232
+
233
+ @overload
234
+ def predict(
235
+ self,
236
+ query: str,
237
+ indices: Union[List[str], List[float], List[int], None] = None,
238
+ *,
239
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
240
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
241
+ context_anchor_time: Union[pd.Timestamp, None] = None,
242
+ run_mode: Union[RunMode, str] = RunMode.FAST,
243
+ num_neighbors: Optional[List[int]] = None,
244
+ num_hops: int = 2,
245
+ max_pq_iterations: int = 20,
246
+ random_seed: Optional[int] = _RANDOM_SEED,
247
+ verbose: Union[bool, ProgressLogger] = True,
248
+ use_prediction_time: bool = False,
249
+ ) -> Explanation:
250
+ pass
251
+
252
+ def predict(
253
+ self,
254
+ query: str,
255
+ indices: Union[List[str], List[float], List[int], None] = None,
256
+ *,
257
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
258
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
259
+ context_anchor_time: Union[pd.Timestamp, None] = None,
260
+ run_mode: Union[RunMode, str] = RunMode.FAST,
261
+ num_neighbors: Optional[List[int]] = None,
262
+ num_hops: int = 2,
263
+ max_pq_iterations: int = 20,
264
+ random_seed: Optional[int] = _RANDOM_SEED,
265
+ verbose: Union[bool, ProgressLogger] = True,
266
+ use_prediction_time: bool = False,
267
+ ) -> Union[pd.DataFrame, Explanation]:
172
268
  """Returns predictions for a predictive query.
173
269
 
174
270
  Args:
175
271
  query: The predictive query.
176
- indices: The entity primary keys to predict on. Will override the
177
- indices given as part of the predictive query.
272
+ indices: The entity primary keys to predict for. Will override the
273
+ indices given as part of the predictive query. Predictions will
274
+ be generated for all indices, independent of whether they
275
+ fulfill entity filter constraints. To pre-filter entities, use
276
+ :meth:`~KumoRFM.is_valid_entity`.
277
+ explain: Configuration for explainability.
278
+ If set to ``True``, will additionally explain the prediction.
279
+ Passing in an :class:`ExplainConfig` instance provides control
280
+ over which parts of explanation are generated.
281
+ Explainability is currently only supported for single entity
282
+ predictions with ``run_mode="FAST"``.
178
283
  anchor_time: The anchor timestamp for the prediction. If set to
179
- :obj:`None`, will use the maximum timestamp in the data.
180
- If set to :`"entity"`, will use the timestamp of the entity.
284
+ ``None``, will use the maximum timestamp in the data.
285
+ If set to ``"entity"``, will use the timestamp of the entity.
181
286
  context_anchor_time: The maximum anchor timestamp for context
182
- examples. If set to :obj:`None`, :obj:`anchor_time` will
287
+ examples. If set to ``None``, ``anchor_time`` will
183
288
  determine the anchor time for context examples.
184
289
  run_mode: The :class:`RunMode` for the query.
185
290
  num_neighbors: The number of neighbors to sample for each hop.
@@ -192,46 +297,54 @@ class KumoRFM:
192
297
  entities to find valid labels.
193
298
  random_seed: A manual seed for generating pseudo-random numbers.
194
299
  verbose: Whether to print verbose output.
300
+ use_prediction_time: Whether to use the anchor timestamp as an
301
+ additional feature during prediction. This is typically
302
+ beneficial for time series forecasting tasks.
195
303
 
196
304
  Returns:
197
- The predictions as a :class:`pandas.DataFrame`
305
+ The predictions as a :class:`pandas.DataFrame`.
306
+ If ``explain`` is provided, returns an :class:`Explanation` object
307
+ containing the prediction, summary, and details.
198
308
  """
199
- explain = False
309
+ explain_config: Optional[ExplainConfig] = None
310
+ if explain is True:
311
+ explain_config = ExplainConfig()
312
+ elif explain is not False:
313
+ explain_config = ExplainConfig._cast(explain)
314
+
200
315
  query_def = self._parse_query(query)
316
+ query_str = query_def.to_string()
201
317
 
202
318
  if num_hops != 2 and num_neighbors is not None:
203
319
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
204
320
  f"custom 'num_hops={num_hops}' option")
205
321
 
206
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
322
+ if explain_config is not None and run_mode in {
323
+ RunMode.NORMAL, RunMode.BEST
324
+ }:
207
325
  warnings.warn(f"Explainability is currently only supported for "
208
326
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
209
327
  f"mode has been reset. Please lower the run mode to "
210
328
  f"suppress this warning.")
211
329
 
212
330
  if indices is None:
213
- if query_def.entity.ids is None:
331
+ if query_def.rfm_entity_ids is None:
214
332
  raise ValueError("Cannot find entities to predict for. Please "
215
333
  "pass them via `predict(query, indices=...)`")
216
- indices = query_def.entity.ids.value
334
+ indices = query_def.get_rfm_entity_id_list()
217
335
  else:
218
- query_def = replace(
219
- query_def,
220
- entity=replace(query_def.entity, ids=None),
221
- )
336
+ query_def = replace(query_def, rfm_entity_ids=None)
222
337
 
223
338
  if len(indices) == 0:
224
- raise ValueError("At least one entity is required for "
225
- "prediction")
339
+ raise ValueError("At least one entity is required")
226
340
 
227
- if explain:
228
- if len(indices) > 1:
229
- raise ValueError(
230
- f"Cannot explain predictions for more than a single "
231
- f"entity (got {len(indices)})")
341
+ if explain_config is not None and len(indices) > 1:
342
+ raise ValueError(
343
+ f"Cannot explain predictions for more than a single entity "
344
+ f"(got {len(indices)})")
232
345
 
233
346
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
234
- if explain:
347
+ if explain_config is not None:
235
348
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
236
349
  else:
237
350
  msg = f'[bold]PREDICT[/bold] {query_repr}'
@@ -242,9 +355,9 @@ class KumoRFM:
242
355
  with verbose as logger:
243
356
 
244
357
  batch_size: Optional[int] = None
245
- if self._batch_size == 'auto':
246
- task_type = query_def.get_task_type(
247
- stypes=self._graph_store.stype_dict,
358
+ if self._batch_size == 'max':
359
+ task_type = LocalPQueryDriver.get_task_type(
360
+ query_def,
248
361
  edge_types=self._graph_store.edge_types,
249
362
  )
250
363
  batch_size = _MAX_PRED_SIZE[task_type]
@@ -261,7 +374,9 @@ class KumoRFM:
261
374
  logger.log(f"Splitting {len(indices):,} entities into "
262
375
  f"{len(batches):,} batches of size {batch_size:,}")
263
376
 
264
- dfs: List[pd.DataFrame] = []
377
+ predictions: List[pd.DataFrame] = []
378
+ summary: Optional[str] = None
379
+ details: Optional[Explanation] = None
265
380
  for i, batch in enumerate(batches):
266
381
  # TODO Re-use the context for subsequent predictions.
267
382
  context = self._get_context(
@@ -280,6 +395,8 @@ class KumoRFM:
280
395
  request = RFMPredictRequest(
281
396
  context=context,
282
397
  run_mode=RunMode(run_mode),
398
+ query=query_str,
399
+ use_prediction_time=use_prediction_time,
283
400
  )
284
401
  with warnings.catch_warnings():
285
402
  warnings.filterwarnings('ignore', message='gencode')
@@ -302,11 +419,36 @@ class KumoRFM:
302
419
 
303
420
  for attempt in range(self.num_retries + 1):
304
421
  try:
305
- if explain:
306
- resp = global_state.client.rfm_api.explain(_bytes)
422
+ if explain_config is not None:
423
+ resp = global_state.client.rfm_api.explain(
424
+ request=_bytes,
425
+ skip_summary=explain_config.skip_summary,
426
+ )
427
+ summary = resp.summary
428
+ details = resp.details
307
429
  else:
308
430
  resp = global_state.client.rfm_api.predict(_bytes)
309
- dfs.append(pd.DataFrame(**resp.prediction))
431
+ df = pd.DataFrame(**resp.prediction)
432
+
433
+ # Cast 'ENTITY' to correct data type:
434
+ if 'ENTITY' in df:
435
+ entity = query_def.entity_table
436
+ pkey_map = self._graph_store.pkey_map_dict[entity]
437
+ df['ENTITY'] = df['ENTITY'].astype(
438
+ type(pkey_map.index[0]))
439
+
440
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
441
+ if 'ANCHOR_TIMESTAMP' in df:
442
+ ser = df['ANCHOR_TIMESTAMP']
443
+ if not pd.api.types.is_datetime64_any_dtype(ser):
444
+ if isinstance(ser.iloc[0], str):
445
+ unit = None
446
+ else:
447
+ unit = 'ms'
448
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
449
+ ser, errors='coerce', unit=unit)
450
+
451
+ predictions.append(df)
310
452
 
311
453
  if (isinstance(verbose, InteractiveProgressLogger)
312
454
  and len(batches) > 1):
@@ -327,7 +469,72 @@ class KumoRFM:
327
469
 
328
470
  time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
329
471
 
330
- return dfs[0] if len(dfs) == 1 else pd.concat(dfs, ignore_index=True)
472
+ if len(predictions) == 1:
473
+ prediction = predictions[0]
474
+ else:
475
+ prediction = pd.concat(predictions, ignore_index=True)
476
+
477
+ if explain_config is not None:
478
+ assert len(predictions) == 1
479
+ assert summary is not None
480
+ assert details is not None
481
+ return Explanation(
482
+ prediction=prediction,
483
+ summary=summary,
484
+ details=details,
485
+ )
486
+
487
+ return prediction
488
+
489
+ def is_valid_entity(
490
+ self,
491
+ query: str,
492
+ indices: Union[List[str], List[float], List[int], None] = None,
493
+ *,
494
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
495
+ ) -> np.ndarray:
496
+ r"""Returns a mask that denotes which entities are valid for the
497
+ given predictive query, *i.e.*, which entities fulfill (temporal)
498
+ entity filter constraints.
499
+
500
+ Args:
501
+ query: The predictive query.
502
+ indices: The entity primary keys to predict for. Will override the
503
+ indices given as part of the predictive query.
504
+ anchor_time: The anchor timestamp for the prediction. If set to
505
+ ``None``, will use the maximum timestamp in the data.
506
+ If set to ``"entity"``, will use the timestamp of the entity.
507
+ """
508
+ query_def = self._parse_query(query)
509
+
510
+ if indices is None:
511
+ if query_def.rfm_entity_ids is None:
512
+ raise ValueError("Cannot find entities to predict for. Please "
513
+ "pass them via "
514
+ "`is_valid_entity(query, indices=...)`")
515
+ indices = query_def.get_rfm_entity_id_list()
516
+
517
+ if len(indices) == 0:
518
+ raise ValueError("At least one entity is required")
519
+
520
+ if anchor_time is None:
521
+ anchor_time = self._graph_store.max_time
522
+
523
+ if isinstance(anchor_time, pd.Timestamp):
524
+ self._validate_time(query_def, anchor_time, None, False)
525
+ else:
526
+ assert anchor_time == 'entity'
527
+ if (query_def.entity_table not in self._graph_store.time_dict):
528
+ raise ValueError(f"Anchor time 'entity' requires the entity "
529
+ f"table '{query_def.entity_table}' "
530
+ f"to have a time column.")
531
+
532
+ node = self._graph_store.get_node_id(
533
+ table_name=query_def.entity_table,
534
+ pkey=pd.Series(indices),
535
+ )
536
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
537
+ return query_driver.is_valid(node, anchor_time)
331
538
 
332
539
  def evaluate(
333
540
  self,
@@ -342,6 +549,7 @@ class KumoRFM:
342
549
  max_pq_iterations: int = 20,
343
550
  random_seed: Optional[int] = _RANDOM_SEED,
344
551
  verbose: Union[bool, ProgressLogger] = True,
552
+ use_prediction_time: bool = False,
345
553
  ) -> pd.DataFrame:
346
554
  """Evaluates a predictive query.
347
555
 
@@ -349,10 +557,10 @@ class KumoRFM:
349
557
  query: The predictive query.
350
558
  metrics: The metrics to use.
351
559
  anchor_time: The anchor timestamp for the prediction. If set to
352
- :obj:`None`, will use the maximum timestamp in the data.
353
- If set to :`"entity"`, will use the timestamp of the entity.
560
+ ``None``, will use the maximum timestamp in the data.
561
+ If set to ``"entity"``, will use the timestamp of the entity.
354
562
  context_anchor_time: The maximum anchor timestamp for context
355
- examples. If set to :obj:`None`, :obj:`anchor_time` will
563
+ examples. If set to ``None``, ``anchor_time`` will
356
564
  determine the anchor time for context examples.
357
565
  run_mode: The :class:`RunMode` for the query.
358
566
  num_neighbors: The number of neighbors to sample for each hop.
@@ -365,6 +573,9 @@ class KumoRFM:
365
573
  entities to find valid labels.
366
574
  random_seed: A manual seed for generating pseudo-random numbers.
367
575
  verbose: Whether to print verbose output.
576
+ use_prediction_time: Whether to use the anchor timestamp as an
577
+ additional feature during prediction. This is typically
578
+ beneficial for time series forecasting tasks.
368
579
 
369
580
  Returns:
370
581
  The metrics as a :class:`pandas.DataFrame`
@@ -375,10 +586,10 @@ class KumoRFM:
375
586
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
376
587
  f"custom 'num_hops={num_hops}' option")
377
588
 
378
- if query_def.entity.ids is not None:
589
+ if query_def.rfm_entity_ids is not None:
379
590
  query_def = replace(
380
591
  query_def,
381
- entity=replace(query_def.entity, ids=None),
592
+ rfm_entity_ids=None,
382
593
  )
383
594
 
384
595
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
@@ -408,6 +619,7 @@ class KumoRFM:
408
619
  context=context,
409
620
  run_mode=RunMode(run_mode),
410
621
  metrics=metrics,
622
+ use_prediction_time=use_prediction_time,
411
623
  )
412
624
  with warnings.catch_warnings():
413
625
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -418,7 +630,7 @@ class KumoRFM:
418
630
 
419
631
  if len(request_bytes) > _MAX_SIZE:
420
632
  stats_msg = Context.get_memory_stats(request_msg.context)
421
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
633
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
422
634
 
423
635
  try:
424
636
  resp = global_state.client.rfm_api.evaluate(request_bytes)
@@ -466,18 +678,19 @@ class KumoRFM:
466
678
 
467
679
  if anchor_time is None:
468
680
  anchor_time = self._graph_store.max_time
469
- anchor_time = anchor_time - (query_def.target.end_offset *
470
- query_def.num_forecasts)
681
+ if query_def.target_ast.date_offset_range is not None:
682
+ anchor_time = anchor_time - (
683
+ query_def.target_ast.date_offset_range.end_date_offset *
684
+ query_def.num_forecasts)
471
685
 
472
686
  assert anchor_time is not None
473
687
  if isinstance(anchor_time, pd.Timestamp):
474
688
  self._validate_time(query_def, anchor_time, None, evaluate=True)
475
689
  else:
476
690
  assert anchor_time == 'entity'
477
- if (query_def.entity.pkey.table_name
478
- not in self._graph_store.time_dict):
691
+ if (query_def.entity_table not in self._graph_store.time_dict):
479
692
  raise ValueError(f"Anchor time 'entity' requires the entity "
480
- f"table '{query_def.entity.pkey.table_name}' "
693
+ f"table '{query_def.entity_table}' "
481
694
  f"to have a time column")
482
695
 
483
696
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -492,7 +705,7 @@ class KumoRFM:
492
705
  )
493
706
 
494
707
  entity = self._graph_store.pkey_map_dict[
495
- query_def.entity.pkey.table_name].index[node]
708
+ query_def.entity_table].index[node]
496
709
 
497
710
  return pd.DataFrame({
498
711
  'ENTITY': entity,
@@ -502,8 +715,8 @@ class KumoRFM:
502
715
 
503
716
  # Helpers #################################################################
504
717
 
505
- def _parse_query(self, query: str) -> PQueryDefinition:
506
- if isinstance(query, PQueryDefinition):
718
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
719
+ if isinstance(query, ValidatedPredictiveQuery):
507
720
  return query
508
721
 
509
722
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -513,12 +726,12 @@ class KumoRFM:
513
726
  "predictions or evaluations.")
514
727
 
515
728
  try:
516
- request = RFMValidateQueryRequest(
729
+ request = RFMParseQueryRequest(
517
730
  query=query,
518
731
  graph_definition=self._graph_def,
519
732
  )
520
733
 
521
- resp = global_state.client.rfm_api.validate_query(request)
734
+ resp = global_state.client.rfm_api.parse_query(request)
522
735
  # TODO Expose validation warnings.
523
736
 
524
737
  if len(resp.validation_response.warnings) > 0:
@@ -529,7 +742,7 @@ class KumoRFM:
529
742
  warnings.warn(f"Encountered the following warnings during "
530
743
  f"parsing:\n{msg}")
531
744
 
532
- return resp.query_definition
745
+ return resp.query
533
746
  except HTTPException as e:
534
747
  try:
535
748
  msg = json.loads(e.detail)['detail']
@@ -540,7 +753,7 @@ class KumoRFM:
540
753
 
541
754
  def _validate_time(
542
755
  self,
543
- query: PQueryDefinition,
756
+ query: ValidatedPredictiveQuery,
544
757
  anchor_time: pd.Timestamp,
545
758
  context_anchor_time: Union[pd.Timestamp, None],
546
759
  evaluate: bool,
@@ -563,6 +776,11 @@ class KumoRFM:
563
776
  f"only contains data back to "
564
777
  f"'{self._graph_store.min_time}'.")
565
778
 
779
+ if query.target_ast.date_offset_range is not None:
780
+ end_offset = query.target_ast.date_offset_range.end_date_offset
781
+ else:
782
+ end_offset = pd.DateOffset(0)
783
+ forecast_end_offset = end_offset * query.num_forecasts
566
784
  if (context_anchor_time is not None
567
785
  and context_anchor_time > anchor_time):
568
786
  warnings.warn(f"Context anchor timestamp "
@@ -571,19 +789,18 @@ class KumoRFM:
571
789
  f"(got '{anchor_time}'). Please make sure this is "
572
790
  f"intended.")
573
791
  elif (query.query_type == QueryType.TEMPORAL
574
- and context_anchor_time is not None and context_anchor_time +
575
- query.target.end_offset * query.num_forecasts > anchor_time):
792
+ and context_anchor_time is not None
793
+ and context_anchor_time + forecast_end_offset > anchor_time):
576
794
  warnings.warn(f"Aggregation for context examples at timestamp "
577
795
  f"'{context_anchor_time}' will leak information "
578
796
  f"from the prediction anchor timestamp "
579
797
  f"'{anchor_time}'. Please make sure this is "
580
798
  f"intended.")
581
799
 
582
- elif (context_anchor_time is not None and context_anchor_time -
583
- query.target.end_offset * query.num_forecasts
800
+ elif (context_anchor_time is not None
801
+ and context_anchor_time - forecast_end_offset
584
802
  < self._graph_store.min_time):
585
- _time = context_anchor_time - (query.target.end_offset *
586
- query.num_forecasts)
803
+ _time = context_anchor_time - forecast_end_offset
587
804
  warnings.warn(f"Context anchor timestamp is too early or "
588
805
  f"aggregation time range is too large. To form "
589
806
  f"proper input data, we would need data back to "
@@ -596,8 +813,7 @@ class KumoRFM:
596
813
  f"latest timestamp '{self._graph_store.max_time}' "
597
814
  f"in the data. Please make sure this is intended.")
598
815
 
599
- max_eval_time = (self._graph_store.max_time -
600
- query.target.end_offset * query.num_forecasts)
816
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
601
817
  if evaluate and anchor_time > max_eval_time:
602
818
  raise ValueError(
603
819
  f"Anchor timestamp for evaluation is after the latest "
@@ -605,7 +821,7 @@ class KumoRFM:
605
821
 
606
822
  def _get_context(
607
823
  self,
608
- query: PQueryDefinition,
824
+ query: ValidatedPredictiveQuery,
609
825
  indices: Union[List[str], List[float], List[int], None],
610
826
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
611
827
  context_anchor_time: Union[pd.Timestamp, None],
@@ -633,8 +849,8 @@ class KumoRFM:
633
849
  f"must go beyond this for your use-case.")
634
850
 
635
851
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
636
- task_type = query.get_task_type(
637
- stypes=self._graph_store.stype_dict,
852
+ task_type = LocalPQueryDriver.get_task_type(
853
+ query,
638
854
  edge_types=self._graph_store.edge_types,
639
855
  )
640
856
 
@@ -666,11 +882,15 @@ class KumoRFM:
666
882
  else:
667
883
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
668
884
 
885
+ if query.target_ast.date_offset_range is None:
886
+ end_offset = pd.DateOffset(0)
887
+ else:
888
+ end_offset = query.target_ast.date_offset_range.end_date_offset
889
+ forecast_end_offset = end_offset * query.num_forecasts
669
890
  if anchor_time is None:
670
891
  anchor_time = self._graph_store.max_time
671
892
  if evaluate:
672
- anchor_time = anchor_time - (query.target.end_offset *
673
- query.num_forecasts)
893
+ anchor_time = anchor_time - forecast_end_offset
674
894
  if logger is not None:
675
895
  assert isinstance(anchor_time, pd.Timestamp)
676
896
  if anchor_time == pd.Timestamp.min:
@@ -685,15 +905,14 @@ class KumoRFM:
685
905
  assert anchor_time is not None
686
906
  if isinstance(anchor_time, pd.Timestamp):
687
907
  if context_anchor_time is None:
688
- context_anchor_time = anchor_time - (query.target.end_offset *
689
- query.num_forecasts)
908
+ context_anchor_time = anchor_time - forecast_end_offset
690
909
  self._validate_time(query, anchor_time, context_anchor_time,
691
910
  evaluate)
692
911
  else:
693
912
  assert anchor_time == 'entity'
694
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
913
+ if query.entity_table not in self._graph_store.time_dict:
695
914
  raise ValueError(f"Anchor time 'entity' requires the entity "
696
- f"table '{query.entity.pkey.table_name}' to "
915
+ f"table '{query.entity_table}' to "
697
916
  f"have a time column")
698
917
  if context_anchor_time is not None:
699
918
  warnings.warn("Ignoring option 'context_anchor_time' for "
@@ -744,7 +963,7 @@ class KumoRFM:
744
963
  f"in batches")
745
964
 
746
965
  test_node = self._graph_store.get_node_id(
747
- table_name=query.entity.pkey.table_name,
966
+ table_name=query.entity_table,
748
967
  pkey=pd.Series(indices),
749
968
  )
750
969
 
@@ -752,8 +971,7 @@ class KumoRFM:
752
971
  test_time = pd.Series(anchor_time).repeat(
753
972
  len(test_node)).reset_index(drop=True)
754
973
  else:
755
- time = self._graph_store.time_dict[
756
- query.entity.pkey.table_name]
974
+ time = self._graph_store.time_dict[query.entity_table]
757
975
  time = time[test_node] * 1000**3
758
976
  test_time = pd.Series(time, dtype='datetime64[ns]')
759
977
 
@@ -786,12 +1004,23 @@ class KumoRFM:
786
1004
  raise NotImplementedError
787
1005
  logger.log(msg)
788
1006
 
789
- entity_table_names = query.get_entity_table_names(
790
- self._graph_store.edge_types)
1007
+ entity_table_names: Tuple[str, ...]
1008
+ if task_type.is_link_pred:
1009
+ final_aggr = query.get_final_target_aggregation()
1010
+ assert final_aggr is not None
1011
+ edge_fkey = final_aggr._get_target_column_name()
1012
+ for edge_type in self._graph_store.edge_types:
1013
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1014
+ entity_table_names = (
1015
+ query.entity_table,
1016
+ edge_type[2],
1017
+ )
1018
+ else:
1019
+ entity_table_names = (query.entity_table, )
791
1020
 
792
1021
  # Exclude the entity anchor time from the feature set to prevent
793
1022
  # running out-of-distribution between in-context and test examples:
794
- exclude_cols_dict = query.exclude_cols_dict
1023
+ exclude_cols_dict = query.get_exclude_cols_dict()
795
1024
  if anchor_time == 'entity':
796
1025
  if entity_table_names[0] not in exclude_cols_dict:
797
1026
  exclude_cols_dict[entity_table_names[0]] = []
@@ -820,7 +1049,7 @@ class KumoRFM:
820
1049
 
821
1050
  step_size: Optional[int] = None
822
1051
  if query.query_type == QueryType.TEMPORAL:
823
- step_size = date_offset_to_seconds(query.target.end_offset)
1052
+ step_size = date_offset_to_seconds(end_offset)
824
1053
 
825
1054
  return Context(
826
1055
  task_type=task_type,
@@ -845,7 +1074,7 @@ class KumoRFM:
845
1074
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
846
1075
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
847
1076
  elif task_type == TaskType.REGRESSION:
848
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1077
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
849
1078
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
850
1079
  supported_metrics = [
851
1080
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',