kumoai 2.13.0.dev202511211730__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/connector/utils.py +23 -2
  5. kumoai/experimental/rfm/__init__.py +20 -45
  6. kumoai/experimental/rfm/backend/__init__.py +0 -0
  7. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  8. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
  9. kumoai/experimental/rfm/backend/local/sampler.py +313 -0
  10. kumoai/experimental/rfm/backend/local/table.py +119 -0
  11. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  12. kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
  13. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  14. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  15. kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
  17. kumoai/experimental/rfm/base/__init__.py +23 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +19 -0
  21. kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
  22. kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
  23. kumoai/experimental/rfm/infer/__init__.py +6 -0
  24. kumoai/experimental/rfm/infer/dtype.py +79 -0
  25. kumoai/experimental/rfm/infer/pkey.py +126 -0
  26. kumoai/experimental/rfm/infer/time_col.py +62 -0
  27. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  28. kumoai/experimental/rfm/rfm.py +224 -167
  29. kumoai/experimental/rfm/sagemaker.py +11 -3
  30. kumoai/pquery/predictive_query.py +10 -6
  31. kumoai/testing/decorators.py +1 -1
  32. kumoai/testing/snow.py +50 -0
  33. kumoai/utils/__init__.py +2 -0
  34. kumoai/utils/sql.py +3 -0
  35. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +9 -8
  36. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +39 -23
  37. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  38. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  39. kumoai/experimental/rfm/utils.py +0 -344
  40. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
  41. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
  42. {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,13 @@ import numpy as np
21
21
  import pandas as pd
22
22
  from kumoapi.model_plan import RunMode
23
23
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.pquery.AST import (
25
+ Aggregation,
26
+ Column,
27
+ Condition,
28
+ Join,
29
+ LogicalOperation,
30
+ )
24
31
  from kumoapi.rfm import Context
25
32
  from kumoapi.rfm import Explanation as ExplanationConfig
26
33
  from kumoapi.rfm import (
@@ -29,16 +36,12 @@ from kumoapi.rfm import (
29
36
  RFMPredictRequest,
30
37
  )
31
38
  from kumoapi.task import TaskType
39
+ from kumoapi.typing import AggregationType, Stype
32
40
 
33
41
  from kumoai.client.rfm import RFMAPI
34
42
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
43
+ from kumoai.experimental.rfm import Graph
44
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
45
  from kumoai.mixin import CastMixin
43
46
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
44
47
 
@@ -123,17 +126,17 @@ class KumoRFM:
123
126
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
127
  relational dataset without training.
125
128
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
129
+ model from a :class:`Graph` object.
127
130
 
128
131
  .. code-block:: python
129
132
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
133
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
134
 
132
135
  df_users = pd.DataFrame(...)
133
136
  df_items = pd.DataFrame(...)
134
137
  df_orders = pd.DataFrame(...)
135
138
 
136
- graph = LocalGraph.from_data({
139
+ graph = Graph.from_data({
137
140
  'users': df_users,
138
141
  'items': df_items,
139
142
  'orders': df_orders,
@@ -150,32 +153,41 @@ class KumoRFM:
150
153
 
151
154
  Args:
152
155
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
156
  verbose: Whether to print verbose output.
163
157
  """
164
158
  def __init__(
165
159
  self,
166
- graph: LocalGraph,
167
- preprocess: bool = False,
160
+ graph: Graph,
168
161
  verbose: Union[bool, ProgressLogger] = True,
169
162
  ) -> None:
170
163
  graph = graph.validate()
171
164
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
- self._graph_sampler = LocalGraphSampler(self._graph_store)
165
+
166
+ if graph.backend == DataBackend.LOCAL:
167
+ from kumoai.experimental.rfm.backend.local import LocalSampler
168
+ self._sampler: Sampler = LocalSampler(graph, verbose)
169
+ elif graph.backend == DataBackend.SQLITE:
170
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
171
+ self._sampler = SQLiteSampler(graph, verbose)
172
+ elif graph.backend == DataBackend.SNOWFLAKE:
173
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
174
+ self._sampler = SnowSampler(graph, verbose)
175
+ else:
176
+ raise NotImplementedError
177
+
178
+ self._client: Optional[RFMAPI] = None
174
179
 
175
180
  self._batch_size: Optional[int | Literal['max']] = None
176
181
  self.num_retries: int = 0
182
+
183
+ @property
184
+ def _api_client(self) -> RFMAPI:
185
+ if self._client is not None:
186
+ return self._client
187
+
177
188
  from kumoai.experimental.rfm import global_state
178
- self._api_client = RFMAPI(global_state.client)
189
+ self._client = RFMAPI(global_state.client)
190
+ return self._client
179
191
 
180
192
  def __repr__(self) -> str:
181
193
  return f'{self.__class__.__name__}()'
@@ -225,7 +237,7 @@ class KumoRFM:
225
237
  run_mode: Union[RunMode, str] = RunMode.FAST,
226
238
  num_neighbors: Optional[List[int]] = None,
227
239
  num_hops: int = 2,
228
- max_pq_iterations: int = 20,
240
+ max_pq_iterations: int = 10,
229
241
  random_seed: Optional[int] = _RANDOM_SEED,
230
242
  verbose: Union[bool, ProgressLogger] = True,
231
243
  use_prediction_time: bool = False,
@@ -244,7 +256,7 @@ class KumoRFM:
244
256
  run_mode: Union[RunMode, str] = RunMode.FAST,
245
257
  num_neighbors: Optional[List[int]] = None,
246
258
  num_hops: int = 2,
247
- max_pq_iterations: int = 20,
259
+ max_pq_iterations: int = 10,
248
260
  random_seed: Optional[int] = _RANDOM_SEED,
249
261
  verbose: Union[bool, ProgressLogger] = True,
250
262
  use_prediction_time: bool = False,
@@ -262,7 +274,7 @@ class KumoRFM:
262
274
  run_mode: Union[RunMode, str] = RunMode.FAST,
263
275
  num_neighbors: Optional[List[int]] = None,
264
276
  num_hops: int = 2,
265
- max_pq_iterations: int = 20,
277
+ max_pq_iterations: int = 10,
266
278
  random_seed: Optional[int] = _RANDOM_SEED,
267
279
  verbose: Union[bool, ProgressLogger] = True,
268
280
  use_prediction_time: bool = False,
@@ -358,9 +370,9 @@ class KumoRFM:
358
370
 
359
371
  batch_size: Optional[int] = None
360
372
  if self._batch_size == 'max':
361
- task_type = LocalPQueryDriver.get_task_type(
362
- query_def,
363
- edge_types=self._graph_store.edge_types,
373
+ task_type = self._get_task_type(
374
+ query=query_def,
375
+ edge_types=self._sampler.edge_types,
364
376
  )
365
377
  batch_size = _MAX_PRED_SIZE[task_type]
366
378
  else:
@@ -434,10 +446,10 @@ class KumoRFM:
434
446
 
435
447
  # Cast 'ENTITY' to correct data type:
436
448
  if 'ENTITY' in df:
437
- entity = query_def.entity_table
438
- pkey_map = self._graph_store.pkey_map_dict[entity]
439
- df['ENTITY'] = df['ENTITY'].astype(
440
- type(pkey_map.index[0]))
449
+ table_dict = context.subgraph.table_dict
450
+ table = table_dict[query_def.entity_table]
451
+ ser = table.df[table.primary_key]
452
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
441
453
 
442
454
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
443
455
  if 'ANCHOR_TIMESTAMP' in df:
@@ -520,23 +532,18 @@ class KumoRFM:
520
532
  raise ValueError("At least one entity is required")
521
533
 
522
534
  if anchor_time is None:
523
- anchor_time = self._graph_store.max_time
535
+ anchor_time = self._get_default_anchor_time(query_def)
524
536
 
525
537
  if isinstance(anchor_time, pd.Timestamp):
526
538
  self._validate_time(query_def, anchor_time, None, False)
527
539
  else:
528
540
  assert anchor_time == 'entity'
529
- if (query_def.entity_table not in self._graph_store.time_dict):
541
+ if query_def.entity_table not in self._sampler.time_column_dict:
530
542
  raise ValueError(f"Anchor time 'entity' requires the entity "
531
543
  f"table '{query_def.entity_table}' "
532
544
  f"to have a time column.")
533
545
 
534
- node = self._graph_store.get_node_id(
535
- table_name=query_def.entity_table,
536
- pkey=pd.Series(indices),
537
- )
538
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
539
- return query_driver.is_valid(node, anchor_time)
546
+ raise NotImplementedError
540
547
 
541
548
  def evaluate(
542
549
  self,
@@ -548,7 +555,7 @@ class KumoRFM:
548
555
  run_mode: Union[RunMode, str] = RunMode.FAST,
549
556
  num_neighbors: Optional[List[int]] = None,
550
557
  num_hops: int = 2,
551
- max_pq_iterations: int = 20,
558
+ max_pq_iterations: int = 10,
552
559
  random_seed: Optional[int] = _RANDOM_SEED,
553
560
  verbose: Union[bool, ProgressLogger] = True,
554
561
  use_prediction_time: bool = False,
@@ -659,7 +666,7 @@ class KumoRFM:
659
666
  *,
660
667
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
661
668
  random_seed: Optional[int] = _RANDOM_SEED,
662
- max_iterations: int = 20,
669
+ max_iterations: int = 10,
663
670
  ) -> pd.DataFrame:
664
671
  """Returns the labels of a predictive query for a specified anchor
665
672
  time.
@@ -679,40 +686,37 @@ class KumoRFM:
679
686
  query_def = self._parse_query(query)
680
687
 
681
688
  if anchor_time is None:
682
- anchor_time = self._graph_store.max_time
689
+ anchor_time = self._get_default_anchor_time(query_def)
683
690
  if query_def.target_ast.date_offset_range is not None:
684
- anchor_time = anchor_time - (
685
- query_def.target_ast.date_offset_range.end_date_offset *
686
- query_def.num_forecasts)
691
+ offset = query_def.target_ast.date_offset_range.end_date_offset
692
+ offset *= query_def.num_forecasts
693
+ anchor_time -= offset
687
694
 
688
695
  assert anchor_time is not None
689
696
  if isinstance(anchor_time, pd.Timestamp):
690
697
  self._validate_time(query_def, anchor_time, None, evaluate=True)
691
698
  else:
692
699
  assert anchor_time == 'entity'
693
- if (query_def.entity_table not in self._graph_store.time_dict):
700
+ if query_def.entity_table not in self._sampler.time_column_dict:
694
701
  raise ValueError(f"Anchor time 'entity' requires the entity "
695
702
  f"table '{query_def.entity_table}' "
696
703
  f"to have a time column")
697
704
 
698
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
699
- random_seed)
700
-
701
- node, time, y = query_driver.collect_test(
702
- size=size,
703
- anchor_time=anchor_time,
704
- batch_size=min(10_000, size),
705
- max_iterations=max_iterations,
706
- guarantee_train_examples=False,
705
+ train, test = self._sampler.sample_target(
706
+ query=query,
707
+ num_train_examples=0,
708
+ train_anchor_time=anchor_time,
709
+ num_train_trials=0,
710
+ num_test_examples=size,
711
+ test_anchor_time=anchor_time,
712
+ num_test_trials=max_iterations * size,
713
+ random_seed=random_seed,
707
714
  )
708
715
 
709
- entity = self._graph_store.pkey_map_dict[
710
- query_def.entity_table].index[node]
711
-
712
716
  return pd.DataFrame({
713
- 'ENTITY': entity,
714
- 'ANCHOR_TIMESTAMP': time,
715
- 'TARGET': y,
717
+ 'ENTITY': test.entity_pkey,
718
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
719
+ 'TARGET': test.target,
716
720
  })
717
721
 
718
722
  # Helpers #################################################################
@@ -735,8 +739,6 @@ class KumoRFM:
735
739
 
736
740
  resp = self._api_client.parse_query(request)
737
741
 
738
- # TODO Expose validation warnings.
739
-
740
742
  if len(resp.validation_response.warnings) > 0:
741
743
  msg = '\n'.join([
742
744
  f'{i+1}. {warning.title}: {warning.message}' for i, warning
@@ -754,6 +756,60 @@ class KumoRFM:
754
756
  raise ValueError(f"Failed to parse query '{query}'. "
755
757
  f"{msg}") from None
756
758
 
759
+ @staticmethod
760
+ def _get_task_type(
761
+ query: ValidatedPredictiveQuery,
762
+ edge_types: List[Tuple[str, str, str]],
763
+ ) -> TaskType:
764
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
765
+ return TaskType.BINARY_CLASSIFICATION
766
+
767
+ target = query.target_ast
768
+ if isinstance(target, Join):
769
+ target = target.rhs_target
770
+ if isinstance(target, Aggregation):
771
+ if target.aggr == AggregationType.LIST_DISTINCT:
772
+ table_name, col_name = target._get_target_column_name().split(
773
+ '.')
774
+ target_edge_types = [
775
+ edge_type for edge_type in edge_types
776
+ if edge_type[0] == table_name and edge_type[1] == col_name
777
+ ]
778
+ if len(target_edge_types) != 1:
779
+ raise NotImplementedError(
780
+ f"Multilabel-classification queries based on "
781
+ f"'LIST_DISTINCT' are not supported yet. If you "
782
+ f"planned to write a link prediction query instead, "
783
+ f"make sure to register '{col_name}' as a "
784
+ f"foreign key.")
785
+ return TaskType.TEMPORAL_LINK_PREDICTION
786
+
787
+ return TaskType.REGRESSION
788
+
789
+ assert isinstance(target, Column)
790
+
791
+ if target.stype in {Stype.ID, Stype.categorical}:
792
+ return TaskType.MULTICLASS_CLASSIFICATION
793
+
794
+ if target.stype in {Stype.numerical}:
795
+ return TaskType.REGRESSION
796
+
797
+ raise NotImplementedError("Task type not yet supported")
798
+
799
+ def _get_default_anchor_time(
800
+ self,
801
+ query: ValidatedPredictiveQuery,
802
+ ) -> pd.Timestamp:
803
+ if query.query_type == QueryType.TEMPORAL:
804
+ aggr_table_names = [
805
+ aggr._get_target_column_name().split('.')[0]
806
+ for aggr in query.get_all_target_aggregations()
807
+ ]
808
+ return self._sampler.get_max_time(aggr_table_names)
809
+
810
+ assert query.query_type == QueryType.STATIC
811
+ return self._sampler.get_max_time()
812
+
757
813
  def _validate_time(
758
814
  self,
759
815
  query: ValidatedPredictiveQuery,
@@ -762,28 +818,30 @@ class KumoRFM:
762
818
  evaluate: bool,
763
819
  ) -> None:
764
820
 
765
- if self._graph_store.min_time == pd.Timestamp.max:
821
+ if len(self._sampler.time_column_dict) == 0:
766
822
  return # Graph without timestamps
767
823
 
768
- if anchor_time < self._graph_store.min_time:
824
+ min_time = self._sampler.get_min_time()
825
+ max_time = self._sampler.get_max_time()
826
+
827
+ if anchor_time < min_time:
769
828
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
770
- f"the earliest timestamp "
771
- f"'{self._graph_store.min_time}' in the data.")
829
+ f"the earliest timestamp '{min_time}' in the "
830
+ f"data.")
772
831
 
773
- if (context_anchor_time is not None
774
- and context_anchor_time < self._graph_store.min_time):
832
+ if context_anchor_time is not None and context_anchor_time < min_time:
775
833
  raise ValueError(f"Context anchor timestamp is too early or "
776
834
  f"aggregation time range is too large. To make "
777
835
  f"this prediction, we would need data back to "
778
836
  f"'{context_anchor_time}', however, your data "
779
- f"only contains data back to "
780
- f"'{self._graph_store.min_time}'.")
837
+ f"only contains data back to '{min_time}'.")
781
838
 
782
839
  if query.target_ast.date_offset_range is not None:
783
840
  end_offset = query.target_ast.date_offset_range.end_date_offset
784
841
  else:
785
842
  end_offset = pd.DateOffset(0)
786
- forecast_end_offset = end_offset * query.num_forecasts
843
+ end_offset = end_offset * query.num_forecasts
844
+
787
845
  if (context_anchor_time is not None
788
846
  and context_anchor_time > anchor_time):
789
847
  warnings.warn(f"Context anchor timestamp "
@@ -793,7 +851,7 @@ class KumoRFM:
793
851
  f"intended.")
794
852
  elif (query.query_type == QueryType.TEMPORAL
795
853
  and context_anchor_time is not None
796
- and context_anchor_time + forecast_end_offset > anchor_time):
854
+ and context_anchor_time + end_offset > anchor_time):
797
855
  warnings.warn(f"Aggregation for context examples at timestamp "
798
856
  f"'{context_anchor_time}' will leak information "
799
857
  f"from the prediction anchor timestamp "
@@ -801,26 +859,23 @@ class KumoRFM:
801
859
  f"intended.")
802
860
 
803
861
  elif (context_anchor_time is not None
804
- and context_anchor_time - forecast_end_offset
805
- < self._graph_store.min_time):
806
- _time = context_anchor_time - forecast_end_offset
862
+ and context_anchor_time - end_offset < min_time):
863
+ _time = context_anchor_time - end_offset
807
864
  warnings.warn(f"Context anchor timestamp is too early or "
808
865
  f"aggregation time range is too large. To form "
809
866
  f"proper input data, we would need data back to "
810
867
  f"'{_time}', however, your data only contains "
811
- f"data back to '{self._graph_store.min_time}'.")
868
+ f"data back to '{min_time}'.")
812
869
 
813
- if (not evaluate and anchor_time
814
- > self._graph_store.max_time + pd.DateOffset(days=1)):
870
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
815
871
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
816
- f"latest timestamp '{self._graph_store.max_time}' "
817
- f"in the data. Please make sure this is intended.")
872
+ f"latest timestamp '{max_time}' in the data. Please "
873
+ f"make sure this is intended.")
818
874
 
819
- max_eval_time = self._graph_store.max_time - forecast_end_offset
820
- if evaluate and anchor_time > max_eval_time:
875
+ if evaluate and anchor_time > max_time - end_offset:
821
876
  raise ValueError(
822
877
  f"Anchor timestamp for evaluation is after the latest "
823
- f"supported timestamp '{max_eval_time}'.")
878
+ f"supported timestamp '{max_time - end_offset}'.")
824
879
 
825
880
  def _get_context(
826
881
  self,
@@ -851,10 +906,9 @@ class KumoRFM:
851
906
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
852
907
  f"must go beyond this for your use-case.")
853
908
 
854
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
855
- task_type = LocalPQueryDriver.get_task_type(
856
- query,
857
- edge_types=self._graph_store.edge_types,
909
+ task_type = self._get_task_type(
910
+ query=query,
911
+ edge_types=self._sampler.edge_types,
858
912
  )
859
913
 
860
914
  if logger is not None:
@@ -886,14 +940,17 @@ class KumoRFM:
886
940
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
887
941
 
888
942
  if query.target_ast.date_offset_range is None:
889
- end_offset = pd.DateOffset(0)
943
+ step_offset = pd.DateOffset(0)
890
944
  else:
891
- end_offset = query.target_ast.date_offset_range.end_date_offset
892
- forecast_end_offset = end_offset * query.num_forecasts
945
+ step_offset = query.target_ast.date_offset_range.end_date_offset
946
+ end_offset = step_offset * query.num_forecasts
947
+
893
948
  if anchor_time is None:
894
- anchor_time = self._graph_store.max_time
949
+ anchor_time = self._get_default_anchor_time(query)
950
+
895
951
  if evaluate:
896
- anchor_time = anchor_time - forecast_end_offset
952
+ anchor_time = anchor_time - end_offset
953
+
897
954
  if logger is not None:
898
955
  assert isinstance(anchor_time, pd.Timestamp)
899
956
  if anchor_time == pd.Timestamp.min:
@@ -907,57 +964,71 @@ class KumoRFM:
907
964
 
908
965
  assert anchor_time is not None
909
966
  if isinstance(anchor_time, pd.Timestamp):
967
+ if context_anchor_time == 'entity':
968
+ raise ValueError("Anchor time 'entity' needs to be shared "
969
+ "for context and prediction examples")
910
970
  if context_anchor_time is None:
911
- context_anchor_time = anchor_time - forecast_end_offset
971
+ context_anchor_time = anchor_time - end_offset
912
972
  self._validate_time(query, anchor_time, context_anchor_time,
913
973
  evaluate)
914
974
  else:
915
975
  assert anchor_time == 'entity'
916
- if query.entity_table not in self._graph_store.time_dict:
976
+ if query.query_type != QueryType.STATIC:
977
+ raise ValueError("Anchor time 'entity' is only valid for "
978
+ "static predictive queries")
979
+ if query.entity_table not in self._sampler.time_column_dict:
917
980
  raise ValueError(f"Anchor time 'entity' requires the entity "
918
981
  f"table '{query.entity_table}' to "
919
982
  f"have a time column")
920
- if context_anchor_time is not None:
921
- warnings.warn("Ignoring option 'context_anchor_time' for "
922
- "`anchor_time='entity'`")
923
- context_anchor_time = None
983
+ if isinstance(context_anchor_time, pd.Timestamp):
984
+ raise ValueError("Anchor time 'entity' needs to be shared "
985
+ "for context and prediction examples")
986
+ context_anchor_time = 'entity'
924
987
 
925
- y_test: Optional[pd.Series] = None
988
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
926
989
  if evaluate:
927
- max_test_size = _MAX_TEST_SIZE[run_mode]
990
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
928
991
  if task_type.is_link_pred:
929
- max_test_size = max_test_size // 5
992
+ num_test_examples = num_test_examples // 5
993
+ else:
994
+ num_test_examples = 0
995
+
996
+ train, test = self._sampler.sample_target(
997
+ query=query,
998
+ num_train_examples=num_train_examples,
999
+ train_anchor_time=context_anchor_time,
1000
+ num_train_trials=max_pq_iterations * num_train_examples,
1001
+ num_test_examples=num_test_examples,
1002
+ test_anchor_time=anchor_time,
1003
+ num_test_trials=max_pq_iterations * num_test_examples,
1004
+ random_seed=random_seed,
1005
+ )
1006
+ train_pkey, train_time, y_train = train
1007
+ test_pkey, test_time, y_test = test
930
1008
 
931
- test_node, test_time, y_test = query_driver.collect_test(
932
- size=max_test_size,
933
- anchor_time=anchor_time,
934
- max_iterations=max_pq_iterations,
935
- guarantee_train_examples=True,
936
- )
937
- if logger is not None:
938
- if task_type == TaskType.BINARY_CLASSIFICATION:
939
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
940
- msg = (f"Collected {len(y_test):,} test examples with "
941
- f"{pos:.2f}% positive cases")
942
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
943
- msg = (f"Collected {len(y_test):,} test examples "
944
- f"holding {y_test.nunique()} classes")
945
- elif task_type == TaskType.REGRESSION:
946
- _min, _max = float(y_test.min()), float(y_test.max())
947
- msg = (f"Collected {len(y_test):,} test examples with "
948
- f"targets between {format_value(_min)} and "
949
- f"{format_value(_max)}")
950
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
951
- num_rhs = y_test.explode().nunique()
952
- msg = (f"Collected {len(y_test):,} test examples with "
953
- f"{num_rhs:,} unique items")
954
- else:
955
- raise NotImplementedError
956
- logger.log(msg)
1009
+ if evaluate and logger is not None:
1010
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1011
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1012
+ msg = (f"Collected {len(y_test):,} test examples with "
1013
+ f"{pos:.2f}% positive cases")
1014
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1015
+ msg = (f"Collected {len(y_test):,} test examples holding "
1016
+ f"{y_test.nunique()} classes")
1017
+ elif task_type == TaskType.REGRESSION:
1018
+ _min, _max = float(y_test.min()), float(y_test.max())
1019
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1020
+ f"between {format_value(_min)} and "
1021
+ f"{format_value(_max)}")
1022
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1023
+ num_rhs = y_test.explode().nunique()
1024
+ msg = (f"Collected {len(y_test):,} test examples with "
1025
+ f"{num_rhs:,} unique items")
1026
+ else:
1027
+ raise NotImplementedError
1028
+ logger.log(msg)
957
1029
 
958
- else:
1030
+ if not evaluate:
959
1031
  assert indices is not None
960
-
961
1032
  if len(indices) > _MAX_PRED_SIZE[task_type]:
962
1033
  raise ValueError(f"Cannot predict for more than "
963
1034
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -965,26 +1036,12 @@ class KumoRFM:
965
1036
  f"`KumoRFM.batch_mode` to process entities "
966
1037
  f"in batches")
967
1038
 
968
- test_node = self._graph_store.get_node_id(
969
- table_name=query.entity_table,
970
- pkey=pd.Series(indices),
971
- )
972
-
1039
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
973
1040
  if isinstance(anchor_time, pd.Timestamp):
974
- test_time = pd.Series(anchor_time).repeat(
975
- len(test_node)).reset_index(drop=True)
1041
+ test_time = pd.Series([anchor_time]).repeat(
1042
+ len(indices)).reset_index(drop=True)
976
1043
  else:
977
- time = self._graph_store.time_dict[query.entity_table]
978
- time = time[test_node] * 1000**3
979
- test_time = pd.Series(time, dtype='datetime64[ns]')
980
-
981
- train_node, train_time, y_train = query_driver.collect_train(
982
- size=_MAX_CONTEXT_SIZE[run_mode],
983
- anchor_time=context_anchor_time or 'entity',
984
- exclude_node=test_node if (query.query_type == QueryType.STATIC
985
- or anchor_time == 'entity') else None,
986
- max_iterations=max_pq_iterations,
987
- )
1044
+ train_time = test_time = 'entity'
988
1045
 
989
1046
  if logger is not None:
990
1047
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1012,7 +1069,7 @@ class KumoRFM:
1012
1069
  final_aggr = query.get_final_target_aggregation()
1013
1070
  assert final_aggr is not None
1014
1071
  edge_fkey = final_aggr._get_target_column_name()
1015
- for edge_type in self._graph_store.edge_types:
1072
+ for edge_type in self._sampler.edge_types:
1016
1073
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1017
1074
  entity_table_names = (
1018
1075
  query.entity_table,
@@ -1024,20 +1081,24 @@ class KumoRFM:
1024
1081
  # Exclude the entity anchor time from the feature set to prevent
1025
1082
  # running out-of-distribution between in-context and test examples:
1026
1083
  exclude_cols_dict = query.get_exclude_cols_dict()
1027
- if anchor_time == 'entity':
1084
+ if entity_table_names[0] in self._sampler.time_column_dict:
1028
1085
  if entity_table_names[0] not in exclude_cols_dict:
1029
1086
  exclude_cols_dict[entity_table_names[0]] = []
1030
- time_column_dict = self._graph_store.time_column_dict
1031
- time_column = time_column_dict[entity_table_names[0]]
1087
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1032
1088
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1033
1089
 
1034
- subgraph = self._graph_sampler(
1090
+ subgraph = self._sampler.sample_subgraph(
1035
1091
  entity_table_names=entity_table_names,
1036
- node=np.concatenate([train_node, test_node]),
1037
- time=np.concatenate([
1038
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1040
- ]),
1092
+ entity_pkey=pd.concat(
1093
+ [train_pkey, test_pkey],
1094
+ axis=0,
1095
+ ignore_index=True,
1096
+ ),
1097
+ anchor_time=pd.concat(
1098
+ [train_time, test_time],
1099
+ axis=0,
1100
+ ignore_index=True,
1101
+ ) if isinstance(train_time, pd.Series) else 'entity',
1041
1102
  num_neighbors=num_neighbors,
1042
1103
  exclude_cols_dict=exclude_cols_dict,
1043
1104
  )
@@ -1049,18 +1110,14 @@ class KumoRFM:
1049
1110
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1050
1111
  f"must go beyond this for your use-case.")
1051
1112
 
1052
- step_size: Optional[int] = None
1053
- if query.query_type == QueryType.TEMPORAL:
1054
- step_size = date_offset_to_seconds(end_offset)
1055
-
1056
1113
  return Context(
1057
1114
  task_type=task_type,
1058
1115
  entity_table_names=entity_table_names,
1059
1116
  subgraph=subgraph,
1060
1117
  y_train=y_train,
1061
- y_test=y_test,
1118
+ y_test=y_test if evaluate else None,
1062
1119
  top_k=query.top_k,
1063
- step_size=step_size,
1120
+ step_size=None,
1064
1121
  )
1065
1122
 
1066
1123
  @staticmethod