kumoai 2.12.0.dev202511031731__cp311-cp311-macosx_11_0_arm64.whl → 2.12.0.dev202511111731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +3 -1
- kumoai/experimental/rfm/local_pquery_driver.py +221 -26
- kumoai/experimental/rfm/pquery/__init__.py +0 -4
- kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
- kumoai/experimental/rfm/rfm.py +127 -71
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/METADATA +2 -2
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/RECORD +14 -16
- kumoai/experimental/rfm/pquery/backend.py +0 -136
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,19 +5,28 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
24
|
from kumoapi.rfm import Context
|
|
15
25
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
26
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
19
29
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
30
|
)
|
|
22
31
|
from kumoapi.task import TaskType
|
|
23
32
|
|
|
@@ -30,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
30
39
|
LocalPQueryDriver,
|
|
31
40
|
date_offset_to_seconds,
|
|
32
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
33
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
34
44
|
|
|
35
45
|
_RANDOM_SEED = 42
|
|
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
60
70
|
"beyond this for your use-case.")
|
|
61
71
|
|
|
62
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
63
84
|
@dataclass(repr=False)
|
|
64
85
|
class Explanation:
|
|
65
86
|
prediction: pd.DataFrame
|
|
@@ -87,6 +108,12 @@ class Explanation:
|
|
|
87
108
|
def __repr__(self) -> str:
|
|
88
109
|
return str((self.prediction, self.summary))
|
|
89
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
90
117
|
|
|
91
118
|
class KumoRFM:
|
|
92
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -209,7 +236,7 @@ class KumoRFM:
|
|
|
209
236
|
query: str,
|
|
210
237
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
211
238
|
*,
|
|
212
|
-
explain: Literal[True],
|
|
239
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
213
240
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
214
241
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
215
242
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -227,7 +254,7 @@ class KumoRFM:
|
|
|
227
254
|
query: str,
|
|
228
255
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
229
256
|
*,
|
|
230
|
-
explain: bool = False,
|
|
257
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
231
258
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
232
259
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
233
260
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -247,9 +274,12 @@ class KumoRFM:
|
|
|
247
274
|
be generated for all indices, independent of whether they
|
|
248
275
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
249
276
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
250
|
-
explain:
|
|
251
|
-
|
|
252
|
-
|
|
277
|
+
explain: Configuration for explainability.
|
|
278
|
+
If set to ``True``, will additionally explain the prediction.
|
|
279
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
280
|
+
over which parts of explanation are generated.
|
|
281
|
+
Explainability is currently only supported for single entity
|
|
282
|
+
predictions with ``run_mode="FAST"``.
|
|
253
283
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
254
284
|
``None``, will use the maximum timestamp in the data.
|
|
255
285
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -273,42 +303,48 @@ class KumoRFM:
|
|
|
273
303
|
|
|
274
304
|
Returns:
|
|
275
305
|
The predictions as a :class:`pandas.DataFrame`.
|
|
276
|
-
If ``explain
|
|
277
|
-
|
|
306
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
307
|
+
containing the prediction, summary, and details.
|
|
278
308
|
"""
|
|
309
|
+
explain_config: Optional[ExplainConfig] = None
|
|
310
|
+
if explain is True:
|
|
311
|
+
explain_config = ExplainConfig()
|
|
312
|
+
elif explain is not False:
|
|
313
|
+
explain_config = ExplainConfig._cast(explain)
|
|
314
|
+
|
|
279
315
|
query_def = self._parse_query(query)
|
|
316
|
+
query_str = query_def.to_string()
|
|
280
317
|
|
|
281
318
|
if num_hops != 2 and num_neighbors is not None:
|
|
282
319
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
283
320
|
f"custom 'num_hops={num_hops}' option")
|
|
284
321
|
|
|
285
|
-
if
|
|
322
|
+
if explain_config is not None and run_mode in {
|
|
323
|
+
RunMode.NORMAL, RunMode.BEST
|
|
324
|
+
}:
|
|
286
325
|
warnings.warn(f"Explainability is currently only supported for "
|
|
287
326
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
288
327
|
f"mode has been reset. Please lower the run mode to "
|
|
289
328
|
f"suppress this warning.")
|
|
290
329
|
|
|
291
330
|
if indices is None:
|
|
292
|
-
if query_def.
|
|
331
|
+
if query_def.rfm_entity_ids is None:
|
|
293
332
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
294
333
|
"pass them via `predict(query, indices=...)`")
|
|
295
|
-
indices = query_def.
|
|
334
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
296
335
|
else:
|
|
297
|
-
query_def = replace(
|
|
298
|
-
query_def,
|
|
299
|
-
entity=replace(query_def.entity, ids=None),
|
|
300
|
-
)
|
|
336
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
301
337
|
|
|
302
338
|
if len(indices) == 0:
|
|
303
339
|
raise ValueError("At least one entity is required")
|
|
304
340
|
|
|
305
|
-
if
|
|
341
|
+
if explain_config is not None and len(indices) > 1:
|
|
306
342
|
raise ValueError(
|
|
307
343
|
f"Cannot explain predictions for more than a single entity "
|
|
308
344
|
f"(got {len(indices)})")
|
|
309
345
|
|
|
310
346
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
311
|
-
if
|
|
347
|
+
if explain_config is not None:
|
|
312
348
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
313
349
|
else:
|
|
314
350
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -320,8 +356,8 @@ class KumoRFM:
|
|
|
320
356
|
|
|
321
357
|
batch_size: Optional[int] = None
|
|
322
358
|
if self._batch_size == 'max':
|
|
323
|
-
task_type =
|
|
324
|
-
|
|
359
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
360
|
+
query_def,
|
|
325
361
|
edge_types=self._graph_store.edge_types,
|
|
326
362
|
)
|
|
327
363
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -359,6 +395,7 @@ class KumoRFM:
|
|
|
359
395
|
request = RFMPredictRequest(
|
|
360
396
|
context=context,
|
|
361
397
|
run_mode=RunMode(run_mode),
|
|
398
|
+
query=query_str,
|
|
362
399
|
use_prediction_time=use_prediction_time,
|
|
363
400
|
)
|
|
364
401
|
with warnings.catch_warnings():
|
|
@@ -382,8 +419,11 @@ class KumoRFM:
|
|
|
382
419
|
|
|
383
420
|
for attempt in range(self.num_retries + 1):
|
|
384
421
|
try:
|
|
385
|
-
if
|
|
386
|
-
resp = global_state.client.rfm_api.explain(
|
|
422
|
+
if explain_config is not None:
|
|
423
|
+
resp = global_state.client.rfm_api.explain(
|
|
424
|
+
request=_bytes,
|
|
425
|
+
skip_summary=explain_config.skip_summary,
|
|
426
|
+
)
|
|
387
427
|
summary = resp.summary
|
|
388
428
|
details = resp.details
|
|
389
429
|
else:
|
|
@@ -392,7 +432,7 @@ class KumoRFM:
|
|
|
392
432
|
|
|
393
433
|
# Cast 'ENTITY' to correct data type:
|
|
394
434
|
if 'ENTITY' in df:
|
|
395
|
-
entity = query_def.
|
|
435
|
+
entity = query_def.entity_table
|
|
396
436
|
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
397
437
|
df['ENTITY'] = df['ENTITY'].astype(
|
|
398
438
|
type(pkey_map.index[0]))
|
|
@@ -434,7 +474,7 @@ class KumoRFM:
|
|
|
434
474
|
else:
|
|
435
475
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
436
476
|
|
|
437
|
-
if
|
|
477
|
+
if explain_config is not None:
|
|
438
478
|
assert len(predictions) == 1
|
|
439
479
|
assert summary is not None
|
|
440
480
|
assert details is not None
|
|
@@ -468,11 +508,11 @@ class KumoRFM:
|
|
|
468
508
|
query_def = self._parse_query(query)
|
|
469
509
|
|
|
470
510
|
if indices is None:
|
|
471
|
-
if query_def.
|
|
511
|
+
if query_def.rfm_entity_ids is None:
|
|
472
512
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
473
513
|
"pass them via "
|
|
474
514
|
"`is_valid_entity(query, indices=...)`")
|
|
475
|
-
indices = query_def.
|
|
515
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
476
516
|
|
|
477
517
|
if len(indices) == 0:
|
|
478
518
|
raise ValueError("At least one entity is required")
|
|
@@ -484,14 +524,13 @@ class KumoRFM:
|
|
|
484
524
|
self._validate_time(query_def, anchor_time, None, False)
|
|
485
525
|
else:
|
|
486
526
|
assert anchor_time == 'entity'
|
|
487
|
-
if (query_def.
|
|
488
|
-
not in self._graph_store.time_dict):
|
|
527
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
489
528
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
490
|
-
f"table '{query_def.
|
|
491
|
-
f"to have a time column")
|
|
529
|
+
f"table '{query_def.entity_table}' "
|
|
530
|
+
f"to have a time column.")
|
|
492
531
|
|
|
493
532
|
node = self._graph_store.get_node_id(
|
|
494
|
-
table_name=query_def.
|
|
533
|
+
table_name=query_def.entity_table,
|
|
495
534
|
pkey=pd.Series(indices),
|
|
496
535
|
)
|
|
497
536
|
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
@@ -547,10 +586,10 @@ class KumoRFM:
|
|
|
547
586
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
548
587
|
f"custom 'num_hops={num_hops}' option")
|
|
549
588
|
|
|
550
|
-
if query_def.
|
|
589
|
+
if query_def.rfm_entity_ids is not None:
|
|
551
590
|
query_def = replace(
|
|
552
591
|
query_def,
|
|
553
|
-
|
|
592
|
+
rfm_entity_ids=None,
|
|
554
593
|
)
|
|
555
594
|
|
|
556
595
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -639,18 +678,19 @@ class KumoRFM:
|
|
|
639
678
|
|
|
640
679
|
if anchor_time is None:
|
|
641
680
|
anchor_time = self._graph_store.max_time
|
|
642
|
-
|
|
643
|
-
|
|
681
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
682
|
+
anchor_time = anchor_time - (
|
|
683
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
684
|
+
query_def.num_forecasts)
|
|
644
685
|
|
|
645
686
|
assert anchor_time is not None
|
|
646
687
|
if isinstance(anchor_time, pd.Timestamp):
|
|
647
688
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
648
689
|
else:
|
|
649
690
|
assert anchor_time == 'entity'
|
|
650
|
-
if (query_def.
|
|
651
|
-
not in self._graph_store.time_dict):
|
|
691
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
652
692
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
653
|
-
f"table '{query_def.
|
|
693
|
+
f"table '{query_def.entity_table}' "
|
|
654
694
|
f"to have a time column")
|
|
655
695
|
|
|
656
696
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -665,7 +705,7 @@ class KumoRFM:
|
|
|
665
705
|
)
|
|
666
706
|
|
|
667
707
|
entity = self._graph_store.pkey_map_dict[
|
|
668
|
-
query_def.
|
|
708
|
+
query_def.entity_table].index[node]
|
|
669
709
|
|
|
670
710
|
return pd.DataFrame({
|
|
671
711
|
'ENTITY': entity,
|
|
@@ -675,8 +715,8 @@ class KumoRFM:
|
|
|
675
715
|
|
|
676
716
|
# Helpers #################################################################
|
|
677
717
|
|
|
678
|
-
def _parse_query(self, query: str) ->
|
|
679
|
-
if isinstance(query,
|
|
718
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
719
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
680
720
|
return query
|
|
681
721
|
|
|
682
722
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -686,12 +726,12 @@ class KumoRFM:
|
|
|
686
726
|
"predictions or evaluations.")
|
|
687
727
|
|
|
688
728
|
try:
|
|
689
|
-
request =
|
|
729
|
+
request = RFMParseQueryRequest(
|
|
690
730
|
query=query,
|
|
691
731
|
graph_definition=self._graph_def,
|
|
692
732
|
)
|
|
693
733
|
|
|
694
|
-
resp = global_state.client.rfm_api.
|
|
734
|
+
resp = global_state.client.rfm_api.parse_query(request)
|
|
695
735
|
# TODO Expose validation warnings.
|
|
696
736
|
|
|
697
737
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -702,7 +742,7 @@ class KumoRFM:
|
|
|
702
742
|
warnings.warn(f"Encountered the following warnings during "
|
|
703
743
|
f"parsing:\n{msg}")
|
|
704
744
|
|
|
705
|
-
return resp.
|
|
745
|
+
return resp.query
|
|
706
746
|
except HTTPException as e:
|
|
707
747
|
try:
|
|
708
748
|
msg = json.loads(e.detail)['detail']
|
|
@@ -713,7 +753,7 @@ class KumoRFM:
|
|
|
713
753
|
|
|
714
754
|
def _validate_time(
|
|
715
755
|
self,
|
|
716
|
-
query:
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
717
757
|
anchor_time: pd.Timestamp,
|
|
718
758
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
719
759
|
evaluate: bool,
|
|
@@ -736,6 +776,11 @@ class KumoRFM:
|
|
|
736
776
|
f"only contains data back to "
|
|
737
777
|
f"'{self._graph_store.min_time}'.")
|
|
738
778
|
|
|
779
|
+
if query.target_ast.date_offset_range is not None:
|
|
780
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
781
|
+
else:
|
|
782
|
+
end_offset = pd.DateOffset(0)
|
|
783
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
739
784
|
if (context_anchor_time is not None
|
|
740
785
|
and context_anchor_time > anchor_time):
|
|
741
786
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -744,19 +789,18 @@ class KumoRFM:
|
|
|
744
789
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
745
790
|
f"intended.")
|
|
746
791
|
elif (query.query_type == QueryType.TEMPORAL
|
|
747
|
-
and context_anchor_time is not None
|
|
748
|
-
|
|
792
|
+
and context_anchor_time is not None
|
|
793
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
749
794
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
750
795
|
f"'{context_anchor_time}' will leak information "
|
|
751
796
|
f"from the prediction anchor timestamp "
|
|
752
797
|
f"'{anchor_time}'. Please make sure this is "
|
|
753
798
|
f"intended.")
|
|
754
799
|
|
|
755
|
-
elif (context_anchor_time is not None
|
|
756
|
-
|
|
800
|
+
elif (context_anchor_time is not None
|
|
801
|
+
and context_anchor_time - forecast_end_offset
|
|
757
802
|
< self._graph_store.min_time):
|
|
758
|
-
_time = context_anchor_time -
|
|
759
|
-
query.num_forecasts)
|
|
803
|
+
_time = context_anchor_time - forecast_end_offset
|
|
760
804
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
761
805
|
f"aggregation time range is too large. To form "
|
|
762
806
|
f"proper input data, we would need data back to "
|
|
@@ -769,8 +813,7 @@ class KumoRFM:
|
|
|
769
813
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
770
814
|
f"in the data. Please make sure this is intended.")
|
|
771
815
|
|
|
772
|
-
max_eval_time =
|
|
773
|
-
query.target.end_offset * query.num_forecasts)
|
|
816
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
774
817
|
if evaluate and anchor_time > max_eval_time:
|
|
775
818
|
raise ValueError(
|
|
776
819
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -778,7 +821,7 @@ class KumoRFM:
|
|
|
778
821
|
|
|
779
822
|
def _get_context(
|
|
780
823
|
self,
|
|
781
|
-
query:
|
|
824
|
+
query: ValidatedPredictiveQuery,
|
|
782
825
|
indices: Union[List[str], List[float], List[int], None],
|
|
783
826
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
784
827
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -806,8 +849,8 @@ class KumoRFM:
|
|
|
806
849
|
f"must go beyond this for your use-case.")
|
|
807
850
|
|
|
808
851
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
809
|
-
task_type =
|
|
810
|
-
|
|
852
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
853
|
+
query,
|
|
811
854
|
edge_types=self._graph_store.edge_types,
|
|
812
855
|
)
|
|
813
856
|
|
|
@@ -839,11 +882,15 @@ class KumoRFM:
|
|
|
839
882
|
else:
|
|
840
883
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
841
884
|
|
|
885
|
+
if query.target_ast.date_offset_range is None:
|
|
886
|
+
end_offset = pd.DateOffset(0)
|
|
887
|
+
else:
|
|
888
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
889
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
842
890
|
if anchor_time is None:
|
|
843
891
|
anchor_time = self._graph_store.max_time
|
|
844
892
|
if evaluate:
|
|
845
|
-
anchor_time = anchor_time -
|
|
846
|
-
query.num_forecasts)
|
|
893
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
847
894
|
if logger is not None:
|
|
848
895
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
849
896
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -858,15 +905,14 @@ class KumoRFM:
|
|
|
858
905
|
assert anchor_time is not None
|
|
859
906
|
if isinstance(anchor_time, pd.Timestamp):
|
|
860
907
|
if context_anchor_time is None:
|
|
861
|
-
context_anchor_time = anchor_time -
|
|
862
|
-
query.num_forecasts)
|
|
908
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
863
909
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
864
910
|
evaluate)
|
|
865
911
|
else:
|
|
866
912
|
assert anchor_time == 'entity'
|
|
867
|
-
if query.
|
|
913
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
868
914
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
869
|
-
f"table '{query.
|
|
915
|
+
f"table '{query.entity_table}' to "
|
|
870
916
|
f"have a time column")
|
|
871
917
|
if context_anchor_time is not None:
|
|
872
918
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -917,7 +963,7 @@ class KumoRFM:
|
|
|
917
963
|
f"in batches")
|
|
918
964
|
|
|
919
965
|
test_node = self._graph_store.get_node_id(
|
|
920
|
-
table_name=query.
|
|
966
|
+
table_name=query.entity_table,
|
|
921
967
|
pkey=pd.Series(indices),
|
|
922
968
|
)
|
|
923
969
|
|
|
@@ -925,8 +971,7 @@ class KumoRFM:
|
|
|
925
971
|
test_time = pd.Series(anchor_time).repeat(
|
|
926
972
|
len(test_node)).reset_index(drop=True)
|
|
927
973
|
else:
|
|
928
|
-
time = self._graph_store.time_dict[
|
|
929
|
-
query.entity.pkey.table_name]
|
|
974
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
930
975
|
time = time[test_node] * 1000**3
|
|
931
976
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
932
977
|
|
|
@@ -959,12 +1004,23 @@ class KumoRFM:
|
|
|
959
1004
|
raise NotImplementedError
|
|
960
1005
|
logger.log(msg)
|
|
961
1006
|
|
|
962
|
-
entity_table_names
|
|
963
|
-
|
|
1007
|
+
entity_table_names: Tuple[str, ...]
|
|
1008
|
+
if task_type.is_link_pred:
|
|
1009
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1010
|
+
assert final_aggr is not None
|
|
1011
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1012
|
+
for edge_type in self._graph_store.edge_types:
|
|
1013
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1014
|
+
entity_table_names = (
|
|
1015
|
+
query.entity_table,
|
|
1016
|
+
edge_type[2],
|
|
1017
|
+
)
|
|
1018
|
+
else:
|
|
1019
|
+
entity_table_names = (query.entity_table, )
|
|
964
1020
|
|
|
965
1021
|
# Exclude the entity anchor time from the feature set to prevent
|
|
966
1022
|
# running out-of-distribution between in-context and test examples:
|
|
967
|
-
exclude_cols_dict = query.
|
|
1023
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
968
1024
|
if anchor_time == 'entity':
|
|
969
1025
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
970
1026
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -993,7 +1049,7 @@ class KumoRFM:
|
|
|
993
1049
|
|
|
994
1050
|
step_size: Optional[int] = None
|
|
995
1051
|
if query.query_type == QueryType.TEMPORAL:
|
|
996
|
-
step_size = date_offset_to_seconds(
|
|
1052
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
997
1053
|
|
|
998
1054
|
return Context(
|
|
999
1055
|
task_type=task_type,
|
kumoai/utils/progress_logger.py
CHANGED
|
@@ -103,10 +103,13 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
103
103
|
self._progress.update(self._task, advance=1) # type: ignore
|
|
104
104
|
|
|
105
105
|
def __enter__(self) -> Self:
|
|
106
|
+
from kumoai import in_notebook
|
|
107
|
+
|
|
106
108
|
super().__enter__()
|
|
107
109
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
+
if not in_notebook(): # Render progress bar in TUI.
|
|
111
|
+
sys.stdout.write("\x1b]9;4;3\x07")
|
|
112
|
+
sys.stdout.flush()
|
|
110
113
|
|
|
111
114
|
if self.verbose:
|
|
112
115
|
self._live = Live(
|
|
@@ -119,6 +122,8 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
119
122
|
return self
|
|
120
123
|
|
|
121
124
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
125
|
+
from kumoai import in_notebook
|
|
126
|
+
|
|
122
127
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
123
128
|
|
|
124
129
|
if exc_type is not None:
|
|
@@ -134,8 +139,9 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
134
139
|
self._live.stop()
|
|
135
140
|
self._live = None
|
|
136
141
|
|
|
137
|
-
|
|
138
|
-
|
|
142
|
+
if not in_notebook():
|
|
143
|
+
sys.stdout.write("\x1b]9;4;0\x07")
|
|
144
|
+
sys.stdout.flush()
|
|
139
145
|
|
|
140
146
|
def __rich_console__(
|
|
141
147
|
self,
|
{kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.12.0.
|
|
3
|
+
Version: 2.12.0.dev202511111731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.45.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
kumoai/_logging.py,sha256=U2_5ROdyk92P4xO4H2WJV8EC7dr6YxmmnM-b7QX9M7I,886
|
|
2
2
|
kumoai/mixin.py,sha256=MP413xzuCqWhxAPUHmloLA3j4ZyF1tEtfi516b_hOXQ,812
|
|
3
|
-
kumoai/_version.py,sha256=
|
|
3
|
+
kumoai/_version.py,sha256=EmBJ4U0JvENPiq7lq8M80mpSdMDFEwNkBsjWDdzaLT4,39
|
|
4
4
|
kumoai/__init__.py,sha256=LU1zmKYc0KV5hy2VGKUuXgSvbJwj2rSRQ_R_bpHyl1o,10708
|
|
5
5
|
kumoai/formatting.py,sha256=jA_rLDCGKZI8WWCha-vtuLenVKTZvli99Tqpurz1H84,953
|
|
6
6
|
kumoai/futures.py,sha256=oJFIfdCM_3nWIqQteBKYMY4fPhoYlYWE_JA2o6tx-ng,3737
|
|
@@ -13,17 +13,15 @@ kumoai/_singleton.py,sha256=UTwrbDkoZSGB8ZelorvprPDDv9uZkUi1q_SrmsyngpQ,836
|
|
|
13
13
|
kumoai/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
kumoai/experimental/rfm/local_graph_sampler.py,sha256=o60_sdMa_fr60DrdmCIaE6lKQAD2msp1t-GGubFNt-o,6738
|
|
15
15
|
kumoai/experimental/rfm/local_graph.py,sha256=2iJDlsGVzqCe1bD_puXWlhwGkn7YnQyJ4p4C-fwCZNE,30076
|
|
16
|
-
kumoai/experimental/rfm/local_pquery_driver.py,sha256=
|
|
17
|
-
kumoai/experimental/rfm/__init__.py,sha256=
|
|
16
|
+
kumoai/experimental/rfm/local_pquery_driver.py,sha256=aO7Jfwx9gxGKYvpqxZx1LLWdI1MhuZQOPtAITxoOQO0,26162
|
|
17
|
+
kumoai/experimental/rfm/__init__.py,sha256=ornmi2x947jkQLptMn7ZLvTf2Sw-RMcVW73AnjVsWAo,1709
|
|
18
18
|
kumoai/experimental/rfm/utils.py,sha256=3IiBvT_aLBkkcJh3H11_50yt_XlEzHR0cm9Kprrtl8k,11123
|
|
19
19
|
kumoai/experimental/rfm/local_table.py,sha256=r8xZ33Mjs6JD8ud6h23tZ99Dag2DvZ4h6tWjmGrKQg4,19605
|
|
20
|
-
kumoai/experimental/rfm/rfm.py,sha256=
|
|
20
|
+
kumoai/experimental/rfm/rfm.py,sha256=V2NxxhrYi_MqLi_xcZsOYsdciT7V44iS5Fc9Ewq9eiM,48101
|
|
21
21
|
kumoai/experimental/rfm/local_graph_store.py,sha256=8BqonuaMftAAsjgZpB369i5AeNd1PkisMbbEqc0cKBo,13847
|
|
22
22
|
kumoai/experimental/rfm/authenticate.py,sha256=FiuHMvP7V3zBZUlHMDMbNLhc-UgDZgz4hjVSTuQ7DRw,18888
|
|
23
|
-
kumoai/experimental/rfm/pquery/
|
|
24
|
-
kumoai/experimental/rfm/pquery/
|
|
25
|
-
kumoai/experimental/rfm/pquery/pandas_backend.py,sha256=pgHCErSo6U-KJMhgIYijYt96uubtFB2WtsrTdLU7NYc,15396
|
|
26
|
-
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=BgF3saosisgLHx1RyLj-HSEbMp4xLatNuARdKWwiiLY,17326
|
|
23
|
+
kumoai/experimental/rfm/pquery/__init__.py,sha256=X0O3EIq5SMfBEE-ii5Cq6iDhR3s3XMXB52Cx5htoePw,152
|
|
24
|
+
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=kiBJq7uVGbasG7TiqsubEl6ey3UYzZiM4bwxILqp_54,18487
|
|
27
25
|
kumoai/experimental/rfm/pquery/executor.py,sha256=f7-pJhL0BgFU9E4o4gQpQyArOvyrZtwxFmks34-QOAE,2741
|
|
28
26
|
kumoai/experimental/rfm/infer/multicategorical.py,sha256=0-cLpDnGryhr76QhZNO-klKokJ6MUSfxXcGdQ61oykY,1102
|
|
29
27
|
kumoai/experimental/rfm/infer/categorical.py,sha256=VwNaKwKbRYkTxEJ1R6gziffC8dGsEThcDEfbi-KqW5c,853
|
|
@@ -40,7 +38,7 @@ kumoai/artifact_export/job.py,sha256=GEisSwvcjK_35RgOfsLXGgxMTXIWm765B_BW_Kgs-V0
|
|
|
40
38
|
kumoai/artifact_export/__init__.py,sha256=BsfDrc3mCHpO9-BqvqKm8qrXDIwfdaoH5UIoG4eQkc4,238
|
|
41
39
|
kumoai/utils/datasets.py,sha256=ptKIUoBONVD55pTVNdRCkQT3NWdN_r9UAUu4xewPa3U,2928
|
|
42
40
|
kumoai/utils/__init__.py,sha256=wGDC_31XJ-7ipm6eawjLAJaP4EfmtNOH8BHzaetQ9Ko,268
|
|
43
|
-
kumoai/utils/progress_logger.py,sha256=
|
|
41
|
+
kumoai/utils/progress_logger.py,sha256=pngEGzMHkiOUKOa6fbzxCEc2xlA4SJKV4TDTVVoqObM,5062
|
|
44
42
|
kumoai/utils/forecasting.py,sha256=-nDS6ucKNfQhTQOfebjefj0wwWH3-KYNslIomxwwMBM,7415
|
|
45
43
|
kumoai/codegen/generate.py,sha256=SvfWWa71xSAOjH9645yQvgoEM-o4BYjupM_EpUxqB_E,7331
|
|
46
44
|
kumoai/codegen/naming.py,sha256=_XVQGxHfuub4bhvyuBKjltD5Lm_oPpibvP_LZteCGk0,3021
|
|
@@ -84,8 +82,8 @@ kumoai/client/jobs.py,sha256=iu_Wrta6BQMlV6ZtzSnmhjwNPKDMQDXOsqVVIyWodqw,17074
|
|
|
84
82
|
kumoai/client/utils.py,sha256=lz1NubwMDHCwzQRowRXm7mjAoYRd5UjRQIwXdtWAl90,3849
|
|
85
83
|
kumoai/client/connector.py,sha256=x3i2aBTJTEMZvYRcWkY-UfWVOANZjqAso4GBbcshFjw,3920
|
|
86
84
|
kumoai/client/table.py,sha256=cQG-RPm-e91idEgse1IPJDvBmzddIDGDkuyrR1rq4wU,3235
|
|
87
|
-
kumoai/client/rfm.py,sha256=
|
|
88
|
-
kumoai/client/endpoints.py,sha256=
|
|
85
|
+
kumoai/client/rfm.py,sha256=NxKk8mH2A-B58rSXhDWaph4KeiSyJYDq-RO-vAHh7es,3726
|
|
86
|
+
kumoai/client/endpoints.py,sha256=iF2ZD25AJCIVbmBJ8tTZ8y1Ch0m6nTp18ydN7h4WiTk,5382
|
|
89
87
|
kumoai/trainer/config.py,sha256=-2RfK10AsVVThSyfWtlyfH4Fc4EwTdu0V3yrDRtIOjk,98
|
|
90
88
|
kumoai/trainer/util.py,sha256=bDPGkMF9KOy4HgtA-OwhXP17z9cbrfMnZGtyGuUq_Eo,4062
|
|
91
89
|
kumoai/trainer/job.py,sha256=Wk69nzFhbvuA3nEvtCstI04z5CxkgvQ6tHnGchE0Lkg,44938
|
|
@@ -93,8 +91,8 @@ kumoai/trainer/baseline_trainer.py,sha256=LlfViNOmswNv4c6zJJLsyv0pC2mM2WKMGYx06o
|
|
|
93
91
|
kumoai/trainer/__init__.py,sha256=zUdFl-f-sBWmm2x8R-rdVzPBeU2FaMzUY5mkcgoTa1k,939
|
|
94
92
|
kumoai/trainer/online_serving.py,sha256=9cddb5paeZaCgbUeceQdAOxysCtV5XP-KcsgFz_XR5w,9566
|
|
95
93
|
kumoai/trainer/trainer.py,sha256=hBXO7gwpo3t59zKFTeIkK65B8QRmWCwO33sbDuEAPlY,20133
|
|
96
|
-
kumoai-2.12.0.
|
|
97
|
-
kumoai-2.12.0.
|
|
98
|
-
kumoai-2.12.0.
|
|
99
|
-
kumoai-2.12.0.
|
|
100
|
-
kumoai-2.12.0.
|
|
94
|
+
kumoai-2.12.0.dev202511111731.dist-info/RECORD,,
|
|
95
|
+
kumoai-2.12.0.dev202511111731.dist-info/WHEEL,sha256=sunMa2yiYbrNLGeMVDqEA0ayyJbHlex7SCn1TZrEq60,136
|
|
96
|
+
kumoai-2.12.0.dev202511111731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
|
|
97
|
+
kumoai-2.12.0.dev202511111731.dist-info/METADATA,sha256=sNoIEIZxJx58O-0mQyfBmpsnrkAzg3ZVQhucsvlDX64,2052
|
|
98
|
+
kumoai-2.12.0.dev202511111731.dist-info/licenses/LICENSE,sha256=TbWlyqRmhq9PEzCaTI0H0nWLQCCOywQM8wYH8MbjfLo,1102
|