kumoai 2.12.1__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +9 -13
  4. kumoai/client/pquery.py +6 -2
  5. kumoai/connector/utils.py +23 -2
  6. kumoai/experimental/rfm/__init__.py +162 -46
  7. kumoai/experimental/rfm/backend/__init__.py +0 -0
  8. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
  10. kumoai/experimental/rfm/backend/local/sampler.py +313 -0
  11. kumoai/experimental/rfm/backend/local/table.py +119 -0
  12. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
  18. kumoai/experimental/rfm/base/__init__.py +23 -0
  19. kumoai/experimental/rfm/base/column.py +66 -0
  20. kumoai/experimental/rfm/base/sampler.py +773 -0
  21. kumoai/experimental/rfm/base/source.py +19 -0
  22. kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
  23. kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
  24. kumoai/experimental/rfm/infer/__init__.py +6 -0
  25. kumoai/experimental/rfm/infer/dtype.py +79 -0
  26. kumoai/experimental/rfm/infer/pkey.py +126 -0
  27. kumoai/experimental/rfm/infer/time_col.py +62 -0
  28. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  29. kumoai/experimental/rfm/rfm.py +233 -174
  30. kumoai/experimental/rfm/sagemaker.py +138 -0
  31. kumoai/spcs.py +1 -3
  32. kumoai/testing/decorators.py +1 -1
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +2 -0
  35. kumoai/utils/sql.py +3 -0
  36. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +12 -2
  37. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +40 -23
  38. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  39. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  40. kumoai/experimental/rfm/utils.py +0 -344
  41. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,13 @@ import numpy as np
21
21
  import pandas as pd
22
22
  from kumoapi.model_plan import RunMode
23
23
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.pquery.AST import (
25
+ Aggregation,
26
+ Column,
27
+ Condition,
28
+ Join,
29
+ LogicalOperation,
30
+ )
24
31
  from kumoapi.rfm import Context
25
32
  from kumoapi.rfm import Explanation as ExplanationConfig
26
33
  from kumoapi.rfm import (
@@ -29,16 +36,12 @@ from kumoapi.rfm import (
29
36
  RFMPredictRequest,
30
37
  )
31
38
  from kumoapi.task import TaskType
39
+ from kumoapi.typing import AggregationType, Stype
32
40
 
33
- from kumoai import global_state
41
+ from kumoai.client.rfm import RFMAPI
34
42
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
43
+ from kumoai.experimental.rfm import Graph
44
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
45
  from kumoai.mixin import CastMixin
43
46
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
44
47
 
@@ -123,17 +126,17 @@ class KumoRFM:
123
126
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
127
  relational dataset without training.
125
128
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
129
+ model from a :class:`Graph` object.
127
130
 
128
131
  .. code-block:: python
129
132
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
133
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
134
 
132
135
  df_users = pd.DataFrame(...)
133
136
  df_items = pd.DataFrame(...)
134
137
  df_orders = pd.DataFrame(...)
135
138
 
136
- graph = LocalGraph.from_data({
139
+ graph = Graph.from_data({
137
140
  'users': df_users,
138
141
  'items': df_items,
139
142
  'orders': df_orders,
@@ -141,40 +144,51 @@ class KumoRFM:
141
144
 
142
145
  rfm = KumoRFM(graph)
143
146
 
144
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
145
- "FOR users.user_id=0")
146
- result = rfm.query(query)
147
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
148
+ "FOR users.user_id=1")
149
+ result = rfm.predict(query)
147
150
 
148
151
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
149
152
  # 1 0.85
150
153
 
151
154
  Args:
152
155
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
156
  verbose: Whether to print verbose output.
163
157
  """
164
158
  def __init__(
165
159
  self,
166
- graph: LocalGraph,
167
- preprocess: bool = False,
160
+ graph: Graph,
168
161
  verbose: Union[bool, ProgressLogger] = True,
169
162
  ) -> None:
170
163
  graph = graph.validate()
171
164
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
- self._graph_sampler = LocalGraphSampler(self._graph_store)
165
+
166
+ if graph.backend == DataBackend.LOCAL:
167
+ from kumoai.experimental.rfm.backend.local import LocalSampler
168
+ self._sampler: Sampler = LocalSampler(graph, verbose)
169
+ elif graph.backend == DataBackend.SQLITE:
170
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
171
+ self._sampler = SQLiteSampler(graph, verbose)
172
+ elif graph.backend == DataBackend.SNOWFLAKE:
173
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
174
+ self._sampler = SnowSampler(graph, verbose)
175
+ else:
176
+ raise NotImplementedError
177
+
178
+ self._client: Optional[RFMAPI] = None
174
179
 
175
180
  self._batch_size: Optional[int | Literal['max']] = None
176
181
  self.num_retries: int = 0
177
182
 
183
+ @property
184
+ def _api_client(self) -> RFMAPI:
185
+ if self._client is not None:
186
+ return self._client
187
+
188
+ from kumoai.experimental.rfm import global_state
189
+ self._client = RFMAPI(global_state.client)
190
+ return self._client
191
+
178
192
  def __repr__(self) -> str:
179
193
  return f'{self.__class__.__name__}()'
180
194
 
@@ -223,7 +237,7 @@ class KumoRFM:
223
237
  run_mode: Union[RunMode, str] = RunMode.FAST,
224
238
  num_neighbors: Optional[List[int]] = None,
225
239
  num_hops: int = 2,
226
- max_pq_iterations: int = 20,
240
+ max_pq_iterations: int = 10,
227
241
  random_seed: Optional[int] = _RANDOM_SEED,
228
242
  verbose: Union[bool, ProgressLogger] = True,
229
243
  use_prediction_time: bool = False,
@@ -242,7 +256,7 @@ class KumoRFM:
242
256
  run_mode: Union[RunMode, str] = RunMode.FAST,
243
257
  num_neighbors: Optional[List[int]] = None,
244
258
  num_hops: int = 2,
245
- max_pq_iterations: int = 20,
259
+ max_pq_iterations: int = 10,
246
260
  random_seed: Optional[int] = _RANDOM_SEED,
247
261
  verbose: Union[bool, ProgressLogger] = True,
248
262
  use_prediction_time: bool = False,
@@ -260,7 +274,7 @@ class KumoRFM:
260
274
  run_mode: Union[RunMode, str] = RunMode.FAST,
261
275
  num_neighbors: Optional[List[int]] = None,
262
276
  num_hops: int = 2,
263
- max_pq_iterations: int = 20,
277
+ max_pq_iterations: int = 10,
264
278
  random_seed: Optional[int] = _RANDOM_SEED,
265
279
  verbose: Union[bool, ProgressLogger] = True,
266
280
  use_prediction_time: bool = False,
@@ -356,9 +370,9 @@ class KumoRFM:
356
370
 
357
371
  batch_size: Optional[int] = None
358
372
  if self._batch_size == 'max':
359
- task_type = LocalPQueryDriver.get_task_type(
360
- query_def,
361
- edge_types=self._graph_store.edge_types,
373
+ task_type = self._get_task_type(
374
+ query=query_def,
375
+ edge_types=self._sampler.edge_types,
362
376
  )
363
377
  batch_size = _MAX_PRED_SIZE[task_type]
364
378
  else:
@@ -420,22 +434,22 @@ class KumoRFM:
420
434
  for attempt in range(self.num_retries + 1):
421
435
  try:
422
436
  if explain_config is not None:
423
- resp = global_state.client.rfm_api.explain(
437
+ resp = self._api_client.explain(
424
438
  request=_bytes,
425
439
  skip_summary=explain_config.skip_summary,
426
440
  )
427
441
  summary = resp.summary
428
442
  details = resp.details
429
443
  else:
430
- resp = global_state.client.rfm_api.predict(_bytes)
444
+ resp = self._api_client.predict(_bytes)
431
445
  df = pd.DataFrame(**resp.prediction)
432
446
 
433
447
  # Cast 'ENTITY' to correct data type:
434
448
  if 'ENTITY' in df:
435
- entity = query_def.entity_table
436
- pkey_map = self._graph_store.pkey_map_dict[entity]
437
- df['ENTITY'] = df['ENTITY'].astype(
438
- type(pkey_map.index[0]))
449
+ table_dict = context.subgraph.table_dict
450
+ table = table_dict[query_def.entity_table]
451
+ ser = table.df[table.primary_key]
452
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
439
453
 
440
454
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
441
455
  if 'ANCHOR_TIMESTAMP' in df:
@@ -518,23 +532,18 @@ class KumoRFM:
518
532
  raise ValueError("At least one entity is required")
519
533
 
520
534
  if anchor_time is None:
521
- anchor_time = self._graph_store.max_time
535
+ anchor_time = self._get_default_anchor_time(query_def)
522
536
 
523
537
  if isinstance(anchor_time, pd.Timestamp):
524
538
  self._validate_time(query_def, anchor_time, None, False)
525
539
  else:
526
540
  assert anchor_time == 'entity'
527
- if (query_def.entity_table not in self._graph_store.time_dict):
541
+ if query_def.entity_table not in self._sampler.time_column_dict:
528
542
  raise ValueError(f"Anchor time 'entity' requires the entity "
529
543
  f"table '{query_def.entity_table}' "
530
544
  f"to have a time column.")
531
545
 
532
- node = self._graph_store.get_node_id(
533
- table_name=query_def.entity_table,
534
- pkey=pd.Series(indices),
535
- )
536
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
537
- return query_driver.is_valid(node, anchor_time)
546
+ raise NotImplementedError
538
547
 
539
548
  def evaluate(
540
549
  self,
@@ -546,7 +555,7 @@ class KumoRFM:
546
555
  run_mode: Union[RunMode, str] = RunMode.FAST,
547
556
  num_neighbors: Optional[List[int]] = None,
548
557
  num_hops: int = 2,
549
- max_pq_iterations: int = 20,
558
+ max_pq_iterations: int = 10,
550
559
  random_seed: Optional[int] = _RANDOM_SEED,
551
560
  verbose: Union[bool, ProgressLogger] = True,
552
561
  use_prediction_time: bool = False,
@@ -633,7 +642,7 @@ class KumoRFM:
633
642
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
634
643
 
635
644
  try:
636
- resp = global_state.client.rfm_api.evaluate(request_bytes)
645
+ resp = self._api_client.evaluate(request_bytes)
637
646
  except HTTPException as e:
638
647
  try:
639
648
  msg = json.loads(e.detail)['detail']
@@ -657,7 +666,7 @@ class KumoRFM:
657
666
  *,
658
667
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
659
668
  random_seed: Optional[int] = _RANDOM_SEED,
660
- max_iterations: int = 20,
669
+ max_iterations: int = 10,
661
670
  ) -> pd.DataFrame:
662
671
  """Returns the labels of a predictive query for a specified anchor
663
672
  time.
@@ -677,40 +686,37 @@ class KumoRFM:
677
686
  query_def = self._parse_query(query)
678
687
 
679
688
  if anchor_time is None:
680
- anchor_time = self._graph_store.max_time
689
+ anchor_time = self._get_default_anchor_time(query_def)
681
690
  if query_def.target_ast.date_offset_range is not None:
682
- anchor_time = anchor_time - (
683
- query_def.target_ast.date_offset_range.end_date_offset *
684
- query_def.num_forecasts)
691
+ offset = query_def.target_ast.date_offset_range.end_date_offset
692
+ offset *= query_def.num_forecasts
693
+ anchor_time -= offset
685
694
 
686
695
  assert anchor_time is not None
687
696
  if isinstance(anchor_time, pd.Timestamp):
688
697
  self._validate_time(query_def, anchor_time, None, evaluate=True)
689
698
  else:
690
699
  assert anchor_time == 'entity'
691
- if (query_def.entity_table not in self._graph_store.time_dict):
700
+ if query_def.entity_table not in self._sampler.time_column_dict:
692
701
  raise ValueError(f"Anchor time 'entity' requires the entity "
693
702
  f"table '{query_def.entity_table}' "
694
703
  f"to have a time column")
695
704
 
696
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
697
- random_seed)
698
-
699
- node, time, y = query_driver.collect_test(
700
- size=size,
701
- anchor_time=anchor_time,
702
- batch_size=min(10_000, size),
703
- max_iterations=max_iterations,
704
- guarantee_train_examples=False,
705
+ train, test = self._sampler.sample_target(
706
+ query=query,
707
+ num_train_examples=0,
708
+ train_anchor_time=anchor_time,
709
+ num_train_trials=0,
710
+ num_test_examples=size,
711
+ test_anchor_time=anchor_time,
712
+ num_test_trials=max_iterations * size,
713
+ random_seed=random_seed,
705
714
  )
706
715
 
707
- entity = self._graph_store.pkey_map_dict[
708
- query_def.entity_table].index[node]
709
-
710
716
  return pd.DataFrame({
711
- 'ENTITY': entity,
712
- 'ANCHOR_TIMESTAMP': time,
713
- 'TARGET': y,
717
+ 'ENTITY': test.entity_pkey,
718
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
719
+ 'TARGET': test.target,
714
720
  })
715
721
 
716
722
  # Helpers #################################################################
@@ -731,8 +737,7 @@ class KumoRFM:
731
737
  graph_definition=self._graph_def,
732
738
  )
733
739
 
734
- resp = global_state.client.rfm_api.parse_query(request)
735
- # TODO Expose validation warnings.
740
+ resp = self._api_client.parse_query(request)
736
741
 
737
742
  if len(resp.validation_response.warnings) > 0:
738
743
  msg = '\n'.join([
@@ -751,6 +756,60 @@ class KumoRFM:
751
756
  raise ValueError(f"Failed to parse query '{query}'. "
752
757
  f"{msg}") from None
753
758
 
759
+ @staticmethod
760
+ def _get_task_type(
761
+ query: ValidatedPredictiveQuery,
762
+ edge_types: List[Tuple[str, str, str]],
763
+ ) -> TaskType:
764
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
765
+ return TaskType.BINARY_CLASSIFICATION
766
+
767
+ target = query.target_ast
768
+ if isinstance(target, Join):
769
+ target = target.rhs_target
770
+ if isinstance(target, Aggregation):
771
+ if target.aggr == AggregationType.LIST_DISTINCT:
772
+ table_name, col_name = target._get_target_column_name().split(
773
+ '.')
774
+ target_edge_types = [
775
+ edge_type for edge_type in edge_types
776
+ if edge_type[0] == table_name and edge_type[1] == col_name
777
+ ]
778
+ if len(target_edge_types) != 1:
779
+ raise NotImplementedError(
780
+ f"Multilabel-classification queries based on "
781
+ f"'LIST_DISTINCT' are not supported yet. If you "
782
+ f"planned to write a link prediction query instead, "
783
+ f"make sure to register '{col_name}' as a "
784
+ f"foreign key.")
785
+ return TaskType.TEMPORAL_LINK_PREDICTION
786
+
787
+ return TaskType.REGRESSION
788
+
789
+ assert isinstance(target, Column)
790
+
791
+ if target.stype in {Stype.ID, Stype.categorical}:
792
+ return TaskType.MULTICLASS_CLASSIFICATION
793
+
794
+ if target.stype in {Stype.numerical}:
795
+ return TaskType.REGRESSION
796
+
797
+ raise NotImplementedError("Task type not yet supported")
798
+
799
+ def _get_default_anchor_time(
800
+ self,
801
+ query: ValidatedPredictiveQuery,
802
+ ) -> pd.Timestamp:
803
+ if query.query_type == QueryType.TEMPORAL:
804
+ aggr_table_names = [
805
+ aggr._get_target_column_name().split('.')[0]
806
+ for aggr in query.get_all_target_aggregations()
807
+ ]
808
+ return self._sampler.get_max_time(aggr_table_names)
809
+
810
+ assert query.query_type == QueryType.STATIC
811
+ return self._sampler.get_max_time()
812
+
754
813
  def _validate_time(
755
814
  self,
756
815
  query: ValidatedPredictiveQuery,
@@ -759,28 +818,30 @@ class KumoRFM:
759
818
  evaluate: bool,
760
819
  ) -> None:
761
820
 
762
- if self._graph_store.min_time == pd.Timestamp.max:
821
+ if len(self._sampler.time_column_dict) == 0:
763
822
  return # Graph without timestamps
764
823
 
765
- if anchor_time < self._graph_store.min_time:
824
+ min_time = self._sampler.get_min_time()
825
+ max_time = self._sampler.get_max_time()
826
+
827
+ if anchor_time < min_time:
766
828
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
767
- f"the earliest timestamp "
768
- f"'{self._graph_store.min_time}' in the data.")
829
+ f"the earliest timestamp '{min_time}' in the "
830
+ f"data.")
769
831
 
770
- if (context_anchor_time is not None
771
- and context_anchor_time < self._graph_store.min_time):
832
+ if context_anchor_time is not None and context_anchor_time < min_time:
772
833
  raise ValueError(f"Context anchor timestamp is too early or "
773
834
  f"aggregation time range is too large. To make "
774
835
  f"this prediction, we would need data back to "
775
836
  f"'{context_anchor_time}', however, your data "
776
- f"only contains data back to "
777
- f"'{self._graph_store.min_time}'.")
837
+ f"only contains data back to '{min_time}'.")
778
838
 
779
839
  if query.target_ast.date_offset_range is not None:
780
840
  end_offset = query.target_ast.date_offset_range.end_date_offset
781
841
  else:
782
842
  end_offset = pd.DateOffset(0)
783
- forecast_end_offset = end_offset * query.num_forecasts
843
+ end_offset = end_offset * query.num_forecasts
844
+
784
845
  if (context_anchor_time is not None
785
846
  and context_anchor_time > anchor_time):
786
847
  warnings.warn(f"Context anchor timestamp "
@@ -790,7 +851,7 @@ class KumoRFM:
790
851
  f"intended.")
791
852
  elif (query.query_type == QueryType.TEMPORAL
792
853
  and context_anchor_time is not None
793
- and context_anchor_time + forecast_end_offset > anchor_time):
854
+ and context_anchor_time + end_offset > anchor_time):
794
855
  warnings.warn(f"Aggregation for context examples at timestamp "
795
856
  f"'{context_anchor_time}' will leak information "
796
857
  f"from the prediction anchor timestamp "
@@ -798,26 +859,23 @@ class KumoRFM:
798
859
  f"intended.")
799
860
 
800
861
  elif (context_anchor_time is not None
801
- and context_anchor_time - forecast_end_offset
802
- < self._graph_store.min_time):
803
- _time = context_anchor_time - forecast_end_offset
862
+ and context_anchor_time - end_offset < min_time):
863
+ _time = context_anchor_time - end_offset
804
864
  warnings.warn(f"Context anchor timestamp is too early or "
805
865
  f"aggregation time range is too large. To form "
806
866
  f"proper input data, we would need data back to "
807
867
  f"'{_time}', however, your data only contains "
808
- f"data back to '{self._graph_store.min_time}'.")
868
+ f"data back to '{min_time}'.")
809
869
 
810
- if (not evaluate and anchor_time
811
- > self._graph_store.max_time + pd.DateOffset(days=1)):
870
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
812
871
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
813
- f"latest timestamp '{self._graph_store.max_time}' "
814
- f"in the data. Please make sure this is intended.")
872
+ f"latest timestamp '{max_time}' in the data. Please "
873
+ f"make sure this is intended.")
815
874
 
816
- max_eval_time = self._graph_store.max_time - forecast_end_offset
817
- if evaluate and anchor_time > max_eval_time:
875
+ if evaluate and anchor_time > max_time - end_offset:
818
876
  raise ValueError(
819
877
  f"Anchor timestamp for evaluation is after the latest "
820
- f"supported timestamp '{max_eval_time}'.")
878
+ f"supported timestamp '{max_time - end_offset}'.")
821
879
 
822
880
  def _get_context(
823
881
  self,
@@ -848,10 +906,9 @@ class KumoRFM:
848
906
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
849
907
  f"must go beyond this for your use-case.")
850
908
 
851
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
852
- task_type = LocalPQueryDriver.get_task_type(
853
- query,
854
- edge_types=self._graph_store.edge_types,
909
+ task_type = self._get_task_type(
910
+ query=query,
911
+ edge_types=self._sampler.edge_types,
855
912
  )
856
913
 
857
914
  if logger is not None:
@@ -883,14 +940,17 @@ class KumoRFM:
883
940
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
884
941
 
885
942
  if query.target_ast.date_offset_range is None:
886
- end_offset = pd.DateOffset(0)
943
+ step_offset = pd.DateOffset(0)
887
944
  else:
888
- end_offset = query.target_ast.date_offset_range.end_date_offset
889
- forecast_end_offset = end_offset * query.num_forecasts
945
+ step_offset = query.target_ast.date_offset_range.end_date_offset
946
+ end_offset = step_offset * query.num_forecasts
947
+
890
948
  if anchor_time is None:
891
- anchor_time = self._graph_store.max_time
949
+ anchor_time = self._get_default_anchor_time(query)
950
+
892
951
  if evaluate:
893
- anchor_time = anchor_time - forecast_end_offset
952
+ anchor_time = anchor_time - end_offset
953
+
894
954
  if logger is not None:
895
955
  assert isinstance(anchor_time, pd.Timestamp)
896
956
  if anchor_time == pd.Timestamp.min:
@@ -904,57 +964,71 @@ class KumoRFM:
904
964
 
905
965
  assert anchor_time is not None
906
966
  if isinstance(anchor_time, pd.Timestamp):
967
+ if context_anchor_time == 'entity':
968
+ raise ValueError("Anchor time 'entity' needs to be shared "
969
+ "for context and prediction examples")
907
970
  if context_anchor_time is None:
908
- context_anchor_time = anchor_time - forecast_end_offset
971
+ context_anchor_time = anchor_time - end_offset
909
972
  self._validate_time(query, anchor_time, context_anchor_time,
910
973
  evaluate)
911
974
  else:
912
975
  assert anchor_time == 'entity'
913
- if query.entity_table not in self._graph_store.time_dict:
976
+ if query.query_type != QueryType.STATIC:
977
+ raise ValueError("Anchor time 'entity' is only valid for "
978
+ "static predictive queries")
979
+ if query.entity_table not in self._sampler.time_column_dict:
914
980
  raise ValueError(f"Anchor time 'entity' requires the entity "
915
981
  f"table '{query.entity_table}' to "
916
982
  f"have a time column")
917
- if context_anchor_time is not None:
918
- warnings.warn("Ignoring option 'context_anchor_time' for "
919
- "`anchor_time='entity'`")
920
- context_anchor_time = None
983
+ if isinstance(context_anchor_time, pd.Timestamp):
984
+ raise ValueError("Anchor time 'entity' needs to be shared "
985
+ "for context and prediction examples")
986
+ context_anchor_time = 'entity'
921
987
 
922
- y_test: Optional[pd.Series] = None
988
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
923
989
  if evaluate:
924
- max_test_size = _MAX_TEST_SIZE[run_mode]
990
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
925
991
  if task_type.is_link_pred:
926
- max_test_size = max_test_size // 5
992
+ num_test_examples = num_test_examples // 5
993
+ else:
994
+ num_test_examples = 0
995
+
996
+ train, test = self._sampler.sample_target(
997
+ query=query,
998
+ num_train_examples=num_train_examples,
999
+ train_anchor_time=context_anchor_time,
1000
+ num_train_trials=max_pq_iterations * num_train_examples,
1001
+ num_test_examples=num_test_examples,
1002
+ test_anchor_time=anchor_time,
1003
+ num_test_trials=max_pq_iterations * num_test_examples,
1004
+ random_seed=random_seed,
1005
+ )
1006
+ train_pkey, train_time, y_train = train
1007
+ test_pkey, test_time, y_test = test
927
1008
 
928
- test_node, test_time, y_test = query_driver.collect_test(
929
- size=max_test_size,
930
- anchor_time=anchor_time,
931
- max_iterations=max_pq_iterations,
932
- guarantee_train_examples=True,
933
- )
934
- if logger is not None:
935
- if task_type == TaskType.BINARY_CLASSIFICATION:
936
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
937
- msg = (f"Collected {len(y_test):,} test examples with "
938
- f"{pos:.2f}% positive cases")
939
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
940
- msg = (f"Collected {len(y_test):,} test examples "
941
- f"holding {y_test.nunique()} classes")
942
- elif task_type == TaskType.REGRESSION:
943
- _min, _max = float(y_test.min()), float(y_test.max())
944
- msg = (f"Collected {len(y_test):,} test examples with "
945
- f"targets between {format_value(_min)} and "
946
- f"{format_value(_max)}")
947
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
948
- num_rhs = y_test.explode().nunique()
949
- msg = (f"Collected {len(y_test):,} test examples with "
950
- f"{num_rhs:,} unique items")
951
- else:
952
- raise NotImplementedError
953
- logger.log(msg)
1009
+ if evaluate and logger is not None:
1010
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1011
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1012
+ msg = (f"Collected {len(y_test):,} test examples with "
1013
+ f"{pos:.2f}% positive cases")
1014
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1015
+ msg = (f"Collected {len(y_test):,} test examples holding "
1016
+ f"{y_test.nunique()} classes")
1017
+ elif task_type == TaskType.REGRESSION:
1018
+ _min, _max = float(y_test.min()), float(y_test.max())
1019
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1020
+ f"between {format_value(_min)} and "
1021
+ f"{format_value(_max)}")
1022
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1023
+ num_rhs = y_test.explode().nunique()
1024
+ msg = (f"Collected {len(y_test):,} test examples with "
1025
+ f"{num_rhs:,} unique items")
1026
+ else:
1027
+ raise NotImplementedError
1028
+ logger.log(msg)
954
1029
 
955
- else:
1030
+ if not evaluate:
956
1031
  assert indices is not None
957
-
958
1032
  if len(indices) > _MAX_PRED_SIZE[task_type]:
959
1033
  raise ValueError(f"Cannot predict for more than "
960
1034
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -962,26 +1036,12 @@ class KumoRFM:
962
1036
  f"`KumoRFM.batch_mode` to process entities "
963
1037
  f"in batches")
964
1038
 
965
- test_node = self._graph_store.get_node_id(
966
- table_name=query.entity_table,
967
- pkey=pd.Series(indices),
968
- )
969
-
1039
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
970
1040
  if isinstance(anchor_time, pd.Timestamp):
971
- test_time = pd.Series(anchor_time).repeat(
972
- len(test_node)).reset_index(drop=True)
1041
+ test_time = pd.Series([anchor_time]).repeat(
1042
+ len(indices)).reset_index(drop=True)
973
1043
  else:
974
- time = self._graph_store.time_dict[query.entity_table]
975
- time = time[test_node] * 1000**3
976
- test_time = pd.Series(time, dtype='datetime64[ns]')
977
-
978
- train_node, train_time, y_train = query_driver.collect_train(
979
- size=_MAX_CONTEXT_SIZE[run_mode],
980
- anchor_time=context_anchor_time or 'entity',
981
- exclude_node=test_node if (query.query_type == QueryType.STATIC
982
- or anchor_time == 'entity') else None,
983
- max_iterations=max_pq_iterations,
984
- )
1044
+ train_time = test_time = 'entity'
985
1045
 
986
1046
  if logger is not None:
987
1047
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1009,7 +1069,7 @@ class KumoRFM:
1009
1069
  final_aggr = query.get_final_target_aggregation()
1010
1070
  assert final_aggr is not None
1011
1071
  edge_fkey = final_aggr._get_target_column_name()
1012
- for edge_type in self._graph_store.edge_types:
1072
+ for edge_type in self._sampler.edge_types:
1013
1073
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1014
1074
  entity_table_names = (
1015
1075
  query.entity_table,
@@ -1021,21 +1081,24 @@ class KumoRFM:
1021
1081
  # Exclude the entity anchor time from the feature set to prevent
1022
1082
  # running out-of-distribution between in-context and test examples:
1023
1083
  exclude_cols_dict = query.get_exclude_cols_dict()
1024
- if anchor_time == 'entity':
1084
+ if entity_table_names[0] in self._sampler.time_column_dict:
1025
1085
  if entity_table_names[0] not in exclude_cols_dict:
1026
1086
  exclude_cols_dict[entity_table_names[0]] = []
1027
- time_column_dict = self._graph_store.time_column_dict
1028
- time_column = time_column_dict[entity_table_names[0]]
1087
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1029
1088
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1030
1089
 
1031
- subgraph = self._graph_sampler(
1090
+ subgraph = self._sampler.sample_subgraph(
1032
1091
  entity_table_names=entity_table_names,
1033
- node=np.concatenate([train_node, test_node]),
1034
- time=np.concatenate([
1035
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1036
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1037
- ]),
1038
- run_mode=run_mode,
1092
+ entity_pkey=pd.concat(
1093
+ [train_pkey, test_pkey],
1094
+ axis=0,
1095
+ ignore_index=True,
1096
+ ),
1097
+ anchor_time=pd.concat(
1098
+ [train_time, test_time],
1099
+ axis=0,
1100
+ ignore_index=True,
1101
+ ) if isinstance(train_time, pd.Series) else 'entity',
1039
1102
  num_neighbors=num_neighbors,
1040
1103
  exclude_cols_dict=exclude_cols_dict,
1041
1104
  )
@@ -1047,18 +1110,14 @@ class KumoRFM:
1047
1110
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1048
1111
  f"must go beyond this for your use-case.")
1049
1112
 
1050
- step_size: Optional[int] = None
1051
- if query.query_type == QueryType.TEMPORAL:
1052
- step_size = date_offset_to_seconds(end_offset)
1053
-
1054
1113
  return Context(
1055
1114
  task_type=task_type,
1056
1115
  entity_table_names=entity_table_names,
1057
1116
  subgraph=subgraph,
1058
1117
  y_train=y_train,
1059
- y_test=y_test,
1118
+ y_test=y_test if evaluate else None,
1060
1119
  top_k=query.top_k,
1061
- step_size=step_size,
1120
+ step_size=None,
1062
1121
  )
1063
1122
 
1064
1123
  @staticmethod