kumoai 2.11.0.dev202510151831__cp312-cp312-macosx_11_0_arm64.whl → 2.13.0.dev202511161731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +3 -1
- kumoai/experimental/rfm/local_pquery_driver.py +221 -26
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +277 -223
- kumoai/experimental/rfm/rfm.py +140 -72
- kumoai/trainer/trainer.py +9 -10
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.11.0.dev202510151831.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/METADATA +2 -2
- {kumoai-2.11.0.dev202510151831.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/RECORD +17 -17
- {kumoai-2.11.0.dev202510151831.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/WHEEL +0 -0
- {kumoai-2.11.0.dev202510151831.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.11.0.dev202510151831.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,19 +5,28 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
24
|
from kumoapi.rfm import Context
|
|
15
25
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
26
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
19
29
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
30
|
)
|
|
22
31
|
from kumoapi.task import TaskType
|
|
23
32
|
|
|
@@ -30,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
30
39
|
LocalPQueryDriver,
|
|
31
40
|
date_offset_to_seconds,
|
|
32
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
33
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
34
44
|
|
|
35
45
|
_RANDOM_SEED = 42
|
|
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
60
70
|
"beyond this for your use-case.")
|
|
61
71
|
|
|
62
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
63
84
|
@dataclass(repr=False)
|
|
64
85
|
class Explanation:
|
|
65
86
|
prediction: pd.DataFrame
|
|
@@ -87,6 +108,12 @@ class Explanation:
|
|
|
87
108
|
def __repr__(self) -> str:
|
|
88
109
|
return str((self.prediction, self.summary))
|
|
89
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
90
117
|
|
|
91
118
|
class KumoRFM:
|
|
92
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -199,6 +226,7 @@ class KumoRFM:
|
|
|
199
226
|
max_pq_iterations: int = 20,
|
|
200
227
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
201
228
|
verbose: Union[bool, ProgressLogger] = True,
|
|
229
|
+
use_prediction_time: bool = False,
|
|
202
230
|
) -> pd.DataFrame:
|
|
203
231
|
pass
|
|
204
232
|
|
|
@@ -208,7 +236,7 @@ class KumoRFM:
|
|
|
208
236
|
query: str,
|
|
209
237
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
210
238
|
*,
|
|
211
|
-
explain: Literal[True],
|
|
239
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
212
240
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
213
241
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
214
242
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -217,6 +245,7 @@ class KumoRFM:
|
|
|
217
245
|
max_pq_iterations: int = 20,
|
|
218
246
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
219
247
|
verbose: Union[bool, ProgressLogger] = True,
|
|
248
|
+
use_prediction_time: bool = False,
|
|
220
249
|
) -> Explanation:
|
|
221
250
|
pass
|
|
222
251
|
|
|
@@ -225,7 +254,7 @@ class KumoRFM:
|
|
|
225
254
|
query: str,
|
|
226
255
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
227
256
|
*,
|
|
228
|
-
explain: bool = False,
|
|
257
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
229
258
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
230
259
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
231
260
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -234,6 +263,7 @@ class KumoRFM:
|
|
|
234
263
|
max_pq_iterations: int = 20,
|
|
235
264
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
236
265
|
verbose: Union[bool, ProgressLogger] = True,
|
|
266
|
+
use_prediction_time: bool = False,
|
|
237
267
|
) -> Union[pd.DataFrame, Explanation]:
|
|
238
268
|
"""Returns predictions for a predictive query.
|
|
239
269
|
|
|
@@ -244,9 +274,12 @@ class KumoRFM:
|
|
|
244
274
|
be generated for all indices, independent of whether they
|
|
245
275
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
246
276
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
247
|
-
explain:
|
|
248
|
-
|
|
249
|
-
|
|
277
|
+
explain: Configuration for explainability.
|
|
278
|
+
If set to ``True``, will additionally explain the prediction.
|
|
279
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
280
|
+
over which parts of explanation are generated.
|
|
281
|
+
Explainability is currently only supported for single entity
|
|
282
|
+
predictions with ``run_mode="FAST"``.
|
|
250
283
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
251
284
|
``None``, will use the maximum timestamp in the data.
|
|
252
285
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -264,45 +297,54 @@ class KumoRFM:
|
|
|
264
297
|
entities to find valid labels.
|
|
265
298
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
266
299
|
verbose: Whether to print verbose output.
|
|
300
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
301
|
+
additional feature during prediction. This is typically
|
|
302
|
+
beneficial for time series forecasting tasks.
|
|
267
303
|
|
|
268
304
|
Returns:
|
|
269
305
|
The predictions as a :class:`pandas.DataFrame`.
|
|
270
|
-
If ``explain
|
|
271
|
-
|
|
306
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
307
|
+
containing the prediction, summary, and details.
|
|
272
308
|
"""
|
|
309
|
+
explain_config: Optional[ExplainConfig] = None
|
|
310
|
+
if explain is True:
|
|
311
|
+
explain_config = ExplainConfig()
|
|
312
|
+
elif explain is not False:
|
|
313
|
+
explain_config = ExplainConfig._cast(explain)
|
|
314
|
+
|
|
273
315
|
query_def = self._parse_query(query)
|
|
316
|
+
query_str = query_def.to_string()
|
|
274
317
|
|
|
275
318
|
if num_hops != 2 and num_neighbors is not None:
|
|
276
319
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
277
320
|
f"custom 'num_hops={num_hops}' option")
|
|
278
321
|
|
|
279
|
-
if
|
|
322
|
+
if explain_config is not None and run_mode in {
|
|
323
|
+
RunMode.NORMAL, RunMode.BEST
|
|
324
|
+
}:
|
|
280
325
|
warnings.warn(f"Explainability is currently only supported for "
|
|
281
326
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
282
327
|
f"mode has been reset. Please lower the run mode to "
|
|
283
328
|
f"suppress this warning.")
|
|
284
329
|
|
|
285
330
|
if indices is None:
|
|
286
|
-
if query_def.
|
|
331
|
+
if query_def.rfm_entity_ids is None:
|
|
287
332
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
288
333
|
"pass them via `predict(query, indices=...)`")
|
|
289
|
-
indices = query_def.
|
|
334
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
290
335
|
else:
|
|
291
|
-
query_def = replace(
|
|
292
|
-
query_def,
|
|
293
|
-
entity=replace(query_def.entity, ids=None),
|
|
294
|
-
)
|
|
336
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
295
337
|
|
|
296
338
|
if len(indices) == 0:
|
|
297
339
|
raise ValueError("At least one entity is required")
|
|
298
340
|
|
|
299
|
-
if
|
|
341
|
+
if explain_config is not None and len(indices) > 1:
|
|
300
342
|
raise ValueError(
|
|
301
343
|
f"Cannot explain predictions for more than a single entity "
|
|
302
344
|
f"(got {len(indices)})")
|
|
303
345
|
|
|
304
346
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
305
|
-
if
|
|
347
|
+
if explain_config is not None:
|
|
306
348
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
307
349
|
else:
|
|
308
350
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -314,8 +356,8 @@ class KumoRFM:
|
|
|
314
356
|
|
|
315
357
|
batch_size: Optional[int] = None
|
|
316
358
|
if self._batch_size == 'max':
|
|
317
|
-
task_type =
|
|
318
|
-
|
|
359
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
360
|
+
query_def,
|
|
319
361
|
edge_types=self._graph_store.edge_types,
|
|
320
362
|
)
|
|
321
363
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -353,6 +395,8 @@ class KumoRFM:
|
|
|
353
395
|
request = RFMPredictRequest(
|
|
354
396
|
context=context,
|
|
355
397
|
run_mode=RunMode(run_mode),
|
|
398
|
+
query=query_str,
|
|
399
|
+
use_prediction_time=use_prediction_time,
|
|
356
400
|
)
|
|
357
401
|
with warnings.catch_warnings():
|
|
358
402
|
warnings.filterwarnings('ignore', message='gencode')
|
|
@@ -375,8 +419,11 @@ class KumoRFM:
|
|
|
375
419
|
|
|
376
420
|
for attempt in range(self.num_retries + 1):
|
|
377
421
|
try:
|
|
378
|
-
if
|
|
379
|
-
resp = global_state.client.rfm_api.explain(
|
|
422
|
+
if explain_config is not None:
|
|
423
|
+
resp = global_state.client.rfm_api.explain(
|
|
424
|
+
request=_bytes,
|
|
425
|
+
skip_summary=explain_config.skip_summary,
|
|
426
|
+
)
|
|
380
427
|
summary = resp.summary
|
|
381
428
|
details = resp.details
|
|
382
429
|
else:
|
|
@@ -385,7 +432,7 @@ class KumoRFM:
|
|
|
385
432
|
|
|
386
433
|
# Cast 'ENTITY' to correct data type:
|
|
387
434
|
if 'ENTITY' in df:
|
|
388
|
-
entity = query_def.
|
|
435
|
+
entity = query_def.entity_table
|
|
389
436
|
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
390
437
|
df['ENTITY'] = df['ENTITY'].astype(
|
|
391
438
|
type(pkey_map.index[0]))
|
|
@@ -427,7 +474,7 @@ class KumoRFM:
|
|
|
427
474
|
else:
|
|
428
475
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
429
476
|
|
|
430
|
-
if
|
|
477
|
+
if explain_config is not None:
|
|
431
478
|
assert len(predictions) == 1
|
|
432
479
|
assert summary is not None
|
|
433
480
|
assert details is not None
|
|
@@ -461,11 +508,11 @@ class KumoRFM:
|
|
|
461
508
|
query_def = self._parse_query(query)
|
|
462
509
|
|
|
463
510
|
if indices is None:
|
|
464
|
-
if query_def.
|
|
511
|
+
if query_def.rfm_entity_ids is None:
|
|
465
512
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
466
513
|
"pass them via "
|
|
467
514
|
"`is_valid_entity(query, indices=...)`")
|
|
468
|
-
indices = query_def.
|
|
515
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
469
516
|
|
|
470
517
|
if len(indices) == 0:
|
|
471
518
|
raise ValueError("At least one entity is required")
|
|
@@ -477,14 +524,13 @@ class KumoRFM:
|
|
|
477
524
|
self._validate_time(query_def, anchor_time, None, False)
|
|
478
525
|
else:
|
|
479
526
|
assert anchor_time == 'entity'
|
|
480
|
-
if (query_def.
|
|
481
|
-
not in self._graph_store.time_dict):
|
|
527
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
482
528
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
483
|
-
f"table '{query_def.
|
|
484
|
-
f"to have a time column")
|
|
529
|
+
f"table '{query_def.entity_table}' "
|
|
530
|
+
f"to have a time column.")
|
|
485
531
|
|
|
486
532
|
node = self._graph_store.get_node_id(
|
|
487
|
-
table_name=query_def.
|
|
533
|
+
table_name=query_def.entity_table,
|
|
488
534
|
pkey=pd.Series(indices),
|
|
489
535
|
)
|
|
490
536
|
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
@@ -503,6 +549,7 @@ class KumoRFM:
|
|
|
503
549
|
max_pq_iterations: int = 20,
|
|
504
550
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
505
551
|
verbose: Union[bool, ProgressLogger] = True,
|
|
552
|
+
use_prediction_time: bool = False,
|
|
506
553
|
) -> pd.DataFrame:
|
|
507
554
|
"""Evaluates a predictive query.
|
|
508
555
|
|
|
@@ -526,6 +573,9 @@ class KumoRFM:
|
|
|
526
573
|
entities to find valid labels.
|
|
527
574
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
528
575
|
verbose: Whether to print verbose output.
|
|
576
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
577
|
+
additional feature during prediction. This is typically
|
|
578
|
+
beneficial for time series forecasting tasks.
|
|
529
579
|
|
|
530
580
|
Returns:
|
|
531
581
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -536,10 +586,10 @@ class KumoRFM:
|
|
|
536
586
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
537
587
|
f"custom 'num_hops={num_hops}' option")
|
|
538
588
|
|
|
539
|
-
if query_def.
|
|
589
|
+
if query_def.rfm_entity_ids is not None:
|
|
540
590
|
query_def = replace(
|
|
541
591
|
query_def,
|
|
542
|
-
|
|
592
|
+
rfm_entity_ids=None,
|
|
543
593
|
)
|
|
544
594
|
|
|
545
595
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -569,6 +619,7 @@ class KumoRFM:
|
|
|
569
619
|
context=context,
|
|
570
620
|
run_mode=RunMode(run_mode),
|
|
571
621
|
metrics=metrics,
|
|
622
|
+
use_prediction_time=use_prediction_time,
|
|
572
623
|
)
|
|
573
624
|
with warnings.catch_warnings():
|
|
574
625
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -579,7 +630,7 @@ class KumoRFM:
|
|
|
579
630
|
|
|
580
631
|
if len(request_bytes) > _MAX_SIZE:
|
|
581
632
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
582
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
633
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
583
634
|
|
|
584
635
|
try:
|
|
585
636
|
resp = global_state.client.rfm_api.evaluate(request_bytes)
|
|
@@ -627,18 +678,19 @@ class KumoRFM:
|
|
|
627
678
|
|
|
628
679
|
if anchor_time is None:
|
|
629
680
|
anchor_time = self._graph_store.max_time
|
|
630
|
-
|
|
631
|
-
|
|
681
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
682
|
+
anchor_time = anchor_time - (
|
|
683
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
684
|
+
query_def.num_forecasts)
|
|
632
685
|
|
|
633
686
|
assert anchor_time is not None
|
|
634
687
|
if isinstance(anchor_time, pd.Timestamp):
|
|
635
688
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
636
689
|
else:
|
|
637
690
|
assert anchor_time == 'entity'
|
|
638
|
-
if (query_def.
|
|
639
|
-
not in self._graph_store.time_dict):
|
|
691
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
640
692
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
641
|
-
f"table '{query_def.
|
|
693
|
+
f"table '{query_def.entity_table}' "
|
|
642
694
|
f"to have a time column")
|
|
643
695
|
|
|
644
696
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -653,7 +705,7 @@ class KumoRFM:
|
|
|
653
705
|
)
|
|
654
706
|
|
|
655
707
|
entity = self._graph_store.pkey_map_dict[
|
|
656
|
-
query_def.
|
|
708
|
+
query_def.entity_table].index[node]
|
|
657
709
|
|
|
658
710
|
return pd.DataFrame({
|
|
659
711
|
'ENTITY': entity,
|
|
@@ -663,8 +715,8 @@ class KumoRFM:
|
|
|
663
715
|
|
|
664
716
|
# Helpers #################################################################
|
|
665
717
|
|
|
666
|
-
def _parse_query(self, query: str) ->
|
|
667
|
-
if isinstance(query,
|
|
718
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
719
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
668
720
|
return query
|
|
669
721
|
|
|
670
722
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -674,12 +726,12 @@ class KumoRFM:
|
|
|
674
726
|
"predictions or evaluations.")
|
|
675
727
|
|
|
676
728
|
try:
|
|
677
|
-
request =
|
|
729
|
+
request = RFMParseQueryRequest(
|
|
678
730
|
query=query,
|
|
679
731
|
graph_definition=self._graph_def,
|
|
680
732
|
)
|
|
681
733
|
|
|
682
|
-
resp = global_state.client.rfm_api.
|
|
734
|
+
resp = global_state.client.rfm_api.parse_query(request)
|
|
683
735
|
# TODO Expose validation warnings.
|
|
684
736
|
|
|
685
737
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -690,7 +742,7 @@ class KumoRFM:
|
|
|
690
742
|
warnings.warn(f"Encountered the following warnings during "
|
|
691
743
|
f"parsing:\n{msg}")
|
|
692
744
|
|
|
693
|
-
return resp.
|
|
745
|
+
return resp.query
|
|
694
746
|
except HTTPException as e:
|
|
695
747
|
try:
|
|
696
748
|
msg = json.loads(e.detail)['detail']
|
|
@@ -701,7 +753,7 @@ class KumoRFM:
|
|
|
701
753
|
|
|
702
754
|
def _validate_time(
|
|
703
755
|
self,
|
|
704
|
-
query:
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
705
757
|
anchor_time: pd.Timestamp,
|
|
706
758
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
707
759
|
evaluate: bool,
|
|
@@ -724,6 +776,11 @@ class KumoRFM:
|
|
|
724
776
|
f"only contains data back to "
|
|
725
777
|
f"'{self._graph_store.min_time}'.")
|
|
726
778
|
|
|
779
|
+
if query.target_ast.date_offset_range is not None:
|
|
780
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
781
|
+
else:
|
|
782
|
+
end_offset = pd.DateOffset(0)
|
|
783
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
727
784
|
if (context_anchor_time is not None
|
|
728
785
|
and context_anchor_time > anchor_time):
|
|
729
786
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -732,19 +789,18 @@ class KumoRFM:
|
|
|
732
789
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
733
790
|
f"intended.")
|
|
734
791
|
elif (query.query_type == QueryType.TEMPORAL
|
|
735
|
-
and context_anchor_time is not None
|
|
736
|
-
|
|
792
|
+
and context_anchor_time is not None
|
|
793
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
737
794
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
738
795
|
f"'{context_anchor_time}' will leak information "
|
|
739
796
|
f"from the prediction anchor timestamp "
|
|
740
797
|
f"'{anchor_time}'. Please make sure this is "
|
|
741
798
|
f"intended.")
|
|
742
799
|
|
|
743
|
-
elif (context_anchor_time is not None
|
|
744
|
-
|
|
800
|
+
elif (context_anchor_time is not None
|
|
801
|
+
and context_anchor_time - forecast_end_offset
|
|
745
802
|
< self._graph_store.min_time):
|
|
746
|
-
_time = context_anchor_time -
|
|
747
|
-
query.num_forecasts)
|
|
803
|
+
_time = context_anchor_time - forecast_end_offset
|
|
748
804
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
749
805
|
f"aggregation time range is too large. To form "
|
|
750
806
|
f"proper input data, we would need data back to "
|
|
@@ -757,8 +813,7 @@ class KumoRFM:
|
|
|
757
813
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
758
814
|
f"in the data. Please make sure this is intended.")
|
|
759
815
|
|
|
760
|
-
max_eval_time =
|
|
761
|
-
query.target.end_offset * query.num_forecasts)
|
|
816
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
762
817
|
if evaluate and anchor_time > max_eval_time:
|
|
763
818
|
raise ValueError(
|
|
764
819
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -766,7 +821,7 @@ class KumoRFM:
|
|
|
766
821
|
|
|
767
822
|
def _get_context(
|
|
768
823
|
self,
|
|
769
|
-
query:
|
|
824
|
+
query: ValidatedPredictiveQuery,
|
|
770
825
|
indices: Union[List[str], List[float], List[int], None],
|
|
771
826
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
772
827
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -794,8 +849,8 @@ class KumoRFM:
|
|
|
794
849
|
f"must go beyond this for your use-case.")
|
|
795
850
|
|
|
796
851
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
797
|
-
task_type =
|
|
798
|
-
|
|
852
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
853
|
+
query,
|
|
799
854
|
edge_types=self._graph_store.edge_types,
|
|
800
855
|
)
|
|
801
856
|
|
|
@@ -827,11 +882,15 @@ class KumoRFM:
|
|
|
827
882
|
else:
|
|
828
883
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
829
884
|
|
|
885
|
+
if query.target_ast.date_offset_range is None:
|
|
886
|
+
end_offset = pd.DateOffset(0)
|
|
887
|
+
else:
|
|
888
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
889
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
830
890
|
if anchor_time is None:
|
|
831
891
|
anchor_time = self._graph_store.max_time
|
|
832
892
|
if evaluate:
|
|
833
|
-
anchor_time = anchor_time -
|
|
834
|
-
query.num_forecasts)
|
|
893
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
835
894
|
if logger is not None:
|
|
836
895
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
837
896
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -846,15 +905,14 @@ class KumoRFM:
|
|
|
846
905
|
assert anchor_time is not None
|
|
847
906
|
if isinstance(anchor_time, pd.Timestamp):
|
|
848
907
|
if context_anchor_time is None:
|
|
849
|
-
context_anchor_time = anchor_time -
|
|
850
|
-
query.num_forecasts)
|
|
908
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
851
909
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
852
910
|
evaluate)
|
|
853
911
|
else:
|
|
854
912
|
assert anchor_time == 'entity'
|
|
855
|
-
if query.
|
|
913
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
856
914
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
857
|
-
f"table '{query.
|
|
915
|
+
f"table '{query.entity_table}' to "
|
|
858
916
|
f"have a time column")
|
|
859
917
|
if context_anchor_time is not None:
|
|
860
918
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -905,7 +963,7 @@ class KumoRFM:
|
|
|
905
963
|
f"in batches")
|
|
906
964
|
|
|
907
965
|
test_node = self._graph_store.get_node_id(
|
|
908
|
-
table_name=query.
|
|
966
|
+
table_name=query.entity_table,
|
|
909
967
|
pkey=pd.Series(indices),
|
|
910
968
|
)
|
|
911
969
|
|
|
@@ -913,8 +971,7 @@ class KumoRFM:
|
|
|
913
971
|
test_time = pd.Series(anchor_time).repeat(
|
|
914
972
|
len(test_node)).reset_index(drop=True)
|
|
915
973
|
else:
|
|
916
|
-
time = self._graph_store.time_dict[
|
|
917
|
-
query.entity.pkey.table_name]
|
|
974
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
918
975
|
time = time[test_node] * 1000**3
|
|
919
976
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
920
977
|
|
|
@@ -947,12 +1004,23 @@ class KumoRFM:
|
|
|
947
1004
|
raise NotImplementedError
|
|
948
1005
|
logger.log(msg)
|
|
949
1006
|
|
|
950
|
-
entity_table_names
|
|
951
|
-
|
|
1007
|
+
entity_table_names: Tuple[str, ...]
|
|
1008
|
+
if task_type.is_link_pred:
|
|
1009
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1010
|
+
assert final_aggr is not None
|
|
1011
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1012
|
+
for edge_type in self._graph_store.edge_types:
|
|
1013
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1014
|
+
entity_table_names = (
|
|
1015
|
+
query.entity_table,
|
|
1016
|
+
edge_type[2],
|
|
1017
|
+
)
|
|
1018
|
+
else:
|
|
1019
|
+
entity_table_names = (query.entity_table, )
|
|
952
1020
|
|
|
953
1021
|
# Exclude the entity anchor time from the feature set to prevent
|
|
954
1022
|
# running out-of-distribution between in-context and test examples:
|
|
955
|
-
exclude_cols_dict = query.
|
|
1023
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
956
1024
|
if anchor_time == 'entity':
|
|
957
1025
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
958
1026
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -981,7 +1049,7 @@ class KumoRFM:
|
|
|
981
1049
|
|
|
982
1050
|
step_size: Optional[int] = None
|
|
983
1051
|
if query.query_type == QueryType.TEMPORAL:
|
|
984
|
-
step_size = date_offset_to_seconds(
|
|
1052
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
985
1053
|
|
|
986
1054
|
return Context(
|
|
987
1055
|
task_type=task_type,
|
kumoai/trainer/trainer.py
CHANGED
|
@@ -20,7 +20,6 @@ from kumoapi.jobs import (
|
|
|
20
20
|
TrainingJobResource,
|
|
21
21
|
)
|
|
22
22
|
from kumoapi.model_plan import ModelPlan
|
|
23
|
-
from kumoapi.task import TaskType
|
|
24
23
|
|
|
25
24
|
from kumoai import global_state
|
|
26
25
|
from kumoai.artifact_export.config import OutputConfig
|
|
@@ -405,15 +404,15 @@ class Trainer:
|
|
|
405
404
|
pred_table_data_path = prediction_table.table_data_uri
|
|
406
405
|
|
|
407
406
|
api = global_state.client.batch_prediction_job_api
|
|
408
|
-
|
|
409
|
-
from kumoai.pquery.predictive_query import PredictiveQuery
|
|
410
|
-
pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
411
|
-
if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
407
|
+
# Remove to resolve https://github.com/kumo-ai/kumo/issues/24250
|
|
408
|
+
# from kumoai.pquery.predictive_query import PredictiveQuery
|
|
409
|
+
# pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
410
|
+
# if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
411
|
+
# if binary_classification_threshold is None:
|
|
412
|
+
# logger.warning(
|
|
413
|
+
# "No binary classification threshold provided. "
|
|
414
|
+
# "Using default threshold of 0.5.")
|
|
415
|
+
# binary_classification_threshold = 0.5
|
|
417
416
|
job_id, response = api.maybe_create(
|
|
418
417
|
BatchPredictionRequest(
|
|
419
418
|
dict(custom_tags),
|
kumoai/utils/progress_logger.py
CHANGED
|
@@ -103,10 +103,13 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
103
103
|
self._progress.update(self._task, advance=1) # type: ignore
|
|
104
104
|
|
|
105
105
|
def __enter__(self) -> Self:
|
|
106
|
+
from kumoai import in_notebook
|
|
107
|
+
|
|
106
108
|
super().__enter__()
|
|
107
109
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
+
if not in_notebook(): # Render progress bar in TUI.
|
|
111
|
+
sys.stdout.write("\x1b]9;4;3\x07")
|
|
112
|
+
sys.stdout.flush()
|
|
110
113
|
|
|
111
114
|
if self.verbose:
|
|
112
115
|
self._live = Live(
|
|
@@ -119,6 +122,8 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
119
122
|
return self
|
|
120
123
|
|
|
121
124
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
125
|
+
from kumoai import in_notebook
|
|
126
|
+
|
|
122
127
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
123
128
|
|
|
124
129
|
if exc_type is not None:
|
|
@@ -134,8 +139,9 @@ class InteractiveProgressLogger(ProgressLogger):
|
|
|
134
139
|
self._live.stop()
|
|
135
140
|
self._live = None
|
|
136
141
|
|
|
137
|
-
|
|
138
|
-
|
|
142
|
+
if not in_notebook():
|
|
143
|
+
sys.stdout.write("\x1b]9;4;0\x07")
|
|
144
|
+
sys.stdout.flush()
|
|
139
145
|
|
|
140
146
|
def __rich_console__(
|
|
141
147
|
self,
|
{kumoai-2.11.0.dev202510151831.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.13.0.dev202511161731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.45.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|