kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +52 -104
- kumoai/experimental/rfm/backend/local/sampler.py +125 -55
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +174 -49
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +21 -5
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +422 -35
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +144 -0
- kumoai/experimental/rfm/base/table.py +386 -195
- kumoai/experimental/rfm/graph.py +350 -178
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +630 -408
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +290 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +49 -40
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,26 +1,24 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
import time
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
6
|
+
from collections.abc import Generator, Iterator
|
|
6
7
|
from contextlib import contextmanager
|
|
7
8
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
9
|
+
from typing import Any, Literal, overload
|
|
19
10
|
|
|
20
11
|
import numpy as np
|
|
21
12
|
import pandas as pd
|
|
22
13
|
from kumoapi.model_plan import RunMode
|
|
23
14
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
15
|
+
from kumoapi.pquery.AST import (
|
|
16
|
+
Aggregation,
|
|
17
|
+
Column,
|
|
18
|
+
Condition,
|
|
19
|
+
Join,
|
|
20
|
+
LogicalOperation,
|
|
21
|
+
)
|
|
24
22
|
from kumoapi.rfm import Context
|
|
25
23
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
24
|
from kumoapi.rfm import (
|
|
@@ -29,35 +27,35 @@ from kumoapi.rfm import (
|
|
|
29
27
|
RFMPredictRequest,
|
|
30
28
|
)
|
|
31
29
|
from kumoapi.task import TaskType
|
|
30
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
31
|
|
|
33
32
|
from kumoai.client.rfm import RFMAPI
|
|
34
33
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import Graph
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
34
|
+
from kumoai.experimental.rfm import Graph, TaskTable
|
|
35
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
36
|
from kumoai.mixin import CastMixin
|
|
43
|
-
from kumoai.utils import
|
|
37
|
+
from kumoai.utils import ProgressLogger, display
|
|
44
38
|
|
|
45
39
|
_RANDOM_SEED = 42
|
|
46
40
|
|
|
47
41
|
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
48
42
|
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
49
43
|
|
|
44
|
+
_MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
|
|
45
|
+
_MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
|
|
46
|
+
|
|
50
47
|
_MAX_CONTEXT_SIZE = {
|
|
51
48
|
RunMode.DEBUG: 100,
|
|
52
49
|
RunMode.FAST: 1_000,
|
|
53
50
|
RunMode.NORMAL: 5_000,
|
|
54
51
|
RunMode.BEST: 10_000,
|
|
55
52
|
}
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
RunMode.
|
|
59
|
-
RunMode.
|
|
60
|
-
RunMode.
|
|
53
|
+
|
|
54
|
+
_DEFAULT_NUM_NEIGHBORS = {
|
|
55
|
+
RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
|
|
56
|
+
RunMode.FAST: [32, 32, 8, 8, 4, 4],
|
|
57
|
+
RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
|
|
58
|
+
RunMode.BEST: [64, 64, 8, 8, 4, 4],
|
|
61
59
|
}
|
|
62
60
|
|
|
63
61
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
@@ -95,24 +93,26 @@ class Explanation:
|
|
|
95
93
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
94
|
pass
|
|
97
95
|
|
|
98
|
-
def __getitem__(self, index: int) ->
|
|
96
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
99
97
|
if index == 0:
|
|
100
98
|
return self.prediction
|
|
101
99
|
if index == 1:
|
|
102
100
|
return self.summary
|
|
103
101
|
raise IndexError("Index out of range")
|
|
104
102
|
|
|
105
|
-
def __iter__(self) -> Iterator[
|
|
103
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
106
104
|
return iter((self.prediction, self.summary))
|
|
107
105
|
|
|
108
106
|
def __repr__(self) -> str:
|
|
109
107
|
return str((self.prediction, self.summary))
|
|
110
108
|
|
|
111
|
-
def
|
|
112
|
-
|
|
109
|
+
def print(self) -> None:
|
|
110
|
+
r"""Prints the explanation."""
|
|
111
|
+
display.dataframe(self.prediction)
|
|
112
|
+
display.message(self.summary)
|
|
113
113
|
|
|
114
|
-
|
|
115
|
-
|
|
114
|
+
def _ipython_display_(self) -> None:
|
|
115
|
+
self.print()
|
|
116
116
|
|
|
117
117
|
|
|
118
118
|
class KumoRFM:
|
|
@@ -151,20 +151,35 @@ class KumoRFM:
|
|
|
151
151
|
Args:
|
|
152
152
|
graph: The graph.
|
|
153
153
|
verbose: Whether to print verbose output.
|
|
154
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
155
|
+
for optimal querying. For example, for transactional database
|
|
156
|
+
backends, will create any missing indices. Requires write-access to
|
|
157
|
+
the data backend.
|
|
154
158
|
"""
|
|
155
159
|
def __init__(
|
|
156
160
|
self,
|
|
157
161
|
graph: Graph,
|
|
158
|
-
verbose:
|
|
162
|
+
verbose: bool | ProgressLogger = True,
|
|
163
|
+
optimize: bool = False,
|
|
159
164
|
) -> None:
|
|
160
165
|
graph = graph.validate()
|
|
161
166
|
self._graph_def = graph._to_api_graph_definition()
|
|
162
|
-
self._graph_store = LocalGraphStore(graph, verbose)
|
|
163
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
164
167
|
|
|
165
|
-
|
|
168
|
+
if graph.backend == DataBackend.LOCAL:
|
|
169
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
170
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
171
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
172
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
173
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
174
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
175
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
176
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
177
|
+
else:
|
|
178
|
+
raise NotImplementedError
|
|
166
179
|
|
|
167
|
-
self.
|
|
180
|
+
self._client: RFMAPI | None = None
|
|
181
|
+
|
|
182
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
168
183
|
self.num_retries: int = 0
|
|
169
184
|
|
|
170
185
|
@property
|
|
@@ -182,7 +197,7 @@ class KumoRFM:
|
|
|
182
197
|
@contextmanager
|
|
183
198
|
def batch_mode(
|
|
184
199
|
self,
|
|
185
|
-
batch_size:
|
|
200
|
+
batch_size: int | Literal['max'] = 'max',
|
|
186
201
|
num_retries: int = 1,
|
|
187
202
|
) -> Generator[None, None, None]:
|
|
188
203
|
"""Context manager to predict in batches.
|
|
@@ -216,17 +231,17 @@ class KumoRFM:
|
|
|
216
231
|
def predict(
|
|
217
232
|
self,
|
|
218
233
|
query: str,
|
|
219
|
-
indices:
|
|
234
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
220
235
|
*,
|
|
221
236
|
explain: Literal[False] = False,
|
|
222
|
-
anchor_time:
|
|
223
|
-
context_anchor_time:
|
|
224
|
-
run_mode:
|
|
225
|
-
num_neighbors:
|
|
237
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
238
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
239
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
240
|
+
num_neighbors: list[int] | None = None,
|
|
226
241
|
num_hops: int = 2,
|
|
227
|
-
max_pq_iterations: int =
|
|
228
|
-
random_seed:
|
|
229
|
-
verbose:
|
|
242
|
+
max_pq_iterations: int = 10,
|
|
243
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
244
|
+
verbose: bool | ProgressLogger = True,
|
|
230
245
|
use_prediction_time: bool = False,
|
|
231
246
|
) -> pd.DataFrame:
|
|
232
247
|
pass
|
|
@@ -235,17 +250,17 @@ class KumoRFM:
|
|
|
235
250
|
def predict(
|
|
236
251
|
self,
|
|
237
252
|
query: str,
|
|
238
|
-
indices:
|
|
253
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
239
254
|
*,
|
|
240
|
-
explain:
|
|
241
|
-
anchor_time:
|
|
242
|
-
context_anchor_time:
|
|
243
|
-
run_mode:
|
|
244
|
-
num_neighbors:
|
|
255
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
256
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
257
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
258
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
259
|
+
num_neighbors: list[int] | None = None,
|
|
245
260
|
num_hops: int = 2,
|
|
246
|
-
max_pq_iterations: int =
|
|
247
|
-
random_seed:
|
|
248
|
-
verbose:
|
|
261
|
+
max_pq_iterations: int = 10,
|
|
262
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
263
|
+
verbose: bool | ProgressLogger = True,
|
|
249
264
|
use_prediction_time: bool = False,
|
|
250
265
|
) -> Explanation:
|
|
251
266
|
pass
|
|
@@ -253,19 +268,19 @@ class KumoRFM:
|
|
|
253
268
|
def predict(
|
|
254
269
|
self,
|
|
255
270
|
query: str,
|
|
256
|
-
indices:
|
|
271
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
257
272
|
*,
|
|
258
|
-
explain:
|
|
259
|
-
anchor_time:
|
|
260
|
-
context_anchor_time:
|
|
261
|
-
run_mode:
|
|
262
|
-
num_neighbors:
|
|
273
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
274
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
275
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
276
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
277
|
+
num_neighbors: list[int] | None = None,
|
|
263
278
|
num_hops: int = 2,
|
|
264
|
-
max_pq_iterations: int =
|
|
265
|
-
random_seed:
|
|
266
|
-
verbose:
|
|
279
|
+
max_pq_iterations: int = 10,
|
|
280
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
281
|
+
verbose: bool | ProgressLogger = True,
|
|
267
282
|
use_prediction_time: bool = False,
|
|
268
|
-
) ->
|
|
283
|
+
) -> pd.DataFrame | Explanation:
|
|
269
284
|
"""Returns predictions for a predictive query.
|
|
270
285
|
|
|
271
286
|
Args:
|
|
@@ -307,18 +322,133 @@ class KumoRFM:
|
|
|
307
322
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
308
323
|
containing the prediction, summary, and details.
|
|
309
324
|
"""
|
|
310
|
-
explain_config: Optional[ExplainConfig] = None
|
|
311
|
-
if explain is True:
|
|
312
|
-
explain_config = ExplainConfig()
|
|
313
|
-
elif explain is not False:
|
|
314
|
-
explain_config = ExplainConfig._cast(explain)
|
|
315
|
-
|
|
316
325
|
query_def = self._parse_query(query)
|
|
317
|
-
query_str = query_def.to_string()
|
|
318
326
|
|
|
327
|
+
if indices is None:
|
|
328
|
+
if query_def.rfm_entity_ids is None:
|
|
329
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
330
|
+
"pass them via `predict(query, indices=...)`")
|
|
331
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
332
|
+
else:
|
|
333
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
334
|
+
|
|
335
|
+
if not isinstance(verbose, ProgressLogger):
|
|
336
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
337
|
+
if explain is not False:
|
|
338
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
339
|
+
else:
|
|
340
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
341
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
342
|
+
|
|
343
|
+
with verbose as logger:
|
|
344
|
+
task_table = self._get_task_table(
|
|
345
|
+
query=query_def,
|
|
346
|
+
indices=indices,
|
|
347
|
+
anchor_time=anchor_time,
|
|
348
|
+
context_anchor_time=context_anchor_time,
|
|
349
|
+
run_mode=run_mode,
|
|
350
|
+
max_pq_iterations=max_pq_iterations,
|
|
351
|
+
random_seed=random_seed,
|
|
352
|
+
logger=logger,
|
|
353
|
+
)
|
|
354
|
+
task_table._query = query_def.to_string() # type: ignore
|
|
355
|
+
|
|
356
|
+
return self.predict_task(
|
|
357
|
+
task_table,
|
|
358
|
+
explain=explain, # type: ignore
|
|
359
|
+
run_mode=run_mode,
|
|
360
|
+
num_neighbors=num_neighbors,
|
|
361
|
+
num_hops=num_hops,
|
|
362
|
+
verbose=verbose,
|
|
363
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
364
|
+
use_prediction_time=use_prediction_time,
|
|
365
|
+
top_k=query_def.top_k,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
@overload
|
|
369
|
+
def predict_task(
|
|
370
|
+
self,
|
|
371
|
+
task: TaskTable,
|
|
372
|
+
*,
|
|
373
|
+
explain: Literal[False] = False,
|
|
374
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
375
|
+
num_neighbors: list[int] | None = None,
|
|
376
|
+
num_hops: int = 2,
|
|
377
|
+
verbose: bool | ProgressLogger = True,
|
|
378
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
379
|
+
use_prediction_time: bool = False,
|
|
380
|
+
top_k: int | None = None,
|
|
381
|
+
) -> pd.DataFrame:
|
|
382
|
+
pass
|
|
383
|
+
|
|
384
|
+
@overload
|
|
385
|
+
def predict_task(
|
|
386
|
+
self,
|
|
387
|
+
task: TaskTable,
|
|
388
|
+
*,
|
|
389
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
390
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
391
|
+
num_neighbors: list[int] | None = None,
|
|
392
|
+
num_hops: int = 2,
|
|
393
|
+
verbose: bool | ProgressLogger = True,
|
|
394
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
395
|
+
use_prediction_time: bool = False,
|
|
396
|
+
top_k: int | None = None,
|
|
397
|
+
) -> Explanation:
|
|
398
|
+
pass
|
|
399
|
+
|
|
400
|
+
def predict_task(
|
|
401
|
+
self,
|
|
402
|
+
task: TaskTable,
|
|
403
|
+
*,
|
|
404
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
405
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
406
|
+
num_neighbors: list[int] | None = None,
|
|
407
|
+
num_hops: int = 2,
|
|
408
|
+
verbose: bool | ProgressLogger = True,
|
|
409
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
410
|
+
use_prediction_time: bool = False,
|
|
411
|
+
top_k: int | None = None,
|
|
412
|
+
) -> pd.DataFrame | Explanation:
|
|
413
|
+
"""Returns predictions for a custom task specification.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
task: The custom :class:`TaskTable`.
|
|
417
|
+
explain: Configuration for explainability.
|
|
418
|
+
If set to ``True``, will additionally explain the prediction.
|
|
419
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
420
|
+
over which parts of explanation are generated.
|
|
421
|
+
Explainability is currently only supported for single entity
|
|
422
|
+
predictions with ``run_mode="FAST"``.
|
|
423
|
+
run_mode: The :class:`RunMode` for the query.
|
|
424
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
425
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
426
|
+
num_hops: The number of hops to sample when generating the context.
|
|
427
|
+
verbose: Whether to print verbose output.
|
|
428
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
429
|
+
model input.
|
|
430
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
431
|
+
additional feature during prediction. This is typically
|
|
432
|
+
beneficial for time series forecasting tasks.
|
|
433
|
+
top_k: The number of predictions to return per entity.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
437
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
438
|
+
containing the prediction, summary, and details.
|
|
439
|
+
"""
|
|
319
440
|
if num_hops != 2 and num_neighbors is not None:
|
|
320
441
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
321
442
|
f"custom 'num_hops={num_hops}' option")
|
|
443
|
+
if num_neighbors is None:
|
|
444
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
445
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
446
|
+
|
|
447
|
+
explain_config: ExplainConfig | None = None
|
|
448
|
+
if explain is True:
|
|
449
|
+
explain_config = ExplainConfig()
|
|
450
|
+
elif explain is not False:
|
|
451
|
+
explain_config = ExplainConfig._cast(explain)
|
|
322
452
|
|
|
323
453
|
if explain_config is not None and run_mode in {
|
|
324
454
|
RunMode.NORMAL, RunMode.BEST
|
|
@@ -327,83 +457,82 @@ class KumoRFM:
|
|
|
327
457
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
328
458
|
f"mode has been reset. Please lower the run mode to "
|
|
329
459
|
f"suppress this warning.")
|
|
460
|
+
run_mode = RunMode.FAST
|
|
330
461
|
|
|
331
|
-
if
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
336
|
-
else:
|
|
337
|
-
query_def = replace(query_def, rfm_entity_ids=None)
|
|
338
|
-
|
|
339
|
-
if len(indices) == 0:
|
|
340
|
-
raise ValueError("At least one entity is required")
|
|
341
|
-
|
|
342
|
-
if explain_config is not None and len(indices) > 1:
|
|
343
|
-
raise ValueError(
|
|
344
|
-
f"Cannot explain predictions for more than a single entity "
|
|
345
|
-
f"(got {len(indices)})")
|
|
346
|
-
|
|
347
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
348
|
-
if explain_config is not None:
|
|
349
|
-
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
350
|
-
else:
|
|
351
|
-
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
462
|
+
if explain_config is not None and task.num_prediction_examples > 1:
|
|
463
|
+
raise ValueError(f"Cannot explain predictions for more than a "
|
|
464
|
+
f"single entity "
|
|
465
|
+
f"(got {task.num_prediction_examples:,})")
|
|
352
466
|
|
|
353
467
|
if not isinstance(verbose, ProgressLogger):
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
edge_types=self._graph_store.edge_types,
|
|
363
|
-
)
|
|
364
|
-
batch_size = _MAX_PRED_SIZE[task_type]
|
|
468
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
469
|
+
task_type_repr = 'binary classification'
|
|
470
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
471
|
+
task_type_repr = 'multi-class classification'
|
|
472
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
473
|
+
task_type_repr = 'regression'
|
|
474
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
475
|
+
task_type_repr = 'link prediction'
|
|
365
476
|
else:
|
|
366
|
-
|
|
477
|
+
task_type_repr = str(task.task_type)
|
|
367
478
|
|
|
368
|
-
if
|
|
369
|
-
|
|
370
|
-
batches = [indices[step:step + batch_size] for step in offsets]
|
|
479
|
+
if explain_config is not None:
|
|
480
|
+
msg = f'Explain {task_type_repr} task'
|
|
371
481
|
else:
|
|
372
|
-
|
|
482
|
+
msg = f'Predict {task_type_repr} task'
|
|
483
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
373
484
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
485
|
+
with verbose as logger:
|
|
486
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
487
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
488
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
489
|
+
f"examples")
|
|
490
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
491
|
+
|
|
492
|
+
if self._batch_size is None:
|
|
493
|
+
batch_size = task.num_prediction_examples
|
|
494
|
+
elif self._batch_size == 'max':
|
|
495
|
+
batch_size = _MAX_PRED_SIZE[task.task_type]
|
|
496
|
+
else:
|
|
497
|
+
batch_size = self._batch_size
|
|
377
498
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
499
|
+
if batch_size > _MAX_PRED_SIZE[task.task_type]:
|
|
500
|
+
raise ValueError(f"Cannot predict for more than "
|
|
501
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
502
|
+
f"entities at once (got {batch_size:,}). Use "
|
|
503
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
504
|
+
f"in batches with a sufficient batch size.")
|
|
505
|
+
|
|
506
|
+
if task.num_prediction_examples > batch_size:
|
|
507
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
508
|
+
logger.log(f"Splitting {task.num_prediction_examples:,} "
|
|
509
|
+
f"entities into {num:,} batches of size "
|
|
510
|
+
f"{batch_size:,}")
|
|
511
|
+
|
|
512
|
+
predictions: list[pd.DataFrame] = []
|
|
513
|
+
summary: str | None = None
|
|
514
|
+
details: Explanation | None = None
|
|
515
|
+
for start in range(0, task.num_prediction_examples, batch_size):
|
|
383
516
|
context = self._get_context(
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
anchor_time=anchor_time,
|
|
387
|
-
context_anchor_time=context_anchor_time,
|
|
388
|
-
run_mode=RunMode(run_mode),
|
|
517
|
+
task=task.narrow_prediction(start, length=batch_size),
|
|
518
|
+
run_mode=run_mode,
|
|
389
519
|
num_neighbors=num_neighbors,
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
evaluate=False,
|
|
393
|
-
random_seed=random_seed,
|
|
394
|
-
logger=logger if i == 0 else None,
|
|
520
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
521
|
+
top_k=top_k,
|
|
395
522
|
)
|
|
523
|
+
context.y_test = None
|
|
524
|
+
|
|
396
525
|
request = RFMPredictRequest(
|
|
397
526
|
context=context,
|
|
398
527
|
run_mode=RunMode(run_mode),
|
|
399
|
-
query=
|
|
528
|
+
query=getattr(task, '_query', ''),
|
|
400
529
|
use_prediction_time=use_prediction_time,
|
|
401
530
|
)
|
|
402
531
|
with warnings.catch_warnings():
|
|
403
532
|
warnings.filterwarnings('ignore', message='gencode')
|
|
404
533
|
request_msg = request.to_protobuf()
|
|
405
534
|
_bytes = request_msg.SerializeToString()
|
|
406
|
-
if
|
|
535
|
+
if start == 0:
|
|
407
536
|
logger.log(f"Generated context of size "
|
|
408
537
|
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
409
538
|
|
|
@@ -411,12 +540,9 @@ class KumoRFM:
|
|
|
411
540
|
stats = Context.get_memory_stats(request_msg.context)
|
|
412
541
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
413
542
|
|
|
414
|
-
if
|
|
415
|
-
|
|
416
|
-
verbose.init_progress(
|
|
417
|
-
total=len(batches),
|
|
418
|
-
description='Predicting',
|
|
419
|
-
)
|
|
543
|
+
if start == 0 and task.num_prediction_examples > batch_size:
|
|
544
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
545
|
+
verbose.init_progress(total=num, description='Predicting')
|
|
420
546
|
|
|
421
547
|
for attempt in range(self.num_retries + 1):
|
|
422
548
|
try:
|
|
@@ -433,10 +559,10 @@ class KumoRFM:
|
|
|
433
559
|
|
|
434
560
|
# Cast 'ENTITY' to correct data type:
|
|
435
561
|
if 'ENTITY' in df:
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
562
|
+
table_dict = context.subgraph.table_dict
|
|
563
|
+
table = table_dict[context.entity_table_names[0]]
|
|
564
|
+
ser = table.df[table.primary_key]
|
|
565
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
440
566
|
|
|
441
567
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
442
568
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -451,8 +577,7 @@ class KumoRFM:
|
|
|
451
577
|
|
|
452
578
|
predictions.append(df)
|
|
453
579
|
|
|
454
|
-
if
|
|
455
|
-
and len(batches) > 1):
|
|
580
|
+
if task.num_prediction_examples > batch_size:
|
|
456
581
|
verbose.step()
|
|
457
582
|
|
|
458
583
|
break
|
|
@@ -490,9 +615,9 @@ class KumoRFM:
|
|
|
490
615
|
def is_valid_entity(
|
|
491
616
|
self,
|
|
492
617
|
query: str,
|
|
493
|
-
indices:
|
|
618
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
494
619
|
*,
|
|
495
|
-
anchor_time:
|
|
620
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
496
621
|
) -> np.ndarray:
|
|
497
622
|
r"""Returns a mask that denotes which entities are valid for the
|
|
498
623
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -519,37 +644,32 @@ class KumoRFM:
|
|
|
519
644
|
raise ValueError("At least one entity is required")
|
|
520
645
|
|
|
521
646
|
if anchor_time is None:
|
|
522
|
-
anchor_time = self.
|
|
647
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
523
648
|
|
|
524
649
|
if isinstance(anchor_time, pd.Timestamp):
|
|
525
650
|
self._validate_time(query_def, anchor_time, None, False)
|
|
526
651
|
else:
|
|
527
652
|
assert anchor_time == 'entity'
|
|
528
|
-
if
|
|
653
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
529
654
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
655
|
f"table '{query_def.entity_table}' "
|
|
531
656
|
f"to have a time column.")
|
|
532
657
|
|
|
533
|
-
|
|
534
|
-
table_name=query_def.entity_table,
|
|
535
|
-
pkey=pd.Series(indices),
|
|
536
|
-
)
|
|
537
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
538
|
-
return query_driver.is_valid(node, anchor_time)
|
|
658
|
+
raise NotImplementedError
|
|
539
659
|
|
|
540
660
|
def evaluate(
|
|
541
661
|
self,
|
|
542
662
|
query: str,
|
|
543
663
|
*,
|
|
544
|
-
metrics:
|
|
545
|
-
anchor_time:
|
|
546
|
-
context_anchor_time:
|
|
547
|
-
run_mode:
|
|
548
|
-
num_neighbors:
|
|
664
|
+
metrics: list[str] | None = None,
|
|
665
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
666
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
667
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
668
|
+
num_neighbors: list[int] | None = None,
|
|
549
669
|
num_hops: int = 2,
|
|
550
|
-
max_pq_iterations: int =
|
|
551
|
-
random_seed:
|
|
552
|
-
verbose:
|
|
670
|
+
max_pq_iterations: int = 10,
|
|
671
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
672
|
+
verbose: bool | ProgressLogger = True,
|
|
553
673
|
use_prediction_time: bool = False,
|
|
554
674
|
) -> pd.DataFrame:
|
|
555
675
|
"""Evaluates a predictive query.
|
|
@@ -582,40 +702,51 @@ class KumoRFM:
|
|
|
582
702
|
The metrics as a :class:`pandas.DataFrame`
|
|
583
703
|
"""
|
|
584
704
|
query_def = self._parse_query(query)
|
|
585
|
-
|
|
586
|
-
if num_hops != 2 and num_neighbors is not None:
|
|
587
|
-
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
588
|
-
f"custom 'num_hops={num_hops}' option")
|
|
589
|
-
|
|
590
705
|
if query_def.rfm_entity_ids is not None:
|
|
591
706
|
query_def = replace(
|
|
592
707
|
query_def,
|
|
593
708
|
rfm_entity_ids=None,
|
|
594
709
|
)
|
|
595
710
|
|
|
596
|
-
|
|
597
|
-
|
|
711
|
+
task_type = self._get_task_type(
|
|
712
|
+
query=query_def,
|
|
713
|
+
edge_types=self._sampler.edge_types,
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
if num_hops != 2 and num_neighbors is not None:
|
|
717
|
+
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
718
|
+
f"custom 'num_hops={num_hops}' option")
|
|
719
|
+
if num_neighbors is None:
|
|
720
|
+
key = RunMode.FAST if task_type.is_link_pred else run_mode
|
|
721
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
722
|
+
|
|
723
|
+
if metrics is not None and len(metrics) > 0:
|
|
724
|
+
self._validate_metrics(metrics, task_type)
|
|
725
|
+
metrics = list(dict.fromkeys(metrics))
|
|
598
726
|
|
|
599
727
|
if not isinstance(verbose, ProgressLogger):
|
|
600
|
-
|
|
728
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
729
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
730
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
601
731
|
|
|
602
732
|
with verbose as logger:
|
|
603
|
-
|
|
733
|
+
task_table = self._get_task_table(
|
|
604
734
|
query=query_def,
|
|
605
735
|
indices=None,
|
|
606
736
|
anchor_time=anchor_time,
|
|
607
737
|
context_anchor_time=context_anchor_time,
|
|
608
|
-
run_mode=
|
|
609
|
-
num_neighbors=num_neighbors,
|
|
610
|
-
num_hops=num_hops,
|
|
738
|
+
run_mode=run_mode,
|
|
611
739
|
max_pq_iterations=max_pq_iterations,
|
|
612
|
-
evaluate=True,
|
|
613
740
|
random_seed=random_seed,
|
|
614
|
-
logger=logger
|
|
741
|
+
logger=logger,
|
|
742
|
+
)
|
|
743
|
+
context = self._get_context(
|
|
744
|
+
task=task_table,
|
|
745
|
+
run_mode=run_mode,
|
|
746
|
+
num_neighbors=num_neighbors,
|
|
747
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
615
748
|
)
|
|
616
|
-
|
|
617
|
-
self._validate_metrics(metrics, context.task_type)
|
|
618
|
-
metrics = list(dict.fromkeys(metrics))
|
|
749
|
+
|
|
619
750
|
request = RFMEvaluateRequest(
|
|
620
751
|
context=context,
|
|
621
752
|
run_mode=RunMode(run_mode),
|
|
@@ -633,17 +764,23 @@ class KumoRFM:
|
|
|
633
764
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
634
765
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
635
766
|
|
|
636
|
-
|
|
637
|
-
resp = self._api_client.evaluate(request_bytes)
|
|
638
|
-
except HTTPException as e:
|
|
767
|
+
for attempt in range(self.num_retries + 1):
|
|
639
768
|
try:
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
769
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
770
|
+
break
|
|
771
|
+
except HTTPException as e:
|
|
772
|
+
if attempt == self.num_retries:
|
|
773
|
+
try:
|
|
774
|
+
msg = json.loads(e.detail)['detail']
|
|
775
|
+
except Exception:
|
|
776
|
+
msg = e.detail
|
|
777
|
+
raise RuntimeError(
|
|
778
|
+
f"An unexpected exception occurred. Please create "
|
|
779
|
+
f"an issue at "
|
|
780
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
781
|
+
) from None
|
|
782
|
+
|
|
783
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
647
784
|
|
|
648
785
|
return pd.DataFrame.from_dict(
|
|
649
786
|
resp.metrics,
|
|
@@ -656,9 +793,9 @@ class KumoRFM:
|
|
|
656
793
|
query: str,
|
|
657
794
|
size: int,
|
|
658
795
|
*,
|
|
659
|
-
anchor_time:
|
|
660
|
-
random_seed:
|
|
661
|
-
max_iterations: int =
|
|
796
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
797
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
798
|
+
max_iterations: int = 10,
|
|
662
799
|
) -> pd.DataFrame:
|
|
663
800
|
"""Returns the labels of a predictive query for a specified anchor
|
|
664
801
|
time.
|
|
@@ -678,40 +815,37 @@ class KumoRFM:
|
|
|
678
815
|
query_def = self._parse_query(query)
|
|
679
816
|
|
|
680
817
|
if anchor_time is None:
|
|
681
|
-
anchor_time = self.
|
|
818
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
682
819
|
if query_def.target_ast.date_offset_range is not None:
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
820
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
821
|
+
offset *= query_def.num_forecasts
|
|
822
|
+
anchor_time -= offset
|
|
686
823
|
|
|
687
824
|
assert anchor_time is not None
|
|
688
825
|
if isinstance(anchor_time, pd.Timestamp):
|
|
689
826
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
690
827
|
else:
|
|
691
828
|
assert anchor_time == 'entity'
|
|
692
|
-
if
|
|
829
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
693
830
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
694
831
|
f"table '{query_def.entity_table}' "
|
|
695
832
|
f"to have a time column")
|
|
696
833
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
834
|
+
train, test = self._sampler.sample_target(
|
|
835
|
+
query=query_def,
|
|
836
|
+
num_train_examples=0,
|
|
837
|
+
train_anchor_time=anchor_time,
|
|
838
|
+
num_train_trials=0,
|
|
839
|
+
num_test_examples=size,
|
|
840
|
+
test_anchor_time=anchor_time,
|
|
841
|
+
num_test_trials=max_iterations * size,
|
|
842
|
+
random_seed=random_seed,
|
|
706
843
|
)
|
|
707
844
|
|
|
708
|
-
entity = self._graph_store.pkey_map_dict[
|
|
709
|
-
query_def.entity_table].index[node]
|
|
710
|
-
|
|
711
845
|
return pd.DataFrame({
|
|
712
|
-
'ENTITY':
|
|
713
|
-
'ANCHOR_TIMESTAMP':
|
|
714
|
-
'TARGET':
|
|
846
|
+
'ENTITY': test.entity_pkey,
|
|
847
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
848
|
+
'TARGET': test.target,
|
|
715
849
|
})
|
|
716
850
|
|
|
717
851
|
# Helpers #################################################################
|
|
@@ -726,63 +860,120 @@ class KumoRFM:
|
|
|
726
860
|
"`predict()` or `evaluate()` methods to perform "
|
|
727
861
|
"predictions or evaluations.")
|
|
728
862
|
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
)
|
|
863
|
+
request = RFMParseQueryRequest(
|
|
864
|
+
query=query,
|
|
865
|
+
graph_definition=self._graph_def,
|
|
866
|
+
)
|
|
734
867
|
|
|
735
|
-
|
|
868
|
+
for attempt in range(self.num_retries + 1):
|
|
869
|
+
try:
|
|
870
|
+
resp = self._api_client.parse_query(request)
|
|
871
|
+
break
|
|
872
|
+
except HTTPException as e:
|
|
873
|
+
if attempt == self.num_retries:
|
|
874
|
+
try:
|
|
875
|
+
msg = json.loads(e.detail)['detail']
|
|
876
|
+
except Exception:
|
|
877
|
+
msg = e.detail
|
|
878
|
+
raise ValueError(f"Failed to parse query '{query}'. {msg}")
|
|
736
879
|
|
|
737
|
-
|
|
880
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
738
881
|
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
882
|
+
if len(resp.validation_response.warnings) > 0:
|
|
883
|
+
msg = '\n'.join([
|
|
884
|
+
f'{i+1}. {warning.title}: {warning.message}'
|
|
885
|
+
for i, warning in enumerate(resp.validation_response.warnings)
|
|
886
|
+
])
|
|
887
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
888
|
+
f"parsing:\n{msg}")
|
|
746
889
|
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
890
|
+
return resp.query
|
|
891
|
+
|
|
892
|
+
@staticmethod
|
|
893
|
+
def _get_task_type(
|
|
894
|
+
query: ValidatedPredictiveQuery,
|
|
895
|
+
edge_types: list[tuple[str, str, str]],
|
|
896
|
+
) -> TaskType:
|
|
897
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
898
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
899
|
+
|
|
900
|
+
target = query.target_ast
|
|
901
|
+
if isinstance(target, Join):
|
|
902
|
+
target = target.rhs_target
|
|
903
|
+
if isinstance(target, Aggregation):
|
|
904
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
905
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
906
|
+
'.')
|
|
907
|
+
target_edge_types = [
|
|
908
|
+
edge_type for edge_type in edge_types
|
|
909
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
910
|
+
]
|
|
911
|
+
if len(target_edge_types) != 1:
|
|
912
|
+
raise NotImplementedError(
|
|
913
|
+
f"Multilabel-classification queries based on "
|
|
914
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
915
|
+
f"planned to write a link prediction query instead, "
|
|
916
|
+
f"make sure to register '{col_name}' as a "
|
|
917
|
+
f"foreign key.")
|
|
918
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
919
|
+
|
|
920
|
+
return TaskType.REGRESSION
|
|
921
|
+
|
|
922
|
+
assert isinstance(target, Column)
|
|
923
|
+
|
|
924
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
925
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
926
|
+
|
|
927
|
+
if target.stype in {Stype.numerical}:
|
|
928
|
+
return TaskType.REGRESSION
|
|
929
|
+
|
|
930
|
+
raise NotImplementedError("Task type not yet supported")
|
|
931
|
+
|
|
932
|
+
def _get_default_anchor_time(
|
|
933
|
+
self,
|
|
934
|
+
query: ValidatedPredictiveQuery | None = None,
|
|
935
|
+
) -> pd.Timestamp:
|
|
936
|
+
if query is not None and query.query_type == QueryType.TEMPORAL:
|
|
937
|
+
aggr_table_names = [
|
|
938
|
+
aggr._get_target_column_name().split('.')[0]
|
|
939
|
+
for aggr in query.get_all_target_aggregations()
|
|
940
|
+
]
|
|
941
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
942
|
+
|
|
943
|
+
return self._sampler.get_max_time()
|
|
755
944
|
|
|
756
945
|
def _validate_time(
|
|
757
946
|
self,
|
|
758
947
|
query: ValidatedPredictiveQuery,
|
|
759
948
|
anchor_time: pd.Timestamp,
|
|
760
|
-
context_anchor_time:
|
|
949
|
+
context_anchor_time: pd.Timestamp | None,
|
|
761
950
|
evaluate: bool,
|
|
762
951
|
) -> None:
|
|
763
952
|
|
|
764
|
-
if self.
|
|
953
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
765
954
|
return # Graph without timestamps
|
|
766
955
|
|
|
767
|
-
|
|
956
|
+
min_time = self._sampler.get_min_time()
|
|
957
|
+
max_time = self._sampler.get_max_time()
|
|
958
|
+
|
|
959
|
+
if anchor_time < min_time:
|
|
768
960
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
769
|
-
f"the earliest timestamp "
|
|
770
|
-
f"
|
|
961
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
962
|
+
f"data.")
|
|
771
963
|
|
|
772
|
-
if
|
|
773
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
964
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
774
965
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
775
966
|
f"aggregation time range is too large. To make "
|
|
776
967
|
f"this prediction, we would need data back to "
|
|
777
968
|
f"'{context_anchor_time}', however, your data "
|
|
778
|
-
f"only contains data back to "
|
|
779
|
-
f"'{self._graph_store.min_time}'.")
|
|
969
|
+
f"only contains data back to '{min_time}'.")
|
|
780
970
|
|
|
781
971
|
if query.target_ast.date_offset_range is not None:
|
|
782
972
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
783
973
|
else:
|
|
784
974
|
end_offset = pd.DateOffset(0)
|
|
785
|
-
|
|
975
|
+
end_offset = end_offset * query.num_forecasts
|
|
976
|
+
|
|
786
977
|
if (context_anchor_time is not None
|
|
787
978
|
and context_anchor_time > anchor_time):
|
|
788
979
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -792,7 +983,7 @@ class KumoRFM:
|
|
|
792
983
|
f"intended.")
|
|
793
984
|
elif (query.query_type == QueryType.TEMPORAL
|
|
794
985
|
and context_anchor_time is not None
|
|
795
|
-
and context_anchor_time +
|
|
986
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
796
987
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
797
988
|
f"'{context_anchor_time}' will leak information "
|
|
798
989
|
f"from the prediction anchor timestamp "
|
|
@@ -800,62 +991,44 @@ class KumoRFM:
|
|
|
800
991
|
f"intended.")
|
|
801
992
|
|
|
802
993
|
elif (context_anchor_time is not None
|
|
803
|
-
and context_anchor_time -
|
|
804
|
-
|
|
805
|
-
_time = context_anchor_time - forecast_end_offset
|
|
994
|
+
and context_anchor_time - end_offset < min_time):
|
|
995
|
+
_time = context_anchor_time - end_offset
|
|
806
996
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
807
997
|
f"aggregation time range is too large. To form "
|
|
808
998
|
f"proper input data, we would need data back to "
|
|
809
999
|
f"'{_time}', however, your data only contains "
|
|
810
|
-
f"data back to '{
|
|
1000
|
+
f"data back to '{min_time}'.")
|
|
811
1001
|
|
|
812
|
-
if
|
|
813
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
1002
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
814
1003
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
815
|
-
f"latest timestamp '{
|
|
816
|
-
f"
|
|
1004
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
1005
|
+
f"make sure this is intended.")
|
|
817
1006
|
|
|
818
|
-
|
|
819
|
-
if evaluate and anchor_time > max_eval_time:
|
|
1007
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
820
1008
|
raise ValueError(
|
|
821
1009
|
f"Anchor timestamp for evaluation is after the latest "
|
|
822
|
-
f"supported timestamp '{
|
|
1010
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
823
1011
|
|
|
824
|
-
def
|
|
1012
|
+
def _get_task_table(
|
|
825
1013
|
self,
|
|
826
1014
|
query: ValidatedPredictiveQuery,
|
|
827
|
-
indices:
|
|
828
|
-
anchor_time:
|
|
829
|
-
context_anchor_time:
|
|
830
|
-
run_mode: RunMode,
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
if num_neighbors is not None:
|
|
840
|
-
num_hops = len(num_neighbors)
|
|
841
|
-
|
|
842
|
-
if num_hops < 0:
|
|
843
|
-
raise ValueError(f"'num_hops' must be non-negative "
|
|
844
|
-
f"(got {num_hops})")
|
|
845
|
-
if num_hops > 6:
|
|
846
|
-
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
847
|
-
f"hops (got {num_hops}). Please reduce the "
|
|
848
|
-
f"number of hops and try again. Please create a "
|
|
849
|
-
f"feature request at "
|
|
850
|
-
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
851
|
-
f"must go beyond this for your use-case.")
|
|
852
|
-
|
|
853
|
-
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
854
|
-
task_type = LocalPQueryDriver.get_task_type(
|
|
855
|
-
query,
|
|
856
|
-
edge_types=self._graph_store.edge_types,
|
|
1015
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
1016
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
1017
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
1018
|
+
run_mode: RunMode = RunMode.FAST,
|
|
1019
|
+
max_pq_iterations: int = 10,
|
|
1020
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
1021
|
+
logger: ProgressLogger | None = None,
|
|
1022
|
+
) -> TaskTable:
|
|
1023
|
+
|
|
1024
|
+
task_type = self._get_task_type(
|
|
1025
|
+
query=query,
|
|
1026
|
+
edge_types=self._sampler.edge_types,
|
|
857
1027
|
)
|
|
858
1028
|
|
|
1029
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1030
|
+
num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
|
|
1031
|
+
|
|
859
1032
|
if logger is not None:
|
|
860
1033
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
861
1034
|
task_type_repr = 'binary classification'
|
|
@@ -869,30 +1042,17 @@ class KumoRFM:
|
|
|
869
1042
|
task_type_repr = str(task_type)
|
|
870
1043
|
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
871
1044
|
|
|
872
|
-
if task_type.is_link_pred and num_hops < 2:
|
|
873
|
-
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
874
|
-
f"with less than 2 hops (got {num_hops}) since "
|
|
875
|
-
f"historical target entities need to be part of "
|
|
876
|
-
f"the context. Please increase the number of "
|
|
877
|
-
f"hops and try again.")
|
|
878
|
-
|
|
879
|
-
if num_neighbors is None:
|
|
880
|
-
if run_mode == RunMode.DEBUG:
|
|
881
|
-
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
882
|
-
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
883
|
-
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
884
|
-
else:
|
|
885
|
-
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
886
|
-
|
|
887
1045
|
if query.target_ast.date_offset_range is None:
|
|
888
|
-
|
|
1046
|
+
step_offset = pd.DateOffset(0)
|
|
889
1047
|
else:
|
|
890
|
-
|
|
891
|
-
|
|
1048
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
1049
|
+
end_offset = step_offset * query.num_forecasts
|
|
1050
|
+
|
|
892
1051
|
if anchor_time is None:
|
|
893
|
-
anchor_time = self.
|
|
894
|
-
if
|
|
895
|
-
anchor_time = anchor_time -
|
|
1052
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
1053
|
+
if num_test_examples > 0:
|
|
1054
|
+
anchor_time = anchor_time - end_offset
|
|
1055
|
+
|
|
896
1056
|
if logger is not None:
|
|
897
1057
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
898
1058
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -904,114 +1064,98 @@ class KumoRFM:
|
|
|
904
1064
|
else:
|
|
905
1065
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
906
1066
|
|
|
907
|
-
assert anchor_time is not None
|
|
908
1067
|
if isinstance(anchor_time, pd.Timestamp):
|
|
1068
|
+
if context_anchor_time == 'entity':
|
|
1069
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
1070
|
+
"for context and prediction examples")
|
|
909
1071
|
if context_anchor_time is None:
|
|
910
|
-
context_anchor_time = anchor_time -
|
|
1072
|
+
context_anchor_time = anchor_time - end_offset
|
|
911
1073
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
912
|
-
evaluate)
|
|
1074
|
+
evaluate=num_test_examples > 0)
|
|
913
1075
|
else:
|
|
914
1076
|
assert anchor_time == 'entity'
|
|
915
|
-
if query.
|
|
1077
|
+
if query.query_type != QueryType.STATIC:
|
|
1078
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
1079
|
+
"static predictive queries")
|
|
1080
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
916
1081
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
917
1082
|
f"table '{query.entity_table}' to "
|
|
918
1083
|
f"have a time column")
|
|
919
|
-
if context_anchor_time
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
context_anchor_time =
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
if logger is not None:
|
|
937
|
-
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
938
|
-
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
939
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
940
|
-
f"{pos:.2f}% positive cases")
|
|
941
|
-
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
942
|
-
msg = (f"Collected {len(y_test):,} test examples "
|
|
943
|
-
f"holding {y_test.nunique()} classes")
|
|
944
|
-
elif task_type == TaskType.REGRESSION:
|
|
945
|
-
_min, _max = float(y_test.min()), float(y_test.max())
|
|
946
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
947
|
-
f"targets between {format_value(_min)} and "
|
|
948
|
-
f"{format_value(_max)}")
|
|
949
|
-
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
950
|
-
num_rhs = y_test.explode().nunique()
|
|
951
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
952
|
-
f"{num_rhs:,} unique items")
|
|
953
|
-
else:
|
|
954
|
-
raise NotImplementedError
|
|
955
|
-
logger.log(msg)
|
|
956
|
-
|
|
957
|
-
else:
|
|
958
|
-
assert indices is not None
|
|
959
|
-
|
|
960
|
-
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
961
|
-
raise ValueError(f"Cannot predict for more than "
|
|
962
|
-
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
963
|
-
f"once (got {len(indices):,}). Use "
|
|
964
|
-
f"`KumoRFM.batch_mode` to process entities "
|
|
965
|
-
f"in batches")
|
|
1084
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
1085
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
1086
|
+
"for context and prediction examples")
|
|
1087
|
+
context_anchor_time = 'entity'
|
|
1088
|
+
|
|
1089
|
+
train, test = self._sampler.sample_target(
|
|
1090
|
+
query=query,
|
|
1091
|
+
num_train_examples=num_train_examples,
|
|
1092
|
+
train_anchor_time=context_anchor_time,
|
|
1093
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1094
|
+
num_test_examples=num_test_examples,
|
|
1095
|
+
test_anchor_time=anchor_time,
|
|
1096
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1097
|
+
random_seed=random_seed,
|
|
1098
|
+
)
|
|
1099
|
+
train_pkey, train_time, train_y = train
|
|
1100
|
+
test_pkey, test_time, test_y = test
|
|
966
1101
|
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
1102
|
+
if num_test_examples > 0 and logger is not None:
|
|
1103
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1104
|
+
pos = 100 * int((test_y > 0).sum()) / len(test_y)
|
|
1105
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1106
|
+
f"{pos:.2f}% positive cases")
|
|
1107
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1108
|
+
msg = (f"Collected {len(test_y):,} test examples holding "
|
|
1109
|
+
f"{test_y.nunique()} classes")
|
|
1110
|
+
elif task_type == TaskType.REGRESSION:
|
|
1111
|
+
_min, _max = float(test_y.min()), float(test_y.max())
|
|
1112
|
+
msg = (f"Collected {len(test_y):,} test examples with targets "
|
|
1113
|
+
f"between {format_value(_min)} and "
|
|
1114
|
+
f"{format_value(_max)}")
|
|
1115
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1116
|
+
num_rhs = test_y.explode().nunique()
|
|
1117
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1118
|
+
f"{num_rhs:,} unique items")
|
|
1119
|
+
else:
|
|
1120
|
+
raise NotImplementedError
|
|
1121
|
+
logger.log(msg)
|
|
971
1122
|
|
|
1123
|
+
if num_test_examples == 0:
|
|
1124
|
+
assert indices is not None
|
|
1125
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
972
1126
|
if isinstance(anchor_time, pd.Timestamp):
|
|
973
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
974
|
-
len(
|
|
1127
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1128
|
+
len(indices)).reset_index(drop=True)
|
|
975
1129
|
else:
|
|
976
|
-
|
|
977
|
-
time = time[test_node] * 1000**3
|
|
978
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
979
|
-
|
|
980
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
981
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
982
|
-
anchor_time=context_anchor_time or 'entity',
|
|
983
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
984
|
-
or anchor_time == 'entity') else None,
|
|
985
|
-
max_iterations=max_pq_iterations,
|
|
986
|
-
)
|
|
1130
|
+
train_time = test_time = 'entity'
|
|
987
1131
|
|
|
988
1132
|
if logger is not None:
|
|
989
1133
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
990
|
-
pos = 100 * int((
|
|
991
|
-
msg = (f"Collected {len(
|
|
1134
|
+
pos = 100 * int((train_y > 0).sum()) / len(train_y)
|
|
1135
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
992
1136
|
f"{pos:.2f}% positive cases")
|
|
993
1137
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
994
|
-
msg = (f"Collected {len(
|
|
995
|
-
f"holding {
|
|
1138
|
+
msg = (f"Collected {len(train_y):,} in-context examples "
|
|
1139
|
+
f"holding {train_y.nunique()} classes")
|
|
996
1140
|
elif task_type == TaskType.REGRESSION:
|
|
997
|
-
_min, _max = float(
|
|
998
|
-
msg = (f"Collected {len(
|
|
1141
|
+
_min, _max = float(train_y.min()), float(train_y.max())
|
|
1142
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
999
1143
|
f"targets between {format_value(_min)} and "
|
|
1000
1144
|
f"{format_value(_max)}")
|
|
1001
1145
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1002
|
-
num_rhs =
|
|
1003
|
-
msg = (f"Collected {len(
|
|
1146
|
+
num_rhs = train_y.explode().nunique()
|
|
1147
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1004
1148
|
f"{num_rhs:,} unique items")
|
|
1005
1149
|
else:
|
|
1006
1150
|
raise NotImplementedError
|
|
1007
1151
|
logger.log(msg)
|
|
1008
1152
|
|
|
1009
|
-
entity_table_names:
|
|
1153
|
+
entity_table_names: tuple[str] | tuple[str, str]
|
|
1010
1154
|
if task_type.is_link_pred:
|
|
1011
1155
|
final_aggr = query.get_final_target_aggregation()
|
|
1012
1156
|
assert final_aggr is not None
|
|
1013
1157
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1014
|
-
for edge_type in self.
|
|
1158
|
+
for edge_type in self._sampler.edge_types:
|
|
1015
1159
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1016
1160
|
entity_table_names = (
|
|
1017
1161
|
query.entity_table,
|
|
@@ -1020,23 +1164,98 @@ class KumoRFM:
|
|
|
1020
1164
|
else:
|
|
1021
1165
|
entity_table_names = (query.entity_table, )
|
|
1022
1166
|
|
|
1167
|
+
context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
|
|
1168
|
+
if isinstance(train_time, pd.Series):
|
|
1169
|
+
context_df['ANCHOR_TIMESTAMP'] = train_time
|
|
1170
|
+
pred_df = pd.DataFrame({'ENTITY': test_pkey})
|
|
1171
|
+
if num_test_examples > 0:
|
|
1172
|
+
pred_df['TARGET'] = test_y
|
|
1173
|
+
if isinstance(test_time, pd.Series):
|
|
1174
|
+
pred_df['ANCHOR_TIMESTAMP'] = test_time
|
|
1175
|
+
|
|
1176
|
+
return TaskTable(
|
|
1177
|
+
task_type=task_type,
|
|
1178
|
+
context_df=context_df,
|
|
1179
|
+
pred_df=pred_df,
|
|
1180
|
+
entity_table_name=entity_table_names,
|
|
1181
|
+
entity_column='ENTITY',
|
|
1182
|
+
target_column='TARGET',
|
|
1183
|
+
time_column='ANCHOR_TIMESTAMP' if isinstance(
|
|
1184
|
+
train_time, pd.Series) else TaskTable.ENTITY_TIME,
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
def _get_context(
|
|
1188
|
+
self,
|
|
1189
|
+
task: TaskTable,
|
|
1190
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
1191
|
+
num_neighbors: list[int] | None = None,
|
|
1192
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
1193
|
+
top_k: int | None = None,
|
|
1194
|
+
) -> Context:
|
|
1195
|
+
|
|
1196
|
+
# TODO Remove all
|
|
1197
|
+
if task.num_context_examples > max(_MAX_CONTEXT_SIZE.values()):
|
|
1198
|
+
raise ValueError(f"Cannot process a context with more than "
|
|
1199
|
+
f"{max(_MAX_CONTEXT_SIZE.values()):,} samples "
|
|
1200
|
+
f"(got {task.num_context_examples:,})")
|
|
1201
|
+
if task.evaluate:
|
|
1202
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
1203
|
+
raise ValueError(f"Cannot process a test set with more than "
|
|
1204
|
+
f"{_MAX_TEST_SIZE[task.task_type]:,} samples "
|
|
1205
|
+
f"for evaluation "
|
|
1206
|
+
f"(got {task.num_prediction_examples:,})")
|
|
1207
|
+
else:
|
|
1208
|
+
if task.num_prediction_examples > _MAX_PRED_SIZE[task.task_type]:
|
|
1209
|
+
raise ValueError(f"Cannot predict for more than "
|
|
1210
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
1211
|
+
f"entities at once "
|
|
1212
|
+
f"(got {task.num_prediction_examples:,})")
|
|
1213
|
+
|
|
1214
|
+
if num_neighbors is None:
|
|
1215
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1216
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
|
1217
|
+
|
|
1218
|
+
if len(num_neighbors) > 6:
|
|
1219
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
1220
|
+
f"hops (got {len(num_neighbors)}). Reduce the "
|
|
1221
|
+
f"number of hops and try again. Please create a "
|
|
1222
|
+
f"feature request at "
|
|
1223
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1224
|
+
f"must go beyond this for your use-case.")
|
|
1225
|
+
|
|
1023
1226
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1024
1227
|
# running out-of-distribution between in-context and test examples:
|
|
1025
|
-
exclude_cols_dict =
|
|
1026
|
-
if
|
|
1027
|
-
if
|
|
1028
|
-
exclude_cols_dict[
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1228
|
+
exclude_cols_dict = exclude_cols_dict or {}
|
|
1229
|
+
if task.entity_table_name in self._sampler.time_column_dict:
|
|
1230
|
+
if task.entity_table_name not in exclude_cols_dict:
|
|
1231
|
+
exclude_cols_dict[task.entity_table_name] = []
|
|
1232
|
+
time_col = self._sampler.time_column_dict[task.entity_table_name]
|
|
1233
|
+
exclude_cols_dict[task.entity_table_name].append(time_col)
|
|
1234
|
+
|
|
1235
|
+
entity_pkey = pd.concat([
|
|
1236
|
+
task._context_df[task._entity_column],
|
|
1237
|
+
task._pred_df[task._entity_column],
|
|
1238
|
+
], axis=0, ignore_index=True)
|
|
1239
|
+
|
|
1240
|
+
if task.use_entity_time:
|
|
1241
|
+
if task.entity_table_name not in self._sampler.time_column_dict:
|
|
1242
|
+
raise ValueError(f"The given annchor time requires the entity "
|
|
1243
|
+
f"table '{task.entity_table_name}' to have a "
|
|
1244
|
+
f"time column")
|
|
1245
|
+
anchor_time = 'entity'
|
|
1246
|
+
elif task._time_column is not None:
|
|
1247
|
+
anchor_time = pd.concat([
|
|
1248
|
+
task._context_df[task._time_column],
|
|
1249
|
+
task._pred_df[task._time_column],
|
|
1250
|
+
], axis=0, ignore_index=True)
|
|
1251
|
+
else:
|
|
1252
|
+
anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
|
|
1253
|
+
(len(entity_pkey))).reset_index(drop=True)
|
|
1254
|
+
|
|
1255
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1256
|
+
entity_table_names=task.entity_table_names,
|
|
1257
|
+
entity_pkey=entity_pkey,
|
|
1258
|
+
anchor_time=anchor_time,
|
|
1040
1259
|
num_neighbors=num_neighbors,
|
|
1041
1260
|
exclude_cols_dict=exclude_cols_dict,
|
|
1042
1261
|
)
|
|
@@ -1048,23 +1267,26 @@ class KumoRFM:
|
|
|
1048
1267
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1049
1268
|
f"must go beyond this for your use-case.")
|
|
1050
1269
|
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1270
|
+
if (task.task_type.is_link_pred
|
|
1271
|
+
and task.entity_table_names[-1] not in subgraph.table_dict):
|
|
1272
|
+
raise ValueError("Cannot perform link prediction on subgraphs "
|
|
1273
|
+
"without any historical target entities. Please "
|
|
1274
|
+
"increase the number of hops and try again.")
|
|
1054
1275
|
|
|
1055
1276
|
return Context(
|
|
1056
|
-
task_type=task_type,
|
|
1057
|
-
entity_table_names=entity_table_names,
|
|
1277
|
+
task_type=task.task_type,
|
|
1278
|
+
entity_table_names=task.entity_table_names,
|
|
1058
1279
|
subgraph=subgraph,
|
|
1059
|
-
y_train=
|
|
1060
|
-
y_test=
|
|
1061
|
-
|
|
1062
|
-
|
|
1280
|
+
y_train=task._context_df[task.target_column.name],
|
|
1281
|
+
y_test=task._pred_df[task.target_column.name]
|
|
1282
|
+
if task.evaluate else None,
|
|
1283
|
+
top_k=top_k,
|
|
1284
|
+
step_size=None,
|
|
1063
1285
|
)
|
|
1064
1286
|
|
|
1065
1287
|
@staticmethod
|
|
1066
1288
|
def _validate_metrics(
|
|
1067
|
-
metrics:
|
|
1289
|
+
metrics: list[str],
|
|
1068
1290
|
task_type: TaskType,
|
|
1069
1291
|
) -> None:
|
|
1070
1292
|
|
|
@@ -1121,7 +1343,7 @@ class KumoRFM:
|
|
|
1121
1343
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1122
1344
|
|
|
1123
1345
|
|
|
1124
|
-
def format_value(value:
|
|
1346
|
+
def format_value(value: int | float) -> str:
|
|
1125
1347
|
if value == int(value):
|
|
1126
1348
|
return f'{int(value):,}'
|
|
1127
1349
|
if abs(value) >= 1000:
|