kumoai 2.13.0.dev202512040651__cp312-cp312-win_amd64.whl → 2.14.0.dev202512111731__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +31 -70
- kumoai/experimental/rfm/backend/local/sampler.py +313 -0
- kumoai/experimental/rfm/backend/snow/table.py +1 -1
- kumoai/experimental/rfm/base/__init__.py +3 -0
- kumoai/experimental/rfm/base/sampler.py +763 -0
- kumoai/experimental/rfm/base/table.py +8 -3
- kumoai/experimental/rfm/graph.py +33 -17
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +199 -151
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- {kumoai-2.13.0.dev202512040651.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/METADATA +2 -2
- {kumoai-2.13.0.dev202512040651.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/RECORD +19 -19
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512040651.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512040651.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512040651.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -21,6 +21,13 @@ import numpy as np
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
from kumoapi.model_plan import RunMode
|
|
23
23
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.pquery.AST import (
|
|
25
|
+
Aggregation,
|
|
26
|
+
Column,
|
|
27
|
+
Condition,
|
|
28
|
+
Join,
|
|
29
|
+
LogicalOperation,
|
|
30
|
+
)
|
|
24
31
|
from kumoapi.rfm import Context
|
|
25
32
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
33
|
from kumoapi.rfm import (
|
|
@@ -29,16 +36,11 @@ from kumoapi.rfm import (
|
|
|
29
36
|
RFMPredictRequest,
|
|
30
37
|
)
|
|
31
38
|
from kumoapi.task import TaskType
|
|
39
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
40
|
|
|
33
41
|
from kumoai.client.rfm import RFMAPI
|
|
34
42
|
from kumoai.exceptions import HTTPException
|
|
35
43
|
from kumoai.experimental.rfm import Graph
|
|
36
|
-
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
42
44
|
from kumoai.mixin import CastMixin
|
|
43
45
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
44
46
|
|
|
@@ -159,8 +161,9 @@ class KumoRFM:
|
|
|
159
161
|
) -> None:
|
|
160
162
|
graph = graph.validate()
|
|
161
163
|
self._graph_def = graph._to_api_graph_definition()
|
|
162
|
-
|
|
163
|
-
|
|
164
|
+
|
|
165
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
166
|
+
self._sampler = LocalSampler(graph, verbose)
|
|
164
167
|
|
|
165
168
|
self._client: Optional[RFMAPI] = None
|
|
166
169
|
|
|
@@ -224,7 +227,7 @@ class KumoRFM:
|
|
|
224
227
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
225
228
|
num_neighbors: Optional[List[int]] = None,
|
|
226
229
|
num_hops: int = 2,
|
|
227
|
-
max_pq_iterations: int =
|
|
230
|
+
max_pq_iterations: int = 10,
|
|
228
231
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
229
232
|
verbose: Union[bool, ProgressLogger] = True,
|
|
230
233
|
use_prediction_time: bool = False,
|
|
@@ -243,7 +246,7 @@ class KumoRFM:
|
|
|
243
246
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
244
247
|
num_neighbors: Optional[List[int]] = None,
|
|
245
248
|
num_hops: int = 2,
|
|
246
|
-
max_pq_iterations: int =
|
|
249
|
+
max_pq_iterations: int = 10,
|
|
247
250
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
248
251
|
verbose: Union[bool, ProgressLogger] = True,
|
|
249
252
|
use_prediction_time: bool = False,
|
|
@@ -261,7 +264,7 @@ class KumoRFM:
|
|
|
261
264
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
262
265
|
num_neighbors: Optional[List[int]] = None,
|
|
263
266
|
num_hops: int = 2,
|
|
264
|
-
max_pq_iterations: int =
|
|
267
|
+
max_pq_iterations: int = 10,
|
|
265
268
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
266
269
|
verbose: Union[bool, ProgressLogger] = True,
|
|
267
270
|
use_prediction_time: bool = False,
|
|
@@ -357,9 +360,9 @@ class KumoRFM:
|
|
|
357
360
|
|
|
358
361
|
batch_size: Optional[int] = None
|
|
359
362
|
if self._batch_size == 'max':
|
|
360
|
-
task_type =
|
|
361
|
-
query_def,
|
|
362
|
-
edge_types=self.
|
|
363
|
+
task_type = self._get_task_type(
|
|
364
|
+
query=query_def,
|
|
365
|
+
edge_types=self._sampler.edge_types,
|
|
363
366
|
)
|
|
364
367
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
365
368
|
else:
|
|
@@ -433,10 +436,10 @@ class KumoRFM:
|
|
|
433
436
|
|
|
434
437
|
# Cast 'ENTITY' to correct data type:
|
|
435
438
|
if 'ENTITY' in df:
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
439
|
+
table_dict = context.subgraph.table_dict
|
|
440
|
+
table = table_dict[query_def.entity_table]
|
|
441
|
+
ser = table.df[table.primary_key]
|
|
442
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
440
443
|
|
|
441
444
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
442
445
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -519,23 +522,18 @@ class KumoRFM:
|
|
|
519
522
|
raise ValueError("At least one entity is required")
|
|
520
523
|
|
|
521
524
|
if anchor_time is None:
|
|
522
|
-
anchor_time = self.
|
|
525
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
523
526
|
|
|
524
527
|
if isinstance(anchor_time, pd.Timestamp):
|
|
525
528
|
self._validate_time(query_def, anchor_time, None, False)
|
|
526
529
|
else:
|
|
527
530
|
assert anchor_time == 'entity'
|
|
528
|
-
if
|
|
531
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
529
532
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
533
|
f"table '{query_def.entity_table}' "
|
|
531
534
|
f"to have a time column.")
|
|
532
535
|
|
|
533
|
-
|
|
534
|
-
table_name=query_def.entity_table,
|
|
535
|
-
pkey=pd.Series(indices),
|
|
536
|
-
)
|
|
537
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
538
|
-
return query_driver.is_valid(node, anchor_time)
|
|
536
|
+
raise NotImplementedError
|
|
539
537
|
|
|
540
538
|
def evaluate(
|
|
541
539
|
self,
|
|
@@ -547,7 +545,7 @@ class KumoRFM:
|
|
|
547
545
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
548
546
|
num_neighbors: Optional[List[int]] = None,
|
|
549
547
|
num_hops: int = 2,
|
|
550
|
-
max_pq_iterations: int =
|
|
548
|
+
max_pq_iterations: int = 10,
|
|
551
549
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
552
550
|
verbose: Union[bool, ProgressLogger] = True,
|
|
553
551
|
use_prediction_time: bool = False,
|
|
@@ -658,7 +656,7 @@ class KumoRFM:
|
|
|
658
656
|
*,
|
|
659
657
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
660
658
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
661
|
-
max_iterations: int =
|
|
659
|
+
max_iterations: int = 10,
|
|
662
660
|
) -> pd.DataFrame:
|
|
663
661
|
"""Returns the labels of a predictive query for a specified anchor
|
|
664
662
|
time.
|
|
@@ -678,40 +676,37 @@ class KumoRFM:
|
|
|
678
676
|
query_def = self._parse_query(query)
|
|
679
677
|
|
|
680
678
|
if anchor_time is None:
|
|
681
|
-
anchor_time = self.
|
|
679
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
682
680
|
if query_def.target_ast.date_offset_range is not None:
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
681
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
682
|
+
offset *= query_def.num_forecasts
|
|
683
|
+
anchor_time -= offset
|
|
686
684
|
|
|
687
685
|
assert anchor_time is not None
|
|
688
686
|
if isinstance(anchor_time, pd.Timestamp):
|
|
689
687
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
690
688
|
else:
|
|
691
689
|
assert anchor_time == 'entity'
|
|
692
|
-
if
|
|
690
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
693
691
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
694
692
|
f"table '{query_def.entity_table}' "
|
|
695
693
|
f"to have a time column")
|
|
696
694
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
695
|
+
train, test = self._sampler.sample_target(
|
|
696
|
+
query=query,
|
|
697
|
+
num_train_examples=0,
|
|
698
|
+
train_anchor_time=anchor_time,
|
|
699
|
+
num_train_trials=0,
|
|
700
|
+
num_test_examples=size,
|
|
701
|
+
test_anchor_time=anchor_time,
|
|
702
|
+
num_test_trials=max_iterations * size,
|
|
703
|
+
random_seed=random_seed,
|
|
706
704
|
)
|
|
707
705
|
|
|
708
|
-
entity = self._graph_store.pkey_map_dict[
|
|
709
|
-
query_def.entity_table].index[node]
|
|
710
|
-
|
|
711
706
|
return pd.DataFrame({
|
|
712
|
-
'ENTITY':
|
|
713
|
-
'ANCHOR_TIMESTAMP':
|
|
714
|
-
'TARGET':
|
|
707
|
+
'ENTITY': test.entity_pkey,
|
|
708
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
709
|
+
'TARGET': test.target,
|
|
715
710
|
})
|
|
716
711
|
|
|
717
712
|
# Helpers #################################################################
|
|
@@ -734,8 +729,6 @@ class KumoRFM:
|
|
|
734
729
|
|
|
735
730
|
resp = self._api_client.parse_query(request)
|
|
736
731
|
|
|
737
|
-
# TODO Expose validation warnings.
|
|
738
|
-
|
|
739
732
|
if len(resp.validation_response.warnings) > 0:
|
|
740
733
|
msg = '\n'.join([
|
|
741
734
|
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
@@ -753,6 +746,60 @@ class KumoRFM:
|
|
|
753
746
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
754
747
|
f"{msg}") from None
|
|
755
748
|
|
|
749
|
+
@staticmethod
|
|
750
|
+
def _get_task_type(
|
|
751
|
+
query: ValidatedPredictiveQuery,
|
|
752
|
+
edge_types: List[Tuple[str, str, str]],
|
|
753
|
+
) -> TaskType:
|
|
754
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
755
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
756
|
+
|
|
757
|
+
target = query.target_ast
|
|
758
|
+
if isinstance(target, Join):
|
|
759
|
+
target = target.rhs_target
|
|
760
|
+
if isinstance(target, Aggregation):
|
|
761
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
762
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
763
|
+
'.')
|
|
764
|
+
target_edge_types = [
|
|
765
|
+
edge_type for edge_type in edge_types
|
|
766
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
767
|
+
]
|
|
768
|
+
if len(target_edge_types) != 1:
|
|
769
|
+
raise NotImplementedError(
|
|
770
|
+
f"Multilabel-classification queries based on "
|
|
771
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
772
|
+
f"planned to write a link prediction query instead, "
|
|
773
|
+
f"make sure to register '{col_name}' as a "
|
|
774
|
+
f"foreign key.")
|
|
775
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
776
|
+
|
|
777
|
+
return TaskType.REGRESSION
|
|
778
|
+
|
|
779
|
+
assert isinstance(target, Column)
|
|
780
|
+
|
|
781
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
782
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
783
|
+
|
|
784
|
+
if target.stype in {Stype.numerical}:
|
|
785
|
+
return TaskType.REGRESSION
|
|
786
|
+
|
|
787
|
+
raise NotImplementedError("Task type not yet supported")
|
|
788
|
+
|
|
789
|
+
def _get_default_anchor_time(
|
|
790
|
+
self,
|
|
791
|
+
query: ValidatedPredictiveQuery,
|
|
792
|
+
) -> pd.Timestamp:
|
|
793
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
794
|
+
aggr_table_names = [
|
|
795
|
+
aggr._get_target_column_name().split('.')[0]
|
|
796
|
+
for aggr in query.get_all_target_aggregations()
|
|
797
|
+
]
|
|
798
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
799
|
+
|
|
800
|
+
assert query.query_type == QueryType.STATIC
|
|
801
|
+
return self._sampler.get_max_time()
|
|
802
|
+
|
|
756
803
|
def _validate_time(
|
|
757
804
|
self,
|
|
758
805
|
query: ValidatedPredictiveQuery,
|
|
@@ -761,28 +808,30 @@ class KumoRFM:
|
|
|
761
808
|
evaluate: bool,
|
|
762
809
|
) -> None:
|
|
763
810
|
|
|
764
|
-
if self.
|
|
811
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
765
812
|
return # Graph without timestamps
|
|
766
813
|
|
|
767
|
-
|
|
814
|
+
min_time = self._sampler.get_min_time()
|
|
815
|
+
max_time = self._sampler.get_max_time()
|
|
816
|
+
|
|
817
|
+
if anchor_time < min_time:
|
|
768
818
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
769
|
-
f"the earliest timestamp "
|
|
770
|
-
f"
|
|
819
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
820
|
+
f"data.")
|
|
771
821
|
|
|
772
|
-
if
|
|
773
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
822
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
774
823
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
775
824
|
f"aggregation time range is too large. To make "
|
|
776
825
|
f"this prediction, we would need data back to "
|
|
777
826
|
f"'{context_anchor_time}', however, your data "
|
|
778
|
-
f"only contains data back to "
|
|
779
|
-
f"'{self._graph_store.min_time}'.")
|
|
827
|
+
f"only contains data back to '{min_time}'.")
|
|
780
828
|
|
|
781
829
|
if query.target_ast.date_offset_range is not None:
|
|
782
830
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
783
831
|
else:
|
|
784
832
|
end_offset = pd.DateOffset(0)
|
|
785
|
-
|
|
833
|
+
end_offset = end_offset * query.num_forecasts
|
|
834
|
+
|
|
786
835
|
if (context_anchor_time is not None
|
|
787
836
|
and context_anchor_time > anchor_time):
|
|
788
837
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -792,7 +841,7 @@ class KumoRFM:
|
|
|
792
841
|
f"intended.")
|
|
793
842
|
elif (query.query_type == QueryType.TEMPORAL
|
|
794
843
|
and context_anchor_time is not None
|
|
795
|
-
and context_anchor_time +
|
|
844
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
796
845
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
797
846
|
f"'{context_anchor_time}' will leak information "
|
|
798
847
|
f"from the prediction anchor timestamp "
|
|
@@ -800,26 +849,23 @@ class KumoRFM:
|
|
|
800
849
|
f"intended.")
|
|
801
850
|
|
|
802
851
|
elif (context_anchor_time is not None
|
|
803
|
-
and context_anchor_time -
|
|
804
|
-
|
|
805
|
-
_time = context_anchor_time - forecast_end_offset
|
|
852
|
+
and context_anchor_time - end_offset < min_time):
|
|
853
|
+
_time = context_anchor_time - end_offset
|
|
806
854
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
807
855
|
f"aggregation time range is too large. To form "
|
|
808
856
|
f"proper input data, we would need data back to "
|
|
809
857
|
f"'{_time}', however, your data only contains "
|
|
810
|
-
f"data back to '{
|
|
858
|
+
f"data back to '{min_time}'.")
|
|
811
859
|
|
|
812
|
-
if
|
|
813
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
860
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
814
861
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
815
|
-
f"latest timestamp '{
|
|
816
|
-
f"
|
|
862
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
863
|
+
f"make sure this is intended.")
|
|
817
864
|
|
|
818
|
-
|
|
819
|
-
if evaluate and anchor_time > max_eval_time:
|
|
865
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
820
866
|
raise ValueError(
|
|
821
867
|
f"Anchor timestamp for evaluation is after the latest "
|
|
822
|
-
f"supported timestamp '{
|
|
868
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
823
869
|
|
|
824
870
|
def _get_context(
|
|
825
871
|
self,
|
|
@@ -850,10 +896,9 @@ class KumoRFM:
|
|
|
850
896
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
851
897
|
f"must go beyond this for your use-case.")
|
|
852
898
|
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
edge_types=self._graph_store.edge_types,
|
|
899
|
+
task_type = self._get_task_type(
|
|
900
|
+
query=query,
|
|
901
|
+
edge_types=self._sampler.edge_types,
|
|
857
902
|
)
|
|
858
903
|
|
|
859
904
|
if logger is not None:
|
|
@@ -885,14 +930,17 @@ class KumoRFM:
|
|
|
885
930
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
886
931
|
|
|
887
932
|
if query.target_ast.date_offset_range is None:
|
|
888
|
-
|
|
933
|
+
step_offset = pd.DateOffset(0)
|
|
889
934
|
else:
|
|
890
|
-
|
|
891
|
-
|
|
935
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
936
|
+
end_offset = step_offset * query.num_forecasts
|
|
937
|
+
|
|
892
938
|
if anchor_time is None:
|
|
893
|
-
anchor_time = self.
|
|
939
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
940
|
+
|
|
894
941
|
if evaluate:
|
|
895
|
-
anchor_time = anchor_time -
|
|
942
|
+
anchor_time = anchor_time - end_offset
|
|
943
|
+
|
|
896
944
|
if logger is not None:
|
|
897
945
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
898
946
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -906,57 +954,71 @@ class KumoRFM:
|
|
|
906
954
|
|
|
907
955
|
assert anchor_time is not None
|
|
908
956
|
if isinstance(anchor_time, pd.Timestamp):
|
|
957
|
+
if context_anchor_time == 'entity':
|
|
958
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
959
|
+
"for context and prediction examples")
|
|
909
960
|
if context_anchor_time is None:
|
|
910
|
-
context_anchor_time = anchor_time -
|
|
961
|
+
context_anchor_time = anchor_time - end_offset
|
|
911
962
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
912
963
|
evaluate)
|
|
913
964
|
else:
|
|
914
965
|
assert anchor_time == 'entity'
|
|
915
|
-
if query.
|
|
966
|
+
if query.query_type != QueryType.STATIC:
|
|
967
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
968
|
+
"static predictive queries")
|
|
969
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
916
970
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
917
971
|
f"table '{query.entity_table}' to "
|
|
918
972
|
f"have a time column")
|
|
919
|
-
if context_anchor_time
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
context_anchor_time =
|
|
973
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
974
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
975
|
+
"for context and prediction examples")
|
|
976
|
+
context_anchor_time = 'entity'
|
|
923
977
|
|
|
924
|
-
|
|
978
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
925
979
|
if evaluate:
|
|
926
|
-
|
|
980
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
927
981
|
if task_type.is_link_pred:
|
|
928
|
-
|
|
982
|
+
num_test_examples = num_test_examples // 5
|
|
983
|
+
else:
|
|
984
|
+
num_test_examples = 0
|
|
985
|
+
|
|
986
|
+
train, test = self._sampler.sample_target(
|
|
987
|
+
query=query,
|
|
988
|
+
num_train_examples=num_train_examples,
|
|
989
|
+
train_anchor_time=context_anchor_time,
|
|
990
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
991
|
+
num_test_examples=num_test_examples,
|
|
992
|
+
test_anchor_time=anchor_time,
|
|
993
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
994
|
+
random_seed=random_seed,
|
|
995
|
+
)
|
|
996
|
+
train_pkey, train_time, y_train = train
|
|
997
|
+
test_pkey, test_time, y_test = test
|
|
929
998
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
num_rhs = y_test.explode().nunique()
|
|
951
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
952
|
-
f"{num_rhs:,} unique items")
|
|
953
|
-
else:
|
|
954
|
-
raise NotImplementedError
|
|
955
|
-
logger.log(msg)
|
|
999
|
+
if evaluate and logger is not None:
|
|
1000
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1001
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1002
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1003
|
+
f"{pos:.2f}% positive cases")
|
|
1004
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1005
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1006
|
+
f"{y_test.nunique()} classes")
|
|
1007
|
+
elif task_type == TaskType.REGRESSION:
|
|
1008
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1009
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1010
|
+
f"between {format_value(_min)} and "
|
|
1011
|
+
f"{format_value(_max)}")
|
|
1012
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1013
|
+
num_rhs = y_test.explode().nunique()
|
|
1014
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1015
|
+
f"{num_rhs:,} unique items")
|
|
1016
|
+
else:
|
|
1017
|
+
raise NotImplementedError
|
|
1018
|
+
logger.log(msg)
|
|
956
1019
|
|
|
957
|
-
|
|
1020
|
+
if not evaluate:
|
|
958
1021
|
assert indices is not None
|
|
959
|
-
|
|
960
1022
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
961
1023
|
raise ValueError(f"Cannot predict for more than "
|
|
962
1024
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -964,26 +1026,12 @@ class KumoRFM:
|
|
|
964
1026
|
f"`KumoRFM.batch_mode` to process entities "
|
|
965
1027
|
f"in batches")
|
|
966
1028
|
|
|
967
|
-
|
|
968
|
-
table_name=query.entity_table,
|
|
969
|
-
pkey=pd.Series(indices),
|
|
970
|
-
)
|
|
971
|
-
|
|
1029
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
972
1030
|
if isinstance(anchor_time, pd.Timestamp):
|
|
973
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
974
|
-
len(
|
|
1031
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1032
|
+
len(indices)).reset_index(drop=True)
|
|
975
1033
|
else:
|
|
976
|
-
|
|
977
|
-
time = time[test_node] * 1000**3
|
|
978
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
979
|
-
|
|
980
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
981
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
982
|
-
anchor_time=context_anchor_time or 'entity',
|
|
983
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
984
|
-
or anchor_time == 'entity') else None,
|
|
985
|
-
max_iterations=max_pq_iterations,
|
|
986
|
-
)
|
|
1034
|
+
train_time = test_time = 'entity'
|
|
987
1035
|
|
|
988
1036
|
if logger is not None:
|
|
989
1037
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1011,7 +1059,7 @@ class KumoRFM:
|
|
|
1011
1059
|
final_aggr = query.get_final_target_aggregation()
|
|
1012
1060
|
assert final_aggr is not None
|
|
1013
1061
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1014
|
-
for edge_type in self.
|
|
1062
|
+
for edge_type in self._sampler.edge_types:
|
|
1015
1063
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1016
1064
|
entity_table_names = (
|
|
1017
1065
|
query.entity_table,
|
|
@@ -1023,20 +1071,24 @@ class KumoRFM:
|
|
|
1023
1071
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1024
1072
|
# running out-of-distribution between in-context and test examples:
|
|
1025
1073
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1026
|
-
if
|
|
1074
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1027
1075
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1028
1076
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1029
|
-
|
|
1030
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1077
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1031
1078
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1032
1079
|
|
|
1033
|
-
subgraph = self.
|
|
1080
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1034
1081
|
entity_table_names=entity_table_names,
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1082
|
+
entity_pkey=pd.concat(
|
|
1083
|
+
[train_pkey, test_pkey],
|
|
1084
|
+
axis=0,
|
|
1085
|
+
ignore_index=True,
|
|
1086
|
+
),
|
|
1087
|
+
anchor_time=pd.concat(
|
|
1088
|
+
[train_time, test_time],
|
|
1089
|
+
axis=0,
|
|
1090
|
+
ignore_index=True,
|
|
1091
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1040
1092
|
num_neighbors=num_neighbors,
|
|
1041
1093
|
exclude_cols_dict=exclude_cols_dict,
|
|
1042
1094
|
)
|
|
@@ -1048,18 +1100,14 @@ class KumoRFM:
|
|
|
1048
1100
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1049
1101
|
f"must go beyond this for your use-case.")
|
|
1050
1102
|
|
|
1051
|
-
step_size: Optional[int] = None
|
|
1052
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1053
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1054
|
-
|
|
1055
1103
|
return Context(
|
|
1056
1104
|
task_type=task_type,
|
|
1057
1105
|
entity_table_names=entity_table_names,
|
|
1058
1106
|
subgraph=subgraph,
|
|
1059
1107
|
y_train=y_train,
|
|
1060
|
-
y_test=y_test,
|
|
1108
|
+
y_test=y_test if evaluate else None,
|
|
1061
1109
|
top_k=query.top_k,
|
|
1062
|
-
step_size=
|
|
1110
|
+
step_size=None,
|
|
1063
1111
|
)
|
|
1064
1112
|
|
|
1065
1113
|
@staticmethod
|
|
Binary file
|
|
@@ -370,9 +370,11 @@ class PredictiveQuery:
|
|
|
370
370
|
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
371
371
|
job_id: GenerateTrainTableJobID = train_table_job_api.create(
|
|
372
372
|
GenerateTrainTableRequest(
|
|
373
|
-
dict(custom_tags),
|
|
374
|
-
|
|
375
|
-
|
|
373
|
+
dict(custom_tags),
|
|
374
|
+
pq_id,
|
|
375
|
+
plan,
|
|
376
|
+
None,
|
|
377
|
+
))
|
|
376
378
|
|
|
377
379
|
self._train_table = TrainingTableJob(job_id=job_id)
|
|
378
380
|
if non_blocking:
|
|
@@ -451,9 +453,11 @@ class PredictiveQuery:
|
|
|
451
453
|
bp_table_api = global_state.client.generate_prediction_table_job_api
|
|
452
454
|
job_id: GeneratePredictionTableJobID = bp_table_api.create(
|
|
453
455
|
GeneratePredictionTableRequest(
|
|
454
|
-
dict(custom_tags),
|
|
455
|
-
|
|
456
|
-
|
|
456
|
+
dict(custom_tags),
|
|
457
|
+
pq_id,
|
|
458
|
+
plan,
|
|
459
|
+
None,
|
|
460
|
+
))
|
|
457
461
|
|
|
458
462
|
self._prediction_table = PredictionTableJob(job_id=job_id)
|
|
459
463
|
if non_blocking:
|
{kumoai-2.13.0.dev202512040651.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.14.0.dev202512111731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.49.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|