kumoai 2.10.0.dev202510061830__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202511261731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +10 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +153 -10
- kumoai/experimental/rfm/infer/timestamp.py +5 -4
- kumoai/experimental/rfm/local_graph.py +90 -74
- kumoai/experimental/rfm/local_graph_sampler.py +16 -10
- kumoai/experimental/rfm/local_graph_store.py +13 -1
- kumoai/experimental/rfm/local_pquery_driver.py +249 -49
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +277 -223
- kumoai/experimental/rfm/rfm.py +174 -91
- kumoai/experimental/rfm/sagemaker.py +130 -0
- kumoai/jobs.py +1 -0
- kumoai/spcs.py +1 -3
- kumoai/trainer/trainer.py +9 -10
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/METADATA +13 -5
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/RECORD +26 -25
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/WHEEL +0 -0
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,23 +5,32 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
24
|
from kumoapi.rfm import Context
|
|
15
25
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
26
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
19
29
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
30
|
)
|
|
22
31
|
from kumoapi.task import TaskType
|
|
23
32
|
|
|
24
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
25
34
|
from kumoai.exceptions import HTTPException
|
|
26
35
|
from kumoai.experimental.rfm import LocalGraph
|
|
27
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
@@ -30,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
30
39
|
LocalPQueryDriver,
|
|
31
40
|
date_offset_to_seconds,
|
|
32
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
33
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
34
44
|
|
|
35
45
|
_RANDOM_SEED = 42
|
|
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
60
70
|
"beyond this for your use-case.")
|
|
61
71
|
|
|
62
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
63
84
|
@dataclass(repr=False)
|
|
64
85
|
class Explanation:
|
|
65
86
|
prediction: pd.DataFrame
|
|
@@ -87,6 +108,12 @@ class Explanation:
|
|
|
87
108
|
def __repr__(self) -> str:
|
|
88
109
|
return str((self.prediction, self.summary))
|
|
89
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
90
117
|
|
|
91
118
|
class KumoRFM:
|
|
92
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -114,9 +141,9 @@ class KumoRFM:
|
|
|
114
141
|
|
|
115
142
|
rfm = KumoRFM(graph)
|
|
116
143
|
|
|
117
|
-
query = ("PREDICT COUNT(
|
|
118
|
-
"FOR users.user_id=
|
|
119
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
120
147
|
|
|
121
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
122
149
|
# 1 0.85
|
|
@@ -145,9 +172,20 @@ class KumoRFM:
|
|
|
145
172
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
146
173
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
147
174
|
|
|
175
|
+
self._client: Optional[RFMAPI] = None
|
|
176
|
+
|
|
148
177
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
149
178
|
self.num_retries: int = 0
|
|
150
179
|
|
|
180
|
+
@property
|
|
181
|
+
def _api_client(self) -> RFMAPI:
|
|
182
|
+
if self._client is not None:
|
|
183
|
+
return self._client
|
|
184
|
+
|
|
185
|
+
from kumoai.experimental.rfm import global_state
|
|
186
|
+
self._client = RFMAPI(global_state.client)
|
|
187
|
+
return self._client
|
|
188
|
+
|
|
151
189
|
def __repr__(self) -> str:
|
|
152
190
|
return f'{self.__class__.__name__}()'
|
|
153
191
|
|
|
@@ -199,6 +237,7 @@ class KumoRFM:
|
|
|
199
237
|
max_pq_iterations: int = 20,
|
|
200
238
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
201
239
|
verbose: Union[bool, ProgressLogger] = True,
|
|
240
|
+
use_prediction_time: bool = False,
|
|
202
241
|
) -> pd.DataFrame:
|
|
203
242
|
pass
|
|
204
243
|
|
|
@@ -208,7 +247,7 @@ class KumoRFM:
|
|
|
208
247
|
query: str,
|
|
209
248
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
210
249
|
*,
|
|
211
|
-
explain: Literal[True],
|
|
250
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
212
251
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
213
252
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
214
253
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -217,6 +256,7 @@ class KumoRFM:
|
|
|
217
256
|
max_pq_iterations: int = 20,
|
|
218
257
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
219
258
|
verbose: Union[bool, ProgressLogger] = True,
|
|
259
|
+
use_prediction_time: bool = False,
|
|
220
260
|
) -> Explanation:
|
|
221
261
|
pass
|
|
222
262
|
|
|
@@ -225,7 +265,7 @@ class KumoRFM:
|
|
|
225
265
|
query: str,
|
|
226
266
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
227
267
|
*,
|
|
228
|
-
explain: bool = False,
|
|
268
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
229
269
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
230
270
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
231
271
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -234,6 +274,7 @@ class KumoRFM:
|
|
|
234
274
|
max_pq_iterations: int = 20,
|
|
235
275
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
236
276
|
verbose: Union[bool, ProgressLogger] = True,
|
|
277
|
+
use_prediction_time: bool = False,
|
|
237
278
|
) -> Union[pd.DataFrame, Explanation]:
|
|
238
279
|
"""Returns predictions for a predictive query.
|
|
239
280
|
|
|
@@ -244,9 +285,12 @@ class KumoRFM:
|
|
|
244
285
|
be generated for all indices, independent of whether they
|
|
245
286
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
246
287
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
247
|
-
explain:
|
|
248
|
-
|
|
249
|
-
|
|
288
|
+
explain: Configuration for explainability.
|
|
289
|
+
If set to ``True``, will additionally explain the prediction.
|
|
290
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
291
|
+
over which parts of explanation are generated.
|
|
292
|
+
Explainability is currently only supported for single entity
|
|
293
|
+
predictions with ``run_mode="FAST"``.
|
|
250
294
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
251
295
|
``None``, will use the maximum timestamp in the data.
|
|
252
296
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -264,45 +308,54 @@ class KumoRFM:
|
|
|
264
308
|
entities to find valid labels.
|
|
265
309
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
266
310
|
verbose: Whether to print verbose output.
|
|
311
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
312
|
+
additional feature during prediction. This is typically
|
|
313
|
+
beneficial for time series forecasting tasks.
|
|
267
314
|
|
|
268
315
|
Returns:
|
|
269
316
|
The predictions as a :class:`pandas.DataFrame`.
|
|
270
|
-
If ``explain
|
|
271
|
-
|
|
317
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
318
|
+
containing the prediction, summary, and details.
|
|
272
319
|
"""
|
|
320
|
+
explain_config: Optional[ExplainConfig] = None
|
|
321
|
+
if explain is True:
|
|
322
|
+
explain_config = ExplainConfig()
|
|
323
|
+
elif explain is not False:
|
|
324
|
+
explain_config = ExplainConfig._cast(explain)
|
|
325
|
+
|
|
273
326
|
query_def = self._parse_query(query)
|
|
327
|
+
query_str = query_def.to_string()
|
|
274
328
|
|
|
275
329
|
if num_hops != 2 and num_neighbors is not None:
|
|
276
330
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
277
331
|
f"custom 'num_hops={num_hops}' option")
|
|
278
332
|
|
|
279
|
-
if
|
|
333
|
+
if explain_config is not None and run_mode in {
|
|
334
|
+
RunMode.NORMAL, RunMode.BEST
|
|
335
|
+
}:
|
|
280
336
|
warnings.warn(f"Explainability is currently only supported for "
|
|
281
337
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
282
338
|
f"mode has been reset. Please lower the run mode to "
|
|
283
339
|
f"suppress this warning.")
|
|
284
340
|
|
|
285
341
|
if indices is None:
|
|
286
|
-
if query_def.
|
|
342
|
+
if query_def.rfm_entity_ids is None:
|
|
287
343
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
288
344
|
"pass them via `predict(query, indices=...)`")
|
|
289
|
-
indices = query_def.
|
|
345
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
290
346
|
else:
|
|
291
|
-
query_def = replace(
|
|
292
|
-
query_def,
|
|
293
|
-
entity=replace(query_def.entity, ids=None),
|
|
294
|
-
)
|
|
347
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
295
348
|
|
|
296
349
|
if len(indices) == 0:
|
|
297
350
|
raise ValueError("At least one entity is required")
|
|
298
351
|
|
|
299
|
-
if
|
|
352
|
+
if explain_config is not None and len(indices) > 1:
|
|
300
353
|
raise ValueError(
|
|
301
354
|
f"Cannot explain predictions for more than a single entity "
|
|
302
355
|
f"(got {len(indices)})")
|
|
303
356
|
|
|
304
357
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
305
|
-
if
|
|
358
|
+
if explain_config is not None:
|
|
306
359
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
307
360
|
else:
|
|
308
361
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -314,8 +367,8 @@ class KumoRFM:
|
|
|
314
367
|
|
|
315
368
|
batch_size: Optional[int] = None
|
|
316
369
|
if self._batch_size == 'max':
|
|
317
|
-
task_type =
|
|
318
|
-
|
|
370
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
371
|
+
query_def,
|
|
319
372
|
edge_types=self._graph_store.edge_types,
|
|
320
373
|
)
|
|
321
374
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -353,6 +406,8 @@ class KumoRFM:
|
|
|
353
406
|
request = RFMPredictRequest(
|
|
354
407
|
context=context,
|
|
355
408
|
run_mode=RunMode(run_mode),
|
|
409
|
+
query=query_str,
|
|
410
|
+
use_prediction_time=use_prediction_time,
|
|
356
411
|
)
|
|
357
412
|
with warnings.catch_warnings():
|
|
358
413
|
warnings.filterwarnings('ignore', message='gencode')
|
|
@@ -375,29 +430,36 @@ class KumoRFM:
|
|
|
375
430
|
|
|
376
431
|
for attempt in range(self.num_retries + 1):
|
|
377
432
|
try:
|
|
378
|
-
if
|
|
379
|
-
resp =
|
|
433
|
+
if explain_config is not None:
|
|
434
|
+
resp = self._api_client.explain(
|
|
435
|
+
request=_bytes,
|
|
436
|
+
skip_summary=explain_config.skip_summary,
|
|
437
|
+
)
|
|
380
438
|
summary = resp.summary
|
|
381
439
|
details = resp.details
|
|
382
440
|
else:
|
|
383
|
-
resp =
|
|
384
|
-
|
|
441
|
+
resp = self._api_client.predict(_bytes)
|
|
442
|
+
df = pd.DataFrame(**resp.prediction)
|
|
385
443
|
|
|
386
444
|
# Cast 'ENTITY' to correct data type:
|
|
387
|
-
if 'ENTITY' in
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
445
|
+
if 'ENTITY' in df:
|
|
446
|
+
entity = query_def.entity_table
|
|
447
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
448
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
391
449
|
type(pkey_map.index[0]))
|
|
392
450
|
|
|
393
451
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
394
|
-
if
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
452
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
453
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
454
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
455
|
+
if isinstance(ser.iloc[0], str):
|
|
456
|
+
unit = None
|
|
457
|
+
else:
|
|
458
|
+
unit = 'ms'
|
|
459
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
460
|
+
ser, errors='coerce', unit=unit)
|
|
461
|
+
|
|
462
|
+
predictions.append(df)
|
|
401
463
|
|
|
402
464
|
if (isinstance(verbose, InteractiveProgressLogger)
|
|
403
465
|
and len(batches) > 1):
|
|
@@ -423,7 +485,7 @@ class KumoRFM:
|
|
|
423
485
|
else:
|
|
424
486
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
425
487
|
|
|
426
|
-
if
|
|
488
|
+
if explain_config is not None:
|
|
427
489
|
assert len(predictions) == 1
|
|
428
490
|
assert summary is not None
|
|
429
491
|
assert details is not None
|
|
@@ -457,11 +519,11 @@ class KumoRFM:
|
|
|
457
519
|
query_def = self._parse_query(query)
|
|
458
520
|
|
|
459
521
|
if indices is None:
|
|
460
|
-
if query_def.
|
|
522
|
+
if query_def.rfm_entity_ids is None:
|
|
461
523
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
462
524
|
"pass them via "
|
|
463
525
|
"`is_valid_entity(query, indices=...)`")
|
|
464
|
-
indices = query_def.
|
|
526
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
465
527
|
|
|
466
528
|
if len(indices) == 0:
|
|
467
529
|
raise ValueError("At least one entity is required")
|
|
@@ -473,14 +535,13 @@ class KumoRFM:
|
|
|
473
535
|
self._validate_time(query_def, anchor_time, None, False)
|
|
474
536
|
else:
|
|
475
537
|
assert anchor_time == 'entity'
|
|
476
|
-
if (query_def.
|
|
477
|
-
not in self._graph_store.time_dict):
|
|
538
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
478
539
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
479
|
-
f"table '{query_def.
|
|
480
|
-
f"to have a time column")
|
|
540
|
+
f"table '{query_def.entity_table}' "
|
|
541
|
+
f"to have a time column.")
|
|
481
542
|
|
|
482
543
|
node = self._graph_store.get_node_id(
|
|
483
|
-
table_name=query_def.
|
|
544
|
+
table_name=query_def.entity_table,
|
|
484
545
|
pkey=pd.Series(indices),
|
|
485
546
|
)
|
|
486
547
|
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
@@ -499,6 +560,7 @@ class KumoRFM:
|
|
|
499
560
|
max_pq_iterations: int = 20,
|
|
500
561
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
501
562
|
verbose: Union[bool, ProgressLogger] = True,
|
|
563
|
+
use_prediction_time: bool = False,
|
|
502
564
|
) -> pd.DataFrame:
|
|
503
565
|
"""Evaluates a predictive query.
|
|
504
566
|
|
|
@@ -522,6 +584,9 @@ class KumoRFM:
|
|
|
522
584
|
entities to find valid labels.
|
|
523
585
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
524
586
|
verbose: Whether to print verbose output.
|
|
587
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
588
|
+
additional feature during prediction. This is typically
|
|
589
|
+
beneficial for time series forecasting tasks.
|
|
525
590
|
|
|
526
591
|
Returns:
|
|
527
592
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -532,10 +597,10 @@ class KumoRFM:
|
|
|
532
597
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
533
598
|
f"custom 'num_hops={num_hops}' option")
|
|
534
599
|
|
|
535
|
-
if query_def.
|
|
600
|
+
if query_def.rfm_entity_ids is not None:
|
|
536
601
|
query_def = replace(
|
|
537
602
|
query_def,
|
|
538
|
-
|
|
603
|
+
rfm_entity_ids=None,
|
|
539
604
|
)
|
|
540
605
|
|
|
541
606
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -565,6 +630,7 @@ class KumoRFM:
|
|
|
565
630
|
context=context,
|
|
566
631
|
run_mode=RunMode(run_mode),
|
|
567
632
|
metrics=metrics,
|
|
633
|
+
use_prediction_time=use_prediction_time,
|
|
568
634
|
)
|
|
569
635
|
with warnings.catch_warnings():
|
|
570
636
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -575,10 +641,10 @@ class KumoRFM:
|
|
|
575
641
|
|
|
576
642
|
if len(request_bytes) > _MAX_SIZE:
|
|
577
643
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
578
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
644
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
579
645
|
|
|
580
646
|
try:
|
|
581
|
-
resp =
|
|
647
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
582
648
|
except HTTPException as e:
|
|
583
649
|
try:
|
|
584
650
|
msg = json.loads(e.detail)['detail']
|
|
@@ -623,18 +689,19 @@ class KumoRFM:
|
|
|
623
689
|
|
|
624
690
|
if anchor_time is None:
|
|
625
691
|
anchor_time = self._graph_store.max_time
|
|
626
|
-
|
|
627
|
-
|
|
692
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
693
|
+
anchor_time = anchor_time - (
|
|
694
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
695
|
+
query_def.num_forecasts)
|
|
628
696
|
|
|
629
697
|
assert anchor_time is not None
|
|
630
698
|
if isinstance(anchor_time, pd.Timestamp):
|
|
631
699
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
632
700
|
else:
|
|
633
701
|
assert anchor_time == 'entity'
|
|
634
|
-
if (query_def.
|
|
635
|
-
not in self._graph_store.time_dict):
|
|
702
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
636
703
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
637
|
-
f"table '{query_def.
|
|
704
|
+
f"table '{query_def.entity_table}' "
|
|
638
705
|
f"to have a time column")
|
|
639
706
|
|
|
640
707
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -649,7 +716,7 @@ class KumoRFM:
|
|
|
649
716
|
)
|
|
650
717
|
|
|
651
718
|
entity = self._graph_store.pkey_map_dict[
|
|
652
|
-
query_def.
|
|
719
|
+
query_def.entity_table].index[node]
|
|
653
720
|
|
|
654
721
|
return pd.DataFrame({
|
|
655
722
|
'ENTITY': entity,
|
|
@@ -659,8 +726,8 @@ class KumoRFM:
|
|
|
659
726
|
|
|
660
727
|
# Helpers #################################################################
|
|
661
728
|
|
|
662
|
-
def _parse_query(self, query: str) ->
|
|
663
|
-
if isinstance(query,
|
|
729
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
730
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
664
731
|
return query
|
|
665
732
|
|
|
666
733
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -670,12 +737,13 @@ class KumoRFM:
|
|
|
670
737
|
"predictions or evaluations.")
|
|
671
738
|
|
|
672
739
|
try:
|
|
673
|
-
request =
|
|
740
|
+
request = RFMParseQueryRequest(
|
|
674
741
|
query=query,
|
|
675
742
|
graph_definition=self._graph_def,
|
|
676
743
|
)
|
|
677
744
|
|
|
678
|
-
resp =
|
|
745
|
+
resp = self._api_client.parse_query(request)
|
|
746
|
+
|
|
679
747
|
# TODO Expose validation warnings.
|
|
680
748
|
|
|
681
749
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -686,7 +754,7 @@ class KumoRFM:
|
|
|
686
754
|
warnings.warn(f"Encountered the following warnings during "
|
|
687
755
|
f"parsing:\n{msg}")
|
|
688
756
|
|
|
689
|
-
return resp.
|
|
757
|
+
return resp.query
|
|
690
758
|
except HTTPException as e:
|
|
691
759
|
try:
|
|
692
760
|
msg = json.loads(e.detail)['detail']
|
|
@@ -697,7 +765,7 @@ class KumoRFM:
|
|
|
697
765
|
|
|
698
766
|
def _validate_time(
|
|
699
767
|
self,
|
|
700
|
-
query:
|
|
768
|
+
query: ValidatedPredictiveQuery,
|
|
701
769
|
anchor_time: pd.Timestamp,
|
|
702
770
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
703
771
|
evaluate: bool,
|
|
@@ -720,6 +788,11 @@ class KumoRFM:
|
|
|
720
788
|
f"only contains data back to "
|
|
721
789
|
f"'{self._graph_store.min_time}'.")
|
|
722
790
|
|
|
791
|
+
if query.target_ast.date_offset_range is not None:
|
|
792
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
793
|
+
else:
|
|
794
|
+
end_offset = pd.DateOffset(0)
|
|
795
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
723
796
|
if (context_anchor_time is not None
|
|
724
797
|
and context_anchor_time > anchor_time):
|
|
725
798
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -728,19 +801,18 @@ class KumoRFM:
|
|
|
728
801
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
729
802
|
f"intended.")
|
|
730
803
|
elif (query.query_type == QueryType.TEMPORAL
|
|
731
|
-
and context_anchor_time is not None
|
|
732
|
-
|
|
804
|
+
and context_anchor_time is not None
|
|
805
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
733
806
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
734
807
|
f"'{context_anchor_time}' will leak information "
|
|
735
808
|
f"from the prediction anchor timestamp "
|
|
736
809
|
f"'{anchor_time}'. Please make sure this is "
|
|
737
810
|
f"intended.")
|
|
738
811
|
|
|
739
|
-
elif (context_anchor_time is not None
|
|
740
|
-
|
|
812
|
+
elif (context_anchor_time is not None
|
|
813
|
+
and context_anchor_time - forecast_end_offset
|
|
741
814
|
< self._graph_store.min_time):
|
|
742
|
-
_time = context_anchor_time -
|
|
743
|
-
query.num_forecasts)
|
|
815
|
+
_time = context_anchor_time - forecast_end_offset
|
|
744
816
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
745
817
|
f"aggregation time range is too large. To form "
|
|
746
818
|
f"proper input data, we would need data back to "
|
|
@@ -753,8 +825,7 @@ class KumoRFM:
|
|
|
753
825
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
754
826
|
f"in the data. Please make sure this is intended.")
|
|
755
827
|
|
|
756
|
-
max_eval_time =
|
|
757
|
-
query.target.end_offset * query.num_forecasts)
|
|
828
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
758
829
|
if evaluate and anchor_time > max_eval_time:
|
|
759
830
|
raise ValueError(
|
|
760
831
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -762,7 +833,7 @@ class KumoRFM:
|
|
|
762
833
|
|
|
763
834
|
def _get_context(
|
|
764
835
|
self,
|
|
765
|
-
query:
|
|
836
|
+
query: ValidatedPredictiveQuery,
|
|
766
837
|
indices: Union[List[str], List[float], List[int], None],
|
|
767
838
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
768
839
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -790,8 +861,8 @@ class KumoRFM:
|
|
|
790
861
|
f"must go beyond this for your use-case.")
|
|
791
862
|
|
|
792
863
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
793
|
-
task_type =
|
|
794
|
-
|
|
864
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
865
|
+
query,
|
|
795
866
|
edge_types=self._graph_store.edge_types,
|
|
796
867
|
)
|
|
797
868
|
|
|
@@ -823,11 +894,15 @@ class KumoRFM:
|
|
|
823
894
|
else:
|
|
824
895
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
825
896
|
|
|
897
|
+
if query.target_ast.date_offset_range is None:
|
|
898
|
+
end_offset = pd.DateOffset(0)
|
|
899
|
+
else:
|
|
900
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
901
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
826
902
|
if anchor_time is None:
|
|
827
903
|
anchor_time = self._graph_store.max_time
|
|
828
904
|
if evaluate:
|
|
829
|
-
anchor_time = anchor_time -
|
|
830
|
-
query.num_forecasts)
|
|
905
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
831
906
|
if logger is not None:
|
|
832
907
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
833
908
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -842,15 +917,14 @@ class KumoRFM:
|
|
|
842
917
|
assert anchor_time is not None
|
|
843
918
|
if isinstance(anchor_time, pd.Timestamp):
|
|
844
919
|
if context_anchor_time is None:
|
|
845
|
-
context_anchor_time = anchor_time -
|
|
846
|
-
query.num_forecasts)
|
|
920
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
847
921
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
848
922
|
evaluate)
|
|
849
923
|
else:
|
|
850
924
|
assert anchor_time == 'entity'
|
|
851
|
-
if query.
|
|
925
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
852
926
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
853
|
-
f"table '{query.
|
|
927
|
+
f"table '{query.entity_table}' to "
|
|
854
928
|
f"have a time column")
|
|
855
929
|
if context_anchor_time is not None:
|
|
856
930
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -901,7 +975,7 @@ class KumoRFM:
|
|
|
901
975
|
f"in batches")
|
|
902
976
|
|
|
903
977
|
test_node = self._graph_store.get_node_id(
|
|
904
|
-
table_name=query.
|
|
978
|
+
table_name=query.entity_table,
|
|
905
979
|
pkey=pd.Series(indices),
|
|
906
980
|
)
|
|
907
981
|
|
|
@@ -909,8 +983,7 @@ class KumoRFM:
|
|
|
909
983
|
test_time = pd.Series(anchor_time).repeat(
|
|
910
984
|
len(test_node)).reset_index(drop=True)
|
|
911
985
|
else:
|
|
912
|
-
time = self._graph_store.time_dict[
|
|
913
|
-
query.entity.pkey.table_name]
|
|
986
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
914
987
|
time = time[test_node] * 1000**3
|
|
915
988
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
916
989
|
|
|
@@ -943,12 +1016,23 @@ class KumoRFM:
|
|
|
943
1016
|
raise NotImplementedError
|
|
944
1017
|
logger.log(msg)
|
|
945
1018
|
|
|
946
|
-
entity_table_names
|
|
947
|
-
|
|
1019
|
+
entity_table_names: Tuple[str, ...]
|
|
1020
|
+
if task_type.is_link_pred:
|
|
1021
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1022
|
+
assert final_aggr is not None
|
|
1023
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1024
|
+
for edge_type in self._graph_store.edge_types:
|
|
1025
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1026
|
+
entity_table_names = (
|
|
1027
|
+
query.entity_table,
|
|
1028
|
+
edge_type[2],
|
|
1029
|
+
)
|
|
1030
|
+
else:
|
|
1031
|
+
entity_table_names = (query.entity_table, )
|
|
948
1032
|
|
|
949
1033
|
# Exclude the entity anchor time from the feature set to prevent
|
|
950
1034
|
# running out-of-distribution between in-context and test examples:
|
|
951
|
-
exclude_cols_dict = query.
|
|
1035
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
952
1036
|
if anchor_time == 'entity':
|
|
953
1037
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
954
1038
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -963,7 +1047,6 @@ class KumoRFM:
|
|
|
963
1047
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
964
1048
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
965
1049
|
]),
|
|
966
|
-
run_mode=run_mode,
|
|
967
1050
|
num_neighbors=num_neighbors,
|
|
968
1051
|
exclude_cols_dict=exclude_cols_dict,
|
|
969
1052
|
)
|
|
@@ -977,7 +1060,7 @@ class KumoRFM:
|
|
|
977
1060
|
|
|
978
1061
|
step_size: Optional[int] = None
|
|
979
1062
|
if query.query_type == QueryType.TEMPORAL:
|
|
980
|
-
step_size = date_offset_to_seconds(
|
|
1063
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
981
1064
|
|
|
982
1065
|
return Context(
|
|
983
1066
|
task_type=task_type,
|
|
@@ -1002,7 +1085,7 @@ class KumoRFM:
|
|
|
1002
1085
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1003
1086
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
1004
1087
|
elif task_type == TaskType.REGRESSION:
|
|
1005
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1088
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
1006
1089
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1007
1090
|
supported_metrics = [
|
|
1008
1091
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|