kumoai 2.14.0.dev202512141732__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
- kumoai/experimental/rfm/backend/local/sampler.py +4 -5
- kumoai/experimental/rfm/backend/local/table.py +24 -30
- kumoai/experimental/rfm/backend/snow/sampler.py +331 -43
- kumoai/experimental/rfm/backend/snow/table.py +166 -56
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +372 -30
- kumoai/experimental/rfm/backend/sqlite/table.py +117 -48
- kumoai/experimental/rfm/base/__init__.py +8 -1
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +385 -0
- kumoai/experimental/rfm/base/table.py +374 -208
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +335 -180
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +10 -5
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +606 -361
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +192 -13
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +3 -2
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +49 -40
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,23 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
import time
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
6
|
+
from collections.abc import Generator, Iterator
|
|
6
7
|
from contextlib import contextmanager
|
|
7
8
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
9
|
+
from typing import Any, Literal, overload
|
|
19
10
|
|
|
20
|
-
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
@@ -37,30 +27,37 @@ from kumoapi.rfm import (
|
|
|
37
27
|
)
|
|
38
28
|
from kumoapi.task import TaskType
|
|
39
29
|
from kumoapi.typing import AggregationType, Stype
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.markdown import Markdown
|
|
40
32
|
|
|
33
|
+
from kumoai import in_notebook
|
|
41
34
|
from kumoai.client.rfm import RFMAPI
|
|
42
35
|
from kumoai.exceptions import HTTPException
|
|
43
|
-
from kumoai.experimental.rfm import Graph
|
|
36
|
+
from kumoai.experimental.rfm import Graph, TaskTable
|
|
44
37
|
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
45
38
|
from kumoai.mixin import CastMixin
|
|
46
|
-
from kumoai.utils import
|
|
39
|
+
from kumoai.utils import ProgressLogger, display
|
|
47
40
|
|
|
48
41
|
_RANDOM_SEED = 42
|
|
49
42
|
|
|
50
43
|
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
51
44
|
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
52
45
|
|
|
46
|
+
_MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
|
|
47
|
+
_MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
|
|
48
|
+
|
|
53
49
|
_MAX_CONTEXT_SIZE = {
|
|
54
50
|
RunMode.DEBUG: 100,
|
|
55
51
|
RunMode.FAST: 1_000,
|
|
56
52
|
RunMode.NORMAL: 5_000,
|
|
57
53
|
RunMode.BEST: 10_000,
|
|
58
54
|
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
RunMode.
|
|
62
|
-
RunMode.
|
|
63
|
-
RunMode.
|
|
55
|
+
|
|
56
|
+
_DEFAULT_NUM_NEIGHBORS = {
|
|
57
|
+
RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
|
|
58
|
+
RunMode.FAST: [32, 32, 8, 8, 4, 4],
|
|
59
|
+
RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
|
|
60
|
+
RunMode.BEST: [64, 64, 8, 8, 4, 4],
|
|
64
61
|
}
|
|
65
62
|
|
|
66
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
@@ -98,24 +95,36 @@ class Explanation:
|
|
|
98
95
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
99
96
|
pass
|
|
100
97
|
|
|
101
|
-
def __getitem__(self, index: int) ->
|
|
98
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
102
99
|
if index == 0:
|
|
103
100
|
return self.prediction
|
|
104
101
|
if index == 1:
|
|
105
102
|
return self.summary
|
|
106
103
|
raise IndexError("Index out of range")
|
|
107
104
|
|
|
108
|
-
def __iter__(self) -> Iterator[
|
|
105
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
109
106
|
return iter((self.prediction, self.summary))
|
|
110
107
|
|
|
111
108
|
def __repr__(self) -> str:
|
|
112
109
|
return str((self.prediction, self.summary))
|
|
113
110
|
|
|
114
|
-
def
|
|
115
|
-
|
|
111
|
+
def __str__(self) -> str:
|
|
112
|
+
console = Console(soft_wrap=True)
|
|
113
|
+
with console.capture() as cap:
|
|
114
|
+
console.print(display.to_rich_table(self.prediction))
|
|
115
|
+
console.print(Markdown(self.summary))
|
|
116
|
+
return cap.get()[:-1]
|
|
117
|
+
|
|
118
|
+
def print(self) -> None:
|
|
119
|
+
r"""Prints the explanation."""
|
|
120
|
+
if in_notebook():
|
|
121
|
+
display.dataframe(self.prediction)
|
|
122
|
+
display.message(self.summary)
|
|
123
|
+
else:
|
|
124
|
+
print(self)
|
|
116
125
|
|
|
117
|
-
|
|
118
|
-
|
|
126
|
+
def _ipython_display_(self) -> None:
|
|
127
|
+
self.print()
|
|
119
128
|
|
|
120
129
|
|
|
121
130
|
class KumoRFM:
|
|
@@ -154,11 +163,16 @@ class KumoRFM:
|
|
|
154
163
|
Args:
|
|
155
164
|
graph: The graph.
|
|
156
165
|
verbose: Whether to print verbose output.
|
|
166
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
167
|
+
for optimal querying. For example, for transactional database
|
|
168
|
+
backends, will create any missing indices. Requires write-access to
|
|
169
|
+
the data backend.
|
|
157
170
|
"""
|
|
158
171
|
def __init__(
|
|
159
172
|
self,
|
|
160
173
|
graph: Graph,
|
|
161
|
-
verbose:
|
|
174
|
+
verbose: bool | ProgressLogger = True,
|
|
175
|
+
optimize: bool = False,
|
|
162
176
|
) -> None:
|
|
163
177
|
graph = graph.validate()
|
|
164
178
|
self._graph_def = graph._to_api_graph_definition()
|
|
@@ -168,17 +182,17 @@ class KumoRFM:
|
|
|
168
182
|
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
169
183
|
elif graph.backend == DataBackend.SQLITE:
|
|
170
184
|
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
171
|
-
self._sampler = SQLiteSampler(graph, verbose)
|
|
185
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
172
186
|
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
173
187
|
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
174
188
|
self._sampler = SnowSampler(graph, verbose)
|
|
175
189
|
else:
|
|
176
190
|
raise NotImplementedError
|
|
177
191
|
|
|
178
|
-
self._client:
|
|
192
|
+
self._client: RFMAPI | None = None
|
|
179
193
|
|
|
180
|
-
self._batch_size:
|
|
181
|
-
self.
|
|
194
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
195
|
+
self._num_retries: int = 0
|
|
182
196
|
|
|
183
197
|
@property
|
|
184
198
|
def _api_client(self) -> RFMAPI:
|
|
@@ -192,10 +206,34 @@ class KumoRFM:
|
|
|
192
206
|
def __repr__(self) -> str:
|
|
193
207
|
return f'{self.__class__.__name__}()'
|
|
194
208
|
|
|
209
|
+
@contextmanager
|
|
210
|
+
def retry(
|
|
211
|
+
self,
|
|
212
|
+
num_retries: int = 1,
|
|
213
|
+
) -> Generator[None, None, None]:
|
|
214
|
+
"""Context manager to retry failed queries due to unexpected server
|
|
215
|
+
issues.
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
with model.retry(num_retries=1):
|
|
220
|
+
df = model.predict(query, indices=...)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
num_retries: The maximum number of retries.
|
|
224
|
+
"""
|
|
225
|
+
if num_retries < 0:
|
|
226
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
227
|
+
f"zero (got {num_retries})")
|
|
228
|
+
|
|
229
|
+
self._num_retries = num_retries
|
|
230
|
+
yield
|
|
231
|
+
self._num_retries = 0
|
|
232
|
+
|
|
195
233
|
@contextmanager
|
|
196
234
|
def batch_mode(
|
|
197
235
|
self,
|
|
198
|
-
batch_size:
|
|
236
|
+
batch_size: int | Literal['max'] = 'max',
|
|
199
237
|
num_retries: int = 1,
|
|
200
238
|
) -> Generator[None, None, None]:
|
|
201
239
|
"""Context manager to predict in batches.
|
|
@@ -215,31 +253,26 @@ class KumoRFM:
|
|
|
215
253
|
raise ValueError(f"'batch_size' must be greater than zero "
|
|
216
254
|
f"(got {batch_size})")
|
|
217
255
|
|
|
218
|
-
if num_retries < 0:
|
|
219
|
-
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
220
|
-
f"zero (got {num_retries})")
|
|
221
|
-
|
|
222
256
|
self._batch_size = batch_size
|
|
223
|
-
self.
|
|
224
|
-
|
|
257
|
+
with self.retry(self._num_retries or num_retries):
|
|
258
|
+
yield
|
|
225
259
|
self._batch_size = None
|
|
226
|
-
self.num_retries = 0
|
|
227
260
|
|
|
228
261
|
@overload
|
|
229
262
|
def predict(
|
|
230
263
|
self,
|
|
231
264
|
query: str,
|
|
232
|
-
indices:
|
|
265
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
233
266
|
*,
|
|
234
267
|
explain: Literal[False] = False,
|
|
235
|
-
anchor_time:
|
|
236
|
-
context_anchor_time:
|
|
237
|
-
run_mode:
|
|
238
|
-
num_neighbors:
|
|
268
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
269
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
270
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
271
|
+
num_neighbors: list[int] | None = None,
|
|
239
272
|
num_hops: int = 2,
|
|
240
273
|
max_pq_iterations: int = 10,
|
|
241
|
-
random_seed:
|
|
242
|
-
verbose:
|
|
274
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
275
|
+
verbose: bool | ProgressLogger = True,
|
|
243
276
|
use_prediction_time: bool = False,
|
|
244
277
|
) -> pd.DataFrame:
|
|
245
278
|
pass
|
|
@@ -248,37 +281,56 @@ class KumoRFM:
|
|
|
248
281
|
def predict(
|
|
249
282
|
self,
|
|
250
283
|
query: str,
|
|
251
|
-
indices:
|
|
284
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
252
285
|
*,
|
|
253
|
-
explain:
|
|
254
|
-
anchor_time:
|
|
255
|
-
context_anchor_time:
|
|
256
|
-
run_mode:
|
|
257
|
-
num_neighbors:
|
|
286
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
287
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
288
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
289
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
290
|
+
num_neighbors: list[int] | None = None,
|
|
258
291
|
num_hops: int = 2,
|
|
259
292
|
max_pq_iterations: int = 10,
|
|
260
|
-
random_seed:
|
|
261
|
-
verbose:
|
|
293
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
294
|
+
verbose: bool | ProgressLogger = True,
|
|
262
295
|
use_prediction_time: bool = False,
|
|
263
296
|
) -> Explanation:
|
|
264
297
|
pass
|
|
265
298
|
|
|
299
|
+
@overload
|
|
300
|
+
def predict(
|
|
301
|
+
self,
|
|
302
|
+
query: str,
|
|
303
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
304
|
+
*,
|
|
305
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
306
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
307
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
308
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
309
|
+
num_neighbors: list[int] | None = None,
|
|
310
|
+
num_hops: int = 2,
|
|
311
|
+
max_pq_iterations: int = 10,
|
|
312
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
313
|
+
verbose: bool | ProgressLogger = True,
|
|
314
|
+
use_prediction_time: bool = False,
|
|
315
|
+
) -> pd.DataFrame | Explanation:
|
|
316
|
+
pass
|
|
317
|
+
|
|
266
318
|
def predict(
|
|
267
319
|
self,
|
|
268
320
|
query: str,
|
|
269
|
-
indices:
|
|
321
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
270
322
|
*,
|
|
271
|
-
explain:
|
|
272
|
-
anchor_time:
|
|
273
|
-
context_anchor_time:
|
|
274
|
-
run_mode:
|
|
275
|
-
num_neighbors:
|
|
323
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
324
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
325
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
326
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
327
|
+
num_neighbors: list[int] | None = None,
|
|
276
328
|
num_hops: int = 2,
|
|
277
329
|
max_pq_iterations: int = 10,
|
|
278
|
-
random_seed:
|
|
279
|
-
verbose:
|
|
330
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
331
|
+
verbose: bool | ProgressLogger = True,
|
|
280
332
|
use_prediction_time: bool = False,
|
|
281
|
-
) ->
|
|
333
|
+
) -> pd.DataFrame | Explanation:
|
|
282
334
|
"""Returns predictions for a predictive query.
|
|
283
335
|
|
|
284
336
|
Args:
|
|
@@ -286,8 +338,7 @@ class KumoRFM:
|
|
|
286
338
|
indices: The entity primary keys to predict for. Will override the
|
|
287
339
|
indices given as part of the predictive query. Predictions will
|
|
288
340
|
be generated for all indices, independent of whether they
|
|
289
|
-
fulfill entity filter constraints.
|
|
290
|
-
:meth:`~KumoRFM.is_valid_entity`.
|
|
341
|
+
fulfill entity filter constraints.
|
|
291
342
|
explain: Configuration for explainability.
|
|
292
343
|
If set to ``True``, will additionally explain the prediction.
|
|
293
344
|
Passing in an :class:`ExplainConfig` instance provides control
|
|
@@ -320,18 +371,152 @@ class KumoRFM:
|
|
|
320
371
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
321
372
|
containing the prediction, summary, and details.
|
|
322
373
|
"""
|
|
323
|
-
explain_config: Optional[ExplainConfig] = None
|
|
324
|
-
if explain is True:
|
|
325
|
-
explain_config = ExplainConfig()
|
|
326
|
-
elif explain is not False:
|
|
327
|
-
explain_config = ExplainConfig._cast(explain)
|
|
328
|
-
|
|
329
374
|
query_def = self._parse_query(query)
|
|
330
|
-
query_str = query_def.to_string()
|
|
331
375
|
|
|
376
|
+
if indices is None:
|
|
377
|
+
if query_def.rfm_entity_ids is None:
|
|
378
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
379
|
+
"pass them via `predict(query, indices=...)`")
|
|
380
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
381
|
+
query_def = replace(
|
|
382
|
+
query_def,
|
|
383
|
+
for_each='FOR EACH',
|
|
384
|
+
rfm_entity_ids=None,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if not isinstance(verbose, ProgressLogger):
|
|
388
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
389
|
+
if explain is not False:
|
|
390
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
391
|
+
else:
|
|
392
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
393
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
394
|
+
|
|
395
|
+
with verbose as logger:
|
|
396
|
+
task_table = self._get_task_table(
|
|
397
|
+
query=query_def,
|
|
398
|
+
indices=indices,
|
|
399
|
+
anchor_time=anchor_time,
|
|
400
|
+
context_anchor_time=context_anchor_time,
|
|
401
|
+
run_mode=run_mode,
|
|
402
|
+
max_pq_iterations=max_pq_iterations,
|
|
403
|
+
random_seed=random_seed,
|
|
404
|
+
logger=logger,
|
|
405
|
+
)
|
|
406
|
+
task_table._query = query_def.to_string()
|
|
407
|
+
|
|
408
|
+
return self.predict_task(
|
|
409
|
+
task_table,
|
|
410
|
+
explain=explain,
|
|
411
|
+
run_mode=run_mode,
|
|
412
|
+
num_neighbors=num_neighbors,
|
|
413
|
+
num_hops=num_hops,
|
|
414
|
+
verbose=verbose,
|
|
415
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
416
|
+
use_prediction_time=use_prediction_time,
|
|
417
|
+
top_k=query_def.top_k,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
@overload
|
|
421
|
+
def predict_task(
|
|
422
|
+
self,
|
|
423
|
+
task: TaskTable,
|
|
424
|
+
*,
|
|
425
|
+
explain: Literal[False] = False,
|
|
426
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
427
|
+
num_neighbors: list[int] | None = None,
|
|
428
|
+
num_hops: int = 2,
|
|
429
|
+
verbose: bool | ProgressLogger = True,
|
|
430
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
431
|
+
use_prediction_time: bool = False,
|
|
432
|
+
top_k: int | None = None,
|
|
433
|
+
) -> pd.DataFrame:
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
@overload
|
|
437
|
+
def predict_task(
|
|
438
|
+
self,
|
|
439
|
+
task: TaskTable,
|
|
440
|
+
*,
|
|
441
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
442
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
443
|
+
num_neighbors: list[int] | None = None,
|
|
444
|
+
num_hops: int = 2,
|
|
445
|
+
verbose: bool | ProgressLogger = True,
|
|
446
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
447
|
+
use_prediction_time: bool = False,
|
|
448
|
+
top_k: int | None = None,
|
|
449
|
+
) -> Explanation:
|
|
450
|
+
pass
|
|
451
|
+
|
|
452
|
+
@overload
|
|
453
|
+
def predict_task(
|
|
454
|
+
self,
|
|
455
|
+
task: TaskTable,
|
|
456
|
+
*,
|
|
457
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
458
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
459
|
+
num_neighbors: list[int] | None = None,
|
|
460
|
+
num_hops: int = 2,
|
|
461
|
+
verbose: bool | ProgressLogger = True,
|
|
462
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
463
|
+
use_prediction_time: bool = False,
|
|
464
|
+
top_k: int | None = None,
|
|
465
|
+
) -> pd.DataFrame | Explanation:
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
def predict_task(
|
|
469
|
+
self,
|
|
470
|
+
task: TaskTable,
|
|
471
|
+
*,
|
|
472
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
473
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
474
|
+
num_neighbors: list[int] | None = None,
|
|
475
|
+
num_hops: int = 2,
|
|
476
|
+
verbose: bool | ProgressLogger = True,
|
|
477
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
478
|
+
use_prediction_time: bool = False,
|
|
479
|
+
top_k: int | None = None,
|
|
480
|
+
) -> pd.DataFrame | Explanation:
|
|
481
|
+
"""Returns predictions for a custom task specification.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
task: The custom :class:`TaskTable`.
|
|
485
|
+
explain: Configuration for explainability.
|
|
486
|
+
If set to ``True``, will additionally explain the prediction.
|
|
487
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
488
|
+
over which parts of explanation are generated.
|
|
489
|
+
Explainability is currently only supported for single entity
|
|
490
|
+
predictions with ``run_mode="FAST"``.
|
|
491
|
+
run_mode: The :class:`RunMode` for the query.
|
|
492
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
493
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
494
|
+
num_hops: The number of hops to sample when generating the context.
|
|
495
|
+
verbose: Whether to print verbose output.
|
|
496
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
497
|
+
model input.
|
|
498
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
499
|
+
additional feature during prediction. This is typically
|
|
500
|
+
beneficial for time series forecasting tasks.
|
|
501
|
+
top_k: The number of predictions to return per entity.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
505
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
506
|
+
containing the prediction, summary, and details.
|
|
507
|
+
"""
|
|
332
508
|
if num_hops != 2 and num_neighbors is not None:
|
|
333
509
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
334
510
|
f"custom 'num_hops={num_hops}' option")
|
|
511
|
+
if num_neighbors is None:
|
|
512
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
513
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
514
|
+
|
|
515
|
+
explain_config: ExplainConfig | None = None
|
|
516
|
+
if explain is True:
|
|
517
|
+
explain_config = ExplainConfig()
|
|
518
|
+
elif explain is not False:
|
|
519
|
+
explain_config = ExplainConfig._cast(explain)
|
|
335
520
|
|
|
336
521
|
if explain_config is not None and run_mode in {
|
|
337
522
|
RunMode.NORMAL, RunMode.BEST
|
|
@@ -340,83 +525,82 @@ class KumoRFM:
|
|
|
340
525
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
341
526
|
f"mode has been reset. Please lower the run mode to "
|
|
342
527
|
f"suppress this warning.")
|
|
528
|
+
run_mode = RunMode.FAST
|
|
343
529
|
|
|
344
|
-
if
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
349
|
-
else:
|
|
350
|
-
query_def = replace(query_def, rfm_entity_ids=None)
|
|
351
|
-
|
|
352
|
-
if len(indices) == 0:
|
|
353
|
-
raise ValueError("At least one entity is required")
|
|
354
|
-
|
|
355
|
-
if explain_config is not None and len(indices) > 1:
|
|
356
|
-
raise ValueError(
|
|
357
|
-
f"Cannot explain predictions for more than a single entity "
|
|
358
|
-
f"(got {len(indices)})")
|
|
359
|
-
|
|
360
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
361
|
-
if explain_config is not None:
|
|
362
|
-
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
363
|
-
else:
|
|
364
|
-
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
530
|
+
if explain_config is not None and task.num_prediction_examples > 1:
|
|
531
|
+
raise ValueError(f"Cannot explain predictions for more than a "
|
|
532
|
+
f"single entity "
|
|
533
|
+
f"(got {task.num_prediction_examples:,})")
|
|
365
534
|
|
|
366
535
|
if not isinstance(verbose, ProgressLogger):
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
edge_types=self._sampler.edge_types,
|
|
376
|
-
)
|
|
377
|
-
batch_size = _MAX_PRED_SIZE[task_type]
|
|
536
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
537
|
+
task_type_repr = 'binary classification'
|
|
538
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
539
|
+
task_type_repr = 'multi-class classification'
|
|
540
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
541
|
+
task_type_repr = 'regression'
|
|
542
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
543
|
+
task_type_repr = 'link prediction'
|
|
378
544
|
else:
|
|
379
|
-
|
|
545
|
+
task_type_repr = str(task.task_type)
|
|
380
546
|
|
|
381
|
-
if
|
|
382
|
-
|
|
383
|
-
batches = [indices[step:step + batch_size] for step in offsets]
|
|
547
|
+
if explain_config is not None:
|
|
548
|
+
msg = f"Explaining {task_type_repr} task"
|
|
384
549
|
else:
|
|
385
|
-
|
|
550
|
+
msg = f"Predicting {task_type_repr} task"
|
|
551
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
386
552
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
553
|
+
with verbose as logger:
|
|
554
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
555
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
556
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
557
|
+
f"examples")
|
|
558
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
559
|
+
|
|
560
|
+
if self._batch_size is None:
|
|
561
|
+
batch_size = task.num_prediction_examples
|
|
562
|
+
elif self._batch_size == 'max':
|
|
563
|
+
batch_size = _MAX_PRED_SIZE[task.task_type]
|
|
564
|
+
else:
|
|
565
|
+
batch_size = self._batch_size
|
|
390
566
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
567
|
+
if batch_size > _MAX_PRED_SIZE[task.task_type]:
|
|
568
|
+
raise ValueError(f"Cannot predict for more than "
|
|
569
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
570
|
+
f"entities at once (got {batch_size:,}). Use "
|
|
571
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
572
|
+
f"in batches with a sufficient batch size.")
|
|
573
|
+
|
|
574
|
+
if task.num_prediction_examples > batch_size:
|
|
575
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
576
|
+
logger.log(f"Splitting {task.num_prediction_examples:,} "
|
|
577
|
+
f"entities into {num:,} batches of size "
|
|
578
|
+
f"{batch_size:,}")
|
|
579
|
+
|
|
580
|
+
predictions: list[pd.DataFrame] = []
|
|
581
|
+
summary: str | None = None
|
|
582
|
+
details: Explanation | None = None
|
|
583
|
+
for start in range(0, task.num_prediction_examples, batch_size):
|
|
396
584
|
context = self._get_context(
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
anchor_time=anchor_time,
|
|
400
|
-
context_anchor_time=context_anchor_time,
|
|
401
|
-
run_mode=RunMode(run_mode),
|
|
585
|
+
task=task.narrow_prediction(start, length=batch_size),
|
|
586
|
+
run_mode=run_mode,
|
|
402
587
|
num_neighbors=num_neighbors,
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
evaluate=False,
|
|
406
|
-
random_seed=random_seed,
|
|
407
|
-
logger=logger if i == 0 else None,
|
|
588
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
589
|
+
top_k=top_k,
|
|
408
590
|
)
|
|
591
|
+
context.y_test = None
|
|
592
|
+
|
|
409
593
|
request = RFMPredictRequest(
|
|
410
594
|
context=context,
|
|
411
595
|
run_mode=RunMode(run_mode),
|
|
412
|
-
query=
|
|
596
|
+
query=task._query,
|
|
413
597
|
use_prediction_time=use_prediction_time,
|
|
414
598
|
)
|
|
415
599
|
with warnings.catch_warnings():
|
|
416
600
|
warnings.filterwarnings('ignore', message='gencode')
|
|
417
601
|
request_msg = request.to_protobuf()
|
|
418
602
|
_bytes = request_msg.SerializeToString()
|
|
419
|
-
if
|
|
603
|
+
if start == 0:
|
|
420
604
|
logger.log(f"Generated context of size "
|
|
421
605
|
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
422
606
|
|
|
@@ -424,14 +608,11 @@ class KumoRFM:
|
|
|
424
608
|
stats = Context.get_memory_stats(request_msg.context)
|
|
425
609
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
426
610
|
|
|
427
|
-
if
|
|
428
|
-
|
|
429
|
-
verbose.init_progress(
|
|
430
|
-
total=len(batches),
|
|
431
|
-
description='Predicting',
|
|
432
|
-
)
|
|
611
|
+
if start == 0 and task.num_prediction_examples > batch_size:
|
|
612
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
613
|
+
verbose.init_progress(total=num, description='Predicting')
|
|
433
614
|
|
|
434
|
-
for attempt in range(self.
|
|
615
|
+
for attempt in range(self._num_retries + 1):
|
|
435
616
|
try:
|
|
436
617
|
if explain_config is not None:
|
|
437
618
|
resp = self._api_client.explain(
|
|
@@ -447,7 +628,7 @@ class KumoRFM:
|
|
|
447
628
|
# Cast 'ENTITY' to correct data type:
|
|
448
629
|
if 'ENTITY' in df:
|
|
449
630
|
table_dict = context.subgraph.table_dict
|
|
450
|
-
table = table_dict[
|
|
631
|
+
table = table_dict[context.entity_table_names[0]]
|
|
451
632
|
ser = table.df[table.primary_key]
|
|
452
633
|
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
453
634
|
|
|
@@ -464,13 +645,12 @@ class KumoRFM:
|
|
|
464
645
|
|
|
465
646
|
predictions.append(df)
|
|
466
647
|
|
|
467
|
-
if
|
|
468
|
-
and len(batches) > 1):
|
|
648
|
+
if task.num_prediction_examples > batch_size:
|
|
469
649
|
verbose.step()
|
|
470
650
|
|
|
471
651
|
break
|
|
472
652
|
except HTTPException as e:
|
|
473
|
-
if attempt == self.
|
|
653
|
+
if attempt == self._num_retries:
|
|
474
654
|
try:
|
|
475
655
|
msg = json.loads(e.detail)['detail']
|
|
476
656
|
except Exception:
|
|
@@ -500,64 +680,19 @@ class KumoRFM:
|
|
|
500
680
|
|
|
501
681
|
return prediction
|
|
502
682
|
|
|
503
|
-
def is_valid_entity(
|
|
504
|
-
self,
|
|
505
|
-
query: str,
|
|
506
|
-
indices: Union[List[str], List[float], List[int], None] = None,
|
|
507
|
-
*,
|
|
508
|
-
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
509
|
-
) -> np.ndarray:
|
|
510
|
-
r"""Returns a mask that denotes which entities are valid for the
|
|
511
|
-
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
512
|
-
entity filter constraints.
|
|
513
|
-
|
|
514
|
-
Args:
|
|
515
|
-
query: The predictive query.
|
|
516
|
-
indices: The entity primary keys to predict for. Will override the
|
|
517
|
-
indices given as part of the predictive query.
|
|
518
|
-
anchor_time: The anchor timestamp for the prediction. If set to
|
|
519
|
-
``None``, will use the maximum timestamp in the data.
|
|
520
|
-
If set to ``"entity"``, will use the timestamp of the entity.
|
|
521
|
-
"""
|
|
522
|
-
query_def = self._parse_query(query)
|
|
523
|
-
|
|
524
|
-
if indices is None:
|
|
525
|
-
if query_def.rfm_entity_ids is None:
|
|
526
|
-
raise ValueError("Cannot find entities to predict for. Please "
|
|
527
|
-
"pass them via "
|
|
528
|
-
"`is_valid_entity(query, indices=...)`")
|
|
529
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
530
|
-
|
|
531
|
-
if len(indices) == 0:
|
|
532
|
-
raise ValueError("At least one entity is required")
|
|
533
|
-
|
|
534
|
-
if anchor_time is None:
|
|
535
|
-
anchor_time = self._get_default_anchor_time(query_def)
|
|
536
|
-
|
|
537
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
538
|
-
self._validate_time(query_def, anchor_time, None, False)
|
|
539
|
-
else:
|
|
540
|
-
assert anchor_time == 'entity'
|
|
541
|
-
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
542
|
-
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
543
|
-
f"table '{query_def.entity_table}' "
|
|
544
|
-
f"to have a time column.")
|
|
545
|
-
|
|
546
|
-
raise NotImplementedError
|
|
547
|
-
|
|
548
683
|
def evaluate(
|
|
549
684
|
self,
|
|
550
685
|
query: str,
|
|
551
686
|
*,
|
|
552
|
-
metrics:
|
|
553
|
-
anchor_time:
|
|
554
|
-
context_anchor_time:
|
|
555
|
-
run_mode:
|
|
556
|
-
num_neighbors:
|
|
687
|
+
metrics: list[str] | None = None,
|
|
688
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
689
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
690
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
691
|
+
num_neighbors: list[int] | None = None,
|
|
557
692
|
num_hops: int = 2,
|
|
558
693
|
max_pq_iterations: int = 10,
|
|
559
|
-
random_seed:
|
|
560
|
-
verbose:
|
|
694
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
695
|
+
verbose: bool | ProgressLogger = True,
|
|
561
696
|
use_prediction_time: bool = False,
|
|
562
697
|
) -> pd.DataFrame:
|
|
563
698
|
"""Evaluates a predictive query.
|
|
@@ -589,41 +724,120 @@ class KumoRFM:
|
|
|
589
724
|
Returns:
|
|
590
725
|
The metrics as a :class:`pandas.DataFrame`
|
|
591
726
|
"""
|
|
592
|
-
query_def =
|
|
727
|
+
query_def = replace(
|
|
728
|
+
self._parse_query(query),
|
|
729
|
+
for_each='FOR EACH',
|
|
730
|
+
rfm_entity_ids=None,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
if not isinstance(verbose, ProgressLogger):
|
|
734
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
735
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
736
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
737
|
+
|
|
738
|
+
with verbose as logger:
|
|
739
|
+
task_table = self._get_task_table(
|
|
740
|
+
query=query_def,
|
|
741
|
+
indices=None,
|
|
742
|
+
anchor_time=anchor_time,
|
|
743
|
+
context_anchor_time=context_anchor_time,
|
|
744
|
+
run_mode=run_mode,
|
|
745
|
+
max_pq_iterations=max_pq_iterations,
|
|
746
|
+
random_seed=random_seed,
|
|
747
|
+
logger=logger,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
return self.evaluate_task(
|
|
751
|
+
task_table,
|
|
752
|
+
metrics=metrics,
|
|
753
|
+
run_mode=run_mode,
|
|
754
|
+
num_neighbors=num_neighbors,
|
|
755
|
+
num_hops=num_hops,
|
|
756
|
+
verbose=verbose,
|
|
757
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
758
|
+
use_prediction_time=use_prediction_time,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
def evaluate_task(
|
|
762
|
+
self,
|
|
763
|
+
task: TaskTable,
|
|
764
|
+
*,
|
|
765
|
+
metrics: list[str] | None = None,
|
|
766
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
767
|
+
num_neighbors: list[int] | None = None,
|
|
768
|
+
num_hops: int = 2,
|
|
769
|
+
verbose: bool | ProgressLogger = True,
|
|
770
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
771
|
+
use_prediction_time: bool = False,
|
|
772
|
+
) -> pd.DataFrame:
|
|
773
|
+
"""Evaluates a custom task specification.
|
|
774
|
+
|
|
775
|
+
Args:
|
|
776
|
+
task: The custom :class:`TaskTable`.
|
|
777
|
+
metrics: The metrics to use.
|
|
778
|
+
run_mode: The :class:`RunMode` for the query.
|
|
779
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
780
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
781
|
+
num_hops: The number of hops to sample when generating the context.
|
|
782
|
+
verbose: Whether to print verbose output.
|
|
783
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
784
|
+
model input.
|
|
785
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
786
|
+
additional feature during prediction. This is typically
|
|
787
|
+
beneficial for time series forecasting tasks.
|
|
593
788
|
|
|
789
|
+
Returns:
|
|
790
|
+
The metrics as a :class:`pandas.DataFrame`
|
|
791
|
+
"""
|
|
594
792
|
if num_hops != 2 and num_neighbors is not None:
|
|
595
793
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
596
794
|
f"custom 'num_hops={num_hops}' option")
|
|
795
|
+
if num_neighbors is None:
|
|
796
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
797
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
597
798
|
|
|
598
|
-
if
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
rfm_entity_ids=None,
|
|
602
|
-
)
|
|
603
|
-
|
|
604
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
605
|
-
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
799
|
+
if metrics is not None and len(metrics) > 0:
|
|
800
|
+
self._validate_metrics(metrics, task.task_type)
|
|
801
|
+
metrics = list(dict.fromkeys(metrics))
|
|
606
802
|
|
|
607
803
|
if not isinstance(verbose, ProgressLogger):
|
|
608
|
-
|
|
804
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
805
|
+
task_type_repr = 'binary classification'
|
|
806
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
807
|
+
task_type_repr = 'multi-class classification'
|
|
808
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
809
|
+
task_type_repr = 'regression'
|
|
810
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
811
|
+
task_type_repr = 'link prediction'
|
|
812
|
+
else:
|
|
813
|
+
task_type_repr = str(task.task_type)
|
|
814
|
+
|
|
815
|
+
msg = f"Evaluating {task_type_repr} task"
|
|
816
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
609
817
|
|
|
610
818
|
with verbose as logger:
|
|
819
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
820
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
821
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
822
|
+
f"examples")
|
|
823
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
824
|
+
|
|
825
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
826
|
+
logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
|
|
827
|
+
f"out of {task.num_prediction_examples:,} test "
|
|
828
|
+
f"examples")
|
|
829
|
+
task = task.narrow_prediction(
|
|
830
|
+
start=0,
|
|
831
|
+
length=_MAX_TEST_SIZE[task.task_type],
|
|
832
|
+
)
|
|
833
|
+
|
|
611
834
|
context = self._get_context(
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
anchor_time=anchor_time,
|
|
615
|
-
context_anchor_time=context_anchor_time,
|
|
616
|
-
run_mode=RunMode(run_mode),
|
|
835
|
+
task=task,
|
|
836
|
+
run_mode=run_mode,
|
|
617
837
|
num_neighbors=num_neighbors,
|
|
618
|
-
|
|
619
|
-
max_pq_iterations=max_pq_iterations,
|
|
620
|
-
evaluate=True,
|
|
621
|
-
random_seed=random_seed,
|
|
622
|
-
logger=logger if verbose else None,
|
|
838
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
623
839
|
)
|
|
624
|
-
|
|
625
|
-
self._validate_metrics(metrics, context.task_type)
|
|
626
|
-
metrics = list(dict.fromkeys(metrics))
|
|
840
|
+
|
|
627
841
|
request = RFMEvaluateRequest(
|
|
628
842
|
context=context,
|
|
629
843
|
run_mode=RunMode(run_mode),
|
|
@@ -641,17 +855,23 @@ class KumoRFM:
|
|
|
641
855
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
642
856
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
643
857
|
|
|
644
|
-
|
|
645
|
-
resp = self._api_client.evaluate(request_bytes)
|
|
646
|
-
except HTTPException as e:
|
|
858
|
+
for attempt in range(self._num_retries + 1):
|
|
647
859
|
try:
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
860
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
861
|
+
break
|
|
862
|
+
except HTTPException as e:
|
|
863
|
+
if attempt == self._num_retries:
|
|
864
|
+
try:
|
|
865
|
+
msg = json.loads(e.detail)['detail']
|
|
866
|
+
except Exception:
|
|
867
|
+
msg = e.detail
|
|
868
|
+
raise RuntimeError(
|
|
869
|
+
f"An unexpected exception occurred. Please create "
|
|
870
|
+
f"an issue at "
|
|
871
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
872
|
+
) from None
|
|
873
|
+
|
|
874
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
655
875
|
|
|
656
876
|
return pd.DataFrame.from_dict(
|
|
657
877
|
resp.metrics,
|
|
@@ -664,8 +884,8 @@ class KumoRFM:
|
|
|
664
884
|
query: str,
|
|
665
885
|
size: int,
|
|
666
886
|
*,
|
|
667
|
-
anchor_time:
|
|
668
|
-
random_seed:
|
|
887
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
888
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
669
889
|
max_iterations: int = 10,
|
|
670
890
|
) -> pd.DataFrame:
|
|
671
891
|
"""Returns the labels of a predictive query for a specified anchor
|
|
@@ -703,7 +923,7 @@ class KumoRFM:
|
|
|
703
923
|
f"to have a time column")
|
|
704
924
|
|
|
705
925
|
train, test = self._sampler.sample_target(
|
|
706
|
-
query=
|
|
926
|
+
query=query_def,
|
|
707
927
|
num_train_examples=0,
|
|
708
928
|
train_anchor_time=anchor_time,
|
|
709
929
|
num_train_trials=0,
|
|
@@ -731,35 +951,39 @@ class KumoRFM:
|
|
|
731
951
|
"`predict()` or `evaluate()` methods to perform "
|
|
732
952
|
"predictions or evaluations.")
|
|
733
953
|
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
)
|
|
954
|
+
request = RFMParseQueryRequest(
|
|
955
|
+
query=query,
|
|
956
|
+
graph_definition=self._graph_def,
|
|
957
|
+
)
|
|
739
958
|
|
|
740
|
-
|
|
959
|
+
for attempt in range(self._num_retries + 1):
|
|
960
|
+
try:
|
|
961
|
+
resp = self._api_client.parse_query(request)
|
|
962
|
+
break
|
|
963
|
+
except HTTPException as e:
|
|
964
|
+
if attempt == self._num_retries:
|
|
965
|
+
try:
|
|
966
|
+
msg = json.loads(e.detail)['detail']
|
|
967
|
+
except Exception:
|
|
968
|
+
msg = e.detail
|
|
969
|
+
raise ValueError(f"Failed to parse query '{query}'. {msg}")
|
|
741
970
|
|
|
742
|
-
|
|
743
|
-
msg = '\n'.join([
|
|
744
|
-
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
745
|
-
in enumerate(resp.validation_response.warnings)
|
|
746
|
-
])
|
|
747
|
-
warnings.warn(f"Encountered the following warnings during "
|
|
748
|
-
f"parsing:\n{msg}")
|
|
971
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
749
972
|
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
973
|
+
if len(resp.validation_response.warnings) > 0:
|
|
974
|
+
msg = '\n'.join([
|
|
975
|
+
f'{i+1}. {warning.title}: {warning.message}'
|
|
976
|
+
for i, warning in enumerate(resp.validation_response.warnings)
|
|
977
|
+
])
|
|
978
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
979
|
+
f"parsing:\n{msg}")
|
|
980
|
+
|
|
981
|
+
return resp.query
|
|
758
982
|
|
|
759
983
|
@staticmethod
|
|
760
984
|
def _get_task_type(
|
|
761
985
|
query: ValidatedPredictiveQuery,
|
|
762
|
-
edge_types:
|
|
986
|
+
edge_types: list[tuple[str, str, str]],
|
|
763
987
|
) -> TaskType:
|
|
764
988
|
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
765
989
|
return TaskType.BINARY_CLASSIFICATION
|
|
@@ -798,31 +1022,38 @@ class KumoRFM:
|
|
|
798
1022
|
|
|
799
1023
|
def _get_default_anchor_time(
|
|
800
1024
|
self,
|
|
801
|
-
query: ValidatedPredictiveQuery,
|
|
1025
|
+
query: ValidatedPredictiveQuery | None = None,
|
|
802
1026
|
) -> pd.Timestamp:
|
|
803
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1027
|
+
if query is not None and query.query_type == QueryType.TEMPORAL:
|
|
804
1028
|
aggr_table_names = [
|
|
805
1029
|
aggr._get_target_column_name().split('.')[0]
|
|
806
1030
|
for aggr in query.get_all_target_aggregations()
|
|
807
1031
|
]
|
|
808
1032
|
return self._sampler.get_max_time(aggr_table_names)
|
|
809
1033
|
|
|
810
|
-
assert query.query_type == QueryType.STATIC
|
|
811
1034
|
return self._sampler.get_max_time()
|
|
812
1035
|
|
|
813
1036
|
def _validate_time(
|
|
814
1037
|
self,
|
|
815
1038
|
query: ValidatedPredictiveQuery,
|
|
816
1039
|
anchor_time: pd.Timestamp,
|
|
817
|
-
context_anchor_time:
|
|
1040
|
+
context_anchor_time: pd.Timestamp | None,
|
|
818
1041
|
evaluate: bool,
|
|
819
1042
|
) -> None:
|
|
820
1043
|
|
|
821
1044
|
if len(self._sampler.time_column_dict) == 0:
|
|
822
1045
|
return # Graph without timestamps
|
|
823
1046
|
|
|
824
|
-
|
|
825
|
-
|
|
1047
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
1048
|
+
aggr_table_names = [
|
|
1049
|
+
aggr._get_target_column_name().split('.')[0]
|
|
1050
|
+
for aggr in query.get_all_target_aggregations()
|
|
1051
|
+
]
|
|
1052
|
+
min_time = self._sampler.get_min_time(aggr_table_names)
|
|
1053
|
+
max_time = self._sampler.get_max_time(aggr_table_names)
|
|
1054
|
+
else:
|
|
1055
|
+
min_time = self._sampler.get_min_time()
|
|
1056
|
+
max_time = self._sampler.get_max_time()
|
|
826
1057
|
|
|
827
1058
|
if anchor_time < min_time:
|
|
828
1059
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
@@ -877,40 +1108,26 @@ class KumoRFM:
|
|
|
877
1108
|
f"Anchor timestamp for evaluation is after the latest "
|
|
878
1109
|
f"supported timestamp '{max_time - end_offset}'.")
|
|
879
1110
|
|
|
880
|
-
def
|
|
1111
|
+
def _get_task_table(
|
|
881
1112
|
self,
|
|
882
1113
|
query: ValidatedPredictiveQuery,
|
|
883
|
-
indices:
|
|
884
|
-
anchor_time:
|
|
885
|
-
context_anchor_time:
|
|
886
|
-
run_mode: RunMode,
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
random_seed: Optional[int] = _RANDOM_SEED,
|
|
892
|
-
logger: Optional[ProgressLogger] = None,
|
|
893
|
-
) -> Context:
|
|
894
|
-
|
|
895
|
-
if num_neighbors is not None:
|
|
896
|
-
num_hops = len(num_neighbors)
|
|
897
|
-
|
|
898
|
-
if num_hops < 0:
|
|
899
|
-
raise ValueError(f"'num_hops' must be non-negative "
|
|
900
|
-
f"(got {num_hops})")
|
|
901
|
-
if num_hops > 6:
|
|
902
|
-
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
903
|
-
f"hops (got {num_hops}). Please reduce the "
|
|
904
|
-
f"number of hops and try again. Please create a "
|
|
905
|
-
f"feature request at "
|
|
906
|
-
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
907
|
-
f"must go beyond this for your use-case.")
|
|
1114
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
1115
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
1116
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
1117
|
+
run_mode: RunMode = RunMode.FAST,
|
|
1118
|
+
max_pq_iterations: int = 10,
|
|
1119
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
1120
|
+
logger: ProgressLogger | None = None,
|
|
1121
|
+
) -> TaskTable:
|
|
908
1122
|
|
|
909
1123
|
task_type = self._get_task_type(
|
|
910
1124
|
query=query,
|
|
911
1125
|
edge_types=self._sampler.edge_types,
|
|
912
1126
|
)
|
|
913
1127
|
|
|
1128
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1129
|
+
num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
|
|
1130
|
+
|
|
914
1131
|
if logger is not None:
|
|
915
1132
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
916
1133
|
task_type_repr = 'binary classification'
|
|
@@ -924,21 +1141,6 @@ class KumoRFM:
|
|
|
924
1141
|
task_type_repr = str(task_type)
|
|
925
1142
|
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
926
1143
|
|
|
927
|
-
if task_type.is_link_pred and num_hops < 2:
|
|
928
|
-
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
929
|
-
f"with less than 2 hops (got {num_hops}) since "
|
|
930
|
-
f"historical target entities need to be part of "
|
|
931
|
-
f"the context. Please increase the number of "
|
|
932
|
-
f"hops and try again.")
|
|
933
|
-
|
|
934
|
-
if num_neighbors is None:
|
|
935
|
-
if run_mode == RunMode.DEBUG:
|
|
936
|
-
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
937
|
-
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
938
|
-
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
939
|
-
else:
|
|
940
|
-
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
941
|
-
|
|
942
1144
|
if query.target_ast.date_offset_range is None:
|
|
943
1145
|
step_offset = pd.DateOffset(0)
|
|
944
1146
|
else:
|
|
@@ -947,8 +1149,7 @@ class KumoRFM:
|
|
|
947
1149
|
|
|
948
1150
|
if anchor_time is None:
|
|
949
1151
|
anchor_time = self._get_default_anchor_time(query)
|
|
950
|
-
|
|
951
|
-
if evaluate:
|
|
1152
|
+
if num_test_examples > 0:
|
|
952
1153
|
anchor_time = anchor_time - end_offset
|
|
953
1154
|
|
|
954
1155
|
if logger is not None:
|
|
@@ -962,7 +1163,6 @@ class KumoRFM:
|
|
|
962
1163
|
else:
|
|
963
1164
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
964
1165
|
|
|
965
|
-
assert anchor_time is not None
|
|
966
1166
|
if isinstance(anchor_time, pd.Timestamp):
|
|
967
1167
|
if context_anchor_time == 'entity':
|
|
968
1168
|
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
@@ -970,7 +1170,7 @@ class KumoRFM:
|
|
|
970
1170
|
if context_anchor_time is None:
|
|
971
1171
|
context_anchor_time = anchor_time - end_offset
|
|
972
1172
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
973
|
-
evaluate)
|
|
1173
|
+
evaluate=num_test_examples > 0)
|
|
974
1174
|
else:
|
|
975
1175
|
assert anchor_time == 'entity'
|
|
976
1176
|
if query.query_type != QueryType.STATIC:
|
|
@@ -985,14 +1185,6 @@ class KumoRFM:
|
|
|
985
1185
|
"for context and prediction examples")
|
|
986
1186
|
context_anchor_time = 'entity'
|
|
987
1187
|
|
|
988
|
-
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
989
|
-
if evaluate:
|
|
990
|
-
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
991
|
-
if task_type.is_link_pred:
|
|
992
|
-
num_test_examples = num_test_examples // 5
|
|
993
|
-
else:
|
|
994
|
-
num_test_examples = 0
|
|
995
|
-
|
|
996
1188
|
train, test = self._sampler.sample_target(
|
|
997
1189
|
query=query,
|
|
998
1190
|
num_train_examples=num_train_examples,
|
|
@@ -1003,39 +1195,32 @@ class KumoRFM:
|
|
|
1003
1195
|
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1004
1196
|
random_seed=random_seed,
|
|
1005
1197
|
)
|
|
1006
|
-
train_pkey, train_time,
|
|
1007
|
-
test_pkey, test_time,
|
|
1198
|
+
train_pkey, train_time, train_y = train
|
|
1199
|
+
test_pkey, test_time, test_y = test
|
|
1008
1200
|
|
|
1009
|
-
if
|
|
1201
|
+
if num_test_examples > 0 and logger is not None:
|
|
1010
1202
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1011
|
-
pos = 100 * int((
|
|
1012
|
-
msg = (f"Collected {len(
|
|
1203
|
+
pos = 100 * int((test_y > 0).sum()) / len(test_y)
|
|
1204
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1013
1205
|
f"{pos:.2f}% positive cases")
|
|
1014
1206
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1015
|
-
msg = (f"Collected {len(
|
|
1016
|
-
f"{
|
|
1207
|
+
msg = (f"Collected {len(test_y):,} test examples holding "
|
|
1208
|
+
f"{test_y.nunique()} classes")
|
|
1017
1209
|
elif task_type == TaskType.REGRESSION:
|
|
1018
|
-
_min, _max = float(
|
|
1019
|
-
msg = (f"Collected {len(
|
|
1210
|
+
_min, _max = float(test_y.min()), float(test_y.max())
|
|
1211
|
+
msg = (f"Collected {len(test_y):,} test examples with targets "
|
|
1020
1212
|
f"between {format_value(_min)} and "
|
|
1021
1213
|
f"{format_value(_max)}")
|
|
1022
1214
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1023
|
-
num_rhs =
|
|
1024
|
-
msg = (f"Collected {len(
|
|
1215
|
+
num_rhs = test_y.explode().nunique()
|
|
1216
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1025
1217
|
f"{num_rhs:,} unique items")
|
|
1026
1218
|
else:
|
|
1027
1219
|
raise NotImplementedError
|
|
1028
1220
|
logger.log(msg)
|
|
1029
1221
|
|
|
1030
|
-
if
|
|
1222
|
+
if num_test_examples == 0:
|
|
1031
1223
|
assert indices is not None
|
|
1032
|
-
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
1033
|
-
raise ValueError(f"Cannot predict for more than "
|
|
1034
|
-
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
1035
|
-
f"once (got {len(indices):,}). Use "
|
|
1036
|
-
f"`KumoRFM.batch_mode` to process entities "
|
|
1037
|
-
f"in batches")
|
|
1038
|
-
|
|
1039
1224
|
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
1040
1225
|
if isinstance(anchor_time, pd.Timestamp):
|
|
1041
1226
|
test_time = pd.Series([anchor_time]).repeat(
|
|
@@ -1045,26 +1230,26 @@ class KumoRFM:
|
|
|
1045
1230
|
|
|
1046
1231
|
if logger is not None:
|
|
1047
1232
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1048
|
-
pos = 100 * int((
|
|
1049
|
-
msg = (f"Collected {len(
|
|
1233
|
+
pos = 100 * int((train_y > 0).sum()) / len(train_y)
|
|
1234
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1050
1235
|
f"{pos:.2f}% positive cases")
|
|
1051
1236
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1052
|
-
msg = (f"Collected {len(
|
|
1053
|
-
f"holding {
|
|
1237
|
+
msg = (f"Collected {len(train_y):,} in-context examples "
|
|
1238
|
+
f"holding {train_y.nunique()} classes")
|
|
1054
1239
|
elif task_type == TaskType.REGRESSION:
|
|
1055
|
-
_min, _max = float(
|
|
1056
|
-
msg = (f"Collected {len(
|
|
1240
|
+
_min, _max = float(train_y.min()), float(train_y.max())
|
|
1241
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1057
1242
|
f"targets between {format_value(_min)} and "
|
|
1058
1243
|
f"{format_value(_max)}")
|
|
1059
1244
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1060
|
-
num_rhs =
|
|
1061
|
-
msg = (f"Collected {len(
|
|
1245
|
+
num_rhs = train_y.explode().nunique()
|
|
1246
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1062
1247
|
f"{num_rhs:,} unique items")
|
|
1063
1248
|
else:
|
|
1064
1249
|
raise NotImplementedError
|
|
1065
1250
|
logger.log(msg)
|
|
1066
1251
|
|
|
1067
|
-
entity_table_names:
|
|
1252
|
+
entity_table_names: tuple[str] | tuple[str, str]
|
|
1068
1253
|
if task_type.is_link_pred:
|
|
1069
1254
|
final_aggr = query.get_final_target_aggregation()
|
|
1070
1255
|
assert final_aggr is not None
|
|
@@ -1078,27 +1263,80 @@ class KumoRFM:
|
|
|
1078
1263
|
else:
|
|
1079
1264
|
entity_table_names = (query.entity_table, )
|
|
1080
1265
|
|
|
1266
|
+
context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
|
|
1267
|
+
if isinstance(train_time, pd.Series):
|
|
1268
|
+
context_df['ANCHOR_TIMESTAMP'] = train_time
|
|
1269
|
+
pred_df = pd.DataFrame({'ENTITY': test_pkey})
|
|
1270
|
+
if num_test_examples > 0:
|
|
1271
|
+
pred_df['TARGET'] = test_y
|
|
1272
|
+
if isinstance(test_time, pd.Series):
|
|
1273
|
+
pred_df['ANCHOR_TIMESTAMP'] = test_time
|
|
1274
|
+
|
|
1275
|
+
return TaskTable(
|
|
1276
|
+
task_type=task_type,
|
|
1277
|
+
context_df=context_df,
|
|
1278
|
+
pred_df=pred_df,
|
|
1279
|
+
entity_table_name=entity_table_names,
|
|
1280
|
+
entity_column='ENTITY',
|
|
1281
|
+
target_column='TARGET',
|
|
1282
|
+
time_column='ANCHOR_TIMESTAMP' if isinstance(
|
|
1283
|
+
train_time, pd.Series) else TaskTable.ENTITY_TIME,
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
def _get_context(
|
|
1287
|
+
self,
|
|
1288
|
+
task: TaskTable,
|
|
1289
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
1290
|
+
num_neighbors: list[int] | None = None,
|
|
1291
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
1292
|
+
top_k: int | None = None,
|
|
1293
|
+
) -> Context:
|
|
1294
|
+
|
|
1295
|
+
if num_neighbors is None:
|
|
1296
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1297
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
|
1298
|
+
|
|
1299
|
+
if len(num_neighbors) > 6:
|
|
1300
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
1301
|
+
f"hops (got {len(num_neighbors)}). Reduce the "
|
|
1302
|
+
f"number of hops and try again. Please create a "
|
|
1303
|
+
f"feature request at "
|
|
1304
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1305
|
+
f"must go beyond this for your use-case.")
|
|
1306
|
+
|
|
1081
1307
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1082
1308
|
# running out-of-distribution between in-context and test examples:
|
|
1083
|
-
exclude_cols_dict =
|
|
1084
|
-
if
|
|
1085
|
-
if
|
|
1086
|
-
exclude_cols_dict[
|
|
1087
|
-
|
|
1088
|
-
exclude_cols_dict[
|
|
1309
|
+
exclude_cols_dict = exclude_cols_dict or {}
|
|
1310
|
+
if task.entity_table_name in self._sampler.time_column_dict:
|
|
1311
|
+
if task.entity_table_name not in exclude_cols_dict:
|
|
1312
|
+
exclude_cols_dict[task.entity_table_name] = []
|
|
1313
|
+
time_col = self._sampler.time_column_dict[task.entity_table_name]
|
|
1314
|
+
exclude_cols_dict[task.entity_table_name].append(time_col)
|
|
1315
|
+
|
|
1316
|
+
entity_pkey = pd.concat([
|
|
1317
|
+
task._context_df[task._entity_column],
|
|
1318
|
+
task._pred_df[task._entity_column],
|
|
1319
|
+
], axis=0, ignore_index=True)
|
|
1320
|
+
|
|
1321
|
+
if task.use_entity_time:
|
|
1322
|
+
if task.entity_table_name not in self._sampler.time_column_dict:
|
|
1323
|
+
raise ValueError(f"The given annchor time requires the entity "
|
|
1324
|
+
f"table '{task.entity_table_name}' to have a "
|
|
1325
|
+
f"time column")
|
|
1326
|
+
anchor_time = 'entity'
|
|
1327
|
+
elif task._time_column is not None:
|
|
1328
|
+
anchor_time = pd.concat([
|
|
1329
|
+
task._context_df[task._time_column],
|
|
1330
|
+
task._pred_df[task._time_column],
|
|
1331
|
+
], axis=0, ignore_index=True)
|
|
1332
|
+
else:
|
|
1333
|
+
anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
|
|
1334
|
+
(len(entity_pkey))).reset_index(drop=True)
|
|
1089
1335
|
|
|
1090
1336
|
subgraph = self._sampler.sample_subgraph(
|
|
1091
|
-
entity_table_names=entity_table_names,
|
|
1092
|
-
entity_pkey=
|
|
1093
|
-
|
|
1094
|
-
axis=0,
|
|
1095
|
-
ignore_index=True,
|
|
1096
|
-
),
|
|
1097
|
-
anchor_time=pd.concat(
|
|
1098
|
-
[train_time, test_time],
|
|
1099
|
-
axis=0,
|
|
1100
|
-
ignore_index=True,
|
|
1101
|
-
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1337
|
+
entity_table_names=task.entity_table_names,
|
|
1338
|
+
entity_pkey=entity_pkey,
|
|
1339
|
+
anchor_time=anchor_time,
|
|
1102
1340
|
num_neighbors=num_neighbors,
|
|
1103
1341
|
exclude_cols_dict=exclude_cols_dict,
|
|
1104
1342
|
)
|
|
@@ -1110,19 +1348,26 @@ class KumoRFM:
|
|
|
1110
1348
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1111
1349
|
f"must go beyond this for your use-case.")
|
|
1112
1350
|
|
|
1351
|
+
if (task.task_type.is_link_pred
|
|
1352
|
+
and task.entity_table_names[-1] not in subgraph.table_dict):
|
|
1353
|
+
raise ValueError("Cannot perform link prediction on subgraphs "
|
|
1354
|
+
"without any historical target entities. Please "
|
|
1355
|
+
"increase the number of hops and try again.")
|
|
1356
|
+
|
|
1113
1357
|
return Context(
|
|
1114
|
-
task_type=task_type,
|
|
1115
|
-
entity_table_names=entity_table_names,
|
|
1358
|
+
task_type=task.task_type,
|
|
1359
|
+
entity_table_names=task.entity_table_names,
|
|
1116
1360
|
subgraph=subgraph,
|
|
1117
|
-
y_train=
|
|
1118
|
-
y_test=
|
|
1119
|
-
|
|
1361
|
+
y_train=task._context_df[task.target_column.name],
|
|
1362
|
+
y_test=task._pred_df[task.target_column.name]
|
|
1363
|
+
if task.evaluate else None,
|
|
1364
|
+
top_k=top_k,
|
|
1120
1365
|
step_size=None,
|
|
1121
1366
|
)
|
|
1122
1367
|
|
|
1123
1368
|
@staticmethod
|
|
1124
1369
|
def _validate_metrics(
|
|
1125
|
-
metrics:
|
|
1370
|
+
metrics: list[str],
|
|
1126
1371
|
task_type: TaskType,
|
|
1127
1372
|
) -> None:
|
|
1128
1373
|
|
|
@@ -1179,7 +1424,7 @@ class KumoRFM:
|
|
|
1179
1424
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1180
1425
|
|
|
1181
1426
|
|
|
1182
|
-
def format_value(value:
|
|
1427
|
+
def format_value(value: int | float) -> str:
|
|
1183
1428
|
if value == int(value):
|
|
1184
1429
|
return f'{int(value):,}'
|
|
1185
1430
|
if abs(value) >= 1000:
|