kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +184 -70
- kumoai/experimental/rfm/backend/snow/table.py +137 -64
- kumoai/experimental/rfm/backend/sqlite/sampler.py +191 -86
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +26 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +182 -19
- kumoai/experimental/rfm/base/table.py +275 -109
- kumoai/experimental/rfm/graph.py +115 -107
- kumoai/experimental/rfm/infer/dtype.py +4 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +530 -304
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +36 -33
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
import time
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
@@ -7,7 +8,6 @@ from contextlib import contextmanager
|
|
|
7
8
|
from dataclasses import dataclass, replace
|
|
8
9
|
from typing import Any, Literal, overload
|
|
9
10
|
|
|
10
|
-
import numpy as np
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from kumoapi.model_plan import RunMode
|
|
13
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
@@ -27,31 +27,37 @@ from kumoapi.rfm import (
|
|
|
27
27
|
)
|
|
28
28
|
from kumoapi.task import TaskType
|
|
29
29
|
from kumoapi.typing import AggregationType, Stype
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.markdown import Markdown
|
|
30
32
|
|
|
31
|
-
from kumoai import in_notebook
|
|
33
|
+
from kumoai import in_notebook
|
|
32
34
|
from kumoai.client.rfm import RFMAPI
|
|
33
35
|
from kumoai.exceptions import HTTPException
|
|
34
|
-
from kumoai.experimental.rfm import Graph
|
|
36
|
+
from kumoai.experimental.rfm import Graph, TaskTable
|
|
35
37
|
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
36
38
|
from kumoai.mixin import CastMixin
|
|
37
|
-
from kumoai.utils import ProgressLogger
|
|
39
|
+
from kumoai.utils import ProgressLogger, display
|
|
38
40
|
|
|
39
41
|
_RANDOM_SEED = 42
|
|
40
42
|
|
|
41
43
|
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
42
44
|
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
43
45
|
|
|
46
|
+
_MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
|
|
47
|
+
_MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
|
|
48
|
+
|
|
44
49
|
_MAX_CONTEXT_SIZE = {
|
|
45
50
|
RunMode.DEBUG: 100,
|
|
46
51
|
RunMode.FAST: 1_000,
|
|
47
52
|
RunMode.NORMAL: 5_000,
|
|
48
53
|
RunMode.BEST: 10_000,
|
|
49
54
|
}
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
RunMode.
|
|
53
|
-
RunMode.
|
|
54
|
-
RunMode.
|
|
55
|
+
|
|
56
|
+
_DEFAULT_NUM_NEIGHBORS = {
|
|
57
|
+
RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
|
|
58
|
+
RunMode.FAST: [32, 32, 8, 8, 4, 4],
|
|
59
|
+
RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
|
|
60
|
+
RunMode.BEST: [64, 64, 8, 8, 4, 4],
|
|
55
61
|
}
|
|
56
62
|
|
|
57
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
@@ -102,25 +108,20 @@ class Explanation:
|
|
|
102
108
|
def __repr__(self) -> str:
|
|
103
109
|
return str((self.prediction, self.summary))
|
|
104
110
|
|
|
111
|
+
def __str__(self) -> str:
|
|
112
|
+
console = Console(soft_wrap=True)
|
|
113
|
+
with console.capture() as cap:
|
|
114
|
+
console.print(display.to_rich_table(self.prediction))
|
|
115
|
+
console.print(Markdown(self.summary))
|
|
116
|
+
return cap.get()[:-1]
|
|
117
|
+
|
|
105
118
|
def print(self) -> None:
|
|
106
119
|
r"""Prints the explanation."""
|
|
107
|
-
if
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
st.markdown(self.summary)
|
|
111
|
-
elif in_notebook():
|
|
112
|
-
from IPython.display import Markdown, display
|
|
113
|
-
try:
|
|
114
|
-
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
-
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
-
else:
|
|
117
|
-
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
-
except ImportError:
|
|
119
|
-
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
-
display(Markdown(self.summary))
|
|
120
|
+
if in_notebook():
|
|
121
|
+
display.dataframe(self.prediction)
|
|
122
|
+
display.message(self.summary)
|
|
121
123
|
else:
|
|
122
|
-
print(self
|
|
123
|
-
print(self.summary)
|
|
124
|
+
print(self)
|
|
124
125
|
|
|
125
126
|
def _ipython_display_(self) -> None:
|
|
126
127
|
self.print()
|
|
@@ -191,7 +192,7 @@ class KumoRFM:
|
|
|
191
192
|
self._client: RFMAPI | None = None
|
|
192
193
|
|
|
193
194
|
self._batch_size: int | Literal['max'] | None = None
|
|
194
|
-
self.
|
|
195
|
+
self._num_retries: int = 0
|
|
195
196
|
|
|
196
197
|
@property
|
|
197
198
|
def _api_client(self) -> RFMAPI:
|
|
@@ -205,6 +206,30 @@ class KumoRFM:
|
|
|
205
206
|
def __repr__(self) -> str:
|
|
206
207
|
return f'{self.__class__.__name__}()'
|
|
207
208
|
|
|
209
|
+
@contextmanager
|
|
210
|
+
def retry(
|
|
211
|
+
self,
|
|
212
|
+
num_retries: int = 1,
|
|
213
|
+
) -> Generator[None, None, None]:
|
|
214
|
+
"""Context manager to retry failed queries due to unexpected server
|
|
215
|
+
issues.
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
with model.retry(num_retries=1):
|
|
220
|
+
df = model.predict(query, indices=...)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
num_retries: The maximum number of retries.
|
|
224
|
+
"""
|
|
225
|
+
if num_retries < 0:
|
|
226
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
227
|
+
f"zero (got {num_retries})")
|
|
228
|
+
|
|
229
|
+
self._num_retries = num_retries
|
|
230
|
+
yield
|
|
231
|
+
self._num_retries = 0
|
|
232
|
+
|
|
208
233
|
@contextmanager
|
|
209
234
|
def batch_mode(
|
|
210
235
|
self,
|
|
@@ -228,15 +253,10 @@ class KumoRFM:
|
|
|
228
253
|
raise ValueError(f"'batch_size' must be greater than zero "
|
|
229
254
|
f"(got {batch_size})")
|
|
230
255
|
|
|
231
|
-
if num_retries < 0:
|
|
232
|
-
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
233
|
-
f"zero (got {num_retries})")
|
|
234
|
-
|
|
235
256
|
self._batch_size = batch_size
|
|
236
|
-
self.
|
|
237
|
-
|
|
257
|
+
with self.retry(self._num_retries or num_retries):
|
|
258
|
+
yield
|
|
238
259
|
self._batch_size = None
|
|
239
|
-
self.num_retries = 0
|
|
240
260
|
|
|
241
261
|
@overload
|
|
242
262
|
def predict(
|
|
@@ -276,6 +296,25 @@ class KumoRFM:
|
|
|
276
296
|
) -> Explanation:
|
|
277
297
|
pass
|
|
278
298
|
|
|
299
|
+
@overload
|
|
300
|
+
def predict(
|
|
301
|
+
self,
|
|
302
|
+
query: str,
|
|
303
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
304
|
+
*,
|
|
305
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
306
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
307
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
308
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
309
|
+
num_neighbors: list[int] | None = None,
|
|
310
|
+
num_hops: int = 2,
|
|
311
|
+
max_pq_iterations: int = 10,
|
|
312
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
313
|
+
verbose: bool | ProgressLogger = True,
|
|
314
|
+
use_prediction_time: bool = False,
|
|
315
|
+
) -> pd.DataFrame | Explanation:
|
|
316
|
+
pass
|
|
317
|
+
|
|
279
318
|
def predict(
|
|
280
319
|
self,
|
|
281
320
|
query: str,
|
|
@@ -299,8 +338,7 @@ class KumoRFM:
|
|
|
299
338
|
indices: The entity primary keys to predict for. Will override the
|
|
300
339
|
indices given as part of the predictive query. Predictions will
|
|
301
340
|
be generated for all indices, independent of whether they
|
|
302
|
-
fulfill entity filter constraints.
|
|
303
|
-
:meth:`~KumoRFM.is_valid_entity`.
|
|
341
|
+
fulfill entity filter constraints.
|
|
304
342
|
explain: Configuration for explainability.
|
|
305
343
|
If set to ``True``, will additionally explain the prediction.
|
|
306
344
|
Passing in an :class:`ExplainConfig` instance provides control
|
|
@@ -333,18 +371,152 @@ class KumoRFM:
|
|
|
333
371
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
334
372
|
containing the prediction, summary, and details.
|
|
335
373
|
"""
|
|
336
|
-
explain_config: ExplainConfig | None = None
|
|
337
|
-
if explain is True:
|
|
338
|
-
explain_config = ExplainConfig()
|
|
339
|
-
elif explain is not False:
|
|
340
|
-
explain_config = ExplainConfig._cast(explain)
|
|
341
|
-
|
|
342
374
|
query_def = self._parse_query(query)
|
|
343
|
-
query_str = query_def.to_string()
|
|
344
375
|
|
|
376
|
+
if indices is None:
|
|
377
|
+
if query_def.rfm_entity_ids is None:
|
|
378
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
379
|
+
"pass them via `predict(query, indices=...)`")
|
|
380
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
381
|
+
query_def = replace(
|
|
382
|
+
query_def,
|
|
383
|
+
for_each='FOR EACH',
|
|
384
|
+
rfm_entity_ids=None,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if not isinstance(verbose, ProgressLogger):
|
|
388
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
389
|
+
if explain is not False:
|
|
390
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
391
|
+
else:
|
|
392
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
393
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
394
|
+
|
|
395
|
+
with verbose as logger:
|
|
396
|
+
task_table = self._get_task_table(
|
|
397
|
+
query=query_def,
|
|
398
|
+
indices=indices,
|
|
399
|
+
anchor_time=anchor_time,
|
|
400
|
+
context_anchor_time=context_anchor_time,
|
|
401
|
+
run_mode=run_mode,
|
|
402
|
+
max_pq_iterations=max_pq_iterations,
|
|
403
|
+
random_seed=random_seed,
|
|
404
|
+
logger=logger,
|
|
405
|
+
)
|
|
406
|
+
task_table._query = query_def.to_string()
|
|
407
|
+
|
|
408
|
+
return self.predict_task(
|
|
409
|
+
task_table,
|
|
410
|
+
explain=explain,
|
|
411
|
+
run_mode=run_mode,
|
|
412
|
+
num_neighbors=num_neighbors,
|
|
413
|
+
num_hops=num_hops,
|
|
414
|
+
verbose=verbose,
|
|
415
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
416
|
+
use_prediction_time=use_prediction_time,
|
|
417
|
+
top_k=query_def.top_k,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
@overload
|
|
421
|
+
def predict_task(
|
|
422
|
+
self,
|
|
423
|
+
task: TaskTable,
|
|
424
|
+
*,
|
|
425
|
+
explain: Literal[False] = False,
|
|
426
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
427
|
+
num_neighbors: list[int] | None = None,
|
|
428
|
+
num_hops: int = 2,
|
|
429
|
+
verbose: bool | ProgressLogger = True,
|
|
430
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
431
|
+
use_prediction_time: bool = False,
|
|
432
|
+
top_k: int | None = None,
|
|
433
|
+
) -> pd.DataFrame:
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
@overload
|
|
437
|
+
def predict_task(
|
|
438
|
+
self,
|
|
439
|
+
task: TaskTable,
|
|
440
|
+
*,
|
|
441
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
442
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
443
|
+
num_neighbors: list[int] | None = None,
|
|
444
|
+
num_hops: int = 2,
|
|
445
|
+
verbose: bool | ProgressLogger = True,
|
|
446
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
447
|
+
use_prediction_time: bool = False,
|
|
448
|
+
top_k: int | None = None,
|
|
449
|
+
) -> Explanation:
|
|
450
|
+
pass
|
|
451
|
+
|
|
452
|
+
@overload
|
|
453
|
+
def predict_task(
|
|
454
|
+
self,
|
|
455
|
+
task: TaskTable,
|
|
456
|
+
*,
|
|
457
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
458
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
459
|
+
num_neighbors: list[int] | None = None,
|
|
460
|
+
num_hops: int = 2,
|
|
461
|
+
verbose: bool | ProgressLogger = True,
|
|
462
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
463
|
+
use_prediction_time: bool = False,
|
|
464
|
+
top_k: int | None = None,
|
|
465
|
+
) -> pd.DataFrame | Explanation:
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
def predict_task(
|
|
469
|
+
self,
|
|
470
|
+
task: TaskTable,
|
|
471
|
+
*,
|
|
472
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
473
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
474
|
+
num_neighbors: list[int] | None = None,
|
|
475
|
+
num_hops: int = 2,
|
|
476
|
+
verbose: bool | ProgressLogger = True,
|
|
477
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
478
|
+
use_prediction_time: bool = False,
|
|
479
|
+
top_k: int | None = None,
|
|
480
|
+
) -> pd.DataFrame | Explanation:
|
|
481
|
+
"""Returns predictions for a custom task specification.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
task: The custom :class:`TaskTable`.
|
|
485
|
+
explain: Configuration for explainability.
|
|
486
|
+
If set to ``True``, will additionally explain the prediction.
|
|
487
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
488
|
+
over which parts of explanation are generated.
|
|
489
|
+
Explainability is currently only supported for single entity
|
|
490
|
+
predictions with ``run_mode="FAST"``.
|
|
491
|
+
run_mode: The :class:`RunMode` for the query.
|
|
492
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
493
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
494
|
+
num_hops: The number of hops to sample when generating the context.
|
|
495
|
+
verbose: Whether to print verbose output.
|
|
496
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
497
|
+
model input.
|
|
498
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
499
|
+
additional feature during prediction. This is typically
|
|
500
|
+
beneficial for time series forecasting tasks.
|
|
501
|
+
top_k: The number of predictions to return per entity.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
505
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
506
|
+
containing the prediction, summary, and details.
|
|
507
|
+
"""
|
|
345
508
|
if num_hops != 2 and num_neighbors is not None:
|
|
346
509
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
347
510
|
f"custom 'num_hops={num_hops}' option")
|
|
511
|
+
if num_neighbors is None:
|
|
512
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
513
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
514
|
+
|
|
515
|
+
explain_config: ExplainConfig | None = None
|
|
516
|
+
if explain is True:
|
|
517
|
+
explain_config = ExplainConfig()
|
|
518
|
+
elif explain is not False:
|
|
519
|
+
explain_config = ExplainConfig._cast(explain)
|
|
348
520
|
|
|
349
521
|
if explain_config is not None and run_mode in {
|
|
350
522
|
RunMode.NORMAL, RunMode.BEST
|
|
@@ -353,83 +525,82 @@ class KumoRFM:
|
|
|
353
525
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
354
526
|
f"mode has been reset. Please lower the run mode to "
|
|
355
527
|
f"suppress this warning.")
|
|
528
|
+
run_mode = RunMode.FAST
|
|
356
529
|
|
|
357
|
-
if
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
362
|
-
else:
|
|
363
|
-
query_def = replace(query_def, rfm_entity_ids=None)
|
|
364
|
-
|
|
365
|
-
if len(indices) == 0:
|
|
366
|
-
raise ValueError("At least one entity is required")
|
|
367
|
-
|
|
368
|
-
if explain_config is not None and len(indices) > 1:
|
|
369
|
-
raise ValueError(
|
|
370
|
-
f"Cannot explain predictions for more than a single entity "
|
|
371
|
-
f"(got {len(indices)})")
|
|
372
|
-
|
|
373
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
374
|
-
if explain_config is not None:
|
|
375
|
-
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
376
|
-
else:
|
|
377
|
-
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
530
|
+
if explain_config is not None and task.num_prediction_examples > 1:
|
|
531
|
+
raise ValueError(f"Cannot explain predictions for more than a "
|
|
532
|
+
f"single entity "
|
|
533
|
+
f"(got {task.num_prediction_examples:,})")
|
|
378
534
|
|
|
379
535
|
if not isinstance(verbose, ProgressLogger):
|
|
536
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
537
|
+
task_type_repr = 'binary classification'
|
|
538
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
539
|
+
task_type_repr = 'multi-class classification'
|
|
540
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
541
|
+
task_type_repr = 'regression'
|
|
542
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
543
|
+
task_type_repr = 'link prediction'
|
|
544
|
+
else:
|
|
545
|
+
task_type_repr = str(task.task_type)
|
|
546
|
+
|
|
547
|
+
if explain_config is not None:
|
|
548
|
+
msg = f"Explaining {task_type_repr} task"
|
|
549
|
+
else:
|
|
550
|
+
msg = f"Predicting {task_type_repr} task"
|
|
380
551
|
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
381
552
|
|
|
382
553
|
with verbose as logger:
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
batch_size =
|
|
554
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
555
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
556
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
557
|
+
f"examples")
|
|
558
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
559
|
+
|
|
560
|
+
if self._batch_size is None:
|
|
561
|
+
batch_size = task.num_prediction_examples
|
|
562
|
+
elif self._batch_size == 'max':
|
|
563
|
+
batch_size = _MAX_PRED_SIZE[task.task_type]
|
|
391
564
|
else:
|
|
392
565
|
batch_size = self._batch_size
|
|
393
566
|
|
|
394
|
-
if batch_size
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
567
|
+
if batch_size > _MAX_PRED_SIZE[task.task_type]:
|
|
568
|
+
raise ValueError(f"Cannot predict for more than "
|
|
569
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
570
|
+
f"entities at once (got {batch_size:,}). Use "
|
|
571
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
572
|
+
f"in batches with a sufficient batch size.")
|
|
399
573
|
|
|
400
|
-
if
|
|
401
|
-
|
|
402
|
-
|
|
574
|
+
if task.num_prediction_examples > batch_size:
|
|
575
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
576
|
+
logger.log(f"Splitting {task.num_prediction_examples:,} "
|
|
577
|
+
f"entities into {num:,} batches of size "
|
|
578
|
+
f"{batch_size:,}")
|
|
403
579
|
|
|
404
580
|
predictions: list[pd.DataFrame] = []
|
|
405
581
|
summary: str | None = None
|
|
406
582
|
details: Explanation | None = None
|
|
407
|
-
for
|
|
408
|
-
# TODO Re-use the context for subsequent predictions.
|
|
583
|
+
for start in range(0, task.num_prediction_examples, batch_size):
|
|
409
584
|
context = self._get_context(
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
anchor_time=anchor_time,
|
|
413
|
-
context_anchor_time=context_anchor_time,
|
|
414
|
-
run_mode=RunMode(run_mode),
|
|
585
|
+
task=task.narrow_prediction(start, length=batch_size),
|
|
586
|
+
run_mode=run_mode,
|
|
415
587
|
num_neighbors=num_neighbors,
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
evaluate=False,
|
|
419
|
-
random_seed=random_seed,
|
|
420
|
-
logger=logger if i == 0 else None,
|
|
588
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
589
|
+
top_k=top_k,
|
|
421
590
|
)
|
|
591
|
+
context.y_test = None
|
|
592
|
+
|
|
422
593
|
request = RFMPredictRequest(
|
|
423
594
|
context=context,
|
|
424
595
|
run_mode=RunMode(run_mode),
|
|
425
|
-
query=
|
|
596
|
+
query=task._query,
|
|
426
597
|
use_prediction_time=use_prediction_time,
|
|
427
598
|
)
|
|
428
599
|
with warnings.catch_warnings():
|
|
429
600
|
warnings.filterwarnings('ignore', message='gencode')
|
|
430
601
|
request_msg = request.to_protobuf()
|
|
431
602
|
_bytes = request_msg.SerializeToString()
|
|
432
|
-
if
|
|
603
|
+
if start == 0:
|
|
433
604
|
logger.log(f"Generated context of size "
|
|
434
605
|
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
435
606
|
|
|
@@ -437,13 +608,11 @@ class KumoRFM:
|
|
|
437
608
|
stats = Context.get_memory_stats(request_msg.context)
|
|
438
609
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
439
610
|
|
|
440
|
-
if
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
description='Predicting',
|
|
444
|
-
)
|
|
611
|
+
if start == 0 and task.num_prediction_examples > batch_size:
|
|
612
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
613
|
+
verbose.init_progress(total=num, description='Predicting')
|
|
445
614
|
|
|
446
|
-
for attempt in range(self.
|
|
615
|
+
for attempt in range(self._num_retries + 1):
|
|
447
616
|
try:
|
|
448
617
|
if explain_config is not None:
|
|
449
618
|
resp = self._api_client.explain(
|
|
@@ -459,7 +628,7 @@ class KumoRFM:
|
|
|
459
628
|
# Cast 'ENTITY' to correct data type:
|
|
460
629
|
if 'ENTITY' in df:
|
|
461
630
|
table_dict = context.subgraph.table_dict
|
|
462
|
-
table = table_dict[
|
|
631
|
+
table = table_dict[context.entity_table_names[0]]
|
|
463
632
|
ser = table.df[table.primary_key]
|
|
464
633
|
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
465
634
|
|
|
@@ -476,12 +645,12 @@ class KumoRFM:
|
|
|
476
645
|
|
|
477
646
|
predictions.append(df)
|
|
478
647
|
|
|
479
|
-
if
|
|
648
|
+
if task.num_prediction_examples > batch_size:
|
|
480
649
|
verbose.step()
|
|
481
650
|
|
|
482
651
|
break
|
|
483
652
|
except HTTPException as e:
|
|
484
|
-
if attempt == self.
|
|
653
|
+
if attempt == self._num_retries:
|
|
485
654
|
try:
|
|
486
655
|
msg = json.loads(e.detail)['detail']
|
|
487
656
|
except Exception:
|
|
@@ -511,51 +680,6 @@ class KumoRFM:
|
|
|
511
680
|
|
|
512
681
|
return prediction
|
|
513
682
|
|
|
514
|
-
def is_valid_entity(
|
|
515
|
-
self,
|
|
516
|
-
query: str,
|
|
517
|
-
indices: list[str] | list[float] | list[int] | None = None,
|
|
518
|
-
*,
|
|
519
|
-
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
520
|
-
) -> np.ndarray:
|
|
521
|
-
r"""Returns a mask that denotes which entities are valid for the
|
|
522
|
-
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
523
|
-
entity filter constraints.
|
|
524
|
-
|
|
525
|
-
Args:
|
|
526
|
-
query: The predictive query.
|
|
527
|
-
indices: The entity primary keys to predict for. Will override the
|
|
528
|
-
indices given as part of the predictive query.
|
|
529
|
-
anchor_time: The anchor timestamp for the prediction. If set to
|
|
530
|
-
``None``, will use the maximum timestamp in the data.
|
|
531
|
-
If set to ``"entity"``, will use the timestamp of the entity.
|
|
532
|
-
"""
|
|
533
|
-
query_def = self._parse_query(query)
|
|
534
|
-
|
|
535
|
-
if indices is None:
|
|
536
|
-
if query_def.rfm_entity_ids is None:
|
|
537
|
-
raise ValueError("Cannot find entities to predict for. Please "
|
|
538
|
-
"pass them via "
|
|
539
|
-
"`is_valid_entity(query, indices=...)`")
|
|
540
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
541
|
-
|
|
542
|
-
if len(indices) == 0:
|
|
543
|
-
raise ValueError("At least one entity is required")
|
|
544
|
-
|
|
545
|
-
if anchor_time is None:
|
|
546
|
-
anchor_time = self._get_default_anchor_time(query_def)
|
|
547
|
-
|
|
548
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
549
|
-
self._validate_time(query_def, anchor_time, None, False)
|
|
550
|
-
else:
|
|
551
|
-
assert anchor_time == 'entity'
|
|
552
|
-
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
553
|
-
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
554
|
-
f"table '{query_def.entity_table}' "
|
|
555
|
-
f"to have a time column.")
|
|
556
|
-
|
|
557
|
-
raise NotImplementedError
|
|
558
|
-
|
|
559
683
|
def evaluate(
|
|
560
684
|
self,
|
|
561
685
|
query: str,
|
|
@@ -600,41 +724,120 @@ class KumoRFM:
|
|
|
600
724
|
Returns:
|
|
601
725
|
The metrics as a :class:`pandas.DataFrame`
|
|
602
726
|
"""
|
|
603
|
-
query_def =
|
|
727
|
+
query_def = replace(
|
|
728
|
+
self._parse_query(query),
|
|
729
|
+
for_each='FOR EACH',
|
|
730
|
+
rfm_entity_ids=None,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
if not isinstance(verbose, ProgressLogger):
|
|
734
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
735
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
736
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
737
|
+
|
|
738
|
+
with verbose as logger:
|
|
739
|
+
task_table = self._get_task_table(
|
|
740
|
+
query=query_def,
|
|
741
|
+
indices=None,
|
|
742
|
+
anchor_time=anchor_time,
|
|
743
|
+
context_anchor_time=context_anchor_time,
|
|
744
|
+
run_mode=run_mode,
|
|
745
|
+
max_pq_iterations=max_pq_iterations,
|
|
746
|
+
random_seed=random_seed,
|
|
747
|
+
logger=logger,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
return self.evaluate_task(
|
|
751
|
+
task_table,
|
|
752
|
+
metrics=metrics,
|
|
753
|
+
run_mode=run_mode,
|
|
754
|
+
num_neighbors=num_neighbors,
|
|
755
|
+
num_hops=num_hops,
|
|
756
|
+
verbose=verbose,
|
|
757
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
758
|
+
use_prediction_time=use_prediction_time,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
def evaluate_task(
|
|
762
|
+
self,
|
|
763
|
+
task: TaskTable,
|
|
764
|
+
*,
|
|
765
|
+
metrics: list[str] | None = None,
|
|
766
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
767
|
+
num_neighbors: list[int] | None = None,
|
|
768
|
+
num_hops: int = 2,
|
|
769
|
+
verbose: bool | ProgressLogger = True,
|
|
770
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
771
|
+
use_prediction_time: bool = False,
|
|
772
|
+
) -> pd.DataFrame:
|
|
773
|
+
"""Evaluates a custom task specification.
|
|
604
774
|
|
|
775
|
+
Args:
|
|
776
|
+
task: The custom :class:`TaskTable`.
|
|
777
|
+
metrics: The metrics to use.
|
|
778
|
+
run_mode: The :class:`RunMode` for the query.
|
|
779
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
780
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
781
|
+
num_hops: The number of hops to sample when generating the context.
|
|
782
|
+
verbose: Whether to print verbose output.
|
|
783
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
784
|
+
model input.
|
|
785
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
786
|
+
additional feature during prediction. This is typically
|
|
787
|
+
beneficial for time series forecasting tasks.
|
|
788
|
+
|
|
789
|
+
Returns:
|
|
790
|
+
The metrics as a :class:`pandas.DataFrame`
|
|
791
|
+
"""
|
|
605
792
|
if num_hops != 2 and num_neighbors is not None:
|
|
606
793
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
607
794
|
f"custom 'num_hops={num_hops}' option")
|
|
795
|
+
if num_neighbors is None:
|
|
796
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
797
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
608
798
|
|
|
609
|
-
if
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
rfm_entity_ids=None,
|
|
613
|
-
)
|
|
614
|
-
|
|
615
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
616
|
-
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
799
|
+
if metrics is not None and len(metrics) > 0:
|
|
800
|
+
self._validate_metrics(metrics, task.task_type)
|
|
801
|
+
metrics = list(dict.fromkeys(metrics))
|
|
617
802
|
|
|
618
803
|
if not isinstance(verbose, ProgressLogger):
|
|
804
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
805
|
+
task_type_repr = 'binary classification'
|
|
806
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
807
|
+
task_type_repr = 'multi-class classification'
|
|
808
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
809
|
+
task_type_repr = 'regression'
|
|
810
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
811
|
+
task_type_repr = 'link prediction'
|
|
812
|
+
else:
|
|
813
|
+
task_type_repr = str(task.task_type)
|
|
814
|
+
|
|
815
|
+
msg = f"Evaluating {task_type_repr} task"
|
|
619
816
|
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
620
817
|
|
|
621
818
|
with verbose as logger:
|
|
819
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
820
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
821
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
822
|
+
f"examples")
|
|
823
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
824
|
+
|
|
825
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
826
|
+
logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
|
|
827
|
+
f"out of {task.num_prediction_examples:,} test "
|
|
828
|
+
f"examples")
|
|
829
|
+
task = task.narrow_prediction(
|
|
830
|
+
start=0,
|
|
831
|
+
length=_MAX_TEST_SIZE[task.task_type],
|
|
832
|
+
)
|
|
833
|
+
|
|
622
834
|
context = self._get_context(
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
anchor_time=anchor_time,
|
|
626
|
-
context_anchor_time=context_anchor_time,
|
|
627
|
-
run_mode=RunMode(run_mode),
|
|
835
|
+
task=task,
|
|
836
|
+
run_mode=run_mode,
|
|
628
837
|
num_neighbors=num_neighbors,
|
|
629
|
-
|
|
630
|
-
max_pq_iterations=max_pq_iterations,
|
|
631
|
-
evaluate=True,
|
|
632
|
-
random_seed=random_seed,
|
|
633
|
-
logger=logger if verbose else None,
|
|
838
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
634
839
|
)
|
|
635
|
-
|
|
636
|
-
self._validate_metrics(metrics, context.task_type)
|
|
637
|
-
metrics = list(dict.fromkeys(metrics))
|
|
840
|
+
|
|
638
841
|
request = RFMEvaluateRequest(
|
|
639
842
|
context=context,
|
|
640
843
|
run_mode=RunMode(run_mode),
|
|
@@ -652,17 +855,23 @@ class KumoRFM:
|
|
|
652
855
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
653
856
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
654
857
|
|
|
655
|
-
|
|
656
|
-
resp = self._api_client.evaluate(request_bytes)
|
|
657
|
-
except HTTPException as e:
|
|
858
|
+
for attempt in range(self._num_retries + 1):
|
|
658
859
|
try:
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
860
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
861
|
+
break
|
|
862
|
+
except HTTPException as e:
|
|
863
|
+
if attempt == self._num_retries:
|
|
864
|
+
try:
|
|
865
|
+
msg = json.loads(e.detail)['detail']
|
|
866
|
+
except Exception:
|
|
867
|
+
msg = e.detail
|
|
868
|
+
raise RuntimeError(
|
|
869
|
+
f"An unexpected exception occurred. Please create "
|
|
870
|
+
f"an issue at "
|
|
871
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
872
|
+
) from None
|
|
873
|
+
|
|
874
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
666
875
|
|
|
667
876
|
return pd.DataFrame.from_dict(
|
|
668
877
|
resp.metrics,
|
|
@@ -714,7 +923,7 @@ class KumoRFM:
|
|
|
714
923
|
f"to have a time column")
|
|
715
924
|
|
|
716
925
|
train, test = self._sampler.sample_target(
|
|
717
|
-
query=
|
|
926
|
+
query=query_def,
|
|
718
927
|
num_train_examples=0,
|
|
719
928
|
train_anchor_time=anchor_time,
|
|
720
929
|
num_train_trials=0,
|
|
@@ -742,30 +951,34 @@ class KumoRFM:
|
|
|
742
951
|
"`predict()` or `evaluate()` methods to perform "
|
|
743
952
|
"predictions or evaluations.")
|
|
744
953
|
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
)
|
|
954
|
+
request = RFMParseQueryRequest(
|
|
955
|
+
query=query,
|
|
956
|
+
graph_definition=self._graph_def,
|
|
957
|
+
)
|
|
750
958
|
|
|
751
|
-
|
|
959
|
+
for attempt in range(self._num_retries + 1):
|
|
960
|
+
try:
|
|
961
|
+
resp = self._api_client.parse_query(request)
|
|
962
|
+
break
|
|
963
|
+
except HTTPException as e:
|
|
964
|
+
if attempt == self._num_retries:
|
|
965
|
+
try:
|
|
966
|
+
msg = json.loads(e.detail)['detail']
|
|
967
|
+
except Exception:
|
|
968
|
+
msg = e.detail
|
|
969
|
+
raise ValueError(f"Failed to parse query '{query}'. {msg}")
|
|
752
970
|
|
|
753
|
-
|
|
754
|
-
msg = '\n'.join([
|
|
755
|
-
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
756
|
-
in enumerate(resp.validation_response.warnings)
|
|
757
|
-
])
|
|
758
|
-
warnings.warn(f"Encountered the following warnings during "
|
|
759
|
-
f"parsing:\n{msg}")
|
|
971
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
760
972
|
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
973
|
+
if len(resp.validation_response.warnings) > 0:
|
|
974
|
+
msg = '\n'.join([
|
|
975
|
+
f'{i+1}. {warning.title}: {warning.message}'
|
|
976
|
+
for i, warning in enumerate(resp.validation_response.warnings)
|
|
977
|
+
])
|
|
978
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
979
|
+
f"parsing:\n{msg}")
|
|
980
|
+
|
|
981
|
+
return resp.query
|
|
769
982
|
|
|
770
983
|
@staticmethod
|
|
771
984
|
def _get_task_type(
|
|
@@ -809,16 +1022,15 @@ class KumoRFM:
|
|
|
809
1022
|
|
|
810
1023
|
def _get_default_anchor_time(
|
|
811
1024
|
self,
|
|
812
|
-
query: ValidatedPredictiveQuery,
|
|
1025
|
+
query: ValidatedPredictiveQuery | None = None,
|
|
813
1026
|
) -> pd.Timestamp:
|
|
814
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1027
|
+
if query is not None and query.query_type == QueryType.TEMPORAL:
|
|
815
1028
|
aggr_table_names = [
|
|
816
1029
|
aggr._get_target_column_name().split('.')[0]
|
|
817
1030
|
for aggr in query.get_all_target_aggregations()
|
|
818
1031
|
]
|
|
819
1032
|
return self._sampler.get_max_time(aggr_table_names)
|
|
820
1033
|
|
|
821
|
-
assert query.query_type == QueryType.STATIC
|
|
822
1034
|
return self._sampler.get_max_time()
|
|
823
1035
|
|
|
824
1036
|
def _validate_time(
|
|
@@ -888,40 +1100,26 @@ class KumoRFM:
|
|
|
888
1100
|
f"Anchor timestamp for evaluation is after the latest "
|
|
889
1101
|
f"supported timestamp '{max_time - end_offset}'.")
|
|
890
1102
|
|
|
891
|
-
def
|
|
1103
|
+
def _get_task_table(
|
|
892
1104
|
self,
|
|
893
1105
|
query: ValidatedPredictiveQuery,
|
|
894
1106
|
indices: list[str] | list[float] | list[int] | None,
|
|
895
|
-
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
896
|
-
context_anchor_time: pd.Timestamp | None,
|
|
897
|
-
run_mode: RunMode,
|
|
898
|
-
|
|
899
|
-
num_hops: int,
|
|
900
|
-
max_pq_iterations: int,
|
|
901
|
-
evaluate: bool,
|
|
1107
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
1108
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
1109
|
+
run_mode: RunMode = RunMode.FAST,
|
|
1110
|
+
max_pq_iterations: int = 10,
|
|
902
1111
|
random_seed: int | None = _RANDOM_SEED,
|
|
903
1112
|
logger: ProgressLogger | None = None,
|
|
904
|
-
) ->
|
|
905
|
-
|
|
906
|
-
if num_neighbors is not None:
|
|
907
|
-
num_hops = len(num_neighbors)
|
|
908
|
-
|
|
909
|
-
if num_hops < 0:
|
|
910
|
-
raise ValueError(f"'num_hops' must be non-negative "
|
|
911
|
-
f"(got {num_hops})")
|
|
912
|
-
if num_hops > 6:
|
|
913
|
-
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
914
|
-
f"hops (got {num_hops}). Please reduce the "
|
|
915
|
-
f"number of hops and try again. Please create a "
|
|
916
|
-
f"feature request at "
|
|
917
|
-
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
918
|
-
f"must go beyond this for your use-case.")
|
|
1113
|
+
) -> TaskTable:
|
|
919
1114
|
|
|
920
1115
|
task_type = self._get_task_type(
|
|
921
1116
|
query=query,
|
|
922
1117
|
edge_types=self._sampler.edge_types,
|
|
923
1118
|
)
|
|
924
1119
|
|
|
1120
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1121
|
+
num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
|
|
1122
|
+
|
|
925
1123
|
if logger is not None:
|
|
926
1124
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
927
1125
|
task_type_repr = 'binary classification'
|
|
@@ -935,21 +1133,6 @@ class KumoRFM:
|
|
|
935
1133
|
task_type_repr = str(task_type)
|
|
936
1134
|
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
937
1135
|
|
|
938
|
-
if task_type.is_link_pred and num_hops < 2:
|
|
939
|
-
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
940
|
-
f"with less than 2 hops (got {num_hops}) since "
|
|
941
|
-
f"historical target entities need to be part of "
|
|
942
|
-
f"the context. Please increase the number of "
|
|
943
|
-
f"hops and try again.")
|
|
944
|
-
|
|
945
|
-
if num_neighbors is None:
|
|
946
|
-
if run_mode == RunMode.DEBUG:
|
|
947
|
-
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
948
|
-
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
949
|
-
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
950
|
-
else:
|
|
951
|
-
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
952
|
-
|
|
953
1136
|
if query.target_ast.date_offset_range is None:
|
|
954
1137
|
step_offset = pd.DateOffset(0)
|
|
955
1138
|
else:
|
|
@@ -958,8 +1141,7 @@ class KumoRFM:
|
|
|
958
1141
|
|
|
959
1142
|
if anchor_time is None:
|
|
960
1143
|
anchor_time = self._get_default_anchor_time(query)
|
|
961
|
-
|
|
962
|
-
if evaluate:
|
|
1144
|
+
if num_test_examples > 0:
|
|
963
1145
|
anchor_time = anchor_time - end_offset
|
|
964
1146
|
|
|
965
1147
|
if logger is not None:
|
|
@@ -973,7 +1155,6 @@ class KumoRFM:
|
|
|
973
1155
|
else:
|
|
974
1156
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
975
1157
|
|
|
976
|
-
assert anchor_time is not None
|
|
977
1158
|
if isinstance(anchor_time, pd.Timestamp):
|
|
978
1159
|
if context_anchor_time == 'entity':
|
|
979
1160
|
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
@@ -981,7 +1162,7 @@ class KumoRFM:
|
|
|
981
1162
|
if context_anchor_time is None:
|
|
982
1163
|
context_anchor_time = anchor_time - end_offset
|
|
983
1164
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
984
|
-
evaluate)
|
|
1165
|
+
evaluate=num_test_examples > 0)
|
|
985
1166
|
else:
|
|
986
1167
|
assert anchor_time == 'entity'
|
|
987
1168
|
if query.query_type != QueryType.STATIC:
|
|
@@ -996,14 +1177,6 @@ class KumoRFM:
|
|
|
996
1177
|
"for context and prediction examples")
|
|
997
1178
|
context_anchor_time = 'entity'
|
|
998
1179
|
|
|
999
|
-
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1000
|
-
if evaluate:
|
|
1001
|
-
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
1002
|
-
if task_type.is_link_pred:
|
|
1003
|
-
num_test_examples = num_test_examples // 5
|
|
1004
|
-
else:
|
|
1005
|
-
num_test_examples = 0
|
|
1006
|
-
|
|
1007
1180
|
train, test = self._sampler.sample_target(
|
|
1008
1181
|
query=query,
|
|
1009
1182
|
num_train_examples=num_train_examples,
|
|
@@ -1014,39 +1187,32 @@ class KumoRFM:
|
|
|
1014
1187
|
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1015
1188
|
random_seed=random_seed,
|
|
1016
1189
|
)
|
|
1017
|
-
train_pkey, train_time,
|
|
1018
|
-
test_pkey, test_time,
|
|
1190
|
+
train_pkey, train_time, train_y = train
|
|
1191
|
+
test_pkey, test_time, test_y = test
|
|
1019
1192
|
|
|
1020
|
-
if
|
|
1193
|
+
if num_test_examples > 0 and logger is not None:
|
|
1021
1194
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1022
|
-
pos = 100 * int((
|
|
1023
|
-
msg = (f"Collected {len(
|
|
1195
|
+
pos = 100 * int((test_y > 0).sum()) / len(test_y)
|
|
1196
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1024
1197
|
f"{pos:.2f}% positive cases")
|
|
1025
1198
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1026
|
-
msg = (f"Collected {len(
|
|
1027
|
-
f"{
|
|
1199
|
+
msg = (f"Collected {len(test_y):,} test examples holding "
|
|
1200
|
+
f"{test_y.nunique()} classes")
|
|
1028
1201
|
elif task_type == TaskType.REGRESSION:
|
|
1029
|
-
_min, _max = float(
|
|
1030
|
-
msg = (f"Collected {len(
|
|
1202
|
+
_min, _max = float(test_y.min()), float(test_y.max())
|
|
1203
|
+
msg = (f"Collected {len(test_y):,} test examples with targets "
|
|
1031
1204
|
f"between {format_value(_min)} and "
|
|
1032
1205
|
f"{format_value(_max)}")
|
|
1033
1206
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1034
|
-
num_rhs =
|
|
1035
|
-
msg = (f"Collected {len(
|
|
1207
|
+
num_rhs = test_y.explode().nunique()
|
|
1208
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1036
1209
|
f"{num_rhs:,} unique items")
|
|
1037
1210
|
else:
|
|
1038
1211
|
raise NotImplementedError
|
|
1039
1212
|
logger.log(msg)
|
|
1040
1213
|
|
|
1041
|
-
if
|
|
1214
|
+
if num_test_examples == 0:
|
|
1042
1215
|
assert indices is not None
|
|
1043
|
-
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
1044
|
-
raise ValueError(f"Cannot predict for more than "
|
|
1045
|
-
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
1046
|
-
f"once (got {len(indices):,}). Use "
|
|
1047
|
-
f"`KumoRFM.batch_mode` to process entities "
|
|
1048
|
-
f"in batches")
|
|
1049
|
-
|
|
1050
1216
|
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
1051
1217
|
if isinstance(anchor_time, pd.Timestamp):
|
|
1052
1218
|
test_time = pd.Series([anchor_time]).repeat(
|
|
@@ -1056,26 +1222,26 @@ class KumoRFM:
|
|
|
1056
1222
|
|
|
1057
1223
|
if logger is not None:
|
|
1058
1224
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1059
|
-
pos = 100 * int((
|
|
1060
|
-
msg = (f"Collected {len(
|
|
1225
|
+
pos = 100 * int((train_y > 0).sum()) / len(train_y)
|
|
1226
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1061
1227
|
f"{pos:.2f}% positive cases")
|
|
1062
1228
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1063
|
-
msg = (f"Collected {len(
|
|
1064
|
-
f"holding {
|
|
1229
|
+
msg = (f"Collected {len(train_y):,} in-context examples "
|
|
1230
|
+
f"holding {train_y.nunique()} classes")
|
|
1065
1231
|
elif task_type == TaskType.REGRESSION:
|
|
1066
|
-
_min, _max = float(
|
|
1067
|
-
msg = (f"Collected {len(
|
|
1232
|
+
_min, _max = float(train_y.min()), float(train_y.max())
|
|
1233
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1068
1234
|
f"targets between {format_value(_min)} and "
|
|
1069
1235
|
f"{format_value(_max)}")
|
|
1070
1236
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1071
|
-
num_rhs =
|
|
1072
|
-
msg = (f"Collected {len(
|
|
1237
|
+
num_rhs = train_y.explode().nunique()
|
|
1238
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1073
1239
|
f"{num_rhs:,} unique items")
|
|
1074
1240
|
else:
|
|
1075
1241
|
raise NotImplementedError
|
|
1076
1242
|
logger.log(msg)
|
|
1077
1243
|
|
|
1078
|
-
entity_table_names: tuple[str,
|
|
1244
|
+
entity_table_names: tuple[str] | tuple[str, str]
|
|
1079
1245
|
if task_type.is_link_pred:
|
|
1080
1246
|
final_aggr = query.get_final_target_aggregation()
|
|
1081
1247
|
assert final_aggr is not None
|
|
@@ -1089,27 +1255,80 @@ class KumoRFM:
|
|
|
1089
1255
|
else:
|
|
1090
1256
|
entity_table_names = (query.entity_table, )
|
|
1091
1257
|
|
|
1258
|
+
context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
|
|
1259
|
+
if isinstance(train_time, pd.Series):
|
|
1260
|
+
context_df['ANCHOR_TIMESTAMP'] = train_time
|
|
1261
|
+
pred_df = pd.DataFrame({'ENTITY': test_pkey})
|
|
1262
|
+
if num_test_examples > 0:
|
|
1263
|
+
pred_df['TARGET'] = test_y
|
|
1264
|
+
if isinstance(test_time, pd.Series):
|
|
1265
|
+
pred_df['ANCHOR_TIMESTAMP'] = test_time
|
|
1266
|
+
|
|
1267
|
+
return TaskTable(
|
|
1268
|
+
task_type=task_type,
|
|
1269
|
+
context_df=context_df,
|
|
1270
|
+
pred_df=pred_df,
|
|
1271
|
+
entity_table_name=entity_table_names,
|
|
1272
|
+
entity_column='ENTITY',
|
|
1273
|
+
target_column='TARGET',
|
|
1274
|
+
time_column='ANCHOR_TIMESTAMP' if isinstance(
|
|
1275
|
+
train_time, pd.Series) else TaskTable.ENTITY_TIME,
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
def _get_context(
|
|
1279
|
+
self,
|
|
1280
|
+
task: TaskTable,
|
|
1281
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
1282
|
+
num_neighbors: list[int] | None = None,
|
|
1283
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
1284
|
+
top_k: int | None = None,
|
|
1285
|
+
) -> Context:
|
|
1286
|
+
|
|
1287
|
+
if num_neighbors is None:
|
|
1288
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1289
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
|
1290
|
+
|
|
1291
|
+
if len(num_neighbors) > 6:
|
|
1292
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
1293
|
+
f"hops (got {len(num_neighbors)}). Reduce the "
|
|
1294
|
+
f"number of hops and try again. Please create a "
|
|
1295
|
+
f"feature request at "
|
|
1296
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1297
|
+
f"must go beyond this for your use-case.")
|
|
1298
|
+
|
|
1092
1299
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1093
1300
|
# running out-of-distribution between in-context and test examples:
|
|
1094
|
-
exclude_cols_dict =
|
|
1095
|
-
if
|
|
1096
|
-
if
|
|
1097
|
-
exclude_cols_dict[
|
|
1098
|
-
|
|
1099
|
-
exclude_cols_dict[
|
|
1301
|
+
exclude_cols_dict = exclude_cols_dict or {}
|
|
1302
|
+
if task.entity_table_name in self._sampler.time_column_dict:
|
|
1303
|
+
if task.entity_table_name not in exclude_cols_dict:
|
|
1304
|
+
exclude_cols_dict[task.entity_table_name] = []
|
|
1305
|
+
time_col = self._sampler.time_column_dict[task.entity_table_name]
|
|
1306
|
+
exclude_cols_dict[task.entity_table_name].append(time_col)
|
|
1307
|
+
|
|
1308
|
+
entity_pkey = pd.concat([
|
|
1309
|
+
task._context_df[task._entity_column],
|
|
1310
|
+
task._pred_df[task._entity_column],
|
|
1311
|
+
], axis=0, ignore_index=True)
|
|
1312
|
+
|
|
1313
|
+
if task.use_entity_time:
|
|
1314
|
+
if task.entity_table_name not in self._sampler.time_column_dict:
|
|
1315
|
+
raise ValueError(f"The given annchor time requires the entity "
|
|
1316
|
+
f"table '{task.entity_table_name}' to have a "
|
|
1317
|
+
f"time column")
|
|
1318
|
+
anchor_time = 'entity'
|
|
1319
|
+
elif task._time_column is not None:
|
|
1320
|
+
anchor_time = pd.concat([
|
|
1321
|
+
task._context_df[task._time_column],
|
|
1322
|
+
task._pred_df[task._time_column],
|
|
1323
|
+
], axis=0, ignore_index=True)
|
|
1324
|
+
else:
|
|
1325
|
+
anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
|
|
1326
|
+
(len(entity_pkey))).reset_index(drop=True)
|
|
1100
1327
|
|
|
1101
1328
|
subgraph = self._sampler.sample_subgraph(
|
|
1102
|
-
entity_table_names=entity_table_names,
|
|
1103
|
-
entity_pkey=
|
|
1104
|
-
|
|
1105
|
-
axis=0,
|
|
1106
|
-
ignore_index=True,
|
|
1107
|
-
),
|
|
1108
|
-
anchor_time=pd.concat(
|
|
1109
|
-
[train_time, test_time],
|
|
1110
|
-
axis=0,
|
|
1111
|
-
ignore_index=True,
|
|
1112
|
-
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1329
|
+
entity_table_names=task.entity_table_names,
|
|
1330
|
+
entity_pkey=entity_pkey,
|
|
1331
|
+
anchor_time=anchor_time,
|
|
1113
1332
|
num_neighbors=num_neighbors,
|
|
1114
1333
|
exclude_cols_dict=exclude_cols_dict,
|
|
1115
1334
|
)
|
|
@@ -1121,13 +1340,20 @@ class KumoRFM:
|
|
|
1121
1340
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1122
1341
|
f"must go beyond this for your use-case.")
|
|
1123
1342
|
|
|
1343
|
+
if (task.task_type.is_link_pred
|
|
1344
|
+
and task.entity_table_names[-1] not in subgraph.table_dict):
|
|
1345
|
+
raise ValueError("Cannot perform link prediction on subgraphs "
|
|
1346
|
+
"without any historical target entities. Please "
|
|
1347
|
+
"increase the number of hops and try again.")
|
|
1348
|
+
|
|
1124
1349
|
return Context(
|
|
1125
|
-
task_type=task_type,
|
|
1126
|
-
entity_table_names=entity_table_names,
|
|
1350
|
+
task_type=task.task_type,
|
|
1351
|
+
entity_table_names=task.entity_table_names,
|
|
1127
1352
|
subgraph=subgraph,
|
|
1128
|
-
y_train=
|
|
1129
|
-
y_test=
|
|
1130
|
-
|
|
1353
|
+
y_train=task._context_df[task.target_column.name],
|
|
1354
|
+
y_test=task._pred_df[task.target_column.name]
|
|
1355
|
+
if task.evaluate else None,
|
|
1356
|
+
top_k=top_k,
|
|
1131
1357
|
step_size=None,
|
|
1132
1358
|
)
|
|
1133
1359
|
|