kumoai 2.12.0.dev202511031731__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202511261731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +6 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +151 -8
- kumoai/experimental/rfm/local_graph_sampler.py +0 -2
- kumoai/experimental/rfm/local_pquery_driver.py +221 -26
- kumoai/experimental/rfm/pquery/__init__.py +0 -4
- kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
- kumoai/experimental/rfm/rfm.py +146 -79
- kumoai/experimental/rfm/sagemaker.py +130 -0
- kumoai/spcs.py +1 -3
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/METADATA +11 -2
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/RECORD +19 -20
- kumoai/experimental/rfm/pquery/backend.py +0 -136
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,23 +5,32 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
24
|
from kumoapi.rfm import Context
|
|
15
25
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
26
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
19
29
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
30
|
)
|
|
22
31
|
from kumoapi.task import TaskType
|
|
23
32
|
|
|
24
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
25
34
|
from kumoai.exceptions import HTTPException
|
|
26
35
|
from kumoai.experimental.rfm import LocalGraph
|
|
27
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
@@ -30,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
30
39
|
LocalPQueryDriver,
|
|
31
40
|
date_offset_to_seconds,
|
|
32
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
33
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
34
44
|
|
|
35
45
|
_RANDOM_SEED = 42
|
|
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
60
70
|
"beyond this for your use-case.")
|
|
61
71
|
|
|
62
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
63
84
|
@dataclass(repr=False)
|
|
64
85
|
class Explanation:
|
|
65
86
|
prediction: pd.DataFrame
|
|
@@ -87,6 +108,12 @@ class Explanation:
|
|
|
87
108
|
def __repr__(self) -> str:
|
|
88
109
|
return str((self.prediction, self.summary))
|
|
89
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
90
117
|
|
|
91
118
|
class KumoRFM:
|
|
92
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -114,9 +141,9 @@ class KumoRFM:
|
|
|
114
141
|
|
|
115
142
|
rfm = KumoRFM(graph)
|
|
116
143
|
|
|
117
|
-
query = ("PREDICT COUNT(
|
|
118
|
-
"FOR users.user_id=
|
|
119
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
120
147
|
|
|
121
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
122
149
|
# 1 0.85
|
|
@@ -145,9 +172,20 @@ class KumoRFM:
|
|
|
145
172
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
146
173
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
147
174
|
|
|
175
|
+
self._client: Optional[RFMAPI] = None
|
|
176
|
+
|
|
148
177
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
149
178
|
self.num_retries: int = 0
|
|
150
179
|
|
|
180
|
+
@property
|
|
181
|
+
def _api_client(self) -> RFMAPI:
|
|
182
|
+
if self._client is not None:
|
|
183
|
+
return self._client
|
|
184
|
+
|
|
185
|
+
from kumoai.experimental.rfm import global_state
|
|
186
|
+
self._client = RFMAPI(global_state.client)
|
|
187
|
+
return self._client
|
|
188
|
+
|
|
151
189
|
def __repr__(self) -> str:
|
|
152
190
|
return f'{self.__class__.__name__}()'
|
|
153
191
|
|
|
@@ -209,7 +247,7 @@ class KumoRFM:
|
|
|
209
247
|
query: str,
|
|
210
248
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
211
249
|
*,
|
|
212
|
-
explain: Literal[True],
|
|
250
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
213
251
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
214
252
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
215
253
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -227,7 +265,7 @@ class KumoRFM:
|
|
|
227
265
|
query: str,
|
|
228
266
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
229
267
|
*,
|
|
230
|
-
explain: bool = False,
|
|
268
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
231
269
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
232
270
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
233
271
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -247,9 +285,12 @@ class KumoRFM:
|
|
|
247
285
|
be generated for all indices, independent of whether they
|
|
248
286
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
249
287
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
250
|
-
explain:
|
|
251
|
-
|
|
252
|
-
|
|
288
|
+
explain: Configuration for explainability.
|
|
289
|
+
If set to ``True``, will additionally explain the prediction.
|
|
290
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
291
|
+
over which parts of explanation are generated.
|
|
292
|
+
Explainability is currently only supported for single entity
|
|
293
|
+
predictions with ``run_mode="FAST"``.
|
|
253
294
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
254
295
|
``None``, will use the maximum timestamp in the data.
|
|
255
296
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -273,42 +314,48 @@ class KumoRFM:
|
|
|
273
314
|
|
|
274
315
|
Returns:
|
|
275
316
|
The predictions as a :class:`pandas.DataFrame`.
|
|
276
|
-
If ``explain
|
|
277
|
-
|
|
317
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
318
|
+
containing the prediction, summary, and details.
|
|
278
319
|
"""
|
|
320
|
+
explain_config: Optional[ExplainConfig] = None
|
|
321
|
+
if explain is True:
|
|
322
|
+
explain_config = ExplainConfig()
|
|
323
|
+
elif explain is not False:
|
|
324
|
+
explain_config = ExplainConfig._cast(explain)
|
|
325
|
+
|
|
279
326
|
query_def = self._parse_query(query)
|
|
327
|
+
query_str = query_def.to_string()
|
|
280
328
|
|
|
281
329
|
if num_hops != 2 and num_neighbors is not None:
|
|
282
330
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
283
331
|
f"custom 'num_hops={num_hops}' option")
|
|
284
332
|
|
|
285
|
-
if
|
|
333
|
+
if explain_config is not None and run_mode in {
|
|
334
|
+
RunMode.NORMAL, RunMode.BEST
|
|
335
|
+
}:
|
|
286
336
|
warnings.warn(f"Explainability is currently only supported for "
|
|
287
337
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
288
338
|
f"mode has been reset. Please lower the run mode to "
|
|
289
339
|
f"suppress this warning.")
|
|
290
340
|
|
|
291
341
|
if indices is None:
|
|
292
|
-
if query_def.
|
|
342
|
+
if query_def.rfm_entity_ids is None:
|
|
293
343
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
294
344
|
"pass them via `predict(query, indices=...)`")
|
|
295
|
-
indices = query_def.
|
|
345
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
296
346
|
else:
|
|
297
|
-
query_def = replace(
|
|
298
|
-
query_def,
|
|
299
|
-
entity=replace(query_def.entity, ids=None),
|
|
300
|
-
)
|
|
347
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
301
348
|
|
|
302
349
|
if len(indices) == 0:
|
|
303
350
|
raise ValueError("At least one entity is required")
|
|
304
351
|
|
|
305
|
-
if
|
|
352
|
+
if explain_config is not None and len(indices) > 1:
|
|
306
353
|
raise ValueError(
|
|
307
354
|
f"Cannot explain predictions for more than a single entity "
|
|
308
355
|
f"(got {len(indices)})")
|
|
309
356
|
|
|
310
357
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
311
|
-
if
|
|
358
|
+
if explain_config is not None:
|
|
312
359
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
313
360
|
else:
|
|
314
361
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -320,8 +367,8 @@ class KumoRFM:
|
|
|
320
367
|
|
|
321
368
|
batch_size: Optional[int] = None
|
|
322
369
|
if self._batch_size == 'max':
|
|
323
|
-
task_type =
|
|
324
|
-
|
|
370
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
371
|
+
query_def,
|
|
325
372
|
edge_types=self._graph_store.edge_types,
|
|
326
373
|
)
|
|
327
374
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -359,6 +406,7 @@ class KumoRFM:
|
|
|
359
406
|
request = RFMPredictRequest(
|
|
360
407
|
context=context,
|
|
361
408
|
run_mode=RunMode(run_mode),
|
|
409
|
+
query=query_str,
|
|
362
410
|
use_prediction_time=use_prediction_time,
|
|
363
411
|
)
|
|
364
412
|
with warnings.catch_warnings():
|
|
@@ -382,17 +430,20 @@ class KumoRFM:
|
|
|
382
430
|
|
|
383
431
|
for attempt in range(self.num_retries + 1):
|
|
384
432
|
try:
|
|
385
|
-
if
|
|
386
|
-
resp =
|
|
433
|
+
if explain_config is not None:
|
|
434
|
+
resp = self._api_client.explain(
|
|
435
|
+
request=_bytes,
|
|
436
|
+
skip_summary=explain_config.skip_summary,
|
|
437
|
+
)
|
|
387
438
|
summary = resp.summary
|
|
388
439
|
details = resp.details
|
|
389
440
|
else:
|
|
390
|
-
resp =
|
|
441
|
+
resp = self._api_client.predict(_bytes)
|
|
391
442
|
df = pd.DataFrame(**resp.prediction)
|
|
392
443
|
|
|
393
444
|
# Cast 'ENTITY' to correct data type:
|
|
394
445
|
if 'ENTITY' in df:
|
|
395
|
-
entity = query_def.
|
|
446
|
+
entity = query_def.entity_table
|
|
396
447
|
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
397
448
|
df['ENTITY'] = df['ENTITY'].astype(
|
|
398
449
|
type(pkey_map.index[0]))
|
|
@@ -434,7 +485,7 @@ class KumoRFM:
|
|
|
434
485
|
else:
|
|
435
486
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
436
487
|
|
|
437
|
-
if
|
|
488
|
+
if explain_config is not None:
|
|
438
489
|
assert len(predictions) == 1
|
|
439
490
|
assert summary is not None
|
|
440
491
|
assert details is not None
|
|
@@ -468,11 +519,11 @@ class KumoRFM:
|
|
|
468
519
|
query_def = self._parse_query(query)
|
|
469
520
|
|
|
470
521
|
if indices is None:
|
|
471
|
-
if query_def.
|
|
522
|
+
if query_def.rfm_entity_ids is None:
|
|
472
523
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
473
524
|
"pass them via "
|
|
474
525
|
"`is_valid_entity(query, indices=...)`")
|
|
475
|
-
indices = query_def.
|
|
526
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
476
527
|
|
|
477
528
|
if len(indices) == 0:
|
|
478
529
|
raise ValueError("At least one entity is required")
|
|
@@ -484,14 +535,13 @@ class KumoRFM:
|
|
|
484
535
|
self._validate_time(query_def, anchor_time, None, False)
|
|
485
536
|
else:
|
|
486
537
|
assert anchor_time == 'entity'
|
|
487
|
-
if (query_def.
|
|
488
|
-
not in self._graph_store.time_dict):
|
|
538
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
489
539
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
490
|
-
f"table '{query_def.
|
|
491
|
-
f"to have a time column")
|
|
540
|
+
f"table '{query_def.entity_table}' "
|
|
541
|
+
f"to have a time column.")
|
|
492
542
|
|
|
493
543
|
node = self._graph_store.get_node_id(
|
|
494
|
-
table_name=query_def.
|
|
544
|
+
table_name=query_def.entity_table,
|
|
495
545
|
pkey=pd.Series(indices),
|
|
496
546
|
)
|
|
497
547
|
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
@@ -547,10 +597,10 @@ class KumoRFM:
|
|
|
547
597
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
548
598
|
f"custom 'num_hops={num_hops}' option")
|
|
549
599
|
|
|
550
|
-
if query_def.
|
|
600
|
+
if query_def.rfm_entity_ids is not None:
|
|
551
601
|
query_def = replace(
|
|
552
602
|
query_def,
|
|
553
|
-
|
|
603
|
+
rfm_entity_ids=None,
|
|
554
604
|
)
|
|
555
605
|
|
|
556
606
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -591,10 +641,10 @@ class KumoRFM:
|
|
|
591
641
|
|
|
592
642
|
if len(request_bytes) > _MAX_SIZE:
|
|
593
643
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
594
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
644
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
595
645
|
|
|
596
646
|
try:
|
|
597
|
-
resp =
|
|
647
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
598
648
|
except HTTPException as e:
|
|
599
649
|
try:
|
|
600
650
|
msg = json.loads(e.detail)['detail']
|
|
@@ -639,18 +689,19 @@ class KumoRFM:
|
|
|
639
689
|
|
|
640
690
|
if anchor_time is None:
|
|
641
691
|
anchor_time = self._graph_store.max_time
|
|
642
|
-
|
|
643
|
-
|
|
692
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
693
|
+
anchor_time = anchor_time - (
|
|
694
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
695
|
+
query_def.num_forecasts)
|
|
644
696
|
|
|
645
697
|
assert anchor_time is not None
|
|
646
698
|
if isinstance(anchor_time, pd.Timestamp):
|
|
647
699
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
648
700
|
else:
|
|
649
701
|
assert anchor_time == 'entity'
|
|
650
|
-
if (query_def.
|
|
651
|
-
not in self._graph_store.time_dict):
|
|
702
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
652
703
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
653
|
-
f"table '{query_def.
|
|
704
|
+
f"table '{query_def.entity_table}' "
|
|
654
705
|
f"to have a time column")
|
|
655
706
|
|
|
656
707
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -665,7 +716,7 @@ class KumoRFM:
|
|
|
665
716
|
)
|
|
666
717
|
|
|
667
718
|
entity = self._graph_store.pkey_map_dict[
|
|
668
|
-
query_def.
|
|
719
|
+
query_def.entity_table].index[node]
|
|
669
720
|
|
|
670
721
|
return pd.DataFrame({
|
|
671
722
|
'ENTITY': entity,
|
|
@@ -675,8 +726,8 @@ class KumoRFM:
|
|
|
675
726
|
|
|
676
727
|
# Helpers #################################################################
|
|
677
728
|
|
|
678
|
-
def _parse_query(self, query: str) ->
|
|
679
|
-
if isinstance(query,
|
|
729
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
730
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
680
731
|
return query
|
|
681
732
|
|
|
682
733
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -686,12 +737,13 @@ class KumoRFM:
|
|
|
686
737
|
"predictions or evaluations.")
|
|
687
738
|
|
|
688
739
|
try:
|
|
689
|
-
request =
|
|
740
|
+
request = RFMParseQueryRequest(
|
|
690
741
|
query=query,
|
|
691
742
|
graph_definition=self._graph_def,
|
|
692
743
|
)
|
|
693
744
|
|
|
694
|
-
resp =
|
|
745
|
+
resp = self._api_client.parse_query(request)
|
|
746
|
+
|
|
695
747
|
# TODO Expose validation warnings.
|
|
696
748
|
|
|
697
749
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -702,7 +754,7 @@ class KumoRFM:
|
|
|
702
754
|
warnings.warn(f"Encountered the following warnings during "
|
|
703
755
|
f"parsing:\n{msg}")
|
|
704
756
|
|
|
705
|
-
return resp.
|
|
757
|
+
return resp.query
|
|
706
758
|
except HTTPException as e:
|
|
707
759
|
try:
|
|
708
760
|
msg = json.loads(e.detail)['detail']
|
|
@@ -713,7 +765,7 @@ class KumoRFM:
|
|
|
713
765
|
|
|
714
766
|
def _validate_time(
|
|
715
767
|
self,
|
|
716
|
-
query:
|
|
768
|
+
query: ValidatedPredictiveQuery,
|
|
717
769
|
anchor_time: pd.Timestamp,
|
|
718
770
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
719
771
|
evaluate: bool,
|
|
@@ -736,6 +788,11 @@ class KumoRFM:
|
|
|
736
788
|
f"only contains data back to "
|
|
737
789
|
f"'{self._graph_store.min_time}'.")
|
|
738
790
|
|
|
791
|
+
if query.target_ast.date_offset_range is not None:
|
|
792
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
793
|
+
else:
|
|
794
|
+
end_offset = pd.DateOffset(0)
|
|
795
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
739
796
|
if (context_anchor_time is not None
|
|
740
797
|
and context_anchor_time > anchor_time):
|
|
741
798
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -744,19 +801,18 @@ class KumoRFM:
|
|
|
744
801
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
745
802
|
f"intended.")
|
|
746
803
|
elif (query.query_type == QueryType.TEMPORAL
|
|
747
|
-
and context_anchor_time is not None
|
|
748
|
-
|
|
804
|
+
and context_anchor_time is not None
|
|
805
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
749
806
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
750
807
|
f"'{context_anchor_time}' will leak information "
|
|
751
808
|
f"from the prediction anchor timestamp "
|
|
752
809
|
f"'{anchor_time}'. Please make sure this is "
|
|
753
810
|
f"intended.")
|
|
754
811
|
|
|
755
|
-
elif (context_anchor_time is not None
|
|
756
|
-
|
|
812
|
+
elif (context_anchor_time is not None
|
|
813
|
+
and context_anchor_time - forecast_end_offset
|
|
757
814
|
< self._graph_store.min_time):
|
|
758
|
-
_time = context_anchor_time -
|
|
759
|
-
query.num_forecasts)
|
|
815
|
+
_time = context_anchor_time - forecast_end_offset
|
|
760
816
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
761
817
|
f"aggregation time range is too large. To form "
|
|
762
818
|
f"proper input data, we would need data back to "
|
|
@@ -769,8 +825,7 @@ class KumoRFM:
|
|
|
769
825
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
770
826
|
f"in the data. Please make sure this is intended.")
|
|
771
827
|
|
|
772
|
-
max_eval_time =
|
|
773
|
-
query.target.end_offset * query.num_forecasts)
|
|
828
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
774
829
|
if evaluate and anchor_time > max_eval_time:
|
|
775
830
|
raise ValueError(
|
|
776
831
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -778,7 +833,7 @@ class KumoRFM:
|
|
|
778
833
|
|
|
779
834
|
def _get_context(
|
|
780
835
|
self,
|
|
781
|
-
query:
|
|
836
|
+
query: ValidatedPredictiveQuery,
|
|
782
837
|
indices: Union[List[str], List[float], List[int], None],
|
|
783
838
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
784
839
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -806,8 +861,8 @@ class KumoRFM:
|
|
|
806
861
|
f"must go beyond this for your use-case.")
|
|
807
862
|
|
|
808
863
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
809
|
-
task_type =
|
|
810
|
-
|
|
864
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
865
|
+
query,
|
|
811
866
|
edge_types=self._graph_store.edge_types,
|
|
812
867
|
)
|
|
813
868
|
|
|
@@ -839,11 +894,15 @@ class KumoRFM:
|
|
|
839
894
|
else:
|
|
840
895
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
841
896
|
|
|
897
|
+
if query.target_ast.date_offset_range is None:
|
|
898
|
+
end_offset = pd.DateOffset(0)
|
|
899
|
+
else:
|
|
900
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
901
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
842
902
|
if anchor_time is None:
|
|
843
903
|
anchor_time = self._graph_store.max_time
|
|
844
904
|
if evaluate:
|
|
845
|
-
anchor_time = anchor_time -
|
|
846
|
-
query.num_forecasts)
|
|
905
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
847
906
|
if logger is not None:
|
|
848
907
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
849
908
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -858,15 +917,14 @@ class KumoRFM:
|
|
|
858
917
|
assert anchor_time is not None
|
|
859
918
|
if isinstance(anchor_time, pd.Timestamp):
|
|
860
919
|
if context_anchor_time is None:
|
|
861
|
-
context_anchor_time = anchor_time -
|
|
862
|
-
query.num_forecasts)
|
|
920
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
863
921
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
864
922
|
evaluate)
|
|
865
923
|
else:
|
|
866
924
|
assert anchor_time == 'entity'
|
|
867
|
-
if query.
|
|
925
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
868
926
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
869
|
-
f"table '{query.
|
|
927
|
+
f"table '{query.entity_table}' to "
|
|
870
928
|
f"have a time column")
|
|
871
929
|
if context_anchor_time is not None:
|
|
872
930
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -917,7 +975,7 @@ class KumoRFM:
|
|
|
917
975
|
f"in batches")
|
|
918
976
|
|
|
919
977
|
test_node = self._graph_store.get_node_id(
|
|
920
|
-
table_name=query.
|
|
978
|
+
table_name=query.entity_table,
|
|
921
979
|
pkey=pd.Series(indices),
|
|
922
980
|
)
|
|
923
981
|
|
|
@@ -925,8 +983,7 @@ class KumoRFM:
|
|
|
925
983
|
test_time = pd.Series(anchor_time).repeat(
|
|
926
984
|
len(test_node)).reset_index(drop=True)
|
|
927
985
|
else:
|
|
928
|
-
time = self._graph_store.time_dict[
|
|
929
|
-
query.entity.pkey.table_name]
|
|
986
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
930
987
|
time = time[test_node] * 1000**3
|
|
931
988
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
932
989
|
|
|
@@ -959,12 +1016,23 @@ class KumoRFM:
|
|
|
959
1016
|
raise NotImplementedError
|
|
960
1017
|
logger.log(msg)
|
|
961
1018
|
|
|
962
|
-
entity_table_names
|
|
963
|
-
|
|
1019
|
+
entity_table_names: Tuple[str, ...]
|
|
1020
|
+
if task_type.is_link_pred:
|
|
1021
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1022
|
+
assert final_aggr is not None
|
|
1023
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1024
|
+
for edge_type in self._graph_store.edge_types:
|
|
1025
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1026
|
+
entity_table_names = (
|
|
1027
|
+
query.entity_table,
|
|
1028
|
+
edge_type[2],
|
|
1029
|
+
)
|
|
1030
|
+
else:
|
|
1031
|
+
entity_table_names = (query.entity_table, )
|
|
964
1032
|
|
|
965
1033
|
# Exclude the entity anchor time from the feature set to prevent
|
|
966
1034
|
# running out-of-distribution between in-context and test examples:
|
|
967
|
-
exclude_cols_dict = query.
|
|
1035
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
968
1036
|
if anchor_time == 'entity':
|
|
969
1037
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
970
1038
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -979,7 +1047,6 @@ class KumoRFM:
|
|
|
979
1047
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
980
1048
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
981
1049
|
]),
|
|
982
|
-
run_mode=run_mode,
|
|
983
1050
|
num_neighbors=num_neighbors,
|
|
984
1051
|
exclude_cols_dict=exclude_cols_dict,
|
|
985
1052
|
)
|
|
@@ -993,7 +1060,7 @@ class KumoRFM:
|
|
|
993
1060
|
|
|
994
1061
|
step_size: Optional[int] = None
|
|
995
1062
|
if query.query_type == QueryType.TEMPORAL:
|
|
996
|
-
step_size = date_offset_to_seconds(
|
|
1063
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
997
1064
|
|
|
998
1065
|
return Context(
|
|
999
1066
|
task_type=task_type,
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import boto3
|
|
6
|
+
import requests
|
|
7
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
8
|
+
from mypy_boto3_sagemaker_runtime.type_defs import InvokeEndpointOutputTypeDef
|
|
9
|
+
|
|
10
|
+
from kumoai.client import KumoClient
|
|
11
|
+
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
12
|
+
from kumoai.exceptions import HTTPException
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SageMakerResponseAdapter(requests.Response):
|
|
16
|
+
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
17
|
+
super().__init__()
|
|
18
|
+
# Read the body bytes
|
|
19
|
+
self._content = sm_response['Body'].read()
|
|
20
|
+
self.status_code = 200
|
|
21
|
+
self.headers['Content-Type'] = sm_response.get('ContentType',
|
|
22
|
+
'application/json')
|
|
23
|
+
# Optionally, you can store original sm_response for debugging
|
|
24
|
+
self.sm_response = sm_response
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def text(self) -> str:
|
|
28
|
+
assert isinstance(self._content, bytes)
|
|
29
|
+
return self._content.decode('utf-8')
|
|
30
|
+
|
|
31
|
+
def json(self, **kwargs) -> dict[str, Any]: # type: ignore
|
|
32
|
+
return json.loads(self.text, **kwargs)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class KumoClient_SageMakerAdapter(KumoClient):
|
|
36
|
+
def __init__(self, region: str, endpoint_name: str):
|
|
37
|
+
self._client: SageMakerRuntimeClient = boto3.client(
|
|
38
|
+
service_name="sagemaker-runtime", region_name=region)
|
|
39
|
+
self._endpoint_name = endpoint_name
|
|
40
|
+
|
|
41
|
+
# Recording buffers.
|
|
42
|
+
self._recording_active = False
|
|
43
|
+
self._recorded_reqs: List[Dict[str, Any]] = []
|
|
44
|
+
self._recorded_resps: List[Dict[str, Any]] = []
|
|
45
|
+
|
|
46
|
+
def authenticate(self) -> None:
|
|
47
|
+
# TODO(siyang): call /ping to verify?
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
51
|
+
assert endpoint.method == HTTPMethod.POST
|
|
52
|
+
if 'json' in kwargs:
|
|
53
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
54
|
+
elif 'data' in kwargs:
|
|
55
|
+
raw_payload = kwargs.pop('data')
|
|
56
|
+
assert isinstance(raw_payload, bytes)
|
|
57
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
58
|
+
else:
|
|
59
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
60
|
+
|
|
61
|
+
request = {
|
|
62
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
63
|
+
'payload': payload,
|
|
64
|
+
}
|
|
65
|
+
response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
|
|
66
|
+
EndpointName=self._endpoint_name,
|
|
67
|
+
ContentType="application/json",
|
|
68
|
+
Body=json.dumps(request),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
adapted_response = SageMakerResponseAdapter(response)
|
|
72
|
+
|
|
73
|
+
# If validation is active, store input/output
|
|
74
|
+
if self._recording_active:
|
|
75
|
+
self._recorded_reqs.append(request)
|
|
76
|
+
self._recorded_resps.append(adapted_response.json())
|
|
77
|
+
|
|
78
|
+
return adapted_response
|
|
79
|
+
|
|
80
|
+
def start_recording(self) -> None:
|
|
81
|
+
"""Start recording requests/responses to/from sagemaker endpoint."""
|
|
82
|
+
assert not self._recording_active
|
|
83
|
+
self._recording_active = True
|
|
84
|
+
self._recorded_reqs.clear()
|
|
85
|
+
self._recorded_resps.clear()
|
|
86
|
+
|
|
87
|
+
def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
|
88
|
+
"""Stop recording and return recorded requests/responses."""
|
|
89
|
+
assert self._recording_active
|
|
90
|
+
self._recording_active = False
|
|
91
|
+
recorded = list(zip(self._recorded_reqs, self._recorded_resps))
|
|
92
|
+
self._recorded_reqs.clear()
|
|
93
|
+
self._recorded_resps.clear()
|
|
94
|
+
return recorded
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class KumoClient_SageMakerProxy_Local(KumoClient):
|
|
98
|
+
def __init__(self, url: str):
|
|
99
|
+
self._client = KumoClient(url, api_key=None)
|
|
100
|
+
self._client._api_url = self._client._url
|
|
101
|
+
self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
|
|
102
|
+
|
|
103
|
+
def authenticate(self) -> None:
|
|
104
|
+
try:
|
|
105
|
+
self._client._session.get(
|
|
106
|
+
self._url + '/ping',
|
|
107
|
+
verify=self._verify_ssl).raise_for_status()
|
|
108
|
+
except Exception:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Client authentication failed. Please check if you "
|
|
111
|
+
"have a valid API key/credentials.")
|
|
112
|
+
|
|
113
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
114
|
+
assert endpoint.method == HTTPMethod.POST
|
|
115
|
+
if 'json' in kwargs:
|
|
116
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
117
|
+
elif 'data' in kwargs:
|
|
118
|
+
raw_payload = kwargs.pop('data')
|
|
119
|
+
assert isinstance(raw_payload, bytes)
|
|
120
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
121
|
+
else:
|
|
122
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
123
|
+
return self._client._request(
|
|
124
|
+
self._endpoint,
|
|
125
|
+
json={
|
|
126
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
127
|
+
'payload': payload,
|
|
128
|
+
},
|
|
129
|
+
**kwargs,
|
|
130
|
+
)
|