kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
- kumoai/experimental/rfm/backend/snow/table.py +146 -70
- kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +320 -19
- kumoai/experimental/rfm/base/table.py +256 -109
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +115 -107
- kumoai/experimental/rfm/infer/dtype.py +7 -2
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +540 -306
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +15 -2
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +40 -35
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
import time
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
@@ -7,7 +8,6 @@ from contextlib import contextmanager
|
|
|
7
8
|
from dataclasses import dataclass, replace
|
|
8
9
|
from typing import Any, Literal, overload
|
|
9
10
|
|
|
10
|
-
import numpy as np
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from kumoapi.model_plan import RunMode
|
|
13
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
@@ -27,31 +27,37 @@ from kumoapi.rfm import (
|
|
|
27
27
|
)
|
|
28
28
|
from kumoapi.task import TaskType
|
|
29
29
|
from kumoapi.typing import AggregationType, Stype
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.markdown import Markdown
|
|
30
32
|
|
|
31
|
-
from kumoai import in_notebook
|
|
33
|
+
from kumoai import in_notebook
|
|
32
34
|
from kumoai.client.rfm import RFMAPI
|
|
33
35
|
from kumoai.exceptions import HTTPException
|
|
34
|
-
from kumoai.experimental.rfm import Graph
|
|
36
|
+
from kumoai.experimental.rfm import Graph, TaskTable
|
|
35
37
|
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
36
38
|
from kumoai.mixin import CastMixin
|
|
37
|
-
from kumoai.utils import ProgressLogger
|
|
39
|
+
from kumoai.utils import ProgressLogger, display
|
|
38
40
|
|
|
39
41
|
_RANDOM_SEED = 42
|
|
40
42
|
|
|
41
43
|
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
42
44
|
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
43
45
|
|
|
46
|
+
_MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
|
|
47
|
+
_MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
|
|
48
|
+
|
|
44
49
|
_MAX_CONTEXT_SIZE = {
|
|
45
50
|
RunMode.DEBUG: 100,
|
|
46
51
|
RunMode.FAST: 1_000,
|
|
47
52
|
RunMode.NORMAL: 5_000,
|
|
48
53
|
RunMode.BEST: 10_000,
|
|
49
54
|
}
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
RunMode.
|
|
53
|
-
RunMode.
|
|
54
|
-
RunMode.
|
|
55
|
+
|
|
56
|
+
_DEFAULT_NUM_NEIGHBORS = {
|
|
57
|
+
RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
|
|
58
|
+
RunMode.FAST: [32, 32, 8, 8, 4, 4],
|
|
59
|
+
RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
|
|
60
|
+
RunMode.BEST: [64, 64, 8, 8, 4, 4],
|
|
55
61
|
}
|
|
56
62
|
|
|
57
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
@@ -102,25 +108,20 @@ class Explanation:
|
|
|
102
108
|
def __repr__(self) -> str:
|
|
103
109
|
return str((self.prediction, self.summary))
|
|
104
110
|
|
|
111
|
+
def __str__(self) -> str:
|
|
112
|
+
console = Console(soft_wrap=True)
|
|
113
|
+
with console.capture() as cap:
|
|
114
|
+
console.print(display.to_rich_table(self.prediction))
|
|
115
|
+
console.print(Markdown(self.summary))
|
|
116
|
+
return cap.get()[:-1]
|
|
117
|
+
|
|
105
118
|
def print(self) -> None:
|
|
106
119
|
r"""Prints the explanation."""
|
|
107
|
-
if
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
st.markdown(self.summary)
|
|
111
|
-
elif in_notebook():
|
|
112
|
-
from IPython.display import Markdown, display
|
|
113
|
-
try:
|
|
114
|
-
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
-
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
-
else:
|
|
117
|
-
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
-
except ImportError:
|
|
119
|
-
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
-
display(Markdown(self.summary))
|
|
120
|
+
if in_notebook():
|
|
121
|
+
display.dataframe(self.prediction)
|
|
122
|
+
display.message(self.summary)
|
|
121
123
|
else:
|
|
122
|
-
print(self
|
|
123
|
-
print(self.summary)
|
|
124
|
+
print(self)
|
|
124
125
|
|
|
125
126
|
def _ipython_display_(self) -> None:
|
|
126
127
|
self.print()
|
|
@@ -191,7 +192,7 @@ class KumoRFM:
|
|
|
191
192
|
self._client: RFMAPI | None = None
|
|
192
193
|
|
|
193
194
|
self._batch_size: int | Literal['max'] | None = None
|
|
194
|
-
self.
|
|
195
|
+
self._num_retries: int = 0
|
|
195
196
|
|
|
196
197
|
@property
|
|
197
198
|
def _api_client(self) -> RFMAPI:
|
|
@@ -205,6 +206,30 @@ class KumoRFM:
|
|
|
205
206
|
def __repr__(self) -> str:
|
|
206
207
|
return f'{self.__class__.__name__}()'
|
|
207
208
|
|
|
209
|
+
@contextmanager
|
|
210
|
+
def retry(
|
|
211
|
+
self,
|
|
212
|
+
num_retries: int = 1,
|
|
213
|
+
) -> Generator[None, None, None]:
|
|
214
|
+
"""Context manager to retry failed queries due to unexpected server
|
|
215
|
+
issues.
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
with model.retry(num_retries=1):
|
|
220
|
+
df = model.predict(query, indices=...)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
num_retries: The maximum number of retries.
|
|
224
|
+
"""
|
|
225
|
+
if num_retries < 0:
|
|
226
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
227
|
+
f"zero (got {num_retries})")
|
|
228
|
+
|
|
229
|
+
self._num_retries = num_retries
|
|
230
|
+
yield
|
|
231
|
+
self._num_retries = 0
|
|
232
|
+
|
|
208
233
|
@contextmanager
|
|
209
234
|
def batch_mode(
|
|
210
235
|
self,
|
|
@@ -228,15 +253,10 @@ class KumoRFM:
|
|
|
228
253
|
raise ValueError(f"'batch_size' must be greater than zero "
|
|
229
254
|
f"(got {batch_size})")
|
|
230
255
|
|
|
231
|
-
if num_retries < 0:
|
|
232
|
-
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
233
|
-
f"zero (got {num_retries})")
|
|
234
|
-
|
|
235
256
|
self._batch_size = batch_size
|
|
236
|
-
self.
|
|
237
|
-
|
|
257
|
+
with self.retry(self._num_retries or num_retries):
|
|
258
|
+
yield
|
|
238
259
|
self._batch_size = None
|
|
239
|
-
self.num_retries = 0
|
|
240
260
|
|
|
241
261
|
@overload
|
|
242
262
|
def predict(
|
|
@@ -276,6 +296,25 @@ class KumoRFM:
|
|
|
276
296
|
) -> Explanation:
|
|
277
297
|
pass
|
|
278
298
|
|
|
299
|
+
@overload
|
|
300
|
+
def predict(
|
|
301
|
+
self,
|
|
302
|
+
query: str,
|
|
303
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
304
|
+
*,
|
|
305
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
306
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
307
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
308
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
309
|
+
num_neighbors: list[int] | None = None,
|
|
310
|
+
num_hops: int = 2,
|
|
311
|
+
max_pq_iterations: int = 10,
|
|
312
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
313
|
+
verbose: bool | ProgressLogger = True,
|
|
314
|
+
use_prediction_time: bool = False,
|
|
315
|
+
) -> pd.DataFrame | Explanation:
|
|
316
|
+
pass
|
|
317
|
+
|
|
279
318
|
def predict(
|
|
280
319
|
self,
|
|
281
320
|
query: str,
|
|
@@ -299,8 +338,7 @@ class KumoRFM:
|
|
|
299
338
|
indices: The entity primary keys to predict for. Will override the
|
|
300
339
|
indices given as part of the predictive query. Predictions will
|
|
301
340
|
be generated for all indices, independent of whether they
|
|
302
|
-
fulfill entity filter constraints.
|
|
303
|
-
:meth:`~KumoRFM.is_valid_entity`.
|
|
341
|
+
fulfill entity filter constraints.
|
|
304
342
|
explain: Configuration for explainability.
|
|
305
343
|
If set to ``True``, will additionally explain the prediction.
|
|
306
344
|
Passing in an :class:`ExplainConfig` instance provides control
|
|
@@ -333,18 +371,152 @@ class KumoRFM:
|
|
|
333
371
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
334
372
|
containing the prediction, summary, and details.
|
|
335
373
|
"""
|
|
336
|
-
explain_config: ExplainConfig | None = None
|
|
337
|
-
if explain is True:
|
|
338
|
-
explain_config = ExplainConfig()
|
|
339
|
-
elif explain is not False:
|
|
340
|
-
explain_config = ExplainConfig._cast(explain)
|
|
341
|
-
|
|
342
374
|
query_def = self._parse_query(query)
|
|
343
|
-
query_str = query_def.to_string()
|
|
344
375
|
|
|
376
|
+
if indices is None:
|
|
377
|
+
if query_def.rfm_entity_ids is None:
|
|
378
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
379
|
+
"pass them via `predict(query, indices=...)`")
|
|
380
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
381
|
+
query_def = replace(
|
|
382
|
+
query_def,
|
|
383
|
+
for_each='FOR EACH',
|
|
384
|
+
rfm_entity_ids=None,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if not isinstance(verbose, ProgressLogger):
|
|
388
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
389
|
+
if explain is not False:
|
|
390
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
391
|
+
else:
|
|
392
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
393
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
394
|
+
|
|
395
|
+
with verbose as logger:
|
|
396
|
+
task_table = self._get_task_table(
|
|
397
|
+
query=query_def,
|
|
398
|
+
indices=indices,
|
|
399
|
+
anchor_time=anchor_time,
|
|
400
|
+
context_anchor_time=context_anchor_time,
|
|
401
|
+
run_mode=run_mode,
|
|
402
|
+
max_pq_iterations=max_pq_iterations,
|
|
403
|
+
random_seed=random_seed,
|
|
404
|
+
logger=logger,
|
|
405
|
+
)
|
|
406
|
+
task_table._query = query_def.to_string()
|
|
407
|
+
|
|
408
|
+
return self.predict_task(
|
|
409
|
+
task_table,
|
|
410
|
+
explain=explain,
|
|
411
|
+
run_mode=run_mode,
|
|
412
|
+
num_neighbors=num_neighbors,
|
|
413
|
+
num_hops=num_hops,
|
|
414
|
+
verbose=verbose,
|
|
415
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
416
|
+
use_prediction_time=use_prediction_time,
|
|
417
|
+
top_k=query_def.top_k,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
@overload
|
|
421
|
+
def predict_task(
|
|
422
|
+
self,
|
|
423
|
+
task: TaskTable,
|
|
424
|
+
*,
|
|
425
|
+
explain: Literal[False] = False,
|
|
426
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
427
|
+
num_neighbors: list[int] | None = None,
|
|
428
|
+
num_hops: int = 2,
|
|
429
|
+
verbose: bool | ProgressLogger = True,
|
|
430
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
431
|
+
use_prediction_time: bool = False,
|
|
432
|
+
top_k: int | None = None,
|
|
433
|
+
) -> pd.DataFrame:
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
@overload
|
|
437
|
+
def predict_task(
|
|
438
|
+
self,
|
|
439
|
+
task: TaskTable,
|
|
440
|
+
*,
|
|
441
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
442
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
443
|
+
num_neighbors: list[int] | None = None,
|
|
444
|
+
num_hops: int = 2,
|
|
445
|
+
verbose: bool | ProgressLogger = True,
|
|
446
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
447
|
+
use_prediction_time: bool = False,
|
|
448
|
+
top_k: int | None = None,
|
|
449
|
+
) -> Explanation:
|
|
450
|
+
pass
|
|
451
|
+
|
|
452
|
+
@overload
|
|
453
|
+
def predict_task(
|
|
454
|
+
self,
|
|
455
|
+
task: TaskTable,
|
|
456
|
+
*,
|
|
457
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
458
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
459
|
+
num_neighbors: list[int] | None = None,
|
|
460
|
+
num_hops: int = 2,
|
|
461
|
+
verbose: bool | ProgressLogger = True,
|
|
462
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
463
|
+
use_prediction_time: bool = False,
|
|
464
|
+
top_k: int | None = None,
|
|
465
|
+
) -> pd.DataFrame | Explanation:
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
def predict_task(
|
|
469
|
+
self,
|
|
470
|
+
task: TaskTable,
|
|
471
|
+
*,
|
|
472
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
473
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
474
|
+
num_neighbors: list[int] | None = None,
|
|
475
|
+
num_hops: int = 2,
|
|
476
|
+
verbose: bool | ProgressLogger = True,
|
|
477
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
478
|
+
use_prediction_time: bool = False,
|
|
479
|
+
top_k: int | None = None,
|
|
480
|
+
) -> pd.DataFrame | Explanation:
|
|
481
|
+
"""Returns predictions for a custom task specification.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
task: The custom :class:`TaskTable`.
|
|
485
|
+
explain: Configuration for explainability.
|
|
486
|
+
If set to ``True``, will additionally explain the prediction.
|
|
487
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
488
|
+
over which parts of explanation are generated.
|
|
489
|
+
Explainability is currently only supported for single entity
|
|
490
|
+
predictions with ``run_mode="FAST"``.
|
|
491
|
+
run_mode: The :class:`RunMode` for the query.
|
|
492
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
493
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
494
|
+
num_hops: The number of hops to sample when generating the context.
|
|
495
|
+
verbose: Whether to print verbose output.
|
|
496
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
497
|
+
model input.
|
|
498
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
499
|
+
additional feature during prediction. This is typically
|
|
500
|
+
beneficial for time series forecasting tasks.
|
|
501
|
+
top_k: The number of predictions to return per entity.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
505
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
506
|
+
containing the prediction, summary, and details.
|
|
507
|
+
"""
|
|
345
508
|
if num_hops != 2 and num_neighbors is not None:
|
|
346
509
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
347
510
|
f"custom 'num_hops={num_hops}' option")
|
|
511
|
+
if num_neighbors is None:
|
|
512
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
513
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
514
|
+
|
|
515
|
+
explain_config: ExplainConfig | None = None
|
|
516
|
+
if explain is True:
|
|
517
|
+
explain_config = ExplainConfig()
|
|
518
|
+
elif explain is not False:
|
|
519
|
+
explain_config = ExplainConfig._cast(explain)
|
|
348
520
|
|
|
349
521
|
if explain_config is not None and run_mode in {
|
|
350
522
|
RunMode.NORMAL, RunMode.BEST
|
|
@@ -353,83 +525,82 @@ class KumoRFM:
|
|
|
353
525
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
354
526
|
f"mode has been reset. Please lower the run mode to "
|
|
355
527
|
f"suppress this warning.")
|
|
528
|
+
run_mode = RunMode.FAST
|
|
356
529
|
|
|
357
|
-
if
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
362
|
-
else:
|
|
363
|
-
query_def = replace(query_def, rfm_entity_ids=None)
|
|
364
|
-
|
|
365
|
-
if len(indices) == 0:
|
|
366
|
-
raise ValueError("At least one entity is required")
|
|
367
|
-
|
|
368
|
-
if explain_config is not None and len(indices) > 1:
|
|
369
|
-
raise ValueError(
|
|
370
|
-
f"Cannot explain predictions for more than a single entity "
|
|
371
|
-
f"(got {len(indices)})")
|
|
372
|
-
|
|
373
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
374
|
-
if explain_config is not None:
|
|
375
|
-
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
376
|
-
else:
|
|
377
|
-
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
530
|
+
if explain_config is not None and task.num_prediction_examples > 1:
|
|
531
|
+
raise ValueError(f"Cannot explain predictions for more than a "
|
|
532
|
+
f"single entity "
|
|
533
|
+
f"(got {task.num_prediction_examples:,})")
|
|
378
534
|
|
|
379
535
|
if not isinstance(verbose, ProgressLogger):
|
|
536
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
537
|
+
task_type_repr = 'binary classification'
|
|
538
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
539
|
+
task_type_repr = 'multi-class classification'
|
|
540
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
541
|
+
task_type_repr = 'regression'
|
|
542
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
543
|
+
task_type_repr = 'link prediction'
|
|
544
|
+
else:
|
|
545
|
+
task_type_repr = str(task.task_type)
|
|
546
|
+
|
|
547
|
+
if explain_config is not None:
|
|
548
|
+
msg = f"Explaining {task_type_repr} task"
|
|
549
|
+
else:
|
|
550
|
+
msg = f"Predicting {task_type_repr} task"
|
|
380
551
|
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
381
552
|
|
|
382
553
|
with verbose as logger:
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
batch_size =
|
|
554
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
555
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
556
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
557
|
+
f"examples")
|
|
558
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
559
|
+
|
|
560
|
+
if self._batch_size is None:
|
|
561
|
+
batch_size = task.num_prediction_examples
|
|
562
|
+
elif self._batch_size == 'max':
|
|
563
|
+
batch_size = _MAX_PRED_SIZE[task.task_type]
|
|
391
564
|
else:
|
|
392
565
|
batch_size = self._batch_size
|
|
393
566
|
|
|
394
|
-
if batch_size
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
567
|
+
if batch_size > _MAX_PRED_SIZE[task.task_type]:
|
|
568
|
+
raise ValueError(f"Cannot predict for more than "
|
|
569
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
570
|
+
f"entities at once (got {batch_size:,}). Use "
|
|
571
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
572
|
+
f"in batches with a sufficient batch size.")
|
|
399
573
|
|
|
400
|
-
if
|
|
401
|
-
|
|
402
|
-
|
|
574
|
+
if task.num_prediction_examples > batch_size:
|
|
575
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
576
|
+
logger.log(f"Splitting {task.num_prediction_examples:,} "
|
|
577
|
+
f"entities into {num:,} batches of size "
|
|
578
|
+
f"{batch_size:,}")
|
|
403
579
|
|
|
404
580
|
predictions: list[pd.DataFrame] = []
|
|
405
581
|
summary: str | None = None
|
|
406
582
|
details: Explanation | None = None
|
|
407
|
-
for
|
|
408
|
-
# TODO Re-use the context for subsequent predictions.
|
|
583
|
+
for start in range(0, task.num_prediction_examples, batch_size):
|
|
409
584
|
context = self._get_context(
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
anchor_time=anchor_time,
|
|
413
|
-
context_anchor_time=context_anchor_time,
|
|
414
|
-
run_mode=RunMode(run_mode),
|
|
585
|
+
task=task.narrow_prediction(start, length=batch_size),
|
|
586
|
+
run_mode=run_mode,
|
|
415
587
|
num_neighbors=num_neighbors,
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
evaluate=False,
|
|
419
|
-
random_seed=random_seed,
|
|
420
|
-
logger=logger if i == 0 else None,
|
|
588
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
589
|
+
top_k=top_k,
|
|
421
590
|
)
|
|
591
|
+
context.y_test = None
|
|
592
|
+
|
|
422
593
|
request = RFMPredictRequest(
|
|
423
594
|
context=context,
|
|
424
595
|
run_mode=RunMode(run_mode),
|
|
425
|
-
query=
|
|
596
|
+
query=task._query,
|
|
426
597
|
use_prediction_time=use_prediction_time,
|
|
427
598
|
)
|
|
428
599
|
with warnings.catch_warnings():
|
|
429
600
|
warnings.filterwarnings('ignore', message='gencode')
|
|
430
601
|
request_msg = request.to_protobuf()
|
|
431
602
|
_bytes = request_msg.SerializeToString()
|
|
432
|
-
if
|
|
603
|
+
if start == 0:
|
|
433
604
|
logger.log(f"Generated context of size "
|
|
434
605
|
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
435
606
|
|
|
@@ -437,13 +608,11 @@ class KumoRFM:
|
|
|
437
608
|
stats = Context.get_memory_stats(request_msg.context)
|
|
438
609
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
439
610
|
|
|
440
|
-
if
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
description='Predicting',
|
|
444
|
-
)
|
|
611
|
+
if start == 0 and task.num_prediction_examples > batch_size:
|
|
612
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
613
|
+
verbose.init_progress(total=num, description='Predicting')
|
|
445
614
|
|
|
446
|
-
for attempt in range(self.
|
|
615
|
+
for attempt in range(self._num_retries + 1):
|
|
447
616
|
try:
|
|
448
617
|
if explain_config is not None:
|
|
449
618
|
resp = self._api_client.explain(
|
|
@@ -459,7 +628,7 @@ class KumoRFM:
|
|
|
459
628
|
# Cast 'ENTITY' to correct data type:
|
|
460
629
|
if 'ENTITY' in df:
|
|
461
630
|
table_dict = context.subgraph.table_dict
|
|
462
|
-
table = table_dict[
|
|
631
|
+
table = table_dict[context.entity_table_names[0]]
|
|
463
632
|
ser = table.df[table.primary_key]
|
|
464
633
|
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
465
634
|
|
|
@@ -476,12 +645,12 @@ class KumoRFM:
|
|
|
476
645
|
|
|
477
646
|
predictions.append(df)
|
|
478
647
|
|
|
479
|
-
if
|
|
648
|
+
if task.num_prediction_examples > batch_size:
|
|
480
649
|
verbose.step()
|
|
481
650
|
|
|
482
651
|
break
|
|
483
652
|
except HTTPException as e:
|
|
484
|
-
if attempt == self.
|
|
653
|
+
if attempt == self._num_retries:
|
|
485
654
|
try:
|
|
486
655
|
msg = json.loads(e.detail)['detail']
|
|
487
656
|
except Exception:
|
|
@@ -511,51 +680,6 @@ class KumoRFM:
|
|
|
511
680
|
|
|
512
681
|
return prediction
|
|
513
682
|
|
|
514
|
-
def is_valid_entity(
|
|
515
|
-
self,
|
|
516
|
-
query: str,
|
|
517
|
-
indices: list[str] | list[float] | list[int] | None = None,
|
|
518
|
-
*,
|
|
519
|
-
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
520
|
-
) -> np.ndarray:
|
|
521
|
-
r"""Returns a mask that denotes which entities are valid for the
|
|
522
|
-
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
523
|
-
entity filter constraints.
|
|
524
|
-
|
|
525
|
-
Args:
|
|
526
|
-
query: The predictive query.
|
|
527
|
-
indices: The entity primary keys to predict for. Will override the
|
|
528
|
-
indices given as part of the predictive query.
|
|
529
|
-
anchor_time: The anchor timestamp for the prediction. If set to
|
|
530
|
-
``None``, will use the maximum timestamp in the data.
|
|
531
|
-
If set to ``"entity"``, will use the timestamp of the entity.
|
|
532
|
-
"""
|
|
533
|
-
query_def = self._parse_query(query)
|
|
534
|
-
|
|
535
|
-
if indices is None:
|
|
536
|
-
if query_def.rfm_entity_ids is None:
|
|
537
|
-
raise ValueError("Cannot find entities to predict for. Please "
|
|
538
|
-
"pass them via "
|
|
539
|
-
"`is_valid_entity(query, indices=...)`")
|
|
540
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
541
|
-
|
|
542
|
-
if len(indices) == 0:
|
|
543
|
-
raise ValueError("At least one entity is required")
|
|
544
|
-
|
|
545
|
-
if anchor_time is None:
|
|
546
|
-
anchor_time = self._get_default_anchor_time(query_def)
|
|
547
|
-
|
|
548
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
549
|
-
self._validate_time(query_def, anchor_time, None, False)
|
|
550
|
-
else:
|
|
551
|
-
assert anchor_time == 'entity'
|
|
552
|
-
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
553
|
-
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
554
|
-
f"table '{query_def.entity_table}' "
|
|
555
|
-
f"to have a time column.")
|
|
556
|
-
|
|
557
|
-
raise NotImplementedError
|
|
558
|
-
|
|
559
683
|
def evaluate(
|
|
560
684
|
self,
|
|
561
685
|
query: str,
|
|
@@ -600,41 +724,120 @@ class KumoRFM:
|
|
|
600
724
|
Returns:
|
|
601
725
|
The metrics as a :class:`pandas.DataFrame`
|
|
602
726
|
"""
|
|
603
|
-
query_def =
|
|
727
|
+
query_def = replace(
|
|
728
|
+
self._parse_query(query),
|
|
729
|
+
for_each='FOR EACH',
|
|
730
|
+
rfm_entity_ids=None,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
if not isinstance(verbose, ProgressLogger):
|
|
734
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
735
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
736
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
737
|
+
|
|
738
|
+
with verbose as logger:
|
|
739
|
+
task_table = self._get_task_table(
|
|
740
|
+
query=query_def,
|
|
741
|
+
indices=None,
|
|
742
|
+
anchor_time=anchor_time,
|
|
743
|
+
context_anchor_time=context_anchor_time,
|
|
744
|
+
run_mode=run_mode,
|
|
745
|
+
max_pq_iterations=max_pq_iterations,
|
|
746
|
+
random_seed=random_seed,
|
|
747
|
+
logger=logger,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
return self.evaluate_task(
|
|
751
|
+
task_table,
|
|
752
|
+
metrics=metrics,
|
|
753
|
+
run_mode=run_mode,
|
|
754
|
+
num_neighbors=num_neighbors,
|
|
755
|
+
num_hops=num_hops,
|
|
756
|
+
verbose=verbose,
|
|
757
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
758
|
+
use_prediction_time=use_prediction_time,
|
|
759
|
+
)
|
|
604
760
|
|
|
761
|
+
def evaluate_task(
|
|
762
|
+
self,
|
|
763
|
+
task: TaskTable,
|
|
764
|
+
*,
|
|
765
|
+
metrics: list[str] | None = None,
|
|
766
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
767
|
+
num_neighbors: list[int] | None = None,
|
|
768
|
+
num_hops: int = 2,
|
|
769
|
+
verbose: bool | ProgressLogger = True,
|
|
770
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
771
|
+
use_prediction_time: bool = False,
|
|
772
|
+
) -> pd.DataFrame:
|
|
773
|
+
"""Evaluates a custom task specification.
|
|
774
|
+
|
|
775
|
+
Args:
|
|
776
|
+
task: The custom :class:`TaskTable`.
|
|
777
|
+
metrics: The metrics to use.
|
|
778
|
+
run_mode: The :class:`RunMode` for the query.
|
|
779
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
780
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
781
|
+
num_hops: The number of hops to sample when generating the context.
|
|
782
|
+
verbose: Whether to print verbose output.
|
|
783
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
784
|
+
model input.
|
|
785
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
786
|
+
additional feature during prediction. This is typically
|
|
787
|
+
beneficial for time series forecasting tasks.
|
|
788
|
+
|
|
789
|
+
Returns:
|
|
790
|
+
The metrics as a :class:`pandas.DataFrame`
|
|
791
|
+
"""
|
|
605
792
|
if num_hops != 2 and num_neighbors is not None:
|
|
606
793
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
607
794
|
f"custom 'num_hops={num_hops}' option")
|
|
795
|
+
if num_neighbors is None:
|
|
796
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
797
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
608
798
|
|
|
609
|
-
if
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
rfm_entity_ids=None,
|
|
613
|
-
)
|
|
614
|
-
|
|
615
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
616
|
-
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
799
|
+
if metrics is not None and len(metrics) > 0:
|
|
800
|
+
self._validate_metrics(metrics, task.task_type)
|
|
801
|
+
metrics = list(dict.fromkeys(metrics))
|
|
617
802
|
|
|
618
803
|
if not isinstance(verbose, ProgressLogger):
|
|
804
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
805
|
+
task_type_repr = 'binary classification'
|
|
806
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
807
|
+
task_type_repr = 'multi-class classification'
|
|
808
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
809
|
+
task_type_repr = 'regression'
|
|
810
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
811
|
+
task_type_repr = 'link prediction'
|
|
812
|
+
else:
|
|
813
|
+
task_type_repr = str(task.task_type)
|
|
814
|
+
|
|
815
|
+
msg = f"Evaluating {task_type_repr} task"
|
|
619
816
|
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
620
817
|
|
|
621
818
|
with verbose as logger:
|
|
819
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
820
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
821
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
822
|
+
f"examples")
|
|
823
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
824
|
+
|
|
825
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
826
|
+
logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
|
|
827
|
+
f"out of {task.num_prediction_examples:,} test "
|
|
828
|
+
f"examples")
|
|
829
|
+
task = task.narrow_prediction(
|
|
830
|
+
start=0,
|
|
831
|
+
length=_MAX_TEST_SIZE[task.task_type],
|
|
832
|
+
)
|
|
833
|
+
|
|
622
834
|
context = self._get_context(
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
anchor_time=anchor_time,
|
|
626
|
-
context_anchor_time=context_anchor_time,
|
|
627
|
-
run_mode=RunMode(run_mode),
|
|
835
|
+
task=task,
|
|
836
|
+
run_mode=run_mode,
|
|
628
837
|
num_neighbors=num_neighbors,
|
|
629
|
-
|
|
630
|
-
max_pq_iterations=max_pq_iterations,
|
|
631
|
-
evaluate=True,
|
|
632
|
-
random_seed=random_seed,
|
|
633
|
-
logger=logger if verbose else None,
|
|
838
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
634
839
|
)
|
|
635
|
-
|
|
636
|
-
self._validate_metrics(metrics, context.task_type)
|
|
637
|
-
metrics = list(dict.fromkeys(metrics))
|
|
840
|
+
|
|
638
841
|
request = RFMEvaluateRequest(
|
|
639
842
|
context=context,
|
|
640
843
|
run_mode=RunMode(run_mode),
|
|
@@ -652,17 +855,23 @@ class KumoRFM:
|
|
|
652
855
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
653
856
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
654
857
|
|
|
655
|
-
|
|
656
|
-
resp = self._api_client.evaluate(request_bytes)
|
|
657
|
-
except HTTPException as e:
|
|
858
|
+
for attempt in range(self._num_retries + 1):
|
|
658
859
|
try:
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
860
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
861
|
+
break
|
|
862
|
+
except HTTPException as e:
|
|
863
|
+
if attempt == self._num_retries:
|
|
864
|
+
try:
|
|
865
|
+
msg = json.loads(e.detail)['detail']
|
|
866
|
+
except Exception:
|
|
867
|
+
msg = e.detail
|
|
868
|
+
raise RuntimeError(
|
|
869
|
+
f"An unexpected exception occurred. Please create "
|
|
870
|
+
f"an issue at "
|
|
871
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
872
|
+
) from None
|
|
873
|
+
|
|
874
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
666
875
|
|
|
667
876
|
return pd.DataFrame.from_dict(
|
|
668
877
|
resp.metrics,
|
|
@@ -714,7 +923,7 @@ class KumoRFM:
|
|
|
714
923
|
f"to have a time column")
|
|
715
924
|
|
|
716
925
|
train, test = self._sampler.sample_target(
|
|
717
|
-
query=
|
|
926
|
+
query=query_def,
|
|
718
927
|
num_train_examples=0,
|
|
719
928
|
train_anchor_time=anchor_time,
|
|
720
929
|
num_train_trials=0,
|
|
@@ -742,30 +951,34 @@ class KumoRFM:
|
|
|
742
951
|
"`predict()` or `evaluate()` methods to perform "
|
|
743
952
|
"predictions or evaluations.")
|
|
744
953
|
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
954
|
+
request = RFMParseQueryRequest(
|
|
955
|
+
query=query,
|
|
956
|
+
graph_definition=self._graph_def,
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
for attempt in range(self._num_retries + 1):
|
|
960
|
+
try:
|
|
961
|
+
resp = self._api_client.parse_query(request)
|
|
962
|
+
break
|
|
963
|
+
except HTTPException as e:
|
|
964
|
+
if attempt == self._num_retries:
|
|
965
|
+
try:
|
|
966
|
+
msg = json.loads(e.detail)['detail']
|
|
967
|
+
except Exception:
|
|
968
|
+
msg = e.detail
|
|
969
|
+
raise ValueError(f"Failed to parse query '{query}'. {msg}")
|
|
750
970
|
|
|
751
|
-
|
|
971
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
752
972
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
973
|
+
if len(resp.validation_response.warnings) > 0:
|
|
974
|
+
msg = '\n'.join([
|
|
975
|
+
f'{i+1}. {warning.title}: {warning.message}'
|
|
976
|
+
for i, warning in enumerate(resp.validation_response.warnings)
|
|
977
|
+
])
|
|
978
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
979
|
+
f"parsing:\n{msg}")
|
|
760
980
|
|
|
761
|
-
|
|
762
|
-
except HTTPException as e:
|
|
763
|
-
try:
|
|
764
|
-
msg = json.loads(e.detail)['detail']
|
|
765
|
-
except Exception:
|
|
766
|
-
msg = e.detail
|
|
767
|
-
raise ValueError(f"Failed to parse query '{query}'. "
|
|
768
|
-
f"{msg}") from None
|
|
981
|
+
return resp.query
|
|
769
982
|
|
|
770
983
|
@staticmethod
|
|
771
984
|
def _get_task_type(
|
|
@@ -809,16 +1022,15 @@ class KumoRFM:
|
|
|
809
1022
|
|
|
810
1023
|
def _get_default_anchor_time(
|
|
811
1024
|
self,
|
|
812
|
-
query: ValidatedPredictiveQuery,
|
|
1025
|
+
query: ValidatedPredictiveQuery | None = None,
|
|
813
1026
|
) -> pd.Timestamp:
|
|
814
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1027
|
+
if query is not None and query.query_type == QueryType.TEMPORAL:
|
|
815
1028
|
aggr_table_names = [
|
|
816
1029
|
aggr._get_target_column_name().split('.')[0]
|
|
817
1030
|
for aggr in query.get_all_target_aggregations()
|
|
818
1031
|
]
|
|
819
1032
|
return self._sampler.get_max_time(aggr_table_names)
|
|
820
1033
|
|
|
821
|
-
assert query.query_type == QueryType.STATIC
|
|
822
1034
|
return self._sampler.get_max_time()
|
|
823
1035
|
|
|
824
1036
|
def _validate_time(
|
|
@@ -832,8 +1044,16 @@ class KumoRFM:
|
|
|
832
1044
|
if len(self._sampler.time_column_dict) == 0:
|
|
833
1045
|
return # Graph without timestamps
|
|
834
1046
|
|
|
835
|
-
|
|
836
|
-
|
|
1047
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
1048
|
+
aggr_table_names = [
|
|
1049
|
+
aggr._get_target_column_name().split('.')[0]
|
|
1050
|
+
for aggr in query.get_all_target_aggregations()
|
|
1051
|
+
]
|
|
1052
|
+
min_time = self._sampler.get_min_time(aggr_table_names)
|
|
1053
|
+
max_time = self._sampler.get_max_time(aggr_table_names)
|
|
1054
|
+
else:
|
|
1055
|
+
min_time = self._sampler.get_min_time()
|
|
1056
|
+
max_time = self._sampler.get_max_time()
|
|
837
1057
|
|
|
838
1058
|
if anchor_time < min_time:
|
|
839
1059
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
@@ -888,40 +1108,26 @@ class KumoRFM:
|
|
|
888
1108
|
f"Anchor timestamp for evaluation is after the latest "
|
|
889
1109
|
f"supported timestamp '{max_time - end_offset}'.")
|
|
890
1110
|
|
|
891
|
-
def
|
|
1111
|
+
def _get_task_table(
|
|
892
1112
|
self,
|
|
893
1113
|
query: ValidatedPredictiveQuery,
|
|
894
1114
|
indices: list[str] | list[float] | list[int] | None,
|
|
895
|
-
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
896
|
-
context_anchor_time: pd.Timestamp | None,
|
|
897
|
-
run_mode: RunMode,
|
|
898
|
-
|
|
899
|
-
num_hops: int,
|
|
900
|
-
max_pq_iterations: int,
|
|
901
|
-
evaluate: bool,
|
|
1115
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
1116
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
1117
|
+
run_mode: RunMode = RunMode.FAST,
|
|
1118
|
+
max_pq_iterations: int = 10,
|
|
902
1119
|
random_seed: int | None = _RANDOM_SEED,
|
|
903
1120
|
logger: ProgressLogger | None = None,
|
|
904
|
-
) ->
|
|
905
|
-
|
|
906
|
-
if num_neighbors is not None:
|
|
907
|
-
num_hops = len(num_neighbors)
|
|
908
|
-
|
|
909
|
-
if num_hops < 0:
|
|
910
|
-
raise ValueError(f"'num_hops' must be non-negative "
|
|
911
|
-
f"(got {num_hops})")
|
|
912
|
-
if num_hops > 6:
|
|
913
|
-
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
914
|
-
f"hops (got {num_hops}). Please reduce the "
|
|
915
|
-
f"number of hops and try again. Please create a "
|
|
916
|
-
f"feature request at "
|
|
917
|
-
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
918
|
-
f"must go beyond this for your use-case.")
|
|
1121
|
+
) -> TaskTable:
|
|
919
1122
|
|
|
920
1123
|
task_type = self._get_task_type(
|
|
921
1124
|
query=query,
|
|
922
1125
|
edge_types=self._sampler.edge_types,
|
|
923
1126
|
)
|
|
924
1127
|
|
|
1128
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1129
|
+
num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
|
|
1130
|
+
|
|
925
1131
|
if logger is not None:
|
|
926
1132
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
927
1133
|
task_type_repr = 'binary classification'
|
|
@@ -935,21 +1141,6 @@ class KumoRFM:
|
|
|
935
1141
|
task_type_repr = str(task_type)
|
|
936
1142
|
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
937
1143
|
|
|
938
|
-
if task_type.is_link_pred and num_hops < 2:
|
|
939
|
-
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
940
|
-
f"with less than 2 hops (got {num_hops}) since "
|
|
941
|
-
f"historical target entities need to be part of "
|
|
942
|
-
f"the context. Please increase the number of "
|
|
943
|
-
f"hops and try again.")
|
|
944
|
-
|
|
945
|
-
if num_neighbors is None:
|
|
946
|
-
if run_mode == RunMode.DEBUG:
|
|
947
|
-
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
948
|
-
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
949
|
-
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
950
|
-
else:
|
|
951
|
-
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
952
|
-
|
|
953
1144
|
if query.target_ast.date_offset_range is None:
|
|
954
1145
|
step_offset = pd.DateOffset(0)
|
|
955
1146
|
else:
|
|
@@ -958,8 +1149,7 @@ class KumoRFM:
|
|
|
958
1149
|
|
|
959
1150
|
if anchor_time is None:
|
|
960
1151
|
anchor_time = self._get_default_anchor_time(query)
|
|
961
|
-
|
|
962
|
-
if evaluate:
|
|
1152
|
+
if num_test_examples > 0:
|
|
963
1153
|
anchor_time = anchor_time - end_offset
|
|
964
1154
|
|
|
965
1155
|
if logger is not None:
|
|
@@ -973,7 +1163,6 @@ class KumoRFM:
|
|
|
973
1163
|
else:
|
|
974
1164
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
975
1165
|
|
|
976
|
-
assert anchor_time is not None
|
|
977
1166
|
if isinstance(anchor_time, pd.Timestamp):
|
|
978
1167
|
if context_anchor_time == 'entity':
|
|
979
1168
|
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
@@ -981,7 +1170,7 @@ class KumoRFM:
|
|
|
981
1170
|
if context_anchor_time is None:
|
|
982
1171
|
context_anchor_time = anchor_time - end_offset
|
|
983
1172
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
984
|
-
evaluate)
|
|
1173
|
+
evaluate=num_test_examples > 0)
|
|
985
1174
|
else:
|
|
986
1175
|
assert anchor_time == 'entity'
|
|
987
1176
|
if query.query_type != QueryType.STATIC:
|
|
@@ -996,14 +1185,6 @@ class KumoRFM:
|
|
|
996
1185
|
"for context and prediction examples")
|
|
997
1186
|
context_anchor_time = 'entity'
|
|
998
1187
|
|
|
999
|
-
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1000
|
-
if evaluate:
|
|
1001
|
-
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
1002
|
-
if task_type.is_link_pred:
|
|
1003
|
-
num_test_examples = num_test_examples // 5
|
|
1004
|
-
else:
|
|
1005
|
-
num_test_examples = 0
|
|
1006
|
-
|
|
1007
1188
|
train, test = self._sampler.sample_target(
|
|
1008
1189
|
query=query,
|
|
1009
1190
|
num_train_examples=num_train_examples,
|
|
@@ -1014,39 +1195,32 @@ class KumoRFM:
|
|
|
1014
1195
|
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1015
1196
|
random_seed=random_seed,
|
|
1016
1197
|
)
|
|
1017
|
-
train_pkey, train_time,
|
|
1018
|
-
test_pkey, test_time,
|
|
1198
|
+
train_pkey, train_time, train_y = train
|
|
1199
|
+
test_pkey, test_time, test_y = test
|
|
1019
1200
|
|
|
1020
|
-
if
|
|
1201
|
+
if num_test_examples > 0 and logger is not None:
|
|
1021
1202
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1022
|
-
pos = 100 * int((
|
|
1023
|
-
msg = (f"Collected {len(
|
|
1203
|
+
pos = 100 * int((test_y > 0).sum()) / len(test_y)
|
|
1204
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1024
1205
|
f"{pos:.2f}% positive cases")
|
|
1025
1206
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1026
|
-
msg = (f"Collected {len(
|
|
1027
|
-
f"{
|
|
1207
|
+
msg = (f"Collected {len(test_y):,} test examples holding "
|
|
1208
|
+
f"{test_y.nunique()} classes")
|
|
1028
1209
|
elif task_type == TaskType.REGRESSION:
|
|
1029
|
-
_min, _max = float(
|
|
1030
|
-
msg = (f"Collected {len(
|
|
1210
|
+
_min, _max = float(test_y.min()), float(test_y.max())
|
|
1211
|
+
msg = (f"Collected {len(test_y):,} test examples with targets "
|
|
1031
1212
|
f"between {format_value(_min)} and "
|
|
1032
1213
|
f"{format_value(_max)}")
|
|
1033
1214
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1034
|
-
num_rhs =
|
|
1035
|
-
msg = (f"Collected {len(
|
|
1215
|
+
num_rhs = test_y.explode().nunique()
|
|
1216
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1036
1217
|
f"{num_rhs:,} unique items")
|
|
1037
1218
|
else:
|
|
1038
1219
|
raise NotImplementedError
|
|
1039
1220
|
logger.log(msg)
|
|
1040
1221
|
|
|
1041
|
-
if
|
|
1222
|
+
if num_test_examples == 0:
|
|
1042
1223
|
assert indices is not None
|
|
1043
|
-
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
1044
|
-
raise ValueError(f"Cannot predict for more than "
|
|
1045
|
-
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
1046
|
-
f"once (got {len(indices):,}). Use "
|
|
1047
|
-
f"`KumoRFM.batch_mode` to process entities "
|
|
1048
|
-
f"in batches")
|
|
1049
|
-
|
|
1050
1224
|
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
1051
1225
|
if isinstance(anchor_time, pd.Timestamp):
|
|
1052
1226
|
test_time = pd.Series([anchor_time]).repeat(
|
|
@@ -1056,26 +1230,26 @@ class KumoRFM:
|
|
|
1056
1230
|
|
|
1057
1231
|
if logger is not None:
|
|
1058
1232
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1059
|
-
pos = 100 * int((
|
|
1060
|
-
msg = (f"Collected {len(
|
|
1233
|
+
pos = 100 * int((train_y > 0).sum()) / len(train_y)
|
|
1234
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1061
1235
|
f"{pos:.2f}% positive cases")
|
|
1062
1236
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1063
|
-
msg = (f"Collected {len(
|
|
1064
|
-
f"holding {
|
|
1237
|
+
msg = (f"Collected {len(train_y):,} in-context examples "
|
|
1238
|
+
f"holding {train_y.nunique()} classes")
|
|
1065
1239
|
elif task_type == TaskType.REGRESSION:
|
|
1066
|
-
_min, _max = float(
|
|
1067
|
-
msg = (f"Collected {len(
|
|
1240
|
+
_min, _max = float(train_y.min()), float(train_y.max())
|
|
1241
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1068
1242
|
f"targets between {format_value(_min)} and "
|
|
1069
1243
|
f"{format_value(_max)}")
|
|
1070
1244
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1071
|
-
num_rhs =
|
|
1072
|
-
msg = (f"Collected {len(
|
|
1245
|
+
num_rhs = train_y.explode().nunique()
|
|
1246
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1073
1247
|
f"{num_rhs:,} unique items")
|
|
1074
1248
|
else:
|
|
1075
1249
|
raise NotImplementedError
|
|
1076
1250
|
logger.log(msg)
|
|
1077
1251
|
|
|
1078
|
-
entity_table_names: tuple[str,
|
|
1252
|
+
entity_table_names: tuple[str] | tuple[str, str]
|
|
1079
1253
|
if task_type.is_link_pred:
|
|
1080
1254
|
final_aggr = query.get_final_target_aggregation()
|
|
1081
1255
|
assert final_aggr is not None
|
|
@@ -1089,27 +1263,80 @@ class KumoRFM:
|
|
|
1089
1263
|
else:
|
|
1090
1264
|
entity_table_names = (query.entity_table, )
|
|
1091
1265
|
|
|
1266
|
+
context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
|
|
1267
|
+
if isinstance(train_time, pd.Series):
|
|
1268
|
+
context_df['ANCHOR_TIMESTAMP'] = train_time
|
|
1269
|
+
pred_df = pd.DataFrame({'ENTITY': test_pkey})
|
|
1270
|
+
if num_test_examples > 0:
|
|
1271
|
+
pred_df['TARGET'] = test_y
|
|
1272
|
+
if isinstance(test_time, pd.Series):
|
|
1273
|
+
pred_df['ANCHOR_TIMESTAMP'] = test_time
|
|
1274
|
+
|
|
1275
|
+
return TaskTable(
|
|
1276
|
+
task_type=task_type,
|
|
1277
|
+
context_df=context_df,
|
|
1278
|
+
pred_df=pred_df,
|
|
1279
|
+
entity_table_name=entity_table_names,
|
|
1280
|
+
entity_column='ENTITY',
|
|
1281
|
+
target_column='TARGET',
|
|
1282
|
+
time_column='ANCHOR_TIMESTAMP' if isinstance(
|
|
1283
|
+
train_time, pd.Series) else TaskTable.ENTITY_TIME,
|
|
1284
|
+
)
|
|
1285
|
+
|
|
1286
|
+
def _get_context(
|
|
1287
|
+
self,
|
|
1288
|
+
task: TaskTable,
|
|
1289
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
1290
|
+
num_neighbors: list[int] | None = None,
|
|
1291
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
1292
|
+
top_k: int | None = None,
|
|
1293
|
+
) -> Context:
|
|
1294
|
+
|
|
1295
|
+
if num_neighbors is None:
|
|
1296
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1297
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
|
1298
|
+
|
|
1299
|
+
if len(num_neighbors) > 6:
|
|
1300
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
1301
|
+
f"hops (got {len(num_neighbors)}). Reduce the "
|
|
1302
|
+
f"number of hops and try again. Please create a "
|
|
1303
|
+
f"feature request at "
|
|
1304
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1305
|
+
f"must go beyond this for your use-case.")
|
|
1306
|
+
|
|
1092
1307
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1093
1308
|
# running out-of-distribution between in-context and test examples:
|
|
1094
|
-
exclude_cols_dict =
|
|
1095
|
-
if
|
|
1096
|
-
if
|
|
1097
|
-
exclude_cols_dict[
|
|
1098
|
-
|
|
1099
|
-
exclude_cols_dict[
|
|
1309
|
+
exclude_cols_dict = exclude_cols_dict or {}
|
|
1310
|
+
if task.entity_table_name in self._sampler.time_column_dict:
|
|
1311
|
+
if task.entity_table_name not in exclude_cols_dict:
|
|
1312
|
+
exclude_cols_dict[task.entity_table_name] = []
|
|
1313
|
+
time_col = self._sampler.time_column_dict[task.entity_table_name]
|
|
1314
|
+
exclude_cols_dict[task.entity_table_name].append(time_col)
|
|
1315
|
+
|
|
1316
|
+
entity_pkey = pd.concat([
|
|
1317
|
+
task._context_df[task._entity_column],
|
|
1318
|
+
task._pred_df[task._entity_column],
|
|
1319
|
+
], axis=0, ignore_index=True)
|
|
1320
|
+
|
|
1321
|
+
if task.use_entity_time:
|
|
1322
|
+
if task.entity_table_name not in self._sampler.time_column_dict:
|
|
1323
|
+
raise ValueError(f"The given annchor time requires the entity "
|
|
1324
|
+
f"table '{task.entity_table_name}' to have a "
|
|
1325
|
+
f"time column")
|
|
1326
|
+
anchor_time = 'entity'
|
|
1327
|
+
elif task._time_column is not None:
|
|
1328
|
+
anchor_time = pd.concat([
|
|
1329
|
+
task._context_df[task._time_column],
|
|
1330
|
+
task._pred_df[task._time_column],
|
|
1331
|
+
], axis=0, ignore_index=True)
|
|
1332
|
+
else:
|
|
1333
|
+
anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
|
|
1334
|
+
(len(entity_pkey))).reset_index(drop=True)
|
|
1100
1335
|
|
|
1101
1336
|
subgraph = self._sampler.sample_subgraph(
|
|
1102
|
-
entity_table_names=entity_table_names,
|
|
1103
|
-
entity_pkey=
|
|
1104
|
-
|
|
1105
|
-
axis=0,
|
|
1106
|
-
ignore_index=True,
|
|
1107
|
-
),
|
|
1108
|
-
anchor_time=pd.concat(
|
|
1109
|
-
[train_time, test_time],
|
|
1110
|
-
axis=0,
|
|
1111
|
-
ignore_index=True,
|
|
1112
|
-
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1337
|
+
entity_table_names=task.entity_table_names,
|
|
1338
|
+
entity_pkey=entity_pkey,
|
|
1339
|
+
anchor_time=anchor_time,
|
|
1113
1340
|
num_neighbors=num_neighbors,
|
|
1114
1341
|
exclude_cols_dict=exclude_cols_dict,
|
|
1115
1342
|
)
|
|
@@ -1121,13 +1348,20 @@ class KumoRFM:
|
|
|
1121
1348
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1122
1349
|
f"must go beyond this for your use-case.")
|
|
1123
1350
|
|
|
1351
|
+
if (task.task_type.is_link_pred
|
|
1352
|
+
and task.entity_table_names[-1] not in subgraph.table_dict):
|
|
1353
|
+
raise ValueError("Cannot perform link prediction on subgraphs "
|
|
1354
|
+
"without any historical target entities. Please "
|
|
1355
|
+
"increase the number of hops and try again.")
|
|
1356
|
+
|
|
1124
1357
|
return Context(
|
|
1125
|
-
task_type=task_type,
|
|
1126
|
-
entity_table_names=entity_table_names,
|
|
1358
|
+
task_type=task.task_type,
|
|
1359
|
+
entity_table_names=task.entity_table_names,
|
|
1127
1360
|
subgraph=subgraph,
|
|
1128
|
-
y_train=
|
|
1129
|
-
y_test=
|
|
1130
|
-
|
|
1361
|
+
y_train=task._context_df[task.target_column.name],
|
|
1362
|
+
y_test=task._pred_df[task.target_column.name]
|
|
1363
|
+
if task.evaluate else None,
|
|
1364
|
+
top_k=top_k,
|
|
1131
1365
|
step_size=None,
|
|
1132
1366
|
)
|
|
1133
1367
|
|