kumoai 2.12.1__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +18 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +162 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
- kumoai/experimental/rfm/backend/local/sampler.py +313 -0
- kumoai/experimental/rfm/backend/local/table.py +119 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
- kumoai/experimental/rfm/backend/snow/table.py +135 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
- kumoai/experimental/rfm/base/__init__.py +23 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
- kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +233 -174
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +12 -2
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +40 -23
- kumoai/experimental/rfm/local_graph_sampler.py +0 -184
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -21,6 +21,13 @@ import numpy as np
|
|
|
21
21
|
import pandas as pd
|
|
22
22
|
from kumoapi.model_plan import RunMode
|
|
23
23
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.pquery.AST import (
|
|
25
|
+
Aggregation,
|
|
26
|
+
Column,
|
|
27
|
+
Condition,
|
|
28
|
+
Join,
|
|
29
|
+
LogicalOperation,
|
|
30
|
+
)
|
|
24
31
|
from kumoapi.rfm import Context
|
|
25
32
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
33
|
from kumoapi.rfm import (
|
|
@@ -29,16 +36,12 @@ from kumoapi.rfm import (
|
|
|
29
36
|
RFMPredictRequest,
|
|
30
37
|
)
|
|
31
38
|
from kumoapi.task import TaskType
|
|
39
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
40
|
|
|
33
|
-
from kumoai import
|
|
41
|
+
from kumoai.client.rfm import RFMAPI
|
|
34
42
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
43
|
+
from kumoai.experimental.rfm import Graph
|
|
44
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
45
|
from kumoai.mixin import CastMixin
|
|
43
46
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
44
47
|
|
|
@@ -123,17 +126,17 @@ class KumoRFM:
|
|
|
123
126
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
124
127
|
relational dataset without training.
|
|
125
128
|
The model is pre-trained and the class provides an interface to query the
|
|
126
|
-
model from a :class:`
|
|
129
|
+
model from a :class:`Graph` object.
|
|
127
130
|
|
|
128
131
|
.. code-block:: python
|
|
129
132
|
|
|
130
|
-
from kumoai.experimental.rfm import
|
|
133
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
131
134
|
|
|
132
135
|
df_users = pd.DataFrame(...)
|
|
133
136
|
df_items = pd.DataFrame(...)
|
|
134
137
|
df_orders = pd.DataFrame(...)
|
|
135
138
|
|
|
136
|
-
graph =
|
|
139
|
+
graph = Graph.from_data({
|
|
137
140
|
'users': df_users,
|
|
138
141
|
'items': df_items,
|
|
139
142
|
'orders': df_orders,
|
|
@@ -141,40 +144,51 @@ class KumoRFM:
|
|
|
141
144
|
|
|
142
145
|
rfm = KumoRFM(graph)
|
|
143
146
|
|
|
144
|
-
query = ("PREDICT COUNT(
|
|
145
|
-
"FOR users.user_id=
|
|
146
|
-
result = rfm.
|
|
147
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
148
|
+
"FOR users.user_id=1")
|
|
149
|
+
result = rfm.predict(query)
|
|
147
150
|
|
|
148
151
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
149
152
|
# 1 0.85
|
|
150
153
|
|
|
151
154
|
Args:
|
|
152
155
|
graph: The graph.
|
|
153
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
154
|
-
materialization.
|
|
155
|
-
This is a runtime trade-off between graph materialization and model
|
|
156
|
-
processing speed.
|
|
157
|
-
It can be benefical to preprocess your data once and then run many
|
|
158
|
-
queries on top to achieve maximum model speed.
|
|
159
|
-
However, if activiated, graph materialization can take potentially
|
|
160
|
-
much longer, especially on graphs with many large text columns.
|
|
161
|
-
Best to tune this option manually.
|
|
162
156
|
verbose: Whether to print verbose output.
|
|
163
157
|
"""
|
|
164
158
|
def __init__(
|
|
165
159
|
self,
|
|
166
|
-
graph:
|
|
167
|
-
preprocess: bool = False,
|
|
160
|
+
graph: Graph,
|
|
168
161
|
verbose: Union[bool, ProgressLogger] = True,
|
|
169
162
|
) -> None:
|
|
170
163
|
graph = graph.validate()
|
|
171
164
|
self._graph_def = graph._to_api_graph_definition()
|
|
172
|
-
|
|
173
|
-
|
|
165
|
+
|
|
166
|
+
if graph.backend == DataBackend.LOCAL:
|
|
167
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
168
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
169
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
170
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
171
|
+
self._sampler = SQLiteSampler(graph, verbose)
|
|
172
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
173
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
174
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
175
|
+
else:
|
|
176
|
+
raise NotImplementedError
|
|
177
|
+
|
|
178
|
+
self._client: Optional[RFMAPI] = None
|
|
174
179
|
|
|
175
180
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
176
181
|
self.num_retries: int = 0
|
|
177
182
|
|
|
183
|
+
@property
|
|
184
|
+
def _api_client(self) -> RFMAPI:
|
|
185
|
+
if self._client is not None:
|
|
186
|
+
return self._client
|
|
187
|
+
|
|
188
|
+
from kumoai.experimental.rfm import global_state
|
|
189
|
+
self._client = RFMAPI(global_state.client)
|
|
190
|
+
return self._client
|
|
191
|
+
|
|
178
192
|
def __repr__(self) -> str:
|
|
179
193
|
return f'{self.__class__.__name__}()'
|
|
180
194
|
|
|
@@ -223,7 +237,7 @@ class KumoRFM:
|
|
|
223
237
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
224
238
|
num_neighbors: Optional[List[int]] = None,
|
|
225
239
|
num_hops: int = 2,
|
|
226
|
-
max_pq_iterations: int =
|
|
240
|
+
max_pq_iterations: int = 10,
|
|
227
241
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
228
242
|
verbose: Union[bool, ProgressLogger] = True,
|
|
229
243
|
use_prediction_time: bool = False,
|
|
@@ -242,7 +256,7 @@ class KumoRFM:
|
|
|
242
256
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
243
257
|
num_neighbors: Optional[List[int]] = None,
|
|
244
258
|
num_hops: int = 2,
|
|
245
|
-
max_pq_iterations: int =
|
|
259
|
+
max_pq_iterations: int = 10,
|
|
246
260
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
247
261
|
verbose: Union[bool, ProgressLogger] = True,
|
|
248
262
|
use_prediction_time: bool = False,
|
|
@@ -260,7 +274,7 @@ class KumoRFM:
|
|
|
260
274
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
261
275
|
num_neighbors: Optional[List[int]] = None,
|
|
262
276
|
num_hops: int = 2,
|
|
263
|
-
max_pq_iterations: int =
|
|
277
|
+
max_pq_iterations: int = 10,
|
|
264
278
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
265
279
|
verbose: Union[bool, ProgressLogger] = True,
|
|
266
280
|
use_prediction_time: bool = False,
|
|
@@ -356,9 +370,9 @@ class KumoRFM:
|
|
|
356
370
|
|
|
357
371
|
batch_size: Optional[int] = None
|
|
358
372
|
if self._batch_size == 'max':
|
|
359
|
-
task_type =
|
|
360
|
-
query_def,
|
|
361
|
-
edge_types=self.
|
|
373
|
+
task_type = self._get_task_type(
|
|
374
|
+
query=query_def,
|
|
375
|
+
edge_types=self._sampler.edge_types,
|
|
362
376
|
)
|
|
363
377
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
364
378
|
else:
|
|
@@ -420,22 +434,22 @@ class KumoRFM:
|
|
|
420
434
|
for attempt in range(self.num_retries + 1):
|
|
421
435
|
try:
|
|
422
436
|
if explain_config is not None:
|
|
423
|
-
resp =
|
|
437
|
+
resp = self._api_client.explain(
|
|
424
438
|
request=_bytes,
|
|
425
439
|
skip_summary=explain_config.skip_summary,
|
|
426
440
|
)
|
|
427
441
|
summary = resp.summary
|
|
428
442
|
details = resp.details
|
|
429
443
|
else:
|
|
430
|
-
resp =
|
|
444
|
+
resp = self._api_client.predict(_bytes)
|
|
431
445
|
df = pd.DataFrame(**resp.prediction)
|
|
432
446
|
|
|
433
447
|
# Cast 'ENTITY' to correct data type:
|
|
434
448
|
if 'ENTITY' in df:
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
449
|
+
table_dict = context.subgraph.table_dict
|
|
450
|
+
table = table_dict[query_def.entity_table]
|
|
451
|
+
ser = table.df[table.primary_key]
|
|
452
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
439
453
|
|
|
440
454
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
441
455
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -518,23 +532,18 @@ class KumoRFM:
|
|
|
518
532
|
raise ValueError("At least one entity is required")
|
|
519
533
|
|
|
520
534
|
if anchor_time is None:
|
|
521
|
-
anchor_time = self.
|
|
535
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
522
536
|
|
|
523
537
|
if isinstance(anchor_time, pd.Timestamp):
|
|
524
538
|
self._validate_time(query_def, anchor_time, None, False)
|
|
525
539
|
else:
|
|
526
540
|
assert anchor_time == 'entity'
|
|
527
|
-
if
|
|
541
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
528
542
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
529
543
|
f"table '{query_def.entity_table}' "
|
|
530
544
|
f"to have a time column.")
|
|
531
545
|
|
|
532
|
-
|
|
533
|
-
table_name=query_def.entity_table,
|
|
534
|
-
pkey=pd.Series(indices),
|
|
535
|
-
)
|
|
536
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
537
|
-
return query_driver.is_valid(node, anchor_time)
|
|
546
|
+
raise NotImplementedError
|
|
538
547
|
|
|
539
548
|
def evaluate(
|
|
540
549
|
self,
|
|
@@ -546,7 +555,7 @@ class KumoRFM:
|
|
|
546
555
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
547
556
|
num_neighbors: Optional[List[int]] = None,
|
|
548
557
|
num_hops: int = 2,
|
|
549
|
-
max_pq_iterations: int =
|
|
558
|
+
max_pq_iterations: int = 10,
|
|
550
559
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
551
560
|
verbose: Union[bool, ProgressLogger] = True,
|
|
552
561
|
use_prediction_time: bool = False,
|
|
@@ -633,7 +642,7 @@ class KumoRFM:
|
|
|
633
642
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
634
643
|
|
|
635
644
|
try:
|
|
636
|
-
resp =
|
|
645
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
637
646
|
except HTTPException as e:
|
|
638
647
|
try:
|
|
639
648
|
msg = json.loads(e.detail)['detail']
|
|
@@ -657,7 +666,7 @@ class KumoRFM:
|
|
|
657
666
|
*,
|
|
658
667
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
659
668
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
660
|
-
max_iterations: int =
|
|
669
|
+
max_iterations: int = 10,
|
|
661
670
|
) -> pd.DataFrame:
|
|
662
671
|
"""Returns the labels of a predictive query for a specified anchor
|
|
663
672
|
time.
|
|
@@ -677,40 +686,37 @@ class KumoRFM:
|
|
|
677
686
|
query_def = self._parse_query(query)
|
|
678
687
|
|
|
679
688
|
if anchor_time is None:
|
|
680
|
-
anchor_time = self.
|
|
689
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
681
690
|
if query_def.target_ast.date_offset_range is not None:
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
691
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
692
|
+
offset *= query_def.num_forecasts
|
|
693
|
+
anchor_time -= offset
|
|
685
694
|
|
|
686
695
|
assert anchor_time is not None
|
|
687
696
|
if isinstance(anchor_time, pd.Timestamp):
|
|
688
697
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
689
698
|
else:
|
|
690
699
|
assert anchor_time == 'entity'
|
|
691
|
-
if
|
|
700
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
692
701
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
693
702
|
f"table '{query_def.entity_table}' "
|
|
694
703
|
f"to have a time column")
|
|
695
704
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
+
train, test = self._sampler.sample_target(
|
|
706
|
+
query=query,
|
|
707
|
+
num_train_examples=0,
|
|
708
|
+
train_anchor_time=anchor_time,
|
|
709
|
+
num_train_trials=0,
|
|
710
|
+
num_test_examples=size,
|
|
711
|
+
test_anchor_time=anchor_time,
|
|
712
|
+
num_test_trials=max_iterations * size,
|
|
713
|
+
random_seed=random_seed,
|
|
705
714
|
)
|
|
706
715
|
|
|
707
|
-
entity = self._graph_store.pkey_map_dict[
|
|
708
|
-
query_def.entity_table].index[node]
|
|
709
|
-
|
|
710
716
|
return pd.DataFrame({
|
|
711
|
-
'ENTITY':
|
|
712
|
-
'ANCHOR_TIMESTAMP':
|
|
713
|
-
'TARGET':
|
|
717
|
+
'ENTITY': test.entity_pkey,
|
|
718
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
719
|
+
'TARGET': test.target,
|
|
714
720
|
})
|
|
715
721
|
|
|
716
722
|
# Helpers #################################################################
|
|
@@ -731,8 +737,7 @@ class KumoRFM:
|
|
|
731
737
|
graph_definition=self._graph_def,
|
|
732
738
|
)
|
|
733
739
|
|
|
734
|
-
resp =
|
|
735
|
-
# TODO Expose validation warnings.
|
|
740
|
+
resp = self._api_client.parse_query(request)
|
|
736
741
|
|
|
737
742
|
if len(resp.validation_response.warnings) > 0:
|
|
738
743
|
msg = '\n'.join([
|
|
@@ -751,6 +756,60 @@ class KumoRFM:
|
|
|
751
756
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
752
757
|
f"{msg}") from None
|
|
753
758
|
|
|
759
|
+
@staticmethod
|
|
760
|
+
def _get_task_type(
|
|
761
|
+
query: ValidatedPredictiveQuery,
|
|
762
|
+
edge_types: List[Tuple[str, str, str]],
|
|
763
|
+
) -> TaskType:
|
|
764
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
765
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
766
|
+
|
|
767
|
+
target = query.target_ast
|
|
768
|
+
if isinstance(target, Join):
|
|
769
|
+
target = target.rhs_target
|
|
770
|
+
if isinstance(target, Aggregation):
|
|
771
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
772
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
773
|
+
'.')
|
|
774
|
+
target_edge_types = [
|
|
775
|
+
edge_type for edge_type in edge_types
|
|
776
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
777
|
+
]
|
|
778
|
+
if len(target_edge_types) != 1:
|
|
779
|
+
raise NotImplementedError(
|
|
780
|
+
f"Multilabel-classification queries based on "
|
|
781
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
782
|
+
f"planned to write a link prediction query instead, "
|
|
783
|
+
f"make sure to register '{col_name}' as a "
|
|
784
|
+
f"foreign key.")
|
|
785
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
786
|
+
|
|
787
|
+
return TaskType.REGRESSION
|
|
788
|
+
|
|
789
|
+
assert isinstance(target, Column)
|
|
790
|
+
|
|
791
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
792
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
793
|
+
|
|
794
|
+
if target.stype in {Stype.numerical}:
|
|
795
|
+
return TaskType.REGRESSION
|
|
796
|
+
|
|
797
|
+
raise NotImplementedError("Task type not yet supported")
|
|
798
|
+
|
|
799
|
+
def _get_default_anchor_time(
|
|
800
|
+
self,
|
|
801
|
+
query: ValidatedPredictiveQuery,
|
|
802
|
+
) -> pd.Timestamp:
|
|
803
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
804
|
+
aggr_table_names = [
|
|
805
|
+
aggr._get_target_column_name().split('.')[0]
|
|
806
|
+
for aggr in query.get_all_target_aggregations()
|
|
807
|
+
]
|
|
808
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
809
|
+
|
|
810
|
+
assert query.query_type == QueryType.STATIC
|
|
811
|
+
return self._sampler.get_max_time()
|
|
812
|
+
|
|
754
813
|
def _validate_time(
|
|
755
814
|
self,
|
|
756
815
|
query: ValidatedPredictiveQuery,
|
|
@@ -759,28 +818,30 @@ class KumoRFM:
|
|
|
759
818
|
evaluate: bool,
|
|
760
819
|
) -> None:
|
|
761
820
|
|
|
762
|
-
if self.
|
|
821
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
763
822
|
return # Graph without timestamps
|
|
764
823
|
|
|
765
|
-
|
|
824
|
+
min_time = self._sampler.get_min_time()
|
|
825
|
+
max_time = self._sampler.get_max_time()
|
|
826
|
+
|
|
827
|
+
if anchor_time < min_time:
|
|
766
828
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
767
|
-
f"the earliest timestamp "
|
|
768
|
-
f"
|
|
829
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
830
|
+
f"data.")
|
|
769
831
|
|
|
770
|
-
if
|
|
771
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
832
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
772
833
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
773
834
|
f"aggregation time range is too large. To make "
|
|
774
835
|
f"this prediction, we would need data back to "
|
|
775
836
|
f"'{context_anchor_time}', however, your data "
|
|
776
|
-
f"only contains data back to "
|
|
777
|
-
f"'{self._graph_store.min_time}'.")
|
|
837
|
+
f"only contains data back to '{min_time}'.")
|
|
778
838
|
|
|
779
839
|
if query.target_ast.date_offset_range is not None:
|
|
780
840
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
781
841
|
else:
|
|
782
842
|
end_offset = pd.DateOffset(0)
|
|
783
|
-
|
|
843
|
+
end_offset = end_offset * query.num_forecasts
|
|
844
|
+
|
|
784
845
|
if (context_anchor_time is not None
|
|
785
846
|
and context_anchor_time > anchor_time):
|
|
786
847
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -790,7 +851,7 @@ class KumoRFM:
|
|
|
790
851
|
f"intended.")
|
|
791
852
|
elif (query.query_type == QueryType.TEMPORAL
|
|
792
853
|
and context_anchor_time is not None
|
|
793
|
-
and context_anchor_time +
|
|
854
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
794
855
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
795
856
|
f"'{context_anchor_time}' will leak information "
|
|
796
857
|
f"from the prediction anchor timestamp "
|
|
@@ -798,26 +859,23 @@ class KumoRFM:
|
|
|
798
859
|
f"intended.")
|
|
799
860
|
|
|
800
861
|
elif (context_anchor_time is not None
|
|
801
|
-
and context_anchor_time -
|
|
802
|
-
|
|
803
|
-
_time = context_anchor_time - forecast_end_offset
|
|
862
|
+
and context_anchor_time - end_offset < min_time):
|
|
863
|
+
_time = context_anchor_time - end_offset
|
|
804
864
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
805
865
|
f"aggregation time range is too large. To form "
|
|
806
866
|
f"proper input data, we would need data back to "
|
|
807
867
|
f"'{_time}', however, your data only contains "
|
|
808
|
-
f"data back to '{
|
|
868
|
+
f"data back to '{min_time}'.")
|
|
809
869
|
|
|
810
|
-
if
|
|
811
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
870
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
812
871
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
813
|
-
f"latest timestamp '{
|
|
814
|
-
f"
|
|
872
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
873
|
+
f"make sure this is intended.")
|
|
815
874
|
|
|
816
|
-
|
|
817
|
-
if evaluate and anchor_time > max_eval_time:
|
|
875
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
818
876
|
raise ValueError(
|
|
819
877
|
f"Anchor timestamp for evaluation is after the latest "
|
|
820
|
-
f"supported timestamp '{
|
|
878
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
821
879
|
|
|
822
880
|
def _get_context(
|
|
823
881
|
self,
|
|
@@ -848,10 +906,9 @@ class KumoRFM:
|
|
|
848
906
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
849
907
|
f"must go beyond this for your use-case.")
|
|
850
908
|
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
edge_types=self._graph_store.edge_types,
|
|
909
|
+
task_type = self._get_task_type(
|
|
910
|
+
query=query,
|
|
911
|
+
edge_types=self._sampler.edge_types,
|
|
855
912
|
)
|
|
856
913
|
|
|
857
914
|
if logger is not None:
|
|
@@ -883,14 +940,17 @@ class KumoRFM:
|
|
|
883
940
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
884
941
|
|
|
885
942
|
if query.target_ast.date_offset_range is None:
|
|
886
|
-
|
|
943
|
+
step_offset = pd.DateOffset(0)
|
|
887
944
|
else:
|
|
888
|
-
|
|
889
|
-
|
|
945
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
946
|
+
end_offset = step_offset * query.num_forecasts
|
|
947
|
+
|
|
890
948
|
if anchor_time is None:
|
|
891
|
-
anchor_time = self.
|
|
949
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
950
|
+
|
|
892
951
|
if evaluate:
|
|
893
|
-
anchor_time = anchor_time -
|
|
952
|
+
anchor_time = anchor_time - end_offset
|
|
953
|
+
|
|
894
954
|
if logger is not None:
|
|
895
955
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
896
956
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -904,57 +964,71 @@ class KumoRFM:
|
|
|
904
964
|
|
|
905
965
|
assert anchor_time is not None
|
|
906
966
|
if isinstance(anchor_time, pd.Timestamp):
|
|
967
|
+
if context_anchor_time == 'entity':
|
|
968
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
969
|
+
"for context and prediction examples")
|
|
907
970
|
if context_anchor_time is None:
|
|
908
|
-
context_anchor_time = anchor_time -
|
|
971
|
+
context_anchor_time = anchor_time - end_offset
|
|
909
972
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
910
973
|
evaluate)
|
|
911
974
|
else:
|
|
912
975
|
assert anchor_time == 'entity'
|
|
913
|
-
if query.
|
|
976
|
+
if query.query_type != QueryType.STATIC:
|
|
977
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
978
|
+
"static predictive queries")
|
|
979
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
914
980
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
915
981
|
f"table '{query.entity_table}' to "
|
|
916
982
|
f"have a time column")
|
|
917
|
-
if context_anchor_time
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
context_anchor_time =
|
|
983
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
984
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
985
|
+
"for context and prediction examples")
|
|
986
|
+
context_anchor_time = 'entity'
|
|
921
987
|
|
|
922
|
-
|
|
988
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
923
989
|
if evaluate:
|
|
924
|
-
|
|
990
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
925
991
|
if task_type.is_link_pred:
|
|
926
|
-
|
|
992
|
+
num_test_examples = num_test_examples // 5
|
|
993
|
+
else:
|
|
994
|
+
num_test_examples = 0
|
|
995
|
+
|
|
996
|
+
train, test = self._sampler.sample_target(
|
|
997
|
+
query=query,
|
|
998
|
+
num_train_examples=num_train_examples,
|
|
999
|
+
train_anchor_time=context_anchor_time,
|
|
1000
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1001
|
+
num_test_examples=num_test_examples,
|
|
1002
|
+
test_anchor_time=anchor_time,
|
|
1003
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1004
|
+
random_seed=random_seed,
|
|
1005
|
+
)
|
|
1006
|
+
train_pkey, train_time, y_train = train
|
|
1007
|
+
test_pkey, test_time, y_test = test
|
|
927
1008
|
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
num_rhs = y_test.explode().nunique()
|
|
949
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
950
|
-
f"{num_rhs:,} unique items")
|
|
951
|
-
else:
|
|
952
|
-
raise NotImplementedError
|
|
953
|
-
logger.log(msg)
|
|
1009
|
+
if evaluate and logger is not None:
|
|
1010
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1011
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1012
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1013
|
+
f"{pos:.2f}% positive cases")
|
|
1014
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1015
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1016
|
+
f"{y_test.nunique()} classes")
|
|
1017
|
+
elif task_type == TaskType.REGRESSION:
|
|
1018
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1019
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1020
|
+
f"between {format_value(_min)} and "
|
|
1021
|
+
f"{format_value(_max)}")
|
|
1022
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1023
|
+
num_rhs = y_test.explode().nunique()
|
|
1024
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1025
|
+
f"{num_rhs:,} unique items")
|
|
1026
|
+
else:
|
|
1027
|
+
raise NotImplementedError
|
|
1028
|
+
logger.log(msg)
|
|
954
1029
|
|
|
955
|
-
|
|
1030
|
+
if not evaluate:
|
|
956
1031
|
assert indices is not None
|
|
957
|
-
|
|
958
1032
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
959
1033
|
raise ValueError(f"Cannot predict for more than "
|
|
960
1034
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -962,26 +1036,12 @@ class KumoRFM:
|
|
|
962
1036
|
f"`KumoRFM.batch_mode` to process entities "
|
|
963
1037
|
f"in batches")
|
|
964
1038
|
|
|
965
|
-
|
|
966
|
-
table_name=query.entity_table,
|
|
967
|
-
pkey=pd.Series(indices),
|
|
968
|
-
)
|
|
969
|
-
|
|
1039
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
970
1040
|
if isinstance(anchor_time, pd.Timestamp):
|
|
971
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
972
|
-
len(
|
|
1041
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1042
|
+
len(indices)).reset_index(drop=True)
|
|
973
1043
|
else:
|
|
974
|
-
|
|
975
|
-
time = time[test_node] * 1000**3
|
|
976
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
977
|
-
|
|
978
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
979
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
980
|
-
anchor_time=context_anchor_time or 'entity',
|
|
981
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
982
|
-
or anchor_time == 'entity') else None,
|
|
983
|
-
max_iterations=max_pq_iterations,
|
|
984
|
-
)
|
|
1044
|
+
train_time = test_time = 'entity'
|
|
985
1045
|
|
|
986
1046
|
if logger is not None:
|
|
987
1047
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1009,7 +1069,7 @@ class KumoRFM:
|
|
|
1009
1069
|
final_aggr = query.get_final_target_aggregation()
|
|
1010
1070
|
assert final_aggr is not None
|
|
1011
1071
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1012
|
-
for edge_type in self.
|
|
1072
|
+
for edge_type in self._sampler.edge_types:
|
|
1013
1073
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1014
1074
|
entity_table_names = (
|
|
1015
1075
|
query.entity_table,
|
|
@@ -1021,21 +1081,24 @@ class KumoRFM:
|
|
|
1021
1081
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1022
1082
|
# running out-of-distribution between in-context and test examples:
|
|
1023
1083
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1024
|
-
if
|
|
1084
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1025
1085
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1026
1086
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1027
|
-
|
|
1028
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1087
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1029
1088
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1030
1089
|
|
|
1031
|
-
subgraph = self.
|
|
1090
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1032
1091
|
entity_table_names=entity_table_names,
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1092
|
+
entity_pkey=pd.concat(
|
|
1093
|
+
[train_pkey, test_pkey],
|
|
1094
|
+
axis=0,
|
|
1095
|
+
ignore_index=True,
|
|
1096
|
+
),
|
|
1097
|
+
anchor_time=pd.concat(
|
|
1098
|
+
[train_time, test_time],
|
|
1099
|
+
axis=0,
|
|
1100
|
+
ignore_index=True,
|
|
1101
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1039
1102
|
num_neighbors=num_neighbors,
|
|
1040
1103
|
exclude_cols_dict=exclude_cols_dict,
|
|
1041
1104
|
)
|
|
@@ -1047,18 +1110,14 @@ class KumoRFM:
|
|
|
1047
1110
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1048
1111
|
f"must go beyond this for your use-case.")
|
|
1049
1112
|
|
|
1050
|
-
step_size: Optional[int] = None
|
|
1051
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1052
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1053
|
-
|
|
1054
1113
|
return Context(
|
|
1055
1114
|
task_type=task_type,
|
|
1056
1115
|
entity_table_names=entity_table_names,
|
|
1057
1116
|
subgraph=subgraph,
|
|
1058
1117
|
y_train=y_train,
|
|
1059
|
-
y_test=y_test,
|
|
1118
|
+
y_test=y_test if evaluate else None,
|
|
1060
1119
|
top_k=query.top_k,
|
|
1061
|
-
step_size=
|
|
1120
|
+
step_size=None,
|
|
1062
1121
|
)
|
|
1063
1122
|
|
|
1064
1123
|
@staticmethod
|