kumoai 2.11.0.dev202510191831__cp311-cp311-win_amd64.whl → 2.12.0.dev202511061731__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +20 -0
- kumoai/experimental/rfm/local_pquery_driver.py +221 -26
- kumoai/experimental/rfm/pquery/__init__.py +0 -4
- kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
- kumoai/experimental/rfm/rfm.py +82 -58
- kumoai/kumolib.cp311-win_amd64.pyd +0 -0
- kumoai/trainer/trainer.py +9 -10
- {kumoai-2.11.0.dev202510191831.dist-info → kumoai-2.12.0.dev202511061731.dist-info}/METADATA +2 -2
- {kumoai-2.11.0.dev202510191831.dist-info → kumoai-2.12.0.dev202511061731.dist-info}/RECORD +15 -17
- kumoai/experimental/rfm/pquery/backend.py +0 -136
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
- {kumoai-2.11.0.dev202510191831.dist-info → kumoai-2.12.0.dev202511061731.dist-info}/WHEEL +0 -0
- {kumoai-2.11.0.dev202510191831.dist-info → kumoai-2.12.0.dev202511061731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.11.0.dev202510191831.dist-info → kumoai-2.12.0.dev202511061731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,19 +5,18 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import Iterator, List, Literal, Optional, Union, overload
|
|
8
|
+
from typing import Iterator, List, Literal, Optional, Tuple, Union, overload
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
13
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
14
|
from kumoapi.rfm import Context
|
|
15
15
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
16
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
17
|
RFMEvaluateRequest,
|
|
18
|
+
RFMParseQueryRequest,
|
|
19
19
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
20
|
)
|
|
22
21
|
from kumoapi.task import TaskType
|
|
23
22
|
|
|
@@ -199,6 +198,7 @@ class KumoRFM:
|
|
|
199
198
|
max_pq_iterations: int = 20,
|
|
200
199
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
201
200
|
verbose: Union[bool, ProgressLogger] = True,
|
|
201
|
+
use_prediction_time: bool = False,
|
|
202
202
|
) -> pd.DataFrame:
|
|
203
203
|
pass
|
|
204
204
|
|
|
@@ -217,6 +217,7 @@ class KumoRFM:
|
|
|
217
217
|
max_pq_iterations: int = 20,
|
|
218
218
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
219
219
|
verbose: Union[bool, ProgressLogger] = True,
|
|
220
|
+
use_prediction_time: bool = False,
|
|
220
221
|
) -> Explanation:
|
|
221
222
|
pass
|
|
222
223
|
|
|
@@ -234,6 +235,7 @@ class KumoRFM:
|
|
|
234
235
|
max_pq_iterations: int = 20,
|
|
235
236
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
236
237
|
verbose: Union[bool, ProgressLogger] = True,
|
|
238
|
+
use_prediction_time: bool = False,
|
|
237
239
|
) -> Union[pd.DataFrame, Explanation]:
|
|
238
240
|
"""Returns predictions for a predictive query.
|
|
239
241
|
|
|
@@ -264,6 +266,9 @@ class KumoRFM:
|
|
|
264
266
|
entities to find valid labels.
|
|
265
267
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
266
268
|
verbose: Whether to print verbose output.
|
|
269
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
270
|
+
additional feature during prediction. This is typically
|
|
271
|
+
beneficial for time series forecasting tasks.
|
|
267
272
|
|
|
268
273
|
Returns:
|
|
269
274
|
The predictions as a :class:`pandas.DataFrame`.
|
|
@@ -283,15 +288,12 @@ class KumoRFM:
|
|
|
283
288
|
f"suppress this warning.")
|
|
284
289
|
|
|
285
290
|
if indices is None:
|
|
286
|
-
if query_def.
|
|
291
|
+
if query_def.rfm_entity_ids is None:
|
|
287
292
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
288
293
|
"pass them via `predict(query, indices=...)`")
|
|
289
|
-
indices = query_def.
|
|
294
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
290
295
|
else:
|
|
291
|
-
query_def = replace(
|
|
292
|
-
query_def,
|
|
293
|
-
entity=replace(query_def.entity, ids=None),
|
|
294
|
-
)
|
|
296
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
295
297
|
|
|
296
298
|
if len(indices) == 0:
|
|
297
299
|
raise ValueError("At least one entity is required")
|
|
@@ -314,8 +316,8 @@ class KumoRFM:
|
|
|
314
316
|
|
|
315
317
|
batch_size: Optional[int] = None
|
|
316
318
|
if self._batch_size == 'max':
|
|
317
|
-
task_type =
|
|
318
|
-
|
|
319
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
320
|
+
query_def,
|
|
319
321
|
edge_types=self._graph_store.edge_types,
|
|
320
322
|
)
|
|
321
323
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -353,6 +355,7 @@ class KumoRFM:
|
|
|
353
355
|
request = RFMPredictRequest(
|
|
354
356
|
context=context,
|
|
355
357
|
run_mode=RunMode(run_mode),
|
|
358
|
+
use_prediction_time=use_prediction_time,
|
|
356
359
|
)
|
|
357
360
|
with warnings.catch_warnings():
|
|
358
361
|
warnings.filterwarnings('ignore', message='gencode')
|
|
@@ -385,7 +388,7 @@ class KumoRFM:
|
|
|
385
388
|
|
|
386
389
|
# Cast 'ENTITY' to correct data type:
|
|
387
390
|
if 'ENTITY' in df:
|
|
388
|
-
entity = query_def.
|
|
391
|
+
entity = query_def.entity_table
|
|
389
392
|
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
390
393
|
df['ENTITY'] = df['ENTITY'].astype(
|
|
391
394
|
type(pkey_map.index[0]))
|
|
@@ -461,11 +464,11 @@ class KumoRFM:
|
|
|
461
464
|
query_def = self._parse_query(query)
|
|
462
465
|
|
|
463
466
|
if indices is None:
|
|
464
|
-
if query_def.
|
|
467
|
+
if query_def.rfm_entity_ids is None:
|
|
465
468
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
466
469
|
"pass them via "
|
|
467
470
|
"`is_valid_entity(query, indices=...)`")
|
|
468
|
-
indices = query_def.
|
|
471
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
469
472
|
|
|
470
473
|
if len(indices) == 0:
|
|
471
474
|
raise ValueError("At least one entity is required")
|
|
@@ -477,14 +480,13 @@ class KumoRFM:
|
|
|
477
480
|
self._validate_time(query_def, anchor_time, None, False)
|
|
478
481
|
else:
|
|
479
482
|
assert anchor_time == 'entity'
|
|
480
|
-
if (query_def.
|
|
481
|
-
not in self._graph_store.time_dict):
|
|
483
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
482
484
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
483
|
-
f"table '{query_def.
|
|
484
|
-
f"to have a time column")
|
|
485
|
+
f"table '{query_def.entity_table}' "
|
|
486
|
+
f"to have a time column.")
|
|
485
487
|
|
|
486
488
|
node = self._graph_store.get_node_id(
|
|
487
|
-
table_name=query_def.
|
|
489
|
+
table_name=query_def.entity_table,
|
|
488
490
|
pkey=pd.Series(indices),
|
|
489
491
|
)
|
|
490
492
|
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
@@ -503,6 +505,7 @@ class KumoRFM:
|
|
|
503
505
|
max_pq_iterations: int = 20,
|
|
504
506
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
505
507
|
verbose: Union[bool, ProgressLogger] = True,
|
|
508
|
+
use_prediction_time: bool = False,
|
|
506
509
|
) -> pd.DataFrame:
|
|
507
510
|
"""Evaluates a predictive query.
|
|
508
511
|
|
|
@@ -526,6 +529,9 @@ class KumoRFM:
|
|
|
526
529
|
entities to find valid labels.
|
|
527
530
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
528
531
|
verbose: Whether to print verbose output.
|
|
532
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
533
|
+
additional feature during prediction. This is typically
|
|
534
|
+
beneficial for time series forecasting tasks.
|
|
529
535
|
|
|
530
536
|
Returns:
|
|
531
537
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -536,10 +542,10 @@ class KumoRFM:
|
|
|
536
542
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
537
543
|
f"custom 'num_hops={num_hops}' option")
|
|
538
544
|
|
|
539
|
-
if query_def.
|
|
545
|
+
if query_def.rfm_entity_ids is not None:
|
|
540
546
|
query_def = replace(
|
|
541
547
|
query_def,
|
|
542
|
-
|
|
548
|
+
rfm_entity_ids=None,
|
|
543
549
|
)
|
|
544
550
|
|
|
545
551
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -569,6 +575,7 @@ class KumoRFM:
|
|
|
569
575
|
context=context,
|
|
570
576
|
run_mode=RunMode(run_mode),
|
|
571
577
|
metrics=metrics,
|
|
578
|
+
use_prediction_time=use_prediction_time,
|
|
572
579
|
)
|
|
573
580
|
with warnings.catch_warnings():
|
|
574
581
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -627,18 +634,19 @@ class KumoRFM:
|
|
|
627
634
|
|
|
628
635
|
if anchor_time is None:
|
|
629
636
|
anchor_time = self._graph_store.max_time
|
|
630
|
-
|
|
631
|
-
|
|
637
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
638
|
+
anchor_time = anchor_time - (
|
|
639
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
640
|
+
query_def.num_forecasts)
|
|
632
641
|
|
|
633
642
|
assert anchor_time is not None
|
|
634
643
|
if isinstance(anchor_time, pd.Timestamp):
|
|
635
644
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
636
645
|
else:
|
|
637
646
|
assert anchor_time == 'entity'
|
|
638
|
-
if (query_def.
|
|
639
|
-
not in self._graph_store.time_dict):
|
|
647
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
640
648
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
641
|
-
f"table '{query_def.
|
|
649
|
+
f"table '{query_def.entity_table}' "
|
|
642
650
|
f"to have a time column")
|
|
643
651
|
|
|
644
652
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -653,7 +661,7 @@ class KumoRFM:
|
|
|
653
661
|
)
|
|
654
662
|
|
|
655
663
|
entity = self._graph_store.pkey_map_dict[
|
|
656
|
-
query_def.
|
|
664
|
+
query_def.entity_table].index[node]
|
|
657
665
|
|
|
658
666
|
return pd.DataFrame({
|
|
659
667
|
'ENTITY': entity,
|
|
@@ -663,8 +671,8 @@ class KumoRFM:
|
|
|
663
671
|
|
|
664
672
|
# Helpers #################################################################
|
|
665
673
|
|
|
666
|
-
def _parse_query(self, query: str) ->
|
|
667
|
-
if isinstance(query,
|
|
674
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
675
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
668
676
|
return query
|
|
669
677
|
|
|
670
678
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -674,12 +682,12 @@ class KumoRFM:
|
|
|
674
682
|
"predictions or evaluations.")
|
|
675
683
|
|
|
676
684
|
try:
|
|
677
|
-
request =
|
|
685
|
+
request = RFMParseQueryRequest(
|
|
678
686
|
query=query,
|
|
679
687
|
graph_definition=self._graph_def,
|
|
680
688
|
)
|
|
681
689
|
|
|
682
|
-
resp = global_state.client.rfm_api.
|
|
690
|
+
resp = global_state.client.rfm_api.parse_query(request)
|
|
683
691
|
# TODO Expose validation warnings.
|
|
684
692
|
|
|
685
693
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -690,7 +698,7 @@ class KumoRFM:
|
|
|
690
698
|
warnings.warn(f"Encountered the following warnings during "
|
|
691
699
|
f"parsing:\n{msg}")
|
|
692
700
|
|
|
693
|
-
return resp.
|
|
701
|
+
return resp.query
|
|
694
702
|
except HTTPException as e:
|
|
695
703
|
try:
|
|
696
704
|
msg = json.loads(e.detail)['detail']
|
|
@@ -701,7 +709,7 @@ class KumoRFM:
|
|
|
701
709
|
|
|
702
710
|
def _validate_time(
|
|
703
711
|
self,
|
|
704
|
-
query:
|
|
712
|
+
query: ValidatedPredictiveQuery,
|
|
705
713
|
anchor_time: pd.Timestamp,
|
|
706
714
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
707
715
|
evaluate: bool,
|
|
@@ -724,6 +732,11 @@ class KumoRFM:
|
|
|
724
732
|
f"only contains data back to "
|
|
725
733
|
f"'{self._graph_store.min_time}'.")
|
|
726
734
|
|
|
735
|
+
if query.target_ast.date_offset_range is not None:
|
|
736
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
737
|
+
else:
|
|
738
|
+
end_offset = pd.DateOffset(0)
|
|
739
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
727
740
|
if (context_anchor_time is not None
|
|
728
741
|
and context_anchor_time > anchor_time):
|
|
729
742
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -732,19 +745,18 @@ class KumoRFM:
|
|
|
732
745
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
733
746
|
f"intended.")
|
|
734
747
|
elif (query.query_type == QueryType.TEMPORAL
|
|
735
|
-
and context_anchor_time is not None
|
|
736
|
-
|
|
748
|
+
and context_anchor_time is not None
|
|
749
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
737
750
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
738
751
|
f"'{context_anchor_time}' will leak information "
|
|
739
752
|
f"from the prediction anchor timestamp "
|
|
740
753
|
f"'{anchor_time}'. Please make sure this is "
|
|
741
754
|
f"intended.")
|
|
742
755
|
|
|
743
|
-
elif (context_anchor_time is not None
|
|
744
|
-
|
|
756
|
+
elif (context_anchor_time is not None
|
|
757
|
+
and context_anchor_time - forecast_end_offset
|
|
745
758
|
< self._graph_store.min_time):
|
|
746
|
-
_time = context_anchor_time -
|
|
747
|
-
query.num_forecasts)
|
|
759
|
+
_time = context_anchor_time - forecast_end_offset
|
|
748
760
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
749
761
|
f"aggregation time range is too large. To form "
|
|
750
762
|
f"proper input data, we would need data back to "
|
|
@@ -757,8 +769,7 @@ class KumoRFM:
|
|
|
757
769
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
758
770
|
f"in the data. Please make sure this is intended.")
|
|
759
771
|
|
|
760
|
-
max_eval_time =
|
|
761
|
-
query.target.end_offset * query.num_forecasts)
|
|
772
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
762
773
|
if evaluate and anchor_time > max_eval_time:
|
|
763
774
|
raise ValueError(
|
|
764
775
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -766,7 +777,7 @@ class KumoRFM:
|
|
|
766
777
|
|
|
767
778
|
def _get_context(
|
|
768
779
|
self,
|
|
769
|
-
query:
|
|
780
|
+
query: ValidatedPredictiveQuery,
|
|
770
781
|
indices: Union[List[str], List[float], List[int], None],
|
|
771
782
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
772
783
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -794,8 +805,8 @@ class KumoRFM:
|
|
|
794
805
|
f"must go beyond this for your use-case.")
|
|
795
806
|
|
|
796
807
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
797
|
-
task_type =
|
|
798
|
-
|
|
808
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
809
|
+
query,
|
|
799
810
|
edge_types=self._graph_store.edge_types,
|
|
800
811
|
)
|
|
801
812
|
|
|
@@ -827,11 +838,15 @@ class KumoRFM:
|
|
|
827
838
|
else:
|
|
828
839
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
829
840
|
|
|
841
|
+
if query.target_ast.date_offset_range is None:
|
|
842
|
+
end_offset = pd.DateOffset(0)
|
|
843
|
+
else:
|
|
844
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
845
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
830
846
|
if anchor_time is None:
|
|
831
847
|
anchor_time = self._graph_store.max_time
|
|
832
848
|
if evaluate:
|
|
833
|
-
anchor_time = anchor_time -
|
|
834
|
-
query.num_forecasts)
|
|
849
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
835
850
|
if logger is not None:
|
|
836
851
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
837
852
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -846,15 +861,14 @@ class KumoRFM:
|
|
|
846
861
|
assert anchor_time is not None
|
|
847
862
|
if isinstance(anchor_time, pd.Timestamp):
|
|
848
863
|
if context_anchor_time is None:
|
|
849
|
-
context_anchor_time = anchor_time -
|
|
850
|
-
query.num_forecasts)
|
|
864
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
851
865
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
852
866
|
evaluate)
|
|
853
867
|
else:
|
|
854
868
|
assert anchor_time == 'entity'
|
|
855
|
-
if query.
|
|
869
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
856
870
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
857
|
-
f"table '{query.
|
|
871
|
+
f"table '{query.entity_table}' to "
|
|
858
872
|
f"have a time column")
|
|
859
873
|
if context_anchor_time is not None:
|
|
860
874
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -905,7 +919,7 @@ class KumoRFM:
|
|
|
905
919
|
f"in batches")
|
|
906
920
|
|
|
907
921
|
test_node = self._graph_store.get_node_id(
|
|
908
|
-
table_name=query.
|
|
922
|
+
table_name=query.entity_table,
|
|
909
923
|
pkey=pd.Series(indices),
|
|
910
924
|
)
|
|
911
925
|
|
|
@@ -913,8 +927,7 @@ class KumoRFM:
|
|
|
913
927
|
test_time = pd.Series(anchor_time).repeat(
|
|
914
928
|
len(test_node)).reset_index(drop=True)
|
|
915
929
|
else:
|
|
916
|
-
time = self._graph_store.time_dict[
|
|
917
|
-
query.entity.pkey.table_name]
|
|
930
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
918
931
|
time = time[test_node] * 1000**3
|
|
919
932
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
920
933
|
|
|
@@ -947,12 +960,23 @@ class KumoRFM:
|
|
|
947
960
|
raise NotImplementedError
|
|
948
961
|
logger.log(msg)
|
|
949
962
|
|
|
950
|
-
entity_table_names
|
|
951
|
-
|
|
963
|
+
entity_table_names: Tuple[str, ...]
|
|
964
|
+
if task_type.is_link_pred:
|
|
965
|
+
final_aggr = query.get_final_target_aggregation()
|
|
966
|
+
assert final_aggr is not None
|
|
967
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
968
|
+
for edge_type in self._graph_store.edge_types:
|
|
969
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
970
|
+
entity_table_names = (
|
|
971
|
+
query.entity_table,
|
|
972
|
+
edge_type[2],
|
|
973
|
+
)
|
|
974
|
+
else:
|
|
975
|
+
entity_table_names = (query.entity_table, )
|
|
952
976
|
|
|
953
977
|
# Exclude the entity anchor time from the feature set to prevent
|
|
954
978
|
# running out-of-distribution between in-context and test examples:
|
|
955
|
-
exclude_cols_dict = query.
|
|
979
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
956
980
|
if anchor_time == 'entity':
|
|
957
981
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
958
982
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -981,7 +1005,7 @@ class KumoRFM:
|
|
|
981
1005
|
|
|
982
1006
|
step_size: Optional[int] = None
|
|
983
1007
|
if query.query_type == QueryType.TEMPORAL:
|
|
984
|
-
step_size = date_offset_to_seconds(
|
|
1008
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
985
1009
|
|
|
986
1010
|
return Context(
|
|
987
1011
|
task_type=task_type,
|
|
Binary file
|
kumoai/trainer/trainer.py
CHANGED
|
@@ -20,7 +20,6 @@ from kumoapi.jobs import (
|
|
|
20
20
|
TrainingJobResource,
|
|
21
21
|
)
|
|
22
22
|
from kumoapi.model_plan import ModelPlan
|
|
23
|
-
from kumoapi.task import TaskType
|
|
24
23
|
|
|
25
24
|
from kumoai import global_state
|
|
26
25
|
from kumoai.artifact_export.config import OutputConfig
|
|
@@ -405,15 +404,15 @@ class Trainer:
|
|
|
405
404
|
pred_table_data_path = prediction_table.table_data_uri
|
|
406
405
|
|
|
407
406
|
api = global_state.client.batch_prediction_job_api
|
|
408
|
-
|
|
409
|
-
from kumoai.pquery.predictive_query import PredictiveQuery
|
|
410
|
-
pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
411
|
-
if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
407
|
+
# Remove to resolve https://github.com/kumo-ai/kumo/issues/24250
|
|
408
|
+
# from kumoai.pquery.predictive_query import PredictiveQuery
|
|
409
|
+
# pquery = PredictiveQuery.load_from_training_job(training_job_id)
|
|
410
|
+
# if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
|
|
411
|
+
# if binary_classification_threshold is None:
|
|
412
|
+
# logger.warning(
|
|
413
|
+
# "No binary classification threshold provided. "
|
|
414
|
+
# "Using default threshold of 0.5.")
|
|
415
|
+
# binary_classification_threshold = 0.5
|
|
417
416
|
job_id, response = api.maybe_create(
|
|
418
417
|
BatchPredictionRequest(
|
|
419
418
|
dict(custom_tags),
|
{kumoai-2.11.0.dev202510191831.dist-info → kumoai-2.12.0.dev202511061731.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.12.0.dev202511061731
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.45.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
kumoai/__init__.py,sha256=
|
|
1
|
+
kumoai/__init__.py,sha256=4efagNAotP3c8mj8yyDGfVFcbgQ9l4wRC4FP-Yt0J3E,11002
|
|
2
2
|
kumoai/_logging.py,sha256=qL4JbMQwKXri2f-SEJoFB8TY5ALG12S-nobGTNWxW-A,915
|
|
3
3
|
kumoai/_singleton.py,sha256=i2BHWKpccNh5SJGDyU0IXsnYzJAYr8Xb0wz4c6LRbpo,861
|
|
4
|
-
kumoai/_version.py,sha256=
|
|
4
|
+
kumoai/_version.py,sha256=VXH5_higm9yOZfANxKmE3Z04HWIkHPFRAKVXx3JcG4s,39
|
|
5
5
|
kumoai/databricks.py,sha256=ahwJz6DWLXMkndT0XwEDBxF-hoqhidFR8wBUQ4TLZ68,490
|
|
6
6
|
kumoai/exceptions.py,sha256=7TMs0SC8xrU009_Pgd4QXtSF9lxJq8MtRbeX9pcQUy4,859
|
|
7
7
|
kumoai/formatting.py,sha256=o3uCnLwXPhe1KI5WV9sBgRrcU7ed4rgu_pf89GL9Nc0,983
|
|
8
8
|
kumoai/futures.py,sha256=J8rtZMEYFzdn5xF_x-LAiKJz3KGL6PT02f6rq_2bOJk,3836
|
|
9
9
|
kumoai/jobs.py,sha256=dCi7BAdfm2tCnonYlGU4WJokJWbh3RzFfaOX2EYCIHU,2576
|
|
10
|
-
kumoai/kumolib.cp311-win_amd64.pyd,sha256=
|
|
10
|
+
kumoai/kumolib.cp311-win_amd64.pyd,sha256=gYxmlmFJnGDOi4hWT_K8mo__KVzRBukEivTzuip4TQY,195584
|
|
11
11
|
kumoai/mixin.py,sha256=IaiB8SAI0VqOoMVzzIaUlqMt53-QPUK6OB0HikG-V9E,840
|
|
12
12
|
kumoai/spcs.py,sha256=SWvfkeJvb_7sGkjSqyMBIuPbMTWCP6v0BC9HBXM1uSI,4398
|
|
13
13
|
kumoai/artifact_export/__init__.py,sha256=UXAQI5q92ChBzWAk8o3J6pElzYHudAzFZssQXd4o7i8,247
|
|
@@ -16,12 +16,12 @@ kumoai/artifact_export/job.py,sha256=lOFIdPCrvhwdfvvDhQ2yzW8J4qIdYQoHZO1Rz3kJky4
|
|
|
16
16
|
kumoai/client/__init__.py,sha256=v0ISO1QD8JJhIJS6IzWz5-SL3EhtNCPeX3j1b2HBY0s,69
|
|
17
17
|
kumoai/client/client.py,sha256=IoZ6WH-VIAdwpwmd5DhP4HqjQL_YpB5vaWjtaWrNECk,8801
|
|
18
18
|
kumoai/client/connector.py,sha256=CO2LG5aDpCLxWNYYFRXGZs1AhYH3dRcbqBEUGwHQGzQ,4030
|
|
19
|
-
kumoai/client/endpoints.py,sha256=
|
|
19
|
+
kumoai/client/endpoints.py,sha256=DpEKEQ1yvL15iHZadXZKO94t-qXrYLaeV1sknX4IuPg,5532
|
|
20
20
|
kumoai/client/graph.py,sha256=6MFyPYxDPfGTWeAI_84RUgWx9rVvqbLnR0Ourtgj5rg,3951
|
|
21
21
|
kumoai/client/jobs.py,sha256=Y8wKiTk1I5ywc-2cxR72LaBjfhPTCVOezSCTeDpTs8Q,17521
|
|
22
22
|
kumoai/client/online.py,sha256=4s_8Sv8m_k_tty4CO7RuAt0e6BDMkGvsZZ3VX8zyDb8,2798
|
|
23
23
|
kumoai/client/pquery.py,sha256=0pXgQLxjoaFWDif0XRAuC_P-X3OSnXNWsiVrXej9uMk,7094
|
|
24
|
-
kumoai/client/rfm.py,sha256=
|
|
24
|
+
kumoai/client/rfm.py,sha256=XCLJsSBe82fErLchpuS4Zb7fA3LBY8QxxIhrbw4_NPQ,3678
|
|
25
25
|
kumoai/client/source_table.py,sha256=mMHJtQ_yUHRI9LdHLVHxNGt83bbzmC1_d-NmXjbiTuI,2154
|
|
26
26
|
kumoai/client/table.py,sha256=VhjLEMLQS1Z7zjcb2Yt3gZfiVqiD7b1gj-WNux_504A,3336
|
|
27
27
|
kumoai/client/utils.py,sha256=RSD5Ia0lQQDR1drRFBJFdo2KVHfQqhJuk6m6du7Kl4E,3979
|
|
@@ -58,20 +58,18 @@ kumoai/experimental/rfm/authenticate.py,sha256=G89_4TMeUpr5fG_0VTzMF5sdNhaciitA1
|
|
|
58
58
|
kumoai/experimental/rfm/local_graph.py,sha256=nZ9hDfyWg1dHFLoTEKoLt0ZJPvf9MUA1MNyfTRzJThg,30886
|
|
59
59
|
kumoai/experimental/rfm/local_graph_sampler.py,sha256=ZCnILozG95EzpgMqhGTG2AF85JphLvAhj-3YPaTqoaQ,6922
|
|
60
60
|
kumoai/experimental/rfm/local_graph_store.py,sha256=eUuIMFcdIRqN1kRxnqOdJpKEt-S_oyupAyHr7YuQoSU,14206
|
|
61
|
-
kumoai/experimental/rfm/local_pquery_driver.py,sha256=
|
|
61
|
+
kumoai/experimental/rfm/local_pquery_driver.py,sha256=XHxRTMRVUzKNlTItkOmW_ClEQ1xgvvwIC6MBLt7qihA,26857
|
|
62
62
|
kumoai/experimental/rfm/local_table.py,sha256=5H08657TIyH7n_QnpFKr2g4BtVqdXTymmrfhSGaDmkU,20150
|
|
63
|
-
kumoai/experimental/rfm/rfm.py,sha256=
|
|
63
|
+
kumoai/experimental/rfm/rfm.py,sha256=K9Fm6O3GWkoOCv9Bq8jSdnWvuMyPYk4lmU1WJIpLSPY,47815
|
|
64
64
|
kumoai/experimental/rfm/utils.py,sha256=dLx2wdyTWg7vZI_7R-I0z_lA-2aV5M8h9n3bnnLyylI,11467
|
|
65
65
|
kumoai/experimental/rfm/infer/__init__.py,sha256=fPsdDr4D3hgC8snW0j3pAVpCyR-xrauuogMnTOMrfok,304
|
|
66
66
|
kumoai/experimental/rfm/infer/categorical.py,sha256=bqmfrE5ZCBTcb35lA4SyAkCu3MgttAn29VBJYMBNhVg,893
|
|
67
67
|
kumoai/experimental/rfm/infer/id.py,sha256=xaJBETLZa8ttzZCsDwFSwfyCi3VYsLc_kDWT_t_6Ih4,954
|
|
68
68
|
kumoai/experimental/rfm/infer/multicategorical.py,sha256=D-1KwYRkOSkBrOJr4Xa3eTCoAF9O9hPGa7Vg67V5_HU,1150
|
|
69
69
|
kumoai/experimental/rfm/infer/timestamp.py,sha256=L2VxjtYTSyUBYAo4M-L08xSQlPpqnHMAVF5_vxjh3Y0,1135
|
|
70
|
-
kumoai/experimental/rfm/pquery/__init__.py,sha256=
|
|
71
|
-
kumoai/experimental/rfm/pquery/backend.py,sha256=mGbRdDcZxRGhFGz55bDHCICkEzsYRO3Gyj95QkzxpKY,3423
|
|
70
|
+
kumoai/experimental/rfm/pquery/__init__.py,sha256=RkTn0I74uXOUuOiBpa6S-_QEYctMutkUnBEfF9ztQzI,159
|
|
72
71
|
kumoai/experimental/rfm/pquery/executor.py,sha256=S8wwXbAkH-YSnmEVYB8d6wyJF4JJ003mH_0zFTvOp_I,2843
|
|
73
|
-
kumoai/experimental/rfm/pquery/
|
|
74
|
-
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=W0CEnjDdqxkBADSyvnupwS1k86N9DhFXejJEDKS1MBo,17832
|
|
72
|
+
kumoai/experimental/rfm/pquery/pandas_executor.py,sha256=QQpOZ_ArH3eSAkenaY3J-gW1Wn5A7f85RiqZxaO5u1Q,19019
|
|
75
73
|
kumoai/graph/__init__.py,sha256=QGk3OMwRzQJSGESdcc7hcQH6UDmNVJYTdqnRren4c7Q,240
|
|
76
74
|
kumoai/graph/column.py,sha256=cQhioibTbIKIBZ-bf8-Bt4F4Iblhidps-CYWrkxRPnE,4295
|
|
77
75
|
kumoai/graph/graph.py,sha256=Pq-dxi4MwoDtrrwm3xeyUB9Hl7ryNfHq4rMHuvyNB3c,39239
|
|
@@ -87,14 +85,14 @@ kumoai/trainer/baseline_trainer.py,sha256=oXweh8j1sar6KhQfr3A7gmQxcDq7SG0Bx3jIen
|
|
|
87
85
|
kumoai/trainer/config.py,sha256=7_Jv1w1mqaokCQwQdJkqCSgVpmh8GqE3fL1Ky_vvttI,100
|
|
88
86
|
kumoai/trainer/job.py,sha256=IBP2SeIk21XpRK1Um1NIs2dEKid319cHu6UkCjKO6jc,46130
|
|
89
87
|
kumoai/trainer/online_serving.py,sha256=T1jicl-qXiiWGQWUCwlfQsyxWUODybj_975gx9yglH4,9824
|
|
90
|
-
kumoai/trainer/trainer.py,sha256=
|
|
88
|
+
kumoai/trainer/trainer.py,sha256=AKumc3X2Vm3qxZSA85Dv_fSLC4JQ3rM7P0ixOWbEex0,20608
|
|
91
89
|
kumoai/trainer/util.py,sha256=LCXkY5MNl6NbEVd2OZ0aVqF6fvr3KiCFh6pH0igAi_g,4165
|
|
92
90
|
kumoai/utils/__init__.py,sha256=wAKgmwtMIGuiauW9D_GGKH95K-24Kgwmld27mm4nsro,278
|
|
93
91
|
kumoai/utils/datasets.py,sha256=UyAII-oAn7x3ombuvpbSQ41aVF9SYKBjQthTD-vcT2A,3011
|
|
94
92
|
kumoai/utils/forecasting.py,sha256=ZgKeUCbWLOot0giAkoigwU5du8LkrwAicFOi5hVn6wg,7624
|
|
95
93
|
kumoai/utils/progress_logger.py,sha256=tzwFrUO5VuiArxx9_tSETno8JF5rnFOedX26I2yDW10,5046
|
|
96
|
-
kumoai-2.
|
|
97
|
-
kumoai-2.
|
|
98
|
-
kumoai-2.
|
|
99
|
-
kumoai-2.
|
|
100
|
-
kumoai-2.
|
|
94
|
+
kumoai-2.12.0.dev202511061731.dist-info/licenses/LICENSE,sha256=ZUilBDp--4vbhsEr6f_Upw9rnIx09zQ3K9fXQ0rfd6w,1111
|
|
95
|
+
kumoai-2.12.0.dev202511061731.dist-info/METADATA,sha256=SPJcCqBkl2zsVJpfx8Z4REvTS6tcT1w4RzwKwk8GJVI,2112
|
|
96
|
+
kumoai-2.12.0.dev202511061731.dist-info/WHEEL,sha256=JLOMsP7F5qtkAkINx5UnzbFguf8CqZeraV8o04b0I8I,101
|
|
97
|
+
kumoai-2.12.0.dev202511061731.dist-info/top_level.txt,sha256=YjU6UcmomoDx30vEXLsOU784ED7VztQOsFApk1SFwvs,7
|
|
98
|
+
kumoai-2.12.0.dev202511061731.dist-info/RECORD,,
|
|
@@ -1,136 +0,0 @@
|
|
|
1
|
-
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, Generic, Optional, Tuple, TypeVar, Union
|
|
3
|
-
|
|
4
|
-
from kumoapi.rfm import PQueryDefinition
|
|
5
|
-
from kumoapi.rfm.pquery import (
|
|
6
|
-
Aggregation,
|
|
7
|
-
AggregationType,
|
|
8
|
-
BoolOp,
|
|
9
|
-
Column,
|
|
10
|
-
Condition,
|
|
11
|
-
Filter,
|
|
12
|
-
Float,
|
|
13
|
-
FloatList,
|
|
14
|
-
Int,
|
|
15
|
-
IntList,
|
|
16
|
-
LogicalOperation,
|
|
17
|
-
MemberOp,
|
|
18
|
-
RelOp,
|
|
19
|
-
Str,
|
|
20
|
-
StrList,
|
|
21
|
-
)
|
|
22
|
-
|
|
23
|
-
TableData = TypeVar('TableData')
|
|
24
|
-
ColumnData = TypeVar('ColumnData')
|
|
25
|
-
IndexData = TypeVar('IndexData')
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
|
|
29
|
-
@abstractmethod
|
|
30
|
-
def eval_aggregation_type(
|
|
31
|
-
self,
|
|
32
|
-
op: AggregationType,
|
|
33
|
-
feat: Optional[ColumnData],
|
|
34
|
-
batch: IndexData,
|
|
35
|
-
batch_size: int,
|
|
36
|
-
filter_na: bool = True,
|
|
37
|
-
) -> Tuple[ColumnData, IndexData]:
|
|
38
|
-
pass
|
|
39
|
-
|
|
40
|
-
@abstractmethod
|
|
41
|
-
def eval_rel_op(
|
|
42
|
-
self,
|
|
43
|
-
left: ColumnData,
|
|
44
|
-
op: RelOp,
|
|
45
|
-
right: Union[Int, Float, Str, None],
|
|
46
|
-
) -> ColumnData:
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
@abstractmethod
|
|
50
|
-
def eval_member_op(
|
|
51
|
-
self,
|
|
52
|
-
left: ColumnData,
|
|
53
|
-
op: MemberOp,
|
|
54
|
-
right: Union[IntList, FloatList, StrList],
|
|
55
|
-
) -> ColumnData:
|
|
56
|
-
pass
|
|
57
|
-
|
|
58
|
-
@abstractmethod
|
|
59
|
-
def eval_bool_op(
|
|
60
|
-
self,
|
|
61
|
-
left: ColumnData,
|
|
62
|
-
op: BoolOp,
|
|
63
|
-
right: Optional[ColumnData],
|
|
64
|
-
) -> ColumnData:
|
|
65
|
-
pass
|
|
66
|
-
|
|
67
|
-
@abstractmethod
|
|
68
|
-
def eval_column(
|
|
69
|
-
self,
|
|
70
|
-
column: Column,
|
|
71
|
-
feat_dict: Dict[str, TableData],
|
|
72
|
-
filter_na: bool = True,
|
|
73
|
-
) -> Tuple[ColumnData, IndexData]:
|
|
74
|
-
pass
|
|
75
|
-
|
|
76
|
-
@abstractmethod
|
|
77
|
-
def eval_aggregation(
|
|
78
|
-
self,
|
|
79
|
-
aggr: Aggregation,
|
|
80
|
-
feat_dict: Dict[str, TableData],
|
|
81
|
-
time_dict: Dict[str, ColumnData],
|
|
82
|
-
batch_dict: Dict[str, IndexData],
|
|
83
|
-
anchor_time: ColumnData,
|
|
84
|
-
filter_na: bool = True,
|
|
85
|
-
num_forecasts: int = 1,
|
|
86
|
-
) -> Tuple[ColumnData, IndexData]:
|
|
87
|
-
pass
|
|
88
|
-
|
|
89
|
-
@abstractmethod
|
|
90
|
-
def eval_condition(
|
|
91
|
-
self,
|
|
92
|
-
condition: Condition,
|
|
93
|
-
feat_dict: Dict[str, TableData],
|
|
94
|
-
time_dict: Dict[str, ColumnData],
|
|
95
|
-
batch_dict: Dict[str, IndexData],
|
|
96
|
-
anchor_time: ColumnData,
|
|
97
|
-
filter_na: bool = True,
|
|
98
|
-
num_forecasts: int = 1,
|
|
99
|
-
) -> Tuple[ColumnData, IndexData]:
|
|
100
|
-
pass
|
|
101
|
-
|
|
102
|
-
@abstractmethod
|
|
103
|
-
def eval_logical_operation(
|
|
104
|
-
self,
|
|
105
|
-
logical_operation: LogicalOperation,
|
|
106
|
-
feat_dict: Dict[str, TableData],
|
|
107
|
-
time_dict: Dict[str, ColumnData],
|
|
108
|
-
batch_dict: Dict[str, IndexData],
|
|
109
|
-
anchor_time: ColumnData,
|
|
110
|
-
filter_na: bool = True,
|
|
111
|
-
num_forecasts: int = 1,
|
|
112
|
-
) -> Tuple[ColumnData, IndexData]:
|
|
113
|
-
pass
|
|
114
|
-
|
|
115
|
-
@abstractmethod
|
|
116
|
-
def eval_filter(
|
|
117
|
-
self,
|
|
118
|
-
filter: Filter,
|
|
119
|
-
feat_dict: Dict[str, TableData],
|
|
120
|
-
time_dict: Dict[str, ColumnData],
|
|
121
|
-
batch_dict: Dict[str, IndexData],
|
|
122
|
-
anchor_time: ColumnData,
|
|
123
|
-
) -> IndexData:
|
|
124
|
-
pass
|
|
125
|
-
|
|
126
|
-
@abstractmethod
|
|
127
|
-
def eval_pquery(
|
|
128
|
-
self,
|
|
129
|
-
query: PQueryDefinition,
|
|
130
|
-
feat_dict: Dict[str, TableData],
|
|
131
|
-
time_dict: Dict[str, ColumnData],
|
|
132
|
-
batch_dict: Dict[str, IndexData],
|
|
133
|
-
anchor_time: ColumnData,
|
|
134
|
-
num_forecasts: int = 1,
|
|
135
|
-
) -> Tuple[ColumnData, IndexData]:
|
|
136
|
-
pass
|