kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
- kumoai/experimental/rfm/backend/snow/table.py +137 -64
- kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +5 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +69 -9
- kumoai/experimental/rfm/base/table.py +258 -97
- kumoai/experimental/rfm/graph.py +106 -98
- kumoai/experimental/rfm/infer/dtype.py +4 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +394 -241
- kumoai/experimental/rfm/task_table.py +290 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +34 -31
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
import time
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
@@ -28,30 +29,33 @@ from kumoapi.rfm import (
|
|
|
28
29
|
from kumoapi.task import TaskType
|
|
29
30
|
from kumoapi.typing import AggregationType, Stype
|
|
30
31
|
|
|
31
|
-
from kumoai import in_notebook, in_snowflake_notebook
|
|
32
32
|
from kumoai.client.rfm import RFMAPI
|
|
33
33
|
from kumoai.exceptions import HTTPException
|
|
34
|
-
from kumoai.experimental.rfm import Graph
|
|
34
|
+
from kumoai.experimental.rfm import Graph, TaskTable
|
|
35
35
|
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
36
36
|
from kumoai.mixin import CastMixin
|
|
37
|
-
from kumoai.utils import ProgressLogger
|
|
37
|
+
from kumoai.utils import ProgressLogger, display
|
|
38
38
|
|
|
39
39
|
_RANDOM_SEED = 42
|
|
40
40
|
|
|
41
41
|
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
42
42
|
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
43
43
|
|
|
44
|
+
_MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
|
|
45
|
+
_MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
|
|
46
|
+
|
|
44
47
|
_MAX_CONTEXT_SIZE = {
|
|
45
48
|
RunMode.DEBUG: 100,
|
|
46
49
|
RunMode.FAST: 1_000,
|
|
47
50
|
RunMode.NORMAL: 5_000,
|
|
48
51
|
RunMode.BEST: 10_000,
|
|
49
52
|
}
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
RunMode.
|
|
53
|
-
RunMode.
|
|
54
|
-
RunMode.
|
|
53
|
+
|
|
54
|
+
_DEFAULT_NUM_NEIGHBORS = {
|
|
55
|
+
RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
|
|
56
|
+
RunMode.FAST: [32, 32, 8, 8, 4, 4],
|
|
57
|
+
RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
|
|
58
|
+
RunMode.BEST: [64, 64, 8, 8, 4, 4],
|
|
55
59
|
}
|
|
56
60
|
|
|
57
61
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
@@ -104,23 +108,8 @@ class Explanation:
|
|
|
104
108
|
|
|
105
109
|
def print(self) -> None:
|
|
106
110
|
r"""Prints the explanation."""
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
st.dataframe(self.prediction, hide_index=True)
|
|
110
|
-
st.markdown(self.summary)
|
|
111
|
-
elif in_notebook():
|
|
112
|
-
from IPython.display import Markdown, display
|
|
113
|
-
try:
|
|
114
|
-
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
-
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
-
else:
|
|
117
|
-
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
-
except ImportError:
|
|
119
|
-
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
-
display(Markdown(self.summary))
|
|
121
|
-
else:
|
|
122
|
-
print(self.prediction.to_string(index=False))
|
|
123
|
-
print(self.summary)
|
|
111
|
+
display.dataframe(self.prediction)
|
|
112
|
+
display.message(self.summary)
|
|
124
113
|
|
|
125
114
|
def _ipython_display_(self) -> None:
|
|
126
115
|
self.print()
|
|
@@ -333,18 +322,133 @@ class KumoRFM:
|
|
|
333
322
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
334
323
|
containing the prediction, summary, and details.
|
|
335
324
|
"""
|
|
336
|
-
explain_config: ExplainConfig | None = None
|
|
337
|
-
if explain is True:
|
|
338
|
-
explain_config = ExplainConfig()
|
|
339
|
-
elif explain is not False:
|
|
340
|
-
explain_config = ExplainConfig._cast(explain)
|
|
341
|
-
|
|
342
325
|
query_def = self._parse_query(query)
|
|
343
|
-
query_str = query_def.to_string()
|
|
344
326
|
|
|
327
|
+
if indices is None:
|
|
328
|
+
if query_def.rfm_entity_ids is None:
|
|
329
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
330
|
+
"pass them via `predict(query, indices=...)`")
|
|
331
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
332
|
+
else:
|
|
333
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
334
|
+
|
|
335
|
+
if not isinstance(verbose, ProgressLogger):
|
|
336
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
337
|
+
if explain is not False:
|
|
338
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
339
|
+
else:
|
|
340
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
341
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
342
|
+
|
|
343
|
+
with verbose as logger:
|
|
344
|
+
task_table = self._get_task_table(
|
|
345
|
+
query=query_def,
|
|
346
|
+
indices=indices,
|
|
347
|
+
anchor_time=anchor_time,
|
|
348
|
+
context_anchor_time=context_anchor_time,
|
|
349
|
+
run_mode=run_mode,
|
|
350
|
+
max_pq_iterations=max_pq_iterations,
|
|
351
|
+
random_seed=random_seed,
|
|
352
|
+
logger=logger,
|
|
353
|
+
)
|
|
354
|
+
task_table._query = query_def.to_string() # type: ignore
|
|
355
|
+
|
|
356
|
+
return self.predict_task(
|
|
357
|
+
task_table,
|
|
358
|
+
explain=explain, # type: ignore
|
|
359
|
+
run_mode=run_mode,
|
|
360
|
+
num_neighbors=num_neighbors,
|
|
361
|
+
num_hops=num_hops,
|
|
362
|
+
verbose=verbose,
|
|
363
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
364
|
+
use_prediction_time=use_prediction_time,
|
|
365
|
+
top_k=query_def.top_k,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
@overload
|
|
369
|
+
def predict_task(
|
|
370
|
+
self,
|
|
371
|
+
task: TaskTable,
|
|
372
|
+
*,
|
|
373
|
+
explain: Literal[False] = False,
|
|
374
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
375
|
+
num_neighbors: list[int] | None = None,
|
|
376
|
+
num_hops: int = 2,
|
|
377
|
+
verbose: bool | ProgressLogger = True,
|
|
378
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
379
|
+
use_prediction_time: bool = False,
|
|
380
|
+
top_k: int | None = None,
|
|
381
|
+
) -> pd.DataFrame:
|
|
382
|
+
pass
|
|
383
|
+
|
|
384
|
+
@overload
|
|
385
|
+
def predict_task(
|
|
386
|
+
self,
|
|
387
|
+
task: TaskTable,
|
|
388
|
+
*,
|
|
389
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
390
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
391
|
+
num_neighbors: list[int] | None = None,
|
|
392
|
+
num_hops: int = 2,
|
|
393
|
+
verbose: bool | ProgressLogger = True,
|
|
394
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
395
|
+
use_prediction_time: bool = False,
|
|
396
|
+
top_k: int | None = None,
|
|
397
|
+
) -> Explanation:
|
|
398
|
+
pass
|
|
399
|
+
|
|
400
|
+
def predict_task(
|
|
401
|
+
self,
|
|
402
|
+
task: TaskTable,
|
|
403
|
+
*,
|
|
404
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
405
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
406
|
+
num_neighbors: list[int] | None = None,
|
|
407
|
+
num_hops: int = 2,
|
|
408
|
+
verbose: bool | ProgressLogger = True,
|
|
409
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
410
|
+
use_prediction_time: bool = False,
|
|
411
|
+
top_k: int | None = None,
|
|
412
|
+
) -> pd.DataFrame | Explanation:
|
|
413
|
+
"""Returns predictions for a custom task specification.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
task: The custom :class:`TaskTable`.
|
|
417
|
+
explain: Configuration for explainability.
|
|
418
|
+
If set to ``True``, will additionally explain the prediction.
|
|
419
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
420
|
+
over which parts of explanation are generated.
|
|
421
|
+
Explainability is currently only supported for single entity
|
|
422
|
+
predictions with ``run_mode="FAST"``.
|
|
423
|
+
run_mode: The :class:`RunMode` for the query.
|
|
424
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
425
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
426
|
+
num_hops: The number of hops to sample when generating the context.
|
|
427
|
+
verbose: Whether to print verbose output.
|
|
428
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
429
|
+
model input.
|
|
430
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
431
|
+
additional feature during prediction. This is typically
|
|
432
|
+
beneficial for time series forecasting tasks.
|
|
433
|
+
top_k: The number of predictions to return per entity.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
437
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
438
|
+
containing the prediction, summary, and details.
|
|
439
|
+
"""
|
|
345
440
|
if num_hops != 2 and num_neighbors is not None:
|
|
346
441
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
347
442
|
f"custom 'num_hops={num_hops}' option")
|
|
443
|
+
if num_neighbors is None:
|
|
444
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
445
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
446
|
+
|
|
447
|
+
explain_config: ExplainConfig | None = None
|
|
448
|
+
if explain is True:
|
|
449
|
+
explain_config = ExplainConfig()
|
|
450
|
+
elif explain is not False:
|
|
451
|
+
explain_config = ExplainConfig._cast(explain)
|
|
348
452
|
|
|
349
453
|
if explain_config is not None and run_mode in {
|
|
350
454
|
RunMode.NORMAL, RunMode.BEST
|
|
@@ -353,83 +457,82 @@ class KumoRFM:
|
|
|
353
457
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
354
458
|
f"mode has been reset. Please lower the run mode to "
|
|
355
459
|
f"suppress this warning.")
|
|
460
|
+
run_mode = RunMode.FAST
|
|
356
461
|
|
|
357
|
-
if
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
362
|
-
else:
|
|
363
|
-
query_def = replace(query_def, rfm_entity_ids=None)
|
|
364
|
-
|
|
365
|
-
if len(indices) == 0:
|
|
366
|
-
raise ValueError("At least one entity is required")
|
|
367
|
-
|
|
368
|
-
if explain_config is not None and len(indices) > 1:
|
|
369
|
-
raise ValueError(
|
|
370
|
-
f"Cannot explain predictions for more than a single entity "
|
|
371
|
-
f"(got {len(indices)})")
|
|
372
|
-
|
|
373
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
374
|
-
if explain_config is not None:
|
|
375
|
-
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
376
|
-
else:
|
|
377
|
-
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
462
|
+
if explain_config is not None and task.num_prediction_examples > 1:
|
|
463
|
+
raise ValueError(f"Cannot explain predictions for more than a "
|
|
464
|
+
f"single entity "
|
|
465
|
+
f"(got {task.num_prediction_examples:,})")
|
|
378
466
|
|
|
379
467
|
if not isinstance(verbose, ProgressLogger):
|
|
468
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
469
|
+
task_type_repr = 'binary classification'
|
|
470
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
471
|
+
task_type_repr = 'multi-class classification'
|
|
472
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
473
|
+
task_type_repr = 'regression'
|
|
474
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
475
|
+
task_type_repr = 'link prediction'
|
|
476
|
+
else:
|
|
477
|
+
task_type_repr = str(task.task_type)
|
|
478
|
+
|
|
479
|
+
if explain_config is not None:
|
|
480
|
+
msg = f'Explain {task_type_repr} task'
|
|
481
|
+
else:
|
|
482
|
+
msg = f'Predict {task_type_repr} task'
|
|
380
483
|
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
381
484
|
|
|
382
485
|
with verbose as logger:
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
batch_size =
|
|
486
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
487
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
488
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
489
|
+
f"examples")
|
|
490
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
491
|
+
|
|
492
|
+
if self._batch_size is None:
|
|
493
|
+
batch_size = task.num_prediction_examples
|
|
494
|
+
elif self._batch_size == 'max':
|
|
495
|
+
batch_size = _MAX_PRED_SIZE[task.task_type]
|
|
391
496
|
else:
|
|
392
497
|
batch_size = self._batch_size
|
|
393
498
|
|
|
394
|
-
if batch_size
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
499
|
+
if batch_size > _MAX_PRED_SIZE[task.task_type]:
|
|
500
|
+
raise ValueError(f"Cannot predict for more than "
|
|
501
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
502
|
+
f"entities at once (got {batch_size:,}). Use "
|
|
503
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
504
|
+
f"in batches with a sufficient batch size.")
|
|
399
505
|
|
|
400
|
-
if
|
|
401
|
-
|
|
402
|
-
|
|
506
|
+
if task.num_prediction_examples > batch_size:
|
|
507
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
508
|
+
logger.log(f"Splitting {task.num_prediction_examples:,} "
|
|
509
|
+
f"entities into {num:,} batches of size "
|
|
510
|
+
f"{batch_size:,}")
|
|
403
511
|
|
|
404
512
|
predictions: list[pd.DataFrame] = []
|
|
405
513
|
summary: str | None = None
|
|
406
514
|
details: Explanation | None = None
|
|
407
|
-
for
|
|
408
|
-
# TODO Re-use the context for subsequent predictions.
|
|
515
|
+
for start in range(0, task.num_prediction_examples, batch_size):
|
|
409
516
|
context = self._get_context(
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
anchor_time=anchor_time,
|
|
413
|
-
context_anchor_time=context_anchor_time,
|
|
414
|
-
run_mode=RunMode(run_mode),
|
|
517
|
+
task=task.narrow_prediction(start, length=batch_size),
|
|
518
|
+
run_mode=run_mode,
|
|
415
519
|
num_neighbors=num_neighbors,
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
evaluate=False,
|
|
419
|
-
random_seed=random_seed,
|
|
420
|
-
logger=logger if i == 0 else None,
|
|
520
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
521
|
+
top_k=top_k,
|
|
421
522
|
)
|
|
523
|
+
context.y_test = None
|
|
524
|
+
|
|
422
525
|
request = RFMPredictRequest(
|
|
423
526
|
context=context,
|
|
424
527
|
run_mode=RunMode(run_mode),
|
|
425
|
-
query=
|
|
528
|
+
query=getattr(task, '_query', ''),
|
|
426
529
|
use_prediction_time=use_prediction_time,
|
|
427
530
|
)
|
|
428
531
|
with warnings.catch_warnings():
|
|
429
532
|
warnings.filterwarnings('ignore', message='gencode')
|
|
430
533
|
request_msg = request.to_protobuf()
|
|
431
534
|
_bytes = request_msg.SerializeToString()
|
|
432
|
-
if
|
|
535
|
+
if start == 0:
|
|
433
536
|
logger.log(f"Generated context of size "
|
|
434
537
|
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
435
538
|
|
|
@@ -437,11 +540,9 @@ class KumoRFM:
|
|
|
437
540
|
stats = Context.get_memory_stats(request_msg.context)
|
|
438
541
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
439
542
|
|
|
440
|
-
if
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
description='Predicting',
|
|
444
|
-
)
|
|
543
|
+
if start == 0 and task.num_prediction_examples > batch_size:
|
|
544
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
545
|
+
verbose.init_progress(total=num, description='Predicting')
|
|
445
546
|
|
|
446
547
|
for attempt in range(self.num_retries + 1):
|
|
447
548
|
try:
|
|
@@ -459,7 +560,7 @@ class KumoRFM:
|
|
|
459
560
|
# Cast 'ENTITY' to correct data type:
|
|
460
561
|
if 'ENTITY' in df:
|
|
461
562
|
table_dict = context.subgraph.table_dict
|
|
462
|
-
table = table_dict[
|
|
563
|
+
table = table_dict[context.entity_table_names[0]]
|
|
463
564
|
ser = table.df[table.primary_key]
|
|
464
565
|
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
465
566
|
|
|
@@ -476,7 +577,7 @@ class KumoRFM:
|
|
|
476
577
|
|
|
477
578
|
predictions.append(df)
|
|
478
579
|
|
|
479
|
-
if
|
|
580
|
+
if task.num_prediction_examples > batch_size:
|
|
480
581
|
verbose.step()
|
|
481
582
|
|
|
482
583
|
break
|
|
@@ -601,40 +702,51 @@ class KumoRFM:
|
|
|
601
702
|
The metrics as a :class:`pandas.DataFrame`
|
|
602
703
|
"""
|
|
603
704
|
query_def = self._parse_query(query)
|
|
604
|
-
|
|
605
|
-
if num_hops != 2 and num_neighbors is not None:
|
|
606
|
-
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
607
|
-
f"custom 'num_hops={num_hops}' option")
|
|
608
|
-
|
|
609
705
|
if query_def.rfm_entity_ids is not None:
|
|
610
706
|
query_def = replace(
|
|
611
707
|
query_def,
|
|
612
708
|
rfm_entity_ids=None,
|
|
613
709
|
)
|
|
614
710
|
|
|
615
|
-
|
|
616
|
-
|
|
711
|
+
task_type = self._get_task_type(
|
|
712
|
+
query=query_def,
|
|
713
|
+
edge_types=self._sampler.edge_types,
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
if num_hops != 2 and num_neighbors is not None:
|
|
717
|
+
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
718
|
+
f"custom 'num_hops={num_hops}' option")
|
|
719
|
+
if num_neighbors is None:
|
|
720
|
+
key = RunMode.FAST if task_type.is_link_pred else run_mode
|
|
721
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
722
|
+
|
|
723
|
+
if metrics is not None and len(metrics) > 0:
|
|
724
|
+
self._validate_metrics(metrics, task_type)
|
|
725
|
+
metrics = list(dict.fromkeys(metrics))
|
|
617
726
|
|
|
618
727
|
if not isinstance(verbose, ProgressLogger):
|
|
728
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
729
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
619
730
|
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
620
731
|
|
|
621
732
|
with verbose as logger:
|
|
622
|
-
|
|
733
|
+
task_table = self._get_task_table(
|
|
623
734
|
query=query_def,
|
|
624
735
|
indices=None,
|
|
625
736
|
anchor_time=anchor_time,
|
|
626
737
|
context_anchor_time=context_anchor_time,
|
|
627
|
-
run_mode=
|
|
628
|
-
num_neighbors=num_neighbors,
|
|
629
|
-
num_hops=num_hops,
|
|
738
|
+
run_mode=run_mode,
|
|
630
739
|
max_pq_iterations=max_pq_iterations,
|
|
631
|
-
evaluate=True,
|
|
632
740
|
random_seed=random_seed,
|
|
633
|
-
logger=logger
|
|
741
|
+
logger=logger,
|
|
634
742
|
)
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
743
|
+
context = self._get_context(
|
|
744
|
+
task=task_table,
|
|
745
|
+
run_mode=run_mode,
|
|
746
|
+
num_neighbors=num_neighbors,
|
|
747
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
748
|
+
)
|
|
749
|
+
|
|
638
750
|
request = RFMEvaluateRequest(
|
|
639
751
|
context=context,
|
|
640
752
|
run_mode=RunMode(run_mode),
|
|
@@ -652,17 +764,23 @@ class KumoRFM:
|
|
|
652
764
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
653
765
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
654
766
|
|
|
655
|
-
|
|
656
|
-
resp = self._api_client.evaluate(request_bytes)
|
|
657
|
-
except HTTPException as e:
|
|
767
|
+
for attempt in range(self.num_retries + 1):
|
|
658
768
|
try:
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
769
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
770
|
+
break
|
|
771
|
+
except HTTPException as e:
|
|
772
|
+
if attempt == self.num_retries:
|
|
773
|
+
try:
|
|
774
|
+
msg = json.loads(e.detail)['detail']
|
|
775
|
+
except Exception:
|
|
776
|
+
msg = e.detail
|
|
777
|
+
raise RuntimeError(
|
|
778
|
+
f"An unexpected exception occurred. Please create "
|
|
779
|
+
f"an issue at "
|
|
780
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
781
|
+
) from None
|
|
782
|
+
|
|
783
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
666
784
|
|
|
667
785
|
return pd.DataFrame.from_dict(
|
|
668
786
|
resp.metrics,
|
|
@@ -714,7 +832,7 @@ class KumoRFM:
|
|
|
714
832
|
f"to have a time column")
|
|
715
833
|
|
|
716
834
|
train, test = self._sampler.sample_target(
|
|
717
|
-
query=
|
|
835
|
+
query=query_def,
|
|
718
836
|
num_train_examples=0,
|
|
719
837
|
train_anchor_time=anchor_time,
|
|
720
838
|
num_train_trials=0,
|
|
@@ -742,30 +860,34 @@ class KumoRFM:
|
|
|
742
860
|
"`predict()` or `evaluate()` methods to perform "
|
|
743
861
|
"predictions or evaluations.")
|
|
744
862
|
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
863
|
+
request = RFMParseQueryRequest(
|
|
864
|
+
query=query,
|
|
865
|
+
graph_definition=self._graph_def,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
for attempt in range(self.num_retries + 1):
|
|
869
|
+
try:
|
|
870
|
+
resp = self._api_client.parse_query(request)
|
|
871
|
+
break
|
|
872
|
+
except HTTPException as e:
|
|
873
|
+
if attempt == self.num_retries:
|
|
874
|
+
try:
|
|
875
|
+
msg = json.loads(e.detail)['detail']
|
|
876
|
+
except Exception:
|
|
877
|
+
msg = e.detail
|
|
878
|
+
raise ValueError(f"Failed to parse query '{query}'. {msg}")
|
|
750
879
|
|
|
751
|
-
|
|
880
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
752
881
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
882
|
+
if len(resp.validation_response.warnings) > 0:
|
|
883
|
+
msg = '\n'.join([
|
|
884
|
+
f'{i+1}. {warning.title}: {warning.message}'
|
|
885
|
+
for i, warning in enumerate(resp.validation_response.warnings)
|
|
886
|
+
])
|
|
887
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
888
|
+
f"parsing:\n{msg}")
|
|
760
889
|
|
|
761
|
-
|
|
762
|
-
except HTTPException as e:
|
|
763
|
-
try:
|
|
764
|
-
msg = json.loads(e.detail)['detail']
|
|
765
|
-
except Exception:
|
|
766
|
-
msg = e.detail
|
|
767
|
-
raise ValueError(f"Failed to parse query '{query}'. "
|
|
768
|
-
f"{msg}") from None
|
|
890
|
+
return resp.query
|
|
769
891
|
|
|
770
892
|
@staticmethod
|
|
771
893
|
def _get_task_type(
|
|
@@ -809,16 +931,15 @@ class KumoRFM:
|
|
|
809
931
|
|
|
810
932
|
def _get_default_anchor_time(
|
|
811
933
|
self,
|
|
812
|
-
query: ValidatedPredictiveQuery,
|
|
934
|
+
query: ValidatedPredictiveQuery | None = None,
|
|
813
935
|
) -> pd.Timestamp:
|
|
814
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
936
|
+
if query is not None and query.query_type == QueryType.TEMPORAL:
|
|
815
937
|
aggr_table_names = [
|
|
816
938
|
aggr._get_target_column_name().split('.')[0]
|
|
817
939
|
for aggr in query.get_all_target_aggregations()
|
|
818
940
|
]
|
|
819
941
|
return self._sampler.get_max_time(aggr_table_names)
|
|
820
942
|
|
|
821
|
-
assert query.query_type == QueryType.STATIC
|
|
822
943
|
return self._sampler.get_max_time()
|
|
823
944
|
|
|
824
945
|
def _validate_time(
|
|
@@ -888,40 +1009,26 @@ class KumoRFM:
|
|
|
888
1009
|
f"Anchor timestamp for evaluation is after the latest "
|
|
889
1010
|
f"supported timestamp '{max_time - end_offset}'.")
|
|
890
1011
|
|
|
891
|
-
def
|
|
1012
|
+
def _get_task_table(
|
|
892
1013
|
self,
|
|
893
1014
|
query: ValidatedPredictiveQuery,
|
|
894
1015
|
indices: list[str] | list[float] | list[int] | None,
|
|
895
|
-
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
896
|
-
context_anchor_time: pd.Timestamp | None,
|
|
897
|
-
run_mode: RunMode,
|
|
898
|
-
|
|
899
|
-
num_hops: int,
|
|
900
|
-
max_pq_iterations: int,
|
|
901
|
-
evaluate: bool,
|
|
1016
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
1017
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
1018
|
+
run_mode: RunMode = RunMode.FAST,
|
|
1019
|
+
max_pq_iterations: int = 10,
|
|
902
1020
|
random_seed: int | None = _RANDOM_SEED,
|
|
903
1021
|
logger: ProgressLogger | None = None,
|
|
904
|
-
) ->
|
|
905
|
-
|
|
906
|
-
if num_neighbors is not None:
|
|
907
|
-
num_hops = len(num_neighbors)
|
|
908
|
-
|
|
909
|
-
if num_hops < 0:
|
|
910
|
-
raise ValueError(f"'num_hops' must be non-negative "
|
|
911
|
-
f"(got {num_hops})")
|
|
912
|
-
if num_hops > 6:
|
|
913
|
-
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
914
|
-
f"hops (got {num_hops}). Please reduce the "
|
|
915
|
-
f"number of hops and try again. Please create a "
|
|
916
|
-
f"feature request at "
|
|
917
|
-
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
918
|
-
f"must go beyond this for your use-case.")
|
|
1022
|
+
) -> TaskTable:
|
|
919
1023
|
|
|
920
1024
|
task_type = self._get_task_type(
|
|
921
1025
|
query=query,
|
|
922
1026
|
edge_types=self._sampler.edge_types,
|
|
923
1027
|
)
|
|
924
1028
|
|
|
1029
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1030
|
+
num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
|
|
1031
|
+
|
|
925
1032
|
if logger is not None:
|
|
926
1033
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
927
1034
|
task_type_repr = 'binary classification'
|
|
@@ -935,21 +1042,6 @@ class KumoRFM:
|
|
|
935
1042
|
task_type_repr = str(task_type)
|
|
936
1043
|
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
937
1044
|
|
|
938
|
-
if task_type.is_link_pred and num_hops < 2:
|
|
939
|
-
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
940
|
-
f"with less than 2 hops (got {num_hops}) since "
|
|
941
|
-
f"historical target entities need to be part of "
|
|
942
|
-
f"the context. Please increase the number of "
|
|
943
|
-
f"hops and try again.")
|
|
944
|
-
|
|
945
|
-
if num_neighbors is None:
|
|
946
|
-
if run_mode == RunMode.DEBUG:
|
|
947
|
-
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
948
|
-
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
949
|
-
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
950
|
-
else:
|
|
951
|
-
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
952
|
-
|
|
953
1045
|
if query.target_ast.date_offset_range is None:
|
|
954
1046
|
step_offset = pd.DateOffset(0)
|
|
955
1047
|
else:
|
|
@@ -958,8 +1050,7 @@ class KumoRFM:
|
|
|
958
1050
|
|
|
959
1051
|
if anchor_time is None:
|
|
960
1052
|
anchor_time = self._get_default_anchor_time(query)
|
|
961
|
-
|
|
962
|
-
if evaluate:
|
|
1053
|
+
if num_test_examples > 0:
|
|
963
1054
|
anchor_time = anchor_time - end_offset
|
|
964
1055
|
|
|
965
1056
|
if logger is not None:
|
|
@@ -973,7 +1064,6 @@ class KumoRFM:
|
|
|
973
1064
|
else:
|
|
974
1065
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
975
1066
|
|
|
976
|
-
assert anchor_time is not None
|
|
977
1067
|
if isinstance(anchor_time, pd.Timestamp):
|
|
978
1068
|
if context_anchor_time == 'entity':
|
|
979
1069
|
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
@@ -981,7 +1071,7 @@ class KumoRFM:
|
|
|
981
1071
|
if context_anchor_time is None:
|
|
982
1072
|
context_anchor_time = anchor_time - end_offset
|
|
983
1073
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
984
|
-
evaluate)
|
|
1074
|
+
evaluate=num_test_examples > 0)
|
|
985
1075
|
else:
|
|
986
1076
|
assert anchor_time == 'entity'
|
|
987
1077
|
if query.query_type != QueryType.STATIC:
|
|
@@ -996,14 +1086,6 @@ class KumoRFM:
|
|
|
996
1086
|
"for context and prediction examples")
|
|
997
1087
|
context_anchor_time = 'entity'
|
|
998
1088
|
|
|
999
|
-
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1000
|
-
if evaluate:
|
|
1001
|
-
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
1002
|
-
if task_type.is_link_pred:
|
|
1003
|
-
num_test_examples = num_test_examples // 5
|
|
1004
|
-
else:
|
|
1005
|
-
num_test_examples = 0
|
|
1006
|
-
|
|
1007
1089
|
train, test = self._sampler.sample_target(
|
|
1008
1090
|
query=query,
|
|
1009
1091
|
num_train_examples=num_train_examples,
|
|
@@ -1014,39 +1096,32 @@ class KumoRFM:
|
|
|
1014
1096
|
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1015
1097
|
random_seed=random_seed,
|
|
1016
1098
|
)
|
|
1017
|
-
train_pkey, train_time,
|
|
1018
|
-
test_pkey, test_time,
|
|
1099
|
+
train_pkey, train_time, train_y = train
|
|
1100
|
+
test_pkey, test_time, test_y = test
|
|
1019
1101
|
|
|
1020
|
-
if
|
|
1102
|
+
if num_test_examples > 0 and logger is not None:
|
|
1021
1103
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1022
|
-
pos = 100 * int((
|
|
1023
|
-
msg = (f"Collected {len(
|
|
1104
|
+
pos = 100 * int((test_y > 0).sum()) / len(test_y)
|
|
1105
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1024
1106
|
f"{pos:.2f}% positive cases")
|
|
1025
1107
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1026
|
-
msg = (f"Collected {len(
|
|
1027
|
-
f"{
|
|
1108
|
+
msg = (f"Collected {len(test_y):,} test examples holding "
|
|
1109
|
+
f"{test_y.nunique()} classes")
|
|
1028
1110
|
elif task_type == TaskType.REGRESSION:
|
|
1029
|
-
_min, _max = float(
|
|
1030
|
-
msg = (f"Collected {len(
|
|
1111
|
+
_min, _max = float(test_y.min()), float(test_y.max())
|
|
1112
|
+
msg = (f"Collected {len(test_y):,} test examples with targets "
|
|
1031
1113
|
f"between {format_value(_min)} and "
|
|
1032
1114
|
f"{format_value(_max)}")
|
|
1033
1115
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1034
|
-
num_rhs =
|
|
1035
|
-
msg = (f"Collected {len(
|
|
1116
|
+
num_rhs = test_y.explode().nunique()
|
|
1117
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1036
1118
|
f"{num_rhs:,} unique items")
|
|
1037
1119
|
else:
|
|
1038
1120
|
raise NotImplementedError
|
|
1039
1121
|
logger.log(msg)
|
|
1040
1122
|
|
|
1041
|
-
if
|
|
1123
|
+
if num_test_examples == 0:
|
|
1042
1124
|
assert indices is not None
|
|
1043
|
-
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
1044
|
-
raise ValueError(f"Cannot predict for more than "
|
|
1045
|
-
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
1046
|
-
f"once (got {len(indices):,}). Use "
|
|
1047
|
-
f"`KumoRFM.batch_mode` to process entities "
|
|
1048
|
-
f"in batches")
|
|
1049
|
-
|
|
1050
1125
|
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
1051
1126
|
if isinstance(anchor_time, pd.Timestamp):
|
|
1052
1127
|
test_time = pd.Series([anchor_time]).repeat(
|
|
@@ -1056,26 +1131,26 @@ class KumoRFM:
|
|
|
1056
1131
|
|
|
1057
1132
|
if logger is not None:
|
|
1058
1133
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1059
|
-
pos = 100 * int((
|
|
1060
|
-
msg = (f"Collected {len(
|
|
1134
|
+
pos = 100 * int((train_y > 0).sum()) / len(train_y)
|
|
1135
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1061
1136
|
f"{pos:.2f}% positive cases")
|
|
1062
1137
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1063
|
-
msg = (f"Collected {len(
|
|
1064
|
-
f"holding {
|
|
1138
|
+
msg = (f"Collected {len(train_y):,} in-context examples "
|
|
1139
|
+
f"holding {train_y.nunique()} classes")
|
|
1065
1140
|
elif task_type == TaskType.REGRESSION:
|
|
1066
|
-
_min, _max = float(
|
|
1067
|
-
msg = (f"Collected {len(
|
|
1141
|
+
_min, _max = float(train_y.min()), float(train_y.max())
|
|
1142
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1068
1143
|
f"targets between {format_value(_min)} and "
|
|
1069
1144
|
f"{format_value(_max)}")
|
|
1070
1145
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1071
|
-
num_rhs =
|
|
1072
|
-
msg = (f"Collected {len(
|
|
1146
|
+
num_rhs = train_y.explode().nunique()
|
|
1147
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1073
1148
|
f"{num_rhs:,} unique items")
|
|
1074
1149
|
else:
|
|
1075
1150
|
raise NotImplementedError
|
|
1076
1151
|
logger.log(msg)
|
|
1077
1152
|
|
|
1078
|
-
entity_table_names: tuple[str,
|
|
1153
|
+
entity_table_names: tuple[str] | tuple[str, str]
|
|
1079
1154
|
if task_type.is_link_pred:
|
|
1080
1155
|
final_aggr = query.get_final_target_aggregation()
|
|
1081
1156
|
assert final_aggr is not None
|
|
@@ -1089,27 +1164,98 @@ class KumoRFM:
|
|
|
1089
1164
|
else:
|
|
1090
1165
|
entity_table_names = (query.entity_table, )
|
|
1091
1166
|
|
|
1167
|
+
context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
|
|
1168
|
+
if isinstance(train_time, pd.Series):
|
|
1169
|
+
context_df['ANCHOR_TIMESTAMP'] = train_time
|
|
1170
|
+
pred_df = pd.DataFrame({'ENTITY': test_pkey})
|
|
1171
|
+
if num_test_examples > 0:
|
|
1172
|
+
pred_df['TARGET'] = test_y
|
|
1173
|
+
if isinstance(test_time, pd.Series):
|
|
1174
|
+
pred_df['ANCHOR_TIMESTAMP'] = test_time
|
|
1175
|
+
|
|
1176
|
+
return TaskTable(
|
|
1177
|
+
task_type=task_type,
|
|
1178
|
+
context_df=context_df,
|
|
1179
|
+
pred_df=pred_df,
|
|
1180
|
+
entity_table_name=entity_table_names,
|
|
1181
|
+
entity_column='ENTITY',
|
|
1182
|
+
target_column='TARGET',
|
|
1183
|
+
time_column='ANCHOR_TIMESTAMP' if isinstance(
|
|
1184
|
+
train_time, pd.Series) else TaskTable.ENTITY_TIME,
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
def _get_context(
|
|
1188
|
+
self,
|
|
1189
|
+
task: TaskTable,
|
|
1190
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
1191
|
+
num_neighbors: list[int] | None = None,
|
|
1192
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
1193
|
+
top_k: int | None = None,
|
|
1194
|
+
) -> Context:
|
|
1195
|
+
|
|
1196
|
+
# TODO Remove all
|
|
1197
|
+
if task.num_context_examples > max(_MAX_CONTEXT_SIZE.values()):
|
|
1198
|
+
raise ValueError(f"Cannot process a context with more than "
|
|
1199
|
+
f"{max(_MAX_CONTEXT_SIZE.values()):,} samples "
|
|
1200
|
+
f"(got {task.num_context_examples:,})")
|
|
1201
|
+
if task.evaluate:
|
|
1202
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
1203
|
+
raise ValueError(f"Cannot process a test set with more than "
|
|
1204
|
+
f"{_MAX_TEST_SIZE[task.task_type]:,} samples "
|
|
1205
|
+
f"for evaluation "
|
|
1206
|
+
f"(got {task.num_prediction_examples:,})")
|
|
1207
|
+
else:
|
|
1208
|
+
if task.num_prediction_examples > _MAX_PRED_SIZE[task.task_type]:
|
|
1209
|
+
raise ValueError(f"Cannot predict for more than "
|
|
1210
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
1211
|
+
f"entities at once "
|
|
1212
|
+
f"(got {task.num_prediction_examples:,})")
|
|
1213
|
+
|
|
1214
|
+
if num_neighbors is None:
|
|
1215
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1216
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
|
1217
|
+
|
|
1218
|
+
if len(num_neighbors) > 6:
|
|
1219
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
1220
|
+
f"hops (got {len(num_neighbors)}). Reduce the "
|
|
1221
|
+
f"number of hops and try again. Please create a "
|
|
1222
|
+
f"feature request at "
|
|
1223
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1224
|
+
f"must go beyond this for your use-case.")
|
|
1225
|
+
|
|
1092
1226
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1093
1227
|
# running out-of-distribution between in-context and test examples:
|
|
1094
|
-
exclude_cols_dict =
|
|
1095
|
-
if
|
|
1096
|
-
if
|
|
1097
|
-
exclude_cols_dict[
|
|
1098
|
-
|
|
1099
|
-
exclude_cols_dict[
|
|
1228
|
+
exclude_cols_dict = exclude_cols_dict or {}
|
|
1229
|
+
if task.entity_table_name in self._sampler.time_column_dict:
|
|
1230
|
+
if task.entity_table_name not in exclude_cols_dict:
|
|
1231
|
+
exclude_cols_dict[task.entity_table_name] = []
|
|
1232
|
+
time_col = self._sampler.time_column_dict[task.entity_table_name]
|
|
1233
|
+
exclude_cols_dict[task.entity_table_name].append(time_col)
|
|
1234
|
+
|
|
1235
|
+
entity_pkey = pd.concat([
|
|
1236
|
+
task._context_df[task._entity_column],
|
|
1237
|
+
task._pred_df[task._entity_column],
|
|
1238
|
+
], axis=0, ignore_index=True)
|
|
1239
|
+
|
|
1240
|
+
if task.use_entity_time:
|
|
1241
|
+
if task.entity_table_name not in self._sampler.time_column_dict:
|
|
1242
|
+
raise ValueError(f"The given annchor time requires the entity "
|
|
1243
|
+
f"table '{task.entity_table_name}' to have a "
|
|
1244
|
+
f"time column")
|
|
1245
|
+
anchor_time = 'entity'
|
|
1246
|
+
elif task._time_column is not None:
|
|
1247
|
+
anchor_time = pd.concat([
|
|
1248
|
+
task._context_df[task._time_column],
|
|
1249
|
+
task._pred_df[task._time_column],
|
|
1250
|
+
], axis=0, ignore_index=True)
|
|
1251
|
+
else:
|
|
1252
|
+
anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
|
|
1253
|
+
(len(entity_pkey))).reset_index(drop=True)
|
|
1100
1254
|
|
|
1101
1255
|
subgraph = self._sampler.sample_subgraph(
|
|
1102
|
-
entity_table_names=entity_table_names,
|
|
1103
|
-
entity_pkey=
|
|
1104
|
-
|
|
1105
|
-
axis=0,
|
|
1106
|
-
ignore_index=True,
|
|
1107
|
-
),
|
|
1108
|
-
anchor_time=pd.concat(
|
|
1109
|
-
[train_time, test_time],
|
|
1110
|
-
axis=0,
|
|
1111
|
-
ignore_index=True,
|
|
1112
|
-
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1256
|
+
entity_table_names=task.entity_table_names,
|
|
1257
|
+
entity_pkey=entity_pkey,
|
|
1258
|
+
anchor_time=anchor_time,
|
|
1113
1259
|
num_neighbors=num_neighbors,
|
|
1114
1260
|
exclude_cols_dict=exclude_cols_dict,
|
|
1115
1261
|
)
|
|
@@ -1121,13 +1267,20 @@ class KumoRFM:
|
|
|
1121
1267
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1122
1268
|
f"must go beyond this for your use-case.")
|
|
1123
1269
|
|
|
1270
|
+
if (task.task_type.is_link_pred
|
|
1271
|
+
and task.entity_table_names[-1] not in subgraph.table_dict):
|
|
1272
|
+
raise ValueError("Cannot perform link prediction on subgraphs "
|
|
1273
|
+
"without any historical target entities. Please "
|
|
1274
|
+
"increase the number of hops and try again.")
|
|
1275
|
+
|
|
1124
1276
|
return Context(
|
|
1125
|
-
task_type=task_type,
|
|
1126
|
-
entity_table_names=entity_table_names,
|
|
1277
|
+
task_type=task.task_type,
|
|
1278
|
+
entity_table_names=task.entity_table_names,
|
|
1127
1279
|
subgraph=subgraph,
|
|
1128
|
-
y_train=
|
|
1129
|
-
y_test=
|
|
1130
|
-
|
|
1280
|
+
y_train=task._context_df[task.target_column.name],
|
|
1281
|
+
y_test=task._pred_df[task.target_column.name]
|
|
1282
|
+
if task.evaluate else None,
|
|
1283
|
+
top_k=top_k,
|
|
1131
1284
|
step_size=None,
|
|
1132
1285
|
)
|
|
1133
1286
|
|