kumoai 2.12.0.dev202510231830__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +41 -35
- kumoai/_version.py +1 -1
- kumoai/client/client.py +15 -13
- kumoai/client/endpoints.py +1 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/client/rfm.py +35 -7
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +191 -48
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +735 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/__init__.py +0 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +64 -40
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +386 -276
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/kumolib.cp311-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/trainer.py +9 -10
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +188 -16
- kumoai/utils/sql.py +3 -0
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +57 -36
- kumoai/experimental/rfm/local_graph.py +0 -810
- kumoai/experimental/rfm/local_graph_sampler.py +0 -184
- kumoai/experimental/rfm/local_pquery_driver.py +0 -494
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/pquery/backend.py +0 -136
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -2,35 +2,38 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Any, Literal, overload
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
13
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
14
21
|
from kumoapi.rfm import Context
|
|
15
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
23
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
24
|
RFMEvaluateRequest,
|
|
25
|
+
RFMParseQueryRequest,
|
|
19
26
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
27
|
)
|
|
22
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
23
30
|
|
|
24
|
-
from kumoai import
|
|
31
|
+
from kumoai.client.rfm import RFMAPI
|
|
25
32
|
from kumoai.exceptions import HTTPException
|
|
26
|
-
from kumoai.experimental.rfm import
|
|
27
|
-
from kumoai.experimental.rfm.
|
|
28
|
-
from kumoai.
|
|
29
|
-
from kumoai.
|
|
30
|
-
LocalPQueryDriver,
|
|
31
|
-
date_offset_to_seconds,
|
|
32
|
-
)
|
|
33
|
-
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
33
|
+
from kumoai.experimental.rfm import Graph
|
|
34
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
35
|
+
from kumoai.mixin import CastMixin
|
|
36
|
+
from kumoai.utils import ProgressLogger, display
|
|
34
37
|
|
|
35
38
|
_RANDOM_SEED = 42
|
|
36
39
|
|
|
@@ -60,6 +63,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
60
63
|
"beyond this for your use-case.")
|
|
61
64
|
|
|
62
65
|
|
|
66
|
+
@dataclass(repr=False)
|
|
67
|
+
class ExplainConfig(CastMixin):
|
|
68
|
+
"""Configuration for explainability.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
72
|
+
the explanation.
|
|
73
|
+
"""
|
|
74
|
+
skip_summary: bool = False
|
|
75
|
+
|
|
76
|
+
|
|
63
77
|
@dataclass(repr=False)
|
|
64
78
|
class Explanation:
|
|
65
79
|
prediction: pd.DataFrame
|
|
@@ -74,19 +88,27 @@ class Explanation:
|
|
|
74
88
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
75
89
|
pass
|
|
76
90
|
|
|
77
|
-
def __getitem__(self, index: int) ->
|
|
91
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
78
92
|
if index == 0:
|
|
79
93
|
return self.prediction
|
|
80
94
|
if index == 1:
|
|
81
95
|
return self.summary
|
|
82
96
|
raise IndexError("Index out of range")
|
|
83
97
|
|
|
84
|
-
def __iter__(self) -> Iterator[
|
|
98
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
85
99
|
return iter((self.prediction, self.summary))
|
|
86
100
|
|
|
87
101
|
def __repr__(self) -> str:
|
|
88
102
|
return str((self.prediction, self.summary))
|
|
89
103
|
|
|
104
|
+
def print(self) -> None:
|
|
105
|
+
r"""Prints the explanation."""
|
|
106
|
+
display.dataframe(self.prediction)
|
|
107
|
+
display.message(self.summary)
|
|
108
|
+
|
|
109
|
+
def _ipython_display_(self) -> None:
|
|
110
|
+
self.print()
|
|
111
|
+
|
|
90
112
|
|
|
91
113
|
class KumoRFM:
|
|
92
114
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -96,17 +118,17 @@ class KumoRFM:
|
|
|
96
118
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
97
119
|
relational dataset without training.
|
|
98
120
|
The model is pre-trained and the class provides an interface to query the
|
|
99
|
-
model from a :class:`
|
|
121
|
+
model from a :class:`Graph` object.
|
|
100
122
|
|
|
101
123
|
.. code-block:: python
|
|
102
124
|
|
|
103
|
-
from kumoai.experimental.rfm import
|
|
125
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
104
126
|
|
|
105
127
|
df_users = pd.DataFrame(...)
|
|
106
128
|
df_items = pd.DataFrame(...)
|
|
107
129
|
df_orders = pd.DataFrame(...)
|
|
108
130
|
|
|
109
|
-
graph =
|
|
131
|
+
graph = Graph.from_data({
|
|
110
132
|
'users': df_users,
|
|
111
133
|
'items': df_items,
|
|
112
134
|
'orders': df_orders,
|
|
@@ -114,47 +136,63 @@ class KumoRFM:
|
|
|
114
136
|
|
|
115
137
|
rfm = KumoRFM(graph)
|
|
116
138
|
|
|
117
|
-
query = ("PREDICT COUNT(
|
|
118
|
-
"FOR users.user_id=
|
|
119
|
-
result = rfm.
|
|
139
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
140
|
+
"FOR users.user_id=1")
|
|
141
|
+
result = rfm.predict(query)
|
|
120
142
|
|
|
121
143
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
122
144
|
# 1 0.85
|
|
123
145
|
|
|
124
146
|
Args:
|
|
125
147
|
graph: The graph.
|
|
126
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
127
|
-
materialization.
|
|
128
|
-
This is a runtime trade-off between graph materialization and model
|
|
129
|
-
processing speed.
|
|
130
|
-
It can be benefical to preprocess your data once and then run many
|
|
131
|
-
queries on top to achieve maximum model speed.
|
|
132
|
-
However, if activiated, graph materialization can take potentially
|
|
133
|
-
much longer, especially on graphs with many large text columns.
|
|
134
|
-
Best to tune this option manually.
|
|
135
148
|
verbose: Whether to print verbose output.
|
|
149
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
150
|
+
for optimal querying. For example, for transactional database
|
|
151
|
+
backends, will create any missing indices. Requires write-access to
|
|
152
|
+
the data backend.
|
|
136
153
|
"""
|
|
137
154
|
def __init__(
|
|
138
155
|
self,
|
|
139
|
-
graph:
|
|
140
|
-
|
|
141
|
-
|
|
156
|
+
graph: Graph,
|
|
157
|
+
verbose: bool | ProgressLogger = True,
|
|
158
|
+
optimize: bool = False,
|
|
142
159
|
) -> None:
|
|
143
160
|
graph = graph.validate()
|
|
144
161
|
self._graph_def = graph._to_api_graph_definition()
|
|
145
|
-
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
146
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
147
162
|
|
|
148
|
-
|
|
163
|
+
if graph.backend == DataBackend.LOCAL:
|
|
164
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
165
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
166
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
167
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
168
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
169
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
170
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
171
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
172
|
+
else:
|
|
173
|
+
raise NotImplementedError
|
|
174
|
+
|
|
175
|
+
self._client: RFMAPI | None = None
|
|
176
|
+
|
|
177
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
149
178
|
self.num_retries: int = 0
|
|
150
179
|
|
|
180
|
+
@property
|
|
181
|
+
def _api_client(self) -> RFMAPI:
|
|
182
|
+
if self._client is not None:
|
|
183
|
+
return self._client
|
|
184
|
+
|
|
185
|
+
from kumoai.experimental.rfm import global_state
|
|
186
|
+
self._client = RFMAPI(global_state.client)
|
|
187
|
+
return self._client
|
|
188
|
+
|
|
151
189
|
def __repr__(self) -> str:
|
|
152
190
|
return f'{self.__class__.__name__}()'
|
|
153
191
|
|
|
154
192
|
@contextmanager
|
|
155
193
|
def batch_mode(
|
|
156
194
|
self,
|
|
157
|
-
batch_size:
|
|
195
|
+
batch_size: int | Literal['max'] = 'max',
|
|
158
196
|
num_retries: int = 1,
|
|
159
197
|
) -> Generator[None, None, None]:
|
|
160
198
|
"""Context manager to predict in batches.
|
|
@@ -188,17 +226,17 @@ class KumoRFM:
|
|
|
188
226
|
def predict(
|
|
189
227
|
self,
|
|
190
228
|
query: str,
|
|
191
|
-
indices:
|
|
229
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
192
230
|
*,
|
|
193
231
|
explain: Literal[False] = False,
|
|
194
|
-
anchor_time:
|
|
195
|
-
context_anchor_time:
|
|
196
|
-
run_mode:
|
|
197
|
-
num_neighbors:
|
|
232
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
233
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
234
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
235
|
+
num_neighbors: list[int] | None = None,
|
|
198
236
|
num_hops: int = 2,
|
|
199
|
-
max_pq_iterations: int =
|
|
200
|
-
random_seed:
|
|
201
|
-
verbose:
|
|
237
|
+
max_pq_iterations: int = 10,
|
|
238
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
239
|
+
verbose: bool | ProgressLogger = True,
|
|
202
240
|
use_prediction_time: bool = False,
|
|
203
241
|
) -> pd.DataFrame:
|
|
204
242
|
pass
|
|
@@ -207,17 +245,17 @@ class KumoRFM:
|
|
|
207
245
|
def predict(
|
|
208
246
|
self,
|
|
209
247
|
query: str,
|
|
210
|
-
indices:
|
|
248
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
211
249
|
*,
|
|
212
|
-
explain: Literal[True],
|
|
213
|
-
anchor_time:
|
|
214
|
-
context_anchor_time:
|
|
215
|
-
run_mode:
|
|
216
|
-
num_neighbors:
|
|
250
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
251
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
252
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
253
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
254
|
+
num_neighbors: list[int] | None = None,
|
|
217
255
|
num_hops: int = 2,
|
|
218
|
-
max_pq_iterations: int =
|
|
219
|
-
random_seed:
|
|
220
|
-
verbose:
|
|
256
|
+
max_pq_iterations: int = 10,
|
|
257
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
258
|
+
verbose: bool | ProgressLogger = True,
|
|
221
259
|
use_prediction_time: bool = False,
|
|
222
260
|
) -> Explanation:
|
|
223
261
|
pass
|
|
@@ -225,19 +263,19 @@ class KumoRFM:
|
|
|
225
263
|
def predict(
|
|
226
264
|
self,
|
|
227
265
|
query: str,
|
|
228
|
-
indices:
|
|
266
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
229
267
|
*,
|
|
230
|
-
explain: bool = False,
|
|
231
|
-
anchor_time:
|
|
232
|
-
context_anchor_time:
|
|
233
|
-
run_mode:
|
|
234
|
-
num_neighbors:
|
|
268
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
269
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
270
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
271
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
272
|
+
num_neighbors: list[int] | None = None,
|
|
235
273
|
num_hops: int = 2,
|
|
236
|
-
max_pq_iterations: int =
|
|
237
|
-
random_seed:
|
|
238
|
-
verbose:
|
|
274
|
+
max_pq_iterations: int = 10,
|
|
275
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
276
|
+
verbose: bool | ProgressLogger = True,
|
|
239
277
|
use_prediction_time: bool = False,
|
|
240
|
-
) ->
|
|
278
|
+
) -> pd.DataFrame | Explanation:
|
|
241
279
|
"""Returns predictions for a predictive query.
|
|
242
280
|
|
|
243
281
|
Args:
|
|
@@ -247,9 +285,12 @@ class KumoRFM:
|
|
|
247
285
|
be generated for all indices, independent of whether they
|
|
248
286
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
249
287
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
250
|
-
explain:
|
|
251
|
-
|
|
252
|
-
|
|
288
|
+
explain: Configuration for explainability.
|
|
289
|
+
If set to ``True``, will additionally explain the prediction.
|
|
290
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
291
|
+
over which parts of explanation are generated.
|
|
292
|
+
Explainability is currently only supported for single entity
|
|
293
|
+
predictions with ``run_mode="FAST"``.
|
|
253
294
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
254
295
|
``None``, will use the maximum timestamp in the data.
|
|
255
296
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -273,56 +314,62 @@ class KumoRFM:
|
|
|
273
314
|
|
|
274
315
|
Returns:
|
|
275
316
|
The predictions as a :class:`pandas.DataFrame`.
|
|
276
|
-
If ``explain
|
|
277
|
-
|
|
317
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
318
|
+
containing the prediction, summary, and details.
|
|
278
319
|
"""
|
|
320
|
+
explain_config: ExplainConfig | None = None
|
|
321
|
+
if explain is True:
|
|
322
|
+
explain_config = ExplainConfig()
|
|
323
|
+
elif explain is not False:
|
|
324
|
+
explain_config = ExplainConfig._cast(explain)
|
|
325
|
+
|
|
279
326
|
query_def = self._parse_query(query)
|
|
327
|
+
query_str = query_def.to_string()
|
|
280
328
|
|
|
281
329
|
if num_hops != 2 and num_neighbors is not None:
|
|
282
330
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
283
331
|
f"custom 'num_hops={num_hops}' option")
|
|
284
332
|
|
|
285
|
-
if
|
|
333
|
+
if explain_config is not None and run_mode in {
|
|
334
|
+
RunMode.NORMAL, RunMode.BEST
|
|
335
|
+
}:
|
|
286
336
|
warnings.warn(f"Explainability is currently only supported for "
|
|
287
337
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
288
338
|
f"mode has been reset. Please lower the run mode to "
|
|
289
339
|
f"suppress this warning.")
|
|
290
340
|
|
|
291
341
|
if indices is None:
|
|
292
|
-
if query_def.
|
|
342
|
+
if query_def.rfm_entity_ids is None:
|
|
293
343
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
294
344
|
"pass them via `predict(query, indices=...)`")
|
|
295
|
-
indices = query_def.
|
|
345
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
296
346
|
else:
|
|
297
|
-
query_def = replace(
|
|
298
|
-
query_def,
|
|
299
|
-
entity=replace(query_def.entity, ids=None),
|
|
300
|
-
)
|
|
347
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
301
348
|
|
|
302
349
|
if len(indices) == 0:
|
|
303
350
|
raise ValueError("At least one entity is required")
|
|
304
351
|
|
|
305
|
-
if
|
|
352
|
+
if explain_config is not None and len(indices) > 1:
|
|
306
353
|
raise ValueError(
|
|
307
354
|
f"Cannot explain predictions for more than a single entity "
|
|
308
355
|
f"(got {len(indices)})")
|
|
309
356
|
|
|
310
357
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
311
|
-
if
|
|
358
|
+
if explain_config is not None:
|
|
312
359
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
313
360
|
else:
|
|
314
361
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
315
362
|
|
|
316
363
|
if not isinstance(verbose, ProgressLogger):
|
|
317
|
-
verbose =
|
|
364
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
318
365
|
|
|
319
366
|
with verbose as logger:
|
|
320
367
|
|
|
321
|
-
batch_size:
|
|
368
|
+
batch_size: int | None = None
|
|
322
369
|
if self._batch_size == 'max':
|
|
323
|
-
task_type =
|
|
324
|
-
|
|
325
|
-
edge_types=self.
|
|
370
|
+
task_type = self._get_task_type(
|
|
371
|
+
query=query_def,
|
|
372
|
+
edge_types=self._sampler.edge_types,
|
|
326
373
|
)
|
|
327
374
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
328
375
|
else:
|
|
@@ -338,9 +385,9 @@ class KumoRFM:
|
|
|
338
385
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
339
386
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
340
387
|
|
|
341
|
-
predictions:
|
|
342
|
-
summary:
|
|
343
|
-
details:
|
|
388
|
+
predictions: list[pd.DataFrame] = []
|
|
389
|
+
summary: str | None = None
|
|
390
|
+
details: Explanation | None = None
|
|
344
391
|
for i, batch in enumerate(batches):
|
|
345
392
|
# TODO Re-use the context for subsequent predictions.
|
|
346
393
|
context = self._get_context(
|
|
@@ -359,6 +406,7 @@ class KumoRFM:
|
|
|
359
406
|
request = RFMPredictRequest(
|
|
360
407
|
context=context,
|
|
361
408
|
run_mode=RunMode(run_mode),
|
|
409
|
+
query=query_str,
|
|
362
410
|
use_prediction_time=use_prediction_time,
|
|
363
411
|
)
|
|
364
412
|
with warnings.catch_warnings():
|
|
@@ -373,8 +421,7 @@ class KumoRFM:
|
|
|
373
421
|
stats = Context.get_memory_stats(request_msg.context)
|
|
374
422
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
375
423
|
|
|
376
|
-
if
|
|
377
|
-
and len(batches) > 1):
|
|
424
|
+
if i == 0 and len(batches) > 1:
|
|
378
425
|
verbose.init_progress(
|
|
379
426
|
total=len(batches),
|
|
380
427
|
description='Predicting',
|
|
@@ -382,20 +429,23 @@ class KumoRFM:
|
|
|
382
429
|
|
|
383
430
|
for attempt in range(self.num_retries + 1):
|
|
384
431
|
try:
|
|
385
|
-
if
|
|
386
|
-
resp =
|
|
432
|
+
if explain_config is not None:
|
|
433
|
+
resp = self._api_client.explain(
|
|
434
|
+
request=_bytes,
|
|
435
|
+
skip_summary=explain_config.skip_summary,
|
|
436
|
+
)
|
|
387
437
|
summary = resp.summary
|
|
388
438
|
details = resp.details
|
|
389
439
|
else:
|
|
390
|
-
resp =
|
|
440
|
+
resp = self._api_client.predict(_bytes)
|
|
391
441
|
df = pd.DataFrame(**resp.prediction)
|
|
392
442
|
|
|
393
443
|
# Cast 'ENTITY' to correct data type:
|
|
394
444
|
if 'ENTITY' in df:
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
445
|
+
table_dict = context.subgraph.table_dict
|
|
446
|
+
table = table_dict[query_def.entity_table]
|
|
447
|
+
ser = table.df[table.primary_key]
|
|
448
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
399
449
|
|
|
400
450
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
401
451
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -410,8 +460,7 @@ class KumoRFM:
|
|
|
410
460
|
|
|
411
461
|
predictions.append(df)
|
|
412
462
|
|
|
413
|
-
if (
|
|
414
|
-
and len(batches) > 1):
|
|
463
|
+
if len(batches) > 1:
|
|
415
464
|
verbose.step()
|
|
416
465
|
|
|
417
466
|
break
|
|
@@ -434,7 +483,7 @@ class KumoRFM:
|
|
|
434
483
|
else:
|
|
435
484
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
436
485
|
|
|
437
|
-
if
|
|
486
|
+
if explain_config is not None:
|
|
438
487
|
assert len(predictions) == 1
|
|
439
488
|
assert summary is not None
|
|
440
489
|
assert details is not None
|
|
@@ -449,9 +498,9 @@ class KumoRFM:
|
|
|
449
498
|
def is_valid_entity(
|
|
450
499
|
self,
|
|
451
500
|
query: str,
|
|
452
|
-
indices:
|
|
501
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
453
502
|
*,
|
|
454
|
-
anchor_time:
|
|
503
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
455
504
|
) -> np.ndarray:
|
|
456
505
|
r"""Returns a mask that denotes which entities are valid for the
|
|
457
506
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -468,48 +517,42 @@ class KumoRFM:
|
|
|
468
517
|
query_def = self._parse_query(query)
|
|
469
518
|
|
|
470
519
|
if indices is None:
|
|
471
|
-
if query_def.
|
|
520
|
+
if query_def.rfm_entity_ids is None:
|
|
472
521
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
473
522
|
"pass them via "
|
|
474
523
|
"`is_valid_entity(query, indices=...)`")
|
|
475
|
-
indices = query_def.
|
|
524
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
476
525
|
|
|
477
526
|
if len(indices) == 0:
|
|
478
527
|
raise ValueError("At least one entity is required")
|
|
479
528
|
|
|
480
529
|
if anchor_time is None:
|
|
481
|
-
anchor_time = self.
|
|
530
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
482
531
|
|
|
483
532
|
if isinstance(anchor_time, pd.Timestamp):
|
|
484
533
|
self._validate_time(query_def, anchor_time, None, False)
|
|
485
534
|
else:
|
|
486
535
|
assert anchor_time == 'entity'
|
|
487
|
-
if
|
|
488
|
-
not in self._graph_store.time_dict):
|
|
536
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
489
537
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
490
|
-
f"table '{query_def.
|
|
491
|
-
f"to have a time column")
|
|
538
|
+
f"table '{query_def.entity_table}' "
|
|
539
|
+
f"to have a time column.")
|
|
492
540
|
|
|
493
|
-
|
|
494
|
-
table_name=query_def.entity.pkey.table_name,
|
|
495
|
-
pkey=pd.Series(indices),
|
|
496
|
-
)
|
|
497
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
498
|
-
return query_driver.is_valid(node, anchor_time)
|
|
541
|
+
raise NotImplementedError
|
|
499
542
|
|
|
500
543
|
def evaluate(
|
|
501
544
|
self,
|
|
502
545
|
query: str,
|
|
503
546
|
*,
|
|
504
|
-
metrics:
|
|
505
|
-
anchor_time:
|
|
506
|
-
context_anchor_time:
|
|
507
|
-
run_mode:
|
|
508
|
-
num_neighbors:
|
|
547
|
+
metrics: list[str] | None = None,
|
|
548
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
549
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
550
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
551
|
+
num_neighbors: list[int] | None = None,
|
|
509
552
|
num_hops: int = 2,
|
|
510
|
-
max_pq_iterations: int =
|
|
511
|
-
random_seed:
|
|
512
|
-
verbose:
|
|
553
|
+
max_pq_iterations: int = 10,
|
|
554
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
555
|
+
verbose: bool | ProgressLogger = True,
|
|
513
556
|
use_prediction_time: bool = False,
|
|
514
557
|
) -> pd.DataFrame:
|
|
515
558
|
"""Evaluates a predictive query.
|
|
@@ -547,17 +590,17 @@ class KumoRFM:
|
|
|
547
590
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
548
591
|
f"custom 'num_hops={num_hops}' option")
|
|
549
592
|
|
|
550
|
-
if query_def.
|
|
593
|
+
if query_def.rfm_entity_ids is not None:
|
|
551
594
|
query_def = replace(
|
|
552
595
|
query_def,
|
|
553
|
-
|
|
596
|
+
rfm_entity_ids=None,
|
|
554
597
|
)
|
|
555
598
|
|
|
556
599
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
557
600
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
558
601
|
|
|
559
602
|
if not isinstance(verbose, ProgressLogger):
|
|
560
|
-
verbose =
|
|
603
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
561
604
|
|
|
562
605
|
with verbose as logger:
|
|
563
606
|
context = self._get_context(
|
|
@@ -591,10 +634,10 @@ class KumoRFM:
|
|
|
591
634
|
|
|
592
635
|
if len(request_bytes) > _MAX_SIZE:
|
|
593
636
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
594
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
637
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
595
638
|
|
|
596
639
|
try:
|
|
597
|
-
resp =
|
|
640
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
598
641
|
except HTTPException as e:
|
|
599
642
|
try:
|
|
600
643
|
msg = json.loads(e.detail)['detail']
|
|
@@ -616,9 +659,9 @@ class KumoRFM:
|
|
|
616
659
|
query: str,
|
|
617
660
|
size: int,
|
|
618
661
|
*,
|
|
619
|
-
anchor_time:
|
|
620
|
-
random_seed:
|
|
621
|
-
max_iterations: int =
|
|
662
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
663
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
664
|
+
max_iterations: int = 10,
|
|
622
665
|
) -> pd.DataFrame:
|
|
623
666
|
"""Returns the labels of a predictive query for a specified anchor
|
|
624
667
|
time.
|
|
@@ -638,45 +681,43 @@ class KumoRFM:
|
|
|
638
681
|
query_def = self._parse_query(query)
|
|
639
682
|
|
|
640
683
|
if anchor_time is None:
|
|
641
|
-
anchor_time = self.
|
|
642
|
-
|
|
643
|
-
|
|
684
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
685
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
686
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
687
|
+
offset *= query_def.num_forecasts
|
|
688
|
+
anchor_time -= offset
|
|
644
689
|
|
|
645
690
|
assert anchor_time is not None
|
|
646
691
|
if isinstance(anchor_time, pd.Timestamp):
|
|
647
692
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
648
693
|
else:
|
|
649
694
|
assert anchor_time == 'entity'
|
|
650
|
-
if
|
|
651
|
-
not in self._graph_store.time_dict):
|
|
695
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
652
696
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
653
|
-
f"table '{query_def.
|
|
697
|
+
f"table '{query_def.entity_table}' "
|
|
654
698
|
f"to have a time column")
|
|
655
699
|
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
700
|
+
train, test = self._sampler.sample_target(
|
|
701
|
+
query=query_def,
|
|
702
|
+
num_train_examples=0,
|
|
703
|
+
train_anchor_time=anchor_time,
|
|
704
|
+
num_train_trials=0,
|
|
705
|
+
num_test_examples=size,
|
|
706
|
+
test_anchor_time=anchor_time,
|
|
707
|
+
num_test_trials=max_iterations * size,
|
|
708
|
+
random_seed=random_seed,
|
|
665
709
|
)
|
|
666
710
|
|
|
667
|
-
entity = self._graph_store.pkey_map_dict[
|
|
668
|
-
query_def.entity.pkey.table_name].index[node]
|
|
669
|
-
|
|
670
711
|
return pd.DataFrame({
|
|
671
|
-
'ENTITY':
|
|
672
|
-
'ANCHOR_TIMESTAMP':
|
|
673
|
-
'TARGET':
|
|
712
|
+
'ENTITY': test.entity_pkey,
|
|
713
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
714
|
+
'TARGET': test.target,
|
|
674
715
|
})
|
|
675
716
|
|
|
676
717
|
# Helpers #################################################################
|
|
677
718
|
|
|
678
|
-
def _parse_query(self, query: str) ->
|
|
679
|
-
if isinstance(query,
|
|
719
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
720
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
680
721
|
return query
|
|
681
722
|
|
|
682
723
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -686,13 +727,12 @@ class KumoRFM:
|
|
|
686
727
|
"predictions or evaluations.")
|
|
687
728
|
|
|
688
729
|
try:
|
|
689
|
-
request =
|
|
730
|
+
request = RFMParseQueryRequest(
|
|
690
731
|
query=query,
|
|
691
732
|
graph_definition=self._graph_def,
|
|
692
733
|
)
|
|
693
734
|
|
|
694
|
-
resp =
|
|
695
|
-
# TODO Expose validation warnings.
|
|
735
|
+
resp = self._api_client.parse_query(request)
|
|
696
736
|
|
|
697
737
|
if len(resp.validation_response.warnings) > 0:
|
|
698
738
|
msg = '\n'.join([
|
|
@@ -702,7 +742,7 @@ class KumoRFM:
|
|
|
702
742
|
warnings.warn(f"Encountered the following warnings during "
|
|
703
743
|
f"parsing:\n{msg}")
|
|
704
744
|
|
|
705
|
-
return resp.
|
|
745
|
+
return resp.query
|
|
706
746
|
except HTTPException as e:
|
|
707
747
|
try:
|
|
708
748
|
msg = json.loads(e.detail)['detail']
|
|
@@ -711,30 +751,91 @@ class KumoRFM:
|
|
|
711
751
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
712
752
|
f"{msg}") from None
|
|
713
753
|
|
|
754
|
+
@staticmethod
|
|
755
|
+
def _get_task_type(
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
757
|
+
edge_types: list[tuple[str, str, str]],
|
|
758
|
+
) -> TaskType:
|
|
759
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
760
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
761
|
+
|
|
762
|
+
target = query.target_ast
|
|
763
|
+
if isinstance(target, Join):
|
|
764
|
+
target = target.rhs_target
|
|
765
|
+
if isinstance(target, Aggregation):
|
|
766
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
767
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
768
|
+
'.')
|
|
769
|
+
target_edge_types = [
|
|
770
|
+
edge_type for edge_type in edge_types
|
|
771
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
772
|
+
]
|
|
773
|
+
if len(target_edge_types) != 1:
|
|
774
|
+
raise NotImplementedError(
|
|
775
|
+
f"Multilabel-classification queries based on "
|
|
776
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
777
|
+
f"planned to write a link prediction query instead, "
|
|
778
|
+
f"make sure to register '{col_name}' as a "
|
|
779
|
+
f"foreign key.")
|
|
780
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
781
|
+
|
|
782
|
+
return TaskType.REGRESSION
|
|
783
|
+
|
|
784
|
+
assert isinstance(target, Column)
|
|
785
|
+
|
|
786
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
787
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
788
|
+
|
|
789
|
+
if target.stype in {Stype.numerical}:
|
|
790
|
+
return TaskType.REGRESSION
|
|
791
|
+
|
|
792
|
+
raise NotImplementedError("Task type not yet supported")
|
|
793
|
+
|
|
794
|
+
def _get_default_anchor_time(
|
|
795
|
+
self,
|
|
796
|
+
query: ValidatedPredictiveQuery,
|
|
797
|
+
) -> pd.Timestamp:
|
|
798
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
799
|
+
aggr_table_names = [
|
|
800
|
+
aggr._get_target_column_name().split('.')[0]
|
|
801
|
+
for aggr in query.get_all_target_aggregations()
|
|
802
|
+
]
|
|
803
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
804
|
+
|
|
805
|
+
assert query.query_type == QueryType.STATIC
|
|
806
|
+
return self._sampler.get_max_time()
|
|
807
|
+
|
|
714
808
|
def _validate_time(
|
|
715
809
|
self,
|
|
716
|
-
query:
|
|
810
|
+
query: ValidatedPredictiveQuery,
|
|
717
811
|
anchor_time: pd.Timestamp,
|
|
718
|
-
context_anchor_time:
|
|
812
|
+
context_anchor_time: pd.Timestamp | None,
|
|
719
813
|
evaluate: bool,
|
|
720
814
|
) -> None:
|
|
721
815
|
|
|
722
|
-
if self.
|
|
816
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
723
817
|
return # Graph without timestamps
|
|
724
818
|
|
|
725
|
-
|
|
819
|
+
min_time = self._sampler.get_min_time()
|
|
820
|
+
max_time = self._sampler.get_max_time()
|
|
821
|
+
|
|
822
|
+
if anchor_time < min_time:
|
|
726
823
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
727
|
-
f"the earliest timestamp "
|
|
728
|
-
f"
|
|
824
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
825
|
+
f"data.")
|
|
729
826
|
|
|
730
|
-
if
|
|
731
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
827
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
732
828
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
733
829
|
f"aggregation time range is too large. To make "
|
|
734
830
|
f"this prediction, we would need data back to "
|
|
735
831
|
f"'{context_anchor_time}', however, your data "
|
|
736
|
-
f"only contains data back to "
|
|
737
|
-
|
|
832
|
+
f"only contains data back to '{min_time}'.")
|
|
833
|
+
|
|
834
|
+
if query.target_ast.date_offset_range is not None:
|
|
835
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
836
|
+
else:
|
|
837
|
+
end_offset = pd.DateOffset(0)
|
|
838
|
+
end_offset = end_offset * query.num_forecasts
|
|
738
839
|
|
|
739
840
|
if (context_anchor_time is not None
|
|
740
841
|
and context_anchor_time > anchor_time):
|
|
@@ -744,51 +845,46 @@ class KumoRFM:
|
|
|
744
845
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
745
846
|
f"intended.")
|
|
746
847
|
elif (query.query_type == QueryType.TEMPORAL
|
|
747
|
-
and context_anchor_time is not None
|
|
748
|
-
|
|
848
|
+
and context_anchor_time is not None
|
|
849
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
749
850
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
750
851
|
f"'{context_anchor_time}' will leak information "
|
|
751
852
|
f"from the prediction anchor timestamp "
|
|
752
853
|
f"'{anchor_time}'. Please make sure this is "
|
|
753
854
|
f"intended.")
|
|
754
855
|
|
|
755
|
-
elif (context_anchor_time is not None
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
_time = context_anchor_time - (query.target.end_offset *
|
|
759
|
-
query.num_forecasts)
|
|
856
|
+
elif (context_anchor_time is not None
|
|
857
|
+
and context_anchor_time - end_offset < min_time):
|
|
858
|
+
_time = context_anchor_time - end_offset
|
|
760
859
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
761
860
|
f"aggregation time range is too large. To form "
|
|
762
861
|
f"proper input data, we would need data back to "
|
|
763
862
|
f"'{_time}', however, your data only contains "
|
|
764
|
-
f"data back to '{
|
|
863
|
+
f"data back to '{min_time}'.")
|
|
765
864
|
|
|
766
|
-
if
|
|
767
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
865
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
768
866
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
769
|
-
f"latest timestamp '{
|
|
770
|
-
f"
|
|
867
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
868
|
+
f"make sure this is intended.")
|
|
771
869
|
|
|
772
|
-
|
|
773
|
-
query.target.end_offset * query.num_forecasts)
|
|
774
|
-
if evaluate and anchor_time > max_eval_time:
|
|
870
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
775
871
|
raise ValueError(
|
|
776
872
|
f"Anchor timestamp for evaluation is after the latest "
|
|
777
|
-
f"supported timestamp '{
|
|
873
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
778
874
|
|
|
779
875
|
def _get_context(
|
|
780
876
|
self,
|
|
781
|
-
query:
|
|
782
|
-
indices:
|
|
783
|
-
anchor_time:
|
|
784
|
-
context_anchor_time:
|
|
877
|
+
query: ValidatedPredictiveQuery,
|
|
878
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
879
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
880
|
+
context_anchor_time: pd.Timestamp | None,
|
|
785
881
|
run_mode: RunMode,
|
|
786
|
-
num_neighbors:
|
|
882
|
+
num_neighbors: list[int] | None,
|
|
787
883
|
num_hops: int,
|
|
788
884
|
max_pq_iterations: int,
|
|
789
885
|
evaluate: bool,
|
|
790
|
-
random_seed:
|
|
791
|
-
logger:
|
|
886
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
887
|
+
logger: ProgressLogger | None = None,
|
|
792
888
|
) -> Context:
|
|
793
889
|
|
|
794
890
|
if num_neighbors is not None:
|
|
@@ -805,10 +901,9 @@ class KumoRFM:
|
|
|
805
901
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
806
902
|
f"must go beyond this for your use-case.")
|
|
807
903
|
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
edge_types=self._graph_store.edge_types,
|
|
904
|
+
task_type = self._get_task_type(
|
|
905
|
+
query=query,
|
|
906
|
+
edge_types=self._sampler.edge_types,
|
|
812
907
|
)
|
|
813
908
|
|
|
814
909
|
if logger is not None:
|
|
@@ -839,11 +934,18 @@ class KumoRFM:
|
|
|
839
934
|
else:
|
|
840
935
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
841
936
|
|
|
937
|
+
if query.target_ast.date_offset_range is None:
|
|
938
|
+
step_offset = pd.DateOffset(0)
|
|
939
|
+
else:
|
|
940
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
941
|
+
end_offset = step_offset * query.num_forecasts
|
|
942
|
+
|
|
842
943
|
if anchor_time is None:
|
|
843
|
-
anchor_time = self.
|
|
944
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
945
|
+
|
|
844
946
|
if evaluate:
|
|
845
|
-
anchor_time = anchor_time -
|
|
846
|
-
|
|
947
|
+
anchor_time = anchor_time - end_offset
|
|
948
|
+
|
|
847
949
|
if logger is not None:
|
|
848
950
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
849
951
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -857,58 +959,71 @@ class KumoRFM:
|
|
|
857
959
|
|
|
858
960
|
assert anchor_time is not None
|
|
859
961
|
if isinstance(anchor_time, pd.Timestamp):
|
|
962
|
+
if context_anchor_time == 'entity':
|
|
963
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
964
|
+
"for context and prediction examples")
|
|
860
965
|
if context_anchor_time is None:
|
|
861
|
-
context_anchor_time = anchor_time -
|
|
862
|
-
query.num_forecasts)
|
|
966
|
+
context_anchor_time = anchor_time - end_offset
|
|
863
967
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
864
968
|
evaluate)
|
|
865
969
|
else:
|
|
866
970
|
assert anchor_time == 'entity'
|
|
867
|
-
if query.
|
|
971
|
+
if query.query_type != QueryType.STATIC:
|
|
972
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
973
|
+
"static predictive queries")
|
|
974
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
868
975
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
869
|
-
f"table '{query.
|
|
976
|
+
f"table '{query.entity_table}' to "
|
|
870
977
|
f"have a time column")
|
|
871
|
-
if context_anchor_time
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
context_anchor_time =
|
|
978
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
979
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
980
|
+
"for context and prediction examples")
|
|
981
|
+
context_anchor_time = 'entity'
|
|
875
982
|
|
|
876
|
-
|
|
983
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
877
984
|
if evaluate:
|
|
878
|
-
|
|
985
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
879
986
|
if task_type.is_link_pred:
|
|
880
|
-
|
|
987
|
+
num_test_examples = num_test_examples // 5
|
|
988
|
+
else:
|
|
989
|
+
num_test_examples = 0
|
|
990
|
+
|
|
991
|
+
train, test = self._sampler.sample_target(
|
|
992
|
+
query=query,
|
|
993
|
+
num_train_examples=num_train_examples,
|
|
994
|
+
train_anchor_time=context_anchor_time,
|
|
995
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
996
|
+
num_test_examples=num_test_examples,
|
|
997
|
+
test_anchor_time=anchor_time,
|
|
998
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
999
|
+
random_seed=random_seed,
|
|
1000
|
+
)
|
|
1001
|
+
train_pkey, train_time, y_train = train
|
|
1002
|
+
test_pkey, test_time, y_test = test
|
|
881
1003
|
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
num_rhs = y_test.explode().nunique()
|
|
903
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
904
|
-
f"{num_rhs:,} unique items")
|
|
905
|
-
else:
|
|
906
|
-
raise NotImplementedError
|
|
907
|
-
logger.log(msg)
|
|
1004
|
+
if evaluate and logger is not None:
|
|
1005
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1006
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1007
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1008
|
+
f"{pos:.2f}% positive cases")
|
|
1009
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1010
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1011
|
+
f"{y_test.nunique()} classes")
|
|
1012
|
+
elif task_type == TaskType.REGRESSION:
|
|
1013
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1014
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1015
|
+
f"between {format_value(_min)} and "
|
|
1016
|
+
f"{format_value(_max)}")
|
|
1017
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1018
|
+
num_rhs = y_test.explode().nunique()
|
|
1019
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1020
|
+
f"{num_rhs:,} unique items")
|
|
1021
|
+
else:
|
|
1022
|
+
raise NotImplementedError
|
|
1023
|
+
logger.log(msg)
|
|
908
1024
|
|
|
909
|
-
|
|
1025
|
+
if not evaluate:
|
|
910
1026
|
assert indices is not None
|
|
911
|
-
|
|
912
1027
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
913
1028
|
raise ValueError(f"Cannot predict for more than "
|
|
914
1029
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -916,27 +1031,12 @@ class KumoRFM:
|
|
|
916
1031
|
f"`KumoRFM.batch_mode` to process entities "
|
|
917
1032
|
f"in batches")
|
|
918
1033
|
|
|
919
|
-
|
|
920
|
-
table_name=query.entity.pkey.table_name,
|
|
921
|
-
pkey=pd.Series(indices),
|
|
922
|
-
)
|
|
923
|
-
|
|
1034
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
924
1035
|
if isinstance(anchor_time, pd.Timestamp):
|
|
925
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
926
|
-
len(
|
|
1036
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1037
|
+
len(indices)).reset_index(drop=True)
|
|
927
1038
|
else:
|
|
928
|
-
|
|
929
|
-
query.entity.pkey.table_name]
|
|
930
|
-
time = time[test_node] * 1000**3
|
|
931
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
932
|
-
|
|
933
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
934
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
935
|
-
anchor_time=context_anchor_time or 'entity',
|
|
936
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
937
|
-
or anchor_time == 'entity') else None,
|
|
938
|
-
max_iterations=max_pq_iterations,
|
|
939
|
-
)
|
|
1039
|
+
train_time = test_time = 'entity'
|
|
940
1040
|
|
|
941
1041
|
if logger is not None:
|
|
942
1042
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -959,27 +1059,41 @@ class KumoRFM:
|
|
|
959
1059
|
raise NotImplementedError
|
|
960
1060
|
logger.log(msg)
|
|
961
1061
|
|
|
962
|
-
entity_table_names
|
|
963
|
-
|
|
1062
|
+
entity_table_names: tuple[str, ...]
|
|
1063
|
+
if task_type.is_link_pred:
|
|
1064
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1065
|
+
assert final_aggr is not None
|
|
1066
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1067
|
+
for edge_type in self._sampler.edge_types:
|
|
1068
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1069
|
+
entity_table_names = (
|
|
1070
|
+
query.entity_table,
|
|
1071
|
+
edge_type[2],
|
|
1072
|
+
)
|
|
1073
|
+
else:
|
|
1074
|
+
entity_table_names = (query.entity_table, )
|
|
964
1075
|
|
|
965
1076
|
# Exclude the entity anchor time from the feature set to prevent
|
|
966
1077
|
# running out-of-distribution between in-context and test examples:
|
|
967
|
-
exclude_cols_dict = query.
|
|
968
|
-
if
|
|
1078
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1079
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
969
1080
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
970
1081
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
971
|
-
|
|
972
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1082
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
973
1083
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
974
1084
|
|
|
975
|
-
subgraph = self.
|
|
1085
|
+
subgraph = self._sampler.sample_subgraph(
|
|
976
1086
|
entity_table_names=entity_table_names,
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
1087
|
+
entity_pkey=pd.concat(
|
|
1088
|
+
[train_pkey, test_pkey],
|
|
1089
|
+
axis=0,
|
|
1090
|
+
ignore_index=True,
|
|
1091
|
+
),
|
|
1092
|
+
anchor_time=pd.concat(
|
|
1093
|
+
[train_time, test_time],
|
|
1094
|
+
axis=0,
|
|
1095
|
+
ignore_index=True,
|
|
1096
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
983
1097
|
num_neighbors=num_neighbors,
|
|
984
1098
|
exclude_cols_dict=exclude_cols_dict,
|
|
985
1099
|
)
|
|
@@ -991,23 +1105,19 @@ class KumoRFM:
|
|
|
991
1105
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
992
1106
|
f"must go beyond this for your use-case.")
|
|
993
1107
|
|
|
994
|
-
step_size: Optional[int] = None
|
|
995
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
996
|
-
step_size = date_offset_to_seconds(query.target.end_offset)
|
|
997
|
-
|
|
998
1108
|
return Context(
|
|
999
1109
|
task_type=task_type,
|
|
1000
1110
|
entity_table_names=entity_table_names,
|
|
1001
1111
|
subgraph=subgraph,
|
|
1002
1112
|
y_train=y_train,
|
|
1003
|
-
y_test=y_test,
|
|
1113
|
+
y_test=y_test if evaluate else None,
|
|
1004
1114
|
top_k=query.top_k,
|
|
1005
|
-
step_size=
|
|
1115
|
+
step_size=None,
|
|
1006
1116
|
)
|
|
1007
1117
|
|
|
1008
1118
|
@staticmethod
|
|
1009
1119
|
def _validate_metrics(
|
|
1010
|
-
metrics:
|
|
1120
|
+
metrics: list[str],
|
|
1011
1121
|
task_type: TaskType,
|
|
1012
1122
|
) -> None:
|
|
1013
1123
|
|
|
@@ -1064,7 +1174,7 @@ class KumoRFM:
|
|
|
1064
1174
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1065
1175
|
|
|
1066
1176
|
|
|
1067
|
-
def format_value(value:
|
|
1177
|
+
def format_value(value: int | float) -> str:
|
|
1068
1178
|
if value == int(value):
|
|
1069
1179
|
return f'{int(value):,}'
|
|
1070
1180
|
if abs(value) >= 1000:
|