kumoai 2.13.0.dev202511271731__cp312-cp312-win_amd64.whl → 2.14.0.dev202512111731__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +20 -45
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
- kumoai/experimental/rfm/backend/local/sampler.py +313 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +13 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +763 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +204 -166
- kumoai/experimental/rfm/sagemaker.py +11 -3
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/decorators.py +1 -1
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/METADATA +9 -8
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/RECORD +34 -22
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.14.0.dev202512111731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -21,6 +21,13 @@ import numpy as np
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
from kumoapi.model_plan import RunMode
|
|
23
23
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.pquery.AST import (
|
|
25
|
+
Aggregation,
|
|
26
|
+
Column,
|
|
27
|
+
Condition,
|
|
28
|
+
Join,
|
|
29
|
+
LogicalOperation,
|
|
30
|
+
)
|
|
24
31
|
from kumoapi.rfm import Context
|
|
25
32
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
33
|
from kumoapi.rfm import (
|
|
@@ -29,16 +36,11 @@ from kumoapi.rfm import (
|
|
|
29
36
|
RFMPredictRequest,
|
|
30
37
|
)
|
|
31
38
|
from kumoapi.task import TaskType
|
|
39
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
40
|
|
|
33
41
|
from kumoai.client.rfm import RFMAPI
|
|
34
42
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import
|
|
36
|
-
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
43
|
+
from kumoai.experimental.rfm import Graph
|
|
42
44
|
from kumoai.mixin import CastMixin
|
|
43
45
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
44
46
|
|
|
@@ -123,17 +125,17 @@ class KumoRFM:
|
|
|
123
125
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
124
126
|
relational dataset without training.
|
|
125
127
|
The model is pre-trained and the class provides an interface to query the
|
|
126
|
-
model from a :class:`
|
|
128
|
+
model from a :class:`Graph` object.
|
|
127
129
|
|
|
128
130
|
.. code-block:: python
|
|
129
131
|
|
|
130
|
-
from kumoai.experimental.rfm import
|
|
132
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
131
133
|
|
|
132
134
|
df_users = pd.DataFrame(...)
|
|
133
135
|
df_items = pd.DataFrame(...)
|
|
134
136
|
df_orders = pd.DataFrame(...)
|
|
135
137
|
|
|
136
|
-
graph =
|
|
138
|
+
graph = Graph.from_data({
|
|
137
139
|
'users': df_users,
|
|
138
140
|
'items': df_items,
|
|
139
141
|
'orders': df_orders,
|
|
@@ -150,27 +152,18 @@ class KumoRFM:
|
|
|
150
152
|
|
|
151
153
|
Args:
|
|
152
154
|
graph: The graph.
|
|
153
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
154
|
-
materialization.
|
|
155
|
-
This is a runtime trade-off between graph materialization and model
|
|
156
|
-
processing speed.
|
|
157
|
-
It can be benefical to preprocess your data once and then run many
|
|
158
|
-
queries on top to achieve maximum model speed.
|
|
159
|
-
However, if activiated, graph materialization can take potentially
|
|
160
|
-
much longer, especially on graphs with many large text columns.
|
|
161
|
-
Best to tune this option manually.
|
|
162
155
|
verbose: Whether to print verbose output.
|
|
163
156
|
"""
|
|
164
157
|
def __init__(
|
|
165
158
|
self,
|
|
166
|
-
graph:
|
|
167
|
-
preprocess: bool = False,
|
|
159
|
+
graph: Graph,
|
|
168
160
|
verbose: Union[bool, ProgressLogger] = True,
|
|
169
161
|
) -> None:
|
|
170
162
|
graph = graph.validate()
|
|
171
163
|
self._graph_def = graph._to_api_graph_definition()
|
|
172
|
-
|
|
173
|
-
|
|
164
|
+
|
|
165
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
166
|
+
self._sampler = LocalSampler(graph, verbose)
|
|
174
167
|
|
|
175
168
|
self._client: Optional[RFMAPI] = None
|
|
176
169
|
|
|
@@ -234,7 +227,7 @@ class KumoRFM:
|
|
|
234
227
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
235
228
|
num_neighbors: Optional[List[int]] = None,
|
|
236
229
|
num_hops: int = 2,
|
|
237
|
-
max_pq_iterations: int =
|
|
230
|
+
max_pq_iterations: int = 10,
|
|
238
231
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
239
232
|
verbose: Union[bool, ProgressLogger] = True,
|
|
240
233
|
use_prediction_time: bool = False,
|
|
@@ -253,7 +246,7 @@ class KumoRFM:
|
|
|
253
246
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
254
247
|
num_neighbors: Optional[List[int]] = None,
|
|
255
248
|
num_hops: int = 2,
|
|
256
|
-
max_pq_iterations: int =
|
|
249
|
+
max_pq_iterations: int = 10,
|
|
257
250
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
258
251
|
verbose: Union[bool, ProgressLogger] = True,
|
|
259
252
|
use_prediction_time: bool = False,
|
|
@@ -271,7 +264,7 @@ class KumoRFM:
|
|
|
271
264
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
272
265
|
num_neighbors: Optional[List[int]] = None,
|
|
273
266
|
num_hops: int = 2,
|
|
274
|
-
max_pq_iterations: int =
|
|
267
|
+
max_pq_iterations: int = 10,
|
|
275
268
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
276
269
|
verbose: Union[bool, ProgressLogger] = True,
|
|
277
270
|
use_prediction_time: bool = False,
|
|
@@ -367,9 +360,9 @@ class KumoRFM:
|
|
|
367
360
|
|
|
368
361
|
batch_size: Optional[int] = None
|
|
369
362
|
if self._batch_size == 'max':
|
|
370
|
-
task_type =
|
|
371
|
-
query_def,
|
|
372
|
-
edge_types=self.
|
|
363
|
+
task_type = self._get_task_type(
|
|
364
|
+
query=query_def,
|
|
365
|
+
edge_types=self._sampler.edge_types,
|
|
373
366
|
)
|
|
374
367
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
375
368
|
else:
|
|
@@ -443,10 +436,10 @@ class KumoRFM:
|
|
|
443
436
|
|
|
444
437
|
# Cast 'ENTITY' to correct data type:
|
|
445
438
|
if 'ENTITY' in df:
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
439
|
+
table_dict = context.subgraph.table_dict
|
|
440
|
+
table = table_dict[query_def.entity_table]
|
|
441
|
+
ser = table.df[table.primary_key]
|
|
442
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
450
443
|
|
|
451
444
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
452
445
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -529,23 +522,18 @@ class KumoRFM:
|
|
|
529
522
|
raise ValueError("At least one entity is required")
|
|
530
523
|
|
|
531
524
|
if anchor_time is None:
|
|
532
|
-
anchor_time = self.
|
|
525
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
533
526
|
|
|
534
527
|
if isinstance(anchor_time, pd.Timestamp):
|
|
535
528
|
self._validate_time(query_def, anchor_time, None, False)
|
|
536
529
|
else:
|
|
537
530
|
assert anchor_time == 'entity'
|
|
538
|
-
if
|
|
531
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
539
532
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
540
533
|
f"table '{query_def.entity_table}' "
|
|
541
534
|
f"to have a time column.")
|
|
542
535
|
|
|
543
|
-
|
|
544
|
-
table_name=query_def.entity_table,
|
|
545
|
-
pkey=pd.Series(indices),
|
|
546
|
-
)
|
|
547
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
548
|
-
return query_driver.is_valid(node, anchor_time)
|
|
536
|
+
raise NotImplementedError
|
|
549
537
|
|
|
550
538
|
def evaluate(
|
|
551
539
|
self,
|
|
@@ -557,7 +545,7 @@ class KumoRFM:
|
|
|
557
545
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
558
546
|
num_neighbors: Optional[List[int]] = None,
|
|
559
547
|
num_hops: int = 2,
|
|
560
|
-
max_pq_iterations: int =
|
|
548
|
+
max_pq_iterations: int = 10,
|
|
561
549
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
562
550
|
verbose: Union[bool, ProgressLogger] = True,
|
|
563
551
|
use_prediction_time: bool = False,
|
|
@@ -668,7 +656,7 @@ class KumoRFM:
|
|
|
668
656
|
*,
|
|
669
657
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
670
658
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
671
|
-
max_iterations: int =
|
|
659
|
+
max_iterations: int = 10,
|
|
672
660
|
) -> pd.DataFrame:
|
|
673
661
|
"""Returns the labels of a predictive query for a specified anchor
|
|
674
662
|
time.
|
|
@@ -688,40 +676,37 @@ class KumoRFM:
|
|
|
688
676
|
query_def = self._parse_query(query)
|
|
689
677
|
|
|
690
678
|
if anchor_time is None:
|
|
691
|
-
anchor_time = self.
|
|
679
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
692
680
|
if query_def.target_ast.date_offset_range is not None:
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
681
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
682
|
+
offset *= query_def.num_forecasts
|
|
683
|
+
anchor_time -= offset
|
|
696
684
|
|
|
697
685
|
assert anchor_time is not None
|
|
698
686
|
if isinstance(anchor_time, pd.Timestamp):
|
|
699
687
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
700
688
|
else:
|
|
701
689
|
assert anchor_time == 'entity'
|
|
702
|
-
if
|
|
690
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
703
691
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
704
692
|
f"table '{query_def.entity_table}' "
|
|
705
693
|
f"to have a time column")
|
|
706
694
|
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
695
|
+
train, test = self._sampler.sample_target(
|
|
696
|
+
query=query,
|
|
697
|
+
num_train_examples=0,
|
|
698
|
+
train_anchor_time=anchor_time,
|
|
699
|
+
num_train_trials=0,
|
|
700
|
+
num_test_examples=size,
|
|
701
|
+
test_anchor_time=anchor_time,
|
|
702
|
+
num_test_trials=max_iterations * size,
|
|
703
|
+
random_seed=random_seed,
|
|
716
704
|
)
|
|
717
705
|
|
|
718
|
-
entity = self._graph_store.pkey_map_dict[
|
|
719
|
-
query_def.entity_table].index[node]
|
|
720
|
-
|
|
721
706
|
return pd.DataFrame({
|
|
722
|
-
'ENTITY':
|
|
723
|
-
'ANCHOR_TIMESTAMP':
|
|
724
|
-
'TARGET':
|
|
707
|
+
'ENTITY': test.entity_pkey,
|
|
708
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
709
|
+
'TARGET': test.target,
|
|
725
710
|
})
|
|
726
711
|
|
|
727
712
|
# Helpers #################################################################
|
|
@@ -744,8 +729,6 @@ class KumoRFM:
|
|
|
744
729
|
|
|
745
730
|
resp = self._api_client.parse_query(request)
|
|
746
731
|
|
|
747
|
-
# TODO Expose validation warnings.
|
|
748
|
-
|
|
749
732
|
if len(resp.validation_response.warnings) > 0:
|
|
750
733
|
msg = '\n'.join([
|
|
751
734
|
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
@@ -763,6 +746,60 @@ class KumoRFM:
|
|
|
763
746
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
764
747
|
f"{msg}") from None
|
|
765
748
|
|
|
749
|
+
@staticmethod
|
|
750
|
+
def _get_task_type(
|
|
751
|
+
query: ValidatedPredictiveQuery,
|
|
752
|
+
edge_types: List[Tuple[str, str, str]],
|
|
753
|
+
) -> TaskType:
|
|
754
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
755
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
756
|
+
|
|
757
|
+
target = query.target_ast
|
|
758
|
+
if isinstance(target, Join):
|
|
759
|
+
target = target.rhs_target
|
|
760
|
+
if isinstance(target, Aggregation):
|
|
761
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
762
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
763
|
+
'.')
|
|
764
|
+
target_edge_types = [
|
|
765
|
+
edge_type for edge_type in edge_types
|
|
766
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
767
|
+
]
|
|
768
|
+
if len(target_edge_types) != 1:
|
|
769
|
+
raise NotImplementedError(
|
|
770
|
+
f"Multilabel-classification queries based on "
|
|
771
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
772
|
+
f"planned to write a link prediction query instead, "
|
|
773
|
+
f"make sure to register '{col_name}' as a "
|
|
774
|
+
f"foreign key.")
|
|
775
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
776
|
+
|
|
777
|
+
return TaskType.REGRESSION
|
|
778
|
+
|
|
779
|
+
assert isinstance(target, Column)
|
|
780
|
+
|
|
781
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
782
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
783
|
+
|
|
784
|
+
if target.stype in {Stype.numerical}:
|
|
785
|
+
return TaskType.REGRESSION
|
|
786
|
+
|
|
787
|
+
raise NotImplementedError("Task type not yet supported")
|
|
788
|
+
|
|
789
|
+
def _get_default_anchor_time(
|
|
790
|
+
self,
|
|
791
|
+
query: ValidatedPredictiveQuery,
|
|
792
|
+
) -> pd.Timestamp:
|
|
793
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
794
|
+
aggr_table_names = [
|
|
795
|
+
aggr._get_target_column_name().split('.')[0]
|
|
796
|
+
for aggr in query.get_all_target_aggregations()
|
|
797
|
+
]
|
|
798
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
799
|
+
|
|
800
|
+
assert query.query_type == QueryType.STATIC
|
|
801
|
+
return self._sampler.get_max_time()
|
|
802
|
+
|
|
766
803
|
def _validate_time(
|
|
767
804
|
self,
|
|
768
805
|
query: ValidatedPredictiveQuery,
|
|
@@ -771,28 +808,30 @@ class KumoRFM:
|
|
|
771
808
|
evaluate: bool,
|
|
772
809
|
) -> None:
|
|
773
810
|
|
|
774
|
-
if self.
|
|
811
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
775
812
|
return # Graph without timestamps
|
|
776
813
|
|
|
777
|
-
|
|
814
|
+
min_time = self._sampler.get_min_time()
|
|
815
|
+
max_time = self._sampler.get_max_time()
|
|
816
|
+
|
|
817
|
+
if anchor_time < min_time:
|
|
778
818
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
779
|
-
f"the earliest timestamp "
|
|
780
|
-
f"
|
|
819
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
820
|
+
f"data.")
|
|
781
821
|
|
|
782
|
-
if
|
|
783
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
822
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
784
823
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
785
824
|
f"aggregation time range is too large. To make "
|
|
786
825
|
f"this prediction, we would need data back to "
|
|
787
826
|
f"'{context_anchor_time}', however, your data "
|
|
788
|
-
f"only contains data back to "
|
|
789
|
-
f"'{self._graph_store.min_time}'.")
|
|
827
|
+
f"only contains data back to '{min_time}'.")
|
|
790
828
|
|
|
791
829
|
if query.target_ast.date_offset_range is not None:
|
|
792
830
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
793
831
|
else:
|
|
794
832
|
end_offset = pd.DateOffset(0)
|
|
795
|
-
|
|
833
|
+
end_offset = end_offset * query.num_forecasts
|
|
834
|
+
|
|
796
835
|
if (context_anchor_time is not None
|
|
797
836
|
and context_anchor_time > anchor_time):
|
|
798
837
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -802,7 +841,7 @@ class KumoRFM:
|
|
|
802
841
|
f"intended.")
|
|
803
842
|
elif (query.query_type == QueryType.TEMPORAL
|
|
804
843
|
and context_anchor_time is not None
|
|
805
|
-
and context_anchor_time +
|
|
844
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
806
845
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
807
846
|
f"'{context_anchor_time}' will leak information "
|
|
808
847
|
f"from the prediction anchor timestamp "
|
|
@@ -810,26 +849,23 @@ class KumoRFM:
|
|
|
810
849
|
f"intended.")
|
|
811
850
|
|
|
812
851
|
elif (context_anchor_time is not None
|
|
813
|
-
and context_anchor_time -
|
|
814
|
-
|
|
815
|
-
_time = context_anchor_time - forecast_end_offset
|
|
852
|
+
and context_anchor_time - end_offset < min_time):
|
|
853
|
+
_time = context_anchor_time - end_offset
|
|
816
854
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
817
855
|
f"aggregation time range is too large. To form "
|
|
818
856
|
f"proper input data, we would need data back to "
|
|
819
857
|
f"'{_time}', however, your data only contains "
|
|
820
|
-
f"data back to '{
|
|
858
|
+
f"data back to '{min_time}'.")
|
|
821
859
|
|
|
822
|
-
if
|
|
823
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
860
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
824
861
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
825
|
-
f"latest timestamp '{
|
|
826
|
-
f"
|
|
862
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
863
|
+
f"make sure this is intended.")
|
|
827
864
|
|
|
828
|
-
|
|
829
|
-
if evaluate and anchor_time > max_eval_time:
|
|
865
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
830
866
|
raise ValueError(
|
|
831
867
|
f"Anchor timestamp for evaluation is after the latest "
|
|
832
|
-
f"supported timestamp '{
|
|
868
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
833
869
|
|
|
834
870
|
def _get_context(
|
|
835
871
|
self,
|
|
@@ -860,10 +896,9 @@ class KumoRFM:
|
|
|
860
896
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
861
897
|
f"must go beyond this for your use-case.")
|
|
862
898
|
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
edge_types=self._graph_store.edge_types,
|
|
899
|
+
task_type = self._get_task_type(
|
|
900
|
+
query=query,
|
|
901
|
+
edge_types=self._sampler.edge_types,
|
|
867
902
|
)
|
|
868
903
|
|
|
869
904
|
if logger is not None:
|
|
@@ -895,14 +930,17 @@ class KumoRFM:
|
|
|
895
930
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
896
931
|
|
|
897
932
|
if query.target_ast.date_offset_range is None:
|
|
898
|
-
|
|
933
|
+
step_offset = pd.DateOffset(0)
|
|
899
934
|
else:
|
|
900
|
-
|
|
901
|
-
|
|
935
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
936
|
+
end_offset = step_offset * query.num_forecasts
|
|
937
|
+
|
|
902
938
|
if anchor_time is None:
|
|
903
|
-
anchor_time = self.
|
|
939
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
940
|
+
|
|
904
941
|
if evaluate:
|
|
905
|
-
anchor_time = anchor_time -
|
|
942
|
+
anchor_time = anchor_time - end_offset
|
|
943
|
+
|
|
906
944
|
if logger is not None:
|
|
907
945
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
908
946
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -916,57 +954,71 @@ class KumoRFM:
|
|
|
916
954
|
|
|
917
955
|
assert anchor_time is not None
|
|
918
956
|
if isinstance(anchor_time, pd.Timestamp):
|
|
957
|
+
if context_anchor_time == 'entity':
|
|
958
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
959
|
+
"for context and prediction examples")
|
|
919
960
|
if context_anchor_time is None:
|
|
920
|
-
context_anchor_time = anchor_time -
|
|
961
|
+
context_anchor_time = anchor_time - end_offset
|
|
921
962
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
922
963
|
evaluate)
|
|
923
964
|
else:
|
|
924
965
|
assert anchor_time == 'entity'
|
|
925
|
-
if query.
|
|
966
|
+
if query.query_type != QueryType.STATIC:
|
|
967
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
968
|
+
"static predictive queries")
|
|
969
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
926
970
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
927
971
|
f"table '{query.entity_table}' to "
|
|
928
972
|
f"have a time column")
|
|
929
|
-
if context_anchor_time
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
context_anchor_time =
|
|
973
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
974
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
975
|
+
"for context and prediction examples")
|
|
976
|
+
context_anchor_time = 'entity'
|
|
933
977
|
|
|
934
|
-
|
|
978
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
935
979
|
if evaluate:
|
|
936
|
-
|
|
980
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
937
981
|
if task_type.is_link_pred:
|
|
938
|
-
|
|
982
|
+
num_test_examples = num_test_examples // 5
|
|
983
|
+
else:
|
|
984
|
+
num_test_examples = 0
|
|
985
|
+
|
|
986
|
+
train, test = self._sampler.sample_target(
|
|
987
|
+
query=query,
|
|
988
|
+
num_train_examples=num_train_examples,
|
|
989
|
+
train_anchor_time=context_anchor_time,
|
|
990
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
991
|
+
num_test_examples=num_test_examples,
|
|
992
|
+
test_anchor_time=anchor_time,
|
|
993
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
994
|
+
random_seed=random_seed,
|
|
995
|
+
)
|
|
996
|
+
train_pkey, train_time, y_train = train
|
|
997
|
+
test_pkey, test_time, y_test = test
|
|
939
998
|
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
num_rhs = y_test.explode().nunique()
|
|
961
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
962
|
-
f"{num_rhs:,} unique items")
|
|
963
|
-
else:
|
|
964
|
-
raise NotImplementedError
|
|
965
|
-
logger.log(msg)
|
|
999
|
+
if evaluate and logger is not None:
|
|
1000
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1001
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1002
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1003
|
+
f"{pos:.2f}% positive cases")
|
|
1004
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1005
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1006
|
+
f"{y_test.nunique()} classes")
|
|
1007
|
+
elif task_type == TaskType.REGRESSION:
|
|
1008
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1009
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1010
|
+
f"between {format_value(_min)} and "
|
|
1011
|
+
f"{format_value(_max)}")
|
|
1012
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1013
|
+
num_rhs = y_test.explode().nunique()
|
|
1014
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1015
|
+
f"{num_rhs:,} unique items")
|
|
1016
|
+
else:
|
|
1017
|
+
raise NotImplementedError
|
|
1018
|
+
logger.log(msg)
|
|
966
1019
|
|
|
967
|
-
|
|
1020
|
+
if not evaluate:
|
|
968
1021
|
assert indices is not None
|
|
969
|
-
|
|
970
1022
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
971
1023
|
raise ValueError(f"Cannot predict for more than "
|
|
972
1024
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -974,26 +1026,12 @@ class KumoRFM:
|
|
|
974
1026
|
f"`KumoRFM.batch_mode` to process entities "
|
|
975
1027
|
f"in batches")
|
|
976
1028
|
|
|
977
|
-
|
|
978
|
-
table_name=query.entity_table,
|
|
979
|
-
pkey=pd.Series(indices),
|
|
980
|
-
)
|
|
981
|
-
|
|
1029
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
982
1030
|
if isinstance(anchor_time, pd.Timestamp):
|
|
983
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
984
|
-
len(
|
|
1031
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1032
|
+
len(indices)).reset_index(drop=True)
|
|
985
1033
|
else:
|
|
986
|
-
|
|
987
|
-
time = time[test_node] * 1000**3
|
|
988
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
989
|
-
|
|
990
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
991
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
992
|
-
anchor_time=context_anchor_time or 'entity',
|
|
993
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
994
|
-
or anchor_time == 'entity') else None,
|
|
995
|
-
max_iterations=max_pq_iterations,
|
|
996
|
-
)
|
|
1034
|
+
train_time = test_time = 'entity'
|
|
997
1035
|
|
|
998
1036
|
if logger is not None:
|
|
999
1037
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1021,7 +1059,7 @@ class KumoRFM:
|
|
|
1021
1059
|
final_aggr = query.get_final_target_aggregation()
|
|
1022
1060
|
assert final_aggr is not None
|
|
1023
1061
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1024
|
-
for edge_type in self.
|
|
1062
|
+
for edge_type in self._sampler.edge_types:
|
|
1025
1063
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1026
1064
|
entity_table_names = (
|
|
1027
1065
|
query.entity_table,
|
|
@@ -1033,20 +1071,24 @@ class KumoRFM:
|
|
|
1033
1071
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1034
1072
|
# running out-of-distribution between in-context and test examples:
|
|
1035
1073
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1036
|
-
if
|
|
1074
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1037
1075
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1038
1076
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1039
|
-
|
|
1040
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1077
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1041
1078
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1042
1079
|
|
|
1043
|
-
subgraph = self.
|
|
1080
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1044
1081
|
entity_table_names=entity_table_names,
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1082
|
+
entity_pkey=pd.concat(
|
|
1083
|
+
[train_pkey, test_pkey],
|
|
1084
|
+
axis=0,
|
|
1085
|
+
ignore_index=True,
|
|
1086
|
+
),
|
|
1087
|
+
anchor_time=pd.concat(
|
|
1088
|
+
[train_time, test_time],
|
|
1089
|
+
axis=0,
|
|
1090
|
+
ignore_index=True,
|
|
1091
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1050
1092
|
num_neighbors=num_neighbors,
|
|
1051
1093
|
exclude_cols_dict=exclude_cols_dict,
|
|
1052
1094
|
)
|
|
@@ -1058,18 +1100,14 @@ class KumoRFM:
|
|
|
1058
1100
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1059
1101
|
f"must go beyond this for your use-case.")
|
|
1060
1102
|
|
|
1061
|
-
step_size: Optional[int] = None
|
|
1062
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1063
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1064
|
-
|
|
1065
1103
|
return Context(
|
|
1066
1104
|
task_type=task_type,
|
|
1067
1105
|
entity_table_names=entity_table_names,
|
|
1068
1106
|
subgraph=subgraph,
|
|
1069
1107
|
y_train=y_train,
|
|
1070
|
-
y_test=y_test,
|
|
1108
|
+
y_test=y_test if evaluate else None,
|
|
1071
1109
|
top_k=query.top_k,
|
|
1072
|
-
step_size=
|
|
1110
|
+
step_size=None,
|
|
1073
1111
|
)
|
|
1074
1112
|
|
|
1075
1113
|
@staticmethod
|
|
@@ -2,15 +2,22 @@ import base64
|
|
|
2
2
|
import json
|
|
3
3
|
from typing import Any, Dict, List, Tuple
|
|
4
4
|
|
|
5
|
-
import boto3
|
|
6
5
|
import requests
|
|
7
|
-
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
8
|
-
from mypy_boto3_sagemaker_runtime.type_defs import InvokeEndpointOutputTypeDef
|
|
9
6
|
|
|
10
7
|
from kumoai.client import KumoClient
|
|
11
8
|
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
12
9
|
from kumoai.exceptions import HTTPException
|
|
13
10
|
|
|
11
|
+
try:
|
|
12
|
+
# isort: off
|
|
13
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
14
|
+
from mypy_boto3_sagemaker_runtime.type_defs import (
|
|
15
|
+
InvokeEndpointOutputTypeDef, )
|
|
16
|
+
# isort: on
|
|
17
|
+
except ImportError:
|
|
18
|
+
SageMakerRuntimeClient = Any
|
|
19
|
+
InvokeEndpointOutputTypeDef = Any
|
|
20
|
+
|
|
14
21
|
|
|
15
22
|
class SageMakerResponseAdapter(requests.Response):
|
|
16
23
|
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
@@ -34,6 +41,7 @@ class SageMakerResponseAdapter(requests.Response):
|
|
|
34
41
|
|
|
35
42
|
class KumoClient_SageMakerAdapter(KumoClient):
|
|
36
43
|
def __init__(self, region: str, endpoint_name: str):
|
|
44
|
+
import boto3
|
|
37
45
|
self._client: SageMakerRuntimeClient = boto3.client(
|
|
38
46
|
service_name="sagemaker-runtime", region_name=region)
|
|
39
47
|
self._endpoint_name = endpoint_name
|
|
Binary file
|