kumoai 2.11.0.dev202510161830__py3-none-any.whl → 2.13.0.dev202511211730__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +10 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +151 -8
- kumoai/experimental/rfm/local_graph_sampler.py +0 -2
- kumoai/experimental/rfm/local_pquery_driver.py +221 -26
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +277 -223
- kumoai/experimental/rfm/rfm.py +149 -79
- kumoai/experimental/rfm/sagemaker.py +130 -0
- kumoai/spcs.py +1 -3
- kumoai/trainer/trainer.py +9 -10
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/METADATA +11 -2
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/RECORD +21 -20
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/WHEEL +0 -0
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,23 +5,32 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
24
|
from kumoapi.rfm import Context
|
|
15
25
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
26
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
19
29
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
30
|
)
|
|
22
31
|
from kumoapi.task import TaskType
|
|
23
32
|
|
|
24
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
25
34
|
from kumoai.exceptions import HTTPException
|
|
26
35
|
from kumoai.experimental.rfm import LocalGraph
|
|
27
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
@@ -30,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
30
39
|
LocalPQueryDriver,
|
|
31
40
|
date_offset_to_seconds,
|
|
32
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
33
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
34
44
|
|
|
35
45
|
_RANDOM_SEED = 42
|
|
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
60
70
|
"beyond this for your use-case.")
|
|
61
71
|
|
|
62
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
63
84
|
@dataclass(repr=False)
|
|
64
85
|
class Explanation:
|
|
65
86
|
prediction: pd.DataFrame
|
|
@@ -87,6 +108,12 @@ class Explanation:
|
|
|
87
108
|
def __repr__(self) -> str:
|
|
88
109
|
return str((self.prediction, self.summary))
|
|
89
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
90
117
|
|
|
91
118
|
class KumoRFM:
|
|
92
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -114,9 +141,9 @@ class KumoRFM:
|
|
|
114
141
|
|
|
115
142
|
rfm = KumoRFM(graph)
|
|
116
143
|
|
|
117
|
-
query = ("PREDICT COUNT(
|
|
118
|
-
"FOR users.user_id=
|
|
119
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
120
147
|
|
|
121
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
122
149
|
# 1 0.85
|
|
@@ -147,6 +174,8 @@ class KumoRFM:
|
|
|
147
174
|
|
|
148
175
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
149
176
|
self.num_retries: int = 0
|
|
177
|
+
from kumoai.experimental.rfm import global_state
|
|
178
|
+
self._api_client = RFMAPI(global_state.client)
|
|
150
179
|
|
|
151
180
|
def __repr__(self) -> str:
|
|
152
181
|
return f'{self.__class__.__name__}()'
|
|
@@ -199,6 +228,7 @@ class KumoRFM:
|
|
|
199
228
|
max_pq_iterations: int = 20,
|
|
200
229
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
201
230
|
verbose: Union[bool, ProgressLogger] = True,
|
|
231
|
+
use_prediction_time: bool = False,
|
|
202
232
|
) -> pd.DataFrame:
|
|
203
233
|
pass
|
|
204
234
|
|
|
@@ -208,7 +238,7 @@ class KumoRFM:
|
|
|
208
238
|
query: str,
|
|
209
239
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
210
240
|
*,
|
|
211
|
-
explain: Literal[True],
|
|
241
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
212
242
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
213
243
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
214
244
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -217,6 +247,7 @@ class KumoRFM:
|
|
|
217
247
|
max_pq_iterations: int = 20,
|
|
218
248
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
219
249
|
verbose: Union[bool, ProgressLogger] = True,
|
|
250
|
+
use_prediction_time: bool = False,
|
|
220
251
|
) -> Explanation:
|
|
221
252
|
pass
|
|
222
253
|
|
|
@@ -225,7 +256,7 @@ class KumoRFM:
|
|
|
225
256
|
query: str,
|
|
226
257
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
227
258
|
*,
|
|
228
|
-
explain: bool = False,
|
|
259
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
229
260
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
230
261
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
231
262
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -234,6 +265,7 @@ class KumoRFM:
|
|
|
234
265
|
max_pq_iterations: int = 20,
|
|
235
266
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
236
267
|
verbose: Union[bool, ProgressLogger] = True,
|
|
268
|
+
use_prediction_time: bool = False,
|
|
237
269
|
) -> Union[pd.DataFrame, Explanation]:
|
|
238
270
|
"""Returns predictions for a predictive query.
|
|
239
271
|
|
|
@@ -244,9 +276,12 @@ class KumoRFM:
|
|
|
244
276
|
be generated for all indices, independent of whether they
|
|
245
277
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
246
278
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
247
|
-
explain:
|
|
248
|
-
|
|
249
|
-
|
|
279
|
+
explain: Configuration for explainability.
|
|
280
|
+
If set to ``True``, will additionally explain the prediction.
|
|
281
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
282
|
+
over which parts of explanation are generated.
|
|
283
|
+
Explainability is currently only supported for single entity
|
|
284
|
+
predictions with ``run_mode="FAST"``.
|
|
250
285
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
251
286
|
``None``, will use the maximum timestamp in the data.
|
|
252
287
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -264,45 +299,54 @@ class KumoRFM:
|
|
|
264
299
|
entities to find valid labels.
|
|
265
300
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
266
301
|
verbose: Whether to print verbose output.
|
|
302
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
303
|
+
additional feature during prediction. This is typically
|
|
304
|
+
beneficial for time series forecasting tasks.
|
|
267
305
|
|
|
268
306
|
Returns:
|
|
269
307
|
The predictions as a :class:`pandas.DataFrame`.
|
|
270
|
-
If ``explain
|
|
271
|
-
|
|
308
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
309
|
+
containing the prediction, summary, and details.
|
|
272
310
|
"""
|
|
311
|
+
explain_config: Optional[ExplainConfig] = None
|
|
312
|
+
if explain is True:
|
|
313
|
+
explain_config = ExplainConfig()
|
|
314
|
+
elif explain is not False:
|
|
315
|
+
explain_config = ExplainConfig._cast(explain)
|
|
316
|
+
|
|
273
317
|
query_def = self._parse_query(query)
|
|
318
|
+
query_str = query_def.to_string()
|
|
274
319
|
|
|
275
320
|
if num_hops != 2 and num_neighbors is not None:
|
|
276
321
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
277
322
|
f"custom 'num_hops={num_hops}' option")
|
|
278
323
|
|
|
279
|
-
if
|
|
324
|
+
if explain_config is not None and run_mode in {
|
|
325
|
+
RunMode.NORMAL, RunMode.BEST
|
|
326
|
+
}:
|
|
280
327
|
warnings.warn(f"Explainability is currently only supported for "
|
|
281
328
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
282
329
|
f"mode has been reset. Please lower the run mode to "
|
|
283
330
|
f"suppress this warning.")
|
|
284
331
|
|
|
285
332
|
if indices is None:
|
|
286
|
-
if query_def.
|
|
333
|
+
if query_def.rfm_entity_ids is None:
|
|
287
334
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
288
335
|
"pass them via `predict(query, indices=...)`")
|
|
289
|
-
indices = query_def.
|
|
336
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
290
337
|
else:
|
|
291
|
-
query_def = replace(
|
|
292
|
-
query_def,
|
|
293
|
-
entity=replace(query_def.entity, ids=None),
|
|
294
|
-
)
|
|
338
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
295
339
|
|
|
296
340
|
if len(indices) == 0:
|
|
297
341
|
raise ValueError("At least one entity is required")
|
|
298
342
|
|
|
299
|
-
if
|
|
343
|
+
if explain_config is not None and len(indices) > 1:
|
|
300
344
|
raise ValueError(
|
|
301
345
|
f"Cannot explain predictions for more than a single entity "
|
|
302
346
|
f"(got {len(indices)})")
|
|
303
347
|
|
|
304
348
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
305
|
-
if
|
|
349
|
+
if explain_config is not None:
|
|
306
350
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
307
351
|
else:
|
|
308
352
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -314,8 +358,8 @@ class KumoRFM:
|
|
|
314
358
|
|
|
315
359
|
batch_size: Optional[int] = None
|
|
316
360
|
if self._batch_size == 'max':
|
|
317
|
-
task_type =
|
|
318
|
-
|
|
361
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
362
|
+
query_def,
|
|
319
363
|
edge_types=self._graph_store.edge_types,
|
|
320
364
|
)
|
|
321
365
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -353,6 +397,8 @@ class KumoRFM:
|
|
|
353
397
|
request = RFMPredictRequest(
|
|
354
398
|
context=context,
|
|
355
399
|
run_mode=RunMode(run_mode),
|
|
400
|
+
query=query_str,
|
|
401
|
+
use_prediction_time=use_prediction_time,
|
|
356
402
|
)
|
|
357
403
|
with warnings.catch_warnings():
|
|
358
404
|
warnings.filterwarnings('ignore', message='gencode')
|
|
@@ -375,17 +421,20 @@ class KumoRFM:
|
|
|
375
421
|
|
|
376
422
|
for attempt in range(self.num_retries + 1):
|
|
377
423
|
try:
|
|
378
|
-
if
|
|
379
|
-
resp =
|
|
424
|
+
if explain_config is not None:
|
|
425
|
+
resp = self._api_client.explain(
|
|
426
|
+
request=_bytes,
|
|
427
|
+
skip_summary=explain_config.skip_summary,
|
|
428
|
+
)
|
|
380
429
|
summary = resp.summary
|
|
381
430
|
details = resp.details
|
|
382
431
|
else:
|
|
383
|
-
resp =
|
|
432
|
+
resp = self._api_client.predict(_bytes)
|
|
384
433
|
df = pd.DataFrame(**resp.prediction)
|
|
385
434
|
|
|
386
435
|
# Cast 'ENTITY' to correct data type:
|
|
387
436
|
if 'ENTITY' in df:
|
|
388
|
-
entity = query_def.
|
|
437
|
+
entity = query_def.entity_table
|
|
389
438
|
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
390
439
|
df['ENTITY'] = df['ENTITY'].astype(
|
|
391
440
|
type(pkey_map.index[0]))
|
|
@@ -427,7 +476,7 @@ class KumoRFM:
|
|
|
427
476
|
else:
|
|
428
477
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
429
478
|
|
|
430
|
-
if
|
|
479
|
+
if explain_config is not None:
|
|
431
480
|
assert len(predictions) == 1
|
|
432
481
|
assert summary is not None
|
|
433
482
|
assert details is not None
|
|
@@ -461,11 +510,11 @@ class KumoRFM:
|
|
|
461
510
|
query_def = self._parse_query(query)
|
|
462
511
|
|
|
463
512
|
if indices is None:
|
|
464
|
-
if query_def.
|
|
513
|
+
if query_def.rfm_entity_ids is None:
|
|
465
514
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
466
515
|
"pass them via "
|
|
467
516
|
"`is_valid_entity(query, indices=...)`")
|
|
468
|
-
indices = query_def.
|
|
517
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
469
518
|
|
|
470
519
|
if len(indices) == 0:
|
|
471
520
|
raise ValueError("At least one entity is required")
|
|
@@ -477,14 +526,13 @@ class KumoRFM:
|
|
|
477
526
|
self._validate_time(query_def, anchor_time, None, False)
|
|
478
527
|
else:
|
|
479
528
|
assert anchor_time == 'entity'
|
|
480
|
-
if (query_def.
|
|
481
|
-
not in self._graph_store.time_dict):
|
|
529
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
482
530
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
483
|
-
f"table '{query_def.
|
|
484
|
-
f"to have a time column")
|
|
531
|
+
f"table '{query_def.entity_table}' "
|
|
532
|
+
f"to have a time column.")
|
|
485
533
|
|
|
486
534
|
node = self._graph_store.get_node_id(
|
|
487
|
-
table_name=query_def.
|
|
535
|
+
table_name=query_def.entity_table,
|
|
488
536
|
pkey=pd.Series(indices),
|
|
489
537
|
)
|
|
490
538
|
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
@@ -503,6 +551,7 @@ class KumoRFM:
|
|
|
503
551
|
max_pq_iterations: int = 20,
|
|
504
552
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
505
553
|
verbose: Union[bool, ProgressLogger] = True,
|
|
554
|
+
use_prediction_time: bool = False,
|
|
506
555
|
) -> pd.DataFrame:
|
|
507
556
|
"""Evaluates a predictive query.
|
|
508
557
|
|
|
@@ -526,6 +575,9 @@ class KumoRFM:
|
|
|
526
575
|
entities to find valid labels.
|
|
527
576
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
528
577
|
verbose: Whether to print verbose output.
|
|
578
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
579
|
+
additional feature during prediction. This is typically
|
|
580
|
+
beneficial for time series forecasting tasks.
|
|
529
581
|
|
|
530
582
|
Returns:
|
|
531
583
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -536,10 +588,10 @@ class KumoRFM:
|
|
|
536
588
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
537
589
|
f"custom 'num_hops={num_hops}' option")
|
|
538
590
|
|
|
539
|
-
if query_def.
|
|
591
|
+
if query_def.rfm_entity_ids is not None:
|
|
540
592
|
query_def = replace(
|
|
541
593
|
query_def,
|
|
542
|
-
|
|
594
|
+
rfm_entity_ids=None,
|
|
543
595
|
)
|
|
544
596
|
|
|
545
597
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -569,6 +621,7 @@ class KumoRFM:
|
|
|
569
621
|
context=context,
|
|
570
622
|
run_mode=RunMode(run_mode),
|
|
571
623
|
metrics=metrics,
|
|
624
|
+
use_prediction_time=use_prediction_time,
|
|
572
625
|
)
|
|
573
626
|
with warnings.catch_warnings():
|
|
574
627
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -579,10 +632,10 @@ class KumoRFM:
|
|
|
579
632
|
|
|
580
633
|
if len(request_bytes) > _MAX_SIZE:
|
|
581
634
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
582
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
635
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
583
636
|
|
|
584
637
|
try:
|
|
585
|
-
resp =
|
|
638
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
586
639
|
except HTTPException as e:
|
|
587
640
|
try:
|
|
588
641
|
msg = json.loads(e.detail)['detail']
|
|
@@ -627,18 +680,19 @@ class KumoRFM:
|
|
|
627
680
|
|
|
628
681
|
if anchor_time is None:
|
|
629
682
|
anchor_time = self._graph_store.max_time
|
|
630
|
-
|
|
631
|
-
|
|
683
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
684
|
+
anchor_time = anchor_time - (
|
|
685
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
686
|
+
query_def.num_forecasts)
|
|
632
687
|
|
|
633
688
|
assert anchor_time is not None
|
|
634
689
|
if isinstance(anchor_time, pd.Timestamp):
|
|
635
690
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
636
691
|
else:
|
|
637
692
|
assert anchor_time == 'entity'
|
|
638
|
-
if (query_def.
|
|
639
|
-
not in self._graph_store.time_dict):
|
|
693
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
640
694
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
641
|
-
f"table '{query_def.
|
|
695
|
+
f"table '{query_def.entity_table}' "
|
|
642
696
|
f"to have a time column")
|
|
643
697
|
|
|
644
698
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -653,7 +707,7 @@ class KumoRFM:
|
|
|
653
707
|
)
|
|
654
708
|
|
|
655
709
|
entity = self._graph_store.pkey_map_dict[
|
|
656
|
-
query_def.
|
|
710
|
+
query_def.entity_table].index[node]
|
|
657
711
|
|
|
658
712
|
return pd.DataFrame({
|
|
659
713
|
'ENTITY': entity,
|
|
@@ -663,8 +717,8 @@ class KumoRFM:
|
|
|
663
717
|
|
|
664
718
|
# Helpers #################################################################
|
|
665
719
|
|
|
666
|
-
def _parse_query(self, query: str) ->
|
|
667
|
-
if isinstance(query,
|
|
720
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
721
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
668
722
|
return query
|
|
669
723
|
|
|
670
724
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -674,12 +728,13 @@ class KumoRFM:
|
|
|
674
728
|
"predictions or evaluations.")
|
|
675
729
|
|
|
676
730
|
try:
|
|
677
|
-
request =
|
|
731
|
+
request = RFMParseQueryRequest(
|
|
678
732
|
query=query,
|
|
679
733
|
graph_definition=self._graph_def,
|
|
680
734
|
)
|
|
681
735
|
|
|
682
|
-
resp =
|
|
736
|
+
resp = self._api_client.parse_query(request)
|
|
737
|
+
|
|
683
738
|
# TODO Expose validation warnings.
|
|
684
739
|
|
|
685
740
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -690,7 +745,7 @@ class KumoRFM:
|
|
|
690
745
|
warnings.warn(f"Encountered the following warnings during "
|
|
691
746
|
f"parsing:\n{msg}")
|
|
692
747
|
|
|
693
|
-
return resp.
|
|
748
|
+
return resp.query
|
|
694
749
|
except HTTPException as e:
|
|
695
750
|
try:
|
|
696
751
|
msg = json.loads(e.detail)['detail']
|
|
@@ -701,7 +756,7 @@ class KumoRFM:
|
|
|
701
756
|
|
|
702
757
|
def _validate_time(
|
|
703
758
|
self,
|
|
704
|
-
query:
|
|
759
|
+
query: ValidatedPredictiveQuery,
|
|
705
760
|
anchor_time: pd.Timestamp,
|
|
706
761
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
707
762
|
evaluate: bool,
|
|
@@ -724,6 +779,11 @@ class KumoRFM:
|
|
|
724
779
|
f"only contains data back to "
|
|
725
780
|
f"'{self._graph_store.min_time}'.")
|
|
726
781
|
|
|
782
|
+
if query.target_ast.date_offset_range is not None:
|
|
783
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
784
|
+
else:
|
|
785
|
+
end_offset = pd.DateOffset(0)
|
|
786
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
727
787
|
if (context_anchor_time is not None
|
|
728
788
|
and context_anchor_time > anchor_time):
|
|
729
789
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -732,19 +792,18 @@ class KumoRFM:
|
|
|
732
792
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
733
793
|
f"intended.")
|
|
734
794
|
elif (query.query_type == QueryType.TEMPORAL
|
|
735
|
-
and context_anchor_time is not None
|
|
736
|
-
|
|
795
|
+
and context_anchor_time is not None
|
|
796
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
737
797
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
738
798
|
f"'{context_anchor_time}' will leak information "
|
|
739
799
|
f"from the prediction anchor timestamp "
|
|
740
800
|
f"'{anchor_time}'. Please make sure this is "
|
|
741
801
|
f"intended.")
|
|
742
802
|
|
|
743
|
-
elif (context_anchor_time is not None
|
|
744
|
-
|
|
803
|
+
elif (context_anchor_time is not None
|
|
804
|
+
and context_anchor_time - forecast_end_offset
|
|
745
805
|
< self._graph_store.min_time):
|
|
746
|
-
_time = context_anchor_time -
|
|
747
|
-
query.num_forecasts)
|
|
806
|
+
_time = context_anchor_time - forecast_end_offset
|
|
748
807
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
749
808
|
f"aggregation time range is too large. To form "
|
|
750
809
|
f"proper input data, we would need data back to "
|
|
@@ -757,8 +816,7 @@ class KumoRFM:
|
|
|
757
816
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
758
817
|
f"in the data. Please make sure this is intended.")
|
|
759
818
|
|
|
760
|
-
max_eval_time =
|
|
761
|
-
query.target.end_offset * query.num_forecasts)
|
|
819
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
762
820
|
if evaluate and anchor_time > max_eval_time:
|
|
763
821
|
raise ValueError(
|
|
764
822
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -766,7 +824,7 @@ class KumoRFM:
|
|
|
766
824
|
|
|
767
825
|
def _get_context(
|
|
768
826
|
self,
|
|
769
|
-
query:
|
|
827
|
+
query: ValidatedPredictiveQuery,
|
|
770
828
|
indices: Union[List[str], List[float], List[int], None],
|
|
771
829
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
772
830
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -794,8 +852,8 @@ class KumoRFM:
|
|
|
794
852
|
f"must go beyond this for your use-case.")
|
|
795
853
|
|
|
796
854
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
797
|
-
task_type =
|
|
798
|
-
|
|
855
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
856
|
+
query,
|
|
799
857
|
edge_types=self._graph_store.edge_types,
|
|
800
858
|
)
|
|
801
859
|
|
|
@@ -827,11 +885,15 @@ class KumoRFM:
|
|
|
827
885
|
else:
|
|
828
886
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
829
887
|
|
|
888
|
+
if query.target_ast.date_offset_range is None:
|
|
889
|
+
end_offset = pd.DateOffset(0)
|
|
890
|
+
else:
|
|
891
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
892
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
830
893
|
if anchor_time is None:
|
|
831
894
|
anchor_time = self._graph_store.max_time
|
|
832
895
|
if evaluate:
|
|
833
|
-
anchor_time = anchor_time -
|
|
834
|
-
query.num_forecasts)
|
|
896
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
835
897
|
if logger is not None:
|
|
836
898
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
837
899
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -846,15 +908,14 @@ class KumoRFM:
|
|
|
846
908
|
assert anchor_time is not None
|
|
847
909
|
if isinstance(anchor_time, pd.Timestamp):
|
|
848
910
|
if context_anchor_time is None:
|
|
849
|
-
context_anchor_time = anchor_time -
|
|
850
|
-
query.num_forecasts)
|
|
911
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
851
912
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
852
913
|
evaluate)
|
|
853
914
|
else:
|
|
854
915
|
assert anchor_time == 'entity'
|
|
855
|
-
if query.
|
|
916
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
856
917
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
857
|
-
f"table '{query.
|
|
918
|
+
f"table '{query.entity_table}' to "
|
|
858
919
|
f"have a time column")
|
|
859
920
|
if context_anchor_time is not None:
|
|
860
921
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -905,7 +966,7 @@ class KumoRFM:
|
|
|
905
966
|
f"in batches")
|
|
906
967
|
|
|
907
968
|
test_node = self._graph_store.get_node_id(
|
|
908
|
-
table_name=query.
|
|
969
|
+
table_name=query.entity_table,
|
|
909
970
|
pkey=pd.Series(indices),
|
|
910
971
|
)
|
|
911
972
|
|
|
@@ -913,8 +974,7 @@ class KumoRFM:
|
|
|
913
974
|
test_time = pd.Series(anchor_time).repeat(
|
|
914
975
|
len(test_node)).reset_index(drop=True)
|
|
915
976
|
else:
|
|
916
|
-
time = self._graph_store.time_dict[
|
|
917
|
-
query.entity.pkey.table_name]
|
|
977
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
918
978
|
time = time[test_node] * 1000**3
|
|
919
979
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
920
980
|
|
|
@@ -947,12 +1007,23 @@ class KumoRFM:
|
|
|
947
1007
|
raise NotImplementedError
|
|
948
1008
|
logger.log(msg)
|
|
949
1009
|
|
|
950
|
-
entity_table_names
|
|
951
|
-
|
|
1010
|
+
entity_table_names: Tuple[str, ...]
|
|
1011
|
+
if task_type.is_link_pred:
|
|
1012
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1013
|
+
assert final_aggr is not None
|
|
1014
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1015
|
+
for edge_type in self._graph_store.edge_types:
|
|
1016
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1017
|
+
entity_table_names = (
|
|
1018
|
+
query.entity_table,
|
|
1019
|
+
edge_type[2],
|
|
1020
|
+
)
|
|
1021
|
+
else:
|
|
1022
|
+
entity_table_names = (query.entity_table, )
|
|
952
1023
|
|
|
953
1024
|
# Exclude the entity anchor time from the feature set to prevent
|
|
954
1025
|
# running out-of-distribution between in-context and test examples:
|
|
955
|
-
exclude_cols_dict = query.
|
|
1026
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
956
1027
|
if anchor_time == 'entity':
|
|
957
1028
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
958
1029
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -967,7 +1038,6 @@ class KumoRFM:
|
|
|
967
1038
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
968
1039
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
969
1040
|
]),
|
|
970
|
-
run_mode=run_mode,
|
|
971
1041
|
num_neighbors=num_neighbors,
|
|
972
1042
|
exclude_cols_dict=exclude_cols_dict,
|
|
973
1043
|
)
|
|
@@ -981,7 +1051,7 @@ class KumoRFM:
|
|
|
981
1051
|
|
|
982
1052
|
step_size: Optional[int] = None
|
|
983
1053
|
if query.query_type == QueryType.TEMPORAL:
|
|
984
|
-
step_size = date_offset_to_seconds(
|
|
1054
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
985
1055
|
|
|
986
1056
|
return Context(
|
|
987
1057
|
task_type=task_type,
|