kumoai 2.12.0.dev202511021731__cp310-cp310-macosx_11_0_arm64.whl → 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +6 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +151 -8
- kumoai/experimental/rfm/local_graph_sampler.py +0 -2
- kumoai/experimental/rfm/local_pquery_driver.py +221 -26
- kumoai/experimental/rfm/pquery/__init__.py +0 -4
- kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
- kumoai/experimental/rfm/rfm.py +137 -79
- kumoai/experimental/rfm/sagemaker.py +130 -0
- kumoai/spcs.py +1 -3
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.12.0.dev202511021731.dist-info → kumoai-2.13.0.dev202511191731.dist-info}/METADATA +11 -2
- {kumoai-2.12.0.dev202511021731.dist-info → kumoai-2.13.0.dev202511191731.dist-info}/RECORD +19 -20
- kumoai/experimental/rfm/pquery/backend.py +0 -136
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
- {kumoai-2.12.0.dev202511021731.dist-info → kumoai-2.13.0.dev202511191731.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511021731.dist-info → kumoai-2.13.0.dev202511191731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511021731.dist-info → kumoai-2.13.0.dev202511191731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,23 +5,32 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
24
|
from kumoapi.rfm import Context
|
|
15
25
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
26
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
19
29
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
30
|
)
|
|
22
31
|
from kumoapi.task import TaskType
|
|
23
32
|
|
|
24
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
25
34
|
from kumoai.exceptions import HTTPException
|
|
26
35
|
from kumoai.experimental.rfm import LocalGraph
|
|
27
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
@@ -30,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
30
39
|
LocalPQueryDriver,
|
|
31
40
|
date_offset_to_seconds,
|
|
32
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
33
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
34
44
|
|
|
35
45
|
_RANDOM_SEED = 42
|
|
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
60
70
|
"beyond this for your use-case.")
|
|
61
71
|
|
|
62
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
63
84
|
@dataclass(repr=False)
|
|
64
85
|
class Explanation:
|
|
65
86
|
prediction: pd.DataFrame
|
|
@@ -87,6 +108,12 @@ class Explanation:
|
|
|
87
108
|
def __repr__(self) -> str:
|
|
88
109
|
return str((self.prediction, self.summary))
|
|
89
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
90
117
|
|
|
91
118
|
class KumoRFM:
|
|
92
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -114,9 +141,9 @@ class KumoRFM:
|
|
|
114
141
|
|
|
115
142
|
rfm = KumoRFM(graph)
|
|
116
143
|
|
|
117
|
-
query = ("PREDICT COUNT(
|
|
118
|
-
"FOR users.user_id=
|
|
119
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
120
147
|
|
|
121
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
122
149
|
# 1 0.85
|
|
@@ -147,6 +174,8 @@ class KumoRFM:
|
|
|
147
174
|
|
|
148
175
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
149
176
|
self.num_retries: int = 0
|
|
177
|
+
from kumoai.experimental.rfm import global_state
|
|
178
|
+
self._api_client = RFMAPI(global_state.client)
|
|
150
179
|
|
|
151
180
|
def __repr__(self) -> str:
|
|
152
181
|
return f'{self.__class__.__name__}()'
|
|
@@ -209,7 +238,7 @@ class KumoRFM:
|
|
|
209
238
|
query: str,
|
|
210
239
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
211
240
|
*,
|
|
212
|
-
explain: Literal[True],
|
|
241
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
213
242
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
214
243
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
215
244
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -227,7 +256,7 @@ class KumoRFM:
|
|
|
227
256
|
query: str,
|
|
228
257
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
229
258
|
*,
|
|
230
|
-
explain: bool = False,
|
|
259
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
231
260
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
232
261
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
233
262
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -247,9 +276,12 @@ class KumoRFM:
|
|
|
247
276
|
be generated for all indices, independent of whether they
|
|
248
277
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
249
278
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
250
|
-
explain:
|
|
251
|
-
|
|
252
|
-
|
|
279
|
+
explain: Configuration for explainability.
|
|
280
|
+
If set to ``True``, will additionally explain the prediction.
|
|
281
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
282
|
+
over which parts of explanation are generated.
|
|
283
|
+
Explainability is currently only supported for single entity
|
|
284
|
+
predictions with ``run_mode="FAST"``.
|
|
253
285
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
254
286
|
``None``, will use the maximum timestamp in the data.
|
|
255
287
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -273,42 +305,48 @@ class KumoRFM:
|
|
|
273
305
|
|
|
274
306
|
Returns:
|
|
275
307
|
The predictions as a :class:`pandas.DataFrame`.
|
|
276
|
-
If ``explain
|
|
277
|
-
|
|
308
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
309
|
+
containing the prediction, summary, and details.
|
|
278
310
|
"""
|
|
311
|
+
explain_config: Optional[ExplainConfig] = None
|
|
312
|
+
if explain is True:
|
|
313
|
+
explain_config = ExplainConfig()
|
|
314
|
+
elif explain is not False:
|
|
315
|
+
explain_config = ExplainConfig._cast(explain)
|
|
316
|
+
|
|
279
317
|
query_def = self._parse_query(query)
|
|
318
|
+
query_str = query_def.to_string()
|
|
280
319
|
|
|
281
320
|
if num_hops != 2 and num_neighbors is not None:
|
|
282
321
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
283
322
|
f"custom 'num_hops={num_hops}' option")
|
|
284
323
|
|
|
285
|
-
if
|
|
324
|
+
if explain_config is not None and run_mode in {
|
|
325
|
+
RunMode.NORMAL, RunMode.BEST
|
|
326
|
+
}:
|
|
286
327
|
warnings.warn(f"Explainability is currently only supported for "
|
|
287
328
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
288
329
|
f"mode has been reset. Please lower the run mode to "
|
|
289
330
|
f"suppress this warning.")
|
|
290
331
|
|
|
291
332
|
if indices is None:
|
|
292
|
-
if query_def.
|
|
333
|
+
if query_def.rfm_entity_ids is None:
|
|
293
334
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
294
335
|
"pass them via `predict(query, indices=...)`")
|
|
295
|
-
indices = query_def.
|
|
336
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
296
337
|
else:
|
|
297
|
-
query_def = replace(
|
|
298
|
-
query_def,
|
|
299
|
-
entity=replace(query_def.entity, ids=None),
|
|
300
|
-
)
|
|
338
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
301
339
|
|
|
302
340
|
if len(indices) == 0:
|
|
303
341
|
raise ValueError("At least one entity is required")
|
|
304
342
|
|
|
305
|
-
if
|
|
343
|
+
if explain_config is not None and len(indices) > 1:
|
|
306
344
|
raise ValueError(
|
|
307
345
|
f"Cannot explain predictions for more than a single entity "
|
|
308
346
|
f"(got {len(indices)})")
|
|
309
347
|
|
|
310
348
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
311
|
-
if
|
|
349
|
+
if explain_config is not None:
|
|
312
350
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
313
351
|
else:
|
|
314
352
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -320,8 +358,8 @@ class KumoRFM:
|
|
|
320
358
|
|
|
321
359
|
batch_size: Optional[int] = None
|
|
322
360
|
if self._batch_size == 'max':
|
|
323
|
-
task_type =
|
|
324
|
-
|
|
361
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
362
|
+
query_def,
|
|
325
363
|
edge_types=self._graph_store.edge_types,
|
|
326
364
|
)
|
|
327
365
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -359,6 +397,7 @@ class KumoRFM:
|
|
|
359
397
|
request = RFMPredictRequest(
|
|
360
398
|
context=context,
|
|
361
399
|
run_mode=RunMode(run_mode),
|
|
400
|
+
query=query_str,
|
|
362
401
|
use_prediction_time=use_prediction_time,
|
|
363
402
|
)
|
|
364
403
|
with warnings.catch_warnings():
|
|
@@ -382,17 +421,20 @@ class KumoRFM:
|
|
|
382
421
|
|
|
383
422
|
for attempt in range(self.num_retries + 1):
|
|
384
423
|
try:
|
|
385
|
-
if
|
|
386
|
-
resp =
|
|
424
|
+
if explain_config is not None:
|
|
425
|
+
resp = self._api_client.explain(
|
|
426
|
+
request=_bytes,
|
|
427
|
+
skip_summary=explain_config.skip_summary,
|
|
428
|
+
)
|
|
387
429
|
summary = resp.summary
|
|
388
430
|
details = resp.details
|
|
389
431
|
else:
|
|
390
|
-
resp =
|
|
432
|
+
resp = self._api_client.predict(_bytes)
|
|
391
433
|
df = pd.DataFrame(**resp.prediction)
|
|
392
434
|
|
|
393
435
|
# Cast 'ENTITY' to correct data type:
|
|
394
436
|
if 'ENTITY' in df:
|
|
395
|
-
entity = query_def.
|
|
437
|
+
entity = query_def.entity_table
|
|
396
438
|
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
397
439
|
df['ENTITY'] = df['ENTITY'].astype(
|
|
398
440
|
type(pkey_map.index[0]))
|
|
@@ -434,7 +476,7 @@ class KumoRFM:
|
|
|
434
476
|
else:
|
|
435
477
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
436
478
|
|
|
437
|
-
if
|
|
479
|
+
if explain_config is not None:
|
|
438
480
|
assert len(predictions) == 1
|
|
439
481
|
assert summary is not None
|
|
440
482
|
assert details is not None
|
|
@@ -468,11 +510,11 @@ class KumoRFM:
|
|
|
468
510
|
query_def = self._parse_query(query)
|
|
469
511
|
|
|
470
512
|
if indices is None:
|
|
471
|
-
if query_def.
|
|
513
|
+
if query_def.rfm_entity_ids is None:
|
|
472
514
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
473
515
|
"pass them via "
|
|
474
516
|
"`is_valid_entity(query, indices=...)`")
|
|
475
|
-
indices = query_def.
|
|
517
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
476
518
|
|
|
477
519
|
if len(indices) == 0:
|
|
478
520
|
raise ValueError("At least one entity is required")
|
|
@@ -484,14 +526,13 @@ class KumoRFM:
|
|
|
484
526
|
self._validate_time(query_def, anchor_time, None, False)
|
|
485
527
|
else:
|
|
486
528
|
assert anchor_time == 'entity'
|
|
487
|
-
if (query_def.
|
|
488
|
-
not in self._graph_store.time_dict):
|
|
529
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
489
530
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
490
|
-
f"table '{query_def.
|
|
491
|
-
f"to have a time column")
|
|
531
|
+
f"table '{query_def.entity_table}' "
|
|
532
|
+
f"to have a time column.")
|
|
492
533
|
|
|
493
534
|
node = self._graph_store.get_node_id(
|
|
494
|
-
table_name=query_def.
|
|
535
|
+
table_name=query_def.entity_table,
|
|
495
536
|
pkey=pd.Series(indices),
|
|
496
537
|
)
|
|
497
538
|
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
@@ -547,10 +588,10 @@ class KumoRFM:
|
|
|
547
588
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
548
589
|
f"custom 'num_hops={num_hops}' option")
|
|
549
590
|
|
|
550
|
-
if query_def.
|
|
591
|
+
if query_def.rfm_entity_ids is not None:
|
|
551
592
|
query_def = replace(
|
|
552
593
|
query_def,
|
|
553
|
-
|
|
594
|
+
rfm_entity_ids=None,
|
|
554
595
|
)
|
|
555
596
|
|
|
556
597
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -591,10 +632,10 @@ class KumoRFM:
|
|
|
591
632
|
|
|
592
633
|
if len(request_bytes) > _MAX_SIZE:
|
|
593
634
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
594
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
635
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
595
636
|
|
|
596
637
|
try:
|
|
597
|
-
resp =
|
|
638
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
598
639
|
except HTTPException as e:
|
|
599
640
|
try:
|
|
600
641
|
msg = json.loads(e.detail)['detail']
|
|
@@ -639,18 +680,19 @@ class KumoRFM:
|
|
|
639
680
|
|
|
640
681
|
if anchor_time is None:
|
|
641
682
|
anchor_time = self._graph_store.max_time
|
|
642
|
-
|
|
643
|
-
|
|
683
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
684
|
+
anchor_time = anchor_time - (
|
|
685
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
686
|
+
query_def.num_forecasts)
|
|
644
687
|
|
|
645
688
|
assert anchor_time is not None
|
|
646
689
|
if isinstance(anchor_time, pd.Timestamp):
|
|
647
690
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
648
691
|
else:
|
|
649
692
|
assert anchor_time == 'entity'
|
|
650
|
-
if (query_def.
|
|
651
|
-
not in self._graph_store.time_dict):
|
|
693
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
652
694
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
653
|
-
f"table '{query_def.
|
|
695
|
+
f"table '{query_def.entity_table}' "
|
|
654
696
|
f"to have a time column")
|
|
655
697
|
|
|
656
698
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -665,7 +707,7 @@ class KumoRFM:
|
|
|
665
707
|
)
|
|
666
708
|
|
|
667
709
|
entity = self._graph_store.pkey_map_dict[
|
|
668
|
-
query_def.
|
|
710
|
+
query_def.entity_table].index[node]
|
|
669
711
|
|
|
670
712
|
return pd.DataFrame({
|
|
671
713
|
'ENTITY': entity,
|
|
@@ -675,8 +717,8 @@ class KumoRFM:
|
|
|
675
717
|
|
|
676
718
|
# Helpers #################################################################
|
|
677
719
|
|
|
678
|
-
def _parse_query(self, query: str) ->
|
|
679
|
-
if isinstance(query,
|
|
720
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
721
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
680
722
|
return query
|
|
681
723
|
|
|
682
724
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -686,12 +728,13 @@ class KumoRFM:
|
|
|
686
728
|
"predictions or evaluations.")
|
|
687
729
|
|
|
688
730
|
try:
|
|
689
|
-
request =
|
|
731
|
+
request = RFMParseQueryRequest(
|
|
690
732
|
query=query,
|
|
691
733
|
graph_definition=self._graph_def,
|
|
692
734
|
)
|
|
693
735
|
|
|
694
|
-
resp =
|
|
736
|
+
resp = self._api_client.parse_query(request)
|
|
737
|
+
|
|
695
738
|
# TODO Expose validation warnings.
|
|
696
739
|
|
|
697
740
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -702,7 +745,7 @@ class KumoRFM:
|
|
|
702
745
|
warnings.warn(f"Encountered the following warnings during "
|
|
703
746
|
f"parsing:\n{msg}")
|
|
704
747
|
|
|
705
|
-
return resp.
|
|
748
|
+
return resp.query
|
|
706
749
|
except HTTPException as e:
|
|
707
750
|
try:
|
|
708
751
|
msg = json.loads(e.detail)['detail']
|
|
@@ -713,7 +756,7 @@ class KumoRFM:
|
|
|
713
756
|
|
|
714
757
|
def _validate_time(
|
|
715
758
|
self,
|
|
716
|
-
query:
|
|
759
|
+
query: ValidatedPredictiveQuery,
|
|
717
760
|
anchor_time: pd.Timestamp,
|
|
718
761
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
719
762
|
evaluate: bool,
|
|
@@ -736,6 +779,11 @@ class KumoRFM:
|
|
|
736
779
|
f"only contains data back to "
|
|
737
780
|
f"'{self._graph_store.min_time}'.")
|
|
738
781
|
|
|
782
|
+
if query.target_ast.date_offset_range is not None:
|
|
783
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
784
|
+
else:
|
|
785
|
+
end_offset = pd.DateOffset(0)
|
|
786
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
739
787
|
if (context_anchor_time is not None
|
|
740
788
|
and context_anchor_time > anchor_time):
|
|
741
789
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -744,19 +792,18 @@ class KumoRFM:
|
|
|
744
792
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
745
793
|
f"intended.")
|
|
746
794
|
elif (query.query_type == QueryType.TEMPORAL
|
|
747
|
-
and context_anchor_time is not None
|
|
748
|
-
|
|
795
|
+
and context_anchor_time is not None
|
|
796
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
749
797
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
750
798
|
f"'{context_anchor_time}' will leak information "
|
|
751
799
|
f"from the prediction anchor timestamp "
|
|
752
800
|
f"'{anchor_time}'. Please make sure this is "
|
|
753
801
|
f"intended.")
|
|
754
802
|
|
|
755
|
-
elif (context_anchor_time is not None
|
|
756
|
-
|
|
803
|
+
elif (context_anchor_time is not None
|
|
804
|
+
and context_anchor_time - forecast_end_offset
|
|
757
805
|
< self._graph_store.min_time):
|
|
758
|
-
_time = context_anchor_time -
|
|
759
|
-
query.num_forecasts)
|
|
806
|
+
_time = context_anchor_time - forecast_end_offset
|
|
760
807
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
761
808
|
f"aggregation time range is too large. To form "
|
|
762
809
|
f"proper input data, we would need data back to "
|
|
@@ -769,8 +816,7 @@ class KumoRFM:
|
|
|
769
816
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
770
817
|
f"in the data. Please make sure this is intended.")
|
|
771
818
|
|
|
772
|
-
max_eval_time =
|
|
773
|
-
query.target.end_offset * query.num_forecasts)
|
|
819
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
774
820
|
if evaluate and anchor_time > max_eval_time:
|
|
775
821
|
raise ValueError(
|
|
776
822
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -778,7 +824,7 @@ class KumoRFM:
|
|
|
778
824
|
|
|
779
825
|
def _get_context(
|
|
780
826
|
self,
|
|
781
|
-
query:
|
|
827
|
+
query: ValidatedPredictiveQuery,
|
|
782
828
|
indices: Union[List[str], List[float], List[int], None],
|
|
783
829
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
784
830
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -806,8 +852,8 @@ class KumoRFM:
|
|
|
806
852
|
f"must go beyond this for your use-case.")
|
|
807
853
|
|
|
808
854
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
809
|
-
task_type =
|
|
810
|
-
|
|
855
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
856
|
+
query,
|
|
811
857
|
edge_types=self._graph_store.edge_types,
|
|
812
858
|
)
|
|
813
859
|
|
|
@@ -839,11 +885,15 @@ class KumoRFM:
|
|
|
839
885
|
else:
|
|
840
886
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
841
887
|
|
|
888
|
+
if query.target_ast.date_offset_range is None:
|
|
889
|
+
end_offset = pd.DateOffset(0)
|
|
890
|
+
else:
|
|
891
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
892
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
842
893
|
if anchor_time is None:
|
|
843
894
|
anchor_time = self._graph_store.max_time
|
|
844
895
|
if evaluate:
|
|
845
|
-
anchor_time = anchor_time -
|
|
846
|
-
query.num_forecasts)
|
|
896
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
847
897
|
if logger is not None:
|
|
848
898
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
849
899
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -858,15 +908,14 @@ class KumoRFM:
|
|
|
858
908
|
assert anchor_time is not None
|
|
859
909
|
if isinstance(anchor_time, pd.Timestamp):
|
|
860
910
|
if context_anchor_time is None:
|
|
861
|
-
context_anchor_time = anchor_time -
|
|
862
|
-
query.num_forecasts)
|
|
911
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
863
912
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
864
913
|
evaluate)
|
|
865
914
|
else:
|
|
866
915
|
assert anchor_time == 'entity'
|
|
867
|
-
if query.
|
|
916
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
868
917
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
869
|
-
f"table '{query.
|
|
918
|
+
f"table '{query.entity_table}' to "
|
|
870
919
|
f"have a time column")
|
|
871
920
|
if context_anchor_time is not None:
|
|
872
921
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -917,7 +966,7 @@ class KumoRFM:
|
|
|
917
966
|
f"in batches")
|
|
918
967
|
|
|
919
968
|
test_node = self._graph_store.get_node_id(
|
|
920
|
-
table_name=query.
|
|
969
|
+
table_name=query.entity_table,
|
|
921
970
|
pkey=pd.Series(indices),
|
|
922
971
|
)
|
|
923
972
|
|
|
@@ -925,8 +974,7 @@ class KumoRFM:
|
|
|
925
974
|
test_time = pd.Series(anchor_time).repeat(
|
|
926
975
|
len(test_node)).reset_index(drop=True)
|
|
927
976
|
else:
|
|
928
|
-
time = self._graph_store.time_dict[
|
|
929
|
-
query.entity.pkey.table_name]
|
|
977
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
930
978
|
time = time[test_node] * 1000**3
|
|
931
979
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
932
980
|
|
|
@@ -959,12 +1007,23 @@ class KumoRFM:
|
|
|
959
1007
|
raise NotImplementedError
|
|
960
1008
|
logger.log(msg)
|
|
961
1009
|
|
|
962
|
-
entity_table_names
|
|
963
|
-
|
|
1010
|
+
entity_table_names: Tuple[str, ...]
|
|
1011
|
+
if task_type.is_link_pred:
|
|
1012
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1013
|
+
assert final_aggr is not None
|
|
1014
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1015
|
+
for edge_type in self._graph_store.edge_types:
|
|
1016
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1017
|
+
entity_table_names = (
|
|
1018
|
+
query.entity_table,
|
|
1019
|
+
edge_type[2],
|
|
1020
|
+
)
|
|
1021
|
+
else:
|
|
1022
|
+
entity_table_names = (query.entity_table, )
|
|
964
1023
|
|
|
965
1024
|
# Exclude the entity anchor time from the feature set to prevent
|
|
966
1025
|
# running out-of-distribution between in-context and test examples:
|
|
967
|
-
exclude_cols_dict = query.
|
|
1026
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
968
1027
|
if anchor_time == 'entity':
|
|
969
1028
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
970
1029
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -979,7 +1038,6 @@ class KumoRFM:
|
|
|
979
1038
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
980
1039
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
981
1040
|
]),
|
|
982
|
-
run_mode=run_mode,
|
|
983
1041
|
num_neighbors=num_neighbors,
|
|
984
1042
|
exclude_cols_dict=exclude_cols_dict,
|
|
985
1043
|
)
|
|
@@ -993,7 +1051,7 @@ class KumoRFM:
|
|
|
993
1051
|
|
|
994
1052
|
step_size: Optional[int] = None
|
|
995
1053
|
if query.query_type == QueryType.TEMPORAL:
|
|
996
|
-
step_size = date_offset_to_seconds(
|
|
1054
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
997
1055
|
|
|
998
1056
|
return Context(
|
|
999
1057
|
task_type=task_type,
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import boto3
|
|
6
|
+
import requests
|
|
7
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
8
|
+
from mypy_boto3_sagemaker_runtime.type_defs import InvokeEndpointOutputTypeDef
|
|
9
|
+
|
|
10
|
+
from kumoai.client import KumoClient
|
|
11
|
+
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
12
|
+
from kumoai.exceptions import HTTPException
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SageMakerResponseAdapter(requests.Response):
|
|
16
|
+
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
17
|
+
super().__init__()
|
|
18
|
+
# Read the body bytes
|
|
19
|
+
self._content = sm_response['Body'].read()
|
|
20
|
+
self.status_code = 200
|
|
21
|
+
self.headers['Content-Type'] = sm_response.get('ContentType',
|
|
22
|
+
'application/json')
|
|
23
|
+
# Optionally, you can store original sm_response for debugging
|
|
24
|
+
self.sm_response = sm_response
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def text(self) -> str:
|
|
28
|
+
assert isinstance(self._content, bytes)
|
|
29
|
+
return self._content.decode('utf-8')
|
|
30
|
+
|
|
31
|
+
def json(self, **kwargs) -> dict[str, Any]: # type: ignore
|
|
32
|
+
return json.loads(self.text, **kwargs)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class KumoClient_SageMakerAdapter(KumoClient):
|
|
36
|
+
def __init__(self, region: str, endpoint_name: str):
|
|
37
|
+
self._client: SageMakerRuntimeClient = boto3.client(
|
|
38
|
+
service_name="sagemaker-runtime", region_name=region)
|
|
39
|
+
self._endpoint_name = endpoint_name
|
|
40
|
+
|
|
41
|
+
# Recording buffers.
|
|
42
|
+
self._recording_active = False
|
|
43
|
+
self._recorded_reqs: List[Dict[str, Any]] = []
|
|
44
|
+
self._recorded_resps: List[Dict[str, Any]] = []
|
|
45
|
+
|
|
46
|
+
def authenticate(self) -> None:
|
|
47
|
+
# TODO(siyang): call /ping to verify?
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
51
|
+
assert endpoint.method == HTTPMethod.POST
|
|
52
|
+
if 'json' in kwargs:
|
|
53
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
54
|
+
elif 'data' in kwargs:
|
|
55
|
+
raw_payload = kwargs.pop('data')
|
|
56
|
+
assert isinstance(raw_payload, bytes)
|
|
57
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
58
|
+
else:
|
|
59
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
60
|
+
|
|
61
|
+
request = {
|
|
62
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
63
|
+
'payload': payload,
|
|
64
|
+
}
|
|
65
|
+
response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
|
|
66
|
+
EndpointName=self._endpoint_name,
|
|
67
|
+
ContentType="application/json",
|
|
68
|
+
Body=json.dumps(request),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
adapted_response = SageMakerResponseAdapter(response)
|
|
72
|
+
|
|
73
|
+
# If validation is active, store input/output
|
|
74
|
+
if self._recording_active:
|
|
75
|
+
self._recorded_reqs.append(request)
|
|
76
|
+
self._recorded_resps.append(adapted_response.json())
|
|
77
|
+
|
|
78
|
+
return adapted_response
|
|
79
|
+
|
|
80
|
+
def start_recording(self) -> None:
|
|
81
|
+
"""Start recording requests/responses to/from sagemaker endpoint."""
|
|
82
|
+
assert not self._recording_active
|
|
83
|
+
self._recording_active = True
|
|
84
|
+
self._recorded_reqs.clear()
|
|
85
|
+
self._recorded_resps.clear()
|
|
86
|
+
|
|
87
|
+
def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
|
88
|
+
"""Stop recording and return recorded requests/responses."""
|
|
89
|
+
assert self._recording_active
|
|
90
|
+
self._recording_active = False
|
|
91
|
+
recorded = list(zip(self._recorded_reqs, self._recorded_resps))
|
|
92
|
+
self._recorded_reqs.clear()
|
|
93
|
+
self._recorded_resps.clear()
|
|
94
|
+
return recorded
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class KumoClient_SageMakerProxy_Local(KumoClient):
|
|
98
|
+
def __init__(self, url: str):
|
|
99
|
+
self._client = KumoClient(url, api_key=None)
|
|
100
|
+
self._client._api_url = self._client._url
|
|
101
|
+
self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
|
|
102
|
+
|
|
103
|
+
def authenticate(self) -> None:
|
|
104
|
+
try:
|
|
105
|
+
self._client._session.get(
|
|
106
|
+
self._url + '/ping',
|
|
107
|
+
verify=self._verify_ssl).raise_for_status()
|
|
108
|
+
except Exception:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"Client authentication failed. Please check if you "
|
|
111
|
+
"have a valid API key/credentials.")
|
|
112
|
+
|
|
113
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
114
|
+
assert endpoint.method == HTTPMethod.POST
|
|
115
|
+
if 'json' in kwargs:
|
|
116
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
117
|
+
elif 'data' in kwargs:
|
|
118
|
+
raw_payload = kwargs.pop('data')
|
|
119
|
+
assert isinstance(raw_payload, bytes)
|
|
120
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
121
|
+
else:
|
|
122
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
123
|
+
return self._client._request(
|
|
124
|
+
self._endpoint,
|
|
125
|
+
json={
|
|
126
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
127
|
+
'payload': payload,
|
|
128
|
+
},
|
|
129
|
+
**kwargs,
|
|
130
|
+
)
|
kumoai/spcs.py
CHANGED
|
@@ -54,9 +54,7 @@ def _refresh_spcs_token() -> None:
|
|
|
54
54
|
api_key=global_state._api_key,
|
|
55
55
|
spcs_token=spcs_token,
|
|
56
56
|
)
|
|
57
|
-
|
|
58
|
-
raise ValueError("Client authentication failed. Please check if you "
|
|
59
|
-
"have a valid API key.")
|
|
57
|
+
client.authenticate()
|
|
60
58
|
|
|
61
59
|
# Update state:
|
|
62
60
|
global_state.set_spcs_token(spcs_token)
|