kumoai 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
  4. kumoai/experimental/rfm/backend/local/sampler.py +213 -14
  5. kumoai/experimental/rfm/backend/local/table.py +12 -2
  6. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  7. kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
  8. kumoai/experimental/rfm/backend/snow/table.py +35 -17
  9. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
  10. kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
  11. kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
  12. kumoai/experimental/rfm/base/__init__.py +17 -6
  13. kumoai/experimental/rfm/base/sampler.py +438 -38
  14. kumoai/experimental/rfm/base/source.py +1 -0
  15. kumoai/experimental/rfm/base/sql_sampler.py +56 -0
  16. kumoai/experimental/rfm/base/table.py +12 -1
  17. kumoai/experimental/rfm/graph.py +26 -9
  18. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  19. kumoai/experimental/rfm/rfm.py +214 -151
  20. kumoai/pquery/predictive_query.py +10 -6
  21. kumoai/testing/snow.py +50 -0
  22. kumoai/utils/__init__.py +2 -0
  23. kumoai/utils/sql.py +3 -0
  24. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
  25. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
  26. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  27. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  28. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
  29. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
  30. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,13 @@ import numpy as np
21
21
  import pandas as pd
22
22
  from kumoapi.model_plan import RunMode
23
23
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.pquery.AST import (
25
+ Aggregation,
26
+ Column,
27
+ Condition,
28
+ Join,
29
+ LogicalOperation,
30
+ )
24
31
  from kumoapi.rfm import Context
25
32
  from kumoapi.rfm import Explanation as ExplanationConfig
26
33
  from kumoapi.rfm import (
@@ -29,16 +36,12 @@ from kumoapi.rfm import (
29
36
  RFMPredictRequest,
30
37
  )
31
38
  from kumoapi.task import TaskType
39
+ from kumoapi.typing import AggregationType, Stype
32
40
 
33
41
  from kumoai.client.rfm import RFMAPI
34
42
  from kumoai.exceptions import HTTPException
35
43
  from kumoai.experimental.rfm import Graph
36
- from kumoai.experimental.rfm.backend.local import LocalGraphStore
37
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
44
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
45
  from kumoai.mixin import CastMixin
43
46
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
44
47
 
@@ -151,16 +154,31 @@ class KumoRFM:
151
154
  Args:
152
155
  graph: The graph.
153
156
  verbose: Whether to print verbose output.
157
+ optimize: If set to ``True``, will optimize the underlying data backend
158
+ for optimal querying. For example, for transactional database
159
+ backends, will create any missing indices. Requires write-access to
160
+ the data backend.
154
161
  """
155
162
  def __init__(
156
163
  self,
157
164
  graph: Graph,
158
165
  verbose: Union[bool, ProgressLogger] = True,
166
+ optimize: bool = False,
159
167
  ) -> None:
160
168
  graph = graph.validate()
161
169
  self._graph_def = graph._to_api_graph_definition()
162
- self._graph_store = LocalGraphStore(graph, verbose)
163
- self._graph_sampler = LocalGraphSampler(self._graph_store)
170
+
171
+ if graph.backend == DataBackend.LOCAL:
172
+ from kumoai.experimental.rfm.backend.local import LocalSampler
173
+ self._sampler: Sampler = LocalSampler(graph, verbose)
174
+ elif graph.backend == DataBackend.SQLITE:
175
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
176
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
177
+ elif graph.backend == DataBackend.SNOWFLAKE:
178
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
179
+ self._sampler = SnowSampler(graph, verbose)
180
+ else:
181
+ raise NotImplementedError
164
182
 
165
183
  self._client: Optional[RFMAPI] = None
166
184
 
@@ -224,7 +242,7 @@ class KumoRFM:
224
242
  run_mode: Union[RunMode, str] = RunMode.FAST,
225
243
  num_neighbors: Optional[List[int]] = None,
226
244
  num_hops: int = 2,
227
- max_pq_iterations: int = 20,
245
+ max_pq_iterations: int = 10,
228
246
  random_seed: Optional[int] = _RANDOM_SEED,
229
247
  verbose: Union[bool, ProgressLogger] = True,
230
248
  use_prediction_time: bool = False,
@@ -243,7 +261,7 @@ class KumoRFM:
243
261
  run_mode: Union[RunMode, str] = RunMode.FAST,
244
262
  num_neighbors: Optional[List[int]] = None,
245
263
  num_hops: int = 2,
246
- max_pq_iterations: int = 20,
264
+ max_pq_iterations: int = 10,
247
265
  random_seed: Optional[int] = _RANDOM_SEED,
248
266
  verbose: Union[bool, ProgressLogger] = True,
249
267
  use_prediction_time: bool = False,
@@ -261,7 +279,7 @@ class KumoRFM:
261
279
  run_mode: Union[RunMode, str] = RunMode.FAST,
262
280
  num_neighbors: Optional[List[int]] = None,
263
281
  num_hops: int = 2,
264
- max_pq_iterations: int = 20,
282
+ max_pq_iterations: int = 10,
265
283
  random_seed: Optional[int] = _RANDOM_SEED,
266
284
  verbose: Union[bool, ProgressLogger] = True,
267
285
  use_prediction_time: bool = False,
@@ -357,9 +375,9 @@ class KumoRFM:
357
375
 
358
376
  batch_size: Optional[int] = None
359
377
  if self._batch_size == 'max':
360
- task_type = LocalPQueryDriver.get_task_type(
361
- query_def,
362
- edge_types=self._graph_store.edge_types,
378
+ task_type = self._get_task_type(
379
+ query=query_def,
380
+ edge_types=self._sampler.edge_types,
363
381
  )
364
382
  batch_size = _MAX_PRED_SIZE[task_type]
365
383
  else:
@@ -433,10 +451,10 @@ class KumoRFM:
433
451
 
434
452
  # Cast 'ENTITY' to correct data type:
435
453
  if 'ENTITY' in df:
436
- entity = query_def.entity_table
437
- pkey_map = self._graph_store.pkey_map_dict[entity]
438
- df['ENTITY'] = df['ENTITY'].astype(
439
- type(pkey_map.index[0]))
454
+ table_dict = context.subgraph.table_dict
455
+ table = table_dict[query_def.entity_table]
456
+ ser = table.df[table.primary_key]
457
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
440
458
 
441
459
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
442
460
  if 'ANCHOR_TIMESTAMP' in df:
@@ -519,23 +537,18 @@ class KumoRFM:
519
537
  raise ValueError("At least one entity is required")
520
538
 
521
539
  if anchor_time is None:
522
- anchor_time = self._graph_store.max_time
540
+ anchor_time = self._get_default_anchor_time(query_def)
523
541
 
524
542
  if isinstance(anchor_time, pd.Timestamp):
525
543
  self._validate_time(query_def, anchor_time, None, False)
526
544
  else:
527
545
  assert anchor_time == 'entity'
528
- if (query_def.entity_table not in self._graph_store.time_dict):
546
+ if query_def.entity_table not in self._sampler.time_column_dict:
529
547
  raise ValueError(f"Anchor time 'entity' requires the entity "
530
548
  f"table '{query_def.entity_table}' "
531
549
  f"to have a time column.")
532
550
 
533
- node = self._graph_store.get_node_id(
534
- table_name=query_def.entity_table,
535
- pkey=pd.Series(indices),
536
- )
537
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
538
- return query_driver.is_valid(node, anchor_time)
551
+ raise NotImplementedError
539
552
 
540
553
  def evaluate(
541
554
  self,
@@ -547,7 +560,7 @@ class KumoRFM:
547
560
  run_mode: Union[RunMode, str] = RunMode.FAST,
548
561
  num_neighbors: Optional[List[int]] = None,
549
562
  num_hops: int = 2,
550
- max_pq_iterations: int = 20,
563
+ max_pq_iterations: int = 10,
551
564
  random_seed: Optional[int] = _RANDOM_SEED,
552
565
  verbose: Union[bool, ProgressLogger] = True,
553
566
  use_prediction_time: bool = False,
@@ -658,7 +671,7 @@ class KumoRFM:
658
671
  *,
659
672
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
660
673
  random_seed: Optional[int] = _RANDOM_SEED,
661
- max_iterations: int = 20,
674
+ max_iterations: int = 10,
662
675
  ) -> pd.DataFrame:
663
676
  """Returns the labels of a predictive query for a specified anchor
664
677
  time.
@@ -678,40 +691,37 @@ class KumoRFM:
678
691
  query_def = self._parse_query(query)
679
692
 
680
693
  if anchor_time is None:
681
- anchor_time = self._graph_store.max_time
694
+ anchor_time = self._get_default_anchor_time(query_def)
682
695
  if query_def.target_ast.date_offset_range is not None:
683
- anchor_time = anchor_time - (
684
- query_def.target_ast.date_offset_range.end_date_offset *
685
- query_def.num_forecasts)
696
+ offset = query_def.target_ast.date_offset_range.end_date_offset
697
+ offset *= query_def.num_forecasts
698
+ anchor_time -= offset
686
699
 
687
700
  assert anchor_time is not None
688
701
  if isinstance(anchor_time, pd.Timestamp):
689
702
  self._validate_time(query_def, anchor_time, None, evaluate=True)
690
703
  else:
691
704
  assert anchor_time == 'entity'
692
- if (query_def.entity_table not in self._graph_store.time_dict):
705
+ if query_def.entity_table not in self._sampler.time_column_dict:
693
706
  raise ValueError(f"Anchor time 'entity' requires the entity "
694
707
  f"table '{query_def.entity_table}' "
695
708
  f"to have a time column")
696
709
 
697
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
698
- random_seed)
699
-
700
- node, time, y = query_driver.collect_test(
701
- size=size,
702
- anchor_time=anchor_time,
703
- batch_size=min(10_000, size),
704
- max_iterations=max_iterations,
705
- guarantee_train_examples=False,
710
+ train, test = self._sampler.sample_target(
711
+ query=query,
712
+ num_train_examples=0,
713
+ train_anchor_time=anchor_time,
714
+ num_train_trials=0,
715
+ num_test_examples=size,
716
+ test_anchor_time=anchor_time,
717
+ num_test_trials=max_iterations * size,
718
+ random_seed=random_seed,
706
719
  )
707
720
 
708
- entity = self._graph_store.pkey_map_dict[
709
- query_def.entity_table].index[node]
710
-
711
721
  return pd.DataFrame({
712
- 'ENTITY': entity,
713
- 'ANCHOR_TIMESTAMP': time,
714
- 'TARGET': y,
722
+ 'ENTITY': test.entity_pkey,
723
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
724
+ 'TARGET': test.target,
715
725
  })
716
726
 
717
727
  # Helpers #################################################################
@@ -734,8 +744,6 @@ class KumoRFM:
734
744
 
735
745
  resp = self._api_client.parse_query(request)
736
746
 
737
- # TODO Expose validation warnings.
738
-
739
747
  if len(resp.validation_response.warnings) > 0:
740
748
  msg = '\n'.join([
741
749
  f'{i+1}. {warning.title}: {warning.message}' for i, warning
@@ -753,6 +761,60 @@ class KumoRFM:
753
761
  raise ValueError(f"Failed to parse query '{query}'. "
754
762
  f"{msg}") from None
755
763
 
764
+ @staticmethod
765
+ def _get_task_type(
766
+ query: ValidatedPredictiveQuery,
767
+ edge_types: List[Tuple[str, str, str]],
768
+ ) -> TaskType:
769
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
770
+ return TaskType.BINARY_CLASSIFICATION
771
+
772
+ target = query.target_ast
773
+ if isinstance(target, Join):
774
+ target = target.rhs_target
775
+ if isinstance(target, Aggregation):
776
+ if target.aggr == AggregationType.LIST_DISTINCT:
777
+ table_name, col_name = target._get_target_column_name().split(
778
+ '.')
779
+ target_edge_types = [
780
+ edge_type for edge_type in edge_types
781
+ if edge_type[0] == table_name and edge_type[1] == col_name
782
+ ]
783
+ if len(target_edge_types) != 1:
784
+ raise NotImplementedError(
785
+ f"Multilabel-classification queries based on "
786
+ f"'LIST_DISTINCT' are not supported yet. If you "
787
+ f"planned to write a link prediction query instead, "
788
+ f"make sure to register '{col_name}' as a "
789
+ f"foreign key.")
790
+ return TaskType.TEMPORAL_LINK_PREDICTION
791
+
792
+ return TaskType.REGRESSION
793
+
794
+ assert isinstance(target, Column)
795
+
796
+ if target.stype in {Stype.ID, Stype.categorical}:
797
+ return TaskType.MULTICLASS_CLASSIFICATION
798
+
799
+ if target.stype in {Stype.numerical}:
800
+ return TaskType.REGRESSION
801
+
802
+ raise NotImplementedError("Task type not yet supported")
803
+
804
+ def _get_default_anchor_time(
805
+ self,
806
+ query: ValidatedPredictiveQuery,
807
+ ) -> pd.Timestamp:
808
+ if query.query_type == QueryType.TEMPORAL:
809
+ aggr_table_names = [
810
+ aggr._get_target_column_name().split('.')[0]
811
+ for aggr in query.get_all_target_aggregations()
812
+ ]
813
+ return self._sampler.get_max_time(aggr_table_names)
814
+
815
+ assert query.query_type == QueryType.STATIC
816
+ return self._sampler.get_max_time()
817
+
756
818
  def _validate_time(
757
819
  self,
758
820
  query: ValidatedPredictiveQuery,
@@ -761,28 +823,30 @@ class KumoRFM:
761
823
  evaluate: bool,
762
824
  ) -> None:
763
825
 
764
- if self._graph_store.min_time == pd.Timestamp.max:
826
+ if len(self._sampler.time_column_dict) == 0:
765
827
  return # Graph without timestamps
766
828
 
767
- if anchor_time < self._graph_store.min_time:
829
+ min_time = self._sampler.get_min_time()
830
+ max_time = self._sampler.get_max_time()
831
+
832
+ if anchor_time < min_time:
768
833
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
769
- f"the earliest timestamp "
770
- f"'{self._graph_store.min_time}' in the data.")
834
+ f"the earliest timestamp '{min_time}' in the "
835
+ f"data.")
771
836
 
772
- if (context_anchor_time is not None
773
- and context_anchor_time < self._graph_store.min_time):
837
+ if context_anchor_time is not None and context_anchor_time < min_time:
774
838
  raise ValueError(f"Context anchor timestamp is too early or "
775
839
  f"aggregation time range is too large. To make "
776
840
  f"this prediction, we would need data back to "
777
841
  f"'{context_anchor_time}', however, your data "
778
- f"only contains data back to "
779
- f"'{self._graph_store.min_time}'.")
842
+ f"only contains data back to '{min_time}'.")
780
843
 
781
844
  if query.target_ast.date_offset_range is not None:
782
845
  end_offset = query.target_ast.date_offset_range.end_date_offset
783
846
  else:
784
847
  end_offset = pd.DateOffset(0)
785
- forecast_end_offset = end_offset * query.num_forecasts
848
+ end_offset = end_offset * query.num_forecasts
849
+
786
850
  if (context_anchor_time is not None
787
851
  and context_anchor_time > anchor_time):
788
852
  warnings.warn(f"Context anchor timestamp "
@@ -792,7 +856,7 @@ class KumoRFM:
792
856
  f"intended.")
793
857
  elif (query.query_type == QueryType.TEMPORAL
794
858
  and context_anchor_time is not None
795
- and context_anchor_time + forecast_end_offset > anchor_time):
859
+ and context_anchor_time + end_offset > anchor_time):
796
860
  warnings.warn(f"Aggregation for context examples at timestamp "
797
861
  f"'{context_anchor_time}' will leak information "
798
862
  f"from the prediction anchor timestamp "
@@ -800,26 +864,23 @@ class KumoRFM:
800
864
  f"intended.")
801
865
 
802
866
  elif (context_anchor_time is not None
803
- and context_anchor_time - forecast_end_offset
804
- < self._graph_store.min_time):
805
- _time = context_anchor_time - forecast_end_offset
867
+ and context_anchor_time - end_offset < min_time):
868
+ _time = context_anchor_time - end_offset
806
869
  warnings.warn(f"Context anchor timestamp is too early or "
807
870
  f"aggregation time range is too large. To form "
808
871
  f"proper input data, we would need data back to "
809
872
  f"'{_time}', however, your data only contains "
810
- f"data back to '{self._graph_store.min_time}'.")
873
+ f"data back to '{min_time}'.")
811
874
 
812
- if (not evaluate and anchor_time
813
- > self._graph_store.max_time + pd.DateOffset(days=1)):
875
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
814
876
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
815
- f"latest timestamp '{self._graph_store.max_time}' "
816
- f"in the data. Please make sure this is intended.")
877
+ f"latest timestamp '{max_time}' in the data. Please "
878
+ f"make sure this is intended.")
817
879
 
818
- max_eval_time = self._graph_store.max_time - forecast_end_offset
819
- if evaluate and anchor_time > max_eval_time:
880
+ if evaluate and anchor_time > max_time - end_offset:
820
881
  raise ValueError(
821
882
  f"Anchor timestamp for evaluation is after the latest "
822
- f"supported timestamp '{max_eval_time}'.")
883
+ f"supported timestamp '{max_time - end_offset}'.")
823
884
 
824
885
  def _get_context(
825
886
  self,
@@ -850,10 +911,9 @@ class KumoRFM:
850
911
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
851
912
  f"must go beyond this for your use-case.")
852
913
 
853
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
854
- task_type = LocalPQueryDriver.get_task_type(
855
- query,
856
- edge_types=self._graph_store.edge_types,
914
+ task_type = self._get_task_type(
915
+ query=query,
916
+ edge_types=self._sampler.edge_types,
857
917
  )
858
918
 
859
919
  if logger is not None:
@@ -885,14 +945,17 @@ class KumoRFM:
885
945
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
886
946
 
887
947
  if query.target_ast.date_offset_range is None:
888
- end_offset = pd.DateOffset(0)
948
+ step_offset = pd.DateOffset(0)
889
949
  else:
890
- end_offset = query.target_ast.date_offset_range.end_date_offset
891
- forecast_end_offset = end_offset * query.num_forecasts
950
+ step_offset = query.target_ast.date_offset_range.end_date_offset
951
+ end_offset = step_offset * query.num_forecasts
952
+
892
953
  if anchor_time is None:
893
- anchor_time = self._graph_store.max_time
954
+ anchor_time = self._get_default_anchor_time(query)
955
+
894
956
  if evaluate:
895
- anchor_time = anchor_time - forecast_end_offset
957
+ anchor_time = anchor_time - end_offset
958
+
896
959
  if logger is not None:
897
960
  assert isinstance(anchor_time, pd.Timestamp)
898
961
  if anchor_time == pd.Timestamp.min:
@@ -906,57 +969,71 @@ class KumoRFM:
906
969
 
907
970
  assert anchor_time is not None
908
971
  if isinstance(anchor_time, pd.Timestamp):
972
+ if context_anchor_time == 'entity':
973
+ raise ValueError("Anchor time 'entity' needs to be shared "
974
+ "for context and prediction examples")
909
975
  if context_anchor_time is None:
910
- context_anchor_time = anchor_time - forecast_end_offset
976
+ context_anchor_time = anchor_time - end_offset
911
977
  self._validate_time(query, anchor_time, context_anchor_time,
912
978
  evaluate)
913
979
  else:
914
980
  assert anchor_time == 'entity'
915
- if query.entity_table not in self._graph_store.time_dict:
981
+ if query.query_type != QueryType.STATIC:
982
+ raise ValueError("Anchor time 'entity' is only valid for "
983
+ "static predictive queries")
984
+ if query.entity_table not in self._sampler.time_column_dict:
916
985
  raise ValueError(f"Anchor time 'entity' requires the entity "
917
986
  f"table '{query.entity_table}' to "
918
987
  f"have a time column")
919
- if context_anchor_time is not None:
920
- warnings.warn("Ignoring option 'context_anchor_time' for "
921
- "`anchor_time='entity'`")
922
- context_anchor_time = None
988
+ if isinstance(context_anchor_time, pd.Timestamp):
989
+ raise ValueError("Anchor time 'entity' needs to be shared "
990
+ "for context and prediction examples")
991
+ context_anchor_time = 'entity'
923
992
 
924
- y_test: Optional[pd.Series] = None
993
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
925
994
  if evaluate:
926
- max_test_size = _MAX_TEST_SIZE[run_mode]
995
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
927
996
  if task_type.is_link_pred:
928
- max_test_size = max_test_size // 5
997
+ num_test_examples = num_test_examples // 5
998
+ else:
999
+ num_test_examples = 0
1000
+
1001
+ train, test = self._sampler.sample_target(
1002
+ query=query,
1003
+ num_train_examples=num_train_examples,
1004
+ train_anchor_time=context_anchor_time,
1005
+ num_train_trials=max_pq_iterations * num_train_examples,
1006
+ num_test_examples=num_test_examples,
1007
+ test_anchor_time=anchor_time,
1008
+ num_test_trials=max_pq_iterations * num_test_examples,
1009
+ random_seed=random_seed,
1010
+ )
1011
+ train_pkey, train_time, y_train = train
1012
+ test_pkey, test_time, y_test = test
929
1013
 
930
- test_node, test_time, y_test = query_driver.collect_test(
931
- size=max_test_size,
932
- anchor_time=anchor_time,
933
- max_iterations=max_pq_iterations,
934
- guarantee_train_examples=True,
935
- )
936
- if logger is not None:
937
- if task_type == TaskType.BINARY_CLASSIFICATION:
938
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
939
- msg = (f"Collected {len(y_test):,} test examples with "
940
- f"{pos:.2f}% positive cases")
941
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
942
- msg = (f"Collected {len(y_test):,} test examples "
943
- f"holding {y_test.nunique()} classes")
944
- elif task_type == TaskType.REGRESSION:
945
- _min, _max = float(y_test.min()), float(y_test.max())
946
- msg = (f"Collected {len(y_test):,} test examples with "
947
- f"targets between {format_value(_min)} and "
948
- f"{format_value(_max)}")
949
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
950
- num_rhs = y_test.explode().nunique()
951
- msg = (f"Collected {len(y_test):,} test examples with "
952
- f"{num_rhs:,} unique items")
953
- else:
954
- raise NotImplementedError
955
- logger.log(msg)
1014
+ if evaluate and logger is not None:
1015
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1016
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1017
+ msg = (f"Collected {len(y_test):,} test examples with "
1018
+ f"{pos:.2f}% positive cases")
1019
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1020
+ msg = (f"Collected {len(y_test):,} test examples holding "
1021
+ f"{y_test.nunique()} classes")
1022
+ elif task_type == TaskType.REGRESSION:
1023
+ _min, _max = float(y_test.min()), float(y_test.max())
1024
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1025
+ f"between {format_value(_min)} and "
1026
+ f"{format_value(_max)}")
1027
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1028
+ num_rhs = y_test.explode().nunique()
1029
+ msg = (f"Collected {len(y_test):,} test examples with "
1030
+ f"{num_rhs:,} unique items")
1031
+ else:
1032
+ raise NotImplementedError
1033
+ logger.log(msg)
956
1034
 
957
- else:
1035
+ if not evaluate:
958
1036
  assert indices is not None
959
-
960
1037
  if len(indices) > _MAX_PRED_SIZE[task_type]:
961
1038
  raise ValueError(f"Cannot predict for more than "
962
1039
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -964,26 +1041,12 @@ class KumoRFM:
964
1041
  f"`KumoRFM.batch_mode` to process entities "
965
1042
  f"in batches")
966
1043
 
967
- test_node = self._graph_store.get_node_id(
968
- table_name=query.entity_table,
969
- pkey=pd.Series(indices),
970
- )
971
-
1044
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
972
1045
  if isinstance(anchor_time, pd.Timestamp):
973
- test_time = pd.Series(anchor_time).repeat(
974
- len(test_node)).reset_index(drop=True)
1046
+ test_time = pd.Series([anchor_time]).repeat(
1047
+ len(indices)).reset_index(drop=True)
975
1048
  else:
976
- time = self._graph_store.time_dict[query.entity_table]
977
- time = time[test_node] * 1000**3
978
- test_time = pd.Series(time, dtype='datetime64[ns]')
979
-
980
- train_node, train_time, y_train = query_driver.collect_train(
981
- size=_MAX_CONTEXT_SIZE[run_mode],
982
- anchor_time=context_anchor_time or 'entity',
983
- exclude_node=test_node if (query.query_type == QueryType.STATIC
984
- or anchor_time == 'entity') else None,
985
- max_iterations=max_pq_iterations,
986
- )
1049
+ train_time = test_time = 'entity'
987
1050
 
988
1051
  if logger is not None:
989
1052
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1011,7 +1074,7 @@ class KumoRFM:
1011
1074
  final_aggr = query.get_final_target_aggregation()
1012
1075
  assert final_aggr is not None
1013
1076
  edge_fkey = final_aggr._get_target_column_name()
1014
- for edge_type in self._graph_store.edge_types:
1077
+ for edge_type in self._sampler.edge_types:
1015
1078
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
1079
  entity_table_names = (
1017
1080
  query.entity_table,
@@ -1023,20 +1086,24 @@ class KumoRFM:
1023
1086
  # Exclude the entity anchor time from the feature set to prevent
1024
1087
  # running out-of-distribution between in-context and test examples:
1025
1088
  exclude_cols_dict = query.get_exclude_cols_dict()
1026
- if anchor_time == 'entity':
1089
+ if entity_table_names[0] in self._sampler.time_column_dict:
1027
1090
  if entity_table_names[0] not in exclude_cols_dict:
1028
1091
  exclude_cols_dict[entity_table_names[0]] = []
1029
- time_column_dict = self._graph_store.time_column_dict
1030
- time_column = time_column_dict[entity_table_names[0]]
1092
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1031
1093
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1032
1094
 
1033
- subgraph = self._graph_sampler(
1095
+ subgraph = self._sampler.sample_subgraph(
1034
1096
  entity_table_names=entity_table_names,
1035
- node=np.concatenate([train_node, test_node]),
1036
- time=np.concatenate([
1037
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1038
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- ]),
1097
+ entity_pkey=pd.concat(
1098
+ [train_pkey, test_pkey],
1099
+ axis=0,
1100
+ ignore_index=True,
1101
+ ),
1102
+ anchor_time=pd.concat(
1103
+ [train_time, test_time],
1104
+ axis=0,
1105
+ ignore_index=True,
1106
+ ) if isinstance(train_time, pd.Series) else 'entity',
1040
1107
  num_neighbors=num_neighbors,
1041
1108
  exclude_cols_dict=exclude_cols_dict,
1042
1109
  )
@@ -1048,18 +1115,14 @@ class KumoRFM:
1048
1115
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1049
1116
  f"must go beyond this for your use-case.")
1050
1117
 
1051
- step_size: Optional[int] = None
1052
- if query.query_type == QueryType.TEMPORAL:
1053
- step_size = date_offset_to_seconds(end_offset)
1054
-
1055
1118
  return Context(
1056
1119
  task_type=task_type,
1057
1120
  entity_table_names=entity_table_names,
1058
1121
  subgraph=subgraph,
1059
1122
  y_train=y_train,
1060
- y_test=y_test,
1123
+ y_test=y_test if evaluate else None,
1061
1124
  top_k=query.top_k,
1062
- step_size=step_size,
1125
+ step_size=None,
1063
1126
  )
1064
1127
 
1065
1128
  @staticmethod
@@ -370,9 +370,11 @@ class PredictiveQuery:
370
370
  train_table_job_api = global_state.client.generate_train_table_job_api
371
371
  job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
372
  GenerateTrainTableRequest(
373
- dict(custom_tags), pq_id, plan,
374
- graph_snapshot_id=self.graph.snapshot(
375
- non_blocking=non_blocking)))
373
+ dict(custom_tags),
374
+ pq_id,
375
+ plan,
376
+ None,
377
+ ))
376
378
 
377
379
  self._train_table = TrainingTableJob(job_id=job_id)
378
380
  if non_blocking:
@@ -451,9 +453,11 @@ class PredictiveQuery:
451
453
  bp_table_api = global_state.client.generate_prediction_table_job_api
452
454
  job_id: GeneratePredictionTableJobID = bp_table_api.create(
453
455
  GeneratePredictionTableRequest(
454
- dict(custom_tags), pq_id, plan,
455
- graph_snapshot_id=self.graph.snapshot(
456
- non_blocking=non_blocking)))
456
+ dict(custom_tags),
457
+ pq_id,
458
+ plan,
459
+ None,
460
+ ))
457
461
 
458
462
  self._prediction_table = PredictionTableJob(job_id=job_id)
459
463
  if non_blocking: