kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +49 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +32 -14
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +186 -39
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
- kumoai/experimental/rfm/base/__init__.py +23 -3
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +380 -185
- kumoai/experimental/rfm/graph.py +404 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +52 -60
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +283 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -2,25 +2,22 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
8
|
+
from typing import Any, Literal, overload
|
|
19
9
|
|
|
20
10
|
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
24
21
|
from kumoapi.rfm import Context
|
|
25
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
23
|
from kumoapi.rfm import (
|
|
@@ -29,18 +26,14 @@ from kumoapi.rfm import (
|
|
|
29
26
|
RFMPredictRequest,
|
|
30
27
|
)
|
|
31
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
30
|
|
|
33
31
|
from kumoai.client.rfm import RFMAPI
|
|
34
32
|
from kumoai.exceptions import HTTPException
|
|
35
33
|
from kumoai.experimental.rfm import Graph
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
34
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
35
|
from kumoai.mixin import CastMixin
|
|
43
|
-
from kumoai.utils import
|
|
36
|
+
from kumoai.utils import ProgressLogger, display
|
|
44
37
|
|
|
45
38
|
_RANDOM_SEED = 42
|
|
46
39
|
|
|
@@ -95,24 +88,26 @@ class Explanation:
|
|
|
95
88
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
89
|
pass
|
|
97
90
|
|
|
98
|
-
def __getitem__(self, index: int) ->
|
|
91
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
99
92
|
if index == 0:
|
|
100
93
|
return self.prediction
|
|
101
94
|
if index == 1:
|
|
102
95
|
return self.summary
|
|
103
96
|
raise IndexError("Index out of range")
|
|
104
97
|
|
|
105
|
-
def __iter__(self) -> Iterator[
|
|
98
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
106
99
|
return iter((self.prediction, self.summary))
|
|
107
100
|
|
|
108
101
|
def __repr__(self) -> str:
|
|
109
102
|
return str((self.prediction, self.summary))
|
|
110
103
|
|
|
111
|
-
def
|
|
112
|
-
|
|
104
|
+
def print(self) -> None:
|
|
105
|
+
r"""Prints the explanation."""
|
|
106
|
+
display.dataframe(self.prediction)
|
|
107
|
+
display.message(self.summary)
|
|
113
108
|
|
|
114
|
-
|
|
115
|
-
|
|
109
|
+
def _ipython_display_(self) -> None:
|
|
110
|
+
self.print()
|
|
116
111
|
|
|
117
112
|
|
|
118
113
|
class KumoRFM:
|
|
@@ -151,20 +146,35 @@ class KumoRFM:
|
|
|
151
146
|
Args:
|
|
152
147
|
graph: The graph.
|
|
153
148
|
verbose: Whether to print verbose output.
|
|
149
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
150
|
+
for optimal querying. For example, for transactional database
|
|
151
|
+
backends, will create any missing indices. Requires write-access to
|
|
152
|
+
the data backend.
|
|
154
153
|
"""
|
|
155
154
|
def __init__(
|
|
156
155
|
self,
|
|
157
156
|
graph: Graph,
|
|
158
|
-
verbose:
|
|
157
|
+
verbose: bool | ProgressLogger = True,
|
|
158
|
+
optimize: bool = False,
|
|
159
159
|
) -> None:
|
|
160
160
|
graph = graph.validate()
|
|
161
161
|
self._graph_def = graph._to_api_graph_definition()
|
|
162
|
-
self._graph_store = LocalGraphStore(graph, verbose)
|
|
163
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
164
162
|
|
|
165
|
-
|
|
163
|
+
if graph.backend == DataBackend.LOCAL:
|
|
164
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
165
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
166
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
167
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
168
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
169
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
170
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
171
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
172
|
+
else:
|
|
173
|
+
raise NotImplementedError
|
|
166
174
|
|
|
167
|
-
self.
|
|
175
|
+
self._client: RFMAPI | None = None
|
|
176
|
+
|
|
177
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
168
178
|
self.num_retries: int = 0
|
|
169
179
|
|
|
170
180
|
@property
|
|
@@ -182,7 +192,7 @@ class KumoRFM:
|
|
|
182
192
|
@contextmanager
|
|
183
193
|
def batch_mode(
|
|
184
194
|
self,
|
|
185
|
-
batch_size:
|
|
195
|
+
batch_size: int | Literal['max'] = 'max',
|
|
186
196
|
num_retries: int = 1,
|
|
187
197
|
) -> Generator[None, None, None]:
|
|
188
198
|
"""Context manager to predict in batches.
|
|
@@ -216,17 +226,17 @@ class KumoRFM:
|
|
|
216
226
|
def predict(
|
|
217
227
|
self,
|
|
218
228
|
query: str,
|
|
219
|
-
indices:
|
|
229
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
220
230
|
*,
|
|
221
231
|
explain: Literal[False] = False,
|
|
222
|
-
anchor_time:
|
|
223
|
-
context_anchor_time:
|
|
224
|
-
run_mode:
|
|
225
|
-
num_neighbors:
|
|
232
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
233
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
234
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
235
|
+
num_neighbors: list[int] | None = None,
|
|
226
236
|
num_hops: int = 2,
|
|
227
|
-
max_pq_iterations: int =
|
|
228
|
-
random_seed:
|
|
229
|
-
verbose:
|
|
237
|
+
max_pq_iterations: int = 10,
|
|
238
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
239
|
+
verbose: bool | ProgressLogger = True,
|
|
230
240
|
use_prediction_time: bool = False,
|
|
231
241
|
) -> pd.DataFrame:
|
|
232
242
|
pass
|
|
@@ -235,17 +245,17 @@ class KumoRFM:
|
|
|
235
245
|
def predict(
|
|
236
246
|
self,
|
|
237
247
|
query: str,
|
|
238
|
-
indices:
|
|
248
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
239
249
|
*,
|
|
240
|
-
explain:
|
|
241
|
-
anchor_time:
|
|
242
|
-
context_anchor_time:
|
|
243
|
-
run_mode:
|
|
244
|
-
num_neighbors:
|
|
250
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
251
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
252
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
253
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
254
|
+
num_neighbors: list[int] | None = None,
|
|
245
255
|
num_hops: int = 2,
|
|
246
|
-
max_pq_iterations: int =
|
|
247
|
-
random_seed:
|
|
248
|
-
verbose:
|
|
256
|
+
max_pq_iterations: int = 10,
|
|
257
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
258
|
+
verbose: bool | ProgressLogger = True,
|
|
249
259
|
use_prediction_time: bool = False,
|
|
250
260
|
) -> Explanation:
|
|
251
261
|
pass
|
|
@@ -253,19 +263,19 @@ class KumoRFM:
|
|
|
253
263
|
def predict(
|
|
254
264
|
self,
|
|
255
265
|
query: str,
|
|
256
|
-
indices:
|
|
266
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
257
267
|
*,
|
|
258
|
-
explain:
|
|
259
|
-
anchor_time:
|
|
260
|
-
context_anchor_time:
|
|
261
|
-
run_mode:
|
|
262
|
-
num_neighbors:
|
|
268
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
269
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
270
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
271
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
272
|
+
num_neighbors: list[int] | None = None,
|
|
263
273
|
num_hops: int = 2,
|
|
264
|
-
max_pq_iterations: int =
|
|
265
|
-
random_seed:
|
|
266
|
-
verbose:
|
|
274
|
+
max_pq_iterations: int = 10,
|
|
275
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
276
|
+
verbose: bool | ProgressLogger = True,
|
|
267
277
|
use_prediction_time: bool = False,
|
|
268
|
-
) ->
|
|
278
|
+
) -> pd.DataFrame | Explanation:
|
|
269
279
|
"""Returns predictions for a predictive query.
|
|
270
280
|
|
|
271
281
|
Args:
|
|
@@ -307,7 +317,7 @@ class KumoRFM:
|
|
|
307
317
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
308
318
|
containing the prediction, summary, and details.
|
|
309
319
|
"""
|
|
310
|
-
explain_config:
|
|
320
|
+
explain_config: ExplainConfig | None = None
|
|
311
321
|
if explain is True:
|
|
312
322
|
explain_config = ExplainConfig()
|
|
313
323
|
elif explain is not False:
|
|
@@ -351,15 +361,15 @@ class KumoRFM:
|
|
|
351
361
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
352
362
|
|
|
353
363
|
if not isinstance(verbose, ProgressLogger):
|
|
354
|
-
verbose =
|
|
364
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
355
365
|
|
|
356
366
|
with verbose as logger:
|
|
357
367
|
|
|
358
|
-
batch_size:
|
|
368
|
+
batch_size: int | None = None
|
|
359
369
|
if self._batch_size == 'max':
|
|
360
|
-
task_type =
|
|
361
|
-
query_def,
|
|
362
|
-
edge_types=self.
|
|
370
|
+
task_type = self._get_task_type(
|
|
371
|
+
query=query_def,
|
|
372
|
+
edge_types=self._sampler.edge_types,
|
|
363
373
|
)
|
|
364
374
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
365
375
|
else:
|
|
@@ -375,9 +385,9 @@ class KumoRFM:
|
|
|
375
385
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
376
386
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
377
387
|
|
|
378
|
-
predictions:
|
|
379
|
-
summary:
|
|
380
|
-
details:
|
|
388
|
+
predictions: list[pd.DataFrame] = []
|
|
389
|
+
summary: str | None = None
|
|
390
|
+
details: Explanation | None = None
|
|
381
391
|
for i, batch in enumerate(batches):
|
|
382
392
|
# TODO Re-use the context for subsequent predictions.
|
|
383
393
|
context = self._get_context(
|
|
@@ -411,8 +421,7 @@ class KumoRFM:
|
|
|
411
421
|
stats = Context.get_memory_stats(request_msg.context)
|
|
412
422
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
413
423
|
|
|
414
|
-
if
|
|
415
|
-
and len(batches) > 1):
|
|
424
|
+
if i == 0 and len(batches) > 1:
|
|
416
425
|
verbose.init_progress(
|
|
417
426
|
total=len(batches),
|
|
418
427
|
description='Predicting',
|
|
@@ -433,10 +442,10 @@ class KumoRFM:
|
|
|
433
442
|
|
|
434
443
|
# Cast 'ENTITY' to correct data type:
|
|
435
444
|
if 'ENTITY' in df:
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
445
|
+
table_dict = context.subgraph.table_dict
|
|
446
|
+
table = table_dict[query_def.entity_table]
|
|
447
|
+
ser = table.df[table.primary_key]
|
|
448
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
440
449
|
|
|
441
450
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
442
451
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -451,8 +460,7 @@ class KumoRFM:
|
|
|
451
460
|
|
|
452
461
|
predictions.append(df)
|
|
453
462
|
|
|
454
|
-
if (
|
|
455
|
-
and len(batches) > 1):
|
|
463
|
+
if len(batches) > 1:
|
|
456
464
|
verbose.step()
|
|
457
465
|
|
|
458
466
|
break
|
|
@@ -490,9 +498,9 @@ class KumoRFM:
|
|
|
490
498
|
def is_valid_entity(
|
|
491
499
|
self,
|
|
492
500
|
query: str,
|
|
493
|
-
indices:
|
|
501
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
494
502
|
*,
|
|
495
|
-
anchor_time:
|
|
503
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
496
504
|
) -> np.ndarray:
|
|
497
505
|
r"""Returns a mask that denotes which entities are valid for the
|
|
498
506
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -519,37 +527,32 @@ class KumoRFM:
|
|
|
519
527
|
raise ValueError("At least one entity is required")
|
|
520
528
|
|
|
521
529
|
if anchor_time is None:
|
|
522
|
-
anchor_time = self.
|
|
530
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
523
531
|
|
|
524
532
|
if isinstance(anchor_time, pd.Timestamp):
|
|
525
533
|
self._validate_time(query_def, anchor_time, None, False)
|
|
526
534
|
else:
|
|
527
535
|
assert anchor_time == 'entity'
|
|
528
|
-
if
|
|
536
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
529
537
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
538
|
f"table '{query_def.entity_table}' "
|
|
531
539
|
f"to have a time column.")
|
|
532
540
|
|
|
533
|
-
|
|
534
|
-
table_name=query_def.entity_table,
|
|
535
|
-
pkey=pd.Series(indices),
|
|
536
|
-
)
|
|
537
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
538
|
-
return query_driver.is_valid(node, anchor_time)
|
|
541
|
+
raise NotImplementedError
|
|
539
542
|
|
|
540
543
|
def evaluate(
|
|
541
544
|
self,
|
|
542
545
|
query: str,
|
|
543
546
|
*,
|
|
544
|
-
metrics:
|
|
545
|
-
anchor_time:
|
|
546
|
-
context_anchor_time:
|
|
547
|
-
run_mode:
|
|
548
|
-
num_neighbors:
|
|
547
|
+
metrics: list[str] | None = None,
|
|
548
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
549
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
550
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
551
|
+
num_neighbors: list[int] | None = None,
|
|
549
552
|
num_hops: int = 2,
|
|
550
|
-
max_pq_iterations: int =
|
|
551
|
-
random_seed:
|
|
552
|
-
verbose:
|
|
553
|
+
max_pq_iterations: int = 10,
|
|
554
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
555
|
+
verbose: bool | ProgressLogger = True,
|
|
553
556
|
use_prediction_time: bool = False,
|
|
554
557
|
) -> pd.DataFrame:
|
|
555
558
|
"""Evaluates a predictive query.
|
|
@@ -597,7 +600,7 @@ class KumoRFM:
|
|
|
597
600
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
598
601
|
|
|
599
602
|
if not isinstance(verbose, ProgressLogger):
|
|
600
|
-
verbose =
|
|
603
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
601
604
|
|
|
602
605
|
with verbose as logger:
|
|
603
606
|
context = self._get_context(
|
|
@@ -656,9 +659,9 @@ class KumoRFM:
|
|
|
656
659
|
query: str,
|
|
657
660
|
size: int,
|
|
658
661
|
*,
|
|
659
|
-
anchor_time:
|
|
660
|
-
random_seed:
|
|
661
|
-
max_iterations: int =
|
|
662
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
663
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
664
|
+
max_iterations: int = 10,
|
|
662
665
|
) -> pd.DataFrame:
|
|
663
666
|
"""Returns the labels of a predictive query for a specified anchor
|
|
664
667
|
time.
|
|
@@ -678,40 +681,37 @@ class KumoRFM:
|
|
|
678
681
|
query_def = self._parse_query(query)
|
|
679
682
|
|
|
680
683
|
if anchor_time is None:
|
|
681
|
-
anchor_time = self.
|
|
684
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
682
685
|
if query_def.target_ast.date_offset_range is not None:
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
687
|
+
offset *= query_def.num_forecasts
|
|
688
|
+
anchor_time -= offset
|
|
686
689
|
|
|
687
690
|
assert anchor_time is not None
|
|
688
691
|
if isinstance(anchor_time, pd.Timestamp):
|
|
689
692
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
690
693
|
else:
|
|
691
694
|
assert anchor_time == 'entity'
|
|
692
|
-
if
|
|
695
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
693
696
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
694
697
|
f"table '{query_def.entity_table}' "
|
|
695
698
|
f"to have a time column")
|
|
696
699
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
700
|
+
train, test = self._sampler.sample_target(
|
|
701
|
+
query=query_def,
|
|
702
|
+
num_train_examples=0,
|
|
703
|
+
train_anchor_time=anchor_time,
|
|
704
|
+
num_train_trials=0,
|
|
705
|
+
num_test_examples=size,
|
|
706
|
+
test_anchor_time=anchor_time,
|
|
707
|
+
num_test_trials=max_iterations * size,
|
|
708
|
+
random_seed=random_seed,
|
|
706
709
|
)
|
|
707
710
|
|
|
708
|
-
entity = self._graph_store.pkey_map_dict[
|
|
709
|
-
query_def.entity_table].index[node]
|
|
710
|
-
|
|
711
711
|
return pd.DataFrame({
|
|
712
|
-
'ENTITY':
|
|
713
|
-
'ANCHOR_TIMESTAMP':
|
|
714
|
-
'TARGET':
|
|
712
|
+
'ENTITY': test.entity_pkey,
|
|
713
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
714
|
+
'TARGET': test.target,
|
|
715
715
|
})
|
|
716
716
|
|
|
717
717
|
# Helpers #################################################################
|
|
@@ -734,8 +734,6 @@ class KumoRFM:
|
|
|
734
734
|
|
|
735
735
|
resp = self._api_client.parse_query(request)
|
|
736
736
|
|
|
737
|
-
# TODO Expose validation warnings.
|
|
738
|
-
|
|
739
737
|
if len(resp.validation_response.warnings) > 0:
|
|
740
738
|
msg = '\n'.join([
|
|
741
739
|
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
@@ -753,36 +751,92 @@ class KumoRFM:
|
|
|
753
751
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
754
752
|
f"{msg}") from None
|
|
755
753
|
|
|
754
|
+
@staticmethod
|
|
755
|
+
def _get_task_type(
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
757
|
+
edge_types: list[tuple[str, str, str]],
|
|
758
|
+
) -> TaskType:
|
|
759
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
760
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
761
|
+
|
|
762
|
+
target = query.target_ast
|
|
763
|
+
if isinstance(target, Join):
|
|
764
|
+
target = target.rhs_target
|
|
765
|
+
if isinstance(target, Aggregation):
|
|
766
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
767
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
768
|
+
'.')
|
|
769
|
+
target_edge_types = [
|
|
770
|
+
edge_type for edge_type in edge_types
|
|
771
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
772
|
+
]
|
|
773
|
+
if len(target_edge_types) != 1:
|
|
774
|
+
raise NotImplementedError(
|
|
775
|
+
f"Multilabel-classification queries based on "
|
|
776
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
777
|
+
f"planned to write a link prediction query instead, "
|
|
778
|
+
f"make sure to register '{col_name}' as a "
|
|
779
|
+
f"foreign key.")
|
|
780
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
781
|
+
|
|
782
|
+
return TaskType.REGRESSION
|
|
783
|
+
|
|
784
|
+
assert isinstance(target, Column)
|
|
785
|
+
|
|
786
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
787
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
788
|
+
|
|
789
|
+
if target.stype in {Stype.numerical}:
|
|
790
|
+
return TaskType.REGRESSION
|
|
791
|
+
|
|
792
|
+
raise NotImplementedError("Task type not yet supported")
|
|
793
|
+
|
|
794
|
+
def _get_default_anchor_time(
|
|
795
|
+
self,
|
|
796
|
+
query: ValidatedPredictiveQuery,
|
|
797
|
+
) -> pd.Timestamp:
|
|
798
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
799
|
+
aggr_table_names = [
|
|
800
|
+
aggr._get_target_column_name().split('.')[0]
|
|
801
|
+
for aggr in query.get_all_target_aggregations()
|
|
802
|
+
]
|
|
803
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
804
|
+
|
|
805
|
+
assert query.query_type == QueryType.STATIC
|
|
806
|
+
return self._sampler.get_max_time()
|
|
807
|
+
|
|
756
808
|
def _validate_time(
|
|
757
809
|
self,
|
|
758
810
|
query: ValidatedPredictiveQuery,
|
|
759
811
|
anchor_time: pd.Timestamp,
|
|
760
|
-
context_anchor_time:
|
|
812
|
+
context_anchor_time: pd.Timestamp | None,
|
|
761
813
|
evaluate: bool,
|
|
762
814
|
) -> None:
|
|
763
815
|
|
|
764
|
-
if self.
|
|
816
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
765
817
|
return # Graph without timestamps
|
|
766
818
|
|
|
767
|
-
|
|
819
|
+
min_time = self._sampler.get_min_time()
|
|
820
|
+
max_time = self._sampler.get_max_time()
|
|
821
|
+
|
|
822
|
+
if anchor_time < min_time:
|
|
768
823
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
769
|
-
f"the earliest timestamp "
|
|
770
|
-
f"
|
|
824
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
825
|
+
f"data.")
|
|
771
826
|
|
|
772
|
-
if
|
|
773
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
827
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
774
828
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
775
829
|
f"aggregation time range is too large. To make "
|
|
776
830
|
f"this prediction, we would need data back to "
|
|
777
831
|
f"'{context_anchor_time}', however, your data "
|
|
778
|
-
f"only contains data back to "
|
|
779
|
-
f"'{self._graph_store.min_time}'.")
|
|
832
|
+
f"only contains data back to '{min_time}'.")
|
|
780
833
|
|
|
781
834
|
if query.target_ast.date_offset_range is not None:
|
|
782
835
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
783
836
|
else:
|
|
784
837
|
end_offset = pd.DateOffset(0)
|
|
785
|
-
|
|
838
|
+
end_offset = end_offset * query.num_forecasts
|
|
839
|
+
|
|
786
840
|
if (context_anchor_time is not None
|
|
787
841
|
and context_anchor_time > anchor_time):
|
|
788
842
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -792,7 +846,7 @@ class KumoRFM:
|
|
|
792
846
|
f"intended.")
|
|
793
847
|
elif (query.query_type == QueryType.TEMPORAL
|
|
794
848
|
and context_anchor_time is not None
|
|
795
|
-
and context_anchor_time +
|
|
849
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
796
850
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
797
851
|
f"'{context_anchor_time}' will leak information "
|
|
798
852
|
f"from the prediction anchor timestamp "
|
|
@@ -800,40 +854,37 @@ class KumoRFM:
|
|
|
800
854
|
f"intended.")
|
|
801
855
|
|
|
802
856
|
elif (context_anchor_time is not None
|
|
803
|
-
and context_anchor_time -
|
|
804
|
-
|
|
805
|
-
_time = context_anchor_time - forecast_end_offset
|
|
857
|
+
and context_anchor_time - end_offset < min_time):
|
|
858
|
+
_time = context_anchor_time - end_offset
|
|
806
859
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
807
860
|
f"aggregation time range is too large. To form "
|
|
808
861
|
f"proper input data, we would need data back to "
|
|
809
862
|
f"'{_time}', however, your data only contains "
|
|
810
|
-
f"data back to '{
|
|
863
|
+
f"data back to '{min_time}'.")
|
|
811
864
|
|
|
812
|
-
if
|
|
813
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
865
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
814
866
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
815
|
-
f"latest timestamp '{
|
|
816
|
-
f"
|
|
867
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
868
|
+
f"make sure this is intended.")
|
|
817
869
|
|
|
818
|
-
|
|
819
|
-
if evaluate and anchor_time > max_eval_time:
|
|
870
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
820
871
|
raise ValueError(
|
|
821
872
|
f"Anchor timestamp for evaluation is after the latest "
|
|
822
|
-
f"supported timestamp '{
|
|
873
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
823
874
|
|
|
824
875
|
def _get_context(
|
|
825
876
|
self,
|
|
826
877
|
query: ValidatedPredictiveQuery,
|
|
827
|
-
indices:
|
|
828
|
-
anchor_time:
|
|
829
|
-
context_anchor_time:
|
|
878
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
879
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
880
|
+
context_anchor_time: pd.Timestamp | None,
|
|
830
881
|
run_mode: RunMode,
|
|
831
|
-
num_neighbors:
|
|
882
|
+
num_neighbors: list[int] | None,
|
|
832
883
|
num_hops: int,
|
|
833
884
|
max_pq_iterations: int,
|
|
834
885
|
evaluate: bool,
|
|
835
|
-
random_seed:
|
|
836
|
-
logger:
|
|
886
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
887
|
+
logger: ProgressLogger | None = None,
|
|
837
888
|
) -> Context:
|
|
838
889
|
|
|
839
890
|
if num_neighbors is not None:
|
|
@@ -850,10 +901,9 @@ class KumoRFM:
|
|
|
850
901
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
851
902
|
f"must go beyond this for your use-case.")
|
|
852
903
|
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
edge_types=self._graph_store.edge_types,
|
|
904
|
+
task_type = self._get_task_type(
|
|
905
|
+
query=query,
|
|
906
|
+
edge_types=self._sampler.edge_types,
|
|
857
907
|
)
|
|
858
908
|
|
|
859
909
|
if logger is not None:
|
|
@@ -885,14 +935,17 @@ class KumoRFM:
|
|
|
885
935
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
886
936
|
|
|
887
937
|
if query.target_ast.date_offset_range is None:
|
|
888
|
-
|
|
938
|
+
step_offset = pd.DateOffset(0)
|
|
889
939
|
else:
|
|
890
|
-
|
|
891
|
-
|
|
940
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
941
|
+
end_offset = step_offset * query.num_forecasts
|
|
942
|
+
|
|
892
943
|
if anchor_time is None:
|
|
893
|
-
anchor_time = self.
|
|
944
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
945
|
+
|
|
894
946
|
if evaluate:
|
|
895
|
-
anchor_time = anchor_time -
|
|
947
|
+
anchor_time = anchor_time - end_offset
|
|
948
|
+
|
|
896
949
|
if logger is not None:
|
|
897
950
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
898
951
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -906,57 +959,71 @@ class KumoRFM:
|
|
|
906
959
|
|
|
907
960
|
assert anchor_time is not None
|
|
908
961
|
if isinstance(anchor_time, pd.Timestamp):
|
|
962
|
+
if context_anchor_time == 'entity':
|
|
963
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
964
|
+
"for context and prediction examples")
|
|
909
965
|
if context_anchor_time is None:
|
|
910
|
-
context_anchor_time = anchor_time -
|
|
966
|
+
context_anchor_time = anchor_time - end_offset
|
|
911
967
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
912
968
|
evaluate)
|
|
913
969
|
else:
|
|
914
970
|
assert anchor_time == 'entity'
|
|
915
|
-
if query.
|
|
971
|
+
if query.query_type != QueryType.STATIC:
|
|
972
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
973
|
+
"static predictive queries")
|
|
974
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
916
975
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
917
976
|
f"table '{query.entity_table}' to "
|
|
918
977
|
f"have a time column")
|
|
919
|
-
if context_anchor_time
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
context_anchor_time =
|
|
978
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
979
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
980
|
+
"for context and prediction examples")
|
|
981
|
+
context_anchor_time = 'entity'
|
|
923
982
|
|
|
924
|
-
|
|
983
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
925
984
|
if evaluate:
|
|
926
|
-
|
|
985
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
927
986
|
if task_type.is_link_pred:
|
|
928
|
-
|
|
987
|
+
num_test_examples = num_test_examples // 5
|
|
988
|
+
else:
|
|
989
|
+
num_test_examples = 0
|
|
990
|
+
|
|
991
|
+
train, test = self._sampler.sample_target(
|
|
992
|
+
query=query,
|
|
993
|
+
num_train_examples=num_train_examples,
|
|
994
|
+
train_anchor_time=context_anchor_time,
|
|
995
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
996
|
+
num_test_examples=num_test_examples,
|
|
997
|
+
test_anchor_time=anchor_time,
|
|
998
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
999
|
+
random_seed=random_seed,
|
|
1000
|
+
)
|
|
1001
|
+
train_pkey, train_time, y_train = train
|
|
1002
|
+
test_pkey, test_time, y_test = test
|
|
929
1003
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
num_rhs = y_test.explode().nunique()
|
|
951
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
952
|
-
f"{num_rhs:,} unique items")
|
|
953
|
-
else:
|
|
954
|
-
raise NotImplementedError
|
|
955
|
-
logger.log(msg)
|
|
1004
|
+
if evaluate and logger is not None:
|
|
1005
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1006
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1007
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1008
|
+
f"{pos:.2f}% positive cases")
|
|
1009
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1010
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1011
|
+
f"{y_test.nunique()} classes")
|
|
1012
|
+
elif task_type == TaskType.REGRESSION:
|
|
1013
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1014
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1015
|
+
f"between {format_value(_min)} and "
|
|
1016
|
+
f"{format_value(_max)}")
|
|
1017
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1018
|
+
num_rhs = y_test.explode().nunique()
|
|
1019
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1020
|
+
f"{num_rhs:,} unique items")
|
|
1021
|
+
else:
|
|
1022
|
+
raise NotImplementedError
|
|
1023
|
+
logger.log(msg)
|
|
956
1024
|
|
|
957
|
-
|
|
1025
|
+
if not evaluate:
|
|
958
1026
|
assert indices is not None
|
|
959
|
-
|
|
960
1027
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
961
1028
|
raise ValueError(f"Cannot predict for more than "
|
|
962
1029
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -964,26 +1031,12 @@ class KumoRFM:
|
|
|
964
1031
|
f"`KumoRFM.batch_mode` to process entities "
|
|
965
1032
|
f"in batches")
|
|
966
1033
|
|
|
967
|
-
|
|
968
|
-
table_name=query.entity_table,
|
|
969
|
-
pkey=pd.Series(indices),
|
|
970
|
-
)
|
|
971
|
-
|
|
1034
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
972
1035
|
if isinstance(anchor_time, pd.Timestamp):
|
|
973
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
974
|
-
len(
|
|
1036
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1037
|
+
len(indices)).reset_index(drop=True)
|
|
975
1038
|
else:
|
|
976
|
-
|
|
977
|
-
time = time[test_node] * 1000**3
|
|
978
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
979
|
-
|
|
980
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
981
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
982
|
-
anchor_time=context_anchor_time or 'entity',
|
|
983
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
984
|
-
or anchor_time == 'entity') else None,
|
|
985
|
-
max_iterations=max_pq_iterations,
|
|
986
|
-
)
|
|
1039
|
+
train_time = test_time = 'entity'
|
|
987
1040
|
|
|
988
1041
|
if logger is not None:
|
|
989
1042
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1006,12 +1059,12 @@ class KumoRFM:
|
|
|
1006
1059
|
raise NotImplementedError
|
|
1007
1060
|
logger.log(msg)
|
|
1008
1061
|
|
|
1009
|
-
entity_table_names:
|
|
1062
|
+
entity_table_names: tuple[str, ...]
|
|
1010
1063
|
if task_type.is_link_pred:
|
|
1011
1064
|
final_aggr = query.get_final_target_aggregation()
|
|
1012
1065
|
assert final_aggr is not None
|
|
1013
1066
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1014
|
-
for edge_type in self.
|
|
1067
|
+
for edge_type in self._sampler.edge_types:
|
|
1015
1068
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1016
1069
|
entity_table_names = (
|
|
1017
1070
|
query.entity_table,
|
|
@@ -1023,20 +1076,24 @@ class KumoRFM:
|
|
|
1023
1076
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1024
1077
|
# running out-of-distribution between in-context and test examples:
|
|
1025
1078
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1026
|
-
if
|
|
1079
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1027
1080
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1028
1081
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1029
|
-
|
|
1030
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1082
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1031
1083
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1032
1084
|
|
|
1033
|
-
subgraph = self.
|
|
1085
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1034
1086
|
entity_table_names=entity_table_names,
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1087
|
+
entity_pkey=pd.concat(
|
|
1088
|
+
[train_pkey, test_pkey],
|
|
1089
|
+
axis=0,
|
|
1090
|
+
ignore_index=True,
|
|
1091
|
+
),
|
|
1092
|
+
anchor_time=pd.concat(
|
|
1093
|
+
[train_time, test_time],
|
|
1094
|
+
axis=0,
|
|
1095
|
+
ignore_index=True,
|
|
1096
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1040
1097
|
num_neighbors=num_neighbors,
|
|
1041
1098
|
exclude_cols_dict=exclude_cols_dict,
|
|
1042
1099
|
)
|
|
@@ -1048,23 +1105,19 @@ class KumoRFM:
|
|
|
1048
1105
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1049
1106
|
f"must go beyond this for your use-case.")
|
|
1050
1107
|
|
|
1051
|
-
step_size: Optional[int] = None
|
|
1052
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1053
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1054
|
-
|
|
1055
1108
|
return Context(
|
|
1056
1109
|
task_type=task_type,
|
|
1057
1110
|
entity_table_names=entity_table_names,
|
|
1058
1111
|
subgraph=subgraph,
|
|
1059
1112
|
y_train=y_train,
|
|
1060
|
-
y_test=y_test,
|
|
1113
|
+
y_test=y_test if evaluate else None,
|
|
1061
1114
|
top_k=query.top_k,
|
|
1062
|
-
step_size=
|
|
1115
|
+
step_size=None,
|
|
1063
1116
|
)
|
|
1064
1117
|
|
|
1065
1118
|
@staticmethod
|
|
1066
1119
|
def _validate_metrics(
|
|
1067
|
-
metrics:
|
|
1120
|
+
metrics: list[str],
|
|
1068
1121
|
task_type: TaskType,
|
|
1069
1122
|
) -> None:
|
|
1070
1123
|
|
|
@@ -1121,7 +1174,7 @@ class KumoRFM:
|
|
|
1121
1174
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1122
1175
|
|
|
1123
1176
|
|
|
1124
|
-
def format_value(value:
|
|
1177
|
+
def format_value(value: int | float) -> str:
|
|
1125
1178
|
if value == int(value):
|
|
1126
1179
|
return f'{int(value):,}'
|
|
1127
1180
|
if abs(value) >= 1000:
|