kumoai 2.11.0.dev202510161830__py3-none-any.whl → 2.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,19 +5,28 @@ from collections import defaultdict
5
5
  from collections.abc import Generator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import Iterator, List, Literal, Optional, Union, overload
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
9
19
 
10
20
  import numpy as np
11
21
  import pandas as pd
12
22
  from kumoapi.model_plan import RunMode
13
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
24
  from kumoapi.rfm import Context
15
25
  from kumoapi.rfm import Explanation as ExplanationConfig
16
26
  from kumoapi.rfm import (
17
- PQueryDefinition,
18
27
  RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
19
29
  RFMPredictRequest,
20
- RFMValidateQueryRequest,
21
30
  )
22
31
  from kumoapi.task import TaskType
23
32
 
@@ -30,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
30
39
  LocalPQueryDriver,
31
40
  date_offset_to_seconds,
32
41
  )
42
+ from kumoai.mixin import CastMixin
33
43
  from kumoai.utils import InteractiveProgressLogger, ProgressLogger
34
44
 
35
45
  _RANDOM_SEED = 42
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
60
70
  "beyond this for your use-case.")
61
71
 
62
72
 
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
63
84
  @dataclass(repr=False)
64
85
  class Explanation:
65
86
  prediction: pd.DataFrame
@@ -87,6 +108,12 @@ class Explanation:
87
108
  def __repr__(self) -> str:
88
109
  return str((self.prediction, self.summary))
89
110
 
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
90
117
 
91
118
  class KumoRFM:
92
119
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
@@ -199,6 +226,7 @@ class KumoRFM:
199
226
  max_pq_iterations: int = 20,
200
227
  random_seed: Optional[int] = _RANDOM_SEED,
201
228
  verbose: Union[bool, ProgressLogger] = True,
229
+ use_prediction_time: bool = False,
202
230
  ) -> pd.DataFrame:
203
231
  pass
204
232
 
@@ -208,7 +236,7 @@ class KumoRFM:
208
236
  query: str,
209
237
  indices: Union[List[str], List[float], List[int], None] = None,
210
238
  *,
211
- explain: Literal[True],
239
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
212
240
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
213
241
  context_anchor_time: Union[pd.Timestamp, None] = None,
214
242
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -217,6 +245,7 @@ class KumoRFM:
217
245
  max_pq_iterations: int = 20,
218
246
  random_seed: Optional[int] = _RANDOM_SEED,
219
247
  verbose: Union[bool, ProgressLogger] = True,
248
+ use_prediction_time: bool = False,
220
249
  ) -> Explanation:
221
250
  pass
222
251
 
@@ -225,7 +254,7 @@ class KumoRFM:
225
254
  query: str,
226
255
  indices: Union[List[str], List[float], List[int], None] = None,
227
256
  *,
228
- explain: bool = False,
257
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
229
258
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
230
259
  context_anchor_time: Union[pd.Timestamp, None] = None,
231
260
  run_mode: Union[RunMode, str] = RunMode.FAST,
@@ -234,6 +263,7 @@ class KumoRFM:
234
263
  max_pq_iterations: int = 20,
235
264
  random_seed: Optional[int] = _RANDOM_SEED,
236
265
  verbose: Union[bool, ProgressLogger] = True,
266
+ use_prediction_time: bool = False,
237
267
  ) -> Union[pd.DataFrame, Explanation]:
238
268
  """Returns predictions for a predictive query.
239
269
 
@@ -244,9 +274,12 @@ class KumoRFM:
244
274
  be generated for all indices, independent of whether they
245
275
  fulfill entity filter constraints. To pre-filter entities, use
246
276
  :meth:`~KumoRFM.is_valid_entity`.
247
- explain: If set to ``True``, will additionally explain the
248
- prediction. Explainability is currently only supported for
249
- single entity predictions with ``run_mode="FAST"``.
277
+ explain: Configuration for explainability.
278
+ If set to ``True``, will additionally explain the prediction.
279
+ Passing in an :class:`ExplainConfig` instance provides control
280
+ over which parts of explanation are generated.
281
+ Explainability is currently only supported for single entity
282
+ predictions with ``run_mode="FAST"``.
250
283
  anchor_time: The anchor timestamp for the prediction. If set to
251
284
  ``None``, will use the maximum timestamp in the data.
252
285
  If set to ``"entity"``, will use the timestamp of the entity.
@@ -264,45 +297,54 @@ class KumoRFM:
264
297
  entities to find valid labels.
265
298
  random_seed: A manual seed for generating pseudo-random numbers.
266
299
  verbose: Whether to print verbose output.
300
+ use_prediction_time: Whether to use the anchor timestamp as an
301
+ additional feature during prediction. This is typically
302
+ beneficial for time series forecasting tasks.
267
303
 
268
304
  Returns:
269
305
  The predictions as a :class:`pandas.DataFrame`.
270
- If ``explain=True``, additionally returns a textual summary that
271
- explains the prediction.
306
+ If ``explain`` is provided, returns an :class:`Explanation` object
307
+ containing the prediction, summary, and details.
272
308
  """
309
+ explain_config: Optional[ExplainConfig] = None
310
+ if explain is True:
311
+ explain_config = ExplainConfig()
312
+ elif explain is not False:
313
+ explain_config = ExplainConfig._cast(explain)
314
+
273
315
  query_def = self._parse_query(query)
316
+ query_str = query_def.to_string()
274
317
 
275
318
  if num_hops != 2 and num_neighbors is not None:
276
319
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
277
320
  f"custom 'num_hops={num_hops}' option")
278
321
 
279
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
322
+ if explain_config is not None and run_mode in {
323
+ RunMode.NORMAL, RunMode.BEST
324
+ }:
280
325
  warnings.warn(f"Explainability is currently only supported for "
281
326
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
282
327
  f"mode has been reset. Please lower the run mode to "
283
328
  f"suppress this warning.")
284
329
 
285
330
  if indices is None:
286
- if query_def.entity.ids is None:
331
+ if query_def.rfm_entity_ids is None:
287
332
  raise ValueError("Cannot find entities to predict for. Please "
288
333
  "pass them via `predict(query, indices=...)`")
289
- indices = query_def.entity.ids.value
334
+ indices = query_def.get_rfm_entity_id_list()
290
335
  else:
291
- query_def = replace(
292
- query_def,
293
- entity=replace(query_def.entity, ids=None),
294
- )
336
+ query_def = replace(query_def, rfm_entity_ids=None)
295
337
 
296
338
  if len(indices) == 0:
297
339
  raise ValueError("At least one entity is required")
298
340
 
299
- if explain and len(indices) > 1:
341
+ if explain_config is not None and len(indices) > 1:
300
342
  raise ValueError(
301
343
  f"Cannot explain predictions for more than a single entity "
302
344
  f"(got {len(indices)})")
303
345
 
304
346
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
305
- if explain:
347
+ if explain_config is not None:
306
348
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
307
349
  else:
308
350
  msg = f'[bold]PREDICT[/bold] {query_repr}'
@@ -314,8 +356,8 @@ class KumoRFM:
314
356
 
315
357
  batch_size: Optional[int] = None
316
358
  if self._batch_size == 'max':
317
- task_type = query_def.get_task_type(
318
- stypes=self._graph_store.stype_dict,
359
+ task_type = LocalPQueryDriver.get_task_type(
360
+ query_def,
319
361
  edge_types=self._graph_store.edge_types,
320
362
  )
321
363
  batch_size = _MAX_PRED_SIZE[task_type]
@@ -353,6 +395,8 @@ class KumoRFM:
353
395
  request = RFMPredictRequest(
354
396
  context=context,
355
397
  run_mode=RunMode(run_mode),
398
+ query=query_str,
399
+ use_prediction_time=use_prediction_time,
356
400
  )
357
401
  with warnings.catch_warnings():
358
402
  warnings.filterwarnings('ignore', message='gencode')
@@ -375,8 +419,11 @@ class KumoRFM:
375
419
 
376
420
  for attempt in range(self.num_retries + 1):
377
421
  try:
378
- if explain:
379
- resp = global_state.client.rfm_api.explain(_bytes)
422
+ if explain_config is not None:
423
+ resp = global_state.client.rfm_api.explain(
424
+ request=_bytes,
425
+ skip_summary=explain_config.skip_summary,
426
+ )
380
427
  summary = resp.summary
381
428
  details = resp.details
382
429
  else:
@@ -385,7 +432,7 @@ class KumoRFM:
385
432
 
386
433
  # Cast 'ENTITY' to correct data type:
387
434
  if 'ENTITY' in df:
388
- entity = query_def.entity.pkey.table_name
435
+ entity = query_def.entity_table
389
436
  pkey_map = self._graph_store.pkey_map_dict[entity]
390
437
  df['ENTITY'] = df['ENTITY'].astype(
391
438
  type(pkey_map.index[0]))
@@ -427,7 +474,7 @@ class KumoRFM:
427
474
  else:
428
475
  prediction = pd.concat(predictions, ignore_index=True)
429
476
 
430
- if explain:
477
+ if explain_config is not None:
431
478
  assert len(predictions) == 1
432
479
  assert summary is not None
433
480
  assert details is not None
@@ -461,11 +508,11 @@ class KumoRFM:
461
508
  query_def = self._parse_query(query)
462
509
 
463
510
  if indices is None:
464
- if query_def.entity.ids is None:
511
+ if query_def.rfm_entity_ids is None:
465
512
  raise ValueError("Cannot find entities to predict for. Please "
466
513
  "pass them via "
467
514
  "`is_valid_entity(query, indices=...)`")
468
- indices = query_def.entity.ids.value
515
+ indices = query_def.get_rfm_entity_id_list()
469
516
 
470
517
  if len(indices) == 0:
471
518
  raise ValueError("At least one entity is required")
@@ -477,14 +524,13 @@ class KumoRFM:
477
524
  self._validate_time(query_def, anchor_time, None, False)
478
525
  else:
479
526
  assert anchor_time == 'entity'
480
- if (query_def.entity.pkey.table_name
481
- not in self._graph_store.time_dict):
527
+ if (query_def.entity_table not in self._graph_store.time_dict):
482
528
  raise ValueError(f"Anchor time 'entity' requires the entity "
483
- f"table '{query_def.entity.pkey.table_name}' "
484
- f"to have a time column")
529
+ f"table '{query_def.entity_table}' "
530
+ f"to have a time column.")
485
531
 
486
532
  node = self._graph_store.get_node_id(
487
- table_name=query_def.entity.pkey.table_name,
533
+ table_name=query_def.entity_table,
488
534
  pkey=pd.Series(indices),
489
535
  )
490
536
  query_driver = LocalPQueryDriver(self._graph_store, query_def)
@@ -503,6 +549,7 @@ class KumoRFM:
503
549
  max_pq_iterations: int = 20,
504
550
  random_seed: Optional[int] = _RANDOM_SEED,
505
551
  verbose: Union[bool, ProgressLogger] = True,
552
+ use_prediction_time: bool = False,
506
553
  ) -> pd.DataFrame:
507
554
  """Evaluates a predictive query.
508
555
 
@@ -526,6 +573,9 @@ class KumoRFM:
526
573
  entities to find valid labels.
527
574
  random_seed: A manual seed for generating pseudo-random numbers.
528
575
  verbose: Whether to print verbose output.
576
+ use_prediction_time: Whether to use the anchor timestamp as an
577
+ additional feature during prediction. This is typically
578
+ beneficial for time series forecasting tasks.
529
579
 
530
580
  Returns:
531
581
  The metrics as a :class:`pandas.DataFrame`
@@ -536,10 +586,10 @@ class KumoRFM:
536
586
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
537
587
  f"custom 'num_hops={num_hops}' option")
538
588
 
539
- if query_def.entity.ids is not None:
589
+ if query_def.rfm_entity_ids is not None:
540
590
  query_def = replace(
541
591
  query_def,
542
- entity=replace(query_def.entity, ids=None),
592
+ rfm_entity_ids=None,
543
593
  )
544
594
 
545
595
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
@@ -569,6 +619,7 @@ class KumoRFM:
569
619
  context=context,
570
620
  run_mode=RunMode(run_mode),
571
621
  metrics=metrics,
622
+ use_prediction_time=use_prediction_time,
572
623
  )
573
624
  with warnings.catch_warnings():
574
625
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -579,7 +630,7 @@ class KumoRFM:
579
630
 
580
631
  if len(request_bytes) > _MAX_SIZE:
581
632
  stats_msg = Context.get_memory_stats(request_msg.context)
582
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
633
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
583
634
 
584
635
  try:
585
636
  resp = global_state.client.rfm_api.evaluate(request_bytes)
@@ -627,18 +678,19 @@ class KumoRFM:
627
678
 
628
679
  if anchor_time is None:
629
680
  anchor_time = self._graph_store.max_time
630
- anchor_time = anchor_time - (query_def.target.end_offset *
631
- query_def.num_forecasts)
681
+ if query_def.target_ast.date_offset_range is not None:
682
+ anchor_time = anchor_time - (
683
+ query_def.target_ast.date_offset_range.end_date_offset *
684
+ query_def.num_forecasts)
632
685
 
633
686
  assert anchor_time is not None
634
687
  if isinstance(anchor_time, pd.Timestamp):
635
688
  self._validate_time(query_def, anchor_time, None, evaluate=True)
636
689
  else:
637
690
  assert anchor_time == 'entity'
638
- if (query_def.entity.pkey.table_name
639
- not in self._graph_store.time_dict):
691
+ if (query_def.entity_table not in self._graph_store.time_dict):
640
692
  raise ValueError(f"Anchor time 'entity' requires the entity "
641
- f"table '{query_def.entity.pkey.table_name}' "
693
+ f"table '{query_def.entity_table}' "
642
694
  f"to have a time column")
643
695
 
644
696
  query_driver = LocalPQueryDriver(self._graph_store, query_def,
@@ -653,7 +705,7 @@ class KumoRFM:
653
705
  )
654
706
 
655
707
  entity = self._graph_store.pkey_map_dict[
656
- query_def.entity.pkey.table_name].index[node]
708
+ query_def.entity_table].index[node]
657
709
 
658
710
  return pd.DataFrame({
659
711
  'ENTITY': entity,
@@ -663,8 +715,8 @@ class KumoRFM:
663
715
 
664
716
  # Helpers #################################################################
665
717
 
666
- def _parse_query(self, query: str) -> PQueryDefinition:
667
- if isinstance(query, PQueryDefinition):
718
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
719
+ if isinstance(query, ValidatedPredictiveQuery):
668
720
  return query
669
721
 
670
722
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -674,12 +726,12 @@ class KumoRFM:
674
726
  "predictions or evaluations.")
675
727
 
676
728
  try:
677
- request = RFMValidateQueryRequest(
729
+ request = RFMParseQueryRequest(
678
730
  query=query,
679
731
  graph_definition=self._graph_def,
680
732
  )
681
733
 
682
- resp = global_state.client.rfm_api.validate_query(request)
734
+ resp = global_state.client.rfm_api.parse_query(request)
683
735
  # TODO Expose validation warnings.
684
736
 
685
737
  if len(resp.validation_response.warnings) > 0:
@@ -690,7 +742,7 @@ class KumoRFM:
690
742
  warnings.warn(f"Encountered the following warnings during "
691
743
  f"parsing:\n{msg}")
692
744
 
693
- return resp.query_definition
745
+ return resp.query
694
746
  except HTTPException as e:
695
747
  try:
696
748
  msg = json.loads(e.detail)['detail']
@@ -701,7 +753,7 @@ class KumoRFM:
701
753
 
702
754
  def _validate_time(
703
755
  self,
704
- query: PQueryDefinition,
756
+ query: ValidatedPredictiveQuery,
705
757
  anchor_time: pd.Timestamp,
706
758
  context_anchor_time: Union[pd.Timestamp, None],
707
759
  evaluate: bool,
@@ -724,6 +776,11 @@ class KumoRFM:
724
776
  f"only contains data back to "
725
777
  f"'{self._graph_store.min_time}'.")
726
778
 
779
+ if query.target_ast.date_offset_range is not None:
780
+ end_offset = query.target_ast.date_offset_range.end_date_offset
781
+ else:
782
+ end_offset = pd.DateOffset(0)
783
+ forecast_end_offset = end_offset * query.num_forecasts
727
784
  if (context_anchor_time is not None
728
785
  and context_anchor_time > anchor_time):
729
786
  warnings.warn(f"Context anchor timestamp "
@@ -732,19 +789,18 @@ class KumoRFM:
732
789
  f"(got '{anchor_time}'). Please make sure this is "
733
790
  f"intended.")
734
791
  elif (query.query_type == QueryType.TEMPORAL
735
- and context_anchor_time is not None and context_anchor_time +
736
- query.target.end_offset * query.num_forecasts > anchor_time):
792
+ and context_anchor_time is not None
793
+ and context_anchor_time + forecast_end_offset > anchor_time):
737
794
  warnings.warn(f"Aggregation for context examples at timestamp "
738
795
  f"'{context_anchor_time}' will leak information "
739
796
  f"from the prediction anchor timestamp "
740
797
  f"'{anchor_time}'. Please make sure this is "
741
798
  f"intended.")
742
799
 
743
- elif (context_anchor_time is not None and context_anchor_time -
744
- query.target.end_offset * query.num_forecasts
800
+ elif (context_anchor_time is not None
801
+ and context_anchor_time - forecast_end_offset
745
802
  < self._graph_store.min_time):
746
- _time = context_anchor_time - (query.target.end_offset *
747
- query.num_forecasts)
803
+ _time = context_anchor_time - forecast_end_offset
748
804
  warnings.warn(f"Context anchor timestamp is too early or "
749
805
  f"aggregation time range is too large. To form "
750
806
  f"proper input data, we would need data back to "
@@ -757,8 +813,7 @@ class KumoRFM:
757
813
  f"latest timestamp '{self._graph_store.max_time}' "
758
814
  f"in the data. Please make sure this is intended.")
759
815
 
760
- max_eval_time = (self._graph_store.max_time -
761
- query.target.end_offset * query.num_forecasts)
816
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
762
817
  if evaluate and anchor_time > max_eval_time:
763
818
  raise ValueError(
764
819
  f"Anchor timestamp for evaluation is after the latest "
@@ -766,7 +821,7 @@ class KumoRFM:
766
821
 
767
822
  def _get_context(
768
823
  self,
769
- query: PQueryDefinition,
824
+ query: ValidatedPredictiveQuery,
770
825
  indices: Union[List[str], List[float], List[int], None],
771
826
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
772
827
  context_anchor_time: Union[pd.Timestamp, None],
@@ -794,8 +849,8 @@ class KumoRFM:
794
849
  f"must go beyond this for your use-case.")
795
850
 
796
851
  query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
797
- task_type = query.get_task_type(
798
- stypes=self._graph_store.stype_dict,
852
+ task_type = LocalPQueryDriver.get_task_type(
853
+ query,
799
854
  edge_types=self._graph_store.edge_types,
800
855
  )
801
856
 
@@ -827,11 +882,15 @@ class KumoRFM:
827
882
  else:
828
883
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
829
884
 
885
+ if query.target_ast.date_offset_range is None:
886
+ end_offset = pd.DateOffset(0)
887
+ else:
888
+ end_offset = query.target_ast.date_offset_range.end_date_offset
889
+ forecast_end_offset = end_offset * query.num_forecasts
830
890
  if anchor_time is None:
831
891
  anchor_time = self._graph_store.max_time
832
892
  if evaluate:
833
- anchor_time = anchor_time - (query.target.end_offset *
834
- query.num_forecasts)
893
+ anchor_time = anchor_time - forecast_end_offset
835
894
  if logger is not None:
836
895
  assert isinstance(anchor_time, pd.Timestamp)
837
896
  if anchor_time == pd.Timestamp.min:
@@ -846,15 +905,14 @@ class KumoRFM:
846
905
  assert anchor_time is not None
847
906
  if isinstance(anchor_time, pd.Timestamp):
848
907
  if context_anchor_time is None:
849
- context_anchor_time = anchor_time - (query.target.end_offset *
850
- query.num_forecasts)
908
+ context_anchor_time = anchor_time - forecast_end_offset
851
909
  self._validate_time(query, anchor_time, context_anchor_time,
852
910
  evaluate)
853
911
  else:
854
912
  assert anchor_time == 'entity'
855
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
913
+ if query.entity_table not in self._graph_store.time_dict:
856
914
  raise ValueError(f"Anchor time 'entity' requires the entity "
857
- f"table '{query.entity.pkey.table_name}' to "
915
+ f"table '{query.entity_table}' to "
858
916
  f"have a time column")
859
917
  if context_anchor_time is not None:
860
918
  warnings.warn("Ignoring option 'context_anchor_time' for "
@@ -905,7 +963,7 @@ class KumoRFM:
905
963
  f"in batches")
906
964
 
907
965
  test_node = self._graph_store.get_node_id(
908
- table_name=query.entity.pkey.table_name,
966
+ table_name=query.entity_table,
909
967
  pkey=pd.Series(indices),
910
968
  )
911
969
 
@@ -913,8 +971,7 @@ class KumoRFM:
913
971
  test_time = pd.Series(anchor_time).repeat(
914
972
  len(test_node)).reset_index(drop=True)
915
973
  else:
916
- time = self._graph_store.time_dict[
917
- query.entity.pkey.table_name]
974
+ time = self._graph_store.time_dict[query.entity_table]
918
975
  time = time[test_node] * 1000**3
919
976
  test_time = pd.Series(time, dtype='datetime64[ns]')
920
977
 
@@ -947,12 +1004,23 @@ class KumoRFM:
947
1004
  raise NotImplementedError
948
1005
  logger.log(msg)
949
1006
 
950
- entity_table_names = query.get_entity_table_names(
951
- self._graph_store.edge_types)
1007
+ entity_table_names: Tuple[str, ...]
1008
+ if task_type.is_link_pred:
1009
+ final_aggr = query.get_final_target_aggregation()
1010
+ assert final_aggr is not None
1011
+ edge_fkey = final_aggr._get_target_column_name()
1012
+ for edge_type in self._graph_store.edge_types:
1013
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1014
+ entity_table_names = (
1015
+ query.entity_table,
1016
+ edge_type[2],
1017
+ )
1018
+ else:
1019
+ entity_table_names = (query.entity_table, )
952
1020
 
953
1021
  # Exclude the entity anchor time from the feature set to prevent
954
1022
  # running out-of-distribution between in-context and test examples:
955
- exclude_cols_dict = query.exclude_cols_dict
1023
+ exclude_cols_dict = query.get_exclude_cols_dict()
956
1024
  if anchor_time == 'entity':
957
1025
  if entity_table_names[0] not in exclude_cols_dict:
958
1026
  exclude_cols_dict[entity_table_names[0]] = []
@@ -981,7 +1049,7 @@ class KumoRFM:
981
1049
 
982
1050
  step_size: Optional[int] = None
983
1051
  if query.query_type == QueryType.TEMPORAL:
984
- step_size = date_offset_to_seconds(query.target.end_offset)
1052
+ step_size = date_offset_to_seconds(end_offset)
985
1053
 
986
1054
  return Context(
987
1055
  task_type=task_type,
@@ -370,9 +370,11 @@ class PredictiveQuery:
370
370
  train_table_job_api = global_state.client.generate_train_table_job_api
371
371
  job_id: GenerateTrainTableJobID = train_table_job_api.create(
372
372
  GenerateTrainTableRequest(
373
- dict(custom_tags), pq_id, plan,
374
- graph_snapshot_id=self.graph.snapshot(
375
- non_blocking=non_blocking)))
373
+ dict(custom_tags),
374
+ pq_id,
375
+ plan,
376
+ None,
377
+ ))
376
378
 
377
379
  self._train_table = TrainingTableJob(job_id=job_id)
378
380
  if non_blocking:
@@ -451,9 +453,11 @@ class PredictiveQuery:
451
453
  bp_table_api = global_state.client.generate_prediction_table_job_api
452
454
  job_id: GeneratePredictionTableJobID = bp_table_api.create(
453
455
  GeneratePredictionTableRequest(
454
- dict(custom_tags), pq_id, plan,
455
- graph_snapshot_id=self.graph.snapshot(
456
- non_blocking=non_blocking)))
456
+ dict(custom_tags),
457
+ pq_id,
458
+ plan,
459
+ None,
460
+ ))
457
461
 
458
462
  self._prediction_table = PredictionTableJob(job_id=job_id)
459
463
  if non_blocking:
kumoai/trainer/trainer.py CHANGED
@@ -20,7 +20,6 @@ from kumoapi.jobs import (
20
20
  TrainingJobResource,
21
21
  )
22
22
  from kumoapi.model_plan import ModelPlan
23
- from kumoapi.task import TaskType
24
23
 
25
24
  from kumoai import global_state
26
25
  from kumoai.artifact_export.config import OutputConfig
@@ -405,15 +404,15 @@ class Trainer:
405
404
  pred_table_data_path = prediction_table.table_data_uri
406
405
 
407
406
  api = global_state.client.batch_prediction_job_api
408
-
409
- from kumoai.pquery.predictive_query import PredictiveQuery
410
- pquery = PredictiveQuery.load_from_training_job(training_job_id)
411
- if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
412
- if binary_classification_threshold is None:
413
- logger.warning("No binary classification threshold provided. "
414
- "Using default threshold of 0.5.")
415
- binary_classification_threshold = 0.5
416
-
407
+ # Remove to resolve https://github.com/kumo-ai/kumo/issues/24250
408
+ # from kumoai.pquery.predictive_query import PredictiveQuery
409
+ # pquery = PredictiveQuery.load_from_training_job(training_job_id)
410
+ # if pquery.get_task_type() == TaskType.BINARY_CLASSIFICATION:
411
+ # if binary_classification_threshold is None:
412
+ # logger.warning(
413
+ # "No binary classification threshold provided. "
414
+ # "Using default threshold of 0.5.")
415
+ # binary_classification_threshold = 0.5
417
416
  job_id, response = api.maybe_create(
418
417
  BatchPredictionRequest(
419
418
  dict(custom_tags),
@@ -103,10 +103,13 @@ class InteractiveProgressLogger(ProgressLogger):
103
103
  self._progress.update(self._task, advance=1) # type: ignore
104
104
 
105
105
  def __enter__(self) -> Self:
106
+ from kumoai import in_notebook
107
+
106
108
  super().__enter__()
107
109
 
108
- sys.stdout.write("\x1b]9;4;3\x07")
109
- sys.stdout.flush()
110
+ if not in_notebook(): # Render progress bar in TUI.
111
+ sys.stdout.write("\x1b]9;4;3\x07")
112
+ sys.stdout.flush()
110
113
 
111
114
  if self.verbose:
112
115
  self._live = Live(
@@ -119,6 +122,8 @@ class InteractiveProgressLogger(ProgressLogger):
119
122
  return self
120
123
 
121
124
  def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
125
+ from kumoai import in_notebook
126
+
122
127
  super().__exit__(exc_type, exc_val, exc_tb)
123
128
 
124
129
  if exc_type is not None:
@@ -134,8 +139,9 @@ class InteractiveProgressLogger(ProgressLogger):
134
139
  self._live.stop()
135
140
  self._live = None
136
141
 
137
- sys.stdout.write("\x1b]9;4;0\x07")
138
- sys.stdout.flush()
142
+ if not in_notebook():
143
+ sys.stdout.write("\x1b]9;4;0\x07")
144
+ sys.stdout.flush()
139
145
 
140
146
  def __rich_console__(
141
147
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kumoai
3
- Version: 2.11.0.dev202510161830
3
+ Version: 2.12.1
4
4
  Summary: AI on the Modern Data Stack
5
5
  Author-email: "Kumo.AI" <hello@kumo.ai>
6
6
  License-Expression: MIT
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
23
23
  Requires-Dist: urllib3
24
24
  Requires-Dist: plotly
25
25
  Requires-Dist: typing_extensions>=4.5.0
26
- Requires-Dist: kumo-api==0.38.0
26
+ Requires-Dist: kumo-api==0.45.0
27
27
  Requires-Dist: tqdm>=4.66.0
28
28
  Requires-Dist: aiohttp>=3.10.0
29
29
  Requires-Dist: pydantic>=1.10.21