kumoai 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +44 -9
- kumoai/experimental/rfm/__init__.py +70 -68
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +67 -0
- kumoai/experimental/rfm/base/sampler.py +782 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +366 -0
- kumoai/experimental/rfm/base/table.py +741 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +581 -154
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +775 -481
- kumoai/experimental/rfm/sagemaker.py +15 -7
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/METADATA +10 -8
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/RECORD +54 -30
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,26 +1,23 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
import time
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
6
|
+
from collections.abc import Generator, Iterator
|
|
6
7
|
from contextlib import contextmanager
|
|
7
8
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
9
|
+
from typing import Any, Literal, overload
|
|
19
10
|
|
|
20
|
-
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
24
21
|
from kumoapi.rfm import Context
|
|
25
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
23
|
from kumoapi.rfm import (
|
|
@@ -29,35 +26,38 @@ from kumoapi.rfm import (
|
|
|
29
26
|
RFMPredictRequest,
|
|
30
27
|
)
|
|
31
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.markdown import Markdown
|
|
32
32
|
|
|
33
|
+
from kumoai import in_notebook
|
|
33
34
|
from kumoai.client.rfm import RFMAPI
|
|
34
35
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
36
|
+
from kumoai.experimental.rfm import Graph, TaskTable
|
|
37
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
38
|
from kumoai.mixin import CastMixin
|
|
43
|
-
from kumoai.utils import
|
|
39
|
+
from kumoai.utils import ProgressLogger, display
|
|
44
40
|
|
|
45
41
|
_RANDOM_SEED = 42
|
|
46
42
|
|
|
47
43
|
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
48
44
|
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
49
45
|
|
|
46
|
+
_MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
|
|
47
|
+
_MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
|
|
48
|
+
|
|
50
49
|
_MAX_CONTEXT_SIZE = {
|
|
51
50
|
RunMode.DEBUG: 100,
|
|
52
51
|
RunMode.FAST: 1_000,
|
|
53
52
|
RunMode.NORMAL: 5_000,
|
|
54
53
|
RunMode.BEST: 10_000,
|
|
55
54
|
}
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
RunMode.
|
|
59
|
-
RunMode.
|
|
60
|
-
RunMode.
|
|
55
|
+
|
|
56
|
+
_DEFAULT_NUM_NEIGHBORS = {
|
|
57
|
+
RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
|
|
58
|
+
RunMode.FAST: [32, 32, 8, 8, 4, 4],
|
|
59
|
+
RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
|
|
60
|
+
RunMode.BEST: [64, 64, 8, 8, 4, 4],
|
|
61
61
|
}
|
|
62
62
|
|
|
63
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
@@ -95,24 +95,36 @@ class Explanation:
|
|
|
95
95
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
96
|
pass
|
|
97
97
|
|
|
98
|
-
def __getitem__(self, index: int) ->
|
|
98
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
99
99
|
if index == 0:
|
|
100
100
|
return self.prediction
|
|
101
101
|
if index == 1:
|
|
102
102
|
return self.summary
|
|
103
103
|
raise IndexError("Index out of range")
|
|
104
104
|
|
|
105
|
-
def __iter__(self) -> Iterator[
|
|
105
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
106
106
|
return iter((self.prediction, self.summary))
|
|
107
107
|
|
|
108
108
|
def __repr__(self) -> str:
|
|
109
109
|
return str((self.prediction, self.summary))
|
|
110
110
|
|
|
111
|
-
def
|
|
112
|
-
|
|
111
|
+
def __str__(self) -> str:
|
|
112
|
+
console = Console(soft_wrap=True)
|
|
113
|
+
with console.capture() as cap:
|
|
114
|
+
console.print(display.to_rich_table(self.prediction))
|
|
115
|
+
console.print(Markdown(self.summary))
|
|
116
|
+
return cap.get()[:-1]
|
|
117
|
+
|
|
118
|
+
def print(self) -> None:
|
|
119
|
+
r"""Prints the explanation."""
|
|
120
|
+
if in_notebook():
|
|
121
|
+
display.dataframe(self.prediction)
|
|
122
|
+
display.message(self.summary)
|
|
123
|
+
else:
|
|
124
|
+
print(self)
|
|
113
125
|
|
|
114
|
-
|
|
115
|
-
|
|
126
|
+
def _ipython_display_(self) -> None:
|
|
127
|
+
self.print()
|
|
116
128
|
|
|
117
129
|
|
|
118
130
|
class KumoRFM:
|
|
@@ -123,17 +135,17 @@ class KumoRFM:
|
|
|
123
135
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
124
136
|
relational dataset without training.
|
|
125
137
|
The model is pre-trained and the class provides an interface to query the
|
|
126
|
-
model from a :class:`
|
|
138
|
+
model from a :class:`Graph` object.
|
|
127
139
|
|
|
128
140
|
.. code-block:: python
|
|
129
141
|
|
|
130
|
-
from kumoai.experimental.rfm import
|
|
142
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
131
143
|
|
|
132
144
|
df_users = pd.DataFrame(...)
|
|
133
145
|
df_items = pd.DataFrame(...)
|
|
134
146
|
df_orders = pd.DataFrame(...)
|
|
135
147
|
|
|
136
|
-
graph =
|
|
148
|
+
graph = Graph.from_data({
|
|
137
149
|
'users': df_users,
|
|
138
150
|
'items': df_items,
|
|
139
151
|
'orders': df_orders,
|
|
@@ -150,40 +162,78 @@ class KumoRFM:
|
|
|
150
162
|
|
|
151
163
|
Args:
|
|
152
164
|
graph: The graph.
|
|
153
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
154
|
-
materialization.
|
|
155
|
-
This is a runtime trade-off between graph materialization and model
|
|
156
|
-
processing speed.
|
|
157
|
-
It can be benefical to preprocess your data once and then run many
|
|
158
|
-
queries on top to achieve maximum model speed.
|
|
159
|
-
However, if activiated, graph materialization can take potentially
|
|
160
|
-
much longer, especially on graphs with many large text columns.
|
|
161
|
-
Best to tune this option manually.
|
|
162
165
|
verbose: Whether to print verbose output.
|
|
166
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
167
|
+
for optimal querying. For example, for transactional database
|
|
168
|
+
backends, will create any missing indices. Requires write-access to
|
|
169
|
+
the data backend.
|
|
163
170
|
"""
|
|
164
171
|
def __init__(
|
|
165
172
|
self,
|
|
166
|
-
graph:
|
|
167
|
-
|
|
168
|
-
|
|
173
|
+
graph: Graph,
|
|
174
|
+
verbose: bool | ProgressLogger = True,
|
|
175
|
+
optimize: bool = False,
|
|
169
176
|
) -> None:
|
|
170
177
|
graph = graph.validate()
|
|
171
178
|
self._graph_def = graph._to_api_graph_definition()
|
|
172
|
-
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
173
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
174
179
|
|
|
175
|
-
|
|
176
|
-
|
|
180
|
+
if graph.backend == DataBackend.LOCAL:
|
|
181
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
182
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
183
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
184
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
185
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
186
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
187
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
188
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
189
|
+
else:
|
|
190
|
+
raise NotImplementedError
|
|
191
|
+
|
|
192
|
+
self._client: RFMAPI | None = None
|
|
193
|
+
|
|
194
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
195
|
+
self._num_retries: int = 0
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def _api_client(self) -> RFMAPI:
|
|
199
|
+
if self._client is not None:
|
|
200
|
+
return self._client
|
|
201
|
+
|
|
177
202
|
from kumoai.experimental.rfm import global_state
|
|
178
|
-
self.
|
|
203
|
+
self._client = RFMAPI(global_state.client)
|
|
204
|
+
return self._client
|
|
179
205
|
|
|
180
206
|
def __repr__(self) -> str:
|
|
181
207
|
return f'{self.__class__.__name__}()'
|
|
182
208
|
|
|
209
|
+
@contextmanager
|
|
210
|
+
def retry(
|
|
211
|
+
self,
|
|
212
|
+
num_retries: int = 1,
|
|
213
|
+
) -> Generator[None, None, None]:
|
|
214
|
+
"""Context manager to retry failed queries due to unexpected server
|
|
215
|
+
issues.
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
with model.retry(num_retries=1):
|
|
220
|
+
df = model.predict(query, indices=...)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
num_retries: The maximum number of retries.
|
|
224
|
+
"""
|
|
225
|
+
if num_retries < 0:
|
|
226
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
227
|
+
f"zero (got {num_retries})")
|
|
228
|
+
|
|
229
|
+
self._num_retries = num_retries
|
|
230
|
+
yield
|
|
231
|
+
self._num_retries = 0
|
|
232
|
+
|
|
183
233
|
@contextmanager
|
|
184
234
|
def batch_mode(
|
|
185
235
|
self,
|
|
186
|
-
batch_size:
|
|
236
|
+
batch_size: int | Literal['max'] = 'max',
|
|
187
237
|
num_retries: int = 1,
|
|
188
238
|
) -> Generator[None, None, None]:
|
|
189
239
|
"""Context manager to predict in batches.
|
|
@@ -203,31 +253,26 @@ class KumoRFM:
|
|
|
203
253
|
raise ValueError(f"'batch_size' must be greater than zero "
|
|
204
254
|
f"(got {batch_size})")
|
|
205
255
|
|
|
206
|
-
if num_retries < 0:
|
|
207
|
-
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
208
|
-
f"zero (got {num_retries})")
|
|
209
|
-
|
|
210
256
|
self._batch_size = batch_size
|
|
211
|
-
self.
|
|
212
|
-
|
|
257
|
+
with self.retry(self._num_retries or num_retries):
|
|
258
|
+
yield
|
|
213
259
|
self._batch_size = None
|
|
214
|
-
self.num_retries = 0
|
|
215
260
|
|
|
216
261
|
@overload
|
|
217
262
|
def predict(
|
|
218
263
|
self,
|
|
219
264
|
query: str,
|
|
220
|
-
indices:
|
|
265
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
221
266
|
*,
|
|
222
267
|
explain: Literal[False] = False,
|
|
223
|
-
anchor_time:
|
|
224
|
-
context_anchor_time:
|
|
225
|
-
run_mode:
|
|
226
|
-
num_neighbors:
|
|
268
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
269
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
270
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
271
|
+
num_neighbors: list[int] | None = None,
|
|
227
272
|
num_hops: int = 2,
|
|
228
|
-
max_pq_iterations: int =
|
|
229
|
-
random_seed:
|
|
230
|
-
verbose:
|
|
273
|
+
max_pq_iterations: int = 10,
|
|
274
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
275
|
+
verbose: bool | ProgressLogger = True,
|
|
231
276
|
use_prediction_time: bool = False,
|
|
232
277
|
) -> pd.DataFrame:
|
|
233
278
|
pass
|
|
@@ -236,37 +281,56 @@ class KumoRFM:
|
|
|
236
281
|
def predict(
|
|
237
282
|
self,
|
|
238
283
|
query: str,
|
|
239
|
-
indices:
|
|
284
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
240
285
|
*,
|
|
241
|
-
explain:
|
|
242
|
-
anchor_time:
|
|
243
|
-
context_anchor_time:
|
|
244
|
-
run_mode:
|
|
245
|
-
num_neighbors:
|
|
286
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
287
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
288
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
289
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
290
|
+
num_neighbors: list[int] | None = None,
|
|
246
291
|
num_hops: int = 2,
|
|
247
|
-
max_pq_iterations: int =
|
|
248
|
-
random_seed:
|
|
249
|
-
verbose:
|
|
292
|
+
max_pq_iterations: int = 10,
|
|
293
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
294
|
+
verbose: bool | ProgressLogger = True,
|
|
250
295
|
use_prediction_time: bool = False,
|
|
251
296
|
) -> Explanation:
|
|
252
297
|
pass
|
|
253
298
|
|
|
299
|
+
@overload
|
|
254
300
|
def predict(
|
|
255
301
|
self,
|
|
256
302
|
query: str,
|
|
257
|
-
indices:
|
|
303
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
258
304
|
*,
|
|
259
|
-
explain:
|
|
260
|
-
anchor_time:
|
|
261
|
-
context_anchor_time:
|
|
262
|
-
run_mode:
|
|
263
|
-
num_neighbors:
|
|
305
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
306
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
307
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
308
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
309
|
+
num_neighbors: list[int] | None = None,
|
|
264
310
|
num_hops: int = 2,
|
|
265
|
-
max_pq_iterations: int =
|
|
266
|
-
random_seed:
|
|
267
|
-
verbose:
|
|
311
|
+
max_pq_iterations: int = 10,
|
|
312
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
313
|
+
verbose: bool | ProgressLogger = True,
|
|
268
314
|
use_prediction_time: bool = False,
|
|
269
|
-
) ->
|
|
315
|
+
) -> pd.DataFrame | Explanation:
|
|
316
|
+
pass
|
|
317
|
+
|
|
318
|
+
def predict(
|
|
319
|
+
self,
|
|
320
|
+
query: str,
|
|
321
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
322
|
+
*,
|
|
323
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
324
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
325
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
326
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
327
|
+
num_neighbors: list[int] | None = None,
|
|
328
|
+
num_hops: int = 2,
|
|
329
|
+
max_pq_iterations: int = 10,
|
|
330
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
331
|
+
verbose: bool | ProgressLogger = True,
|
|
332
|
+
use_prediction_time: bool = False,
|
|
333
|
+
) -> pd.DataFrame | Explanation:
|
|
270
334
|
"""Returns predictions for a predictive query.
|
|
271
335
|
|
|
272
336
|
Args:
|
|
@@ -274,8 +338,7 @@ class KumoRFM:
|
|
|
274
338
|
indices: The entity primary keys to predict for. Will override the
|
|
275
339
|
indices given as part of the predictive query. Predictions will
|
|
276
340
|
be generated for all indices, independent of whether they
|
|
277
|
-
fulfill entity filter constraints.
|
|
278
|
-
:meth:`~KumoRFM.is_valid_entity`.
|
|
341
|
+
fulfill entity filter constraints.
|
|
279
342
|
explain: Configuration for explainability.
|
|
280
343
|
If set to ``True``, will additionally explain the prediction.
|
|
281
344
|
Passing in an :class:`ExplainConfig` instance provides control
|
|
@@ -308,18 +371,152 @@ class KumoRFM:
|
|
|
308
371
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
309
372
|
containing the prediction, summary, and details.
|
|
310
373
|
"""
|
|
311
|
-
explain_config: Optional[ExplainConfig] = None
|
|
312
|
-
if explain is True:
|
|
313
|
-
explain_config = ExplainConfig()
|
|
314
|
-
elif explain is not False:
|
|
315
|
-
explain_config = ExplainConfig._cast(explain)
|
|
316
|
-
|
|
317
374
|
query_def = self._parse_query(query)
|
|
318
|
-
query_str = query_def.to_string()
|
|
319
375
|
|
|
376
|
+
if indices is None:
|
|
377
|
+
if query_def.rfm_entity_ids is None:
|
|
378
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
379
|
+
"pass them via `predict(query, indices=...)`")
|
|
380
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
381
|
+
query_def = replace(
|
|
382
|
+
query_def,
|
|
383
|
+
for_each='FOR EACH',
|
|
384
|
+
rfm_entity_ids=None,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if not isinstance(verbose, ProgressLogger):
|
|
388
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
389
|
+
if explain is not False:
|
|
390
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
391
|
+
else:
|
|
392
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
393
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
394
|
+
|
|
395
|
+
with verbose as logger:
|
|
396
|
+
task_table = self._get_task_table(
|
|
397
|
+
query=query_def,
|
|
398
|
+
indices=indices,
|
|
399
|
+
anchor_time=anchor_time,
|
|
400
|
+
context_anchor_time=context_anchor_time,
|
|
401
|
+
run_mode=run_mode,
|
|
402
|
+
max_pq_iterations=max_pq_iterations,
|
|
403
|
+
random_seed=random_seed,
|
|
404
|
+
logger=logger,
|
|
405
|
+
)
|
|
406
|
+
task_table._query = query_def.to_string()
|
|
407
|
+
|
|
408
|
+
return self.predict_task(
|
|
409
|
+
task_table,
|
|
410
|
+
explain=explain,
|
|
411
|
+
run_mode=run_mode,
|
|
412
|
+
num_neighbors=num_neighbors,
|
|
413
|
+
num_hops=num_hops,
|
|
414
|
+
verbose=verbose,
|
|
415
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
416
|
+
use_prediction_time=use_prediction_time,
|
|
417
|
+
top_k=query_def.top_k,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
@overload
|
|
421
|
+
def predict_task(
|
|
422
|
+
self,
|
|
423
|
+
task: TaskTable,
|
|
424
|
+
*,
|
|
425
|
+
explain: Literal[False] = False,
|
|
426
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
427
|
+
num_neighbors: list[int] | None = None,
|
|
428
|
+
num_hops: int = 2,
|
|
429
|
+
verbose: bool | ProgressLogger = True,
|
|
430
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
431
|
+
use_prediction_time: bool = False,
|
|
432
|
+
top_k: int | None = None,
|
|
433
|
+
) -> pd.DataFrame:
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
@overload
|
|
437
|
+
def predict_task(
|
|
438
|
+
self,
|
|
439
|
+
task: TaskTable,
|
|
440
|
+
*,
|
|
441
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
442
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
443
|
+
num_neighbors: list[int] | None = None,
|
|
444
|
+
num_hops: int = 2,
|
|
445
|
+
verbose: bool | ProgressLogger = True,
|
|
446
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
447
|
+
use_prediction_time: bool = False,
|
|
448
|
+
top_k: int | None = None,
|
|
449
|
+
) -> Explanation:
|
|
450
|
+
pass
|
|
451
|
+
|
|
452
|
+
@overload
|
|
453
|
+
def predict_task(
|
|
454
|
+
self,
|
|
455
|
+
task: TaskTable,
|
|
456
|
+
*,
|
|
457
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
458
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
459
|
+
num_neighbors: list[int] | None = None,
|
|
460
|
+
num_hops: int = 2,
|
|
461
|
+
verbose: bool | ProgressLogger = True,
|
|
462
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
463
|
+
use_prediction_time: bool = False,
|
|
464
|
+
top_k: int | None = None,
|
|
465
|
+
) -> pd.DataFrame | Explanation:
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
def predict_task(
|
|
469
|
+
self,
|
|
470
|
+
task: TaskTable,
|
|
471
|
+
*,
|
|
472
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
473
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
474
|
+
num_neighbors: list[int] | None = None,
|
|
475
|
+
num_hops: int = 2,
|
|
476
|
+
verbose: bool | ProgressLogger = True,
|
|
477
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
478
|
+
use_prediction_time: bool = False,
|
|
479
|
+
top_k: int | None = None,
|
|
480
|
+
) -> pd.DataFrame | Explanation:
|
|
481
|
+
"""Returns predictions for a custom task specification.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
task: The custom :class:`TaskTable`.
|
|
485
|
+
explain: Configuration for explainability.
|
|
486
|
+
If set to ``True``, will additionally explain the prediction.
|
|
487
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
488
|
+
over which parts of explanation are generated.
|
|
489
|
+
Explainability is currently only supported for single entity
|
|
490
|
+
predictions with ``run_mode="FAST"``.
|
|
491
|
+
run_mode: The :class:`RunMode` for the query.
|
|
492
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
493
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
494
|
+
num_hops: The number of hops to sample when generating the context.
|
|
495
|
+
verbose: Whether to print verbose output.
|
|
496
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
497
|
+
model input.
|
|
498
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
499
|
+
additional feature during prediction. This is typically
|
|
500
|
+
beneficial for time series forecasting tasks.
|
|
501
|
+
top_k: The number of predictions to return per entity.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
505
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
506
|
+
containing the prediction, summary, and details.
|
|
507
|
+
"""
|
|
320
508
|
if num_hops != 2 and num_neighbors is not None:
|
|
321
509
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
322
510
|
f"custom 'num_hops={num_hops}' option")
|
|
511
|
+
if num_neighbors is None:
|
|
512
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
513
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
514
|
+
|
|
515
|
+
explain_config: ExplainConfig | None = None
|
|
516
|
+
if explain is True:
|
|
517
|
+
explain_config = ExplainConfig()
|
|
518
|
+
elif explain is not False:
|
|
519
|
+
explain_config = ExplainConfig._cast(explain)
|
|
323
520
|
|
|
324
521
|
if explain_config is not None and run_mode in {
|
|
325
522
|
RunMode.NORMAL, RunMode.BEST
|
|
@@ -328,83 +525,82 @@ class KumoRFM:
|
|
|
328
525
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
329
526
|
f"mode has been reset. Please lower the run mode to "
|
|
330
527
|
f"suppress this warning.")
|
|
528
|
+
run_mode = RunMode.FAST
|
|
331
529
|
|
|
332
|
-
if
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
337
|
-
else:
|
|
338
|
-
query_def = replace(query_def, rfm_entity_ids=None)
|
|
339
|
-
|
|
340
|
-
if len(indices) == 0:
|
|
341
|
-
raise ValueError("At least one entity is required")
|
|
342
|
-
|
|
343
|
-
if explain_config is not None and len(indices) > 1:
|
|
344
|
-
raise ValueError(
|
|
345
|
-
f"Cannot explain predictions for more than a single entity "
|
|
346
|
-
f"(got {len(indices)})")
|
|
347
|
-
|
|
348
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
349
|
-
if explain_config is not None:
|
|
350
|
-
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
351
|
-
else:
|
|
352
|
-
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
530
|
+
if explain_config is not None and task.num_prediction_examples > 1:
|
|
531
|
+
raise ValueError(f"Cannot explain predictions for more than a "
|
|
532
|
+
f"single entity "
|
|
533
|
+
f"(got {task.num_prediction_examples:,})")
|
|
353
534
|
|
|
354
535
|
if not isinstance(verbose, ProgressLogger):
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
edge_types=self._graph_store.edge_types,
|
|
364
|
-
)
|
|
365
|
-
batch_size = _MAX_PRED_SIZE[task_type]
|
|
536
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
537
|
+
task_type_repr = 'binary classification'
|
|
538
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
539
|
+
task_type_repr = 'multi-class classification'
|
|
540
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
541
|
+
task_type_repr = 'regression'
|
|
542
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
543
|
+
task_type_repr = 'link prediction'
|
|
366
544
|
else:
|
|
367
|
-
|
|
545
|
+
task_type_repr = str(task.task_type)
|
|
368
546
|
|
|
369
|
-
if
|
|
370
|
-
|
|
371
|
-
batches = [indices[step:step + batch_size] for step in offsets]
|
|
547
|
+
if explain_config is not None:
|
|
548
|
+
msg = f"Explaining {task_type_repr} task"
|
|
372
549
|
else:
|
|
373
|
-
|
|
550
|
+
msg = f"Predicting {task_type_repr} task"
|
|
551
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
374
552
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
553
|
+
with verbose as logger:
|
|
554
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
555
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
556
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
557
|
+
f"examples")
|
|
558
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
559
|
+
|
|
560
|
+
if self._batch_size is None:
|
|
561
|
+
batch_size = task.num_prediction_examples
|
|
562
|
+
elif self._batch_size == 'max':
|
|
563
|
+
batch_size = _MAX_PRED_SIZE[task.task_type]
|
|
564
|
+
else:
|
|
565
|
+
batch_size = self._batch_size
|
|
378
566
|
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
567
|
+
if batch_size > _MAX_PRED_SIZE[task.task_type]:
|
|
568
|
+
raise ValueError(f"Cannot predict for more than "
|
|
569
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
570
|
+
f"entities at once (got {batch_size:,}). Use "
|
|
571
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
572
|
+
f"in batches with a sufficient batch size.")
|
|
573
|
+
|
|
574
|
+
if task.num_prediction_examples > batch_size:
|
|
575
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
576
|
+
logger.log(f"Splitting {task.num_prediction_examples:,} "
|
|
577
|
+
f"entities into {num:,} batches of size "
|
|
578
|
+
f"{batch_size:,}")
|
|
579
|
+
|
|
580
|
+
predictions: list[pd.DataFrame] = []
|
|
581
|
+
summary: str | None = None
|
|
582
|
+
details: Explanation | None = None
|
|
583
|
+
for start in range(0, task.num_prediction_examples, batch_size):
|
|
384
584
|
context = self._get_context(
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
anchor_time=anchor_time,
|
|
388
|
-
context_anchor_time=context_anchor_time,
|
|
389
|
-
run_mode=RunMode(run_mode),
|
|
585
|
+
task=task.narrow_prediction(start, length=batch_size),
|
|
586
|
+
run_mode=run_mode,
|
|
390
587
|
num_neighbors=num_neighbors,
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
evaluate=False,
|
|
394
|
-
random_seed=random_seed,
|
|
395
|
-
logger=logger if i == 0 else None,
|
|
588
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
589
|
+
top_k=top_k,
|
|
396
590
|
)
|
|
591
|
+
context.y_test = None
|
|
592
|
+
|
|
397
593
|
request = RFMPredictRequest(
|
|
398
594
|
context=context,
|
|
399
595
|
run_mode=RunMode(run_mode),
|
|
400
|
-
query=
|
|
596
|
+
query=task._query,
|
|
401
597
|
use_prediction_time=use_prediction_time,
|
|
402
598
|
)
|
|
403
599
|
with warnings.catch_warnings():
|
|
404
600
|
warnings.filterwarnings('ignore', message='gencode')
|
|
405
601
|
request_msg = request.to_protobuf()
|
|
406
602
|
_bytes = request_msg.SerializeToString()
|
|
407
|
-
if
|
|
603
|
+
if start == 0:
|
|
408
604
|
logger.log(f"Generated context of size "
|
|
409
605
|
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
410
606
|
|
|
@@ -412,14 +608,11 @@ class KumoRFM:
|
|
|
412
608
|
stats = Context.get_memory_stats(request_msg.context)
|
|
413
609
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
414
610
|
|
|
415
|
-
if
|
|
416
|
-
|
|
417
|
-
verbose.init_progress(
|
|
418
|
-
total=len(batches),
|
|
419
|
-
description='Predicting',
|
|
420
|
-
)
|
|
611
|
+
if start == 0 and task.num_prediction_examples > batch_size:
|
|
612
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
613
|
+
verbose.init_progress(total=num, description='Predicting')
|
|
421
614
|
|
|
422
|
-
for attempt in range(self.
|
|
615
|
+
for attempt in range(self._num_retries + 1):
|
|
423
616
|
try:
|
|
424
617
|
if explain_config is not None:
|
|
425
618
|
resp = self._api_client.explain(
|
|
@@ -434,10 +627,10 @@ class KumoRFM:
|
|
|
434
627
|
|
|
435
628
|
# Cast 'ENTITY' to correct data type:
|
|
436
629
|
if 'ENTITY' in df:
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
630
|
+
table_dict = context.subgraph.table_dict
|
|
631
|
+
table = table_dict[context.entity_table_names[0]]
|
|
632
|
+
ser = table.df[table.primary_key]
|
|
633
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
441
634
|
|
|
442
635
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
443
636
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -452,13 +645,12 @@ class KumoRFM:
|
|
|
452
645
|
|
|
453
646
|
predictions.append(df)
|
|
454
647
|
|
|
455
|
-
if
|
|
456
|
-
and len(batches) > 1):
|
|
648
|
+
if task.num_prediction_examples > batch_size:
|
|
457
649
|
verbose.step()
|
|
458
650
|
|
|
459
651
|
break
|
|
460
652
|
except HTTPException as e:
|
|
461
|
-
if attempt == self.
|
|
653
|
+
if attempt == self._num_retries:
|
|
462
654
|
try:
|
|
463
655
|
msg = json.loads(e.detail)['detail']
|
|
464
656
|
except Exception:
|
|
@@ -488,69 +680,19 @@ class KumoRFM:
|
|
|
488
680
|
|
|
489
681
|
return prediction
|
|
490
682
|
|
|
491
|
-
def is_valid_entity(
|
|
492
|
-
self,
|
|
493
|
-
query: str,
|
|
494
|
-
indices: Union[List[str], List[float], List[int], None] = None,
|
|
495
|
-
*,
|
|
496
|
-
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
497
|
-
) -> np.ndarray:
|
|
498
|
-
r"""Returns a mask that denotes which entities are valid for the
|
|
499
|
-
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
500
|
-
entity filter constraints.
|
|
501
|
-
|
|
502
|
-
Args:
|
|
503
|
-
query: The predictive query.
|
|
504
|
-
indices: The entity primary keys to predict for. Will override the
|
|
505
|
-
indices given as part of the predictive query.
|
|
506
|
-
anchor_time: The anchor timestamp for the prediction. If set to
|
|
507
|
-
``None``, will use the maximum timestamp in the data.
|
|
508
|
-
If set to ``"entity"``, will use the timestamp of the entity.
|
|
509
|
-
"""
|
|
510
|
-
query_def = self._parse_query(query)
|
|
511
|
-
|
|
512
|
-
if indices is None:
|
|
513
|
-
if query_def.rfm_entity_ids is None:
|
|
514
|
-
raise ValueError("Cannot find entities to predict for. Please "
|
|
515
|
-
"pass them via "
|
|
516
|
-
"`is_valid_entity(query, indices=...)`")
|
|
517
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
518
|
-
|
|
519
|
-
if len(indices) == 0:
|
|
520
|
-
raise ValueError("At least one entity is required")
|
|
521
|
-
|
|
522
|
-
if anchor_time is None:
|
|
523
|
-
anchor_time = self._graph_store.max_time
|
|
524
|
-
|
|
525
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
526
|
-
self._validate_time(query_def, anchor_time, None, False)
|
|
527
|
-
else:
|
|
528
|
-
assert anchor_time == 'entity'
|
|
529
|
-
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
530
|
-
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
531
|
-
f"table '{query_def.entity_table}' "
|
|
532
|
-
f"to have a time column.")
|
|
533
|
-
|
|
534
|
-
node = self._graph_store.get_node_id(
|
|
535
|
-
table_name=query_def.entity_table,
|
|
536
|
-
pkey=pd.Series(indices),
|
|
537
|
-
)
|
|
538
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
539
|
-
return query_driver.is_valid(node, anchor_time)
|
|
540
|
-
|
|
541
683
|
def evaluate(
|
|
542
684
|
self,
|
|
543
685
|
query: str,
|
|
544
686
|
*,
|
|
545
|
-
metrics:
|
|
546
|
-
anchor_time:
|
|
547
|
-
context_anchor_time:
|
|
548
|
-
run_mode:
|
|
549
|
-
num_neighbors:
|
|
687
|
+
metrics: list[str] | None = None,
|
|
688
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
689
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
690
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
691
|
+
num_neighbors: list[int] | None = None,
|
|
550
692
|
num_hops: int = 2,
|
|
551
|
-
max_pq_iterations: int =
|
|
552
|
-
random_seed:
|
|
553
|
-
verbose:
|
|
693
|
+
max_pq_iterations: int = 10,
|
|
694
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
695
|
+
verbose: bool | ProgressLogger = True,
|
|
554
696
|
use_prediction_time: bool = False,
|
|
555
697
|
) -> pd.DataFrame:
|
|
556
698
|
"""Evaluates a predictive query.
|
|
@@ -582,41 +724,120 @@ class KumoRFM:
|
|
|
582
724
|
Returns:
|
|
583
725
|
The metrics as a :class:`pandas.DataFrame`
|
|
584
726
|
"""
|
|
585
|
-
query_def =
|
|
727
|
+
query_def = replace(
|
|
728
|
+
self._parse_query(query),
|
|
729
|
+
for_each='FOR EACH',
|
|
730
|
+
rfm_entity_ids=None,
|
|
731
|
+
)
|
|
586
732
|
|
|
733
|
+
if not isinstance(verbose, ProgressLogger):
|
|
734
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
735
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
736
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
737
|
+
|
|
738
|
+
with verbose as logger:
|
|
739
|
+
task_table = self._get_task_table(
|
|
740
|
+
query=query_def,
|
|
741
|
+
indices=None,
|
|
742
|
+
anchor_time=anchor_time,
|
|
743
|
+
context_anchor_time=context_anchor_time,
|
|
744
|
+
run_mode=run_mode,
|
|
745
|
+
max_pq_iterations=max_pq_iterations,
|
|
746
|
+
random_seed=random_seed,
|
|
747
|
+
logger=logger,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
return self.evaluate_task(
|
|
751
|
+
task_table,
|
|
752
|
+
metrics=metrics,
|
|
753
|
+
run_mode=run_mode,
|
|
754
|
+
num_neighbors=num_neighbors,
|
|
755
|
+
num_hops=num_hops,
|
|
756
|
+
verbose=verbose,
|
|
757
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
758
|
+
use_prediction_time=use_prediction_time,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
def evaluate_task(
|
|
762
|
+
self,
|
|
763
|
+
task: TaskTable,
|
|
764
|
+
*,
|
|
765
|
+
metrics: list[str] | None = None,
|
|
766
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
767
|
+
num_neighbors: list[int] | None = None,
|
|
768
|
+
num_hops: int = 2,
|
|
769
|
+
verbose: bool | ProgressLogger = True,
|
|
770
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
771
|
+
use_prediction_time: bool = False,
|
|
772
|
+
) -> pd.DataFrame:
|
|
773
|
+
"""Evaluates a custom task specification.
|
|
774
|
+
|
|
775
|
+
Args:
|
|
776
|
+
task: The custom :class:`TaskTable`.
|
|
777
|
+
metrics: The metrics to use.
|
|
778
|
+
run_mode: The :class:`RunMode` for the query.
|
|
779
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
780
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
781
|
+
num_hops: The number of hops to sample when generating the context.
|
|
782
|
+
verbose: Whether to print verbose output.
|
|
783
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
784
|
+
model input.
|
|
785
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
786
|
+
additional feature during prediction. This is typically
|
|
787
|
+
beneficial for time series forecasting tasks.
|
|
788
|
+
|
|
789
|
+
Returns:
|
|
790
|
+
The metrics as a :class:`pandas.DataFrame`
|
|
791
|
+
"""
|
|
587
792
|
if num_hops != 2 and num_neighbors is not None:
|
|
588
793
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
589
794
|
f"custom 'num_hops={num_hops}' option")
|
|
795
|
+
if num_neighbors is None:
|
|
796
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
797
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
590
798
|
|
|
591
|
-
if
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
rfm_entity_ids=None,
|
|
595
|
-
)
|
|
596
|
-
|
|
597
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
598
|
-
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
799
|
+
if metrics is not None and len(metrics) > 0:
|
|
800
|
+
self._validate_metrics(metrics, task.task_type)
|
|
801
|
+
metrics = list(dict.fromkeys(metrics))
|
|
599
802
|
|
|
600
803
|
if not isinstance(verbose, ProgressLogger):
|
|
601
|
-
|
|
804
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
805
|
+
task_type_repr = 'binary classification'
|
|
806
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
807
|
+
task_type_repr = 'multi-class classification'
|
|
808
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
809
|
+
task_type_repr = 'regression'
|
|
810
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
811
|
+
task_type_repr = 'link prediction'
|
|
812
|
+
else:
|
|
813
|
+
task_type_repr = str(task.task_type)
|
|
814
|
+
|
|
815
|
+
msg = f"Evaluating {task_type_repr} task"
|
|
816
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
602
817
|
|
|
603
818
|
with verbose as logger:
|
|
819
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
820
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
821
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
822
|
+
f"examples")
|
|
823
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
824
|
+
|
|
825
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
826
|
+
logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
|
|
827
|
+
f"out of {task.num_prediction_examples:,} test "
|
|
828
|
+
f"examples")
|
|
829
|
+
task = task.narrow_prediction(
|
|
830
|
+
start=0,
|
|
831
|
+
length=_MAX_TEST_SIZE[task.task_type],
|
|
832
|
+
)
|
|
833
|
+
|
|
604
834
|
context = self._get_context(
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
anchor_time=anchor_time,
|
|
608
|
-
context_anchor_time=context_anchor_time,
|
|
609
|
-
run_mode=RunMode(run_mode),
|
|
835
|
+
task=task,
|
|
836
|
+
run_mode=run_mode,
|
|
610
837
|
num_neighbors=num_neighbors,
|
|
611
|
-
|
|
612
|
-
max_pq_iterations=max_pq_iterations,
|
|
613
|
-
evaluate=True,
|
|
614
|
-
random_seed=random_seed,
|
|
615
|
-
logger=logger if verbose else None,
|
|
838
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
616
839
|
)
|
|
617
|
-
|
|
618
|
-
self._validate_metrics(metrics, context.task_type)
|
|
619
|
-
metrics = list(dict.fromkeys(metrics))
|
|
840
|
+
|
|
620
841
|
request = RFMEvaluateRequest(
|
|
621
842
|
context=context,
|
|
622
843
|
run_mode=RunMode(run_mode),
|
|
@@ -634,17 +855,23 @@ class KumoRFM:
|
|
|
634
855
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
635
856
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
636
857
|
|
|
637
|
-
|
|
638
|
-
resp = self._api_client.evaluate(request_bytes)
|
|
639
|
-
except HTTPException as e:
|
|
858
|
+
for attempt in range(self._num_retries + 1):
|
|
640
859
|
try:
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
860
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
861
|
+
break
|
|
862
|
+
except HTTPException as e:
|
|
863
|
+
if attempt == self._num_retries:
|
|
864
|
+
try:
|
|
865
|
+
msg = json.loads(e.detail)['detail']
|
|
866
|
+
except Exception:
|
|
867
|
+
msg = e.detail
|
|
868
|
+
raise RuntimeError(
|
|
869
|
+
f"An unexpected exception occurred. Please create "
|
|
870
|
+
f"an issue at "
|
|
871
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
872
|
+
) from None
|
|
873
|
+
|
|
874
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
648
875
|
|
|
649
876
|
return pd.DataFrame.from_dict(
|
|
650
877
|
resp.metrics,
|
|
@@ -657,9 +884,9 @@ class KumoRFM:
|
|
|
657
884
|
query: str,
|
|
658
885
|
size: int,
|
|
659
886
|
*,
|
|
660
|
-
anchor_time:
|
|
661
|
-
random_seed:
|
|
662
|
-
max_iterations: int =
|
|
887
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
888
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
889
|
+
max_iterations: int = 10,
|
|
663
890
|
) -> pd.DataFrame:
|
|
664
891
|
"""Returns the labels of a predictive query for a specified anchor
|
|
665
892
|
time.
|
|
@@ -679,40 +906,37 @@ class KumoRFM:
|
|
|
679
906
|
query_def = self._parse_query(query)
|
|
680
907
|
|
|
681
908
|
if anchor_time is None:
|
|
682
|
-
anchor_time = self.
|
|
909
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
683
910
|
if query_def.target_ast.date_offset_range is not None:
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
911
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
912
|
+
offset *= query_def.num_forecasts
|
|
913
|
+
anchor_time -= offset
|
|
687
914
|
|
|
688
915
|
assert anchor_time is not None
|
|
689
916
|
if isinstance(anchor_time, pd.Timestamp):
|
|
690
917
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
691
918
|
else:
|
|
692
919
|
assert anchor_time == 'entity'
|
|
693
|
-
if
|
|
920
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
694
921
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
695
922
|
f"table '{query_def.entity_table}' "
|
|
696
923
|
f"to have a time column")
|
|
697
924
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
925
|
+
train, test = self._sampler.sample_target(
|
|
926
|
+
query=query_def,
|
|
927
|
+
num_train_examples=0,
|
|
928
|
+
train_anchor_time=anchor_time,
|
|
929
|
+
num_train_trials=0,
|
|
930
|
+
num_test_examples=size,
|
|
931
|
+
test_anchor_time=anchor_time,
|
|
932
|
+
num_test_trials=max_iterations * size,
|
|
933
|
+
random_seed=random_seed,
|
|
707
934
|
)
|
|
708
935
|
|
|
709
|
-
entity = self._graph_store.pkey_map_dict[
|
|
710
|
-
query_def.entity_table].index[node]
|
|
711
|
-
|
|
712
936
|
return pd.DataFrame({
|
|
713
|
-
'ENTITY':
|
|
714
|
-
'ANCHOR_TIMESTAMP':
|
|
715
|
-
'TARGET':
|
|
937
|
+
'ENTITY': test.entity_pkey,
|
|
938
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
939
|
+
'TARGET': test.target,
|
|
716
940
|
})
|
|
717
941
|
|
|
718
942
|
# Helpers #################################################################
|
|
@@ -727,63 +951,120 @@ class KumoRFM:
|
|
|
727
951
|
"`predict()` or `evaluate()` methods to perform "
|
|
728
952
|
"predictions or evaluations.")
|
|
729
953
|
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
)
|
|
954
|
+
request = RFMParseQueryRequest(
|
|
955
|
+
query=query,
|
|
956
|
+
graph_definition=self._graph_def,
|
|
957
|
+
)
|
|
735
958
|
|
|
736
|
-
|
|
959
|
+
for attempt in range(self._num_retries + 1):
|
|
960
|
+
try:
|
|
961
|
+
resp = self._api_client.parse_query(request)
|
|
962
|
+
break
|
|
963
|
+
except HTTPException as e:
|
|
964
|
+
if attempt == self._num_retries:
|
|
965
|
+
try:
|
|
966
|
+
msg = json.loads(e.detail)['detail']
|
|
967
|
+
except Exception:
|
|
968
|
+
msg = e.detail
|
|
969
|
+
raise ValueError(f"Failed to parse query '{query}'. {msg}")
|
|
737
970
|
|
|
738
|
-
|
|
971
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
739
972
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
973
|
+
if len(resp.validation_response.warnings) > 0:
|
|
974
|
+
msg = '\n'.join([
|
|
975
|
+
f'{i+1}. {warning.title}: {warning.message}'
|
|
976
|
+
for i, warning in enumerate(resp.validation_response.warnings)
|
|
977
|
+
])
|
|
978
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
979
|
+
f"parsing:\n{msg}")
|
|
747
980
|
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
981
|
+
return resp.query
|
|
982
|
+
|
|
983
|
+
@staticmethod
|
|
984
|
+
def _get_task_type(
|
|
985
|
+
query: ValidatedPredictiveQuery,
|
|
986
|
+
edge_types: list[tuple[str, str, str]],
|
|
987
|
+
) -> TaskType:
|
|
988
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
989
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
990
|
+
|
|
991
|
+
target = query.target_ast
|
|
992
|
+
if isinstance(target, Join):
|
|
993
|
+
target = target.rhs_target
|
|
994
|
+
if isinstance(target, Aggregation):
|
|
995
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
996
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
997
|
+
'.')
|
|
998
|
+
target_edge_types = [
|
|
999
|
+
edge_type for edge_type in edge_types
|
|
1000
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
1001
|
+
]
|
|
1002
|
+
if len(target_edge_types) != 1:
|
|
1003
|
+
raise NotImplementedError(
|
|
1004
|
+
f"Multilabel-classification queries based on "
|
|
1005
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
1006
|
+
f"planned to write a link prediction query instead, "
|
|
1007
|
+
f"make sure to register '{col_name}' as a "
|
|
1008
|
+
f"foreign key.")
|
|
1009
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
1010
|
+
|
|
1011
|
+
return TaskType.REGRESSION
|
|
1012
|
+
|
|
1013
|
+
assert isinstance(target, Column)
|
|
1014
|
+
|
|
1015
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
1016
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
1017
|
+
|
|
1018
|
+
if target.stype in {Stype.numerical}:
|
|
1019
|
+
return TaskType.REGRESSION
|
|
1020
|
+
|
|
1021
|
+
raise NotImplementedError("Task type not yet supported")
|
|
1022
|
+
|
|
1023
|
+
def _get_default_anchor_time(
|
|
1024
|
+
self,
|
|
1025
|
+
query: ValidatedPredictiveQuery | None = None,
|
|
1026
|
+
) -> pd.Timestamp:
|
|
1027
|
+
if query is not None and query.query_type == QueryType.TEMPORAL:
|
|
1028
|
+
aggr_table_names = [
|
|
1029
|
+
aggr._get_target_column_name().split('.')[0]
|
|
1030
|
+
for aggr in query.get_all_target_aggregations()
|
|
1031
|
+
]
|
|
1032
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
1033
|
+
|
|
1034
|
+
return self._sampler.get_max_time()
|
|
756
1035
|
|
|
757
1036
|
def _validate_time(
|
|
758
1037
|
self,
|
|
759
1038
|
query: ValidatedPredictiveQuery,
|
|
760
1039
|
anchor_time: pd.Timestamp,
|
|
761
|
-
context_anchor_time:
|
|
1040
|
+
context_anchor_time: pd.Timestamp | None,
|
|
762
1041
|
evaluate: bool,
|
|
763
1042
|
) -> None:
|
|
764
1043
|
|
|
765
|
-
if self.
|
|
1044
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
766
1045
|
return # Graph without timestamps
|
|
767
1046
|
|
|
768
|
-
|
|
1047
|
+
min_time = self._sampler.get_min_time()
|
|
1048
|
+
max_time = self._sampler.get_max_time()
|
|
1049
|
+
|
|
1050
|
+
if anchor_time < min_time:
|
|
769
1051
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
770
|
-
f"the earliest timestamp "
|
|
771
|
-
f"
|
|
1052
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
1053
|
+
f"data.")
|
|
772
1054
|
|
|
773
|
-
if
|
|
774
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
1055
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
775
1056
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
776
1057
|
f"aggregation time range is too large. To make "
|
|
777
1058
|
f"this prediction, we would need data back to "
|
|
778
1059
|
f"'{context_anchor_time}', however, your data "
|
|
779
|
-
f"only contains data back to "
|
|
780
|
-
f"'{self._graph_store.min_time}'.")
|
|
1060
|
+
f"only contains data back to '{min_time}'.")
|
|
781
1061
|
|
|
782
1062
|
if query.target_ast.date_offset_range is not None:
|
|
783
1063
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
784
1064
|
else:
|
|
785
1065
|
end_offset = pd.DateOffset(0)
|
|
786
|
-
|
|
1066
|
+
end_offset = end_offset * query.num_forecasts
|
|
1067
|
+
|
|
787
1068
|
if (context_anchor_time is not None
|
|
788
1069
|
and context_anchor_time > anchor_time):
|
|
789
1070
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -793,7 +1074,7 @@ class KumoRFM:
|
|
|
793
1074
|
f"intended.")
|
|
794
1075
|
elif (query.query_type == QueryType.TEMPORAL
|
|
795
1076
|
and context_anchor_time is not None
|
|
796
|
-
and context_anchor_time +
|
|
1077
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
797
1078
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
798
1079
|
f"'{context_anchor_time}' will leak information "
|
|
799
1080
|
f"from the prediction anchor timestamp "
|
|
@@ -801,62 +1082,44 @@ class KumoRFM:
|
|
|
801
1082
|
f"intended.")
|
|
802
1083
|
|
|
803
1084
|
elif (context_anchor_time is not None
|
|
804
|
-
and context_anchor_time -
|
|
805
|
-
|
|
806
|
-
_time = context_anchor_time - forecast_end_offset
|
|
1085
|
+
and context_anchor_time - end_offset < min_time):
|
|
1086
|
+
_time = context_anchor_time - end_offset
|
|
807
1087
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
808
1088
|
f"aggregation time range is too large. To form "
|
|
809
1089
|
f"proper input data, we would need data back to "
|
|
810
1090
|
f"'{_time}', however, your data only contains "
|
|
811
|
-
f"data back to '{
|
|
1091
|
+
f"data back to '{min_time}'.")
|
|
812
1092
|
|
|
813
|
-
if
|
|
814
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
1093
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
815
1094
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
816
|
-
f"latest timestamp '{
|
|
817
|
-
f"
|
|
1095
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
1096
|
+
f"make sure this is intended.")
|
|
818
1097
|
|
|
819
|
-
|
|
820
|
-
if evaluate and anchor_time > max_eval_time:
|
|
1098
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
821
1099
|
raise ValueError(
|
|
822
1100
|
f"Anchor timestamp for evaluation is after the latest "
|
|
823
|
-
f"supported timestamp '{
|
|
1101
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
824
1102
|
|
|
825
|
-
def
|
|
1103
|
+
def _get_task_table(
|
|
826
1104
|
self,
|
|
827
1105
|
query: ValidatedPredictiveQuery,
|
|
828
|
-
indices:
|
|
829
|
-
anchor_time:
|
|
830
|
-
context_anchor_time:
|
|
831
|
-
run_mode: RunMode,
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
if num_neighbors is not None:
|
|
841
|
-
num_hops = len(num_neighbors)
|
|
842
|
-
|
|
843
|
-
if num_hops < 0:
|
|
844
|
-
raise ValueError(f"'num_hops' must be non-negative "
|
|
845
|
-
f"(got {num_hops})")
|
|
846
|
-
if num_hops > 6:
|
|
847
|
-
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
848
|
-
f"hops (got {num_hops}). Please reduce the "
|
|
849
|
-
f"number of hops and try again. Please create a "
|
|
850
|
-
f"feature request at "
|
|
851
|
-
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
852
|
-
f"must go beyond this for your use-case.")
|
|
853
|
-
|
|
854
|
-
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
855
|
-
task_type = LocalPQueryDriver.get_task_type(
|
|
856
|
-
query,
|
|
857
|
-
edge_types=self._graph_store.edge_types,
|
|
1106
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
1107
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
1108
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
1109
|
+
run_mode: RunMode = RunMode.FAST,
|
|
1110
|
+
max_pq_iterations: int = 10,
|
|
1111
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
1112
|
+
logger: ProgressLogger | None = None,
|
|
1113
|
+
) -> TaskTable:
|
|
1114
|
+
|
|
1115
|
+
task_type = self._get_task_type(
|
|
1116
|
+
query=query,
|
|
1117
|
+
edge_types=self._sampler.edge_types,
|
|
858
1118
|
)
|
|
859
1119
|
|
|
1120
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1121
|
+
num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
|
|
1122
|
+
|
|
860
1123
|
if logger is not None:
|
|
861
1124
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
862
1125
|
task_type_repr = 'binary classification'
|
|
@@ -870,30 +1133,17 @@ class KumoRFM:
|
|
|
870
1133
|
task_type_repr = str(task_type)
|
|
871
1134
|
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
872
1135
|
|
|
873
|
-
if task_type.is_link_pred and num_hops < 2:
|
|
874
|
-
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
875
|
-
f"with less than 2 hops (got {num_hops}) since "
|
|
876
|
-
f"historical target entities need to be part of "
|
|
877
|
-
f"the context. Please increase the number of "
|
|
878
|
-
f"hops and try again.")
|
|
879
|
-
|
|
880
|
-
if num_neighbors is None:
|
|
881
|
-
if run_mode == RunMode.DEBUG:
|
|
882
|
-
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
883
|
-
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
884
|
-
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
885
|
-
else:
|
|
886
|
-
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
887
|
-
|
|
888
1136
|
if query.target_ast.date_offset_range is None:
|
|
889
|
-
|
|
1137
|
+
step_offset = pd.DateOffset(0)
|
|
890
1138
|
else:
|
|
891
|
-
|
|
892
|
-
|
|
1139
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
1140
|
+
end_offset = step_offset * query.num_forecasts
|
|
1141
|
+
|
|
893
1142
|
if anchor_time is None:
|
|
894
|
-
anchor_time = self.
|
|
895
|
-
if
|
|
896
|
-
anchor_time = anchor_time -
|
|
1143
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
1144
|
+
if num_test_examples > 0:
|
|
1145
|
+
anchor_time = anchor_time - end_offset
|
|
1146
|
+
|
|
897
1147
|
if logger is not None:
|
|
898
1148
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
899
1149
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -905,114 +1155,98 @@ class KumoRFM:
|
|
|
905
1155
|
else:
|
|
906
1156
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
907
1157
|
|
|
908
|
-
assert anchor_time is not None
|
|
909
1158
|
if isinstance(anchor_time, pd.Timestamp):
|
|
1159
|
+
if context_anchor_time == 'entity':
|
|
1160
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
1161
|
+
"for context and prediction examples")
|
|
910
1162
|
if context_anchor_time is None:
|
|
911
|
-
context_anchor_time = anchor_time -
|
|
1163
|
+
context_anchor_time = anchor_time - end_offset
|
|
912
1164
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
913
|
-
evaluate)
|
|
1165
|
+
evaluate=num_test_examples > 0)
|
|
914
1166
|
else:
|
|
915
1167
|
assert anchor_time == 'entity'
|
|
916
|
-
if query.
|
|
1168
|
+
if query.query_type != QueryType.STATIC:
|
|
1169
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
1170
|
+
"static predictive queries")
|
|
1171
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
917
1172
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
918
1173
|
f"table '{query.entity_table}' to "
|
|
919
1174
|
f"have a time column")
|
|
920
|
-
if context_anchor_time
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
context_anchor_time =
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
if logger is not None:
|
|
938
|
-
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
939
|
-
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
940
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
941
|
-
f"{pos:.2f}% positive cases")
|
|
942
|
-
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
943
|
-
msg = (f"Collected {len(y_test):,} test examples "
|
|
944
|
-
f"holding {y_test.nunique()} classes")
|
|
945
|
-
elif task_type == TaskType.REGRESSION:
|
|
946
|
-
_min, _max = float(y_test.min()), float(y_test.max())
|
|
947
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
948
|
-
f"targets between {format_value(_min)} and "
|
|
949
|
-
f"{format_value(_max)}")
|
|
950
|
-
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
951
|
-
num_rhs = y_test.explode().nunique()
|
|
952
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
953
|
-
f"{num_rhs:,} unique items")
|
|
954
|
-
else:
|
|
955
|
-
raise NotImplementedError
|
|
956
|
-
logger.log(msg)
|
|
957
|
-
|
|
958
|
-
else:
|
|
959
|
-
assert indices is not None
|
|
960
|
-
|
|
961
|
-
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
962
|
-
raise ValueError(f"Cannot predict for more than "
|
|
963
|
-
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
964
|
-
f"once (got {len(indices):,}). Use "
|
|
965
|
-
f"`KumoRFM.batch_mode` to process entities "
|
|
966
|
-
f"in batches")
|
|
1175
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
1176
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
1177
|
+
"for context and prediction examples")
|
|
1178
|
+
context_anchor_time = 'entity'
|
|
1179
|
+
|
|
1180
|
+
train, test = self._sampler.sample_target(
|
|
1181
|
+
query=query,
|
|
1182
|
+
num_train_examples=num_train_examples,
|
|
1183
|
+
train_anchor_time=context_anchor_time,
|
|
1184
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1185
|
+
num_test_examples=num_test_examples,
|
|
1186
|
+
test_anchor_time=anchor_time,
|
|
1187
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1188
|
+
random_seed=random_seed,
|
|
1189
|
+
)
|
|
1190
|
+
train_pkey, train_time, train_y = train
|
|
1191
|
+
test_pkey, test_time, test_y = test
|
|
967
1192
|
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
1193
|
+
if num_test_examples > 0 and logger is not None:
|
|
1194
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1195
|
+
pos = 100 * int((test_y > 0).sum()) / len(test_y)
|
|
1196
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1197
|
+
f"{pos:.2f}% positive cases")
|
|
1198
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1199
|
+
msg = (f"Collected {len(test_y):,} test examples holding "
|
|
1200
|
+
f"{test_y.nunique()} classes")
|
|
1201
|
+
elif task_type == TaskType.REGRESSION:
|
|
1202
|
+
_min, _max = float(test_y.min()), float(test_y.max())
|
|
1203
|
+
msg = (f"Collected {len(test_y):,} test examples with targets "
|
|
1204
|
+
f"between {format_value(_min)} and "
|
|
1205
|
+
f"{format_value(_max)}")
|
|
1206
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1207
|
+
num_rhs = test_y.explode().nunique()
|
|
1208
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1209
|
+
f"{num_rhs:,} unique items")
|
|
1210
|
+
else:
|
|
1211
|
+
raise NotImplementedError
|
|
1212
|
+
logger.log(msg)
|
|
972
1213
|
|
|
1214
|
+
if num_test_examples == 0:
|
|
1215
|
+
assert indices is not None
|
|
1216
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
973
1217
|
if isinstance(anchor_time, pd.Timestamp):
|
|
974
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
975
|
-
len(
|
|
1218
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1219
|
+
len(indices)).reset_index(drop=True)
|
|
976
1220
|
else:
|
|
977
|
-
|
|
978
|
-
time = time[test_node] * 1000**3
|
|
979
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
980
|
-
|
|
981
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
982
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
983
|
-
anchor_time=context_anchor_time or 'entity',
|
|
984
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
985
|
-
or anchor_time == 'entity') else None,
|
|
986
|
-
max_iterations=max_pq_iterations,
|
|
987
|
-
)
|
|
1221
|
+
train_time = test_time = 'entity'
|
|
988
1222
|
|
|
989
1223
|
if logger is not None:
|
|
990
1224
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
991
|
-
pos = 100 * int((
|
|
992
|
-
msg = (f"Collected {len(
|
|
1225
|
+
pos = 100 * int((train_y > 0).sum()) / len(train_y)
|
|
1226
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
993
1227
|
f"{pos:.2f}% positive cases")
|
|
994
1228
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
995
|
-
msg = (f"Collected {len(
|
|
996
|
-
f"holding {
|
|
1229
|
+
msg = (f"Collected {len(train_y):,} in-context examples "
|
|
1230
|
+
f"holding {train_y.nunique()} classes")
|
|
997
1231
|
elif task_type == TaskType.REGRESSION:
|
|
998
|
-
_min, _max = float(
|
|
999
|
-
msg = (f"Collected {len(
|
|
1232
|
+
_min, _max = float(train_y.min()), float(train_y.max())
|
|
1233
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1000
1234
|
f"targets between {format_value(_min)} and "
|
|
1001
1235
|
f"{format_value(_max)}")
|
|
1002
1236
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1003
|
-
num_rhs =
|
|
1004
|
-
msg = (f"Collected {len(
|
|
1237
|
+
num_rhs = train_y.explode().nunique()
|
|
1238
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1005
1239
|
f"{num_rhs:,} unique items")
|
|
1006
1240
|
else:
|
|
1007
1241
|
raise NotImplementedError
|
|
1008
1242
|
logger.log(msg)
|
|
1009
1243
|
|
|
1010
|
-
entity_table_names:
|
|
1244
|
+
entity_table_names: tuple[str] | tuple[str, str]
|
|
1011
1245
|
if task_type.is_link_pred:
|
|
1012
1246
|
final_aggr = query.get_final_target_aggregation()
|
|
1013
1247
|
assert final_aggr is not None
|
|
1014
1248
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1015
|
-
for edge_type in self.
|
|
1249
|
+
for edge_type in self._sampler.edge_types:
|
|
1016
1250
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1017
1251
|
entity_table_names = (
|
|
1018
1252
|
query.entity_table,
|
|
@@ -1021,23 +1255,80 @@ class KumoRFM:
|
|
|
1021
1255
|
else:
|
|
1022
1256
|
entity_table_names = (query.entity_table, )
|
|
1023
1257
|
|
|
1258
|
+
context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
|
|
1259
|
+
if isinstance(train_time, pd.Series):
|
|
1260
|
+
context_df['ANCHOR_TIMESTAMP'] = train_time
|
|
1261
|
+
pred_df = pd.DataFrame({'ENTITY': test_pkey})
|
|
1262
|
+
if num_test_examples > 0:
|
|
1263
|
+
pred_df['TARGET'] = test_y
|
|
1264
|
+
if isinstance(test_time, pd.Series):
|
|
1265
|
+
pred_df['ANCHOR_TIMESTAMP'] = test_time
|
|
1266
|
+
|
|
1267
|
+
return TaskTable(
|
|
1268
|
+
task_type=task_type,
|
|
1269
|
+
context_df=context_df,
|
|
1270
|
+
pred_df=pred_df,
|
|
1271
|
+
entity_table_name=entity_table_names,
|
|
1272
|
+
entity_column='ENTITY',
|
|
1273
|
+
target_column='TARGET',
|
|
1274
|
+
time_column='ANCHOR_TIMESTAMP' if isinstance(
|
|
1275
|
+
train_time, pd.Series) else TaskTable.ENTITY_TIME,
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
def _get_context(
|
|
1279
|
+
self,
|
|
1280
|
+
task: TaskTable,
|
|
1281
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
1282
|
+
num_neighbors: list[int] | None = None,
|
|
1283
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
1284
|
+
top_k: int | None = None,
|
|
1285
|
+
) -> Context:
|
|
1286
|
+
|
|
1287
|
+
if num_neighbors is None:
|
|
1288
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1289
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
|
1290
|
+
|
|
1291
|
+
if len(num_neighbors) > 6:
|
|
1292
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
1293
|
+
f"hops (got {len(num_neighbors)}). Reduce the "
|
|
1294
|
+
f"number of hops and try again. Please create a "
|
|
1295
|
+
f"feature request at "
|
|
1296
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1297
|
+
f"must go beyond this for your use-case.")
|
|
1298
|
+
|
|
1024
1299
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1025
1300
|
# running out-of-distribution between in-context and test examples:
|
|
1026
|
-
exclude_cols_dict =
|
|
1027
|
-
if
|
|
1028
|
-
if
|
|
1029
|
-
exclude_cols_dict[
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1301
|
+
exclude_cols_dict = exclude_cols_dict or {}
|
|
1302
|
+
if task.entity_table_name in self._sampler.time_column_dict:
|
|
1303
|
+
if task.entity_table_name not in exclude_cols_dict:
|
|
1304
|
+
exclude_cols_dict[task.entity_table_name] = []
|
|
1305
|
+
time_col = self._sampler.time_column_dict[task.entity_table_name]
|
|
1306
|
+
exclude_cols_dict[task.entity_table_name].append(time_col)
|
|
1307
|
+
|
|
1308
|
+
entity_pkey = pd.concat([
|
|
1309
|
+
task._context_df[task._entity_column],
|
|
1310
|
+
task._pred_df[task._entity_column],
|
|
1311
|
+
], axis=0, ignore_index=True)
|
|
1312
|
+
|
|
1313
|
+
if task.use_entity_time:
|
|
1314
|
+
if task.entity_table_name not in self._sampler.time_column_dict:
|
|
1315
|
+
raise ValueError(f"The given annchor time requires the entity "
|
|
1316
|
+
f"table '{task.entity_table_name}' to have a "
|
|
1317
|
+
f"time column")
|
|
1318
|
+
anchor_time = 'entity'
|
|
1319
|
+
elif task._time_column is not None:
|
|
1320
|
+
anchor_time = pd.concat([
|
|
1321
|
+
task._context_df[task._time_column],
|
|
1322
|
+
task._pred_df[task._time_column],
|
|
1323
|
+
], axis=0, ignore_index=True)
|
|
1324
|
+
else:
|
|
1325
|
+
anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
|
|
1326
|
+
(len(entity_pkey))).reset_index(drop=True)
|
|
1327
|
+
|
|
1328
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1329
|
+
entity_table_names=task.entity_table_names,
|
|
1330
|
+
entity_pkey=entity_pkey,
|
|
1331
|
+
anchor_time=anchor_time,
|
|
1041
1332
|
num_neighbors=num_neighbors,
|
|
1042
1333
|
exclude_cols_dict=exclude_cols_dict,
|
|
1043
1334
|
)
|
|
@@ -1049,23 +1340,26 @@ class KumoRFM:
|
|
|
1049
1340
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1050
1341
|
f"must go beyond this for your use-case.")
|
|
1051
1342
|
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1343
|
+
if (task.task_type.is_link_pred
|
|
1344
|
+
and task.entity_table_names[-1] not in subgraph.table_dict):
|
|
1345
|
+
raise ValueError("Cannot perform link prediction on subgraphs "
|
|
1346
|
+
"without any historical target entities. Please "
|
|
1347
|
+
"increase the number of hops and try again.")
|
|
1055
1348
|
|
|
1056
1349
|
return Context(
|
|
1057
|
-
task_type=task_type,
|
|
1058
|
-
entity_table_names=entity_table_names,
|
|
1350
|
+
task_type=task.task_type,
|
|
1351
|
+
entity_table_names=task.entity_table_names,
|
|
1059
1352
|
subgraph=subgraph,
|
|
1060
|
-
y_train=
|
|
1061
|
-
y_test=
|
|
1062
|
-
|
|
1063
|
-
|
|
1353
|
+
y_train=task._context_df[task.target_column.name],
|
|
1354
|
+
y_test=task._pred_df[task.target_column.name]
|
|
1355
|
+
if task.evaluate else None,
|
|
1356
|
+
top_k=top_k,
|
|
1357
|
+
step_size=None,
|
|
1064
1358
|
)
|
|
1065
1359
|
|
|
1066
1360
|
@staticmethod
|
|
1067
1361
|
def _validate_metrics(
|
|
1068
|
-
metrics:
|
|
1362
|
+
metrics: list[str],
|
|
1069
1363
|
task_type: TaskType,
|
|
1070
1364
|
) -> None:
|
|
1071
1365
|
|
|
@@ -1122,7 +1416,7 @@ class KumoRFM:
|
|
|
1122
1416
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1123
1417
|
|
|
1124
1418
|
|
|
1125
|
-
def format_value(value:
|
|
1419
|
+
def format_value(value: int | float) -> str:
|
|
1126
1420
|
if value == int(value):
|
|
1127
1421
|
return f'{int(value):,}'
|
|
1128
1422
|
if abs(value) >= 1000:
|