kumoai 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
- kumoai/experimental/rfm/backend/local/sampler.py +213 -14
- kumoai/experimental/rfm/backend/local/table.py +12 -2
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
- kumoai/experimental/rfm/backend/snow/table.py +35 -17
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
- kumoai/experimental/rfm/base/__init__.py +17 -6
- kumoai/experimental/rfm/base/sampler.py +438 -38
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +56 -0
- kumoai/experimental/rfm/base/table.py +12 -1
- kumoai/experimental/rfm/graph.py +26 -9
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +214 -151
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -21,6 +21,13 @@ import numpy as np
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
from kumoapi.model_plan import RunMode
|
|
23
23
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.pquery.AST import (
|
|
25
|
+
Aggregation,
|
|
26
|
+
Column,
|
|
27
|
+
Condition,
|
|
28
|
+
Join,
|
|
29
|
+
LogicalOperation,
|
|
30
|
+
)
|
|
24
31
|
from kumoapi.rfm import Context
|
|
25
32
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
33
|
from kumoapi.rfm import (
|
|
@@ -29,16 +36,12 @@ from kumoapi.rfm import (
|
|
|
29
36
|
RFMPredictRequest,
|
|
30
37
|
)
|
|
31
38
|
from kumoapi.task import TaskType
|
|
39
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
40
|
|
|
33
41
|
from kumoai.client.rfm import RFMAPI
|
|
34
42
|
from kumoai.exceptions import HTTPException
|
|
35
43
|
from kumoai.experimental.rfm import Graph
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
44
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
45
|
from kumoai.mixin import CastMixin
|
|
43
46
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
44
47
|
|
|
@@ -151,16 +154,31 @@ class KumoRFM:
|
|
|
151
154
|
Args:
|
|
152
155
|
graph: The graph.
|
|
153
156
|
verbose: Whether to print verbose output.
|
|
157
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
158
|
+
for optimal querying. For example, for transactional database
|
|
159
|
+
backends, will create any missing indices. Requires write-access to
|
|
160
|
+
the data backend.
|
|
154
161
|
"""
|
|
155
162
|
def __init__(
|
|
156
163
|
self,
|
|
157
164
|
graph: Graph,
|
|
158
165
|
verbose: Union[bool, ProgressLogger] = True,
|
|
166
|
+
optimize: bool = False,
|
|
159
167
|
) -> None:
|
|
160
168
|
graph = graph.validate()
|
|
161
169
|
self._graph_def = graph._to_api_graph_definition()
|
|
162
|
-
|
|
163
|
-
|
|
170
|
+
|
|
171
|
+
if graph.backend == DataBackend.LOCAL:
|
|
172
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
173
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
174
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
175
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
176
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
177
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
178
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
179
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
180
|
+
else:
|
|
181
|
+
raise NotImplementedError
|
|
164
182
|
|
|
165
183
|
self._client: Optional[RFMAPI] = None
|
|
166
184
|
|
|
@@ -224,7 +242,7 @@ class KumoRFM:
|
|
|
224
242
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
225
243
|
num_neighbors: Optional[List[int]] = None,
|
|
226
244
|
num_hops: int = 2,
|
|
227
|
-
max_pq_iterations: int =
|
|
245
|
+
max_pq_iterations: int = 10,
|
|
228
246
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
229
247
|
verbose: Union[bool, ProgressLogger] = True,
|
|
230
248
|
use_prediction_time: bool = False,
|
|
@@ -243,7 +261,7 @@ class KumoRFM:
|
|
|
243
261
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
244
262
|
num_neighbors: Optional[List[int]] = None,
|
|
245
263
|
num_hops: int = 2,
|
|
246
|
-
max_pq_iterations: int =
|
|
264
|
+
max_pq_iterations: int = 10,
|
|
247
265
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
248
266
|
verbose: Union[bool, ProgressLogger] = True,
|
|
249
267
|
use_prediction_time: bool = False,
|
|
@@ -261,7 +279,7 @@ class KumoRFM:
|
|
|
261
279
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
262
280
|
num_neighbors: Optional[List[int]] = None,
|
|
263
281
|
num_hops: int = 2,
|
|
264
|
-
max_pq_iterations: int =
|
|
282
|
+
max_pq_iterations: int = 10,
|
|
265
283
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
266
284
|
verbose: Union[bool, ProgressLogger] = True,
|
|
267
285
|
use_prediction_time: bool = False,
|
|
@@ -357,9 +375,9 @@ class KumoRFM:
|
|
|
357
375
|
|
|
358
376
|
batch_size: Optional[int] = None
|
|
359
377
|
if self._batch_size == 'max':
|
|
360
|
-
task_type =
|
|
361
|
-
query_def,
|
|
362
|
-
edge_types=self.
|
|
378
|
+
task_type = self._get_task_type(
|
|
379
|
+
query=query_def,
|
|
380
|
+
edge_types=self._sampler.edge_types,
|
|
363
381
|
)
|
|
364
382
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
365
383
|
else:
|
|
@@ -433,10 +451,10 @@ class KumoRFM:
|
|
|
433
451
|
|
|
434
452
|
# Cast 'ENTITY' to correct data type:
|
|
435
453
|
if 'ENTITY' in df:
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
454
|
+
table_dict = context.subgraph.table_dict
|
|
455
|
+
table = table_dict[query_def.entity_table]
|
|
456
|
+
ser = table.df[table.primary_key]
|
|
457
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
440
458
|
|
|
441
459
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
442
460
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -519,23 +537,18 @@ class KumoRFM:
|
|
|
519
537
|
raise ValueError("At least one entity is required")
|
|
520
538
|
|
|
521
539
|
if anchor_time is None:
|
|
522
|
-
anchor_time = self.
|
|
540
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
523
541
|
|
|
524
542
|
if isinstance(anchor_time, pd.Timestamp):
|
|
525
543
|
self._validate_time(query_def, anchor_time, None, False)
|
|
526
544
|
else:
|
|
527
545
|
assert anchor_time == 'entity'
|
|
528
|
-
if
|
|
546
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
529
547
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
548
|
f"table '{query_def.entity_table}' "
|
|
531
549
|
f"to have a time column.")
|
|
532
550
|
|
|
533
|
-
|
|
534
|
-
table_name=query_def.entity_table,
|
|
535
|
-
pkey=pd.Series(indices),
|
|
536
|
-
)
|
|
537
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
538
|
-
return query_driver.is_valid(node, anchor_time)
|
|
551
|
+
raise NotImplementedError
|
|
539
552
|
|
|
540
553
|
def evaluate(
|
|
541
554
|
self,
|
|
@@ -547,7 +560,7 @@ class KumoRFM:
|
|
|
547
560
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
548
561
|
num_neighbors: Optional[List[int]] = None,
|
|
549
562
|
num_hops: int = 2,
|
|
550
|
-
max_pq_iterations: int =
|
|
563
|
+
max_pq_iterations: int = 10,
|
|
551
564
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
552
565
|
verbose: Union[bool, ProgressLogger] = True,
|
|
553
566
|
use_prediction_time: bool = False,
|
|
@@ -658,7 +671,7 @@ class KumoRFM:
|
|
|
658
671
|
*,
|
|
659
672
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
660
673
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
661
|
-
max_iterations: int =
|
|
674
|
+
max_iterations: int = 10,
|
|
662
675
|
) -> pd.DataFrame:
|
|
663
676
|
"""Returns the labels of a predictive query for a specified anchor
|
|
664
677
|
time.
|
|
@@ -678,40 +691,37 @@ class KumoRFM:
|
|
|
678
691
|
query_def = self._parse_query(query)
|
|
679
692
|
|
|
680
693
|
if anchor_time is None:
|
|
681
|
-
anchor_time = self.
|
|
694
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
682
695
|
if query_def.target_ast.date_offset_range is not None:
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
696
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
697
|
+
offset *= query_def.num_forecasts
|
|
698
|
+
anchor_time -= offset
|
|
686
699
|
|
|
687
700
|
assert anchor_time is not None
|
|
688
701
|
if isinstance(anchor_time, pd.Timestamp):
|
|
689
702
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
690
703
|
else:
|
|
691
704
|
assert anchor_time == 'entity'
|
|
692
|
-
if
|
|
705
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
693
706
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
694
707
|
f"table '{query_def.entity_table}' "
|
|
695
708
|
f"to have a time column")
|
|
696
709
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
710
|
+
train, test = self._sampler.sample_target(
|
|
711
|
+
query=query,
|
|
712
|
+
num_train_examples=0,
|
|
713
|
+
train_anchor_time=anchor_time,
|
|
714
|
+
num_train_trials=0,
|
|
715
|
+
num_test_examples=size,
|
|
716
|
+
test_anchor_time=anchor_time,
|
|
717
|
+
num_test_trials=max_iterations * size,
|
|
718
|
+
random_seed=random_seed,
|
|
706
719
|
)
|
|
707
720
|
|
|
708
|
-
entity = self._graph_store.pkey_map_dict[
|
|
709
|
-
query_def.entity_table].index[node]
|
|
710
|
-
|
|
711
721
|
return pd.DataFrame({
|
|
712
|
-
'ENTITY':
|
|
713
|
-
'ANCHOR_TIMESTAMP':
|
|
714
|
-
'TARGET':
|
|
722
|
+
'ENTITY': test.entity_pkey,
|
|
723
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
724
|
+
'TARGET': test.target,
|
|
715
725
|
})
|
|
716
726
|
|
|
717
727
|
# Helpers #################################################################
|
|
@@ -734,8 +744,6 @@ class KumoRFM:
|
|
|
734
744
|
|
|
735
745
|
resp = self._api_client.parse_query(request)
|
|
736
746
|
|
|
737
|
-
# TODO Expose validation warnings.
|
|
738
|
-
|
|
739
747
|
if len(resp.validation_response.warnings) > 0:
|
|
740
748
|
msg = '\n'.join([
|
|
741
749
|
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
@@ -753,6 +761,60 @@ class KumoRFM:
|
|
|
753
761
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
754
762
|
f"{msg}") from None
|
|
755
763
|
|
|
764
|
+
@staticmethod
|
|
765
|
+
def _get_task_type(
|
|
766
|
+
query: ValidatedPredictiveQuery,
|
|
767
|
+
edge_types: List[Tuple[str, str, str]],
|
|
768
|
+
) -> TaskType:
|
|
769
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
770
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
771
|
+
|
|
772
|
+
target = query.target_ast
|
|
773
|
+
if isinstance(target, Join):
|
|
774
|
+
target = target.rhs_target
|
|
775
|
+
if isinstance(target, Aggregation):
|
|
776
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
777
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
778
|
+
'.')
|
|
779
|
+
target_edge_types = [
|
|
780
|
+
edge_type for edge_type in edge_types
|
|
781
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
782
|
+
]
|
|
783
|
+
if len(target_edge_types) != 1:
|
|
784
|
+
raise NotImplementedError(
|
|
785
|
+
f"Multilabel-classification queries based on "
|
|
786
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
787
|
+
f"planned to write a link prediction query instead, "
|
|
788
|
+
f"make sure to register '{col_name}' as a "
|
|
789
|
+
f"foreign key.")
|
|
790
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
791
|
+
|
|
792
|
+
return TaskType.REGRESSION
|
|
793
|
+
|
|
794
|
+
assert isinstance(target, Column)
|
|
795
|
+
|
|
796
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
797
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
798
|
+
|
|
799
|
+
if target.stype in {Stype.numerical}:
|
|
800
|
+
return TaskType.REGRESSION
|
|
801
|
+
|
|
802
|
+
raise NotImplementedError("Task type not yet supported")
|
|
803
|
+
|
|
804
|
+
def _get_default_anchor_time(
|
|
805
|
+
self,
|
|
806
|
+
query: ValidatedPredictiveQuery,
|
|
807
|
+
) -> pd.Timestamp:
|
|
808
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
809
|
+
aggr_table_names = [
|
|
810
|
+
aggr._get_target_column_name().split('.')[0]
|
|
811
|
+
for aggr in query.get_all_target_aggregations()
|
|
812
|
+
]
|
|
813
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
814
|
+
|
|
815
|
+
assert query.query_type == QueryType.STATIC
|
|
816
|
+
return self._sampler.get_max_time()
|
|
817
|
+
|
|
756
818
|
def _validate_time(
|
|
757
819
|
self,
|
|
758
820
|
query: ValidatedPredictiveQuery,
|
|
@@ -761,28 +823,30 @@ class KumoRFM:
|
|
|
761
823
|
evaluate: bool,
|
|
762
824
|
) -> None:
|
|
763
825
|
|
|
764
|
-
if self.
|
|
826
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
765
827
|
return # Graph without timestamps
|
|
766
828
|
|
|
767
|
-
|
|
829
|
+
min_time = self._sampler.get_min_time()
|
|
830
|
+
max_time = self._sampler.get_max_time()
|
|
831
|
+
|
|
832
|
+
if anchor_time < min_time:
|
|
768
833
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
769
|
-
f"the earliest timestamp "
|
|
770
|
-
f"
|
|
834
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
835
|
+
f"data.")
|
|
771
836
|
|
|
772
|
-
if
|
|
773
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
837
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
774
838
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
775
839
|
f"aggregation time range is too large. To make "
|
|
776
840
|
f"this prediction, we would need data back to "
|
|
777
841
|
f"'{context_anchor_time}', however, your data "
|
|
778
|
-
f"only contains data back to "
|
|
779
|
-
f"'{self._graph_store.min_time}'.")
|
|
842
|
+
f"only contains data back to '{min_time}'.")
|
|
780
843
|
|
|
781
844
|
if query.target_ast.date_offset_range is not None:
|
|
782
845
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
783
846
|
else:
|
|
784
847
|
end_offset = pd.DateOffset(0)
|
|
785
|
-
|
|
848
|
+
end_offset = end_offset * query.num_forecasts
|
|
849
|
+
|
|
786
850
|
if (context_anchor_time is not None
|
|
787
851
|
and context_anchor_time > anchor_time):
|
|
788
852
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -792,7 +856,7 @@ class KumoRFM:
|
|
|
792
856
|
f"intended.")
|
|
793
857
|
elif (query.query_type == QueryType.TEMPORAL
|
|
794
858
|
and context_anchor_time is not None
|
|
795
|
-
and context_anchor_time +
|
|
859
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
796
860
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
797
861
|
f"'{context_anchor_time}' will leak information "
|
|
798
862
|
f"from the prediction anchor timestamp "
|
|
@@ -800,26 +864,23 @@ class KumoRFM:
|
|
|
800
864
|
f"intended.")
|
|
801
865
|
|
|
802
866
|
elif (context_anchor_time is not None
|
|
803
|
-
and context_anchor_time -
|
|
804
|
-
|
|
805
|
-
_time = context_anchor_time - forecast_end_offset
|
|
867
|
+
and context_anchor_time - end_offset < min_time):
|
|
868
|
+
_time = context_anchor_time - end_offset
|
|
806
869
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
807
870
|
f"aggregation time range is too large. To form "
|
|
808
871
|
f"proper input data, we would need data back to "
|
|
809
872
|
f"'{_time}', however, your data only contains "
|
|
810
|
-
f"data back to '{
|
|
873
|
+
f"data back to '{min_time}'.")
|
|
811
874
|
|
|
812
|
-
if
|
|
813
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
875
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
814
876
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
815
|
-
f"latest timestamp '{
|
|
816
|
-
f"
|
|
877
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
878
|
+
f"make sure this is intended.")
|
|
817
879
|
|
|
818
|
-
|
|
819
|
-
if evaluate and anchor_time > max_eval_time:
|
|
880
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
820
881
|
raise ValueError(
|
|
821
882
|
f"Anchor timestamp for evaluation is after the latest "
|
|
822
|
-
f"supported timestamp '{
|
|
883
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
823
884
|
|
|
824
885
|
def _get_context(
|
|
825
886
|
self,
|
|
@@ -850,10 +911,9 @@ class KumoRFM:
|
|
|
850
911
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
851
912
|
f"must go beyond this for your use-case.")
|
|
852
913
|
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
edge_types=self._graph_store.edge_types,
|
|
914
|
+
task_type = self._get_task_type(
|
|
915
|
+
query=query,
|
|
916
|
+
edge_types=self._sampler.edge_types,
|
|
857
917
|
)
|
|
858
918
|
|
|
859
919
|
if logger is not None:
|
|
@@ -885,14 +945,17 @@ class KumoRFM:
|
|
|
885
945
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
886
946
|
|
|
887
947
|
if query.target_ast.date_offset_range is None:
|
|
888
|
-
|
|
948
|
+
step_offset = pd.DateOffset(0)
|
|
889
949
|
else:
|
|
890
|
-
|
|
891
|
-
|
|
950
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
951
|
+
end_offset = step_offset * query.num_forecasts
|
|
952
|
+
|
|
892
953
|
if anchor_time is None:
|
|
893
|
-
anchor_time = self.
|
|
954
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
955
|
+
|
|
894
956
|
if evaluate:
|
|
895
|
-
anchor_time = anchor_time -
|
|
957
|
+
anchor_time = anchor_time - end_offset
|
|
958
|
+
|
|
896
959
|
if logger is not None:
|
|
897
960
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
898
961
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -906,57 +969,71 @@ class KumoRFM:
|
|
|
906
969
|
|
|
907
970
|
assert anchor_time is not None
|
|
908
971
|
if isinstance(anchor_time, pd.Timestamp):
|
|
972
|
+
if context_anchor_time == 'entity':
|
|
973
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
974
|
+
"for context and prediction examples")
|
|
909
975
|
if context_anchor_time is None:
|
|
910
|
-
context_anchor_time = anchor_time -
|
|
976
|
+
context_anchor_time = anchor_time - end_offset
|
|
911
977
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
912
978
|
evaluate)
|
|
913
979
|
else:
|
|
914
980
|
assert anchor_time == 'entity'
|
|
915
|
-
if query.
|
|
981
|
+
if query.query_type != QueryType.STATIC:
|
|
982
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
983
|
+
"static predictive queries")
|
|
984
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
916
985
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
917
986
|
f"table '{query.entity_table}' to "
|
|
918
987
|
f"have a time column")
|
|
919
|
-
if context_anchor_time
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
context_anchor_time =
|
|
988
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
989
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
990
|
+
"for context and prediction examples")
|
|
991
|
+
context_anchor_time = 'entity'
|
|
923
992
|
|
|
924
|
-
|
|
993
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
925
994
|
if evaluate:
|
|
926
|
-
|
|
995
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
927
996
|
if task_type.is_link_pred:
|
|
928
|
-
|
|
997
|
+
num_test_examples = num_test_examples // 5
|
|
998
|
+
else:
|
|
999
|
+
num_test_examples = 0
|
|
1000
|
+
|
|
1001
|
+
train, test = self._sampler.sample_target(
|
|
1002
|
+
query=query,
|
|
1003
|
+
num_train_examples=num_train_examples,
|
|
1004
|
+
train_anchor_time=context_anchor_time,
|
|
1005
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1006
|
+
num_test_examples=num_test_examples,
|
|
1007
|
+
test_anchor_time=anchor_time,
|
|
1008
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1009
|
+
random_seed=random_seed,
|
|
1010
|
+
)
|
|
1011
|
+
train_pkey, train_time, y_train = train
|
|
1012
|
+
test_pkey, test_time, y_test = test
|
|
929
1013
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
num_rhs = y_test.explode().nunique()
|
|
951
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
952
|
-
f"{num_rhs:,} unique items")
|
|
953
|
-
else:
|
|
954
|
-
raise NotImplementedError
|
|
955
|
-
logger.log(msg)
|
|
1014
|
+
if evaluate and logger is not None:
|
|
1015
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1016
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1017
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1018
|
+
f"{pos:.2f}% positive cases")
|
|
1019
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1020
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1021
|
+
f"{y_test.nunique()} classes")
|
|
1022
|
+
elif task_type == TaskType.REGRESSION:
|
|
1023
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1024
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1025
|
+
f"between {format_value(_min)} and "
|
|
1026
|
+
f"{format_value(_max)}")
|
|
1027
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1028
|
+
num_rhs = y_test.explode().nunique()
|
|
1029
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1030
|
+
f"{num_rhs:,} unique items")
|
|
1031
|
+
else:
|
|
1032
|
+
raise NotImplementedError
|
|
1033
|
+
logger.log(msg)
|
|
956
1034
|
|
|
957
|
-
|
|
1035
|
+
if not evaluate:
|
|
958
1036
|
assert indices is not None
|
|
959
|
-
|
|
960
1037
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
961
1038
|
raise ValueError(f"Cannot predict for more than "
|
|
962
1039
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -964,26 +1041,12 @@ class KumoRFM:
|
|
|
964
1041
|
f"`KumoRFM.batch_mode` to process entities "
|
|
965
1042
|
f"in batches")
|
|
966
1043
|
|
|
967
|
-
|
|
968
|
-
table_name=query.entity_table,
|
|
969
|
-
pkey=pd.Series(indices),
|
|
970
|
-
)
|
|
971
|
-
|
|
1044
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
972
1045
|
if isinstance(anchor_time, pd.Timestamp):
|
|
973
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
974
|
-
len(
|
|
1046
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1047
|
+
len(indices)).reset_index(drop=True)
|
|
975
1048
|
else:
|
|
976
|
-
|
|
977
|
-
time = time[test_node] * 1000**3
|
|
978
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
979
|
-
|
|
980
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
981
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
982
|
-
anchor_time=context_anchor_time or 'entity',
|
|
983
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
984
|
-
or anchor_time == 'entity') else None,
|
|
985
|
-
max_iterations=max_pq_iterations,
|
|
986
|
-
)
|
|
1049
|
+
train_time = test_time = 'entity'
|
|
987
1050
|
|
|
988
1051
|
if logger is not None:
|
|
989
1052
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1011,7 +1074,7 @@ class KumoRFM:
|
|
|
1011
1074
|
final_aggr = query.get_final_target_aggregation()
|
|
1012
1075
|
assert final_aggr is not None
|
|
1013
1076
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1014
|
-
for edge_type in self.
|
|
1077
|
+
for edge_type in self._sampler.edge_types:
|
|
1015
1078
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1016
1079
|
entity_table_names = (
|
|
1017
1080
|
query.entity_table,
|
|
@@ -1023,20 +1086,24 @@ class KumoRFM:
|
|
|
1023
1086
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1024
1087
|
# running out-of-distribution between in-context and test examples:
|
|
1025
1088
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1026
|
-
if
|
|
1089
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1027
1090
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1028
1091
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1029
|
-
|
|
1030
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1092
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1031
1093
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1032
1094
|
|
|
1033
|
-
subgraph = self.
|
|
1095
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1034
1096
|
entity_table_names=entity_table_names,
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1097
|
+
entity_pkey=pd.concat(
|
|
1098
|
+
[train_pkey, test_pkey],
|
|
1099
|
+
axis=0,
|
|
1100
|
+
ignore_index=True,
|
|
1101
|
+
),
|
|
1102
|
+
anchor_time=pd.concat(
|
|
1103
|
+
[train_time, test_time],
|
|
1104
|
+
axis=0,
|
|
1105
|
+
ignore_index=True,
|
|
1106
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1040
1107
|
num_neighbors=num_neighbors,
|
|
1041
1108
|
exclude_cols_dict=exclude_cols_dict,
|
|
1042
1109
|
)
|
|
@@ -1048,18 +1115,14 @@ class KumoRFM:
|
|
|
1048
1115
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1049
1116
|
f"must go beyond this for your use-case.")
|
|
1050
1117
|
|
|
1051
|
-
step_size: Optional[int] = None
|
|
1052
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1053
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1054
|
-
|
|
1055
1118
|
return Context(
|
|
1056
1119
|
task_type=task_type,
|
|
1057
1120
|
entity_table_names=entity_table_names,
|
|
1058
1121
|
subgraph=subgraph,
|
|
1059
1122
|
y_train=y_train,
|
|
1060
|
-
y_test=y_test,
|
|
1123
|
+
y_test=y_test if evaluate else None,
|
|
1061
1124
|
top_k=query.top_k,
|
|
1062
|
-
step_size=
|
|
1125
|
+
step_size=None,
|
|
1063
1126
|
)
|
|
1064
1127
|
|
|
1065
1128
|
@staticmethod
|
|
@@ -370,9 +370,11 @@ class PredictiveQuery:
|
|
|
370
370
|
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
371
371
|
job_id: GenerateTrainTableJobID = train_table_job_api.create(
|
|
372
372
|
GenerateTrainTableRequest(
|
|
373
|
-
dict(custom_tags),
|
|
374
|
-
|
|
375
|
-
|
|
373
|
+
dict(custom_tags),
|
|
374
|
+
pq_id,
|
|
375
|
+
plan,
|
|
376
|
+
None,
|
|
377
|
+
))
|
|
376
378
|
|
|
377
379
|
self._train_table = TrainingTableJob(job_id=job_id)
|
|
378
380
|
if non_blocking:
|
|
@@ -451,9 +453,11 @@ class PredictiveQuery:
|
|
|
451
453
|
bp_table_api = global_state.client.generate_prediction_table_job_api
|
|
452
454
|
job_id: GeneratePredictionTableJobID = bp_table_api.create(
|
|
453
455
|
GeneratePredictionTableRequest(
|
|
454
|
-
dict(custom_tags),
|
|
455
|
-
|
|
456
|
-
|
|
456
|
+
dict(custom_tags),
|
|
457
|
+
pq_id,
|
|
458
|
+
plan,
|
|
459
|
+
None,
|
|
460
|
+
))
|
|
457
461
|
|
|
458
462
|
self._prediction_table = PredictionTableJob(job_id=job_id)
|
|
459
463
|
if non_blocking:
|