kumoai 2.11.0.dev202510161830__py3-none-any.whl → 2.13.0.dev202511211730__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,23 +5,32 @@ from collections import defaultdict
5
5
  from collections.abc import Generator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import Iterator, List, Literal, Optional, Union, overload
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
9
19
 
10
20
  import numpy as np
11
21
  import pandas as pd
12
22
  from kumoapi.model_plan import RunMode
13
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
24
  from kumoapi.rfm import Context
15
25
  from kumoapi.rfm import Explanation as ExplanationConfig
16
26
  from kumoapi.rfm import (
17
- PQueryDefinition,
18
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
19
29
  RFMPredictRequest,
20
- RFMValidateQueryRequest,
21
30
  )
22
31
  from kumoapi.task import TaskType
23
32
 
24
- from kumoai import global_state
33
+ from kumoai.client.rfm import RFMAPI
25
34
  from kumoai.exceptions import HTTPException
26
35
  from kumoai.experimental.rfm import LocalGraph
27
36
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
@@ -30,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
30
39
  LocalPQueryDriver,
31
40
  date_offset_to_seconds,
32
41
  )
42
+ from kumoai.mixin import CastMixin
33
43
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
34
44
 
35
45
  _RANDOM_SEED = 42
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
60
70
  "beyond this for your use-case.")
61
71
 
62
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
63
84
  @dataclass(repr=False)
64
85
  class Explanation:
65
86
  prediction: pd.DataFrame
@@ -87,6 +108,12 @@ class Explanation:
87
108
  def __repr__(self) -> str:
88
109
  return str((self.prediction, self.summary))
89
110
 
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
90
117
 
91
118
  class KumoRFM:
92
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
@@ -114,9 +141,9 @@ class KumoRFM:
114
141
 
115
142
  rfm = KumoRFM(graph)
116
143
 
117
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
118
- "FOR users.user_id=0")
119
- result = rfm.query(query)
144
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
145
+ "FOR users.user_id=1")
146
+ result = rfm.predict(query)
120
147
 
121
148
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
122
149
  # 1 0.85
@@ -147,6 +174,8 @@ class KumoRFM:
147
174
 
148
175
  self._batch_size: Optional[int | Literal['max']] = None
149
176
  self.num_retries: int = 0
177
+ from kumoai.experimental.rfm import global_state
178
+ self._api_client = RFMAPI(global_state.client)
150
179
 
151
180
  def __repr__(self) -> str:
152
181
  return f'{self.__class__.__name__}()'
@@ -199,6 +228,7 @@ class KumoRFM:
199
228
  max_pq_iterations: int = 20,
200
229
  random_seed: Optional[int] = _RANDOM_SEED,
201
230
  verbose: Union[bool, ProgressLogger] = True,
231
+ use_prediction_time: bool = False,
202
232
  ) -> pd.DataFrame:
203
233
  pass
204
234
 
@@ -208,7 +238,7 @@ class KumoRFM:
208
238
  query: str,
209
239
  indices: Union[List[str], List[float], List[int], None] = None,
210
240
  *,
211
- explain: Literal[True],
241
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
212
242
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
213
243
  context_anchor_time: Union[pd.Timestamp, None] = None,
214
244
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -217,6 +247,7 @@ class KumoRFM:
217
247
  max_pq_iterations: int = 20,
218
248
  random_seed: Optional[int] = _RANDOM_SEED,
219
249
  verbose: Union[bool, ProgressLogger] = True,
250
+ use_prediction_time: bool = False,
220
251
  ) -> Explanation:
221
252
  pass
222
253
 
@@ -225,7 +256,7 @@ class KumoRFM:
225
256
  query: str,
226
257
  indices: Union[List[str], List[float], List[int], None] = None,
227
258
  *,
228
- explain: bool = False,
259
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
229
260
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
230
261
  context_anchor_time: Union[pd.Timestamp, None] = None,
231
262
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -234,6 +265,7 @@ class KumoRFM:
234
265
  max_pq_iterations: int = 20,
235
266
  random_seed: Optional[int] = _RANDOM_SEED,
236
267
  verbose: Union[bool, ProgressLogger] = True,
268
+ use_prediction_time: bool = False,
237
269
  ) -> Union[pd.DataFrame, Explanation]:
238
270
  """Returns predictions for a predictive query.
239
271
 
@@ -244,9 +276,12 @@ class KumoRFM:
244
276
  be generated for all indices, independent of whether they
245
277
  fulfill entity filter constraints. To pre-filter entities, use
246
278
  :meth:`~KumoRFM.is_valid_entity`.
247
- explain: If set to ``True``, will additionally explain the
248
- prediction. Explainability is currently only supported for
249
- single entity predictions with ``run_mode="FAST"``.
279
+ explain: Configuration for explainability.
280
+ If set to ``True``, will additionally explain the prediction.
281
+ Passing in an :class:`ExplainConfig` instance provides control
282
+ over which parts of explanation are generated.
283
+ Explainability is currently only supported for single entity
284
+ predictions with ``run_mode="FAST"``.
250
285
  anchor_time: The anchor timestamp for the prediction. If set to
251
286
  ``None``, will use the maximum timestamp in the data.
252
287
  If set to ``"entity"``, will use the timestamp of the entity.
@@ -264,45 +299,54 @@ class KumoRFM:
264
299
  entities to find valid labels.
265
300
  random_seed: A manual seed for generating pseudo-random numbers.
266
301
  verbose: Whether to print verbose output.
302
+ use_prediction_time: Whether to use the anchor timestamp as an
303
+ additional feature during prediction. This is typically
304
+ beneficial for time series forecasting tasks.
267
305
 
268
306
  Returns:
269
307
  The predictions as a :class:`pandas.DataFrame`.
270
- If ``explain=True``, additionally returns a textual summary that
271
- explains the prediction.
308
+ If ``explain`` is provided, returns an :class:`Explanation` object
309
+ containing the prediction, summary, and details.
272
310
  """
311
+ explain_config: Optional[ExplainConfig] = None
312
+ if explain is True:
313
+ explain_config = ExplainConfig()
314
+ elif explain is not False:
315
+ explain_config = ExplainConfig._cast(explain)
316
+
273
317
  query_def = self._parse_query(query)
318
+ query_str = query_def.to_string()
274
319
 
275
320
  if num_hops != 2 and num_neighbors is not None:
276
321
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
277
322
  f"custom 'num_hops={num_hops}' option")
278
323
 
279
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
324
+ if explain_config is not None and run_mode in {
325
+ RunMode.NORMAL, RunMode.BEST
326
+ }:
280
327
  warnings.warn(f"Explainability is currently only supported for "
281
328
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
282
329
  f"mode has been reset. Please lower the run mode to "
283
330
  f"suppress this warning.")
284
331
 
285
332
  if indices is None:
286
- if query_def.entity.ids is None:
333
+ if query_def.rfm_entity_ids is None:
287
334
  raise ValueError("Cannot find entities to predict for. Please "
288
335
  "pass them via `predict(query, indices=...)`")
289
- indices = query_def.entity.ids.value
336
+ indices = query_def.get_rfm_entity_id_list()
290
337
  else:
291
- query_def = replace(
292
- query_def,
293
- entity=replace(query_def.entity, ids=None),
294
- )
338
+ query_def = replace(query_def, rfm_entity_ids=None)
295
339
 
296
340
  if len(indices) == 0:
297
341
  raise ValueError("At least one entity is required")
298
342
 
299
- if explain and len(indices) > 1:
343
+ if explain_config is not None and len(indices) > 1:
300
344
  raise ValueError(
301
345
  f"Cannot explain predictions for more than a single entity "
302
346
  f"(got {len(indices)})")
303
347
 
304
348
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
305
- if explain:
349
+ if explain_config is not None:
306
350
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
307
351
  else:
308
352
  msg = f'[bold]PREDICT[/bold] {query_repr}'
@@ -314,8 +358,8 @@ class KumoRFM:
314
358
 
315
359
  batch_size: Optional[int] = None
316
360
  if self._batch_size == 'max':
317
- task_type = query_def.get_task_type(
318
- stypes=self._graph_store.stype_dict,
361
+ task_type = LocalPQueryDriver.get_task_type(
362
+ query_def,
319
363
  edge_types=self._graph_store.edge_types,
320
364
  )
321
365
  batch_size = _MAX_PRED_SIZE[task_type]
@@ -353,6 +397,8 @@ class KumoRFM:
353
397
  request = RFMPredictRequest(
354
398
  context=context,
355
399
  run_mode=RunMode(run_mode),
400
+ query=query_str,
401
+ use_prediction_time=use_prediction_time,
356
402
  )
357
403
  with warnings.catch_warnings():
358
404
  warnings.filterwarnings('ignore', message='gencode')
@@ -375,17 +421,20 @@ class KumoRFM:
375
421
 
376
422
  for attempt in range(self.num_retries + 1):
377
423
  try:
378
- if explain:
379
- resp = global_state.client.rfm_api.explain(_bytes)
424
+ if explain_config is not None:
425
+ resp = self._api_client.explain(
426
+ request=_bytes,
427
+ skip_summary=explain_config.skip_summary,
428
+ )
380
429
  summary = resp.summary
381
430
  details = resp.details
382
431
  else:
383
- resp = global_state.client.rfm_api.predict(_bytes)
432
+ resp = self._api_client.predict(_bytes)
384
433
  df = pd.DataFrame(**resp.prediction)
385
434
 
386
435
  # Cast 'ENTITY' to correct data type:
387
436
  if 'ENTITY' in df:
388
- entity = query_def.entity.pkey.table_name
437
+ entity = query_def.entity_table
389
438
  pkey_map = self._graph_store.pkey_map_dict[entity]
390
439
  df['ENTITY'] = df['ENTITY'].astype(
391
440
  type(pkey_map.index[0]))
@@ -427,7 +476,7 @@ class KumoRFM:
427
476
  else:
428
477
  prediction = pd.concat(predictions, ignore_index=True)
429
478
 
430
- if explain:
479
+ if explain_config is not None:
431
480
  assert len(predictions) == 1
432
481
  assert summary is not None
433
482
  assert details is not None
@@ -461,11 +510,11 @@ class KumoRFM:
461
510
  query_def = self._parse_query(query)
462
511
 
463
512
  if indices is None:
464
- if query_def.entity.ids is None:
513
+ if query_def.rfm_entity_ids is None:
465
514
  raise ValueError("Cannot find entities to predict for. Please "
466
515
  "pass them via "
467
516
  "`is_valid_entity(query, indices=...)`")
468
- indices = query_def.entity.ids.value
517
+ indices = query_def.get_rfm_entity_id_list()
469
518
 
470
519
  if len(indices) == 0:
471
520
  raise ValueError("At least one entity is required")
@@ -477,14 +526,13 @@ class KumoRFM:
477
526
  self._validate_time(query_def, anchor_time, None, False)
478
527
  else:
479
528
  assert anchor_time == 'entity'
480
- if (query_def.entity.pkey.table_name
481
- not in self._graph_store.time_dict):
529
+ if (query_def.entity_table not in self._graph_store.time_dict):
482
530
  raise ValueError(f"Anchor time 'entity' requires the entity "
483
- f"table '{query_def.entity.pkey.table_name}' "
484
- f"to have a time column")
531
+ f"table '{query_def.entity_table}' "
532
+ f"to have a time column.")
485
533
 
486
534
  node = self._graph_store.get_node_id(
487
- table_name=query_def.entity.pkey.table_name,
535
+ table_name=query_def.entity_table,
488
536
  pkey=pd.Series(indices),
489
537
  )
490
538
  query_driver = LocalPQueryDriver(self._graph_store, query_def)
@@ -503,6 +551,7 @@ class KumoRFM:
503
551
  max_pq_iterations: int = 20,
504
552
  random_seed: Optional[int] = _RANDOM_SEED,
505
553
  verbose: Union[bool, ProgressLogger] = True,
554
+ use_prediction_time: bool = False,
506
555
  ) -> pd.DataFrame:
507
556
  """Evaluates a predictive query.
508
557
 
@@ -526,6 +575,9 @@ class KumoRFM:
526
575
  entities to find valid labels.
527
576
  random_seed: A manual seed for generating pseudo-random numbers.
528
577
  verbose: Whether to print verbose output.
578
+ use_prediction_time: Whether to use the anchor timestamp as an
579
+ additional feature during prediction. This is typically
580
+ beneficial for time series forecasting tasks.
529
581
 
530
582
  Returns:
531
583
  The metrics as a :class:`pandas.DataFrame`
@@ -536,10 +588,10 @@ class KumoRFM:
536
588
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
537
589
  f"custom 'num_hops={num_hops}' option")
538
590
 
539
- if query_def.entity.ids is not None:
591
+ if query_def.rfm_entity_ids is not None:
540
592
  query_def = replace(
541
593
  query_def,
542
- entity=replace(query_def.entity, ids=None),
594
+ rfm_entity_ids=None,
543
595
  )
544
596
 
545
597
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
@@ -569,6 +621,7 @@ class KumoRFM:
569
621
  context=context,
570
622
  run_mode=RunMode(run_mode),
571
623
  metrics=metrics,
624
+ use_prediction_time=use_prediction_time,
572
625
  )
573
626
  with warnings.catch_warnings():
574
627
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -579,10 +632,10 @@ class KumoRFM:
579
632
 
580
633
  if len(request_bytes) > _MAX_SIZE:
581
634
  stats_msg = Context.get_memory_stats(request_msg.context)
582
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
635
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
583
636
 
584
637
  try:
585
- resp = global_state.client.rfm_api.evaluate(request_bytes)
638
+ resp = self._api_client.evaluate(request_bytes)
586
639
  except HTTPException as e:
587
640
  try:
588
641
  msg = json.loads(e.detail)['detail']
@@ -627,18 +680,19 @@ class KumoRFM:
627
680
 
628
681
  if anchor_time is None:
629
682
  anchor_time = self._graph_store.max_time
630
- anchor_time = anchor_time - (query_def.target.end_offset *
631
- query_def.num_forecasts)
683
+ if query_def.target_ast.date_offset_range is not None:
684
+ anchor_time = anchor_time - (
685
+ query_def.target_ast.date_offset_range.end_date_offset *
686
+ query_def.num_forecasts)
632
687
 
633
688
  assert anchor_time is not None
634
689
  if isinstance(anchor_time, pd.Timestamp):
635
690
  self._validate_time(query_def, anchor_time, None, evaluate=True)
636
691
  else:
637
692
  assert anchor_time == 'entity'
638
- if (query_def.entity.pkey.table_name
639
- not in self._graph_store.time_dict):
693
+ if (query_def.entity_table not in self._graph_store.time_dict):
640
694
  raise ValueError(f"Anchor time 'entity' requires the entity "
641
- f"table '{query_def.entity.pkey.table_name}' "
695
+ f"table '{query_def.entity_table}' "
642
696
  f"to have a time column")
643
697
 
644
698
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -653,7 +707,7 @@ class KumoRFM:
653
707
  )
654
708
 
655
709
  entity = self._graph_store.pkey_map_dict[
656
- query_def.entity.pkey.table_name].index[node]
710
+ query_def.entity_table].index[node]
657
711
 
658
712
  return pd.DataFrame({
659
713
  'ENTITY': entity,
@@ -663,8 +717,8 @@ class KumoRFM:
663
717
 
664
718
  # Helpers #################################################################
665
719
 
666
- def _parse_query(self, query: str) -> PQueryDefinition:
667
- if isinstance(query, PQueryDefinition):
720
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
721
+ if isinstance(query, ValidatedPredictiveQuery):
668
722
  return query
669
723
 
670
724
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -674,12 +728,13 @@ class KumoRFM:
674
728
  "predictions or evaluations.")
675
729
 
676
730
  try:
677
- request = RFMValidateQueryRequest(
731
+ request = RFMParseQueryRequest(
678
732
  query=query,
679
733
  graph_definition=self._graph_def,
680
734
  )
681
735
 
682
- resp = global_state.client.rfm_api.validate_query(request)
736
+ resp = self._api_client.parse_query(request)
737
+
683
738
  # TODO Expose validation warnings.
684
739
 
685
740
  if len(resp.validation_response.warnings) > 0:
@@ -690,7 +745,7 @@ class KumoRFM:
690
745
  warnings.warn(f"Encountered the following warnings during "
691
746
  f"parsing:\n{msg}")
692
747
 
693
- return resp.query_definition
748
+ return resp.query
694
749
  except HTTPException as e:
695
750
  try:
696
751
  msg = json.loads(e.detail)['detail']
@@ -701,7 +756,7 @@ class KumoRFM:
701
756
 
702
757
  def _validate_time(
703
758
  self,
704
- query: PQueryDefinition,
759
+ query: ValidatedPredictiveQuery,
705
760
  anchor_time: pd.Timestamp,
706
761
  context_anchor_time: Union[pd.Timestamp, None],
707
762
  evaluate: bool,
@@ -724,6 +779,11 @@ class KumoRFM:
724
779
  f"only contains data back to "
725
780
  f"'{self._graph_store.min_time}'.")
726
781
 
782
+ if query.target_ast.date_offset_range is not None:
783
+ end_offset = query.target_ast.date_offset_range.end_date_offset
784
+ else:
785
+ end_offset = pd.DateOffset(0)
786
+ forecast_end_offset = end_offset * query.num_forecasts
727
787
  if (context_anchor_time is not None
728
788
  and context_anchor_time > anchor_time):
729
789
  warnings.warn(f"Context anchor timestamp "
@@ -732,19 +792,18 @@ class KumoRFM:
732
792
  f"(got '{anchor_time}'). Please make sure this is "
733
793
  f"intended.")
734
794
  elif (query.query_type == QueryType.TEMPORAL
735
- and context_anchor_time is not None and context_anchor_time +
736
- query.target.end_offset * query.num_forecasts > anchor_time):
795
+ and context_anchor_time is not None
796
+ and context_anchor_time + forecast_end_offset > anchor_time):
737
797
  warnings.warn(f"Aggregation for context examples at timestamp "
738
798
  f"'{context_anchor_time}' will leak information "
739
799
  f"from the prediction anchor timestamp "
740
800
  f"'{anchor_time}'. Please make sure this is "
741
801
  f"intended.")
742
802
 
743
- elif (context_anchor_time is not None and context_anchor_time -
744
- query.target.end_offset * query.num_forecasts
803
+ elif (context_anchor_time is not None
804
+ and context_anchor_time - forecast_end_offset
745
805
  < self._graph_store.min_time):
746
- _time = context_anchor_time - (query.target.end_offset *
747
- query.num_forecasts)
806
+ _time = context_anchor_time - forecast_end_offset
748
807
  warnings.warn(f"Context anchor timestamp is too early or "
749
808
  f"aggregation time range is too large. To form "
750
809
  f"proper input data, we would need data back to "
@@ -757,8 +816,7 @@ class KumoRFM:
757
816
  f"latest timestamp '{self._graph_store.max_time}' "
758
817
  f"in the data. Please make sure this is intended.")
759
818
 
760
- max_eval_time = (self._graph_store.max_time -
761
- query.target.end_offset * query.num_forecasts)
819
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
762
820
  if evaluate and anchor_time > max_eval_time:
763
821
  raise ValueError(
764
822
  f"Anchor timestamp for evaluation is after the latest "
@@ -766,7 +824,7 @@ class KumoRFM:
766
824
 
767
825
  def _get_context(
768
826
  self,
769
- query: PQueryDefinition,
827
+ query: ValidatedPredictiveQuery,
770
828
  indices: Union[List[str], List[float], List[int], None],
771
829
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
772
830
  context_anchor_time: Union[pd.Timestamp, None],
@@ -794,8 +852,8 @@ class KumoRFM:
794
852
  f"must go beyond this for your use-case.")
795
853
 
796
854
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
797
- task_type = query.get_task_type(
798
- stypes=self._graph_store.stype_dict,
855
+ task_type = LocalPQueryDriver.get_task_type(
856
+ query,
799
857
  edge_types=self._graph_store.edge_types,
800
858
  )
801
859
 
@@ -827,11 +885,15 @@ class KumoRFM:
827
885
  else:
828
886
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
829
887
 
888
+ if query.target_ast.date_offset_range is None:
889
+ end_offset = pd.DateOffset(0)
890
+ else:
891
+ end_offset = query.target_ast.date_offset_range.end_date_offset
892
+ forecast_end_offset = end_offset * query.num_forecasts
830
893
  if anchor_time is None:
831
894
  anchor_time = self._graph_store.max_time
832
895
  if evaluate:
833
- anchor_time = anchor_time - (query.target.end_offset *
834
- query.num_forecasts)
896
+ anchor_time = anchor_time - forecast_end_offset
835
897
  if logger is not None:
836
898
  assert isinstance(anchor_time, pd.Timestamp)
837
899
  if anchor_time == pd.Timestamp.min:
@@ -846,15 +908,14 @@ class KumoRFM:
846
908
  assert anchor_time is not None
847
909
  if isinstance(anchor_time, pd.Timestamp):
848
910
  if context_anchor_time is None:
849
- context_anchor_time = anchor_time - (query.target.end_offset *
850
- query.num_forecasts)
911
+ context_anchor_time = anchor_time - forecast_end_offset
851
912
  self._validate_time(query, anchor_time, context_anchor_time,
852
913
  evaluate)
853
914
  else:
854
915
  assert anchor_time == 'entity'
855
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
916
+ if query.entity_table not in self._graph_store.time_dict:
856
917
  raise ValueError(f"Anchor time 'entity' requires the entity "
857
- f"table '{query.entity.pkey.table_name}' to "
918
+ f"table '{query.entity_table}' to "
858
919
  f"have a time column")
859
920
  if context_anchor_time is not None:
860
921
  warnings.warn("Ignoring option 'context_anchor_time' for "
@@ -905,7 +966,7 @@ class KumoRFM:
905
966
  f"in batches")
906
967
 
907
968
  test_node = self._graph_store.get_node_id(
908
- table_name=query.entity.pkey.table_name,
969
+ table_name=query.entity_table,
909
970
  pkey=pd.Series(indices),
910
971
  )
911
972
 
@@ -913,8 +974,7 @@ class KumoRFM:
913
974
  test_time = pd.Series(anchor_time).repeat(
914
975
  len(test_node)).reset_index(drop=True)
915
976
  else:
916
- time = self._graph_store.time_dict[
917
- query.entity.pkey.table_name]
977
+ time = self._graph_store.time_dict[query.entity_table]
918
978
  time = time[test_node] * 1000**3
919
979
  test_time = pd.Series(time, dtype='datetime64[ns]')
920
980
 
@@ -947,12 +1007,23 @@ class KumoRFM:
947
1007
  raise NotImplementedError
948
1008
  logger.log(msg)
949
1009
 
950
- entity_table_names = query.get_entity_table_names(
951
- self._graph_store.edge_types)
1010
+ entity_table_names: Tuple[str, ...]
1011
+ if task_type.is_link_pred:
1012
+ final_aggr = query.get_final_target_aggregation()
1013
+ assert final_aggr is not None
1014
+ edge_fkey = final_aggr._get_target_column_name()
1015
+ for edge_type in self._graph_store.edge_types:
1016
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1017
+ entity_table_names = (
1018
+ query.entity_table,
1019
+ edge_type[2],
1020
+ )
1021
+ else:
1022
+ entity_table_names = (query.entity_table, )
952
1023
 
953
1024
  # Exclude the entity anchor time from the feature set to prevent
954
1025
  # running out-of-distribution between in-context and test examples:
955
- exclude_cols_dict = query.exclude_cols_dict
1026
+ exclude_cols_dict = query.get_exclude_cols_dict()
956
1027
  if anchor_time == 'entity':
957
1028
  if entity_table_names[0] not in exclude_cols_dict:
958
1029
  exclude_cols_dict[entity_table_names[0]] = []
@@ -967,7 +1038,6 @@ class KumoRFM:
967
1038
  train_time.astype('datetime64[ns]').astype(int).to_numpy(),
968
1039
  test_time.astype('datetime64[ns]').astype(int).to_numpy(),
969
1040
  ]),
970
- run_mode=run_mode,
971
1041
  num_neighbors=num_neighbors,
972
1042
  exclude_cols_dict=exclude_cols_dict,
973
1043
  )
@@ -981,7 +1051,7 @@ class KumoRFM:
981
1051
 
982
1052
  step_size: Optional[int] = None
983
1053
  if query.query_type == QueryType.TEMPORAL:
984
- step_size = date_offset_to_seconds(query.target.end_offset)
1054
+ step_size = date_offset_to_seconds(end_offset)
985
1055
 
986
1056
  return Context(
987
1057
  task_type=task_type,