kumoai 2.13.0.dev202511211730__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +20 -45
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
- kumoai/experimental/rfm/backend/local/sampler.py +313 -0
- kumoai/experimental/rfm/backend/local/table.py +119 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
- kumoai/experimental/rfm/backend/snow/table.py +135 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
- kumoai/experimental/rfm/base/__init__.py +23 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
- kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +224 -167
- kumoai/experimental/rfm/sagemaker.py +11 -3
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +9 -8
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +39 -23
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -21,6 +21,13 @@ import numpy as np
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
from kumoapi.model_plan import RunMode
|
|
23
23
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.pquery.AST import (
|
|
25
|
+
Aggregation,
|
|
26
|
+
Column,
|
|
27
|
+
Condition,
|
|
28
|
+
Join,
|
|
29
|
+
LogicalOperation,
|
|
30
|
+
)
|
|
24
31
|
from kumoapi.rfm import Context
|
|
25
32
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
33
|
from kumoapi.rfm import (
|
|
@@ -29,16 +36,12 @@ from kumoapi.rfm import (
|
|
|
29
36
|
RFMPredictRequest,
|
|
30
37
|
)
|
|
31
38
|
from kumoapi.task import TaskType
|
|
39
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
40
|
|
|
33
41
|
from kumoai.client.rfm import RFMAPI
|
|
34
42
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
43
|
+
from kumoai.experimental.rfm import Graph
|
|
44
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
45
|
from kumoai.mixin import CastMixin
|
|
43
46
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
44
47
|
|
|
@@ -123,17 +126,17 @@ class KumoRFM:
|
|
|
123
126
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
124
127
|
relational dataset without training.
|
|
125
128
|
The model is pre-trained and the class provides an interface to query the
|
|
126
|
-
model from a :class:`
|
|
129
|
+
model from a :class:`Graph` object.
|
|
127
130
|
|
|
128
131
|
.. code-block:: python
|
|
129
132
|
|
|
130
|
-
from kumoai.experimental.rfm import
|
|
133
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
131
134
|
|
|
132
135
|
df_users = pd.DataFrame(...)
|
|
133
136
|
df_items = pd.DataFrame(...)
|
|
134
137
|
df_orders = pd.DataFrame(...)
|
|
135
138
|
|
|
136
|
-
graph =
|
|
139
|
+
graph = Graph.from_data({
|
|
137
140
|
'users': df_users,
|
|
138
141
|
'items': df_items,
|
|
139
142
|
'orders': df_orders,
|
|
@@ -150,32 +153,41 @@ class KumoRFM:
|
|
|
150
153
|
|
|
151
154
|
Args:
|
|
152
155
|
graph: The graph.
|
|
153
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
154
|
-
materialization.
|
|
155
|
-
This is a runtime trade-off between graph materialization and model
|
|
156
|
-
processing speed.
|
|
157
|
-
It can be benefical to preprocess your data once and then run many
|
|
158
|
-
queries on top to achieve maximum model speed.
|
|
159
|
-
However, if activiated, graph materialization can take potentially
|
|
160
|
-
much longer, especially on graphs with many large text columns.
|
|
161
|
-
Best to tune this option manually.
|
|
162
156
|
verbose: Whether to print verbose output.
|
|
163
157
|
"""
|
|
164
158
|
def __init__(
|
|
165
159
|
self,
|
|
166
|
-
graph:
|
|
167
|
-
preprocess: bool = False,
|
|
160
|
+
graph: Graph,
|
|
168
161
|
verbose: Union[bool, ProgressLogger] = True,
|
|
169
162
|
) -> None:
|
|
170
163
|
graph = graph.validate()
|
|
171
164
|
self._graph_def = graph._to_api_graph_definition()
|
|
172
|
-
|
|
173
|
-
|
|
165
|
+
|
|
166
|
+
if graph.backend == DataBackend.LOCAL:
|
|
167
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
168
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
169
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
170
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
171
|
+
self._sampler = SQLiteSampler(graph, verbose)
|
|
172
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
173
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
174
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
175
|
+
else:
|
|
176
|
+
raise NotImplementedError
|
|
177
|
+
|
|
178
|
+
self._client: Optional[RFMAPI] = None
|
|
174
179
|
|
|
175
180
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
176
181
|
self.num_retries: int = 0
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def _api_client(self) -> RFMAPI:
|
|
185
|
+
if self._client is not None:
|
|
186
|
+
return self._client
|
|
187
|
+
|
|
177
188
|
from kumoai.experimental.rfm import global_state
|
|
178
|
-
self.
|
|
189
|
+
self._client = RFMAPI(global_state.client)
|
|
190
|
+
return self._client
|
|
179
191
|
|
|
180
192
|
def __repr__(self) -> str:
|
|
181
193
|
return f'{self.__class__.__name__}()'
|
|
@@ -225,7 +237,7 @@ class KumoRFM:
|
|
|
225
237
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
226
238
|
num_neighbors: Optional[List[int]] = None,
|
|
227
239
|
num_hops: int = 2,
|
|
228
|
-
max_pq_iterations: int =
|
|
240
|
+
max_pq_iterations: int = 10,
|
|
229
241
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
230
242
|
verbose: Union[bool, ProgressLogger] = True,
|
|
231
243
|
use_prediction_time: bool = False,
|
|
@@ -244,7 +256,7 @@ class KumoRFM:
|
|
|
244
256
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
245
257
|
num_neighbors: Optional[List[int]] = None,
|
|
246
258
|
num_hops: int = 2,
|
|
247
|
-
max_pq_iterations: int =
|
|
259
|
+
max_pq_iterations: int = 10,
|
|
248
260
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
249
261
|
verbose: Union[bool, ProgressLogger] = True,
|
|
250
262
|
use_prediction_time: bool = False,
|
|
@@ -262,7 +274,7 @@ class KumoRFM:
|
|
|
262
274
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
263
275
|
num_neighbors: Optional[List[int]] = None,
|
|
264
276
|
num_hops: int = 2,
|
|
265
|
-
max_pq_iterations: int =
|
|
277
|
+
max_pq_iterations: int = 10,
|
|
266
278
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
267
279
|
verbose: Union[bool, ProgressLogger] = True,
|
|
268
280
|
use_prediction_time: bool = False,
|
|
@@ -358,9 +370,9 @@ class KumoRFM:
|
|
|
358
370
|
|
|
359
371
|
batch_size: Optional[int] = None
|
|
360
372
|
if self._batch_size == 'max':
|
|
361
|
-
task_type =
|
|
362
|
-
query_def,
|
|
363
|
-
edge_types=self.
|
|
373
|
+
task_type = self._get_task_type(
|
|
374
|
+
query=query_def,
|
|
375
|
+
edge_types=self._sampler.edge_types,
|
|
364
376
|
)
|
|
365
377
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
366
378
|
else:
|
|
@@ -434,10 +446,10 @@ class KumoRFM:
|
|
|
434
446
|
|
|
435
447
|
# Cast 'ENTITY' to correct data type:
|
|
436
448
|
if 'ENTITY' in df:
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
449
|
+
table_dict = context.subgraph.table_dict
|
|
450
|
+
table = table_dict[query_def.entity_table]
|
|
451
|
+
ser = table.df[table.primary_key]
|
|
452
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
441
453
|
|
|
442
454
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
443
455
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -520,23 +532,18 @@ class KumoRFM:
|
|
|
520
532
|
raise ValueError("At least one entity is required")
|
|
521
533
|
|
|
522
534
|
if anchor_time is None:
|
|
523
|
-
anchor_time = self.
|
|
535
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
524
536
|
|
|
525
537
|
if isinstance(anchor_time, pd.Timestamp):
|
|
526
538
|
self._validate_time(query_def, anchor_time, None, False)
|
|
527
539
|
else:
|
|
528
540
|
assert anchor_time == 'entity'
|
|
529
|
-
if
|
|
541
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
530
542
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
531
543
|
f"table '{query_def.entity_table}' "
|
|
532
544
|
f"to have a time column.")
|
|
533
545
|
|
|
534
|
-
|
|
535
|
-
table_name=query_def.entity_table,
|
|
536
|
-
pkey=pd.Series(indices),
|
|
537
|
-
)
|
|
538
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
539
|
-
return query_driver.is_valid(node, anchor_time)
|
|
546
|
+
raise NotImplementedError
|
|
540
547
|
|
|
541
548
|
def evaluate(
|
|
542
549
|
self,
|
|
@@ -548,7 +555,7 @@ class KumoRFM:
|
|
|
548
555
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
549
556
|
num_neighbors: Optional[List[int]] = None,
|
|
550
557
|
num_hops: int = 2,
|
|
551
|
-
max_pq_iterations: int =
|
|
558
|
+
max_pq_iterations: int = 10,
|
|
552
559
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
553
560
|
verbose: Union[bool, ProgressLogger] = True,
|
|
554
561
|
use_prediction_time: bool = False,
|
|
@@ -659,7 +666,7 @@ class KumoRFM:
|
|
|
659
666
|
*,
|
|
660
667
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
661
668
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
662
|
-
max_iterations: int =
|
|
669
|
+
max_iterations: int = 10,
|
|
663
670
|
) -> pd.DataFrame:
|
|
664
671
|
"""Returns the labels of a predictive query for a specified anchor
|
|
665
672
|
time.
|
|
@@ -679,40 +686,37 @@ class KumoRFM:
|
|
|
679
686
|
query_def = self._parse_query(query)
|
|
680
687
|
|
|
681
688
|
if anchor_time is None:
|
|
682
|
-
anchor_time = self.
|
|
689
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
683
690
|
if query_def.target_ast.date_offset_range is not None:
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
691
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
692
|
+
offset *= query_def.num_forecasts
|
|
693
|
+
anchor_time -= offset
|
|
687
694
|
|
|
688
695
|
assert anchor_time is not None
|
|
689
696
|
if isinstance(anchor_time, pd.Timestamp):
|
|
690
697
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
691
698
|
else:
|
|
692
699
|
assert anchor_time == 'entity'
|
|
693
|
-
if
|
|
700
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
694
701
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
695
702
|
f"table '{query_def.entity_table}' "
|
|
696
703
|
f"to have a time column")
|
|
697
704
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
705
|
+
train, test = self._sampler.sample_target(
|
|
706
|
+
query=query,
|
|
707
|
+
num_train_examples=0,
|
|
708
|
+
train_anchor_time=anchor_time,
|
|
709
|
+
num_train_trials=0,
|
|
710
|
+
num_test_examples=size,
|
|
711
|
+
test_anchor_time=anchor_time,
|
|
712
|
+
num_test_trials=max_iterations * size,
|
|
713
|
+
random_seed=random_seed,
|
|
707
714
|
)
|
|
708
715
|
|
|
709
|
-
entity = self._graph_store.pkey_map_dict[
|
|
710
|
-
query_def.entity_table].index[node]
|
|
711
|
-
|
|
712
716
|
return pd.DataFrame({
|
|
713
|
-
'ENTITY':
|
|
714
|
-
'ANCHOR_TIMESTAMP':
|
|
715
|
-
'TARGET':
|
|
717
|
+
'ENTITY': test.entity_pkey,
|
|
718
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
719
|
+
'TARGET': test.target,
|
|
716
720
|
})
|
|
717
721
|
|
|
718
722
|
# Helpers #################################################################
|
|
@@ -735,8 +739,6 @@ class KumoRFM:
|
|
|
735
739
|
|
|
736
740
|
resp = self._api_client.parse_query(request)
|
|
737
741
|
|
|
738
|
-
# TODO Expose validation warnings.
|
|
739
|
-
|
|
740
742
|
if len(resp.validation_response.warnings) > 0:
|
|
741
743
|
msg = '\n'.join([
|
|
742
744
|
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
@@ -754,6 +756,60 @@ class KumoRFM:
|
|
|
754
756
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
755
757
|
f"{msg}") from None
|
|
756
758
|
|
|
759
|
+
@staticmethod
|
|
760
|
+
def _get_task_type(
|
|
761
|
+
query: ValidatedPredictiveQuery,
|
|
762
|
+
edge_types: List[Tuple[str, str, str]],
|
|
763
|
+
) -> TaskType:
|
|
764
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
765
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
766
|
+
|
|
767
|
+
target = query.target_ast
|
|
768
|
+
if isinstance(target, Join):
|
|
769
|
+
target = target.rhs_target
|
|
770
|
+
if isinstance(target, Aggregation):
|
|
771
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
772
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
773
|
+
'.')
|
|
774
|
+
target_edge_types = [
|
|
775
|
+
edge_type for edge_type in edge_types
|
|
776
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
777
|
+
]
|
|
778
|
+
if len(target_edge_types) != 1:
|
|
779
|
+
raise NotImplementedError(
|
|
780
|
+
f"Multilabel-classification queries based on "
|
|
781
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
782
|
+
f"planned to write a link prediction query instead, "
|
|
783
|
+
f"make sure to register '{col_name}' as a "
|
|
784
|
+
f"foreign key.")
|
|
785
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
786
|
+
|
|
787
|
+
return TaskType.REGRESSION
|
|
788
|
+
|
|
789
|
+
assert isinstance(target, Column)
|
|
790
|
+
|
|
791
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
792
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
793
|
+
|
|
794
|
+
if target.stype in {Stype.numerical}:
|
|
795
|
+
return TaskType.REGRESSION
|
|
796
|
+
|
|
797
|
+
raise NotImplementedError("Task type not yet supported")
|
|
798
|
+
|
|
799
|
+
def _get_default_anchor_time(
|
|
800
|
+
self,
|
|
801
|
+
query: ValidatedPredictiveQuery,
|
|
802
|
+
) -> pd.Timestamp:
|
|
803
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
804
|
+
aggr_table_names = [
|
|
805
|
+
aggr._get_target_column_name().split('.')[0]
|
|
806
|
+
for aggr in query.get_all_target_aggregations()
|
|
807
|
+
]
|
|
808
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
809
|
+
|
|
810
|
+
assert query.query_type == QueryType.STATIC
|
|
811
|
+
return self._sampler.get_max_time()
|
|
812
|
+
|
|
757
813
|
def _validate_time(
|
|
758
814
|
self,
|
|
759
815
|
query: ValidatedPredictiveQuery,
|
|
@@ -762,28 +818,30 @@ class KumoRFM:
|
|
|
762
818
|
evaluate: bool,
|
|
763
819
|
) -> None:
|
|
764
820
|
|
|
765
|
-
if self.
|
|
821
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
766
822
|
return # Graph without timestamps
|
|
767
823
|
|
|
768
|
-
|
|
824
|
+
min_time = self._sampler.get_min_time()
|
|
825
|
+
max_time = self._sampler.get_max_time()
|
|
826
|
+
|
|
827
|
+
if anchor_time < min_time:
|
|
769
828
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
770
|
-
f"the earliest timestamp "
|
|
771
|
-
f"
|
|
829
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
830
|
+
f"data.")
|
|
772
831
|
|
|
773
|
-
if
|
|
774
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
832
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
775
833
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
776
834
|
f"aggregation time range is too large. To make "
|
|
777
835
|
f"this prediction, we would need data back to "
|
|
778
836
|
f"'{context_anchor_time}', however, your data "
|
|
779
|
-
f"only contains data back to "
|
|
780
|
-
f"'{self._graph_store.min_time}'.")
|
|
837
|
+
f"only contains data back to '{min_time}'.")
|
|
781
838
|
|
|
782
839
|
if query.target_ast.date_offset_range is not None:
|
|
783
840
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
784
841
|
else:
|
|
785
842
|
end_offset = pd.DateOffset(0)
|
|
786
|
-
|
|
843
|
+
end_offset = end_offset * query.num_forecasts
|
|
844
|
+
|
|
787
845
|
if (context_anchor_time is not None
|
|
788
846
|
and context_anchor_time > anchor_time):
|
|
789
847
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -793,7 +851,7 @@ class KumoRFM:
|
|
|
793
851
|
f"intended.")
|
|
794
852
|
elif (query.query_type == QueryType.TEMPORAL
|
|
795
853
|
and context_anchor_time is not None
|
|
796
|
-
and context_anchor_time +
|
|
854
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
797
855
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
798
856
|
f"'{context_anchor_time}' will leak information "
|
|
799
857
|
f"from the prediction anchor timestamp "
|
|
@@ -801,26 +859,23 @@ class KumoRFM:
|
|
|
801
859
|
f"intended.")
|
|
802
860
|
|
|
803
861
|
elif (context_anchor_time is not None
|
|
804
|
-
and context_anchor_time -
|
|
805
|
-
|
|
806
|
-
_time = context_anchor_time - forecast_end_offset
|
|
862
|
+
and context_anchor_time - end_offset < min_time):
|
|
863
|
+
_time = context_anchor_time - end_offset
|
|
807
864
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
808
865
|
f"aggregation time range is too large. To form "
|
|
809
866
|
f"proper input data, we would need data back to "
|
|
810
867
|
f"'{_time}', however, your data only contains "
|
|
811
|
-
f"data back to '{
|
|
868
|
+
f"data back to '{min_time}'.")
|
|
812
869
|
|
|
813
|
-
if
|
|
814
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
870
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
815
871
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
816
|
-
f"latest timestamp '{
|
|
817
|
-
f"
|
|
872
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
873
|
+
f"make sure this is intended.")
|
|
818
874
|
|
|
819
|
-
|
|
820
|
-
if evaluate and anchor_time > max_eval_time:
|
|
875
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
821
876
|
raise ValueError(
|
|
822
877
|
f"Anchor timestamp for evaluation is after the latest "
|
|
823
|
-
f"supported timestamp '{
|
|
878
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
824
879
|
|
|
825
880
|
def _get_context(
|
|
826
881
|
self,
|
|
@@ -851,10 +906,9 @@ class KumoRFM:
|
|
|
851
906
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
852
907
|
f"must go beyond this for your use-case.")
|
|
853
908
|
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
edge_types=self._graph_store.edge_types,
|
|
909
|
+
task_type = self._get_task_type(
|
|
910
|
+
query=query,
|
|
911
|
+
edge_types=self._sampler.edge_types,
|
|
858
912
|
)
|
|
859
913
|
|
|
860
914
|
if logger is not None:
|
|
@@ -886,14 +940,17 @@ class KumoRFM:
|
|
|
886
940
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
887
941
|
|
|
888
942
|
if query.target_ast.date_offset_range is None:
|
|
889
|
-
|
|
943
|
+
step_offset = pd.DateOffset(0)
|
|
890
944
|
else:
|
|
891
|
-
|
|
892
|
-
|
|
945
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
946
|
+
end_offset = step_offset * query.num_forecasts
|
|
947
|
+
|
|
893
948
|
if anchor_time is None:
|
|
894
|
-
anchor_time = self.
|
|
949
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
950
|
+
|
|
895
951
|
if evaluate:
|
|
896
|
-
anchor_time = anchor_time -
|
|
952
|
+
anchor_time = anchor_time - end_offset
|
|
953
|
+
|
|
897
954
|
if logger is not None:
|
|
898
955
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
899
956
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -907,57 +964,71 @@ class KumoRFM:
|
|
|
907
964
|
|
|
908
965
|
assert anchor_time is not None
|
|
909
966
|
if isinstance(anchor_time, pd.Timestamp):
|
|
967
|
+
if context_anchor_time == 'entity':
|
|
968
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
969
|
+
"for context and prediction examples")
|
|
910
970
|
if context_anchor_time is None:
|
|
911
|
-
context_anchor_time = anchor_time -
|
|
971
|
+
context_anchor_time = anchor_time - end_offset
|
|
912
972
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
913
973
|
evaluate)
|
|
914
974
|
else:
|
|
915
975
|
assert anchor_time == 'entity'
|
|
916
|
-
if query.
|
|
976
|
+
if query.query_type != QueryType.STATIC:
|
|
977
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
978
|
+
"static predictive queries")
|
|
979
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
917
980
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
918
981
|
f"table '{query.entity_table}' to "
|
|
919
982
|
f"have a time column")
|
|
920
|
-
if context_anchor_time
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
context_anchor_time =
|
|
983
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
984
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
985
|
+
"for context and prediction examples")
|
|
986
|
+
context_anchor_time = 'entity'
|
|
924
987
|
|
|
925
|
-
|
|
988
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
926
989
|
if evaluate:
|
|
927
|
-
|
|
990
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
928
991
|
if task_type.is_link_pred:
|
|
929
|
-
|
|
992
|
+
num_test_examples = num_test_examples // 5
|
|
993
|
+
else:
|
|
994
|
+
num_test_examples = 0
|
|
995
|
+
|
|
996
|
+
train, test = self._sampler.sample_target(
|
|
997
|
+
query=query,
|
|
998
|
+
num_train_examples=num_train_examples,
|
|
999
|
+
train_anchor_time=context_anchor_time,
|
|
1000
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1001
|
+
num_test_examples=num_test_examples,
|
|
1002
|
+
test_anchor_time=anchor_time,
|
|
1003
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1004
|
+
random_seed=random_seed,
|
|
1005
|
+
)
|
|
1006
|
+
train_pkey, train_time, y_train = train
|
|
1007
|
+
test_pkey, test_time, y_test = test
|
|
930
1008
|
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
num_rhs = y_test.explode().nunique()
|
|
952
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
953
|
-
f"{num_rhs:,} unique items")
|
|
954
|
-
else:
|
|
955
|
-
raise NotImplementedError
|
|
956
|
-
logger.log(msg)
|
|
1009
|
+
if evaluate and logger is not None:
|
|
1010
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1011
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1012
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1013
|
+
f"{pos:.2f}% positive cases")
|
|
1014
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1015
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1016
|
+
f"{y_test.nunique()} classes")
|
|
1017
|
+
elif task_type == TaskType.REGRESSION:
|
|
1018
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1019
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1020
|
+
f"between {format_value(_min)} and "
|
|
1021
|
+
f"{format_value(_max)}")
|
|
1022
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1023
|
+
num_rhs = y_test.explode().nunique()
|
|
1024
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1025
|
+
f"{num_rhs:,} unique items")
|
|
1026
|
+
else:
|
|
1027
|
+
raise NotImplementedError
|
|
1028
|
+
logger.log(msg)
|
|
957
1029
|
|
|
958
|
-
|
|
1030
|
+
if not evaluate:
|
|
959
1031
|
assert indices is not None
|
|
960
|
-
|
|
961
1032
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
962
1033
|
raise ValueError(f"Cannot predict for more than "
|
|
963
1034
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -965,26 +1036,12 @@ class KumoRFM:
|
|
|
965
1036
|
f"`KumoRFM.batch_mode` to process entities "
|
|
966
1037
|
f"in batches")
|
|
967
1038
|
|
|
968
|
-
|
|
969
|
-
table_name=query.entity_table,
|
|
970
|
-
pkey=pd.Series(indices),
|
|
971
|
-
)
|
|
972
|
-
|
|
1039
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
973
1040
|
if isinstance(anchor_time, pd.Timestamp):
|
|
974
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
975
|
-
len(
|
|
1041
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1042
|
+
len(indices)).reset_index(drop=True)
|
|
976
1043
|
else:
|
|
977
|
-
|
|
978
|
-
time = time[test_node] * 1000**3
|
|
979
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
980
|
-
|
|
981
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
982
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
983
|
-
anchor_time=context_anchor_time or 'entity',
|
|
984
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
985
|
-
or anchor_time == 'entity') else None,
|
|
986
|
-
max_iterations=max_pq_iterations,
|
|
987
|
-
)
|
|
1044
|
+
train_time = test_time = 'entity'
|
|
988
1045
|
|
|
989
1046
|
if logger is not None:
|
|
990
1047
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1012,7 +1069,7 @@ class KumoRFM:
|
|
|
1012
1069
|
final_aggr = query.get_final_target_aggregation()
|
|
1013
1070
|
assert final_aggr is not None
|
|
1014
1071
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1015
|
-
for edge_type in self.
|
|
1072
|
+
for edge_type in self._sampler.edge_types:
|
|
1016
1073
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1017
1074
|
entity_table_names = (
|
|
1018
1075
|
query.entity_table,
|
|
@@ -1024,20 +1081,24 @@ class KumoRFM:
|
|
|
1024
1081
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1025
1082
|
# running out-of-distribution between in-context and test examples:
|
|
1026
1083
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1027
|
-
if
|
|
1084
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1028
1085
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1029
1086
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1030
|
-
|
|
1031
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1087
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1032
1088
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1033
1089
|
|
|
1034
|
-
subgraph = self.
|
|
1090
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1035
1091
|
entity_table_names=entity_table_names,
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1092
|
+
entity_pkey=pd.concat(
|
|
1093
|
+
[train_pkey, test_pkey],
|
|
1094
|
+
axis=0,
|
|
1095
|
+
ignore_index=True,
|
|
1096
|
+
),
|
|
1097
|
+
anchor_time=pd.concat(
|
|
1098
|
+
[train_time, test_time],
|
|
1099
|
+
axis=0,
|
|
1100
|
+
ignore_index=True,
|
|
1101
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1041
1102
|
num_neighbors=num_neighbors,
|
|
1042
1103
|
exclude_cols_dict=exclude_cols_dict,
|
|
1043
1104
|
)
|
|
@@ -1049,18 +1110,14 @@ class KumoRFM:
|
|
|
1049
1110
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1050
1111
|
f"must go beyond this for your use-case.")
|
|
1051
1112
|
|
|
1052
|
-
step_size: Optional[int] = None
|
|
1053
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1054
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1055
|
-
|
|
1056
1113
|
return Context(
|
|
1057
1114
|
task_type=task_type,
|
|
1058
1115
|
entity_table_names=entity_table_names,
|
|
1059
1116
|
subgraph=subgraph,
|
|
1060
1117
|
y_train=y_train,
|
|
1061
|
-
y_test=y_test,
|
|
1118
|
+
y_test=y_test if evaluate else None,
|
|
1062
1119
|
top_k=query.top_k,
|
|
1063
|
-
step_size=
|
|
1120
|
+
step_size=None,
|
|
1064
1121
|
)
|
|
1065
1122
|
|
|
1066
1123
|
@staticmethod
|