kumoai 2.10.0.dev202510021830__py3-none-any.whl → 2.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/client.py +10 -5
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +5 -3
- kumoai/experimental/rfm/infer/timestamp.py +5 -4
- kumoai/experimental/rfm/local_graph.py +90 -74
- kumoai/experimental/rfm/local_graph_sampler.py +16 -8
- kumoai/experimental/rfm/local_graph_store.py +13 -1
- kumoai/experimental/rfm/local_pquery_driver.py +323 -38
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +277 -223
- kumoai/experimental/rfm/rfm.py +220 -79
- kumoai/jobs.py +1 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/trainer/trainer.py +9 -10
- kumoai/utils/progress_logger.py +13 -0
- {kumoai-2.10.0.dev202510021830.dist-info → kumoai-2.12.1.dist-info}/METADATA +4 -5
- {kumoai-2.10.0.dev202510021830.dist-info → kumoai-2.12.1.dist-info}/RECORD +25 -25
- {kumoai-2.10.0.dev202510021830.dist-info → kumoai-2.12.1.dist-info}/WHEEL +0 -0
- {kumoai-2.10.0.dev202510021830.dist-info → kumoai-2.12.1.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.10.0.dev202510021830.dist-info → kumoai-2.12.1.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,21 +5,28 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
24
|
from kumoapi.rfm import Context
|
|
15
25
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
26
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
27
|
RFMEvaluateRequest,
|
|
19
|
-
|
|
28
|
+
RFMParseQueryRequest,
|
|
20
29
|
RFMPredictRequest,
|
|
21
|
-
RFMPredictResponse,
|
|
22
|
-
RFMValidateQueryRequest,
|
|
23
30
|
)
|
|
24
31
|
from kumoapi.task import TaskType
|
|
25
32
|
|
|
@@ -32,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
32
39
|
LocalPQueryDriver,
|
|
33
40
|
date_offset_to_seconds,
|
|
34
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
35
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
36
44
|
|
|
37
45
|
_RANDOM_SEED = 42
|
|
@@ -62,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
62
70
|
"beyond this for your use-case.")
|
|
63
71
|
|
|
64
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
65
84
|
@dataclass(repr=False)
|
|
66
85
|
class Explanation:
|
|
67
86
|
prediction: pd.DataFrame
|
|
@@ -89,6 +108,12 @@ class Explanation:
|
|
|
89
108
|
def __repr__(self) -> str:
|
|
90
109
|
return str((self.prediction, self.summary))
|
|
91
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
92
117
|
|
|
93
118
|
class KumoRFM:
|
|
94
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -201,6 +226,7 @@ class KumoRFM:
|
|
|
201
226
|
max_pq_iterations: int = 20,
|
|
202
227
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
203
228
|
verbose: Union[bool, ProgressLogger] = True,
|
|
229
|
+
use_prediction_time: bool = False,
|
|
204
230
|
) -> pd.DataFrame:
|
|
205
231
|
pass
|
|
206
232
|
|
|
@@ -210,7 +236,7 @@ class KumoRFM:
|
|
|
210
236
|
query: str,
|
|
211
237
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
212
238
|
*,
|
|
213
|
-
explain: Literal[True],
|
|
239
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
214
240
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
215
241
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
216
242
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -219,6 +245,7 @@ class KumoRFM:
|
|
|
219
245
|
max_pq_iterations: int = 20,
|
|
220
246
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
221
247
|
verbose: Union[bool, ProgressLogger] = True,
|
|
248
|
+
use_prediction_time: bool = False,
|
|
222
249
|
) -> Explanation:
|
|
223
250
|
pass
|
|
224
251
|
|
|
@@ -227,7 +254,7 @@ class KumoRFM:
|
|
|
227
254
|
query: str,
|
|
228
255
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
229
256
|
*,
|
|
230
|
-
explain: bool = False,
|
|
257
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
231
258
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
232
259
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
233
260
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -236,16 +263,23 @@ class KumoRFM:
|
|
|
236
263
|
max_pq_iterations: int = 20,
|
|
237
264
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
238
265
|
verbose: Union[bool, ProgressLogger] = True,
|
|
266
|
+
use_prediction_time: bool = False,
|
|
239
267
|
) -> Union[pd.DataFrame, Explanation]:
|
|
240
268
|
"""Returns predictions for a predictive query.
|
|
241
269
|
|
|
242
270
|
Args:
|
|
243
271
|
query: The predictive query.
|
|
244
272
|
indices: The entity primary keys to predict for. Will override the
|
|
245
|
-
indices given as part of the predictive query.
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
273
|
+
indices given as part of the predictive query. Predictions will
|
|
274
|
+
be generated for all indices, independent of whether they
|
|
275
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
276
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
277
|
+
explain: Configuration for explainability.
|
|
278
|
+
If set to ``True``, will additionally explain the prediction.
|
|
279
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
280
|
+
over which parts of explanation are generated.
|
|
281
|
+
Explainability is currently only supported for single entity
|
|
282
|
+
predictions with ``run_mode="FAST"``.
|
|
249
283
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
250
284
|
``None``, will use the maximum timestamp in the data.
|
|
251
285
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -263,46 +297,54 @@ class KumoRFM:
|
|
|
263
297
|
entities to find valid labels.
|
|
264
298
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
265
299
|
verbose: Whether to print verbose output.
|
|
300
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
301
|
+
additional feature during prediction. This is typically
|
|
302
|
+
beneficial for time series forecasting tasks.
|
|
266
303
|
|
|
267
304
|
Returns:
|
|
268
305
|
The predictions as a :class:`pandas.DataFrame`.
|
|
269
|
-
If ``explain
|
|
270
|
-
|
|
306
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
307
|
+
containing the prediction, summary, and details.
|
|
271
308
|
"""
|
|
309
|
+
explain_config: Optional[ExplainConfig] = None
|
|
310
|
+
if explain is True:
|
|
311
|
+
explain_config = ExplainConfig()
|
|
312
|
+
elif explain is not False:
|
|
313
|
+
explain_config = ExplainConfig._cast(explain)
|
|
314
|
+
|
|
272
315
|
query_def = self._parse_query(query)
|
|
316
|
+
query_str = query_def.to_string()
|
|
273
317
|
|
|
274
318
|
if num_hops != 2 and num_neighbors is not None:
|
|
275
319
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
276
320
|
f"custom 'num_hops={num_hops}' option")
|
|
277
321
|
|
|
278
|
-
if
|
|
322
|
+
if explain_config is not None and run_mode in {
|
|
323
|
+
RunMode.NORMAL, RunMode.BEST
|
|
324
|
+
}:
|
|
279
325
|
warnings.warn(f"Explainability is currently only supported for "
|
|
280
326
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
281
327
|
f"mode has been reset. Please lower the run mode to "
|
|
282
328
|
f"suppress this warning.")
|
|
283
329
|
|
|
284
330
|
if indices is None:
|
|
285
|
-
if query_def.
|
|
331
|
+
if query_def.rfm_entity_ids is None:
|
|
286
332
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
287
333
|
"pass them via `predict(query, indices=...)`")
|
|
288
|
-
indices = query_def.
|
|
334
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
289
335
|
else:
|
|
290
|
-
query_def = replace(
|
|
291
|
-
query_def,
|
|
292
|
-
entity=replace(query_def.entity, ids=None),
|
|
293
|
-
)
|
|
336
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
294
337
|
|
|
295
338
|
if len(indices) == 0:
|
|
296
|
-
raise ValueError("At least one entity is required
|
|
297
|
-
"prediction")
|
|
339
|
+
raise ValueError("At least one entity is required")
|
|
298
340
|
|
|
299
|
-
if
|
|
341
|
+
if explain_config is not None and len(indices) > 1:
|
|
300
342
|
raise ValueError(
|
|
301
343
|
f"Cannot explain predictions for more than a single entity "
|
|
302
344
|
f"(got {len(indices)})")
|
|
303
345
|
|
|
304
346
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
305
|
-
if
|
|
347
|
+
if explain_config is not None:
|
|
306
348
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
307
349
|
else:
|
|
308
350
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -314,8 +356,8 @@ class KumoRFM:
|
|
|
314
356
|
|
|
315
357
|
batch_size: Optional[int] = None
|
|
316
358
|
if self._batch_size == 'max':
|
|
317
|
-
task_type =
|
|
318
|
-
|
|
359
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
360
|
+
query_def,
|
|
319
361
|
edge_types=self._graph_store.edge_types,
|
|
320
362
|
)
|
|
321
363
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -332,10 +374,9 @@ class KumoRFM:
|
|
|
332
374
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
333
375
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
334
376
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
] = []
|
|
377
|
+
predictions: List[pd.DataFrame] = []
|
|
378
|
+
summary: Optional[str] = None
|
|
379
|
+
details: Optional[Explanation] = None
|
|
339
380
|
for i, batch in enumerate(batches):
|
|
340
381
|
# TODO Re-use the context for subsequent predictions.
|
|
341
382
|
context = self._get_context(
|
|
@@ -354,6 +395,8 @@ class KumoRFM:
|
|
|
354
395
|
request = RFMPredictRequest(
|
|
355
396
|
context=context,
|
|
356
397
|
run_mode=RunMode(run_mode),
|
|
398
|
+
query=query_str,
|
|
399
|
+
use_prediction_time=use_prediction_time,
|
|
357
400
|
)
|
|
358
401
|
with warnings.catch_warnings():
|
|
359
402
|
warnings.filterwarnings('ignore', message='gencode')
|
|
@@ -376,11 +419,36 @@ class KumoRFM:
|
|
|
376
419
|
|
|
377
420
|
for attempt in range(self.num_retries + 1):
|
|
378
421
|
try:
|
|
379
|
-
if
|
|
380
|
-
resp = global_state.client.rfm_api.explain(
|
|
422
|
+
if explain_config is not None:
|
|
423
|
+
resp = global_state.client.rfm_api.explain(
|
|
424
|
+
request=_bytes,
|
|
425
|
+
skip_summary=explain_config.skip_summary,
|
|
426
|
+
)
|
|
427
|
+
summary = resp.summary
|
|
428
|
+
details = resp.details
|
|
381
429
|
else:
|
|
382
430
|
resp = global_state.client.rfm_api.predict(_bytes)
|
|
383
|
-
|
|
431
|
+
df = pd.DataFrame(**resp.prediction)
|
|
432
|
+
|
|
433
|
+
# Cast 'ENTITY' to correct data type:
|
|
434
|
+
if 'ENTITY' in df:
|
|
435
|
+
entity = query_def.entity_table
|
|
436
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
437
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
438
|
+
type(pkey_map.index[0]))
|
|
439
|
+
|
|
440
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
441
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
442
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
443
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
444
|
+
if isinstance(ser.iloc[0], str):
|
|
445
|
+
unit = None
|
|
446
|
+
else:
|
|
447
|
+
unit = 'ms'
|
|
448
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
449
|
+
ser, errors='coerce', unit=unit)
|
|
450
|
+
|
|
451
|
+
predictions.append(df)
|
|
384
452
|
|
|
385
453
|
if (isinstance(verbose, InteractiveProgressLogger)
|
|
386
454
|
and len(batches) > 1):
|
|
@@ -401,22 +469,73 @@ class KumoRFM:
|
|
|
401
469
|
|
|
402
470
|
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
403
471
|
|
|
404
|
-
predictions = [pd.DataFrame(**resp.prediction) for resp in resps]
|
|
405
472
|
if len(predictions) == 1:
|
|
406
473
|
prediction = predictions[0]
|
|
407
474
|
else:
|
|
408
475
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
409
476
|
|
|
410
|
-
if
|
|
411
|
-
assert len(
|
|
477
|
+
if explain_config is not None:
|
|
478
|
+
assert len(predictions) == 1
|
|
479
|
+
assert summary is not None
|
|
480
|
+
assert details is not None
|
|
412
481
|
return Explanation(
|
|
413
482
|
prediction=prediction,
|
|
414
|
-
summary=
|
|
415
|
-
details=
|
|
483
|
+
summary=summary,
|
|
484
|
+
details=details,
|
|
416
485
|
)
|
|
417
486
|
|
|
418
487
|
return prediction
|
|
419
488
|
|
|
489
|
+
def is_valid_entity(
|
|
490
|
+
self,
|
|
491
|
+
query: str,
|
|
492
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
493
|
+
*,
|
|
494
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
495
|
+
) -> np.ndarray:
|
|
496
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
497
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
498
|
+
entity filter constraints.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
query: The predictive query.
|
|
502
|
+
indices: The entity primary keys to predict for. Will override the
|
|
503
|
+
indices given as part of the predictive query.
|
|
504
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
505
|
+
``None``, will use the maximum timestamp in the data.
|
|
506
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
507
|
+
"""
|
|
508
|
+
query_def = self._parse_query(query)
|
|
509
|
+
|
|
510
|
+
if indices is None:
|
|
511
|
+
if query_def.rfm_entity_ids is None:
|
|
512
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
513
|
+
"pass them via "
|
|
514
|
+
"`is_valid_entity(query, indices=...)`")
|
|
515
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
516
|
+
|
|
517
|
+
if len(indices) == 0:
|
|
518
|
+
raise ValueError("At least one entity is required")
|
|
519
|
+
|
|
520
|
+
if anchor_time is None:
|
|
521
|
+
anchor_time = self._graph_store.max_time
|
|
522
|
+
|
|
523
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
524
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
525
|
+
else:
|
|
526
|
+
assert anchor_time == 'entity'
|
|
527
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
528
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
529
|
+
f"table '{query_def.entity_table}' "
|
|
530
|
+
f"to have a time column.")
|
|
531
|
+
|
|
532
|
+
node = self._graph_store.get_node_id(
|
|
533
|
+
table_name=query_def.entity_table,
|
|
534
|
+
pkey=pd.Series(indices),
|
|
535
|
+
)
|
|
536
|
+
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
537
|
+
return query_driver.is_valid(node, anchor_time)
|
|
538
|
+
|
|
420
539
|
def evaluate(
|
|
421
540
|
self,
|
|
422
541
|
query: str,
|
|
@@ -430,6 +549,7 @@ class KumoRFM:
|
|
|
430
549
|
max_pq_iterations: int = 20,
|
|
431
550
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
432
551
|
verbose: Union[bool, ProgressLogger] = True,
|
|
552
|
+
use_prediction_time: bool = False,
|
|
433
553
|
) -> pd.DataFrame:
|
|
434
554
|
"""Evaluates a predictive query.
|
|
435
555
|
|
|
@@ -453,6 +573,9 @@ class KumoRFM:
|
|
|
453
573
|
entities to find valid labels.
|
|
454
574
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
455
575
|
verbose: Whether to print verbose output.
|
|
576
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
577
|
+
additional feature during prediction. This is typically
|
|
578
|
+
beneficial for time series forecasting tasks.
|
|
456
579
|
|
|
457
580
|
Returns:
|
|
458
581
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -463,10 +586,10 @@ class KumoRFM:
|
|
|
463
586
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
464
587
|
f"custom 'num_hops={num_hops}' option")
|
|
465
588
|
|
|
466
|
-
if query_def.
|
|
589
|
+
if query_def.rfm_entity_ids is not None:
|
|
467
590
|
query_def = replace(
|
|
468
591
|
query_def,
|
|
469
|
-
|
|
592
|
+
rfm_entity_ids=None,
|
|
470
593
|
)
|
|
471
594
|
|
|
472
595
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -496,6 +619,7 @@ class KumoRFM:
|
|
|
496
619
|
context=context,
|
|
497
620
|
run_mode=RunMode(run_mode),
|
|
498
621
|
metrics=metrics,
|
|
622
|
+
use_prediction_time=use_prediction_time,
|
|
499
623
|
)
|
|
500
624
|
with warnings.catch_warnings():
|
|
501
625
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -506,7 +630,7 @@ class KumoRFM:
|
|
|
506
630
|
|
|
507
631
|
if len(request_bytes) > _MAX_SIZE:
|
|
508
632
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
509
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
633
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
510
634
|
|
|
511
635
|
try:
|
|
512
636
|
resp = global_state.client.rfm_api.evaluate(request_bytes)
|
|
@@ -554,18 +678,19 @@ class KumoRFM:
|
|
|
554
678
|
|
|
555
679
|
if anchor_time is None:
|
|
556
680
|
anchor_time = self._graph_store.max_time
|
|
557
|
-
|
|
558
|
-
|
|
681
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
682
|
+
anchor_time = anchor_time - (
|
|
683
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
684
|
+
query_def.num_forecasts)
|
|
559
685
|
|
|
560
686
|
assert anchor_time is not None
|
|
561
687
|
if isinstance(anchor_time, pd.Timestamp):
|
|
562
688
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
563
689
|
else:
|
|
564
690
|
assert anchor_time == 'entity'
|
|
565
|
-
if (query_def.
|
|
566
|
-
not in self._graph_store.time_dict):
|
|
691
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
567
692
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
568
|
-
f"table '{query_def.
|
|
693
|
+
f"table '{query_def.entity_table}' "
|
|
569
694
|
f"to have a time column")
|
|
570
695
|
|
|
571
696
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -580,7 +705,7 @@ class KumoRFM:
|
|
|
580
705
|
)
|
|
581
706
|
|
|
582
707
|
entity = self._graph_store.pkey_map_dict[
|
|
583
|
-
query_def.
|
|
708
|
+
query_def.entity_table].index[node]
|
|
584
709
|
|
|
585
710
|
return pd.DataFrame({
|
|
586
711
|
'ENTITY': entity,
|
|
@@ -590,8 +715,8 @@ class KumoRFM:
|
|
|
590
715
|
|
|
591
716
|
# Helpers #################################################################
|
|
592
717
|
|
|
593
|
-
def _parse_query(self, query: str) ->
|
|
594
|
-
if isinstance(query,
|
|
718
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
719
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
595
720
|
return query
|
|
596
721
|
|
|
597
722
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -601,12 +726,12 @@ class KumoRFM:
|
|
|
601
726
|
"predictions or evaluations.")
|
|
602
727
|
|
|
603
728
|
try:
|
|
604
|
-
request =
|
|
729
|
+
request = RFMParseQueryRequest(
|
|
605
730
|
query=query,
|
|
606
731
|
graph_definition=self._graph_def,
|
|
607
732
|
)
|
|
608
733
|
|
|
609
|
-
resp = global_state.client.rfm_api.
|
|
734
|
+
resp = global_state.client.rfm_api.parse_query(request)
|
|
610
735
|
# TODO Expose validation warnings.
|
|
611
736
|
|
|
612
737
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -617,7 +742,7 @@ class KumoRFM:
|
|
|
617
742
|
warnings.warn(f"Encountered the following warnings during "
|
|
618
743
|
f"parsing:\n{msg}")
|
|
619
744
|
|
|
620
|
-
return resp.
|
|
745
|
+
return resp.query
|
|
621
746
|
except HTTPException as e:
|
|
622
747
|
try:
|
|
623
748
|
msg = json.loads(e.detail)['detail']
|
|
@@ -628,7 +753,7 @@ class KumoRFM:
|
|
|
628
753
|
|
|
629
754
|
def _validate_time(
|
|
630
755
|
self,
|
|
631
|
-
query:
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
632
757
|
anchor_time: pd.Timestamp,
|
|
633
758
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
634
759
|
evaluate: bool,
|
|
@@ -651,6 +776,11 @@ class KumoRFM:
|
|
|
651
776
|
f"only contains data back to "
|
|
652
777
|
f"'{self._graph_store.min_time}'.")
|
|
653
778
|
|
|
779
|
+
if query.target_ast.date_offset_range is not None:
|
|
780
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
781
|
+
else:
|
|
782
|
+
end_offset = pd.DateOffset(0)
|
|
783
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
654
784
|
if (context_anchor_time is not None
|
|
655
785
|
and context_anchor_time > anchor_time):
|
|
656
786
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -659,19 +789,18 @@ class KumoRFM:
|
|
|
659
789
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
660
790
|
f"intended.")
|
|
661
791
|
elif (query.query_type == QueryType.TEMPORAL
|
|
662
|
-
and context_anchor_time is not None
|
|
663
|
-
|
|
792
|
+
and context_anchor_time is not None
|
|
793
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
664
794
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
665
795
|
f"'{context_anchor_time}' will leak information "
|
|
666
796
|
f"from the prediction anchor timestamp "
|
|
667
797
|
f"'{anchor_time}'. Please make sure this is "
|
|
668
798
|
f"intended.")
|
|
669
799
|
|
|
670
|
-
elif (context_anchor_time is not None
|
|
671
|
-
|
|
800
|
+
elif (context_anchor_time is not None
|
|
801
|
+
and context_anchor_time - forecast_end_offset
|
|
672
802
|
< self._graph_store.min_time):
|
|
673
|
-
_time = context_anchor_time -
|
|
674
|
-
query.num_forecasts)
|
|
803
|
+
_time = context_anchor_time - forecast_end_offset
|
|
675
804
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
676
805
|
f"aggregation time range is too large. To form "
|
|
677
806
|
f"proper input data, we would need data back to "
|
|
@@ -684,8 +813,7 @@ class KumoRFM:
|
|
|
684
813
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
685
814
|
f"in the data. Please make sure this is intended.")
|
|
686
815
|
|
|
687
|
-
max_eval_time =
|
|
688
|
-
query.target.end_offset * query.num_forecasts)
|
|
816
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
689
817
|
if evaluate and anchor_time > max_eval_time:
|
|
690
818
|
raise ValueError(
|
|
691
819
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -693,7 +821,7 @@ class KumoRFM:
|
|
|
693
821
|
|
|
694
822
|
def _get_context(
|
|
695
823
|
self,
|
|
696
|
-
query:
|
|
824
|
+
query: ValidatedPredictiveQuery,
|
|
697
825
|
indices: Union[List[str], List[float], List[int], None],
|
|
698
826
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
699
827
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -721,8 +849,8 @@ class KumoRFM:
|
|
|
721
849
|
f"must go beyond this for your use-case.")
|
|
722
850
|
|
|
723
851
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
724
|
-
task_type =
|
|
725
|
-
|
|
852
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
853
|
+
query,
|
|
726
854
|
edge_types=self._graph_store.edge_types,
|
|
727
855
|
)
|
|
728
856
|
|
|
@@ -754,11 +882,15 @@ class KumoRFM:
|
|
|
754
882
|
else:
|
|
755
883
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
756
884
|
|
|
885
|
+
if query.target_ast.date_offset_range is None:
|
|
886
|
+
end_offset = pd.DateOffset(0)
|
|
887
|
+
else:
|
|
888
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
889
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
757
890
|
if anchor_time is None:
|
|
758
891
|
anchor_time = self._graph_store.max_time
|
|
759
892
|
if evaluate:
|
|
760
|
-
anchor_time = anchor_time -
|
|
761
|
-
query.num_forecasts)
|
|
893
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
762
894
|
if logger is not None:
|
|
763
895
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
764
896
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -773,15 +905,14 @@ class KumoRFM:
|
|
|
773
905
|
assert anchor_time is not None
|
|
774
906
|
if isinstance(anchor_time, pd.Timestamp):
|
|
775
907
|
if context_anchor_time is None:
|
|
776
|
-
context_anchor_time = anchor_time -
|
|
777
|
-
query.num_forecasts)
|
|
908
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
778
909
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
779
910
|
evaluate)
|
|
780
911
|
else:
|
|
781
912
|
assert anchor_time == 'entity'
|
|
782
|
-
if query.
|
|
913
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
783
914
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
784
|
-
f"table '{query.
|
|
915
|
+
f"table '{query.entity_table}' to "
|
|
785
916
|
f"have a time column")
|
|
786
917
|
if context_anchor_time is not None:
|
|
787
918
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -832,7 +963,7 @@ class KumoRFM:
|
|
|
832
963
|
f"in batches")
|
|
833
964
|
|
|
834
965
|
test_node = self._graph_store.get_node_id(
|
|
835
|
-
table_name=query.
|
|
966
|
+
table_name=query.entity_table,
|
|
836
967
|
pkey=pd.Series(indices),
|
|
837
968
|
)
|
|
838
969
|
|
|
@@ -840,8 +971,7 @@ class KumoRFM:
|
|
|
840
971
|
test_time = pd.Series(anchor_time).repeat(
|
|
841
972
|
len(test_node)).reset_index(drop=True)
|
|
842
973
|
else:
|
|
843
|
-
time = self._graph_store.time_dict[
|
|
844
|
-
query.entity.pkey.table_name]
|
|
974
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
845
975
|
time = time[test_node] * 1000**3
|
|
846
976
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
847
977
|
|
|
@@ -874,12 +1004,23 @@ class KumoRFM:
|
|
|
874
1004
|
raise NotImplementedError
|
|
875
1005
|
logger.log(msg)
|
|
876
1006
|
|
|
877
|
-
entity_table_names
|
|
878
|
-
|
|
1007
|
+
entity_table_names: Tuple[str, ...]
|
|
1008
|
+
if task_type.is_link_pred:
|
|
1009
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1010
|
+
assert final_aggr is not None
|
|
1011
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1012
|
+
for edge_type in self._graph_store.edge_types:
|
|
1013
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1014
|
+
entity_table_names = (
|
|
1015
|
+
query.entity_table,
|
|
1016
|
+
edge_type[2],
|
|
1017
|
+
)
|
|
1018
|
+
else:
|
|
1019
|
+
entity_table_names = (query.entity_table, )
|
|
879
1020
|
|
|
880
1021
|
# Exclude the entity anchor time from the feature set to prevent
|
|
881
1022
|
# running out-of-distribution between in-context and test examples:
|
|
882
|
-
exclude_cols_dict = query.
|
|
1023
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
883
1024
|
if anchor_time == 'entity':
|
|
884
1025
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
885
1026
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -908,7 +1049,7 @@ class KumoRFM:
|
|
|
908
1049
|
|
|
909
1050
|
step_size: Optional[int] = None
|
|
910
1051
|
if query.query_type == QueryType.TEMPORAL:
|
|
911
|
-
step_size = date_offset_to_seconds(
|
|
1052
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
912
1053
|
|
|
913
1054
|
return Context(
|
|
914
1055
|
task_type=task_type,
|
|
@@ -933,7 +1074,7 @@ class KumoRFM:
|
|
|
933
1074
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
934
1075
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
935
1076
|
elif task_type == TaskType.REGRESSION:
|
|
936
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1077
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
937
1078
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
938
1079
|
supported_metrics = [
|
|
939
1080
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|
kumoai/jobs.py
CHANGED
|
@@ -26,6 +26,7 @@ class JobInterface(ABC, Generic[IDType, JobRequestType, JobResourceType]):
|
|
|
26
26
|
limit (int): Max number of jobs to list, default 10.
|
|
27
27
|
|
|
28
28
|
Example:
|
|
29
|
+
>>> # doctest: +SKIP
|
|
29
30
|
>>> tags = {'pquery_name': 'my_pquery_name'}
|
|
30
31
|
>>> jobs = BatchPredictionJob.search_by_tags(tags)
|
|
31
32
|
Search limited to 10 results based on the `limit` parameter.
|
|
@@ -370,9 +370,11 @@ class PredictiveQuery:
|
|
|
370
370
|
train_table_job_api = global_state.client.generate_train_table_job_api
|
|
371
371
|
job_id: GenerateTrainTableJobID = train_table_job_api.create(
|
|
372
372
|
GenerateTrainTableRequest(
|
|
373
|
-
dict(custom_tags),
|
|
374
|
-
|
|
375
|
-
|
|
373
|
+
dict(custom_tags),
|
|
374
|
+
pq_id,
|
|
375
|
+
plan,
|
|
376
|
+
None,
|
|
377
|
+
))
|
|
376
378
|
|
|
377
379
|
self._train_table = TrainingTableJob(job_id=job_id)
|
|
378
380
|
if non_blocking:
|
|
@@ -451,9 +453,11 @@ class PredictiveQuery:
|
|
|
451
453
|
bp_table_api = global_state.client.generate_prediction_table_job_api
|
|
452
454
|
job_id: GeneratePredictionTableJobID = bp_table_api.create(
|
|
453
455
|
GeneratePredictionTableRequest(
|
|
454
|
-
dict(custom_tags),
|
|
455
|
-
|
|
456
|
-
|
|
456
|
+
dict(custom_tags),
|
|
457
|
+
pq_id,
|
|
458
|
+
plan,
|
|
459
|
+
None,
|
|
460
|
+
))
|
|
457
461
|
|
|
458
462
|
self._prediction_table = PredictionTableJob(job_id=job_id)
|
|
459
463
|
if non_blocking:
|