kumoai 2.10.0.dev202510021830__py3-none-any.whl → 2.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,21 +5,28 @@ from collections import defaultdict
5
5
  from collections.abc import Generator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import Iterator, List, Literal, Optional, Union, overload
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
9
19
 
10
20
  import numpy as np
11
21
  import pandas as pd
12
22
  from kumoapi.model_plan import RunMode
13
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
24
  from kumoapi.rfm import Context
15
25
  from kumoapi.rfm import Explanation as ExplanationConfig
16
26
  from kumoapi.rfm import (
17
- PQueryDefinition,
18
27
  RFMEvaluateRequest,
19
- RFMExplanationResponse,
28
+ RFMParseQueryRequest,
20
29
  RFMPredictRequest,
21
- RFMPredictResponse,
22
- RFMValidateQueryRequest,
23
30
  )
24
31
  from kumoapi.task import TaskType
25
32
 
@@ -32,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
32
39
  LocalPQueryDriver,
33
40
  date_offset_to_seconds,
34
41
  )
42
+ from kumoai.mixin import CastMixin
35
43
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
36
44
 
37
45
  _RANDOM_SEED = 42
@@ -62,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
62
70
  "beyond this for your use-case.")
63
71
 
64
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
65
84
  @dataclass(repr=False)
66
85
  class Explanation:
67
86
  prediction: pd.DataFrame
@@ -89,6 +108,12 @@ class Explanation:
89
108
  def __repr__(self) -> str:
90
109
  return str((self.prediction, self.summary))
91
110
 
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
92
117
 
93
118
  class KumoRFM:
94
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
@@ -201,6 +226,7 @@ class KumoRFM:
201
226
  max_pq_iterations: int = 20,
202
227
  random_seed: Optional[int] = _RANDOM_SEED,
203
228
  verbose: Union[bool, ProgressLogger] = True,
229
+ use_prediction_time: bool = False,
204
230
  ) -> pd.DataFrame:
205
231
  pass
206
232
 
@@ -210,7 +236,7 @@ class KumoRFM:
210
236
  query: str,
211
237
  indices: Union[List[str], List[float], List[int], None] = None,
212
238
  *,
213
- explain: Literal[True],
239
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
214
240
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
215
241
  context_anchor_time: Union[pd.Timestamp, None] = None,
216
242
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -219,6 +245,7 @@ class KumoRFM:
219
245
  max_pq_iterations: int = 20,
220
246
  random_seed: Optional[int] = _RANDOM_SEED,
221
247
  verbose: Union[bool, ProgressLogger] = True,
248
+ use_prediction_time: bool = False,
222
249
  ) -> Explanation:
223
250
  pass
224
251
 
@@ -227,7 +254,7 @@ class KumoRFM:
227
254
  query: str,
228
255
  indices: Union[List[str], List[float], List[int], None] = None,
229
256
  *,
230
- explain: bool = False,
257
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
231
258
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
232
259
  context_anchor_time: Union[pd.Timestamp, None] = None,
233
260
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -236,16 +263,23 @@ class KumoRFM:
236
263
  max_pq_iterations: int = 20,
237
264
  random_seed: Optional[int] = _RANDOM_SEED,
238
265
  verbose: Union[bool, ProgressLogger] = True,
266
+ use_prediction_time: bool = False,
239
267
  ) -> Union[pd.DataFrame, Explanation]:
240
268
  """Returns predictions for a predictive query.
241
269
 
242
270
  Args:
243
271
  query: The predictive query.
244
272
  indices: The entity primary keys to predict for. Will override the
245
- indices given as part of the predictive query.
246
- explain: If set to ``True``, will additionally explain the
247
- prediction. Explainability is currently only supported for
248
- single entity predictions with ``run_mode="FAST"``.
273
+ indices given as part of the predictive query. Predictions will
274
+ be generated for all indices, independent of whether they
275
+ fulfill entity filter constraints. To pre-filter entities, use
276
+ :meth:`~KumoRFM.is_valid_entity`.
277
+ explain: Configuration for explainability.
278
+ If set to ``True``, will additionally explain the prediction.
279
+ Passing in an :class:`ExplainConfig` instance provides control
280
+ over which parts of explanation are generated.
281
+ Explainability is currently only supported for single entity
282
+ predictions with ``run_mode="FAST"``.
249
283
  anchor_time: The anchor timestamp for the prediction. If set to
250
284
  ``None``, will use the maximum timestamp in the data.
251
285
  If set to ``"entity"``, will use the timestamp of the entity.
@@ -263,46 +297,54 @@ class KumoRFM:
263
297
  entities to find valid labels.
264
298
  random_seed: A manual seed for generating pseudo-random numbers.
265
299
  verbose: Whether to print verbose output.
300
+ use_prediction_time: Whether to use the anchor timestamp as an
301
+ additional feature during prediction. This is typically
302
+ beneficial for time series forecasting tasks.
266
303
 
267
304
  Returns:
268
305
  The predictions as a :class:`pandas.DataFrame`.
269
- If ``explain=True``, additionally returns a textual summary that
270
- explains the prediction.
306
+ If ``explain`` is provided, returns an :class:`Explanation` object
307
+ containing the prediction, summary, and details.
271
308
  """
309
+ explain_config: Optional[ExplainConfig] = None
310
+ if explain is True:
311
+ explain_config = ExplainConfig()
312
+ elif explain is not False:
313
+ explain_config = ExplainConfig._cast(explain)
314
+
272
315
  query_def = self._parse_query(query)
316
+ query_str = query_def.to_string()
273
317
 
274
318
  if num_hops != 2 and num_neighbors is not None:
275
319
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
276
320
  f"custom 'num_hops={num_hops}' option")
277
321
 
278
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
322
+ if explain_config is not None and run_mode in {
323
+ RunMode.NORMAL, RunMode.BEST
324
+ }:
279
325
  warnings.warn(f"Explainability is currently only supported for "
280
326
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
281
327
  f"mode has been reset. Please lower the run mode to "
282
328
  f"suppress this warning.")
283
329
 
284
330
  if indices is None:
285
- if query_def.entity.ids is None:
331
+ if query_def.rfm_entity_ids is None:
286
332
  raise ValueError("Cannot find entities to predict for. Please "
287
333
  "pass them via `predict(query, indices=...)`")
288
- indices = query_def.entity.ids.value
334
+ indices = query_def.get_rfm_entity_id_list()
289
335
  else:
290
- query_def = replace(
291
- query_def,
292
- entity=replace(query_def.entity, ids=None),
293
- )
336
+ query_def = replace(query_def, rfm_entity_ids=None)
294
337
 
295
338
  if len(indices) == 0:
296
- raise ValueError("At least one entity is required for "
297
- "prediction")
339
+ raise ValueError("At least one entity is required")
298
340
 
299
- if explain and len(indices) > 1:
341
+ if explain_config is not None and len(indices) > 1:
300
342
  raise ValueError(
301
343
  f"Cannot explain predictions for more than a single entity "
302
344
  f"(got {len(indices)})")
303
345
 
304
346
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
305
- if explain:
347
+ if explain_config is not None:
306
348
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
307
349
  else:
308
350
  msg = f'[bold]PREDICT[/bold] {query_repr}'
@@ -314,8 +356,8 @@ class KumoRFM:
314
356
 
315
357
  batch_size: Optional[int] = None
316
358
  if self._batch_size == 'max':
317
- task_type = query_def.get_task_type(
318
- stypes=self._graph_store.stype_dict,
359
+ task_type = LocalPQueryDriver.get_task_type(
360
+ query_def,
319
361
  edge_types=self._graph_store.edge_types,
320
362
  )
321
363
  batch_size = _MAX_PRED_SIZE[task_type]
@@ -332,10 +374,9 @@ class KumoRFM:
332
374
  logger.log(f"Splitting {len(indices):,} entities into "
333
375
  f"{len(batches):,} batches of size {batch_size:,}")
334
376
 
335
- resps: Union[
336
- List[RFMPredictResponse],
337
- List[RFMExplanationResponse],
338
- ] = []
377
+ predictions: List[pd.DataFrame] = []
378
+ summary: Optional[str] = None
379
+ details: Optional[Explanation] = None
339
380
  for i, batch in enumerate(batches):
340
381
  # TODO Re-use the context for subsequent predictions.
341
382
  context = self._get_context(
@@ -354,6 +395,8 @@ class KumoRFM:
354
395
  request = RFMPredictRequest(
355
396
  context=context,
356
397
  run_mode=RunMode(run_mode),
398
+ query=query_str,
399
+ use_prediction_time=use_prediction_time,
357
400
  )
358
401
  with warnings.catch_warnings():
359
402
  warnings.filterwarnings('ignore', message='gencode')
@@ -376,11 +419,36 @@ class KumoRFM:
376
419
 
377
420
  for attempt in range(self.num_retries + 1):
378
421
  try:
379
- if explain:
380
- resp = global_state.client.rfm_api.explain(_bytes)
422
+ if explain_config is not None:
423
+ resp = global_state.client.rfm_api.explain(
424
+ request=_bytes,
425
+ skip_summary=explain_config.skip_summary,
426
+ )
427
+ summary = resp.summary
428
+ details = resp.details
381
429
  else:
382
430
  resp = global_state.client.rfm_api.predict(_bytes)
383
- resps.append(resp)
431
+ df = pd.DataFrame(**resp.prediction)
432
+
433
+ # Cast 'ENTITY' to correct data type:
434
+ if 'ENTITY' in df:
435
+ entity = query_def.entity_table
436
+ pkey_map = self._graph_store.pkey_map_dict[entity]
437
+ df['ENTITY'] = df['ENTITY'].astype(
438
+ type(pkey_map.index[0]))
439
+
440
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
441
+ if 'ANCHOR_TIMESTAMP' in df:
442
+ ser = df['ANCHOR_TIMESTAMP']
443
+ if not pd.api.types.is_datetime64_any_dtype(ser):
444
+ if isinstance(ser.iloc[0], str):
445
+ unit = None
446
+ else:
447
+ unit = 'ms'
448
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
449
+ ser, errors='coerce', unit=unit)
450
+
451
+ predictions.append(df)
384
452
 
385
453
  if (isinstance(verbose, InteractiveProgressLogger)
386
454
  and len(batches) > 1):
@@ -401,22 +469,73 @@ class KumoRFM:
401
469
 
402
470
  time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
403
471
 
404
- predictions = [pd.DataFrame(**resp.prediction) for resp in resps]
405
472
  if len(predictions) == 1:
406
473
  prediction = predictions[0]
407
474
  else:
408
475
  prediction = pd.concat(predictions, ignore_index=True)
409
476
 
410
- if explain:
411
- assert len(resps) == 1
477
+ if explain_config is not None:
478
+ assert len(predictions) == 1
479
+ assert summary is not None
480
+ assert details is not None
412
481
  return Explanation(
413
482
  prediction=prediction,
414
- summary=resps[0].summary,
415
- details=resps[0].details,
483
+ summary=summary,
484
+ details=details,
416
485
  )
417
486
 
418
487
  return prediction
419
488
 
489
+ def is_valid_entity(
490
+ self,
491
+ query: str,
492
+ indices: Union[List[str], List[float], List[int], None] = None,
493
+ *,
494
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
495
+ ) -> np.ndarray:
496
+ r"""Returns a mask that denotes which entities are valid for the
497
+ given predictive query, *i.e.*, which entities fulfill (temporal)
498
+ entity filter constraints.
499
+
500
+ Args:
501
+ query: The predictive query.
502
+ indices: The entity primary keys to predict for. Will override the
503
+ indices given as part of the predictive query.
504
+ anchor_time: The anchor timestamp for the prediction. If set to
505
+ ``None``, will use the maximum timestamp in the data.
506
+ If set to ``"entity"``, will use the timestamp of the entity.
507
+ """
508
+ query_def = self._parse_query(query)
509
+
510
+ if indices is None:
511
+ if query_def.rfm_entity_ids is None:
512
+ raise ValueError("Cannot find entities to predict for. Please "
513
+ "pass them via "
514
+ "`is_valid_entity(query, indices=...)`")
515
+ indices = query_def.get_rfm_entity_id_list()
516
+
517
+ if len(indices) == 0:
518
+ raise ValueError("At least one entity is required")
519
+
520
+ if anchor_time is None:
521
+ anchor_time = self._graph_store.max_time
522
+
523
+ if isinstance(anchor_time, pd.Timestamp):
524
+ self._validate_time(query_def, anchor_time, None, False)
525
+ else:
526
+ assert anchor_time == 'entity'
527
+ if (query_def.entity_table not in self._graph_store.time_dict):
528
+ raise ValueError(f"Anchor time 'entity' requires the entity "
529
+ f"table '{query_def.entity_table}' "
530
+ f"to have a time column.")
531
+
532
+ node = self._graph_store.get_node_id(
533
+ table_name=query_def.entity_table,
534
+ pkey=pd.Series(indices),
535
+ )
536
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
537
+ return query_driver.is_valid(node, anchor_time)
538
+
420
539
  def evaluate(
421
540
  self,
422
541
  query: str,
@@ -430,6 +549,7 @@ class KumoRFM:
430
549
  max_pq_iterations: int = 20,
431
550
  random_seed: Optional[int] = _RANDOM_SEED,
432
551
  verbose: Union[bool, ProgressLogger] = True,
552
+ use_prediction_time: bool = False,
433
553
  ) -> pd.DataFrame:
434
554
  """Evaluates a predictive query.
435
555
 
@@ -453,6 +573,9 @@ class KumoRFM:
453
573
  entities to find valid labels.
454
574
  random_seed: A manual seed for generating pseudo-random numbers.
455
575
  verbose: Whether to print verbose output.
576
+ use_prediction_time: Whether to use the anchor timestamp as an
577
+ additional feature during prediction. This is typically
578
+ beneficial for time series forecasting tasks.
456
579
 
457
580
  Returns:
458
581
  The metrics as a :class:`pandas.DataFrame`
@@ -463,10 +586,10 @@ class KumoRFM:
463
586
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
464
587
  f"custom 'num_hops={num_hops}' option")
465
588
 
466
- if query_def.entity.ids is not None:
589
+ if query_def.rfm_entity_ids is not None:
467
590
  query_def = replace(
468
591
  query_def,
469
- entity=replace(query_def.entity, ids=None),
592
+ rfm_entity_ids=None,
470
593
  )
471
594
 
472
595
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
@@ -496,6 +619,7 @@ class KumoRFM:
496
619
  context=context,
497
620
  run_mode=RunMode(run_mode),
498
621
  metrics=metrics,
622
+ use_prediction_time=use_prediction_time,
499
623
  )
500
624
  with warnings.catch_warnings():
501
625
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -506,7 +630,7 @@ class KumoRFM:
506
630
 
507
631
  if len(request_bytes) > _MAX_SIZE:
508
632
  stats_msg = Context.get_memory_stats(request_msg.context)
509
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
633
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
510
634
 
511
635
  try:
512
636
  resp = global_state.client.rfm_api.evaluate(request_bytes)
@@ -554,18 +678,19 @@ class KumoRFM:
554
678
 
555
679
  if anchor_time is None:
556
680
  anchor_time = self._graph_store.max_time
557
- anchor_time = anchor_time - (query_def.target.end_offset *
558
- query_def.num_forecasts)
681
+ if query_def.target_ast.date_offset_range is not None:
682
+ anchor_time = anchor_time - (
683
+ query_def.target_ast.date_offset_range.end_date_offset *
684
+ query_def.num_forecasts)
559
685
 
560
686
  assert anchor_time is not None
561
687
  if isinstance(anchor_time, pd.Timestamp):
562
688
  self._validate_time(query_def, anchor_time, None, evaluate=True)
563
689
  else:
564
690
  assert anchor_time == 'entity'
565
- if (query_def.entity.pkey.table_name
566
- not in self._graph_store.time_dict):
691
+ if (query_def.entity_table not in self._graph_store.time_dict):
567
692
  raise ValueError(f"Anchor time 'entity' requires the entity "
568
- f"table '{query_def.entity.pkey.table_name}' "
693
+ f"table '{query_def.entity_table}' "
569
694
  f"to have a time column")
570
695
 
571
696
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -580,7 +705,7 @@ class KumoRFM:
580
705
  )
581
706
 
582
707
  entity = self._graph_store.pkey_map_dict[
583
- query_def.entity.pkey.table_name].index[node]
708
+ query_def.entity_table].index[node]
584
709
 
585
710
  return pd.DataFrame({
586
711
  'ENTITY': entity,
@@ -590,8 +715,8 @@ class KumoRFM:
590
715
 
591
716
  # Helpers #################################################################
592
717
 
593
- def _parse_query(self, query: str) -> PQueryDefinition:
594
- if isinstance(query, PQueryDefinition):
718
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
719
+ if isinstance(query, ValidatedPredictiveQuery):
595
720
  return query
596
721
 
597
722
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -601,12 +726,12 @@ class KumoRFM:
601
726
  "predictions or evaluations.")
602
727
 
603
728
  try:
604
- request = RFMValidateQueryRequest(
729
+ request = RFMParseQueryRequest(
605
730
  query=query,
606
731
  graph_definition=self._graph_def,
607
732
  )
608
733
 
609
- resp = global_state.client.rfm_api.validate_query(request)
734
+ resp = global_state.client.rfm_api.parse_query(request)
610
735
  # TODO Expose validation warnings.
611
736
 
612
737
  if len(resp.validation_response.warnings) > 0:
@@ -617,7 +742,7 @@ class KumoRFM:
617
742
  warnings.warn(f"Encountered the following warnings during "
618
743
  f"parsing:\n{msg}")
619
744
 
620
- return resp.query_definition
745
+ return resp.query
621
746
  except HTTPException as e:
622
747
  try:
623
748
  msg = json.loads(e.detail)['detail']
@@ -628,7 +753,7 @@ class KumoRFM:
628
753
 
629
754
  def _validate_time(
630
755
  self,
631
- query: PQueryDefinition,
756
+ query: ValidatedPredictiveQuery,
632
757
  anchor_time: pd.Timestamp,
633
758
  context_anchor_time: Union[pd.Timestamp, None],
634
759
  evaluate: bool,
@@ -651,6 +776,11 @@ class KumoRFM:
651
776
  f"only contains data back to "
652
777
  f"'{self._graph_store.min_time}'.")
653
778
 
779
+ if query.target_ast.date_offset_range is not None:
780
+ end_offset = query.target_ast.date_offset_range.end_date_offset
781
+ else:
782
+ end_offset = pd.DateOffset(0)
783
+ forecast_end_offset = end_offset * query.num_forecasts
654
784
  if (context_anchor_time is not None
655
785
  and context_anchor_time > anchor_time):
656
786
  warnings.warn(f"Context anchor timestamp "
@@ -659,19 +789,18 @@ class KumoRFM:
659
789
  f"(got '{anchor_time}'). Please make sure this is "
660
790
  f"intended.")
661
791
  elif (query.query_type == QueryType.TEMPORAL
662
- and context_anchor_time is not None and context_anchor_time +
663
- query.target.end_offset * query.num_forecasts > anchor_time):
792
+ and context_anchor_time is not None
793
+ and context_anchor_time + forecast_end_offset > anchor_time):
664
794
  warnings.warn(f"Aggregation for context examples at timestamp "
665
795
  f"'{context_anchor_time}' will leak information "
666
796
  f"from the prediction anchor timestamp "
667
797
  f"'{anchor_time}'. Please make sure this is "
668
798
  f"intended.")
669
799
 
670
- elif (context_anchor_time is not None and context_anchor_time -
671
- query.target.end_offset * query.num_forecasts
800
+ elif (context_anchor_time is not None
801
+ and context_anchor_time - forecast_end_offset
672
802
  < self._graph_store.min_time):
673
- _time = context_anchor_time - (query.target.end_offset *
674
- query.num_forecasts)
803
+ _time = context_anchor_time - forecast_end_offset
675
804
  warnings.warn(f"Context anchor timestamp is too early or "
676
805
  f"aggregation time range is too large. To form "
677
806
  f"proper input data, we would need data back to "
@@ -684,8 +813,7 @@ class KumoRFM:
684
813
  f"latest timestamp '{self._graph_store.max_time}' "
685
814
  f"in the data. Please make sure this is intended.")
686
815
 
687
- max_eval_time = (self._graph_store.max_time -
688
- query.target.end_offset * query.num_forecasts)
816
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
689
817
  if evaluate and anchor_time > max_eval_time:
690
818
  raise ValueError(
691
819
  f"Anchor timestamp for evaluation is after the latest "
@@ -693,7 +821,7 @@ class KumoRFM:
693
821
 
694
822
  def _get_context(
695
823
  self,
696
- query: PQueryDefinition,
824
+ query: ValidatedPredictiveQuery,
697
825
  indices: Union[List[str], List[float], List[int], None],
698
826
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
699
827
  context_anchor_time: Union[pd.Timestamp, None],
@@ -721,8 +849,8 @@ class KumoRFM:
721
849
  f"must go beyond this for your use-case.")
722
850
 
723
851
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
724
- task_type = query.get_task_type(
725
- stypes=self._graph_store.stype_dict,
852
+ task_type = LocalPQueryDriver.get_task_type(
853
+ query,
726
854
  edge_types=self._graph_store.edge_types,
727
855
  )
728
856
 
@@ -754,11 +882,15 @@ class KumoRFM:
754
882
  else:
755
883
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
756
884
 
885
+ if query.target_ast.date_offset_range is None:
886
+ end_offset = pd.DateOffset(0)
887
+ else:
888
+ end_offset = query.target_ast.date_offset_range.end_date_offset
889
+ forecast_end_offset = end_offset * query.num_forecasts
757
890
  if anchor_time is None:
758
891
  anchor_time = self._graph_store.max_time
759
892
  if evaluate:
760
- anchor_time = anchor_time - (query.target.end_offset *
761
- query.num_forecasts)
893
+ anchor_time = anchor_time - forecast_end_offset
762
894
  if logger is not None:
763
895
  assert isinstance(anchor_time, pd.Timestamp)
764
896
  if anchor_time == pd.Timestamp.min:
@@ -773,15 +905,14 @@ class KumoRFM:
773
905
  assert anchor_time is not None
774
906
  if isinstance(anchor_time, pd.Timestamp):
775
907
  if context_anchor_time is None:
776
- context_anchor_time = anchor_time - (query.target.end_offset *
777
- query.num_forecasts)
908
+ context_anchor_time = anchor_time - forecast_end_offset
778
909
  self._validate_time(query, anchor_time, context_anchor_time,
779
910
  evaluate)
780
911
  else:
781
912
  assert anchor_time == 'entity'
782
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
913
+ if query.entity_table not in self._graph_store.time_dict:
783
914
  raise ValueError(f"Anchor time 'entity' requires the entity "
784
- f"table '{query.entity.pkey.table_name}' to "
915
+ f"table '{query.entity_table}' to "
785
916
  f"have a time column")
786
917
  if context_anchor_time is not None:
787
918
  warnings.warn("Ignoring option 'context_anchor_time' for "
@@ -832,7 +963,7 @@ class KumoRFM:
832
963
  f"in batches")
833
964
 
834
965
  test_node = self._graph_store.get_node_id(
835
- table_name=query.entity.pkey.table_name,
966
+ table_name=query.entity_table,
836
967
  pkey=pd.Series(indices),
837
968
  )
838
969
 
@@ -840,8 +971,7 @@ class KumoRFM:
840
971
  test_time = pd.Series(anchor_time).repeat(
841
972
  len(test_node)).reset_index(drop=True)
842
973
  else:
843
- time = self._graph_store.time_dict[
844
- query.entity.pkey.table_name]
974
+ time = self._graph_store.time_dict[query.entity_table]
845
975
  time = time[test_node] * 1000**3
846
976
  test_time = pd.Series(time, dtype='datetime64[ns]')
847
977
 
@@ -874,12 +1004,23 @@ class KumoRFM:
874
1004
  raise NotImplementedError
875
1005
  logger.log(msg)
876
1006
 
877
- entity_table_names = query.get_entity_table_names(
878
- self._graph_store.edge_types)
1007
+ entity_table_names: Tuple[str, ...]
1008
+ if task_type.is_link_pred:
1009
+ final_aggr = query.get_final_target_aggregation()
1010
+ assert final_aggr is not None
1011
+ edge_fkey = final_aggr._get_target_column_name()
1012
+ for edge_type in self._graph_store.edge_types:
1013
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1014
+ entity_table_names = (
1015
+ query.entity_table,
1016
+ edge_type[2],
1017
+ )
1018
+ else:
1019
+ entity_table_names = (query.entity_table, )
879
1020
 
880
1021
  # Exclude the entity anchor time from the feature set to prevent
881
1022
  # running out-of-distribution between in-context and test examples:
882
- exclude_cols_dict = query.exclude_cols_dict
1023
+ exclude_cols_dict = query.get_exclude_cols_dict()
883
1024
  if anchor_time == 'entity':
884
1025
  if entity_table_names[0] not in exclude_cols_dict:
885
1026
  exclude_cols_dict[entity_table_names[0]] = []
@@ -908,7 +1049,7 @@ class KumoRFM:
908
1049
 
909
1050
  step_size: Optional[int] = None
910
1051
  if query.query_type == QueryType.TEMPORAL:
911
- step_size = date_offset_to_seconds(query.target.end_offset)
1052
+ step_size = date_offset_to_seconds(end_offset)
912
1053
 
913
1054
  return Context(
914
1055
  task_type=task_type,
@@ -933,7 +1074,7 @@ class KumoRFM:
933
1074
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
934
1075
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
935
1076
  elif task_type == TaskType.REGRESSION:
936
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1077
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
937
1078
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
938
1079
  supported_metrics = [
939
1080
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
kumoai/jobs.py CHANGED
@@ -26,6 +26,7 @@ class JobInterface(ABC, Generic[IDType, JobRequestType, JobResourceType]):
26
26
  limit (int): Max number of jobs to list, default 10.
27
27
 
28
28
  Example:
29
+ >>> # doctest: +SKIP
29
30
  >>> tags = {'pquery_name': 'my_pquery_name'}
30
31
  >>> jobs = BatchPredictionJob.search_by_tags(tags)
31
32
  Search limited to 10 results based on the `limit` parameter.
@@ -370,9 +370,11 @@ class PredictiveQuery:
370
370
  train_table_job_api = global_state.client.generate_train_table_job_api
371
371
  job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
372
  GenerateTrainTableRequest(
373
- dict(custom_tags), pq_id, plan,
374
- graph_snapshot_id=self.graph.snapshot(
375
- non_blocking=non_blocking)))
373
+ dict(custom_tags),
374
+ pq_id,
375
+ plan,
376
+ None,
377
+ ))
376
378
 
377
379
  self._train_table = TrainingTableJob(job_id=job_id)
378
380
  if non_blocking:
@@ -451,9 +453,11 @@ class PredictiveQuery:
451
453
  bp_table_api = global_state.client.generate_prediction_table_job_api
452
454
  job_id: GeneratePredictionTableJobID = bp_table_api.create(
453
455
  GeneratePredictionTableRequest(
454
- dict(custom_tags), pq_id, plan,
455
- graph_snapshot_id=self.graph.snapshot(
456
- non_blocking=non_blocking)))
456
+ dict(custom_tags),
457
+ pq_id,
458
+ plan,
459
+ None,
460
+ ))
457
461
 
458
462
  self._prediction_table = PredictionTableJob(job_id=job_id)
459
463
  if non_blocking: