kumoai 2.12.0.dev202511031731__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +9 -13
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +35 -7
  6. kumoai/connector/utils.py +23 -2
  7. kumoai/experimental/rfm/__init__.py +164 -46
  8. kumoai/experimental/rfm/backend/__init__.py +0 -0
  9. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
  11. kumoai/experimental/rfm/backend/local/sampler.py +131 -0
  12. kumoai/experimental/rfm/backend/local/table.py +109 -0
  13. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  17. kumoai/experimental/rfm/base/__init__.py +14 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/sampler.py +287 -0
  20. kumoai/experimental/rfm/base/source.py +18 -0
  21. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  22. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  23. kumoai/experimental/rfm/infer/__init__.py +6 -0
  24. kumoai/experimental/rfm/infer/dtype.py +79 -0
  25. kumoai/experimental/rfm/infer/pkey.py +126 -0
  26. kumoai/experimental/rfm/infer/time_col.py +62 -0
  27. kumoai/experimental/rfm/local_graph_sampler.py +43 -4
  28. kumoai/experimental/rfm/local_pquery_driver.py +222 -27
  29. kumoai/experimental/rfm/pquery/__init__.py +0 -4
  30. kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
  31. kumoai/experimental/rfm/rfm.py +153 -96
  32. kumoai/experimental/rfm/sagemaker.py +138 -0
  33. kumoai/spcs.py +1 -3
  34. kumoai/testing/decorators.py +1 -1
  35. kumoai/utils/progress_logger.py +10 -4
  36. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/METADATA +12 -2
  37. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/RECORD +40 -27
  38. kumoai/experimental/rfm/pquery/backend.py +0 -136
  39. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
  40. kumoai/experimental/rfm/utils.py +0 -344
  41. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/WHEEL +0 -0
  42. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/top_level.txt +0 -0
@@ -5,31 +5,41 @@ from collections import defaultdict
5
5
  from collections.abc import Generator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import Iterator, List, Literal, Optional, Union, overload
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
9
19
 
10
20
  import numpy as np
11
21
  import pandas as pd
12
22
  from kumoapi.model_plan import RunMode
13
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
24
  from kumoapi.rfm import Context
15
25
  from kumoapi.rfm import Explanation as ExplanationConfig
16
26
  from kumoapi.rfm import (
17
- PQueryDefinition,
18
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
19
29
  RFMPredictRequest,
20
- RFMValidateQueryRequest,
21
30
  )
22
31
  from kumoapi.task import TaskType
23
32
 
24
- from kumoai import global_state
33
+ from kumoai.client.rfm import RFMAPI
25
34
  from kumoai.exceptions import HTTPException
26
- from kumoai.experimental.rfm import LocalGraph
35
+ from kumoai.experimental.rfm import Graph
36
+ from kumoai.experimental.rfm.backend.local import LocalGraphStore
27
37
  from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
28
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
29
38
  from kumoai.experimental.rfm.local_pquery_driver import (
30
39
  LocalPQueryDriver,
31
40
  date_offset_to_seconds,
32
41
  )
42
+ from kumoai.mixin import CastMixin
33
43
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
34
44
 
35
45
  _RANDOM_SEED = 42
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
60
70
  "beyond this for your use-case.")
61
71
 
62
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
63
84
  @dataclass(repr=False)
64
85
  class Explanation:
65
86
  prediction: pd.DataFrame
@@ -87,6 +108,12 @@ class Explanation:
87
108
  def __repr__(self) -> str:
88
109
  return str((self.prediction, self.summary))
89
110
 
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
90
117
 
91
118
  class KumoRFM:
92
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
@@ -96,17 +123,17 @@ class KumoRFM:
96
123
  :class:`KumoRFM` is a foundation model to generate predictions for any
97
124
  relational dataset without training.
98
125
  The model is pre-trained and the class provides an interface to query the
99
- model from a :class:`LocalGraph` object.
126
+ model from a :class:`Graph` object.
100
127
 
101
128
  .. code-block:: python
102
129
 
103
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
130
+ from kumoai.experimental.rfm import Graph, KumoRFM
104
131
 
105
132
  df_users = pd.DataFrame(...)
106
133
  df_items = pd.DataFrame(...)
107
134
  df_orders = pd.DataFrame(...)
108
135
 
109
- graph = LocalGraph.from_data({
136
+ graph = Graph.from_data({
110
137
  'users': df_users,
111
138
  'items': df_items,
112
139
  'orders': df_orders,
@@ -114,40 +141,41 @@ class KumoRFM:
114
141
 
115
142
  rfm = KumoRFM(graph)
116
143
 
117
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
118
- "FOR users.user_id=0")
119
- result = rfm.query(query)
144
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
145
+ "FOR users.user_id=1")
146
+ result = rfm.predict(query)
120
147
 
121
148
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
122
149
  # 1 0.85
123
150
 
124
151
  Args:
125
152
  graph: The graph.
126
- preprocess: Whether to pre-process the data in advance during graph
127
- materialization.
128
- This is a runtime trade-off between graph materialization and model
129
- processing speed.
130
- It can be benefical to preprocess your data once and then run many
131
- queries on top to achieve maximum model speed.
132
- However, if activiated, graph materialization can take potentially
133
- much longer, especially on graphs with many large text columns.
134
- Best to tune this option manually.
135
153
  verbose: Whether to print verbose output.
136
154
  """
137
155
  def __init__(
138
156
  self,
139
- graph: LocalGraph,
140
- preprocess: bool = False,
157
+ graph: Graph,
141
158
  verbose: Union[bool, ProgressLogger] = True,
142
159
  ) -> None:
143
160
  graph = graph.validate()
144
161
  self._graph_def = graph._to_api_graph_definition()
145
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
162
+ self._graph_store = LocalGraphStore(graph, verbose)
146
163
  self._graph_sampler = LocalGraphSampler(self._graph_store)
147
164
 
165
+ self._client: Optional[RFMAPI] = None
166
+
148
167
  self._batch_size: Optional[int | Literal['max']] = None
149
168
  self.num_retries: int = 0
150
169
 
170
+ @property
171
+ def _api_client(self) -> RFMAPI:
172
+ if self._client is not None:
173
+ return self._client
174
+
175
+ from kumoai.experimental.rfm import global_state
176
+ self._client = RFMAPI(global_state.client)
177
+ return self._client
178
+
151
179
  def __repr__(self) -> str:
152
180
  return f'{self.__class__.__name__}()'
153
181
 
@@ -209,7 +237,7 @@ class KumoRFM:
209
237
  query: str,
210
238
  indices: Union[List[str], List[float], List[int], None] = None,
211
239
  *,
212
- explain: Literal[True],
240
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
213
241
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
214
242
  context_anchor_time: Union[pd.Timestamp, None] = None,
215
243
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -227,7 +255,7 @@ class KumoRFM:
227
255
  query: str,
228
256
  indices: Union[List[str], List[float], List[int], None] = None,
229
257
  *,
230
- explain: bool = False,
258
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
231
259
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
232
260
  context_anchor_time: Union[pd.Timestamp, None] = None,
233
261
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -247,9 +275,12 @@ class KumoRFM:
247
275
  be generated for all indices, independent of whether they
248
276
  fulfill entity filter constraints. To pre-filter entities, use
249
277
  :meth:`~KumoRFM.is_valid_entity`.
250
- explain: If set to ``True``, will additionally explain the
251
- prediction. Explainability is currently only supported for
252
- single entity predictions with ``run_mode="FAST"``.
278
+ explain: Configuration for explainability.
279
+ If set to ``True``, will additionally explain the prediction.
280
+ Passing in an :class:`ExplainConfig` instance provides control
281
+ over which parts of explanation are generated.
282
+ Explainability is currently only supported for single entity
283
+ predictions with ``run_mode="FAST"``.
253
284
  anchor_time: The anchor timestamp for the prediction. If set to
254
285
  ``None``, will use the maximum timestamp in the data.
255
286
  If set to ``"entity"``, will use the timestamp of the entity.
@@ -273,42 +304,48 @@ class KumoRFM:
273
304
 
274
305
  Returns:
275
306
  The predictions as a :class:`pandas.DataFrame`.
276
- If ``explain=True``, additionally returns a textual summary that
277
- explains the prediction.
307
+ If ``explain`` is provided, returns an :class:`Explanation` object
308
+ containing the prediction, summary, and details.
278
309
  """
310
+ explain_config: Optional[ExplainConfig] = None
311
+ if explain is True:
312
+ explain_config = ExplainConfig()
313
+ elif explain is not False:
314
+ explain_config = ExplainConfig._cast(explain)
315
+
279
316
  query_def = self._parse_query(query)
317
+ query_str = query_def.to_string()
280
318
 
281
319
  if num_hops != 2 and num_neighbors is not None:
282
320
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
283
321
  f"custom 'num_hops={num_hops}' option")
284
322
 
285
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
323
+ if explain_config is not None and run_mode in {
324
+ RunMode.NORMAL, RunMode.BEST
325
+ }:
286
326
  warnings.warn(f"Explainability is currently only supported for "
287
327
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
288
328
  f"mode has been reset. Please lower the run mode to "
289
329
  f"suppress this warning.")
290
330
 
291
331
  if indices is None:
292
- if query_def.entity.ids is None:
332
+ if query_def.rfm_entity_ids is None:
293
333
  raise ValueError("Cannot find entities to predict for. Please "
294
334
  "pass them via `predict(query, indices=...)`")
295
- indices = query_def.entity.ids.value
335
+ indices = query_def.get_rfm_entity_id_list()
296
336
  else:
297
- query_def = replace(
298
- query_def,
299
- entity=replace(query_def.entity, ids=None),
300
- )
337
+ query_def = replace(query_def, rfm_entity_ids=None)
301
338
 
302
339
  if len(indices) == 0:
303
340
  raise ValueError("At least one entity is required")
304
341
 
305
- if explain and len(indices) > 1:
342
+ if explain_config is not None and len(indices) > 1:
306
343
  raise ValueError(
307
344
  f"Cannot explain predictions for more than a single entity "
308
345
  f"(got {len(indices)})")
309
346
 
310
347
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
311
- if explain:
348
+ if explain_config is not None:
312
349
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
313
350
  else:
314
351
  msg = f'[bold]PREDICT[/bold] {query_repr}'
@@ -320,8 +357,8 @@ class KumoRFM:
320
357
 
321
358
  batch_size: Optional[int] = None
322
359
  if self._batch_size == 'max':
323
- task_type = query_def.get_task_type(
324
- stypes=self._graph_store.stype_dict,
360
+ task_type = LocalPQueryDriver.get_task_type(
361
+ query_def,
325
362
  edge_types=self._graph_store.edge_types,
326
363
  )
327
364
  batch_size = _MAX_PRED_SIZE[task_type]
@@ -359,6 +396,7 @@ class KumoRFM:
359
396
  request = RFMPredictRequest(
360
397
  context=context,
361
398
  run_mode=RunMode(run_mode),
399
+ query=query_str,
362
400
  use_prediction_time=use_prediction_time,
363
401
  )
364
402
  with warnings.catch_warnings():
@@ -382,17 +420,20 @@ class KumoRFM:
382
420
 
383
421
  for attempt in range(self.num_retries + 1):
384
422
  try:
385
- if explain:
386
- resp = global_state.client.rfm_api.explain(_bytes)
423
+ if explain_config is not None:
424
+ resp = self._api_client.explain(
425
+ request=_bytes,
426
+ skip_summary=explain_config.skip_summary,
427
+ )
387
428
  summary = resp.summary
388
429
  details = resp.details
389
430
  else:
390
- resp = global_state.client.rfm_api.predict(_bytes)
431
+ resp = self._api_client.predict(_bytes)
391
432
  df = pd.DataFrame(**resp.prediction)
392
433
 
393
434
  # Cast 'ENTITY' to correct data type:
394
435
  if 'ENTITY' in df:
395
- entity = query_def.entity.pkey.table_name
436
+ entity = query_def.entity_table
396
437
  pkey_map = self._graph_store.pkey_map_dict[entity]
397
438
  df['ENTITY'] = df['ENTITY'].astype(
398
439
  type(pkey_map.index[0]))
@@ -434,7 +475,7 @@ class KumoRFM:
434
475
  else:
435
476
  prediction = pd.concat(predictions, ignore_index=True)
436
477
 
437
- if explain:
478
+ if explain_config is not None:
438
479
  assert len(predictions) == 1
439
480
  assert summary is not None
440
481
  assert details is not None
@@ -468,11 +509,11 @@ class KumoRFM:
468
509
  query_def = self._parse_query(query)
469
510
 
470
511
  if indices is None:
471
- if query_def.entity.ids is None:
512
+ if query_def.rfm_entity_ids is None:
472
513
  raise ValueError("Cannot find entities to predict for. Please "
473
514
  "pass them via "
474
515
  "`is_valid_entity(query, indices=...)`")
475
- indices = query_def.entity.ids.value
516
+ indices = query_def.get_rfm_entity_id_list()
476
517
 
477
518
  if len(indices) == 0:
478
519
  raise ValueError("At least one entity is required")
@@ -484,14 +525,13 @@ class KumoRFM:
484
525
  self._validate_time(query_def, anchor_time, None, False)
485
526
  else:
486
527
  assert anchor_time == 'entity'
487
- if (query_def.entity.pkey.table_name
488
- not in self._graph_store.time_dict):
528
+ if (query_def.entity_table not in self._graph_store.time_dict):
489
529
  raise ValueError(f"Anchor time 'entity' requires the entity "
490
- f"table '{query_def.entity.pkey.table_name}' "
491
- f"to have a time column")
530
+ f"table '{query_def.entity_table}' "
531
+ f"to have a time column.")
492
532
 
493
533
  node = self._graph_store.get_node_id(
494
- table_name=query_def.entity.pkey.table_name,
534
+ table_name=query_def.entity_table,
495
535
  pkey=pd.Series(indices),
496
536
  )
497
537
  query_driver = LocalPQueryDriver(self._graph_store, query_def)
@@ -547,10 +587,10 @@ class KumoRFM:
547
587
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
548
588
  f"custom 'num_hops={num_hops}' option")
549
589
 
550
- if query_def.entity.ids is not None:
590
+ if query_def.rfm_entity_ids is not None:
551
591
  query_def = replace(
552
592
  query_def,
553
- entity=replace(query_def.entity, ids=None),
593
+ rfm_entity_ids=None,
554
594
  )
555
595
 
556
596
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
@@ -591,10 +631,10 @@ class KumoRFM:
591
631
 
592
632
  if len(request_bytes) > _MAX_SIZE:
593
633
  stats_msg = Context.get_memory_stats(request_msg.context)
594
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
634
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
595
635
 
596
636
  try:
597
- resp = global_state.client.rfm_api.evaluate(request_bytes)
637
+ resp = self._api_client.evaluate(request_bytes)
598
638
  except HTTPException as e:
599
639
  try:
600
640
  msg = json.loads(e.detail)['detail']
@@ -639,18 +679,19 @@ class KumoRFM:
639
679
 
640
680
  if anchor_time is None:
641
681
  anchor_time = self._graph_store.max_time
642
- anchor_time = anchor_time - (query_def.target.end_offset *
643
- query_def.num_forecasts)
682
+ if query_def.target_ast.date_offset_range is not None:
683
+ anchor_time = anchor_time - (
684
+ query_def.target_ast.date_offset_range.end_date_offset *
685
+ query_def.num_forecasts)
644
686
 
645
687
  assert anchor_time is not None
646
688
  if isinstance(anchor_time, pd.Timestamp):
647
689
  self._validate_time(query_def, anchor_time, None, evaluate=True)
648
690
  else:
649
691
  assert anchor_time == 'entity'
650
- if (query_def.entity.pkey.table_name
651
- not in self._graph_store.time_dict):
692
+ if (query_def.entity_table not in self._graph_store.time_dict):
652
693
  raise ValueError(f"Anchor time 'entity' requires the entity "
653
- f"table '{query_def.entity.pkey.table_name}' "
694
+ f"table '{query_def.entity_table}' "
654
695
  f"to have a time column")
655
696
 
656
697
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -665,7 +706,7 @@ class KumoRFM:
665
706
  )
666
707
 
667
708
  entity = self._graph_store.pkey_map_dict[
668
- query_def.entity.pkey.table_name].index[node]
709
+ query_def.entity_table].index[node]
669
710
 
670
711
  return pd.DataFrame({
671
712
  'ENTITY': entity,
@@ -675,8 +716,8 @@ class KumoRFM:
675
716
 
676
717
  # Helpers #################################################################
677
718
 
678
- def _parse_query(self, query: str) -> PQueryDefinition:
679
- if isinstance(query, PQueryDefinition):
719
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
720
+ if isinstance(query, ValidatedPredictiveQuery):
680
721
  return query
681
722
 
682
723
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -686,12 +727,13 @@ class KumoRFM:
686
727
  "predictions or evaluations.")
687
728
 
688
729
  try:
689
- request = RFMValidateQueryRequest(
730
+ request = RFMParseQueryRequest(
690
731
  query=query,
691
732
  graph_definition=self._graph_def,
692
733
  )
693
734
 
694
- resp = global_state.client.rfm_api.validate_query(request)
735
+ resp = self._api_client.parse_query(request)
736
+
695
737
  # TODO Expose validation warnings.
696
738
 
697
739
  if len(resp.validation_response.warnings) > 0:
@@ -702,7 +744,7 @@ class KumoRFM:
702
744
  warnings.warn(f"Encountered the following warnings during "
703
745
  f"parsing:\n{msg}")
704
746
 
705
- return resp.query_definition
747
+ return resp.query
706
748
  except HTTPException as e:
707
749
  try:
708
750
  msg = json.loads(e.detail)['detail']
@@ -713,7 +755,7 @@ class KumoRFM:
713
755
 
714
756
  def _validate_time(
715
757
  self,
716
- query: PQueryDefinition,
758
+ query: ValidatedPredictiveQuery,
717
759
  anchor_time: pd.Timestamp,
718
760
  context_anchor_time: Union[pd.Timestamp, None],
719
761
  evaluate: bool,
@@ -736,6 +778,11 @@ class KumoRFM:
736
778
  f"only contains data back to "
737
779
  f"'{self._graph_store.min_time}'.")
738
780
 
781
+ if query.target_ast.date_offset_range is not None:
782
+ end_offset = query.target_ast.date_offset_range.end_date_offset
783
+ else:
784
+ end_offset = pd.DateOffset(0)
785
+ forecast_end_offset = end_offset * query.num_forecasts
739
786
  if (context_anchor_time is not None
740
787
  and context_anchor_time > anchor_time):
741
788
  warnings.warn(f"Context anchor timestamp "
@@ -744,19 +791,18 @@ class KumoRFM:
744
791
  f"(got '{anchor_time}'). Please make sure this is "
745
792
  f"intended.")
746
793
  elif (query.query_type == QueryType.TEMPORAL
747
- and context_anchor_time is not None and context_anchor_time +
748
- query.target.end_offset * query.num_forecasts > anchor_time):
794
+ and context_anchor_time is not None
795
+ and context_anchor_time + forecast_end_offset > anchor_time):
749
796
  warnings.warn(f"Aggregation for context examples at timestamp "
750
797
  f"'{context_anchor_time}' will leak information "
751
798
  f"from the prediction anchor timestamp "
752
799
  f"'{anchor_time}'. Please make sure this is "
753
800
  f"intended.")
754
801
 
755
- elif (context_anchor_time is not None and context_anchor_time -
756
- query.target.end_offset * query.num_forecasts
802
+ elif (context_anchor_time is not None
803
+ and context_anchor_time - forecast_end_offset
757
804
  < self._graph_store.min_time):
758
- _time = context_anchor_time - (query.target.end_offset *
759
- query.num_forecasts)
805
+ _time = context_anchor_time - forecast_end_offset
760
806
  warnings.warn(f"Context anchor timestamp is too early or "
761
807
  f"aggregation time range is too large. To form "
762
808
  f"proper input data, we would need data back to "
@@ -769,8 +815,7 @@ class KumoRFM:
769
815
  f"latest timestamp '{self._graph_store.max_time}' "
770
816
  f"in the data. Please make sure this is intended.")
771
817
 
772
- max_eval_time = (self._graph_store.max_time -
773
- query.target.end_offset * query.num_forecasts)
818
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
774
819
  if evaluate and anchor_time > max_eval_time:
775
820
  raise ValueError(
776
821
  f"Anchor timestamp for evaluation is after the latest "
@@ -778,7 +823,7 @@ class KumoRFM:
778
823
 
779
824
  def _get_context(
780
825
  self,
781
- query: PQueryDefinition,
826
+ query: ValidatedPredictiveQuery,
782
827
  indices: Union[List[str], List[float], List[int], None],
783
828
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
784
829
  context_anchor_time: Union[pd.Timestamp, None],
@@ -806,8 +851,8 @@ class KumoRFM:
806
851
  f"must go beyond this for your use-case.")
807
852
 
808
853
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
809
- task_type = query.get_task_type(
810
- stypes=self._graph_store.stype_dict,
854
+ task_type = LocalPQueryDriver.get_task_type(
855
+ query,
811
856
  edge_types=self._graph_store.edge_types,
812
857
  )
813
858
 
@@ -839,11 +884,15 @@ class KumoRFM:
839
884
  else:
840
885
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
841
886
 
887
+ if query.target_ast.date_offset_range is None:
888
+ end_offset = pd.DateOffset(0)
889
+ else:
890
+ end_offset = query.target_ast.date_offset_range.end_date_offset
891
+ forecast_end_offset = end_offset * query.num_forecasts
842
892
  if anchor_time is None:
843
893
  anchor_time = self._graph_store.max_time
844
894
  if evaluate:
845
- anchor_time = anchor_time - (query.target.end_offset *
846
- query.num_forecasts)
895
+ anchor_time = anchor_time - forecast_end_offset
847
896
  if logger is not None:
848
897
  assert isinstance(anchor_time, pd.Timestamp)
849
898
  if anchor_time == pd.Timestamp.min:
@@ -858,15 +907,14 @@ class KumoRFM:
858
907
  assert anchor_time is not None
859
908
  if isinstance(anchor_time, pd.Timestamp):
860
909
  if context_anchor_time is None:
861
- context_anchor_time = anchor_time - (query.target.end_offset *
862
- query.num_forecasts)
910
+ context_anchor_time = anchor_time - forecast_end_offset
863
911
  self._validate_time(query, anchor_time, context_anchor_time,
864
912
  evaluate)
865
913
  else:
866
914
  assert anchor_time == 'entity'
867
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
915
+ if query.entity_table not in self._graph_store.time_dict:
868
916
  raise ValueError(f"Anchor time 'entity' requires the entity "
869
- f"table '{query.entity.pkey.table_name}' to "
917
+ f"table '{query.entity_table}' to "
870
918
  f"have a time column")
871
919
  if context_anchor_time is not None:
872
920
  warnings.warn("Ignoring option 'context_anchor_time' for "
@@ -917,7 +965,7 @@ class KumoRFM:
917
965
  f"in batches")
918
966
 
919
967
  test_node = self._graph_store.get_node_id(
920
- table_name=query.entity.pkey.table_name,
968
+ table_name=query.entity_table,
921
969
  pkey=pd.Series(indices),
922
970
  )
923
971
 
@@ -925,8 +973,7 @@ class KumoRFM:
925
973
  test_time = pd.Series(anchor_time).repeat(
926
974
  len(test_node)).reset_index(drop=True)
927
975
  else:
928
- time = self._graph_store.time_dict[
929
- query.entity.pkey.table_name]
976
+ time = self._graph_store.time_dict[query.entity_table]
930
977
  time = time[test_node] * 1000**3
931
978
  test_time = pd.Series(time, dtype='datetime64[ns]')
932
979
 
@@ -959,12 +1006,23 @@ class KumoRFM:
959
1006
  raise NotImplementedError
960
1007
  logger.log(msg)
961
1008
 
962
- entity_table_names = query.get_entity_table_names(
963
- self._graph_store.edge_types)
1009
+ entity_table_names: Tuple[str, ...]
1010
+ if task_type.is_link_pred:
1011
+ final_aggr = query.get_final_target_aggregation()
1012
+ assert final_aggr is not None
1013
+ edge_fkey = final_aggr._get_target_column_name()
1014
+ for edge_type in self._graph_store.edge_types:
1015
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
+ entity_table_names = (
1017
+ query.entity_table,
1018
+ edge_type[2],
1019
+ )
1020
+ else:
1021
+ entity_table_names = (query.entity_table, )
964
1022
 
965
1023
  # Exclude the entity anchor time from the feature set to prevent
966
1024
  # running out-of-distribution between in-context and test examples:
967
- exclude_cols_dict = query.exclude_cols_dict
1025
+ exclude_cols_dict = query.get_exclude_cols_dict()
968
1026
  if anchor_time == 'entity':
969
1027
  if entity_table_names[0] not in exclude_cols_dict:
970
1028
  exclude_cols_dict[entity_table_names[0]] = []
@@ -979,7 +1037,6 @@ class KumoRFM:
979
1037
  train_time.astype('datetime64[ns]').astype(int).to_numpy(),
980
1038
  test_time.astype('datetime64[ns]').astype(int).to_numpy(),
981
1039
  ]),
982
- run_mode=run_mode,
983
1040
  num_neighbors=num_neighbors,
984
1041
  exclude_cols_dict=exclude_cols_dict,
985
1042
  )
@@ -993,7 +1050,7 @@ class KumoRFM:
993
1050
 
994
1051
  step_size: Optional[int] = None
995
1052
  if query.query_type == QueryType.TEMPORAL:
996
- step_size = date_offset_to_seconds(query.target.end_offset)
1053
+ step_size = date_offset_to_seconds(end_offset)
997
1054
 
998
1055
  return Context(
999
1056
  task_type=task_type,