kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/experimental/rfm/__init__.py +33 -8
  3. kumoai/experimental/rfm/authenticate.py +3 -4
  4. kumoai/experimental/rfm/backend/local/graph_store.py +25 -25
  5. kumoai/experimental/rfm/backend/local/table.py +16 -21
  6. kumoai/experimental/rfm/backend/snow/sampler.py +22 -34
  7. kumoai/experimental/rfm/backend/snow/table.py +67 -33
  8. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  9. kumoai/experimental/rfm/backend/sqlite/sampler.py +21 -26
  10. kumoai/experimental/rfm/backend/sqlite/table.py +54 -26
  11. kumoai/experimental/rfm/base/__init__.py +8 -0
  12. kumoai/experimental/rfm/base/column.py +14 -12
  13. kumoai/experimental/rfm/base/column_expression.py +50 -0
  14. kumoai/experimental/rfm/base/sql_sampler.py +31 -3
  15. kumoai/experimental/rfm/base/sql_table.py +229 -0
  16. kumoai/experimental/rfm/base/table.py +162 -143
  17. kumoai/experimental/rfm/graph.py +242 -95
  18. kumoai/experimental/rfm/infer/__init__.py +6 -4
  19. kumoai/experimental/rfm/infer/dtype.py +3 -3
  20. kumoai/experimental/rfm/infer/pkey.py +4 -2
  21. kumoai/experimental/rfm/infer/stype.py +35 -0
  22. kumoai/experimental/rfm/infer/time_col.py +1 -2
  23. kumoai/experimental/rfm/pquery/executor.py +27 -27
  24. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  25. kumoai/experimental/rfm/rfm.py +86 -80
  26. kumoai/experimental/rfm/sagemaker.py +4 -4
  27. kumoai/utils/__init__.py +1 -2
  28. kumoai/utils/progress_logger.py +178 -12
  29. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +2 -1
  30. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +33 -30
  31. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
  32. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
  33. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,5 @@
1
1
  import re
2
2
  import warnings
3
- from typing import Optional
4
3
 
5
4
  import pandas as pd
6
5
 
@@ -8,7 +7,7 @@ import pandas as pd
8
7
  def infer_time_column(
9
8
  df: pd.DataFrame,
10
9
  candidates: list[str],
11
- ) -> Optional[str]:
10
+ ) -> str | None:
12
11
  r"""Auto-detect potential time column.
13
12
 
14
13
  Args:
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, Generic, Tuple, TypeVar
2
+ from typing import Generic, TypeVar
3
3
 
4
4
  from kumoapi.pquery import ValidatedPredictiveQuery
5
5
  from kumoapi.pquery.AST import (
@@ -21,82 +21,82 @@ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
21
21
  def execute_column(
22
22
  self,
23
23
  column: Column,
24
- feat_dict: Dict[str, TableData],
24
+ feat_dict: dict[str, TableData],
25
25
  filter_na: bool = True,
26
- ) -> Tuple[ColumnData, IndexData]:
26
+ ) -> tuple[ColumnData, IndexData]:
27
27
  pass
28
28
 
29
29
  @abstractmethod
30
30
  def execute_aggregation(
31
31
  self,
32
32
  aggr: Aggregation,
33
- feat_dict: Dict[str, TableData],
34
- time_dict: Dict[str, ColumnData],
35
- batch_dict: Dict[str, IndexData],
33
+ feat_dict: dict[str, TableData],
34
+ time_dict: dict[str, ColumnData],
35
+ batch_dict: dict[str, IndexData],
36
36
  anchor_time: ColumnData,
37
37
  filter_na: bool = True,
38
38
  num_forecasts: int = 1,
39
- ) -> Tuple[ColumnData, IndexData]:
39
+ ) -> tuple[ColumnData, IndexData]:
40
40
  pass
41
41
 
42
42
  @abstractmethod
43
43
  def execute_condition(
44
44
  self,
45
45
  condition: Condition,
46
- feat_dict: Dict[str, TableData],
47
- time_dict: Dict[str, ColumnData],
48
- batch_dict: Dict[str, IndexData],
46
+ feat_dict: dict[str, TableData],
47
+ time_dict: dict[str, ColumnData],
48
+ batch_dict: dict[str, IndexData],
49
49
  anchor_time: ColumnData,
50
50
  filter_na: bool = True,
51
51
  num_forecasts: int = 1,
52
- ) -> Tuple[ColumnData, IndexData]:
52
+ ) -> tuple[ColumnData, IndexData]:
53
53
  pass
54
54
 
55
55
  @abstractmethod
56
56
  def execute_logical_operation(
57
57
  self,
58
58
  logical_operation: LogicalOperation,
59
- feat_dict: Dict[str, TableData],
60
- time_dict: Dict[str, ColumnData],
61
- batch_dict: Dict[str, IndexData],
59
+ feat_dict: dict[str, TableData],
60
+ time_dict: dict[str, ColumnData],
61
+ batch_dict: dict[str, IndexData],
62
62
  anchor_time: ColumnData,
63
63
  filter_na: bool = True,
64
64
  num_forecasts: int = 1,
65
- ) -> Tuple[ColumnData, IndexData]:
65
+ ) -> tuple[ColumnData, IndexData]:
66
66
  pass
67
67
 
68
68
  @abstractmethod
69
69
  def execute_join(
70
70
  self,
71
71
  join: Join,
72
- feat_dict: Dict[str, TableData],
73
- time_dict: Dict[str, ColumnData],
74
- batch_dict: Dict[str, IndexData],
72
+ feat_dict: dict[str, TableData],
73
+ time_dict: dict[str, ColumnData],
74
+ batch_dict: dict[str, IndexData],
75
75
  anchor_time: ColumnData,
76
76
  filter_na: bool = True,
77
77
  num_forecasts: int = 1,
78
- ) -> Tuple[ColumnData, IndexData]:
78
+ ) -> tuple[ColumnData, IndexData]:
79
79
  pass
80
80
 
81
81
  @abstractmethod
82
82
  def execute_filter(
83
83
  self,
84
84
  filter: Filter,
85
- feat_dict: Dict[str, TableData],
86
- time_dict: Dict[str, ColumnData],
87
- batch_dict: Dict[str, IndexData],
85
+ feat_dict: dict[str, TableData],
86
+ time_dict: dict[str, ColumnData],
87
+ batch_dict: dict[str, IndexData],
88
88
  anchor_time: ColumnData,
89
- ) -> Tuple[ColumnData, IndexData]:
89
+ ) -> tuple[ColumnData, IndexData]:
90
90
  pass
91
91
 
92
92
  @abstractmethod
93
93
  def execute(
94
94
  self,
95
95
  query: ValidatedPredictiveQuery,
96
- feat_dict: Dict[str, TableData],
97
- time_dict: Dict[str, ColumnData],
98
- batch_dict: Dict[str, IndexData],
96
+ feat_dict: dict[str, TableData],
97
+ time_dict: dict[str, ColumnData],
98
+ batch_dict: dict[str, IndexData],
99
99
  anchor_time: ColumnData,
100
100
  num_forecasts: int = 1,
101
- ) -> Tuple[ColumnData, IndexData]:
101
+ ) -> tuple[ColumnData, IndexData]:
102
102
  pass
@@ -1,5 +1,3 @@
1
- from typing import Dict, List, Tuple
2
-
3
1
  import numpy as np
4
2
  import pandas as pd
5
3
  from kumoapi.pquery import ValidatedPredictiveQuery
@@ -22,9 +20,9 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
22
20
  def execute_column(
23
21
  self,
24
22
  column: Column,
25
- feat_dict: Dict[str, pd.DataFrame],
23
+ feat_dict: dict[str, pd.DataFrame],
26
24
  filter_na: bool = True,
27
- ) -> Tuple[pd.Series, np.ndarray]:
25
+ ) -> tuple[pd.Series, np.ndarray]:
28
26
  table_name, column_name = column.fqn.split(".")
29
27
  if column_name == '*':
30
28
  out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
@@ -60,7 +58,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
60
58
  batch: np.ndarray,
61
59
  batch_size: int,
62
60
  filter_na: bool = True,
63
- ) -> Tuple[pd.Series, np.ndarray]:
61
+ ) -> tuple[pd.Series, np.ndarray]:
64
62
 
65
63
  mask = feat.notna()
66
64
  feat, batch = feat[mask], batch[mask]
@@ -104,13 +102,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
104
102
  def execute_aggregation(
105
103
  self,
106
104
  aggr: Aggregation,
107
- feat_dict: Dict[str, pd.DataFrame],
108
- time_dict: Dict[str, pd.Series],
109
- batch_dict: Dict[str, np.ndarray],
105
+ feat_dict: dict[str, pd.DataFrame],
106
+ time_dict: dict[str, pd.Series],
107
+ batch_dict: dict[str, np.ndarray],
110
108
  anchor_time: pd.Series,
111
109
  filter_na: bool = True,
112
110
  num_forecasts: int = 1,
113
- ) -> Tuple[pd.Series, np.ndarray]:
111
+ ) -> tuple[pd.Series, np.ndarray]:
114
112
  target_table = aggr._get_target_column_name().split('.')[0]
115
113
  target_batch = batch_dict[target_table]
116
114
  target_time = time_dict[target_table]
@@ -131,8 +129,8 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
131
129
  filter_na=True,
132
130
  )
133
131
 
134
- outs: List[pd.Series] = []
135
- masks: List[np.ndarray] = []
132
+ outs: list[pd.Series] = []
133
+ masks: list[np.ndarray] = []
136
134
  for _ in range(num_forecasts):
137
135
  anchor_target_time = anchor_time.iloc[target_batch]
138
136
  anchor_target_time = anchor_target_time.reset_index(drop=True)
@@ -226,13 +224,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
226
224
  def execute_condition(
227
225
  self,
228
226
  condition: Condition,
229
- feat_dict: Dict[str, pd.DataFrame],
230
- time_dict: Dict[str, pd.Series],
231
- batch_dict: Dict[str, np.ndarray],
227
+ feat_dict: dict[str, pd.DataFrame],
228
+ time_dict: dict[str, pd.Series],
229
+ batch_dict: dict[str, np.ndarray],
232
230
  anchor_time: pd.Series,
233
231
  filter_na: bool = True,
234
232
  num_forecasts: int = 1,
235
- ) -> Tuple[pd.Series, np.ndarray]:
233
+ ) -> tuple[pd.Series, np.ndarray]:
236
234
  if num_forecasts > 1:
237
235
  raise NotImplementedError("Forecasting not yet implemented for "
238
236
  "non-regression tasks")
@@ -306,13 +304,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
306
304
  def execute_logical_operation(
307
305
  self,
308
306
  logical_operation: LogicalOperation,
309
- feat_dict: Dict[str, pd.DataFrame],
310
- time_dict: Dict[str, pd.Series],
311
- batch_dict: Dict[str, np.ndarray],
307
+ feat_dict: dict[str, pd.DataFrame],
308
+ time_dict: dict[str, pd.Series],
309
+ batch_dict: dict[str, np.ndarray],
312
310
  anchor_time: pd.Series,
313
311
  filter_na: bool = True,
314
312
  num_forecasts: int = 1,
315
- ) -> Tuple[pd.Series, np.ndarray]:
313
+ ) -> tuple[pd.Series, np.ndarray]:
316
314
  if num_forecasts > 1:
317
315
  raise NotImplementedError("Forecasting not yet implemented for "
318
316
  "non-regression tasks")
@@ -370,13 +368,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
370
368
  def execute_join(
371
369
  self,
372
370
  join: Join,
373
- feat_dict: Dict[str, pd.DataFrame],
374
- time_dict: Dict[str, pd.Series],
375
- batch_dict: Dict[str, np.ndarray],
371
+ feat_dict: dict[str, pd.DataFrame],
372
+ time_dict: dict[str, pd.Series],
373
+ batch_dict: dict[str, np.ndarray],
376
374
  anchor_time: pd.Series,
377
375
  filter_na: bool = True,
378
376
  num_forecasts: int = 1,
379
- ) -> Tuple[pd.Series, np.ndarray]:
377
+ ) -> tuple[pd.Series, np.ndarray]:
380
378
  if isinstance(join.rhs_target, Aggregation):
381
379
  return self.execute_aggregation(
382
380
  aggr=join.rhs_target,
@@ -393,12 +391,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
393
391
  def execute_filter(
394
392
  self,
395
393
  filter: Filter,
396
- feat_dict: Dict[str, pd.DataFrame],
397
- time_dict: Dict[str, pd.Series],
398
- batch_dict: Dict[str, np.ndarray],
394
+ feat_dict: dict[str, pd.DataFrame],
395
+ time_dict: dict[str, pd.Series],
396
+ batch_dict: dict[str, np.ndarray],
399
397
  anchor_time: pd.Series,
400
398
  filter_na: bool = True,
401
- ) -> Tuple[pd.Series, np.ndarray]:
399
+ ) -> tuple[pd.Series, np.ndarray]:
402
400
  out, mask = self.execute_column(
403
401
  column=filter.target,
404
402
  feat_dict=feat_dict,
@@ -431,12 +429,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
431
429
  def execute(
432
430
  self,
433
431
  query: ValidatedPredictiveQuery,
434
- feat_dict: Dict[str, pd.DataFrame],
435
- time_dict: Dict[str, pd.Series],
436
- batch_dict: Dict[str, np.ndarray],
432
+ feat_dict: dict[str, pd.DataFrame],
433
+ time_dict: dict[str, pd.Series],
434
+ batch_dict: dict[str, np.ndarray],
437
435
  anchor_time: pd.Series,
438
436
  num_forecasts: int = 1,
439
- ) -> Tuple[pd.Series, np.ndarray]:
437
+ ) -> tuple[pd.Series, np.ndarray]:
440
438
  if isinstance(query.entity_ast, Column):
441
439
  out, mask = self.execute_column(
442
440
  column=query.entity_ast,
@@ -2,20 +2,10 @@ import json
2
2
  import time
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
8
+ from typing import Any, Literal, overload
19
9
 
20
10
  import numpy as np
21
11
  import pandas as pd
@@ -38,12 +28,13 @@ from kumoapi.rfm import (
38
28
  from kumoapi.task import TaskType
39
29
  from kumoapi.typing import AggregationType, Stype
40
30
 
31
+ from kumoai import in_notebook, in_snowflake_notebook
41
32
  from kumoai.client.rfm import RFMAPI
42
33
  from kumoai.exceptions import HTTPException
43
34
  from kumoai.experimental.rfm import Graph
44
35
  from kumoai.experimental.rfm.base import DataBackend, Sampler
45
36
  from kumoai.mixin import CastMixin
46
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
37
+ from kumoai.utils import ProgressLogger
47
38
 
48
39
  _RANDOM_SEED = 42
49
40
 
@@ -98,24 +89,41 @@ class Explanation:
98
89
  def __getitem__(self, index: Literal[1]) -> str:
99
90
  pass
100
91
 
101
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
92
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
102
93
  if index == 0:
103
94
  return self.prediction
104
95
  if index == 1:
105
96
  return self.summary
106
97
  raise IndexError("Index out of range")
107
98
 
108
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
99
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
109
100
  return iter((self.prediction, self.summary))
110
101
 
111
102
  def __repr__(self) -> str:
112
103
  return str((self.prediction, self.summary))
113
104
 
114
- def _ipython_display_(self) -> None:
115
- from IPython.display import Markdown, display
105
+ def print(self) -> None:
106
+ r"""Prints the explanation."""
107
+ if in_snowflake_notebook():
108
+ import streamlit as st
109
+ st.dataframe(self.prediction, hide_index=True)
110
+ st.markdown(self.summary)
111
+ elif in_notebook():
112
+ from IPython.display import Markdown, display
113
+ try:
114
+ if hasattr(self.prediction.style, 'hide'):
115
+ display(self.prediction.hide(axis='index')) # pandas=2
116
+ else:
117
+ display(self.prediction.hide_index()) # pandas <1.3
118
+ except ImportError:
119
+ print(self.prediction.to_string(index=False)) # missing jinja2
120
+ display(Markdown(self.summary))
121
+ else:
122
+ print(self.prediction.to_string(index=False))
123
+ print(self.summary)
116
124
 
117
- display(self.prediction)
118
- display(Markdown(self.summary))
125
+ def _ipython_display_(self) -> None:
126
+ self.print()
119
127
 
120
128
 
121
129
  class KumoRFM:
@@ -162,7 +170,7 @@ class KumoRFM:
162
170
  def __init__(
163
171
  self,
164
172
  graph: Graph,
165
- verbose: Union[bool, ProgressLogger] = True,
173
+ verbose: bool | ProgressLogger = True,
166
174
  optimize: bool = False,
167
175
  ) -> None:
168
176
  graph = graph.validate()
@@ -180,9 +188,9 @@ class KumoRFM:
180
188
  else:
181
189
  raise NotImplementedError
182
190
 
183
- self._client: Optional[RFMAPI] = None
191
+ self._client: RFMAPI | None = None
184
192
 
185
- self._batch_size: Optional[int | Literal['max']] = None
193
+ self._batch_size: int | Literal['max'] | None = None
186
194
  self.num_retries: int = 0
187
195
 
188
196
  @property
@@ -200,7 +208,7 @@ class KumoRFM:
200
208
  @contextmanager
201
209
  def batch_mode(
202
210
  self,
203
- batch_size: Union[int, Literal['max']] = 'max',
211
+ batch_size: int | Literal['max'] = 'max',
204
212
  num_retries: int = 1,
205
213
  ) -> Generator[None, None, None]:
206
214
  """Context manager to predict in batches.
@@ -234,17 +242,17 @@ class KumoRFM:
234
242
  def predict(
235
243
  self,
236
244
  query: str,
237
- indices: Union[List[str], List[float], List[int], None] = None,
245
+ indices: list[str] | list[float] | list[int] | None = None,
238
246
  *,
239
247
  explain: Literal[False] = False,
240
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
241
- context_anchor_time: Union[pd.Timestamp, None] = None,
242
- run_mode: Union[RunMode, str] = RunMode.FAST,
243
- num_neighbors: Optional[List[int]] = None,
248
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
249
+ context_anchor_time: pd.Timestamp | None = None,
250
+ run_mode: RunMode | str = RunMode.FAST,
251
+ num_neighbors: list[int] | None = None,
244
252
  num_hops: int = 2,
245
253
  max_pq_iterations: int = 10,
246
- random_seed: Optional[int] = _RANDOM_SEED,
247
- verbose: Union[bool, ProgressLogger] = True,
254
+ random_seed: int | None = _RANDOM_SEED,
255
+ verbose: bool | ProgressLogger = True,
248
256
  use_prediction_time: bool = False,
249
257
  ) -> pd.DataFrame:
250
258
  pass
@@ -253,17 +261,17 @@ class KumoRFM:
253
261
  def predict(
254
262
  self,
255
263
  query: str,
256
- indices: Union[List[str], List[float], List[int], None] = None,
264
+ indices: list[str] | list[float] | list[int] | None = None,
257
265
  *,
258
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
259
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
260
- context_anchor_time: Union[pd.Timestamp, None] = None,
261
- run_mode: Union[RunMode, str] = RunMode.FAST,
262
- num_neighbors: Optional[List[int]] = None,
266
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
267
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
268
+ context_anchor_time: pd.Timestamp | None = None,
269
+ run_mode: RunMode | str = RunMode.FAST,
270
+ num_neighbors: list[int] | None = None,
263
271
  num_hops: int = 2,
264
272
  max_pq_iterations: int = 10,
265
- random_seed: Optional[int] = _RANDOM_SEED,
266
- verbose: Union[bool, ProgressLogger] = True,
273
+ random_seed: int | None = _RANDOM_SEED,
274
+ verbose: bool | ProgressLogger = True,
267
275
  use_prediction_time: bool = False,
268
276
  ) -> Explanation:
269
277
  pass
@@ -271,19 +279,19 @@ class KumoRFM:
271
279
  def predict(
272
280
  self,
273
281
  query: str,
274
- indices: Union[List[str], List[float], List[int], None] = None,
282
+ indices: list[str] | list[float] | list[int] | None = None,
275
283
  *,
276
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
277
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
278
- context_anchor_time: Union[pd.Timestamp, None] = None,
279
- run_mode: Union[RunMode, str] = RunMode.FAST,
280
- num_neighbors: Optional[List[int]] = None,
284
+ explain: bool | ExplainConfig | dict[str, Any] = False,
285
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
286
+ context_anchor_time: pd.Timestamp | None = None,
287
+ run_mode: RunMode | str = RunMode.FAST,
288
+ num_neighbors: list[int] | None = None,
281
289
  num_hops: int = 2,
282
290
  max_pq_iterations: int = 10,
283
- random_seed: Optional[int] = _RANDOM_SEED,
284
- verbose: Union[bool, ProgressLogger] = True,
291
+ random_seed: int | None = _RANDOM_SEED,
292
+ verbose: bool | ProgressLogger = True,
285
293
  use_prediction_time: bool = False,
286
- ) -> Union[pd.DataFrame, Explanation]:
294
+ ) -> pd.DataFrame | Explanation:
287
295
  """Returns predictions for a predictive query.
288
296
 
289
297
  Args:
@@ -325,7 +333,7 @@ class KumoRFM:
325
333
  If ``explain`` is provided, returns an :class:`Explanation` object
326
334
  containing the prediction, summary, and details.
327
335
  """
328
- explain_config: Optional[ExplainConfig] = None
336
+ explain_config: ExplainConfig | None = None
329
337
  if explain is True:
330
338
  explain_config = ExplainConfig()
331
339
  elif explain is not False:
@@ -369,11 +377,11 @@ class KumoRFM:
369
377
  msg = f'[bold]PREDICT[/bold] {query_repr}'
370
378
 
371
379
  if not isinstance(verbose, ProgressLogger):
372
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
380
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
373
381
 
374
382
  with verbose as logger:
375
383
 
376
- batch_size: Optional[int] = None
384
+ batch_size: int | None = None
377
385
  if self._batch_size == 'max':
378
386
  task_type = self._get_task_type(
379
387
  query=query_def,
@@ -393,9 +401,9 @@ class KumoRFM:
393
401
  logger.log(f"Splitting {len(indices):,} entities into "
394
402
  f"{len(batches):,} batches of size {batch_size:,}")
395
403
 
396
- predictions: List[pd.DataFrame] = []
397
- summary: Optional[str] = None
398
- details: Optional[Explanation] = None
404
+ predictions: list[pd.DataFrame] = []
405
+ summary: str | None = None
406
+ details: Explanation | None = None
399
407
  for i, batch in enumerate(batches):
400
408
  # TODO Re-use the context for subsequent predictions.
401
409
  context = self._get_context(
@@ -429,8 +437,7 @@ class KumoRFM:
429
437
  stats = Context.get_memory_stats(request_msg.context)
430
438
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
431
439
 
432
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
433
- and len(batches) > 1):
440
+ if i == 0 and len(batches) > 1:
434
441
  verbose.init_progress(
435
442
  total=len(batches),
436
443
  description='Predicting',
@@ -469,8 +476,7 @@ class KumoRFM:
469
476
 
470
477
  predictions.append(df)
471
478
 
472
- if (isinstance(verbose, InteractiveProgressLogger)
473
- and len(batches) > 1):
479
+ if len(batches) > 1:
474
480
  verbose.step()
475
481
 
476
482
  break
@@ -508,9 +514,9 @@ class KumoRFM:
508
514
  def is_valid_entity(
509
515
  self,
510
516
  query: str,
511
- indices: Union[List[str], List[float], List[int], None] = None,
517
+ indices: list[str] | list[float] | list[int] | None = None,
512
518
  *,
513
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
519
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
514
520
  ) -> np.ndarray:
515
521
  r"""Returns a mask that denotes which entities are valid for the
516
522
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -554,15 +560,15 @@ class KumoRFM:
554
560
  self,
555
561
  query: str,
556
562
  *,
557
- metrics: Optional[List[str]] = None,
558
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
559
- context_anchor_time: Union[pd.Timestamp, None] = None,
560
- run_mode: Union[RunMode, str] = RunMode.FAST,
561
- num_neighbors: Optional[List[int]] = None,
563
+ metrics: list[str] | None = None,
564
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
565
+ context_anchor_time: pd.Timestamp | None = None,
566
+ run_mode: RunMode | str = RunMode.FAST,
567
+ num_neighbors: list[int] | None = None,
562
568
  num_hops: int = 2,
563
569
  max_pq_iterations: int = 10,
564
- random_seed: Optional[int] = _RANDOM_SEED,
565
- verbose: Union[bool, ProgressLogger] = True,
570
+ random_seed: int | None = _RANDOM_SEED,
571
+ verbose: bool | ProgressLogger = True,
566
572
  use_prediction_time: bool = False,
567
573
  ) -> pd.DataFrame:
568
574
  """Evaluates a predictive query.
@@ -610,7 +616,7 @@ class KumoRFM:
610
616
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
611
617
 
612
618
  if not isinstance(verbose, ProgressLogger):
613
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
619
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
614
620
 
615
621
  with verbose as logger:
616
622
  context = self._get_context(
@@ -669,8 +675,8 @@ class KumoRFM:
669
675
  query: str,
670
676
  size: int,
671
677
  *,
672
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
673
- random_seed: Optional[int] = _RANDOM_SEED,
678
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
679
+ random_seed: int | None = _RANDOM_SEED,
674
680
  max_iterations: int = 10,
675
681
  ) -> pd.DataFrame:
676
682
  """Returns the labels of a predictive query for a specified anchor
@@ -764,7 +770,7 @@ class KumoRFM:
764
770
  @staticmethod
765
771
  def _get_task_type(
766
772
  query: ValidatedPredictiveQuery,
767
- edge_types: List[Tuple[str, str, str]],
773
+ edge_types: list[tuple[str, str, str]],
768
774
  ) -> TaskType:
769
775
  if isinstance(query.target_ast, (Condition, LogicalOperation)):
770
776
  return TaskType.BINARY_CLASSIFICATION
@@ -819,7 +825,7 @@ class KumoRFM:
819
825
  self,
820
826
  query: ValidatedPredictiveQuery,
821
827
  anchor_time: pd.Timestamp,
822
- context_anchor_time: Union[pd.Timestamp, None],
828
+ context_anchor_time: pd.Timestamp | None,
823
829
  evaluate: bool,
824
830
  ) -> None:
825
831
 
@@ -885,16 +891,16 @@ class KumoRFM:
885
891
  def _get_context(
886
892
  self,
887
893
  query: ValidatedPredictiveQuery,
888
- indices: Union[List[str], List[float], List[int], None],
889
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
890
- context_anchor_time: Union[pd.Timestamp, None],
894
+ indices: list[str] | list[float] | list[int] | None,
895
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
896
+ context_anchor_time: pd.Timestamp | None,
891
897
  run_mode: RunMode,
892
- num_neighbors: Optional[List[int]],
898
+ num_neighbors: list[int] | None,
893
899
  num_hops: int,
894
900
  max_pq_iterations: int,
895
901
  evaluate: bool,
896
- random_seed: Optional[int] = _RANDOM_SEED,
897
- logger: Optional[ProgressLogger] = None,
902
+ random_seed: int | None = _RANDOM_SEED,
903
+ logger: ProgressLogger | None = None,
898
904
  ) -> Context:
899
905
 
900
906
  if num_neighbors is not None:
@@ -1069,7 +1075,7 @@ class KumoRFM:
1069
1075
  raise NotImplementedError
1070
1076
  logger.log(msg)
1071
1077
 
1072
- entity_table_names: Tuple[str, ...]
1078
+ entity_table_names: tuple[str, ...]
1073
1079
  if task_type.is_link_pred:
1074
1080
  final_aggr = query.get_final_target_aggregation()
1075
1081
  assert final_aggr is not None
@@ -1127,7 +1133,7 @@ class KumoRFM:
1127
1133
 
1128
1134
  @staticmethod
1129
1135
  def _validate_metrics(
1130
- metrics: List[str],
1136
+ metrics: list[str],
1131
1137
  task_type: TaskType,
1132
1138
  ) -> None:
1133
1139
 
@@ -1184,7 +1190,7 @@ class KumoRFM:
1184
1190
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1185
1191
 
1186
1192
 
1187
- def format_value(value: Union[int, float]) -> str:
1193
+ def format_value(value: int | float) -> str:
1188
1194
  if value == int(value):
1189
1195
  return f'{int(value):,}'
1190
1196
  if abs(value) >= 1000:
@@ -1,6 +1,6 @@
1
1
  import base64
2
2
  import json
3
- from typing import Any, Dict, List, Tuple
3
+ from typing import Any
4
4
 
5
5
  import requests
6
6
 
@@ -48,8 +48,8 @@ class KumoClient_SageMakerAdapter(KumoClient):
48
48
 
49
49
  # Recording buffers.
50
50
  self._recording_active = False
51
- self._recorded_reqs: List[Dict[str, Any]] = []
52
- self._recorded_resps: List[Dict[str, Any]] = []
51
+ self._recorded_reqs: list[dict[str, Any]] = []
52
+ self._recorded_resps: list[dict[str, Any]] = []
53
53
 
54
54
  def authenticate(self) -> None:
55
55
  # TODO(siyang): call /ping to verify?
@@ -92,7 +92,7 @@ class KumoClient_SageMakerAdapter(KumoClient):
92
92
  self._recorded_reqs.clear()
93
93
  self._recorded_resps.clear()
94
94
 
95
- def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
95
+ def end_recording(self) -> list[tuple[dict[str, Any], dict[str, Any]]]:
96
96
  """Stop recording and return recorded requests/responses."""
97
97
  assert self._recording_active
98
98
  self._recording_active = False