kumoai 2.13.0.dev202511271731__cp312-cp312-win_amd64.whl → 2.14.0.dev202512111731__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/connector/utils.py +23 -2
  4. kumoai/experimental/rfm/__init__.py +20 -45
  5. kumoai/experimental/rfm/backend/__init__.py +0 -0
  6. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
  8. kumoai/experimental/rfm/backend/local/sampler.py +313 -0
  9. kumoai/experimental/rfm/backend/local/table.py +109 -0
  10. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  11. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  12. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  14. kumoai/experimental/rfm/base/__init__.py +13 -0
  15. kumoai/experimental/rfm/base/column.py +66 -0
  16. kumoai/experimental/rfm/base/sampler.py +763 -0
  17. kumoai/experimental/rfm/base/source.py +18 -0
  18. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  19. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  20. kumoai/experimental/rfm/infer/__init__.py +6 -0
  21. kumoai/experimental/rfm/infer/dtype.py +79 -0
  22. kumoai/experimental/rfm/infer/pkey.py +126 -0
  23. kumoai/experimental/rfm/infer/time_col.py +62 -0
  24. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  25. kumoai/experimental/rfm/rfm.py +204 -166
  26. kumoai/experimental/rfm/sagemaker.py +11 -3
  27. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  28. kumoai/pquery/predictive_query.py +10 -6
  29. kumoai/testing/decorators.py +1 -1
  30. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/METADATA +9 -8
  31. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/RECORD +34 -22
  32. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  33. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  34. kumoai/experimental/rfm/utils.py +0 -344
  35. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/WHEEL +0 -0
  36. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/licenses/LICENSE +0 -0
  37. {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,13 @@ import numpy as np
21
21
  import pandas as pd
22
22
  from kumoapi.model_plan import RunMode
23
23
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.pquery.AST import (
25
+ Aggregation,
26
+ Column,
27
+ Condition,
28
+ Join,
29
+ LogicalOperation,
30
+ )
24
31
  from kumoapi.rfm import Context
25
32
  from kumoapi.rfm import Explanation as ExplanationConfig
26
33
  from kumoapi.rfm import (
@@ -29,16 +36,11 @@ from kumoapi.rfm import (
29
36
  RFMPredictRequest,
30
37
  )
31
38
  from kumoapi.task import TaskType
39
+ from kumoapi.typing import AggregationType, Stype
32
40
 
33
41
  from kumoai.client.rfm import RFMAPI
34
42
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
43
+ from kumoai.experimental.rfm import Graph
42
44
  from kumoai.mixin import CastMixin
43
45
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
44
46
 
@@ -123,17 +125,17 @@ class KumoRFM:
123
125
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
126
  relational dataset without training.
125
127
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
128
+ model from a :class:`Graph` object.
127
129
 
128
130
  .. code-block:: python
129
131
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
132
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
133
 
132
134
  df_users = pd.DataFrame(...)
133
135
  df_items = pd.DataFrame(...)
134
136
  df_orders = pd.DataFrame(...)
135
137
 
136
- graph = LocalGraph.from_data({
138
+ graph = Graph.from_data({
137
139
  'users': df_users,
138
140
  'items': df_items,
139
141
  'orders': df_orders,
@@ -150,27 +152,18 @@ class KumoRFM:
150
152
 
151
153
  Args:
152
154
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
155
  verbose: Whether to print verbose output.
163
156
  """
164
157
  def __init__(
165
158
  self,
166
- graph: LocalGraph,
167
- preprocess: bool = False,
159
+ graph: Graph,
168
160
  verbose: Union[bool, ProgressLogger] = True,
169
161
  ) -> None:
170
162
  graph = graph.validate()
171
163
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
- self._graph_sampler = LocalGraphSampler(self._graph_store)
164
+
165
+ from kumoai.experimental.rfm.backend.local import LocalSampler
166
+ self._sampler = LocalSampler(graph, verbose)
174
167
 
175
168
  self._client: Optional[RFMAPI] = None
176
169
 
@@ -234,7 +227,7 @@ class KumoRFM:
234
227
  run_mode: Union[RunMode, str] = RunMode.FAST,
235
228
  num_neighbors: Optional[List[int]] = None,
236
229
  num_hops: int = 2,
237
- max_pq_iterations: int = 20,
230
+ max_pq_iterations: int = 10,
238
231
  random_seed: Optional[int] = _RANDOM_SEED,
239
232
  verbose: Union[bool, ProgressLogger] = True,
240
233
  use_prediction_time: bool = False,
@@ -253,7 +246,7 @@ class KumoRFM:
253
246
  run_mode: Union[RunMode, str] = RunMode.FAST,
254
247
  num_neighbors: Optional[List[int]] = None,
255
248
  num_hops: int = 2,
256
- max_pq_iterations: int = 20,
249
+ max_pq_iterations: int = 10,
257
250
  random_seed: Optional[int] = _RANDOM_SEED,
258
251
  verbose: Union[bool, ProgressLogger] = True,
259
252
  use_prediction_time: bool = False,
@@ -271,7 +264,7 @@ class KumoRFM:
271
264
  run_mode: Union[RunMode, str] = RunMode.FAST,
272
265
  num_neighbors: Optional[List[int]] = None,
273
266
  num_hops: int = 2,
274
- max_pq_iterations: int = 20,
267
+ max_pq_iterations: int = 10,
275
268
  random_seed: Optional[int] = _RANDOM_SEED,
276
269
  verbose: Union[bool, ProgressLogger] = True,
277
270
  use_prediction_time: bool = False,
@@ -367,9 +360,9 @@ class KumoRFM:
367
360
 
368
361
  batch_size: Optional[int] = None
369
362
  if self._batch_size == 'max':
370
- task_type = LocalPQueryDriver.get_task_type(
371
- query_def,
372
- edge_types=self._graph_store.edge_types,
363
+ task_type = self._get_task_type(
364
+ query=query_def,
365
+ edge_types=self._sampler.edge_types,
373
366
  )
374
367
  batch_size = _MAX_PRED_SIZE[task_type]
375
368
  else:
@@ -443,10 +436,10 @@ class KumoRFM:
443
436
 
444
437
  # Cast 'ENTITY' to correct data type:
445
438
  if 'ENTITY' in df:
446
- entity = query_def.entity_table
447
- pkey_map = self._graph_store.pkey_map_dict[entity]
448
- df['ENTITY'] = df['ENTITY'].astype(
449
- type(pkey_map.index[0]))
439
+ table_dict = context.subgraph.table_dict
440
+ table = table_dict[query_def.entity_table]
441
+ ser = table.df[table.primary_key]
442
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
450
443
 
451
444
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
452
445
  if 'ANCHOR_TIMESTAMP' in df:
@@ -529,23 +522,18 @@ class KumoRFM:
529
522
  raise ValueError("At least one entity is required")
530
523
 
531
524
  if anchor_time is None:
532
- anchor_time = self._graph_store.max_time
525
+ anchor_time = self._get_default_anchor_time(query_def)
533
526
 
534
527
  if isinstance(anchor_time, pd.Timestamp):
535
528
  self._validate_time(query_def, anchor_time, None, False)
536
529
  else:
537
530
  assert anchor_time == 'entity'
538
- if (query_def.entity_table not in self._graph_store.time_dict):
531
+ if query_def.entity_table not in self._sampler.time_column_dict:
539
532
  raise ValueError(f"Anchor time 'entity' requires the entity "
540
533
  f"table '{query_def.entity_table}' "
541
534
  f"to have a time column.")
542
535
 
543
- node = self._graph_store.get_node_id(
544
- table_name=query_def.entity_table,
545
- pkey=pd.Series(indices),
546
- )
547
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
548
- return query_driver.is_valid(node, anchor_time)
536
+ raise NotImplementedError
549
537
 
550
538
  def evaluate(
551
539
  self,
@@ -557,7 +545,7 @@ class KumoRFM:
557
545
  run_mode: Union[RunMode, str] = RunMode.FAST,
558
546
  num_neighbors: Optional[List[int]] = None,
559
547
  num_hops: int = 2,
560
- max_pq_iterations: int = 20,
548
+ max_pq_iterations: int = 10,
561
549
  random_seed: Optional[int] = _RANDOM_SEED,
562
550
  verbose: Union[bool, ProgressLogger] = True,
563
551
  use_prediction_time: bool = False,
@@ -668,7 +656,7 @@ class KumoRFM:
668
656
  *,
669
657
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
670
658
  random_seed: Optional[int] = _RANDOM_SEED,
671
- max_iterations: int = 20,
659
+ max_iterations: int = 10,
672
660
  ) -> pd.DataFrame:
673
661
  """Returns the labels of a predictive query for a specified anchor
674
662
  time.
@@ -688,40 +676,37 @@ class KumoRFM:
688
676
  query_def = self._parse_query(query)
689
677
 
690
678
  if anchor_time is None:
691
- anchor_time = self._graph_store.max_time
679
+ anchor_time = self._get_default_anchor_time(query_def)
692
680
  if query_def.target_ast.date_offset_range is not None:
693
- anchor_time = anchor_time - (
694
- query_def.target_ast.date_offset_range.end_date_offset *
695
- query_def.num_forecasts)
681
+ offset = query_def.target_ast.date_offset_range.end_date_offset
682
+ offset *= query_def.num_forecasts
683
+ anchor_time -= offset
696
684
 
697
685
  assert anchor_time is not None
698
686
  if isinstance(anchor_time, pd.Timestamp):
699
687
  self._validate_time(query_def, anchor_time, None, evaluate=True)
700
688
  else:
701
689
  assert anchor_time == 'entity'
702
- if (query_def.entity_table not in self._graph_store.time_dict):
690
+ if query_def.entity_table not in self._sampler.time_column_dict:
703
691
  raise ValueError(f"Anchor time 'entity' requires the entity "
704
692
  f"table '{query_def.entity_table}' "
705
693
  f"to have a time column")
706
694
 
707
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
708
- random_seed)
709
-
710
- node, time, y = query_driver.collect_test(
711
- size=size,
712
- anchor_time=anchor_time,
713
- batch_size=min(10_000, size),
714
- max_iterations=max_iterations,
715
- guarantee_train_examples=False,
695
+ train, test = self._sampler.sample_target(
696
+ query=query,
697
+ num_train_examples=0,
698
+ train_anchor_time=anchor_time,
699
+ num_train_trials=0,
700
+ num_test_examples=size,
701
+ test_anchor_time=anchor_time,
702
+ num_test_trials=max_iterations * size,
703
+ random_seed=random_seed,
716
704
  )
717
705
 
718
- entity = self._graph_store.pkey_map_dict[
719
- query_def.entity_table].index[node]
720
-
721
706
  return pd.DataFrame({
722
- 'ENTITY': entity,
723
- 'ANCHOR_TIMESTAMP': time,
724
- 'TARGET': y,
707
+ 'ENTITY': test.entity_pkey,
708
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
709
+ 'TARGET': test.target,
725
710
  })
726
711
 
727
712
  # Helpers #################################################################
@@ -744,8 +729,6 @@ class KumoRFM:
744
729
 
745
730
  resp = self._api_client.parse_query(request)
746
731
 
747
- # TODO Expose validation warnings.
748
-
749
732
  if len(resp.validation_response.warnings) > 0:
750
733
  msg = '\n'.join([
751
734
  f'{i+1}. {warning.title}: {warning.message}' for i, warning
@@ -763,6 +746,60 @@ class KumoRFM:
763
746
  raise ValueError(f"Failed to parse query '{query}'. "
764
747
  f"{msg}") from None
765
748
 
749
+ @staticmethod
750
+ def _get_task_type(
751
+ query: ValidatedPredictiveQuery,
752
+ edge_types: List[Tuple[str, str, str]],
753
+ ) -> TaskType:
754
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
755
+ return TaskType.BINARY_CLASSIFICATION
756
+
757
+ target = query.target_ast
758
+ if isinstance(target, Join):
759
+ target = target.rhs_target
760
+ if isinstance(target, Aggregation):
761
+ if target.aggr == AggregationType.LIST_DISTINCT:
762
+ table_name, col_name = target._get_target_column_name().split(
763
+ '.')
764
+ target_edge_types = [
765
+ edge_type for edge_type in edge_types
766
+ if edge_type[0] == table_name and edge_type[1] == col_name
767
+ ]
768
+ if len(target_edge_types) != 1:
769
+ raise NotImplementedError(
770
+ f"Multilabel-classification queries based on "
771
+ f"'LIST_DISTINCT' are not supported yet. If you "
772
+ f"planned to write a link prediction query instead, "
773
+ f"make sure to register '{col_name}' as a "
774
+ f"foreign key.")
775
+ return TaskType.TEMPORAL_LINK_PREDICTION
776
+
777
+ return TaskType.REGRESSION
778
+
779
+ assert isinstance(target, Column)
780
+
781
+ if target.stype in {Stype.ID, Stype.categorical}:
782
+ return TaskType.MULTICLASS_CLASSIFICATION
783
+
784
+ if target.stype in {Stype.numerical}:
785
+ return TaskType.REGRESSION
786
+
787
+ raise NotImplementedError("Task type not yet supported")
788
+
789
+ def _get_default_anchor_time(
790
+ self,
791
+ query: ValidatedPredictiveQuery,
792
+ ) -> pd.Timestamp:
793
+ if query.query_type == QueryType.TEMPORAL:
794
+ aggr_table_names = [
795
+ aggr._get_target_column_name().split('.')[0]
796
+ for aggr in query.get_all_target_aggregations()
797
+ ]
798
+ return self._sampler.get_max_time(aggr_table_names)
799
+
800
+ assert query.query_type == QueryType.STATIC
801
+ return self._sampler.get_max_time()
802
+
766
803
  def _validate_time(
767
804
  self,
768
805
  query: ValidatedPredictiveQuery,
@@ -771,28 +808,30 @@ class KumoRFM:
771
808
  evaluate: bool,
772
809
  ) -> None:
773
810
 
774
- if self._graph_store.min_time == pd.Timestamp.max:
811
+ if len(self._sampler.time_column_dict) == 0:
775
812
  return # Graph without timestamps
776
813
 
777
- if anchor_time < self._graph_store.min_time:
814
+ min_time = self._sampler.get_min_time()
815
+ max_time = self._sampler.get_max_time()
816
+
817
+ if anchor_time < min_time:
778
818
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
779
- f"the earliest timestamp "
780
- f"'{self._graph_store.min_time}' in the data.")
819
+ f"the earliest timestamp '{min_time}' in the "
820
+ f"data.")
781
821
 
782
- if (context_anchor_time is not None
783
- and context_anchor_time < self._graph_store.min_time):
822
+ if context_anchor_time is not None and context_anchor_time < min_time:
784
823
  raise ValueError(f"Context anchor timestamp is too early or "
785
824
  f"aggregation time range is too large. To make "
786
825
  f"this prediction, we would need data back to "
787
826
  f"'{context_anchor_time}', however, your data "
788
- f"only contains data back to "
789
- f"'{self._graph_store.min_time}'.")
827
+ f"only contains data back to '{min_time}'.")
790
828
 
791
829
  if query.target_ast.date_offset_range is not None:
792
830
  end_offset = query.target_ast.date_offset_range.end_date_offset
793
831
  else:
794
832
  end_offset = pd.DateOffset(0)
795
- forecast_end_offset = end_offset * query.num_forecasts
833
+ end_offset = end_offset * query.num_forecasts
834
+
796
835
  if (context_anchor_time is not None
797
836
  and context_anchor_time > anchor_time):
798
837
  warnings.warn(f"Context anchor timestamp "
@@ -802,7 +841,7 @@ class KumoRFM:
802
841
  f"intended.")
803
842
  elif (query.query_type == QueryType.TEMPORAL
804
843
  and context_anchor_time is not None
805
- and context_anchor_time + forecast_end_offset > anchor_time):
844
+ and context_anchor_time + end_offset > anchor_time):
806
845
  warnings.warn(f"Aggregation for context examples at timestamp "
807
846
  f"'{context_anchor_time}' will leak information "
808
847
  f"from the prediction anchor timestamp "
@@ -810,26 +849,23 @@ class KumoRFM:
810
849
  f"intended.")
811
850
 
812
851
  elif (context_anchor_time is not None
813
- and context_anchor_time - forecast_end_offset
814
- < self._graph_store.min_time):
815
- _time = context_anchor_time - forecast_end_offset
852
+ and context_anchor_time - end_offset < min_time):
853
+ _time = context_anchor_time - end_offset
816
854
  warnings.warn(f"Context anchor timestamp is too early or "
817
855
  f"aggregation time range is too large. To form "
818
856
  f"proper input data, we would need data back to "
819
857
  f"'{_time}', however, your data only contains "
820
- f"data back to '{self._graph_store.min_time}'.")
858
+ f"data back to '{min_time}'.")
821
859
 
822
- if (not evaluate and anchor_time
823
- > self._graph_store.max_time + pd.DateOffset(days=1)):
860
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
824
861
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
825
- f"latest timestamp '{self._graph_store.max_time}' "
826
- f"in the data. Please make sure this is intended.")
862
+ f"latest timestamp '{max_time}' in the data. Please "
863
+ f"make sure this is intended.")
827
864
 
828
- max_eval_time = self._graph_store.max_time - forecast_end_offset
829
- if evaluate and anchor_time > max_eval_time:
865
+ if evaluate and anchor_time > max_time - end_offset:
830
866
  raise ValueError(
831
867
  f"Anchor timestamp for evaluation is after the latest "
832
- f"supported timestamp '{max_eval_time}'.")
868
+ f"supported timestamp '{max_time - end_offset}'.")
833
869
 
834
870
  def _get_context(
835
871
  self,
@@ -860,10 +896,9 @@ class KumoRFM:
860
896
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
861
897
  f"must go beyond this for your use-case.")
862
898
 
863
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
864
- task_type = LocalPQueryDriver.get_task_type(
865
- query,
866
- edge_types=self._graph_store.edge_types,
899
+ task_type = self._get_task_type(
900
+ query=query,
901
+ edge_types=self._sampler.edge_types,
867
902
  )
868
903
 
869
904
  if logger is not None:
@@ -895,14 +930,17 @@ class KumoRFM:
895
930
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
896
931
 
897
932
  if query.target_ast.date_offset_range is None:
898
- end_offset = pd.DateOffset(0)
933
+ step_offset = pd.DateOffset(0)
899
934
  else:
900
- end_offset = query.target_ast.date_offset_range.end_date_offset
901
- forecast_end_offset = end_offset * query.num_forecasts
935
+ step_offset = query.target_ast.date_offset_range.end_date_offset
936
+ end_offset = step_offset * query.num_forecasts
937
+
902
938
  if anchor_time is None:
903
- anchor_time = self._graph_store.max_time
939
+ anchor_time = self._get_default_anchor_time(query)
940
+
904
941
  if evaluate:
905
- anchor_time = anchor_time - forecast_end_offset
942
+ anchor_time = anchor_time - end_offset
943
+
906
944
  if logger is not None:
907
945
  assert isinstance(anchor_time, pd.Timestamp)
908
946
  if anchor_time == pd.Timestamp.min:
@@ -916,57 +954,71 @@ class KumoRFM:
916
954
 
917
955
  assert anchor_time is not None
918
956
  if isinstance(anchor_time, pd.Timestamp):
957
+ if context_anchor_time == 'entity':
958
+ raise ValueError("Anchor time 'entity' needs to be shared "
959
+ "for context and prediction examples")
919
960
  if context_anchor_time is None:
920
- context_anchor_time = anchor_time - forecast_end_offset
961
+ context_anchor_time = anchor_time - end_offset
921
962
  self._validate_time(query, anchor_time, context_anchor_time,
922
963
  evaluate)
923
964
  else:
924
965
  assert anchor_time == 'entity'
925
- if query.entity_table not in self._graph_store.time_dict:
966
+ if query.query_type != QueryType.STATIC:
967
+ raise ValueError("Anchor time 'entity' is only valid for "
968
+ "static predictive queries")
969
+ if query.entity_table not in self._sampler.time_column_dict:
926
970
  raise ValueError(f"Anchor time 'entity' requires the entity "
927
971
  f"table '{query.entity_table}' to "
928
972
  f"have a time column")
929
- if context_anchor_time is not None:
930
- warnings.warn("Ignoring option 'context_anchor_time' for "
931
- "`anchor_time='entity'`")
932
- context_anchor_time = None
973
+ if isinstance(context_anchor_time, pd.Timestamp):
974
+ raise ValueError("Anchor time 'entity' needs to be shared "
975
+ "for context and prediction examples")
976
+ context_anchor_time = 'entity'
933
977
 
934
- y_test: Optional[pd.Series] = None
978
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
935
979
  if evaluate:
936
- max_test_size = _MAX_TEST_SIZE[run_mode]
980
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
937
981
  if task_type.is_link_pred:
938
- max_test_size = max_test_size // 5
982
+ num_test_examples = num_test_examples // 5
983
+ else:
984
+ num_test_examples = 0
985
+
986
+ train, test = self._sampler.sample_target(
987
+ query=query,
988
+ num_train_examples=num_train_examples,
989
+ train_anchor_time=context_anchor_time,
990
+ num_train_trials=max_pq_iterations * num_train_examples,
991
+ num_test_examples=num_test_examples,
992
+ test_anchor_time=anchor_time,
993
+ num_test_trials=max_pq_iterations * num_test_examples,
994
+ random_seed=random_seed,
995
+ )
996
+ train_pkey, train_time, y_train = train
997
+ test_pkey, test_time, y_test = test
939
998
 
940
- test_node, test_time, y_test = query_driver.collect_test(
941
- size=max_test_size,
942
- anchor_time=anchor_time,
943
- max_iterations=max_pq_iterations,
944
- guarantee_train_examples=True,
945
- )
946
- if logger is not None:
947
- if task_type == TaskType.BINARY_CLASSIFICATION:
948
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
949
- msg = (f"Collected {len(y_test):,} test examples with "
950
- f"{pos:.2f}% positive cases")
951
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
952
- msg = (f"Collected {len(y_test):,} test examples "
953
- f"holding {y_test.nunique()} classes")
954
- elif task_type == TaskType.REGRESSION:
955
- _min, _max = float(y_test.min()), float(y_test.max())
956
- msg = (f"Collected {len(y_test):,} test examples with "
957
- f"targets between {format_value(_min)} and "
958
- f"{format_value(_max)}")
959
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
960
- num_rhs = y_test.explode().nunique()
961
- msg = (f"Collected {len(y_test):,} test examples with "
962
- f"{num_rhs:,} unique items")
963
- else:
964
- raise NotImplementedError
965
- logger.log(msg)
999
+ if evaluate and logger is not None:
1000
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1001
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1002
+ msg = (f"Collected {len(y_test):,} test examples with "
1003
+ f"{pos:.2f}% positive cases")
1004
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1005
+ msg = (f"Collected {len(y_test):,} test examples holding "
1006
+ f"{y_test.nunique()} classes")
1007
+ elif task_type == TaskType.REGRESSION:
1008
+ _min, _max = float(y_test.min()), float(y_test.max())
1009
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1010
+ f"between {format_value(_min)} and "
1011
+ f"{format_value(_max)}")
1012
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1013
+ num_rhs = y_test.explode().nunique()
1014
+ msg = (f"Collected {len(y_test):,} test examples with "
1015
+ f"{num_rhs:,} unique items")
1016
+ else:
1017
+ raise NotImplementedError
1018
+ logger.log(msg)
966
1019
 
967
- else:
1020
+ if not evaluate:
968
1021
  assert indices is not None
969
-
970
1022
  if len(indices) > _MAX_PRED_SIZE[task_type]:
971
1023
  raise ValueError(f"Cannot predict for more than "
972
1024
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -974,26 +1026,12 @@ class KumoRFM:
974
1026
  f"`KumoRFM.batch_mode` to process entities "
975
1027
  f"in batches")
976
1028
 
977
- test_node = self._graph_store.get_node_id(
978
- table_name=query.entity_table,
979
- pkey=pd.Series(indices),
980
- )
981
-
1029
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
982
1030
  if isinstance(anchor_time, pd.Timestamp):
983
- test_time = pd.Series(anchor_time).repeat(
984
- len(test_node)).reset_index(drop=True)
1031
+ test_time = pd.Series([anchor_time]).repeat(
1032
+ len(indices)).reset_index(drop=True)
985
1033
  else:
986
- time = self._graph_store.time_dict[query.entity_table]
987
- time = time[test_node] * 1000**3
988
- test_time = pd.Series(time, dtype='datetime64[ns]')
989
-
990
- train_node, train_time, y_train = query_driver.collect_train(
991
- size=_MAX_CONTEXT_SIZE[run_mode],
992
- anchor_time=context_anchor_time or 'entity',
993
- exclude_node=test_node if (query.query_type == QueryType.STATIC
994
- or anchor_time == 'entity') else None,
995
- max_iterations=max_pq_iterations,
996
- )
1034
+ train_time = test_time = 'entity'
997
1035
 
998
1036
  if logger is not None:
999
1037
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1021,7 +1059,7 @@ class KumoRFM:
1021
1059
  final_aggr = query.get_final_target_aggregation()
1022
1060
  assert final_aggr is not None
1023
1061
  edge_fkey = final_aggr._get_target_column_name()
1024
- for edge_type in self._graph_store.edge_types:
1062
+ for edge_type in self._sampler.edge_types:
1025
1063
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1026
1064
  entity_table_names = (
1027
1065
  query.entity_table,
@@ -1033,20 +1071,24 @@ class KumoRFM:
1033
1071
  # Exclude the entity anchor time from the feature set to prevent
1034
1072
  # running out-of-distribution between in-context and test examples:
1035
1073
  exclude_cols_dict = query.get_exclude_cols_dict()
1036
- if anchor_time == 'entity':
1074
+ if entity_table_names[0] in self._sampler.time_column_dict:
1037
1075
  if entity_table_names[0] not in exclude_cols_dict:
1038
1076
  exclude_cols_dict[entity_table_names[0]] = []
1039
- time_column_dict = self._graph_store.time_column_dict
1040
- time_column = time_column_dict[entity_table_names[0]]
1077
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1041
1078
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1042
1079
 
1043
- subgraph = self._graph_sampler(
1080
+ subgraph = self._sampler.sample_subgraph(
1044
1081
  entity_table_names=entity_table_names,
1045
- node=np.concatenate([train_node, test_node]),
1046
- time=np.concatenate([
1047
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1048
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1049
- ]),
1082
+ entity_pkey=pd.concat(
1083
+ [train_pkey, test_pkey],
1084
+ axis=0,
1085
+ ignore_index=True,
1086
+ ),
1087
+ anchor_time=pd.concat(
1088
+ [train_time, test_time],
1089
+ axis=0,
1090
+ ignore_index=True,
1091
+ ) if isinstance(train_time, pd.Series) else 'entity',
1050
1092
  num_neighbors=num_neighbors,
1051
1093
  exclude_cols_dict=exclude_cols_dict,
1052
1094
  )
@@ -1058,18 +1100,14 @@ class KumoRFM:
1058
1100
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1059
1101
  f"must go beyond this for your use-case.")
1060
1102
 
1061
- step_size: Optional[int] = None
1062
- if query.query_type == QueryType.TEMPORAL:
1063
- step_size = date_offset_to_seconds(end_offset)
1064
-
1065
1103
  return Context(
1066
1104
  task_type=task_type,
1067
1105
  entity_table_names=entity_table_names,
1068
1106
  subgraph=subgraph,
1069
1107
  y_train=y_train,
1070
- y_test=y_test,
1108
+ y_test=y_test if evaluate else None,
1071
1109
  top_k=query.top_k,
1072
- step_size=step_size,
1110
+ step_size=None,
1073
1111
  )
1074
1112
 
1075
1113
  @staticmethod
@@ -2,15 +2,22 @@ import base64
2
2
  import json
3
3
  from typing import Any, Dict, List, Tuple
4
4
 
5
- import boto3
6
5
  import requests
7
- from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
8
- from mypy_boto3_sagemaker_runtime.type_defs import InvokeEndpointOutputTypeDef
9
6
 
10
7
  from kumoai.client import KumoClient
11
8
  from kumoai.client.endpoints import Endpoint, HTTPMethod
12
9
  from kumoai.exceptions import HTTPException
13
10
 
11
+ try:
12
+ # isort: off
13
+ from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
14
+ from mypy_boto3_sagemaker_runtime.type_defs import (
15
+ InvokeEndpointOutputTypeDef, )
16
+ # isort: on
17
+ except ImportError:
18
+ SageMakerRuntimeClient = Any
19
+ InvokeEndpointOutputTypeDef = Any
20
+
14
21
 
15
22
  class SageMakerResponseAdapter(requests.Response):
16
23
  def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
@@ -34,6 +41,7 @@ class SageMakerResponseAdapter(requests.Response):
34
41
 
35
42
  class KumoClient_SageMakerAdapter(KumoClient):
36
43
  def __init__(self, region: str, endpoint_name: str):
44
+ import boto3
37
45
  self._client: SageMakerRuntimeClient = boto3.client(
38
46
  service_name="sagemaker-runtime", region_name=region)
39
47
  self._endpoint_name = endpoint_name
Binary file