kumoai 2.13.0.dev202512040651__cp312-cp312-win_amd64.whl → 2.14.0.dev202512111731__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,6 +21,13 @@ import numpy as np
21
21
  import pandas as pd
22
22
  from kumoapi.model_plan import RunMode
23
23
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.pquery.AST import (
25
+ Aggregation,
26
+ Column,
27
+ Condition,
28
+ Join,
29
+ LogicalOperation,
30
+ )
24
31
  from kumoapi.rfm import Context
25
32
  from kumoapi.rfm import Explanation as ExplanationConfig
26
33
  from kumoapi.rfm import (
@@ -29,16 +36,11 @@ from kumoapi.rfm import (
29
36
  RFMPredictRequest,
30
37
  )
31
38
  from kumoapi.task import TaskType
39
+ from kumoapi.typing import AggregationType, Stype
32
40
 
33
41
  from kumoai.client.rfm import RFMAPI
34
42
  from kumoai.exceptions import HTTPException
35
43
  from kumoai.experimental.rfm import Graph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
42
44
  from kumoai.mixin import CastMixin
43
45
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
44
46
 
@@ -159,8 +161,9 @@ class KumoRFM:
159
161
  ) -> None:
160
162
  graph = graph.validate()
161
163
  self._graph_def = graph._to_api_graph_definition()
162
- self._graph_store = LocalGraphStore(graph, verbose)
163
- self._graph_sampler = LocalGraphSampler(self._graph_store)
164
+
165
+ from kumoai.experimental.rfm.backend.local import LocalSampler
166
+ self._sampler = LocalSampler(graph, verbose)
164
167
 
165
168
  self._client: Optional[RFMAPI] = None
166
169
 
@@ -224,7 +227,7 @@ class KumoRFM:
224
227
  run_mode: Union[RunMode, str] = RunMode.FAST,
225
228
  num_neighbors: Optional[List[int]] = None,
226
229
  num_hops: int = 2,
227
- max_pq_iterations: int = 20,
230
+ max_pq_iterations: int = 10,
228
231
  random_seed: Optional[int] = _RANDOM_SEED,
229
232
  verbose: Union[bool, ProgressLogger] = True,
230
233
  use_prediction_time: bool = False,
@@ -243,7 +246,7 @@ class KumoRFM:
243
246
  run_mode: Union[RunMode, str] = RunMode.FAST,
244
247
  num_neighbors: Optional[List[int]] = None,
245
248
  num_hops: int = 2,
246
- max_pq_iterations: int = 20,
249
+ max_pq_iterations: int = 10,
247
250
  random_seed: Optional[int] = _RANDOM_SEED,
248
251
  verbose: Union[bool, ProgressLogger] = True,
249
252
  use_prediction_time: bool = False,
@@ -261,7 +264,7 @@ class KumoRFM:
261
264
  run_mode: Union[RunMode, str] = RunMode.FAST,
262
265
  num_neighbors: Optional[List[int]] = None,
263
266
  num_hops: int = 2,
264
- max_pq_iterations: int = 20,
267
+ max_pq_iterations: int = 10,
265
268
  random_seed: Optional[int] = _RANDOM_SEED,
266
269
  verbose: Union[bool, ProgressLogger] = True,
267
270
  use_prediction_time: bool = False,
@@ -357,9 +360,9 @@ class KumoRFM:
357
360
 
358
361
  batch_size: Optional[int] = None
359
362
  if self._batch_size == 'max':
360
- task_type = LocalPQueryDriver.get_task_type(
361
- query_def,
362
- edge_types=self._graph_store.edge_types,
363
+ task_type = self._get_task_type(
364
+ query=query_def,
365
+ edge_types=self._sampler.edge_types,
363
366
  )
364
367
  batch_size = _MAX_PRED_SIZE[task_type]
365
368
  else:
@@ -433,10 +436,10 @@ class KumoRFM:
433
436
 
434
437
  # Cast 'ENTITY' to correct data type:
435
438
  if 'ENTITY' in df:
436
- entity = query_def.entity_table
437
- pkey_map = self._graph_store.pkey_map_dict[entity]
438
- df['ENTITY'] = df['ENTITY'].astype(
439
- type(pkey_map.index[0]))
439
+ table_dict = context.subgraph.table_dict
440
+ table = table_dict[query_def.entity_table]
441
+ ser = table.df[table.primary_key]
442
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
440
443
 
441
444
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
442
445
  if 'ANCHOR_TIMESTAMP' in df:
@@ -519,23 +522,18 @@ class KumoRFM:
519
522
  raise ValueError("At least one entity is required")
520
523
 
521
524
  if anchor_time is None:
522
- anchor_time = self._graph_store.max_time
525
+ anchor_time = self._get_default_anchor_time(query_def)
523
526
 
524
527
  if isinstance(anchor_time, pd.Timestamp):
525
528
  self._validate_time(query_def, anchor_time, None, False)
526
529
  else:
527
530
  assert anchor_time == 'entity'
528
- if (query_def.entity_table not in self._graph_store.time_dict):
531
+ if query_def.entity_table not in self._sampler.time_column_dict:
529
532
  raise ValueError(f"Anchor time 'entity' requires the entity "
530
533
  f"table '{query_def.entity_table}' "
531
534
  f"to have a time column.")
532
535
 
533
- node = self._graph_store.get_node_id(
534
- table_name=query_def.entity_table,
535
- pkey=pd.Series(indices),
536
- )
537
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
538
- return query_driver.is_valid(node, anchor_time)
536
+ raise NotImplementedError
539
537
 
540
538
  def evaluate(
541
539
  self,
@@ -547,7 +545,7 @@ class KumoRFM:
547
545
  run_mode: Union[RunMode, str] = RunMode.FAST,
548
546
  num_neighbors: Optional[List[int]] = None,
549
547
  num_hops: int = 2,
550
- max_pq_iterations: int = 20,
548
+ max_pq_iterations: int = 10,
551
549
  random_seed: Optional[int] = _RANDOM_SEED,
552
550
  verbose: Union[bool, ProgressLogger] = True,
553
551
  use_prediction_time: bool = False,
@@ -658,7 +656,7 @@ class KumoRFM:
658
656
  *,
659
657
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
660
658
  random_seed: Optional[int] = _RANDOM_SEED,
661
- max_iterations: int = 20,
659
+ max_iterations: int = 10,
662
660
  ) -> pd.DataFrame:
663
661
  """Returns the labels of a predictive query for a specified anchor
664
662
  time.
@@ -678,40 +676,37 @@ class KumoRFM:
678
676
  query_def = self._parse_query(query)
679
677
 
680
678
  if anchor_time is None:
681
- anchor_time = self._graph_store.max_time
679
+ anchor_time = self._get_default_anchor_time(query_def)
682
680
  if query_def.target_ast.date_offset_range is not None:
683
- anchor_time = anchor_time - (
684
- query_def.target_ast.date_offset_range.end_date_offset *
685
- query_def.num_forecasts)
681
+ offset = query_def.target_ast.date_offset_range.end_date_offset
682
+ offset *= query_def.num_forecasts
683
+ anchor_time -= offset
686
684
 
687
685
  assert anchor_time is not None
688
686
  if isinstance(anchor_time, pd.Timestamp):
689
687
  self._validate_time(query_def, anchor_time, None, evaluate=True)
690
688
  else:
691
689
  assert anchor_time == 'entity'
692
- if (query_def.entity_table not in self._graph_store.time_dict):
690
+ if query_def.entity_table not in self._sampler.time_column_dict:
693
691
  raise ValueError(f"Anchor time 'entity' requires the entity "
694
692
  f"table '{query_def.entity_table}' "
695
693
  f"to have a time column")
696
694
 
697
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
698
- random_seed)
699
-
700
- node, time, y = query_driver.collect_test(
701
- size=size,
702
- anchor_time=anchor_time,
703
- batch_size=min(10_000, size),
704
- max_iterations=max_iterations,
705
- guarantee_train_examples=False,
695
+ train, test = self._sampler.sample_target(
696
+ query=query,
697
+ num_train_examples=0,
698
+ train_anchor_time=anchor_time,
699
+ num_train_trials=0,
700
+ num_test_examples=size,
701
+ test_anchor_time=anchor_time,
702
+ num_test_trials=max_iterations * size,
703
+ random_seed=random_seed,
706
704
  )
707
705
 
708
- entity = self._graph_store.pkey_map_dict[
709
- query_def.entity_table].index[node]
710
-
711
706
  return pd.DataFrame({
712
- 'ENTITY': entity,
713
- 'ANCHOR_TIMESTAMP': time,
714
- 'TARGET': y,
707
+ 'ENTITY': test.entity_pkey,
708
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
709
+ 'TARGET': test.target,
715
710
  })
716
711
 
717
712
  # Helpers #################################################################
@@ -734,8 +729,6 @@ class KumoRFM:
734
729
 
735
730
  resp = self._api_client.parse_query(request)
736
731
 
737
- # TODO Expose validation warnings.
738
-
739
732
  if len(resp.validation_response.warnings) > 0:
740
733
  msg = '\n'.join([
741
734
  f'{i+1}. {warning.title}: {warning.message}' for i, warning
@@ -753,6 +746,60 @@ class KumoRFM:
753
746
  raise ValueError(f"Failed to parse query '{query}'. "
754
747
  f"{msg}") from None
755
748
 
749
+ @staticmethod
750
+ def _get_task_type(
751
+ query: ValidatedPredictiveQuery,
752
+ edge_types: List[Tuple[str, str, str]],
753
+ ) -> TaskType:
754
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
755
+ return TaskType.BINARY_CLASSIFICATION
756
+
757
+ target = query.target_ast
758
+ if isinstance(target, Join):
759
+ target = target.rhs_target
760
+ if isinstance(target, Aggregation):
761
+ if target.aggr == AggregationType.LIST_DISTINCT:
762
+ table_name, col_name = target._get_target_column_name().split(
763
+ '.')
764
+ target_edge_types = [
765
+ edge_type for edge_type in edge_types
766
+ if edge_type[0] == table_name and edge_type[1] == col_name
767
+ ]
768
+ if len(target_edge_types) != 1:
769
+ raise NotImplementedError(
770
+ f"Multilabel-classification queries based on "
771
+ f"'LIST_DISTINCT' are not supported yet. If you "
772
+ f"planned to write a link prediction query instead, "
773
+ f"make sure to register '{col_name}' as a "
774
+ f"foreign key.")
775
+ return TaskType.TEMPORAL_LINK_PREDICTION
776
+
777
+ return TaskType.REGRESSION
778
+
779
+ assert isinstance(target, Column)
780
+
781
+ if target.stype in {Stype.ID, Stype.categorical}:
782
+ return TaskType.MULTICLASS_CLASSIFICATION
783
+
784
+ if target.stype in {Stype.numerical}:
785
+ return TaskType.REGRESSION
786
+
787
+ raise NotImplementedError("Task type not yet supported")
788
+
789
+ def _get_default_anchor_time(
790
+ self,
791
+ query: ValidatedPredictiveQuery,
792
+ ) -> pd.Timestamp:
793
+ if query.query_type == QueryType.TEMPORAL:
794
+ aggr_table_names = [
795
+ aggr._get_target_column_name().split('.')[0]
796
+ for aggr in query.get_all_target_aggregations()
797
+ ]
798
+ return self._sampler.get_max_time(aggr_table_names)
799
+
800
+ assert query.query_type == QueryType.STATIC
801
+ return self._sampler.get_max_time()
802
+
756
803
  def _validate_time(
757
804
  self,
758
805
  query: ValidatedPredictiveQuery,
@@ -761,28 +808,30 @@ class KumoRFM:
761
808
  evaluate: bool,
762
809
  ) -> None:
763
810
 
764
- if self._graph_store.min_time == pd.Timestamp.max:
811
+ if len(self._sampler.time_column_dict) == 0:
765
812
  return # Graph without timestamps
766
813
 
767
- if anchor_time < self._graph_store.min_time:
814
+ min_time = self._sampler.get_min_time()
815
+ max_time = self._sampler.get_max_time()
816
+
817
+ if anchor_time < min_time:
768
818
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
769
- f"the earliest timestamp "
770
- f"'{self._graph_store.min_time}' in the data.")
819
+ f"the earliest timestamp '{min_time}' in the "
820
+ f"data.")
771
821
 
772
- if (context_anchor_time is not None
773
- and context_anchor_time < self._graph_store.min_time):
822
+ if context_anchor_time is not None and context_anchor_time < min_time:
774
823
  raise ValueError(f"Context anchor timestamp is too early or "
775
824
  f"aggregation time range is too large. To make "
776
825
  f"this prediction, we would need data back to "
777
826
  f"'{context_anchor_time}', however, your data "
778
- f"only contains data back to "
779
- f"'{self._graph_store.min_time}'.")
827
+ f"only contains data back to '{min_time}'.")
780
828
 
781
829
  if query.target_ast.date_offset_range is not None:
782
830
  end_offset = query.target_ast.date_offset_range.end_date_offset
783
831
  else:
784
832
  end_offset = pd.DateOffset(0)
785
- forecast_end_offset = end_offset * query.num_forecasts
833
+ end_offset = end_offset * query.num_forecasts
834
+
786
835
  if (context_anchor_time is not None
787
836
  and context_anchor_time > anchor_time):
788
837
  warnings.warn(f"Context anchor timestamp "
@@ -792,7 +841,7 @@ class KumoRFM:
792
841
  f"intended.")
793
842
  elif (query.query_type == QueryType.TEMPORAL
794
843
  and context_anchor_time is not None
795
- and context_anchor_time + forecast_end_offset > anchor_time):
844
+ and context_anchor_time + end_offset > anchor_time):
796
845
  warnings.warn(f"Aggregation for context examples at timestamp "
797
846
  f"'{context_anchor_time}' will leak information "
798
847
  f"from the prediction anchor timestamp "
@@ -800,26 +849,23 @@ class KumoRFM:
800
849
  f"intended.")
801
850
 
802
851
  elif (context_anchor_time is not None
803
- and context_anchor_time - forecast_end_offset
804
- < self._graph_store.min_time):
805
- _time = context_anchor_time - forecast_end_offset
852
+ and context_anchor_time - end_offset < min_time):
853
+ _time = context_anchor_time - end_offset
806
854
  warnings.warn(f"Context anchor timestamp is too early or "
807
855
  f"aggregation time range is too large. To form "
808
856
  f"proper input data, we would need data back to "
809
857
  f"'{_time}', however, your data only contains "
810
- f"data back to '{self._graph_store.min_time}'.")
858
+ f"data back to '{min_time}'.")
811
859
 
812
- if (not evaluate and anchor_time
813
- > self._graph_store.max_time + pd.DateOffset(days=1)):
860
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
814
861
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
815
- f"latest timestamp '{self._graph_store.max_time}' "
816
- f"in the data. Please make sure this is intended.")
862
+ f"latest timestamp '{max_time}' in the data. Please "
863
+ f"make sure this is intended.")
817
864
 
818
- max_eval_time = self._graph_store.max_time - forecast_end_offset
819
- if evaluate and anchor_time > max_eval_time:
865
+ if evaluate and anchor_time > max_time - end_offset:
820
866
  raise ValueError(
821
867
  f"Anchor timestamp for evaluation is after the latest "
822
- f"supported timestamp '{max_eval_time}'.")
868
+ f"supported timestamp '{max_time - end_offset}'.")
823
869
 
824
870
  def _get_context(
825
871
  self,
@@ -850,10 +896,9 @@ class KumoRFM:
850
896
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
851
897
  f"must go beyond this for your use-case.")
852
898
 
853
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
854
- task_type = LocalPQueryDriver.get_task_type(
855
- query,
856
- edge_types=self._graph_store.edge_types,
899
+ task_type = self._get_task_type(
900
+ query=query,
901
+ edge_types=self._sampler.edge_types,
857
902
  )
858
903
 
859
904
  if logger is not None:
@@ -885,14 +930,17 @@ class KumoRFM:
885
930
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
886
931
 
887
932
  if query.target_ast.date_offset_range is None:
888
- end_offset = pd.DateOffset(0)
933
+ step_offset = pd.DateOffset(0)
889
934
  else:
890
- end_offset = query.target_ast.date_offset_range.end_date_offset
891
- forecast_end_offset = end_offset * query.num_forecasts
935
+ step_offset = query.target_ast.date_offset_range.end_date_offset
936
+ end_offset = step_offset * query.num_forecasts
937
+
892
938
  if anchor_time is None:
893
- anchor_time = self._graph_store.max_time
939
+ anchor_time = self._get_default_anchor_time(query)
940
+
894
941
  if evaluate:
895
- anchor_time = anchor_time - forecast_end_offset
942
+ anchor_time = anchor_time - end_offset
943
+
896
944
  if logger is not None:
897
945
  assert isinstance(anchor_time, pd.Timestamp)
898
946
  if anchor_time == pd.Timestamp.min:
@@ -906,57 +954,71 @@ class KumoRFM:
906
954
 
907
955
  assert anchor_time is not None
908
956
  if isinstance(anchor_time, pd.Timestamp):
957
+ if context_anchor_time == 'entity':
958
+ raise ValueError("Anchor time 'entity' needs to be shared "
959
+ "for context and prediction examples")
909
960
  if context_anchor_time is None:
910
- context_anchor_time = anchor_time - forecast_end_offset
961
+ context_anchor_time = anchor_time - end_offset
911
962
  self._validate_time(query, anchor_time, context_anchor_time,
912
963
  evaluate)
913
964
  else:
914
965
  assert anchor_time == 'entity'
915
- if query.entity_table not in self._graph_store.time_dict:
966
+ if query.query_type != QueryType.STATIC:
967
+ raise ValueError("Anchor time 'entity' is only valid for "
968
+ "static predictive queries")
969
+ if query.entity_table not in self._sampler.time_column_dict:
916
970
  raise ValueError(f"Anchor time 'entity' requires the entity "
917
971
  f"table '{query.entity_table}' to "
918
972
  f"have a time column")
919
- if context_anchor_time is not None:
920
- warnings.warn("Ignoring option 'context_anchor_time' for "
921
- "`anchor_time='entity'`")
922
- context_anchor_time = None
973
+ if isinstance(context_anchor_time, pd.Timestamp):
974
+ raise ValueError("Anchor time 'entity' needs to be shared "
975
+ "for context and prediction examples")
976
+ context_anchor_time = 'entity'
923
977
 
924
- y_test: Optional[pd.Series] = None
978
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
925
979
  if evaluate:
926
- max_test_size = _MAX_TEST_SIZE[run_mode]
980
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
927
981
  if task_type.is_link_pred:
928
- max_test_size = max_test_size // 5
982
+ num_test_examples = num_test_examples // 5
983
+ else:
984
+ num_test_examples = 0
985
+
986
+ train, test = self._sampler.sample_target(
987
+ query=query,
988
+ num_train_examples=num_train_examples,
989
+ train_anchor_time=context_anchor_time,
990
+ num_train_trials=max_pq_iterations * num_train_examples,
991
+ num_test_examples=num_test_examples,
992
+ test_anchor_time=anchor_time,
993
+ num_test_trials=max_pq_iterations * num_test_examples,
994
+ random_seed=random_seed,
995
+ )
996
+ train_pkey, train_time, y_train = train
997
+ test_pkey, test_time, y_test = test
929
998
 
930
- test_node, test_time, y_test = query_driver.collect_test(
931
- size=max_test_size,
932
- anchor_time=anchor_time,
933
- max_iterations=max_pq_iterations,
934
- guarantee_train_examples=True,
935
- )
936
- if logger is not None:
937
- if task_type == TaskType.BINARY_CLASSIFICATION:
938
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
939
- msg = (f"Collected {len(y_test):,} test examples with "
940
- f"{pos:.2f}% positive cases")
941
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
942
- msg = (f"Collected {len(y_test):,} test examples "
943
- f"holding {y_test.nunique()} classes")
944
- elif task_type == TaskType.REGRESSION:
945
- _min, _max = float(y_test.min()), float(y_test.max())
946
- msg = (f"Collected {len(y_test):,} test examples with "
947
- f"targets between {format_value(_min)} and "
948
- f"{format_value(_max)}")
949
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
950
- num_rhs = y_test.explode().nunique()
951
- msg = (f"Collected {len(y_test):,} test examples with "
952
- f"{num_rhs:,} unique items")
953
- else:
954
- raise NotImplementedError
955
- logger.log(msg)
999
+ if evaluate and logger is not None:
1000
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1001
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1002
+ msg = (f"Collected {len(y_test):,} test examples with "
1003
+ f"{pos:.2f}% positive cases")
1004
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1005
+ msg = (f"Collected {len(y_test):,} test examples holding "
1006
+ f"{y_test.nunique()} classes")
1007
+ elif task_type == TaskType.REGRESSION:
1008
+ _min, _max = float(y_test.min()), float(y_test.max())
1009
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1010
+ f"between {format_value(_min)} and "
1011
+ f"{format_value(_max)}")
1012
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1013
+ num_rhs = y_test.explode().nunique()
1014
+ msg = (f"Collected {len(y_test):,} test examples with "
1015
+ f"{num_rhs:,} unique items")
1016
+ else:
1017
+ raise NotImplementedError
1018
+ logger.log(msg)
956
1019
 
957
- else:
1020
+ if not evaluate:
958
1021
  assert indices is not None
959
-
960
1022
  if len(indices) > _MAX_PRED_SIZE[task_type]:
961
1023
  raise ValueError(f"Cannot predict for more than "
962
1024
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -964,26 +1026,12 @@ class KumoRFM:
964
1026
  f"`KumoRFM.batch_mode` to process entities "
965
1027
  f"in batches")
966
1028
 
967
- test_node = self._graph_store.get_node_id(
968
- table_name=query.entity_table,
969
- pkey=pd.Series(indices),
970
- )
971
-
1029
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
972
1030
  if isinstance(anchor_time, pd.Timestamp):
973
- test_time = pd.Series(anchor_time).repeat(
974
- len(test_node)).reset_index(drop=True)
1031
+ test_time = pd.Series([anchor_time]).repeat(
1032
+ len(indices)).reset_index(drop=True)
975
1033
  else:
976
- time = self._graph_store.time_dict[query.entity_table]
977
- time = time[test_node] * 1000**3
978
- test_time = pd.Series(time, dtype='datetime64[ns]')
979
-
980
- train_node, train_time, y_train = query_driver.collect_train(
981
- size=_MAX_CONTEXT_SIZE[run_mode],
982
- anchor_time=context_anchor_time or 'entity',
983
- exclude_node=test_node if (query.query_type == QueryType.STATIC
984
- or anchor_time == 'entity') else None,
985
- max_iterations=max_pq_iterations,
986
- )
1034
+ train_time = test_time = 'entity'
987
1035
 
988
1036
  if logger is not None:
989
1037
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1011,7 +1059,7 @@ class KumoRFM:
1011
1059
  final_aggr = query.get_final_target_aggregation()
1012
1060
  assert final_aggr is not None
1013
1061
  edge_fkey = final_aggr._get_target_column_name()
1014
- for edge_type in self._graph_store.edge_types:
1062
+ for edge_type in self._sampler.edge_types:
1015
1063
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
1064
  entity_table_names = (
1017
1065
  query.entity_table,
@@ -1023,20 +1071,24 @@ class KumoRFM:
1023
1071
  # Exclude the entity anchor time from the feature set to prevent
1024
1072
  # running out-of-distribution between in-context and test examples:
1025
1073
  exclude_cols_dict = query.get_exclude_cols_dict()
1026
- if anchor_time == 'entity':
1074
+ if entity_table_names[0] in self._sampler.time_column_dict:
1027
1075
  if entity_table_names[0] not in exclude_cols_dict:
1028
1076
  exclude_cols_dict[entity_table_names[0]] = []
1029
- time_column_dict = self._graph_store.time_column_dict
1030
- time_column = time_column_dict[entity_table_names[0]]
1077
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1031
1078
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1032
1079
 
1033
- subgraph = self._graph_sampler(
1080
+ subgraph = self._sampler.sample_subgraph(
1034
1081
  entity_table_names=entity_table_names,
1035
- node=np.concatenate([train_node, test_node]),
1036
- time=np.concatenate([
1037
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1038
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- ]),
1082
+ entity_pkey=pd.concat(
1083
+ [train_pkey, test_pkey],
1084
+ axis=0,
1085
+ ignore_index=True,
1086
+ ),
1087
+ anchor_time=pd.concat(
1088
+ [train_time, test_time],
1089
+ axis=0,
1090
+ ignore_index=True,
1091
+ ) if isinstance(train_time, pd.Series) else 'entity',
1040
1092
  num_neighbors=num_neighbors,
1041
1093
  exclude_cols_dict=exclude_cols_dict,
1042
1094
  )
@@ -1048,18 +1100,14 @@ class KumoRFM:
1048
1100
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1049
1101
  f"must go beyond this for your use-case.")
1050
1102
 
1051
- step_size: Optional[int] = None
1052
- if query.query_type == QueryType.TEMPORAL:
1053
- step_size = date_offset_to_seconds(end_offset)
1054
-
1055
1103
  return Context(
1056
1104
  task_type=task_type,
1057
1105
  entity_table_names=entity_table_names,
1058
1106
  subgraph=subgraph,
1059
1107
  y_train=y_train,
1060
- y_test=y_test,
1108
+ y_test=y_test if evaluate else None,
1061
1109
  top_k=query.top_k,
1062
- step_size=step_size,
1110
+ step_size=None,
1063
1111
  )
1064
1112
 
1065
1113
  @staticmethod
Binary file
@@ -370,9 +370,11 @@ class PredictiveQuery:
370
370
  train_table_job_api = global_state.client.generate_train_table_job_api
371
371
  job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
372
  GenerateTrainTableRequest(
373
- dict(custom_tags), pq_id, plan,
374
- graph_snapshot_id=self.graph.snapshot(
375
- non_blocking=non_blocking)))
373
+ dict(custom_tags),
374
+ pq_id,
375
+ plan,
376
+ None,
377
+ ))
376
378
 
377
379
  self._train_table = TrainingTableJob(job_id=job_id)
378
380
  if non_blocking:
@@ -451,9 +453,11 @@ class PredictiveQuery:
451
453
  bp_table_api = global_state.client.generate_prediction_table_job_api
452
454
  job_id: GeneratePredictionTableJobID = bp_table_api.create(
453
455
  GeneratePredictionTableRequest(
454
- dict(custom_tags), pq_id, plan,
455
- graph_snapshot_id=self.graph.snapshot(
456
- non_blocking=non_blocking)))
456
+ dict(custom_tags),
457
+ pq_id,
458
+ plan,
459
+ None,
460
+ ))
457
461
 
458
462
  self._prediction_table = PredictionTableJob(job_id=job_id)
459
463
  if non_blocking:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.13.0.dev202512040651
3
+ Version: 2.14.0.dev202512111731
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
23
23
  Requires-Dist: urllib3
24
24
  Requires-Dist: plotly
25
25
  Requires-Dist: typing_extensions>=4.5.0
26
- Requires-Dist: kumo-api==0.48.0
26
+ Requires-Dist: kumo-api==0.49.0
27
27
  Requires-Dist: tqdm>=4.66.0
28
28
  Requires-Dist: aiohttp>=3.10.0
29
29
  Requires-Dist: pydantic>=1.10.21