kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -30
- kumoai/experimental/rfm/backend/snow/sampler.py +197 -90
- kumoai/experimental/rfm/backend/snow/table.py +159 -52
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +199 -99
- kumoai/experimental/rfm/backend/sqlite/table.py +103 -45
- kumoai/experimental/rfm/base/__init__.py +6 -1
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +342 -13
- kumoai/experimental/rfm/base/table.py +374 -208
- kumoai/experimental/rfm/base/utils.py +27 -0
- kumoai/experimental/rfm/graph.py +335 -180
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +600 -360
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +3 -2
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +48 -40
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,23 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
import time
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
6
|
+
from collections.abc import Generator, Iterator
|
|
6
7
|
from contextlib import contextmanager
|
|
7
8
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
9
|
+
from typing import Any, Literal, overload
|
|
19
10
|
|
|
20
|
-
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
@@ -37,30 +27,37 @@ from kumoapi.rfm import (
|
|
|
37
27
|
)
|
|
38
28
|
from kumoapi.task import TaskType
|
|
39
29
|
from kumoapi.typing import AggregationType, Stype
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.markdown import Markdown
|
|
40
32
|
|
|
33
|
+
from kumoai import in_notebook
|
|
41
34
|
from kumoai.client.rfm import RFMAPI
|
|
42
35
|
from kumoai.exceptions import HTTPException
|
|
43
|
-
from kumoai.experimental.rfm import Graph
|
|
36
|
+
from kumoai.experimental.rfm import Graph, TaskTable
|
|
44
37
|
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
45
38
|
from kumoai.mixin import CastMixin
|
|
46
|
-
from kumoai.utils import
|
|
39
|
+
from kumoai.utils import ProgressLogger, display
|
|
47
40
|
|
|
48
41
|
_RANDOM_SEED = 42
|
|
49
42
|
|
|
50
43
|
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
51
44
|
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
52
45
|
|
|
46
|
+
_MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
|
|
47
|
+
_MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
|
|
48
|
+
|
|
53
49
|
_MAX_CONTEXT_SIZE = {
|
|
54
50
|
RunMode.DEBUG: 100,
|
|
55
51
|
RunMode.FAST: 1_000,
|
|
56
52
|
RunMode.NORMAL: 5_000,
|
|
57
53
|
RunMode.BEST: 10_000,
|
|
58
54
|
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
RunMode.
|
|
62
|
-
RunMode.
|
|
63
|
-
RunMode.
|
|
55
|
+
|
|
56
|
+
_DEFAULT_NUM_NEIGHBORS = {
|
|
57
|
+
RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
|
|
58
|
+
RunMode.FAST: [32, 32, 8, 8, 4, 4],
|
|
59
|
+
RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
|
|
60
|
+
RunMode.BEST: [64, 64, 8, 8, 4, 4],
|
|
64
61
|
}
|
|
65
62
|
|
|
66
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
@@ -98,24 +95,36 @@ class Explanation:
|
|
|
98
95
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
99
96
|
pass
|
|
100
97
|
|
|
101
|
-
def __getitem__(self, index: int) ->
|
|
98
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
102
99
|
if index == 0:
|
|
103
100
|
return self.prediction
|
|
104
101
|
if index == 1:
|
|
105
102
|
return self.summary
|
|
106
103
|
raise IndexError("Index out of range")
|
|
107
104
|
|
|
108
|
-
def __iter__(self) -> Iterator[
|
|
105
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
109
106
|
return iter((self.prediction, self.summary))
|
|
110
107
|
|
|
111
108
|
def __repr__(self) -> str:
|
|
112
109
|
return str((self.prediction, self.summary))
|
|
113
110
|
|
|
114
|
-
def
|
|
115
|
-
|
|
111
|
+
def __str__(self) -> str:
|
|
112
|
+
console = Console(soft_wrap=True)
|
|
113
|
+
with console.capture() as cap:
|
|
114
|
+
console.print(display.to_rich_table(self.prediction))
|
|
115
|
+
console.print(Markdown(self.summary))
|
|
116
|
+
return cap.get()[:-1]
|
|
117
|
+
|
|
118
|
+
def print(self) -> None:
|
|
119
|
+
r"""Prints the explanation."""
|
|
120
|
+
if in_notebook():
|
|
121
|
+
display.dataframe(self.prediction)
|
|
122
|
+
display.message(self.summary)
|
|
123
|
+
else:
|
|
124
|
+
print(self)
|
|
116
125
|
|
|
117
|
-
|
|
118
|
-
|
|
126
|
+
def _ipython_display_(self) -> None:
|
|
127
|
+
self.print()
|
|
119
128
|
|
|
120
129
|
|
|
121
130
|
class KumoRFM:
|
|
@@ -162,7 +171,7 @@ class KumoRFM:
|
|
|
162
171
|
def __init__(
|
|
163
172
|
self,
|
|
164
173
|
graph: Graph,
|
|
165
|
-
verbose:
|
|
174
|
+
verbose: bool | ProgressLogger = True,
|
|
166
175
|
optimize: bool = False,
|
|
167
176
|
) -> None:
|
|
168
177
|
graph = graph.validate()
|
|
@@ -180,10 +189,10 @@ class KumoRFM:
|
|
|
180
189
|
else:
|
|
181
190
|
raise NotImplementedError
|
|
182
191
|
|
|
183
|
-
self._client:
|
|
192
|
+
self._client: RFMAPI | None = None
|
|
184
193
|
|
|
185
|
-
self._batch_size:
|
|
186
|
-
self.
|
|
194
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
195
|
+
self._num_retries: int = 0
|
|
187
196
|
|
|
188
197
|
@property
|
|
189
198
|
def _api_client(self) -> RFMAPI:
|
|
@@ -197,10 +206,34 @@ class KumoRFM:
|
|
|
197
206
|
def __repr__(self) -> str:
|
|
198
207
|
return f'{self.__class__.__name__}()'
|
|
199
208
|
|
|
209
|
+
@contextmanager
|
|
210
|
+
def retry(
|
|
211
|
+
self,
|
|
212
|
+
num_retries: int = 1,
|
|
213
|
+
) -> Generator[None, None, None]:
|
|
214
|
+
"""Context manager to retry failed queries due to unexpected server
|
|
215
|
+
issues.
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
with model.retry(num_retries=1):
|
|
220
|
+
df = model.predict(query, indices=...)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
num_retries: The maximum number of retries.
|
|
224
|
+
"""
|
|
225
|
+
if num_retries < 0:
|
|
226
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
227
|
+
f"zero (got {num_retries})")
|
|
228
|
+
|
|
229
|
+
self._num_retries = num_retries
|
|
230
|
+
yield
|
|
231
|
+
self._num_retries = 0
|
|
232
|
+
|
|
200
233
|
@contextmanager
|
|
201
234
|
def batch_mode(
|
|
202
235
|
self,
|
|
203
|
-
batch_size:
|
|
236
|
+
batch_size: int | Literal['max'] = 'max',
|
|
204
237
|
num_retries: int = 1,
|
|
205
238
|
) -> Generator[None, None, None]:
|
|
206
239
|
"""Context manager to predict in batches.
|
|
@@ -220,31 +253,26 @@ class KumoRFM:
|
|
|
220
253
|
raise ValueError(f"'batch_size' must be greater than zero "
|
|
221
254
|
f"(got {batch_size})")
|
|
222
255
|
|
|
223
|
-
if num_retries < 0:
|
|
224
|
-
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
225
|
-
f"zero (got {num_retries})")
|
|
226
|
-
|
|
227
256
|
self._batch_size = batch_size
|
|
228
|
-
self.
|
|
229
|
-
|
|
257
|
+
with self.retry(self._num_retries or num_retries):
|
|
258
|
+
yield
|
|
230
259
|
self._batch_size = None
|
|
231
|
-
self.num_retries = 0
|
|
232
260
|
|
|
233
261
|
@overload
|
|
234
262
|
def predict(
|
|
235
263
|
self,
|
|
236
264
|
query: str,
|
|
237
|
-
indices:
|
|
265
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
238
266
|
*,
|
|
239
267
|
explain: Literal[False] = False,
|
|
240
|
-
anchor_time:
|
|
241
|
-
context_anchor_time:
|
|
242
|
-
run_mode:
|
|
243
|
-
num_neighbors:
|
|
268
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
269
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
270
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
271
|
+
num_neighbors: list[int] | None = None,
|
|
244
272
|
num_hops: int = 2,
|
|
245
273
|
max_pq_iterations: int = 10,
|
|
246
|
-
random_seed:
|
|
247
|
-
verbose:
|
|
274
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
275
|
+
verbose: bool | ProgressLogger = True,
|
|
248
276
|
use_prediction_time: bool = False,
|
|
249
277
|
) -> pd.DataFrame:
|
|
250
278
|
pass
|
|
@@ -253,37 +281,56 @@ class KumoRFM:
|
|
|
253
281
|
def predict(
|
|
254
282
|
self,
|
|
255
283
|
query: str,
|
|
256
|
-
indices:
|
|
284
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
257
285
|
*,
|
|
258
|
-
explain:
|
|
259
|
-
anchor_time:
|
|
260
|
-
context_anchor_time:
|
|
261
|
-
run_mode:
|
|
262
|
-
num_neighbors:
|
|
286
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
287
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
288
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
289
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
290
|
+
num_neighbors: list[int] | None = None,
|
|
263
291
|
num_hops: int = 2,
|
|
264
292
|
max_pq_iterations: int = 10,
|
|
265
|
-
random_seed:
|
|
266
|
-
verbose:
|
|
293
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
294
|
+
verbose: bool | ProgressLogger = True,
|
|
267
295
|
use_prediction_time: bool = False,
|
|
268
296
|
) -> Explanation:
|
|
269
297
|
pass
|
|
270
298
|
|
|
299
|
+
@overload
|
|
300
|
+
def predict(
|
|
301
|
+
self,
|
|
302
|
+
query: str,
|
|
303
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
304
|
+
*,
|
|
305
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
306
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
307
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
308
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
309
|
+
num_neighbors: list[int] | None = None,
|
|
310
|
+
num_hops: int = 2,
|
|
311
|
+
max_pq_iterations: int = 10,
|
|
312
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
313
|
+
verbose: bool | ProgressLogger = True,
|
|
314
|
+
use_prediction_time: bool = False,
|
|
315
|
+
) -> pd.DataFrame | Explanation:
|
|
316
|
+
pass
|
|
317
|
+
|
|
271
318
|
def predict(
|
|
272
319
|
self,
|
|
273
320
|
query: str,
|
|
274
|
-
indices:
|
|
321
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
275
322
|
*,
|
|
276
|
-
explain:
|
|
277
|
-
anchor_time:
|
|
278
|
-
context_anchor_time:
|
|
279
|
-
run_mode:
|
|
280
|
-
num_neighbors:
|
|
323
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
324
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
325
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
326
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
327
|
+
num_neighbors: list[int] | None = None,
|
|
281
328
|
num_hops: int = 2,
|
|
282
329
|
max_pq_iterations: int = 10,
|
|
283
|
-
random_seed:
|
|
284
|
-
verbose:
|
|
330
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
331
|
+
verbose: bool | ProgressLogger = True,
|
|
285
332
|
use_prediction_time: bool = False,
|
|
286
|
-
) ->
|
|
333
|
+
) -> pd.DataFrame | Explanation:
|
|
287
334
|
"""Returns predictions for a predictive query.
|
|
288
335
|
|
|
289
336
|
Args:
|
|
@@ -291,8 +338,7 @@ class KumoRFM:
|
|
|
291
338
|
indices: The entity primary keys to predict for. Will override the
|
|
292
339
|
indices given as part of the predictive query. Predictions will
|
|
293
340
|
be generated for all indices, independent of whether they
|
|
294
|
-
fulfill entity filter constraints.
|
|
295
|
-
:meth:`~KumoRFM.is_valid_entity`.
|
|
341
|
+
fulfill entity filter constraints.
|
|
296
342
|
explain: Configuration for explainability.
|
|
297
343
|
If set to ``True``, will additionally explain the prediction.
|
|
298
344
|
Passing in an :class:`ExplainConfig` instance provides control
|
|
@@ -325,18 +371,152 @@ class KumoRFM:
|
|
|
325
371
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
326
372
|
containing the prediction, summary, and details.
|
|
327
373
|
"""
|
|
328
|
-
explain_config: Optional[ExplainConfig] = None
|
|
329
|
-
if explain is True:
|
|
330
|
-
explain_config = ExplainConfig()
|
|
331
|
-
elif explain is not False:
|
|
332
|
-
explain_config = ExplainConfig._cast(explain)
|
|
333
|
-
|
|
334
374
|
query_def = self._parse_query(query)
|
|
335
|
-
query_str = query_def.to_string()
|
|
336
375
|
|
|
376
|
+
if indices is None:
|
|
377
|
+
if query_def.rfm_entity_ids is None:
|
|
378
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
379
|
+
"pass them via `predict(query, indices=...)`")
|
|
380
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
381
|
+
query_def = replace(
|
|
382
|
+
query_def,
|
|
383
|
+
for_each='FOR EACH',
|
|
384
|
+
rfm_entity_ids=None,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if not isinstance(verbose, ProgressLogger):
|
|
388
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
389
|
+
if explain is not False:
|
|
390
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
391
|
+
else:
|
|
392
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
393
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
394
|
+
|
|
395
|
+
with verbose as logger:
|
|
396
|
+
task_table = self._get_task_table(
|
|
397
|
+
query=query_def,
|
|
398
|
+
indices=indices,
|
|
399
|
+
anchor_time=anchor_time,
|
|
400
|
+
context_anchor_time=context_anchor_time,
|
|
401
|
+
run_mode=run_mode,
|
|
402
|
+
max_pq_iterations=max_pq_iterations,
|
|
403
|
+
random_seed=random_seed,
|
|
404
|
+
logger=logger,
|
|
405
|
+
)
|
|
406
|
+
task_table._query = query_def.to_string()
|
|
407
|
+
|
|
408
|
+
return self.predict_task(
|
|
409
|
+
task_table,
|
|
410
|
+
explain=explain,
|
|
411
|
+
run_mode=run_mode,
|
|
412
|
+
num_neighbors=num_neighbors,
|
|
413
|
+
num_hops=num_hops,
|
|
414
|
+
verbose=verbose,
|
|
415
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
416
|
+
use_prediction_time=use_prediction_time,
|
|
417
|
+
top_k=query_def.top_k,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
@overload
|
|
421
|
+
def predict_task(
|
|
422
|
+
self,
|
|
423
|
+
task: TaskTable,
|
|
424
|
+
*,
|
|
425
|
+
explain: Literal[False] = False,
|
|
426
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
427
|
+
num_neighbors: list[int] | None = None,
|
|
428
|
+
num_hops: int = 2,
|
|
429
|
+
verbose: bool | ProgressLogger = True,
|
|
430
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
431
|
+
use_prediction_time: bool = False,
|
|
432
|
+
top_k: int | None = None,
|
|
433
|
+
) -> pd.DataFrame:
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
@overload
|
|
437
|
+
def predict_task(
|
|
438
|
+
self,
|
|
439
|
+
task: TaskTable,
|
|
440
|
+
*,
|
|
441
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
442
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
443
|
+
num_neighbors: list[int] | None = None,
|
|
444
|
+
num_hops: int = 2,
|
|
445
|
+
verbose: bool | ProgressLogger = True,
|
|
446
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
447
|
+
use_prediction_time: bool = False,
|
|
448
|
+
top_k: int | None = None,
|
|
449
|
+
) -> Explanation:
|
|
450
|
+
pass
|
|
451
|
+
|
|
452
|
+
@overload
|
|
453
|
+
def predict_task(
|
|
454
|
+
self,
|
|
455
|
+
task: TaskTable,
|
|
456
|
+
*,
|
|
457
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
458
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
459
|
+
num_neighbors: list[int] | None = None,
|
|
460
|
+
num_hops: int = 2,
|
|
461
|
+
verbose: bool | ProgressLogger = True,
|
|
462
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
463
|
+
use_prediction_time: bool = False,
|
|
464
|
+
top_k: int | None = None,
|
|
465
|
+
) -> pd.DataFrame | Explanation:
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
def predict_task(
|
|
469
|
+
self,
|
|
470
|
+
task: TaskTable,
|
|
471
|
+
*,
|
|
472
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
473
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
474
|
+
num_neighbors: list[int] | None = None,
|
|
475
|
+
num_hops: int = 2,
|
|
476
|
+
verbose: bool | ProgressLogger = True,
|
|
477
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
478
|
+
use_prediction_time: bool = False,
|
|
479
|
+
top_k: int | None = None,
|
|
480
|
+
) -> pd.DataFrame | Explanation:
|
|
481
|
+
"""Returns predictions for a custom task specification.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
task: The custom :class:`TaskTable`.
|
|
485
|
+
explain: Configuration for explainability.
|
|
486
|
+
If set to ``True``, will additionally explain the prediction.
|
|
487
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
488
|
+
over which parts of explanation are generated.
|
|
489
|
+
Explainability is currently only supported for single entity
|
|
490
|
+
predictions with ``run_mode="FAST"``.
|
|
491
|
+
run_mode: The :class:`RunMode` for the query.
|
|
492
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
493
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
494
|
+
num_hops: The number of hops to sample when generating the context.
|
|
495
|
+
verbose: Whether to print verbose output.
|
|
496
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
497
|
+
model input.
|
|
498
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
499
|
+
additional feature during prediction. This is typically
|
|
500
|
+
beneficial for time series forecasting tasks.
|
|
501
|
+
top_k: The number of predictions to return per entity.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
505
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
506
|
+
containing the prediction, summary, and details.
|
|
507
|
+
"""
|
|
337
508
|
if num_hops != 2 and num_neighbors is not None:
|
|
338
509
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
339
510
|
f"custom 'num_hops={num_hops}' option")
|
|
511
|
+
if num_neighbors is None:
|
|
512
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
513
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
514
|
+
|
|
515
|
+
explain_config: ExplainConfig | None = None
|
|
516
|
+
if explain is True:
|
|
517
|
+
explain_config = ExplainConfig()
|
|
518
|
+
elif explain is not False:
|
|
519
|
+
explain_config = ExplainConfig._cast(explain)
|
|
340
520
|
|
|
341
521
|
if explain_config is not None and run_mode in {
|
|
342
522
|
RunMode.NORMAL, RunMode.BEST
|
|
@@ -345,83 +525,82 @@ class KumoRFM:
|
|
|
345
525
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
346
526
|
f"mode has been reset. Please lower the run mode to "
|
|
347
527
|
f"suppress this warning.")
|
|
528
|
+
run_mode = RunMode.FAST
|
|
348
529
|
|
|
349
|
-
if
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
354
|
-
else:
|
|
355
|
-
query_def = replace(query_def, rfm_entity_ids=None)
|
|
356
|
-
|
|
357
|
-
if len(indices) == 0:
|
|
358
|
-
raise ValueError("At least one entity is required")
|
|
359
|
-
|
|
360
|
-
if explain_config is not None and len(indices) > 1:
|
|
361
|
-
raise ValueError(
|
|
362
|
-
f"Cannot explain predictions for more than a single entity "
|
|
363
|
-
f"(got {len(indices)})")
|
|
364
|
-
|
|
365
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
366
|
-
if explain_config is not None:
|
|
367
|
-
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
368
|
-
else:
|
|
369
|
-
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
530
|
+
if explain_config is not None and task.num_prediction_examples > 1:
|
|
531
|
+
raise ValueError(f"Cannot explain predictions for more than a "
|
|
532
|
+
f"single entity "
|
|
533
|
+
f"(got {task.num_prediction_examples:,})")
|
|
370
534
|
|
|
371
535
|
if not isinstance(verbose, ProgressLogger):
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
edge_types=self._sampler.edge_types,
|
|
381
|
-
)
|
|
382
|
-
batch_size = _MAX_PRED_SIZE[task_type]
|
|
536
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
537
|
+
task_type_repr = 'binary classification'
|
|
538
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
539
|
+
task_type_repr = 'multi-class classification'
|
|
540
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
541
|
+
task_type_repr = 'regression'
|
|
542
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
543
|
+
task_type_repr = 'link prediction'
|
|
383
544
|
else:
|
|
384
|
-
|
|
545
|
+
task_type_repr = str(task.task_type)
|
|
385
546
|
|
|
386
|
-
if
|
|
387
|
-
|
|
388
|
-
batches = [indices[step:step + batch_size] for step in offsets]
|
|
547
|
+
if explain_config is not None:
|
|
548
|
+
msg = f"Explaining {task_type_repr} task"
|
|
389
549
|
else:
|
|
390
|
-
|
|
550
|
+
msg = f"Predicting {task_type_repr} task"
|
|
551
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
391
552
|
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
553
|
+
with verbose as logger:
|
|
554
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
555
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
556
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
557
|
+
f"examples")
|
|
558
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
559
|
+
|
|
560
|
+
if self._batch_size is None:
|
|
561
|
+
batch_size = task.num_prediction_examples
|
|
562
|
+
elif self._batch_size == 'max':
|
|
563
|
+
batch_size = _MAX_PRED_SIZE[task.task_type]
|
|
564
|
+
else:
|
|
565
|
+
batch_size = self._batch_size
|
|
395
566
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
567
|
+
if batch_size > _MAX_PRED_SIZE[task.task_type]:
|
|
568
|
+
raise ValueError(f"Cannot predict for more than "
|
|
569
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
570
|
+
f"entities at once (got {batch_size:,}). Use "
|
|
571
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
572
|
+
f"in batches with a sufficient batch size.")
|
|
573
|
+
|
|
574
|
+
if task.num_prediction_examples > batch_size:
|
|
575
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
576
|
+
logger.log(f"Splitting {task.num_prediction_examples:,} "
|
|
577
|
+
f"entities into {num:,} batches of size "
|
|
578
|
+
f"{batch_size:,}")
|
|
579
|
+
|
|
580
|
+
predictions: list[pd.DataFrame] = []
|
|
581
|
+
summary: str | None = None
|
|
582
|
+
details: Explanation | None = None
|
|
583
|
+
for start in range(0, task.num_prediction_examples, batch_size):
|
|
401
584
|
context = self._get_context(
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
anchor_time=anchor_time,
|
|
405
|
-
context_anchor_time=context_anchor_time,
|
|
406
|
-
run_mode=RunMode(run_mode),
|
|
585
|
+
task=task.narrow_prediction(start, length=batch_size),
|
|
586
|
+
run_mode=run_mode,
|
|
407
587
|
num_neighbors=num_neighbors,
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
evaluate=False,
|
|
411
|
-
random_seed=random_seed,
|
|
412
|
-
logger=logger if i == 0 else None,
|
|
588
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
589
|
+
top_k=top_k,
|
|
413
590
|
)
|
|
591
|
+
context.y_test = None
|
|
592
|
+
|
|
414
593
|
request = RFMPredictRequest(
|
|
415
594
|
context=context,
|
|
416
595
|
run_mode=RunMode(run_mode),
|
|
417
|
-
query=
|
|
596
|
+
query=task._query,
|
|
418
597
|
use_prediction_time=use_prediction_time,
|
|
419
598
|
)
|
|
420
599
|
with warnings.catch_warnings():
|
|
421
600
|
warnings.filterwarnings('ignore', message='gencode')
|
|
422
601
|
request_msg = request.to_protobuf()
|
|
423
602
|
_bytes = request_msg.SerializeToString()
|
|
424
|
-
if
|
|
603
|
+
if start == 0:
|
|
425
604
|
logger.log(f"Generated context of size "
|
|
426
605
|
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
427
606
|
|
|
@@ -429,14 +608,11 @@ class KumoRFM:
|
|
|
429
608
|
stats = Context.get_memory_stats(request_msg.context)
|
|
430
609
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
431
610
|
|
|
432
|
-
if
|
|
433
|
-
|
|
434
|
-
verbose.init_progress(
|
|
435
|
-
total=len(batches),
|
|
436
|
-
description='Predicting',
|
|
437
|
-
)
|
|
611
|
+
if start == 0 and task.num_prediction_examples > batch_size:
|
|
612
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
613
|
+
verbose.init_progress(total=num, description='Predicting')
|
|
438
614
|
|
|
439
|
-
for attempt in range(self.
|
|
615
|
+
for attempt in range(self._num_retries + 1):
|
|
440
616
|
try:
|
|
441
617
|
if explain_config is not None:
|
|
442
618
|
resp = self._api_client.explain(
|
|
@@ -452,7 +628,7 @@ class KumoRFM:
|
|
|
452
628
|
# Cast 'ENTITY' to correct data type:
|
|
453
629
|
if 'ENTITY' in df:
|
|
454
630
|
table_dict = context.subgraph.table_dict
|
|
455
|
-
table = table_dict[
|
|
631
|
+
table = table_dict[context.entity_table_names[0]]
|
|
456
632
|
ser = table.df[table.primary_key]
|
|
457
633
|
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
458
634
|
|
|
@@ -469,13 +645,12 @@ class KumoRFM:
|
|
|
469
645
|
|
|
470
646
|
predictions.append(df)
|
|
471
647
|
|
|
472
|
-
if
|
|
473
|
-
and len(batches) > 1):
|
|
648
|
+
if task.num_prediction_examples > batch_size:
|
|
474
649
|
verbose.step()
|
|
475
650
|
|
|
476
651
|
break
|
|
477
652
|
except HTTPException as e:
|
|
478
|
-
if attempt == self.
|
|
653
|
+
if attempt == self._num_retries:
|
|
479
654
|
try:
|
|
480
655
|
msg = json.loads(e.detail)['detail']
|
|
481
656
|
except Exception:
|
|
@@ -505,64 +680,19 @@ class KumoRFM:
|
|
|
505
680
|
|
|
506
681
|
return prediction
|
|
507
682
|
|
|
508
|
-
def is_valid_entity(
|
|
509
|
-
self,
|
|
510
|
-
query: str,
|
|
511
|
-
indices: Union[List[str], List[float], List[int], None] = None,
|
|
512
|
-
*,
|
|
513
|
-
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
514
|
-
) -> np.ndarray:
|
|
515
|
-
r"""Returns a mask that denotes which entities are valid for the
|
|
516
|
-
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
517
|
-
entity filter constraints.
|
|
518
|
-
|
|
519
|
-
Args:
|
|
520
|
-
query: The predictive query.
|
|
521
|
-
indices: The entity primary keys to predict for. Will override the
|
|
522
|
-
indices given as part of the predictive query.
|
|
523
|
-
anchor_time: The anchor timestamp for the prediction. If set to
|
|
524
|
-
``None``, will use the maximum timestamp in the data.
|
|
525
|
-
If set to ``"entity"``, will use the timestamp of the entity.
|
|
526
|
-
"""
|
|
527
|
-
query_def = self._parse_query(query)
|
|
528
|
-
|
|
529
|
-
if indices is None:
|
|
530
|
-
if query_def.rfm_entity_ids is None:
|
|
531
|
-
raise ValueError("Cannot find entities to predict for. Please "
|
|
532
|
-
"pass them via "
|
|
533
|
-
"`is_valid_entity(query, indices=...)`")
|
|
534
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
535
|
-
|
|
536
|
-
if len(indices) == 0:
|
|
537
|
-
raise ValueError("At least one entity is required")
|
|
538
|
-
|
|
539
|
-
if anchor_time is None:
|
|
540
|
-
anchor_time = self._get_default_anchor_time(query_def)
|
|
541
|
-
|
|
542
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
543
|
-
self._validate_time(query_def, anchor_time, None, False)
|
|
544
|
-
else:
|
|
545
|
-
assert anchor_time == 'entity'
|
|
546
|
-
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
547
|
-
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
548
|
-
f"table '{query_def.entity_table}' "
|
|
549
|
-
f"to have a time column.")
|
|
550
|
-
|
|
551
|
-
raise NotImplementedError
|
|
552
|
-
|
|
553
683
|
def evaluate(
|
|
554
684
|
self,
|
|
555
685
|
query: str,
|
|
556
686
|
*,
|
|
557
|
-
metrics:
|
|
558
|
-
anchor_time:
|
|
559
|
-
context_anchor_time:
|
|
560
|
-
run_mode:
|
|
561
|
-
num_neighbors:
|
|
687
|
+
metrics: list[str] | None = None,
|
|
688
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
689
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
690
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
691
|
+
num_neighbors: list[int] | None = None,
|
|
562
692
|
num_hops: int = 2,
|
|
563
693
|
max_pq_iterations: int = 10,
|
|
564
|
-
random_seed:
|
|
565
|
-
verbose:
|
|
694
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
695
|
+
verbose: bool | ProgressLogger = True,
|
|
566
696
|
use_prediction_time: bool = False,
|
|
567
697
|
) -> pd.DataFrame:
|
|
568
698
|
"""Evaluates a predictive query.
|
|
@@ -594,41 +724,120 @@ class KumoRFM:
|
|
|
594
724
|
Returns:
|
|
595
725
|
The metrics as a :class:`pandas.DataFrame`
|
|
596
726
|
"""
|
|
597
|
-
query_def =
|
|
727
|
+
query_def = replace(
|
|
728
|
+
self._parse_query(query),
|
|
729
|
+
for_each='FOR EACH',
|
|
730
|
+
rfm_entity_ids=None,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
if not isinstance(verbose, ProgressLogger):
|
|
734
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
735
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
736
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
737
|
+
|
|
738
|
+
with verbose as logger:
|
|
739
|
+
task_table = self._get_task_table(
|
|
740
|
+
query=query_def,
|
|
741
|
+
indices=None,
|
|
742
|
+
anchor_time=anchor_time,
|
|
743
|
+
context_anchor_time=context_anchor_time,
|
|
744
|
+
run_mode=run_mode,
|
|
745
|
+
max_pq_iterations=max_pq_iterations,
|
|
746
|
+
random_seed=random_seed,
|
|
747
|
+
logger=logger,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
return self.evaluate_task(
|
|
751
|
+
task_table,
|
|
752
|
+
metrics=metrics,
|
|
753
|
+
run_mode=run_mode,
|
|
754
|
+
num_neighbors=num_neighbors,
|
|
755
|
+
num_hops=num_hops,
|
|
756
|
+
verbose=verbose,
|
|
757
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
758
|
+
use_prediction_time=use_prediction_time,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
def evaluate_task(
|
|
762
|
+
self,
|
|
763
|
+
task: TaskTable,
|
|
764
|
+
*,
|
|
765
|
+
metrics: list[str] | None = None,
|
|
766
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
767
|
+
num_neighbors: list[int] | None = None,
|
|
768
|
+
num_hops: int = 2,
|
|
769
|
+
verbose: bool | ProgressLogger = True,
|
|
770
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
771
|
+
use_prediction_time: bool = False,
|
|
772
|
+
) -> pd.DataFrame:
|
|
773
|
+
"""Evaluates a custom task specification.
|
|
774
|
+
|
|
775
|
+
Args:
|
|
776
|
+
task: The custom :class:`TaskTable`.
|
|
777
|
+
metrics: The metrics to use.
|
|
778
|
+
run_mode: The :class:`RunMode` for the query.
|
|
779
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
780
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
781
|
+
num_hops: The number of hops to sample when generating the context.
|
|
782
|
+
verbose: Whether to print verbose output.
|
|
783
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
784
|
+
model input.
|
|
785
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
786
|
+
additional feature during prediction. This is typically
|
|
787
|
+
beneficial for time series forecasting tasks.
|
|
598
788
|
|
|
789
|
+
Returns:
|
|
790
|
+
The metrics as a :class:`pandas.DataFrame`
|
|
791
|
+
"""
|
|
599
792
|
if num_hops != 2 and num_neighbors is not None:
|
|
600
793
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
601
794
|
f"custom 'num_hops={num_hops}' option")
|
|
795
|
+
if num_neighbors is None:
|
|
796
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
797
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
602
798
|
|
|
603
|
-
if
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
rfm_entity_ids=None,
|
|
607
|
-
)
|
|
608
|
-
|
|
609
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
610
|
-
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
799
|
+
if metrics is not None and len(metrics) > 0:
|
|
800
|
+
self._validate_metrics(metrics, task.task_type)
|
|
801
|
+
metrics = list(dict.fromkeys(metrics))
|
|
611
802
|
|
|
612
803
|
if not isinstance(verbose, ProgressLogger):
|
|
613
|
-
|
|
804
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
805
|
+
task_type_repr = 'binary classification'
|
|
806
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
807
|
+
task_type_repr = 'multi-class classification'
|
|
808
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
809
|
+
task_type_repr = 'regression'
|
|
810
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
811
|
+
task_type_repr = 'link prediction'
|
|
812
|
+
else:
|
|
813
|
+
task_type_repr = str(task.task_type)
|
|
814
|
+
|
|
815
|
+
msg = f"Evaluating {task_type_repr} task"
|
|
816
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
614
817
|
|
|
615
818
|
with verbose as logger:
|
|
819
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
820
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
821
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
822
|
+
f"examples")
|
|
823
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
824
|
+
|
|
825
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
826
|
+
logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
|
|
827
|
+
f"out of {task.num_prediction_examples:,} test "
|
|
828
|
+
f"examples")
|
|
829
|
+
task = task.narrow_prediction(
|
|
830
|
+
start=0,
|
|
831
|
+
length=_MAX_TEST_SIZE[task.task_type],
|
|
832
|
+
)
|
|
833
|
+
|
|
616
834
|
context = self._get_context(
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
anchor_time=anchor_time,
|
|
620
|
-
context_anchor_time=context_anchor_time,
|
|
621
|
-
run_mode=RunMode(run_mode),
|
|
835
|
+
task=task,
|
|
836
|
+
run_mode=run_mode,
|
|
622
837
|
num_neighbors=num_neighbors,
|
|
623
|
-
|
|
624
|
-
max_pq_iterations=max_pq_iterations,
|
|
625
|
-
evaluate=True,
|
|
626
|
-
random_seed=random_seed,
|
|
627
|
-
logger=logger if verbose else None,
|
|
838
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
628
839
|
)
|
|
629
|
-
|
|
630
|
-
self._validate_metrics(metrics, context.task_type)
|
|
631
|
-
metrics = list(dict.fromkeys(metrics))
|
|
840
|
+
|
|
632
841
|
request = RFMEvaluateRequest(
|
|
633
842
|
context=context,
|
|
634
843
|
run_mode=RunMode(run_mode),
|
|
@@ -646,17 +855,23 @@ class KumoRFM:
|
|
|
646
855
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
647
856
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
648
857
|
|
|
649
|
-
|
|
650
|
-
resp = self._api_client.evaluate(request_bytes)
|
|
651
|
-
except HTTPException as e:
|
|
858
|
+
for attempt in range(self._num_retries + 1):
|
|
652
859
|
try:
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
860
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
861
|
+
break
|
|
862
|
+
except HTTPException as e:
|
|
863
|
+
if attempt == self._num_retries:
|
|
864
|
+
try:
|
|
865
|
+
msg = json.loads(e.detail)['detail']
|
|
866
|
+
except Exception:
|
|
867
|
+
msg = e.detail
|
|
868
|
+
raise RuntimeError(
|
|
869
|
+
f"An unexpected exception occurred. Please create "
|
|
870
|
+
f"an issue at "
|
|
871
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
872
|
+
) from None
|
|
873
|
+
|
|
874
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
660
875
|
|
|
661
876
|
return pd.DataFrame.from_dict(
|
|
662
877
|
resp.metrics,
|
|
@@ -669,8 +884,8 @@ class KumoRFM:
|
|
|
669
884
|
query: str,
|
|
670
885
|
size: int,
|
|
671
886
|
*,
|
|
672
|
-
anchor_time:
|
|
673
|
-
random_seed:
|
|
887
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
888
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
674
889
|
max_iterations: int = 10,
|
|
675
890
|
) -> pd.DataFrame:
|
|
676
891
|
"""Returns the labels of a predictive query for a specified anchor
|
|
@@ -708,7 +923,7 @@ class KumoRFM:
|
|
|
708
923
|
f"to have a time column")
|
|
709
924
|
|
|
710
925
|
train, test = self._sampler.sample_target(
|
|
711
|
-
query=
|
|
926
|
+
query=query_def,
|
|
712
927
|
num_train_examples=0,
|
|
713
928
|
train_anchor_time=anchor_time,
|
|
714
929
|
num_train_trials=0,
|
|
@@ -736,35 +951,39 @@ class KumoRFM:
|
|
|
736
951
|
"`predict()` or `evaluate()` methods to perform "
|
|
737
952
|
"predictions or evaluations.")
|
|
738
953
|
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
)
|
|
954
|
+
request = RFMParseQueryRequest(
|
|
955
|
+
query=query,
|
|
956
|
+
graph_definition=self._graph_def,
|
|
957
|
+
)
|
|
744
958
|
|
|
745
|
-
|
|
959
|
+
for attempt in range(self._num_retries + 1):
|
|
960
|
+
try:
|
|
961
|
+
resp = self._api_client.parse_query(request)
|
|
962
|
+
break
|
|
963
|
+
except HTTPException as e:
|
|
964
|
+
if attempt == self._num_retries:
|
|
965
|
+
try:
|
|
966
|
+
msg = json.loads(e.detail)['detail']
|
|
967
|
+
except Exception:
|
|
968
|
+
msg = e.detail
|
|
969
|
+
raise ValueError(f"Failed to parse query '{query}'. {msg}")
|
|
746
970
|
|
|
747
|
-
|
|
748
|
-
msg = '\n'.join([
|
|
749
|
-
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
750
|
-
in enumerate(resp.validation_response.warnings)
|
|
751
|
-
])
|
|
752
|
-
warnings.warn(f"Encountered the following warnings during "
|
|
753
|
-
f"parsing:\n{msg}")
|
|
971
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
754
972
|
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
973
|
+
if len(resp.validation_response.warnings) > 0:
|
|
974
|
+
msg = '\n'.join([
|
|
975
|
+
f'{i+1}. {warning.title}: {warning.message}'
|
|
976
|
+
for i, warning in enumerate(resp.validation_response.warnings)
|
|
977
|
+
])
|
|
978
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
979
|
+
f"parsing:\n{msg}")
|
|
980
|
+
|
|
981
|
+
return resp.query
|
|
763
982
|
|
|
764
983
|
@staticmethod
|
|
765
984
|
def _get_task_type(
|
|
766
985
|
query: ValidatedPredictiveQuery,
|
|
767
|
-
edge_types:
|
|
986
|
+
edge_types: list[tuple[str, str, str]],
|
|
768
987
|
) -> TaskType:
|
|
769
988
|
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
770
989
|
return TaskType.BINARY_CLASSIFICATION
|
|
@@ -803,31 +1022,38 @@ class KumoRFM:
|
|
|
803
1022
|
|
|
804
1023
|
def _get_default_anchor_time(
|
|
805
1024
|
self,
|
|
806
|
-
query: ValidatedPredictiveQuery,
|
|
1025
|
+
query: ValidatedPredictiveQuery | None = None,
|
|
807
1026
|
) -> pd.Timestamp:
|
|
808
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1027
|
+
if query is not None and query.query_type == QueryType.TEMPORAL:
|
|
809
1028
|
aggr_table_names = [
|
|
810
1029
|
aggr._get_target_column_name().split('.')[0]
|
|
811
1030
|
for aggr in query.get_all_target_aggregations()
|
|
812
1031
|
]
|
|
813
1032
|
return self._sampler.get_max_time(aggr_table_names)
|
|
814
1033
|
|
|
815
|
-
assert query.query_type == QueryType.STATIC
|
|
816
1034
|
return self._sampler.get_max_time()
|
|
817
1035
|
|
|
818
1036
|
def _validate_time(
|
|
819
1037
|
self,
|
|
820
1038
|
query: ValidatedPredictiveQuery,
|
|
821
1039
|
anchor_time: pd.Timestamp,
|
|
822
|
-
context_anchor_time:
|
|
1040
|
+
context_anchor_time: pd.Timestamp | None,
|
|
823
1041
|
evaluate: bool,
|
|
824
1042
|
) -> None:
|
|
825
1043
|
|
|
826
1044
|
if len(self._sampler.time_column_dict) == 0:
|
|
827
1045
|
return # Graph without timestamps
|
|
828
1046
|
|
|
829
|
-
|
|
830
|
-
|
|
1047
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
1048
|
+
aggr_table_names = [
|
|
1049
|
+
aggr._get_target_column_name().split('.')[0]
|
|
1050
|
+
for aggr in query.get_all_target_aggregations()
|
|
1051
|
+
]
|
|
1052
|
+
min_time = self._sampler.get_min_time(aggr_table_names)
|
|
1053
|
+
max_time = self._sampler.get_max_time(aggr_table_names)
|
|
1054
|
+
else:
|
|
1055
|
+
min_time = self._sampler.get_min_time()
|
|
1056
|
+
max_time = self._sampler.get_max_time()
|
|
831
1057
|
|
|
832
1058
|
if anchor_time < min_time:
|
|
833
1059
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
@@ -882,40 +1108,26 @@ class KumoRFM:
|
|
|
882
1108
|
f"Anchor timestamp for evaluation is after the latest "
|
|
883
1109
|
f"supported timestamp '{max_time - end_offset}'.")
|
|
884
1110
|
|
|
885
|
-
def
|
|
1111
|
+
def _get_task_table(
|
|
886
1112
|
self,
|
|
887
1113
|
query: ValidatedPredictiveQuery,
|
|
888
|
-
indices:
|
|
889
|
-
anchor_time:
|
|
890
|
-
context_anchor_time:
|
|
891
|
-
run_mode: RunMode,
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
random_seed: Optional[int] = _RANDOM_SEED,
|
|
897
|
-
logger: Optional[ProgressLogger] = None,
|
|
898
|
-
) -> Context:
|
|
899
|
-
|
|
900
|
-
if num_neighbors is not None:
|
|
901
|
-
num_hops = len(num_neighbors)
|
|
902
|
-
|
|
903
|
-
if num_hops < 0:
|
|
904
|
-
raise ValueError(f"'num_hops' must be non-negative "
|
|
905
|
-
f"(got {num_hops})")
|
|
906
|
-
if num_hops > 6:
|
|
907
|
-
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
908
|
-
f"hops (got {num_hops}). Please reduce the "
|
|
909
|
-
f"number of hops and try again. Please create a "
|
|
910
|
-
f"feature request at "
|
|
911
|
-
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
912
|
-
f"must go beyond this for your use-case.")
|
|
1114
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
1115
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
1116
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
1117
|
+
run_mode: RunMode = RunMode.FAST,
|
|
1118
|
+
max_pq_iterations: int = 10,
|
|
1119
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
1120
|
+
logger: ProgressLogger | None = None,
|
|
1121
|
+
) -> TaskTable:
|
|
913
1122
|
|
|
914
1123
|
task_type = self._get_task_type(
|
|
915
1124
|
query=query,
|
|
916
1125
|
edge_types=self._sampler.edge_types,
|
|
917
1126
|
)
|
|
918
1127
|
|
|
1128
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1129
|
+
num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
|
|
1130
|
+
|
|
919
1131
|
if logger is not None:
|
|
920
1132
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
921
1133
|
task_type_repr = 'binary classification'
|
|
@@ -929,21 +1141,6 @@ class KumoRFM:
|
|
|
929
1141
|
task_type_repr = str(task_type)
|
|
930
1142
|
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
931
1143
|
|
|
932
|
-
if task_type.is_link_pred and num_hops < 2:
|
|
933
|
-
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
934
|
-
f"with less than 2 hops (got {num_hops}) since "
|
|
935
|
-
f"historical target entities need to be part of "
|
|
936
|
-
f"the context. Please increase the number of "
|
|
937
|
-
f"hops and try again.")
|
|
938
|
-
|
|
939
|
-
if num_neighbors is None:
|
|
940
|
-
if run_mode == RunMode.DEBUG:
|
|
941
|
-
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
942
|
-
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
943
|
-
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
944
|
-
else:
|
|
945
|
-
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
946
|
-
|
|
947
1144
|
if query.target_ast.date_offset_range is None:
|
|
948
1145
|
step_offset = pd.DateOffset(0)
|
|
949
1146
|
else:
|
|
@@ -952,8 +1149,7 @@ class KumoRFM:
|
|
|
952
1149
|
|
|
953
1150
|
if anchor_time is None:
|
|
954
1151
|
anchor_time = self._get_default_anchor_time(query)
|
|
955
|
-
|
|
956
|
-
if evaluate:
|
|
1152
|
+
if num_test_examples > 0:
|
|
957
1153
|
anchor_time = anchor_time - end_offset
|
|
958
1154
|
|
|
959
1155
|
if logger is not None:
|
|
@@ -967,7 +1163,6 @@ class KumoRFM:
|
|
|
967
1163
|
else:
|
|
968
1164
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
969
1165
|
|
|
970
|
-
assert anchor_time is not None
|
|
971
1166
|
if isinstance(anchor_time, pd.Timestamp):
|
|
972
1167
|
if context_anchor_time == 'entity':
|
|
973
1168
|
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
@@ -975,7 +1170,7 @@ class KumoRFM:
|
|
|
975
1170
|
if context_anchor_time is None:
|
|
976
1171
|
context_anchor_time = anchor_time - end_offset
|
|
977
1172
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
978
|
-
evaluate)
|
|
1173
|
+
evaluate=num_test_examples > 0)
|
|
979
1174
|
else:
|
|
980
1175
|
assert anchor_time == 'entity'
|
|
981
1176
|
if query.query_type != QueryType.STATIC:
|
|
@@ -990,14 +1185,6 @@ class KumoRFM:
|
|
|
990
1185
|
"for context and prediction examples")
|
|
991
1186
|
context_anchor_time = 'entity'
|
|
992
1187
|
|
|
993
|
-
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
994
|
-
if evaluate:
|
|
995
|
-
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
996
|
-
if task_type.is_link_pred:
|
|
997
|
-
num_test_examples = num_test_examples // 5
|
|
998
|
-
else:
|
|
999
|
-
num_test_examples = 0
|
|
1000
|
-
|
|
1001
1188
|
train, test = self._sampler.sample_target(
|
|
1002
1189
|
query=query,
|
|
1003
1190
|
num_train_examples=num_train_examples,
|
|
@@ -1008,39 +1195,32 @@ class KumoRFM:
|
|
|
1008
1195
|
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1009
1196
|
random_seed=random_seed,
|
|
1010
1197
|
)
|
|
1011
|
-
train_pkey, train_time,
|
|
1012
|
-
test_pkey, test_time,
|
|
1198
|
+
train_pkey, train_time, train_y = train
|
|
1199
|
+
test_pkey, test_time, test_y = test
|
|
1013
1200
|
|
|
1014
|
-
if
|
|
1201
|
+
if num_test_examples > 0 and logger is not None:
|
|
1015
1202
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1016
|
-
pos = 100 * int((
|
|
1017
|
-
msg = (f"Collected {len(
|
|
1203
|
+
pos = 100 * int((test_y > 0).sum()) / len(test_y)
|
|
1204
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1018
1205
|
f"{pos:.2f}% positive cases")
|
|
1019
1206
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1020
|
-
msg = (f"Collected {len(
|
|
1021
|
-
f"{
|
|
1207
|
+
msg = (f"Collected {len(test_y):,} test examples holding "
|
|
1208
|
+
f"{test_y.nunique()} classes")
|
|
1022
1209
|
elif task_type == TaskType.REGRESSION:
|
|
1023
|
-
_min, _max = float(
|
|
1024
|
-
msg = (f"Collected {len(
|
|
1210
|
+
_min, _max = float(test_y.min()), float(test_y.max())
|
|
1211
|
+
msg = (f"Collected {len(test_y):,} test examples with targets "
|
|
1025
1212
|
f"between {format_value(_min)} and "
|
|
1026
1213
|
f"{format_value(_max)}")
|
|
1027
1214
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1028
|
-
num_rhs =
|
|
1029
|
-
msg = (f"Collected {len(
|
|
1215
|
+
num_rhs = test_y.explode().nunique()
|
|
1216
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1030
1217
|
f"{num_rhs:,} unique items")
|
|
1031
1218
|
else:
|
|
1032
1219
|
raise NotImplementedError
|
|
1033
1220
|
logger.log(msg)
|
|
1034
1221
|
|
|
1035
|
-
if
|
|
1222
|
+
if num_test_examples == 0:
|
|
1036
1223
|
assert indices is not None
|
|
1037
|
-
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
1038
|
-
raise ValueError(f"Cannot predict for more than "
|
|
1039
|
-
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
1040
|
-
f"once (got {len(indices):,}). Use "
|
|
1041
|
-
f"`KumoRFM.batch_mode` to process entities "
|
|
1042
|
-
f"in batches")
|
|
1043
|
-
|
|
1044
1224
|
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
1045
1225
|
if isinstance(anchor_time, pd.Timestamp):
|
|
1046
1226
|
test_time = pd.Series([anchor_time]).repeat(
|
|
@@ -1050,26 +1230,26 @@ class KumoRFM:
|
|
|
1050
1230
|
|
|
1051
1231
|
if logger is not None:
|
|
1052
1232
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1053
|
-
pos = 100 * int((
|
|
1054
|
-
msg = (f"Collected {len(
|
|
1233
|
+
pos = 100 * int((train_y > 0).sum()) / len(train_y)
|
|
1234
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1055
1235
|
f"{pos:.2f}% positive cases")
|
|
1056
1236
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1057
|
-
msg = (f"Collected {len(
|
|
1058
|
-
f"holding {
|
|
1237
|
+
msg = (f"Collected {len(train_y):,} in-context examples "
|
|
1238
|
+
f"holding {train_y.nunique()} classes")
|
|
1059
1239
|
elif task_type == TaskType.REGRESSION:
|
|
1060
|
-
_min, _max = float(
|
|
1061
|
-
msg = (f"Collected {len(
|
|
1240
|
+
_min, _max = float(train_y.min()), float(train_y.max())
|
|
1241
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1062
1242
|
f"targets between {format_value(_min)} and "
|
|
1063
1243
|
f"{format_value(_max)}")
|
|
1064
1244
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1065
|
-
num_rhs =
|
|
1066
|
-
msg = (f"Collected {len(
|
|
1245
|
+
num_rhs = train_y.explode().nunique()
|
|
1246
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1067
1247
|
f"{num_rhs:,} unique items")
|
|
1068
1248
|
else:
|
|
1069
1249
|
raise NotImplementedError
|
|
1070
1250
|
logger.log(msg)
|
|
1071
1251
|
|
|
1072
|
-
entity_table_names:
|
|
1252
|
+
entity_table_names: tuple[str] | tuple[str, str]
|
|
1073
1253
|
if task_type.is_link_pred:
|
|
1074
1254
|
final_aggr = query.get_final_target_aggregation()
|
|
1075
1255
|
assert final_aggr is not None
|
|
@@ -1083,27 +1263,80 @@ class KumoRFM:
|
|
|
1083
1263
|
else:
|
|
1084
1264
|
entity_table_names = (query.entity_table, )
|
|
1085
1265
|
|
|
1266
|
+
context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
|
|
1267
|
+
if isinstance(train_time, pd.Series):
|
|
1268
|
+
context_df['ANCHOR_TIMESTAMP'] = train_time
|
|
1269
|
+
pred_df = pd.DataFrame({'ENTITY': test_pkey})
|
|
1270
|
+
if num_test_examples > 0:
|
|
1271
|
+
pred_df['TARGET'] = test_y
|
|
1272
|
+
if isinstance(test_time, pd.Series):
|
|
1273
|
+
pred_df['ANCHOR_TIMESTAMP'] = test_time
|
|
1274
|
+
|
|
1275
|
+
return TaskTable(
|
|
1276
|
+
task_type=task_type,
|
|
1277
|
+
context_df=context_df,
|
|
1278
|
+
pred_df=pred_df,
|
|
1279
|
+
entity_table_name=entity_table_names,
|
|
1280
|
+
entity_column='ENTITY',
|
|
1281
|
+
target_column='TARGET',
|
|
1282
|
+
time_column='ANCHOR_TIMESTAMP' if isinstance(
|
|
1283
|
+
train_time, pd.Series) else TaskTable.ENTITY_TIME,
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
def _get_context(
|
|
1287
|
+
self,
|
|
1288
|
+
task: TaskTable,
|
|
1289
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
1290
|
+
num_neighbors: list[int] | None = None,
|
|
1291
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
1292
|
+
top_k: int | None = None,
|
|
1293
|
+
) -> Context:
|
|
1294
|
+
|
|
1295
|
+
if num_neighbors is None:
|
|
1296
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1297
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
|
1298
|
+
|
|
1299
|
+
if len(num_neighbors) > 6:
|
|
1300
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
1301
|
+
f"hops (got {len(num_neighbors)}). Reduce the "
|
|
1302
|
+
f"number of hops and try again. Please create a "
|
|
1303
|
+
f"feature request at "
|
|
1304
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1305
|
+
f"must go beyond this for your use-case.")
|
|
1306
|
+
|
|
1086
1307
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1087
1308
|
# running out-of-distribution between in-context and test examples:
|
|
1088
|
-
exclude_cols_dict =
|
|
1089
|
-
if
|
|
1090
|
-
if
|
|
1091
|
-
exclude_cols_dict[
|
|
1092
|
-
|
|
1093
|
-
exclude_cols_dict[
|
|
1309
|
+
exclude_cols_dict = exclude_cols_dict or {}
|
|
1310
|
+
if task.entity_table_name in self._sampler.time_column_dict:
|
|
1311
|
+
if task.entity_table_name not in exclude_cols_dict:
|
|
1312
|
+
exclude_cols_dict[task.entity_table_name] = []
|
|
1313
|
+
time_col = self._sampler.time_column_dict[task.entity_table_name]
|
|
1314
|
+
exclude_cols_dict[task.entity_table_name].append(time_col)
|
|
1315
|
+
|
|
1316
|
+
entity_pkey = pd.concat([
|
|
1317
|
+
task._context_df[task._entity_column],
|
|
1318
|
+
task._pred_df[task._entity_column],
|
|
1319
|
+
], axis=0, ignore_index=True)
|
|
1320
|
+
|
|
1321
|
+
if task.use_entity_time:
|
|
1322
|
+
if task.entity_table_name not in self._sampler.time_column_dict:
|
|
1323
|
+
raise ValueError(f"The given annchor time requires the entity "
|
|
1324
|
+
f"table '{task.entity_table_name}' to have a "
|
|
1325
|
+
f"time column")
|
|
1326
|
+
anchor_time = 'entity'
|
|
1327
|
+
elif task._time_column is not None:
|
|
1328
|
+
anchor_time = pd.concat([
|
|
1329
|
+
task._context_df[task._time_column],
|
|
1330
|
+
task._pred_df[task._time_column],
|
|
1331
|
+
], axis=0, ignore_index=True)
|
|
1332
|
+
else:
|
|
1333
|
+
anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
|
|
1334
|
+
(len(entity_pkey))).reset_index(drop=True)
|
|
1094
1335
|
|
|
1095
1336
|
subgraph = self._sampler.sample_subgraph(
|
|
1096
|
-
entity_table_names=entity_table_names,
|
|
1097
|
-
entity_pkey=
|
|
1098
|
-
|
|
1099
|
-
axis=0,
|
|
1100
|
-
ignore_index=True,
|
|
1101
|
-
),
|
|
1102
|
-
anchor_time=pd.concat(
|
|
1103
|
-
[train_time, test_time],
|
|
1104
|
-
axis=0,
|
|
1105
|
-
ignore_index=True,
|
|
1106
|
-
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1337
|
+
entity_table_names=task.entity_table_names,
|
|
1338
|
+
entity_pkey=entity_pkey,
|
|
1339
|
+
anchor_time=anchor_time,
|
|
1107
1340
|
num_neighbors=num_neighbors,
|
|
1108
1341
|
exclude_cols_dict=exclude_cols_dict,
|
|
1109
1342
|
)
|
|
@@ -1115,19 +1348,26 @@ class KumoRFM:
|
|
|
1115
1348
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1116
1349
|
f"must go beyond this for your use-case.")
|
|
1117
1350
|
|
|
1351
|
+
if (task.task_type.is_link_pred
|
|
1352
|
+
and task.entity_table_names[-1] not in subgraph.table_dict):
|
|
1353
|
+
raise ValueError("Cannot perform link prediction on subgraphs "
|
|
1354
|
+
"without any historical target entities. Please "
|
|
1355
|
+
"increase the number of hops and try again.")
|
|
1356
|
+
|
|
1118
1357
|
return Context(
|
|
1119
|
-
task_type=task_type,
|
|
1120
|
-
entity_table_names=entity_table_names,
|
|
1358
|
+
task_type=task.task_type,
|
|
1359
|
+
entity_table_names=task.entity_table_names,
|
|
1121
1360
|
subgraph=subgraph,
|
|
1122
|
-
y_train=
|
|
1123
|
-
y_test=
|
|
1124
|
-
|
|
1361
|
+
y_train=task._context_df[task.target_column.name],
|
|
1362
|
+
y_test=task._pred_df[task.target_column.name]
|
|
1363
|
+
if task.evaluate else None,
|
|
1364
|
+
top_k=top_k,
|
|
1125
1365
|
step_size=None,
|
|
1126
1366
|
)
|
|
1127
1367
|
|
|
1128
1368
|
@staticmethod
|
|
1129
1369
|
def _validate_metrics(
|
|
1130
|
-
metrics:
|
|
1370
|
+
metrics: list[str],
|
|
1131
1371
|
task_type: TaskType,
|
|
1132
1372
|
) -> None:
|
|
1133
1373
|
|
|
@@ -1184,7 +1424,7 @@ class KumoRFM:
|
|
|
1184
1424
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1185
1425
|
|
|
1186
1426
|
|
|
1187
|
-
def format_value(value:
|
|
1427
|
+
def format_value(value: int | float) -> str:
|
|
1188
1428
|
if value == int(value):
|
|
1189
1429
|
return f'{int(value):,}'
|
|
1190
1430
|
if abs(value) >= 1000:
|