kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +25 -25
- kumoai/experimental/rfm/backend/local/table.py +16 -21
- kumoai/experimental/rfm/backend/snow/sampler.py +22 -34
- kumoai/experimental/rfm/backend/snow/table.py +67 -33
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +21 -26
- kumoai/experimental/rfm/backend/sqlite/table.py +54 -26
- kumoai/experimental/rfm/base/__init__.py +8 -0
- kumoai/experimental/rfm/base/column.py +14 -12
- kumoai/experimental/rfm/base/column_expression.py +50 -0
- kumoai/experimental/rfm/base/sql_sampler.py +31 -3
- kumoai/experimental/rfm/base/sql_table.py +229 -0
- kumoai/experimental/rfm/base/table.py +162 -143
- kumoai/experimental/rfm/graph.py +242 -95
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +3 -3
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/rfm.py +86 -80
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/progress_logger.py +178 -12
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +2 -1
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +33 -30
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import pandas as pd
|
|
6
5
|
|
|
@@ -8,7 +7,7 @@ import pandas as pd
|
|
|
8
7
|
def infer_time_column(
|
|
9
8
|
df: pd.DataFrame,
|
|
10
9
|
candidates: list[str],
|
|
11
|
-
) ->
|
|
10
|
+
) -> str | None:
|
|
12
11
|
r"""Auto-detect potential time column.
|
|
13
12
|
|
|
14
13
|
Args:
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
3
|
|
|
4
4
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
5
5
|
from kumoapi.pquery.AST import (
|
|
@@ -21,82 +21,82 @@ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
21
21
|
def execute_column(
|
|
22
22
|
self,
|
|
23
23
|
column: Column,
|
|
24
|
-
feat_dict:
|
|
24
|
+
feat_dict: dict[str, TableData],
|
|
25
25
|
filter_na: bool = True,
|
|
26
|
-
) ->
|
|
26
|
+
) -> tuple[ColumnData, IndexData]:
|
|
27
27
|
pass
|
|
28
28
|
|
|
29
29
|
@abstractmethod
|
|
30
30
|
def execute_aggregation(
|
|
31
31
|
self,
|
|
32
32
|
aggr: Aggregation,
|
|
33
|
-
feat_dict:
|
|
34
|
-
time_dict:
|
|
35
|
-
batch_dict:
|
|
33
|
+
feat_dict: dict[str, TableData],
|
|
34
|
+
time_dict: dict[str, ColumnData],
|
|
35
|
+
batch_dict: dict[str, IndexData],
|
|
36
36
|
anchor_time: ColumnData,
|
|
37
37
|
filter_na: bool = True,
|
|
38
38
|
num_forecasts: int = 1,
|
|
39
|
-
) ->
|
|
39
|
+
) -> tuple[ColumnData, IndexData]:
|
|
40
40
|
pass
|
|
41
41
|
|
|
42
42
|
@abstractmethod
|
|
43
43
|
def execute_condition(
|
|
44
44
|
self,
|
|
45
45
|
condition: Condition,
|
|
46
|
-
feat_dict:
|
|
47
|
-
time_dict:
|
|
48
|
-
batch_dict:
|
|
46
|
+
feat_dict: dict[str, TableData],
|
|
47
|
+
time_dict: dict[str, ColumnData],
|
|
48
|
+
batch_dict: dict[str, IndexData],
|
|
49
49
|
anchor_time: ColumnData,
|
|
50
50
|
filter_na: bool = True,
|
|
51
51
|
num_forecasts: int = 1,
|
|
52
|
-
) ->
|
|
52
|
+
) -> tuple[ColumnData, IndexData]:
|
|
53
53
|
pass
|
|
54
54
|
|
|
55
55
|
@abstractmethod
|
|
56
56
|
def execute_logical_operation(
|
|
57
57
|
self,
|
|
58
58
|
logical_operation: LogicalOperation,
|
|
59
|
-
feat_dict:
|
|
60
|
-
time_dict:
|
|
61
|
-
batch_dict:
|
|
59
|
+
feat_dict: dict[str, TableData],
|
|
60
|
+
time_dict: dict[str, ColumnData],
|
|
61
|
+
batch_dict: dict[str, IndexData],
|
|
62
62
|
anchor_time: ColumnData,
|
|
63
63
|
filter_na: bool = True,
|
|
64
64
|
num_forecasts: int = 1,
|
|
65
|
-
) ->
|
|
65
|
+
) -> tuple[ColumnData, IndexData]:
|
|
66
66
|
pass
|
|
67
67
|
|
|
68
68
|
@abstractmethod
|
|
69
69
|
def execute_join(
|
|
70
70
|
self,
|
|
71
71
|
join: Join,
|
|
72
|
-
feat_dict:
|
|
73
|
-
time_dict:
|
|
74
|
-
batch_dict:
|
|
72
|
+
feat_dict: dict[str, TableData],
|
|
73
|
+
time_dict: dict[str, ColumnData],
|
|
74
|
+
batch_dict: dict[str, IndexData],
|
|
75
75
|
anchor_time: ColumnData,
|
|
76
76
|
filter_na: bool = True,
|
|
77
77
|
num_forecasts: int = 1,
|
|
78
|
-
) ->
|
|
78
|
+
) -> tuple[ColumnData, IndexData]:
|
|
79
79
|
pass
|
|
80
80
|
|
|
81
81
|
@abstractmethod
|
|
82
82
|
def execute_filter(
|
|
83
83
|
self,
|
|
84
84
|
filter: Filter,
|
|
85
|
-
feat_dict:
|
|
86
|
-
time_dict:
|
|
87
|
-
batch_dict:
|
|
85
|
+
feat_dict: dict[str, TableData],
|
|
86
|
+
time_dict: dict[str, ColumnData],
|
|
87
|
+
batch_dict: dict[str, IndexData],
|
|
88
88
|
anchor_time: ColumnData,
|
|
89
|
-
) ->
|
|
89
|
+
) -> tuple[ColumnData, IndexData]:
|
|
90
90
|
pass
|
|
91
91
|
|
|
92
92
|
@abstractmethod
|
|
93
93
|
def execute(
|
|
94
94
|
self,
|
|
95
95
|
query: ValidatedPredictiveQuery,
|
|
96
|
-
feat_dict:
|
|
97
|
-
time_dict:
|
|
98
|
-
batch_dict:
|
|
96
|
+
feat_dict: dict[str, TableData],
|
|
97
|
+
time_dict: dict[str, ColumnData],
|
|
98
|
+
batch_dict: dict[str, IndexData],
|
|
99
99
|
anchor_time: ColumnData,
|
|
100
100
|
num_forecasts: int = 1,
|
|
101
|
-
) ->
|
|
101
|
+
) -> tuple[ColumnData, IndexData]:
|
|
102
102
|
pass
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Dict, List, Tuple
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import pandas as pd
|
|
5
3
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
@@ -22,9 +20,9 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
22
20
|
def execute_column(
|
|
23
21
|
self,
|
|
24
22
|
column: Column,
|
|
25
|
-
feat_dict:
|
|
23
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
26
24
|
filter_na: bool = True,
|
|
27
|
-
) ->
|
|
25
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
28
26
|
table_name, column_name = column.fqn.split(".")
|
|
29
27
|
if column_name == '*':
|
|
30
28
|
out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
|
|
@@ -60,7 +58,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
60
58
|
batch: np.ndarray,
|
|
61
59
|
batch_size: int,
|
|
62
60
|
filter_na: bool = True,
|
|
63
|
-
) ->
|
|
61
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
64
62
|
|
|
65
63
|
mask = feat.notna()
|
|
66
64
|
feat, batch = feat[mask], batch[mask]
|
|
@@ -104,13 +102,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
104
102
|
def execute_aggregation(
|
|
105
103
|
self,
|
|
106
104
|
aggr: Aggregation,
|
|
107
|
-
feat_dict:
|
|
108
|
-
time_dict:
|
|
109
|
-
batch_dict:
|
|
105
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
106
|
+
time_dict: dict[str, pd.Series],
|
|
107
|
+
batch_dict: dict[str, np.ndarray],
|
|
110
108
|
anchor_time: pd.Series,
|
|
111
109
|
filter_na: bool = True,
|
|
112
110
|
num_forecasts: int = 1,
|
|
113
|
-
) ->
|
|
111
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
114
112
|
target_table = aggr._get_target_column_name().split('.')[0]
|
|
115
113
|
target_batch = batch_dict[target_table]
|
|
116
114
|
target_time = time_dict[target_table]
|
|
@@ -131,8 +129,8 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
131
129
|
filter_na=True,
|
|
132
130
|
)
|
|
133
131
|
|
|
134
|
-
outs:
|
|
135
|
-
masks:
|
|
132
|
+
outs: list[pd.Series] = []
|
|
133
|
+
masks: list[np.ndarray] = []
|
|
136
134
|
for _ in range(num_forecasts):
|
|
137
135
|
anchor_target_time = anchor_time.iloc[target_batch]
|
|
138
136
|
anchor_target_time = anchor_target_time.reset_index(drop=True)
|
|
@@ -226,13 +224,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
226
224
|
def execute_condition(
|
|
227
225
|
self,
|
|
228
226
|
condition: Condition,
|
|
229
|
-
feat_dict:
|
|
230
|
-
time_dict:
|
|
231
|
-
batch_dict:
|
|
227
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
228
|
+
time_dict: dict[str, pd.Series],
|
|
229
|
+
batch_dict: dict[str, np.ndarray],
|
|
232
230
|
anchor_time: pd.Series,
|
|
233
231
|
filter_na: bool = True,
|
|
234
232
|
num_forecasts: int = 1,
|
|
235
|
-
) ->
|
|
233
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
236
234
|
if num_forecasts > 1:
|
|
237
235
|
raise NotImplementedError("Forecasting not yet implemented for "
|
|
238
236
|
"non-regression tasks")
|
|
@@ -306,13 +304,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
306
304
|
def execute_logical_operation(
|
|
307
305
|
self,
|
|
308
306
|
logical_operation: LogicalOperation,
|
|
309
|
-
feat_dict:
|
|
310
|
-
time_dict:
|
|
311
|
-
batch_dict:
|
|
307
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
308
|
+
time_dict: dict[str, pd.Series],
|
|
309
|
+
batch_dict: dict[str, np.ndarray],
|
|
312
310
|
anchor_time: pd.Series,
|
|
313
311
|
filter_na: bool = True,
|
|
314
312
|
num_forecasts: int = 1,
|
|
315
|
-
) ->
|
|
313
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
316
314
|
if num_forecasts > 1:
|
|
317
315
|
raise NotImplementedError("Forecasting not yet implemented for "
|
|
318
316
|
"non-regression tasks")
|
|
@@ -370,13 +368,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
370
368
|
def execute_join(
|
|
371
369
|
self,
|
|
372
370
|
join: Join,
|
|
373
|
-
feat_dict:
|
|
374
|
-
time_dict:
|
|
375
|
-
batch_dict:
|
|
371
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
372
|
+
time_dict: dict[str, pd.Series],
|
|
373
|
+
batch_dict: dict[str, np.ndarray],
|
|
376
374
|
anchor_time: pd.Series,
|
|
377
375
|
filter_na: bool = True,
|
|
378
376
|
num_forecasts: int = 1,
|
|
379
|
-
) ->
|
|
377
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
380
378
|
if isinstance(join.rhs_target, Aggregation):
|
|
381
379
|
return self.execute_aggregation(
|
|
382
380
|
aggr=join.rhs_target,
|
|
@@ -393,12 +391,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
393
391
|
def execute_filter(
|
|
394
392
|
self,
|
|
395
393
|
filter: Filter,
|
|
396
|
-
feat_dict:
|
|
397
|
-
time_dict:
|
|
398
|
-
batch_dict:
|
|
394
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
395
|
+
time_dict: dict[str, pd.Series],
|
|
396
|
+
batch_dict: dict[str, np.ndarray],
|
|
399
397
|
anchor_time: pd.Series,
|
|
400
398
|
filter_na: bool = True,
|
|
401
|
-
) ->
|
|
399
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
402
400
|
out, mask = self.execute_column(
|
|
403
401
|
column=filter.target,
|
|
404
402
|
feat_dict=feat_dict,
|
|
@@ -431,12 +429,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
431
429
|
def execute(
|
|
432
430
|
self,
|
|
433
431
|
query: ValidatedPredictiveQuery,
|
|
434
|
-
feat_dict:
|
|
435
|
-
time_dict:
|
|
436
|
-
batch_dict:
|
|
432
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
433
|
+
time_dict: dict[str, pd.Series],
|
|
434
|
+
batch_dict: dict[str, np.ndarray],
|
|
437
435
|
anchor_time: pd.Series,
|
|
438
436
|
num_forecasts: int = 1,
|
|
439
|
-
) ->
|
|
437
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
440
438
|
if isinstance(query.entity_ast, Column):
|
|
441
439
|
out, mask = self.execute_column(
|
|
442
440
|
column=query.entity_ast,
|
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -2,20 +2,10 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
8
|
+
from typing import Any, Literal, overload
|
|
19
9
|
|
|
20
10
|
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
@@ -38,12 +28,13 @@ from kumoapi.rfm import (
|
|
|
38
28
|
from kumoapi.task import TaskType
|
|
39
29
|
from kumoapi.typing import AggregationType, Stype
|
|
40
30
|
|
|
31
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
41
32
|
from kumoai.client.rfm import RFMAPI
|
|
42
33
|
from kumoai.exceptions import HTTPException
|
|
43
34
|
from kumoai.experimental.rfm import Graph
|
|
44
35
|
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
45
36
|
from kumoai.mixin import CastMixin
|
|
46
|
-
from kumoai.utils import
|
|
37
|
+
from kumoai.utils import ProgressLogger
|
|
47
38
|
|
|
48
39
|
_RANDOM_SEED = 42
|
|
49
40
|
|
|
@@ -98,24 +89,41 @@ class Explanation:
|
|
|
98
89
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
99
90
|
pass
|
|
100
91
|
|
|
101
|
-
def __getitem__(self, index: int) ->
|
|
92
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
102
93
|
if index == 0:
|
|
103
94
|
return self.prediction
|
|
104
95
|
if index == 1:
|
|
105
96
|
return self.summary
|
|
106
97
|
raise IndexError("Index out of range")
|
|
107
98
|
|
|
108
|
-
def __iter__(self) -> Iterator[
|
|
99
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
109
100
|
return iter((self.prediction, self.summary))
|
|
110
101
|
|
|
111
102
|
def __repr__(self) -> str:
|
|
112
103
|
return str((self.prediction, self.summary))
|
|
113
104
|
|
|
114
|
-
def
|
|
115
|
-
|
|
105
|
+
def print(self) -> None:
|
|
106
|
+
r"""Prints the explanation."""
|
|
107
|
+
if in_snowflake_notebook():
|
|
108
|
+
import streamlit as st
|
|
109
|
+
st.dataframe(self.prediction, hide_index=True)
|
|
110
|
+
st.markdown(self.summary)
|
|
111
|
+
elif in_notebook():
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
try:
|
|
114
|
+
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
+
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
+
else:
|
|
117
|
+
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
+
except ImportError:
|
|
119
|
+
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
+
display(Markdown(self.summary))
|
|
121
|
+
else:
|
|
122
|
+
print(self.prediction.to_string(index=False))
|
|
123
|
+
print(self.summary)
|
|
116
124
|
|
|
117
|
-
|
|
118
|
-
|
|
125
|
+
def _ipython_display_(self) -> None:
|
|
126
|
+
self.print()
|
|
119
127
|
|
|
120
128
|
|
|
121
129
|
class KumoRFM:
|
|
@@ -162,7 +170,7 @@ class KumoRFM:
|
|
|
162
170
|
def __init__(
|
|
163
171
|
self,
|
|
164
172
|
graph: Graph,
|
|
165
|
-
verbose:
|
|
173
|
+
verbose: bool | ProgressLogger = True,
|
|
166
174
|
optimize: bool = False,
|
|
167
175
|
) -> None:
|
|
168
176
|
graph = graph.validate()
|
|
@@ -180,9 +188,9 @@ class KumoRFM:
|
|
|
180
188
|
else:
|
|
181
189
|
raise NotImplementedError
|
|
182
190
|
|
|
183
|
-
self._client:
|
|
191
|
+
self._client: RFMAPI | None = None
|
|
184
192
|
|
|
185
|
-
self._batch_size:
|
|
193
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
186
194
|
self.num_retries: int = 0
|
|
187
195
|
|
|
188
196
|
@property
|
|
@@ -200,7 +208,7 @@ class KumoRFM:
|
|
|
200
208
|
@contextmanager
|
|
201
209
|
def batch_mode(
|
|
202
210
|
self,
|
|
203
|
-
batch_size:
|
|
211
|
+
batch_size: int | Literal['max'] = 'max',
|
|
204
212
|
num_retries: int = 1,
|
|
205
213
|
) -> Generator[None, None, None]:
|
|
206
214
|
"""Context manager to predict in batches.
|
|
@@ -234,17 +242,17 @@ class KumoRFM:
|
|
|
234
242
|
def predict(
|
|
235
243
|
self,
|
|
236
244
|
query: str,
|
|
237
|
-
indices:
|
|
245
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
238
246
|
*,
|
|
239
247
|
explain: Literal[False] = False,
|
|
240
|
-
anchor_time:
|
|
241
|
-
context_anchor_time:
|
|
242
|
-
run_mode:
|
|
243
|
-
num_neighbors:
|
|
248
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
249
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
250
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
251
|
+
num_neighbors: list[int] | None = None,
|
|
244
252
|
num_hops: int = 2,
|
|
245
253
|
max_pq_iterations: int = 10,
|
|
246
|
-
random_seed:
|
|
247
|
-
verbose:
|
|
254
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
255
|
+
verbose: bool | ProgressLogger = True,
|
|
248
256
|
use_prediction_time: bool = False,
|
|
249
257
|
) -> pd.DataFrame:
|
|
250
258
|
pass
|
|
@@ -253,17 +261,17 @@ class KumoRFM:
|
|
|
253
261
|
def predict(
|
|
254
262
|
self,
|
|
255
263
|
query: str,
|
|
256
|
-
indices:
|
|
264
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
257
265
|
*,
|
|
258
|
-
explain:
|
|
259
|
-
anchor_time:
|
|
260
|
-
context_anchor_time:
|
|
261
|
-
run_mode:
|
|
262
|
-
num_neighbors:
|
|
266
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
267
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
268
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
269
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
270
|
+
num_neighbors: list[int] | None = None,
|
|
263
271
|
num_hops: int = 2,
|
|
264
272
|
max_pq_iterations: int = 10,
|
|
265
|
-
random_seed:
|
|
266
|
-
verbose:
|
|
273
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
274
|
+
verbose: bool | ProgressLogger = True,
|
|
267
275
|
use_prediction_time: bool = False,
|
|
268
276
|
) -> Explanation:
|
|
269
277
|
pass
|
|
@@ -271,19 +279,19 @@ class KumoRFM:
|
|
|
271
279
|
def predict(
|
|
272
280
|
self,
|
|
273
281
|
query: str,
|
|
274
|
-
indices:
|
|
282
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
275
283
|
*,
|
|
276
|
-
explain:
|
|
277
|
-
anchor_time:
|
|
278
|
-
context_anchor_time:
|
|
279
|
-
run_mode:
|
|
280
|
-
num_neighbors:
|
|
284
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
285
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
286
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
287
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
288
|
+
num_neighbors: list[int] | None = None,
|
|
281
289
|
num_hops: int = 2,
|
|
282
290
|
max_pq_iterations: int = 10,
|
|
283
|
-
random_seed:
|
|
284
|
-
verbose:
|
|
291
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
292
|
+
verbose: bool | ProgressLogger = True,
|
|
285
293
|
use_prediction_time: bool = False,
|
|
286
|
-
) ->
|
|
294
|
+
) -> pd.DataFrame | Explanation:
|
|
287
295
|
"""Returns predictions for a predictive query.
|
|
288
296
|
|
|
289
297
|
Args:
|
|
@@ -325,7 +333,7 @@ class KumoRFM:
|
|
|
325
333
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
326
334
|
containing the prediction, summary, and details.
|
|
327
335
|
"""
|
|
328
|
-
explain_config:
|
|
336
|
+
explain_config: ExplainConfig | None = None
|
|
329
337
|
if explain is True:
|
|
330
338
|
explain_config = ExplainConfig()
|
|
331
339
|
elif explain is not False:
|
|
@@ -369,11 +377,11 @@ class KumoRFM:
|
|
|
369
377
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
370
378
|
|
|
371
379
|
if not isinstance(verbose, ProgressLogger):
|
|
372
|
-
verbose =
|
|
380
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
373
381
|
|
|
374
382
|
with verbose as logger:
|
|
375
383
|
|
|
376
|
-
batch_size:
|
|
384
|
+
batch_size: int | None = None
|
|
377
385
|
if self._batch_size == 'max':
|
|
378
386
|
task_type = self._get_task_type(
|
|
379
387
|
query=query_def,
|
|
@@ -393,9 +401,9 @@ class KumoRFM:
|
|
|
393
401
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
394
402
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
395
403
|
|
|
396
|
-
predictions:
|
|
397
|
-
summary:
|
|
398
|
-
details:
|
|
404
|
+
predictions: list[pd.DataFrame] = []
|
|
405
|
+
summary: str | None = None
|
|
406
|
+
details: Explanation | None = None
|
|
399
407
|
for i, batch in enumerate(batches):
|
|
400
408
|
# TODO Re-use the context for subsequent predictions.
|
|
401
409
|
context = self._get_context(
|
|
@@ -429,8 +437,7 @@ class KumoRFM:
|
|
|
429
437
|
stats = Context.get_memory_stats(request_msg.context)
|
|
430
438
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
431
439
|
|
|
432
|
-
if
|
|
433
|
-
and len(batches) > 1):
|
|
440
|
+
if i == 0 and len(batches) > 1:
|
|
434
441
|
verbose.init_progress(
|
|
435
442
|
total=len(batches),
|
|
436
443
|
description='Predicting',
|
|
@@ -469,8 +476,7 @@ class KumoRFM:
|
|
|
469
476
|
|
|
470
477
|
predictions.append(df)
|
|
471
478
|
|
|
472
|
-
if (
|
|
473
|
-
and len(batches) > 1):
|
|
479
|
+
if len(batches) > 1:
|
|
474
480
|
verbose.step()
|
|
475
481
|
|
|
476
482
|
break
|
|
@@ -508,9 +514,9 @@ class KumoRFM:
|
|
|
508
514
|
def is_valid_entity(
|
|
509
515
|
self,
|
|
510
516
|
query: str,
|
|
511
|
-
indices:
|
|
517
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
512
518
|
*,
|
|
513
|
-
anchor_time:
|
|
519
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
514
520
|
) -> np.ndarray:
|
|
515
521
|
r"""Returns a mask that denotes which entities are valid for the
|
|
516
522
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -554,15 +560,15 @@ class KumoRFM:
|
|
|
554
560
|
self,
|
|
555
561
|
query: str,
|
|
556
562
|
*,
|
|
557
|
-
metrics:
|
|
558
|
-
anchor_time:
|
|
559
|
-
context_anchor_time:
|
|
560
|
-
run_mode:
|
|
561
|
-
num_neighbors:
|
|
563
|
+
metrics: list[str] | None = None,
|
|
564
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
565
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
566
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
567
|
+
num_neighbors: list[int] | None = None,
|
|
562
568
|
num_hops: int = 2,
|
|
563
569
|
max_pq_iterations: int = 10,
|
|
564
|
-
random_seed:
|
|
565
|
-
verbose:
|
|
570
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
571
|
+
verbose: bool | ProgressLogger = True,
|
|
566
572
|
use_prediction_time: bool = False,
|
|
567
573
|
) -> pd.DataFrame:
|
|
568
574
|
"""Evaluates a predictive query.
|
|
@@ -610,7 +616,7 @@ class KumoRFM:
|
|
|
610
616
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
611
617
|
|
|
612
618
|
if not isinstance(verbose, ProgressLogger):
|
|
613
|
-
verbose =
|
|
619
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
614
620
|
|
|
615
621
|
with verbose as logger:
|
|
616
622
|
context = self._get_context(
|
|
@@ -669,8 +675,8 @@ class KumoRFM:
|
|
|
669
675
|
query: str,
|
|
670
676
|
size: int,
|
|
671
677
|
*,
|
|
672
|
-
anchor_time:
|
|
673
|
-
random_seed:
|
|
678
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
679
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
674
680
|
max_iterations: int = 10,
|
|
675
681
|
) -> pd.DataFrame:
|
|
676
682
|
"""Returns the labels of a predictive query for a specified anchor
|
|
@@ -764,7 +770,7 @@ class KumoRFM:
|
|
|
764
770
|
@staticmethod
|
|
765
771
|
def _get_task_type(
|
|
766
772
|
query: ValidatedPredictiveQuery,
|
|
767
|
-
edge_types:
|
|
773
|
+
edge_types: list[tuple[str, str, str]],
|
|
768
774
|
) -> TaskType:
|
|
769
775
|
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
770
776
|
return TaskType.BINARY_CLASSIFICATION
|
|
@@ -819,7 +825,7 @@ class KumoRFM:
|
|
|
819
825
|
self,
|
|
820
826
|
query: ValidatedPredictiveQuery,
|
|
821
827
|
anchor_time: pd.Timestamp,
|
|
822
|
-
context_anchor_time:
|
|
828
|
+
context_anchor_time: pd.Timestamp | None,
|
|
823
829
|
evaluate: bool,
|
|
824
830
|
) -> None:
|
|
825
831
|
|
|
@@ -885,16 +891,16 @@ class KumoRFM:
|
|
|
885
891
|
def _get_context(
|
|
886
892
|
self,
|
|
887
893
|
query: ValidatedPredictiveQuery,
|
|
888
|
-
indices:
|
|
889
|
-
anchor_time:
|
|
890
|
-
context_anchor_time:
|
|
894
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
895
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
896
|
+
context_anchor_time: pd.Timestamp | None,
|
|
891
897
|
run_mode: RunMode,
|
|
892
|
-
num_neighbors:
|
|
898
|
+
num_neighbors: list[int] | None,
|
|
893
899
|
num_hops: int,
|
|
894
900
|
max_pq_iterations: int,
|
|
895
901
|
evaluate: bool,
|
|
896
|
-
random_seed:
|
|
897
|
-
logger:
|
|
902
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
903
|
+
logger: ProgressLogger | None = None,
|
|
898
904
|
) -> Context:
|
|
899
905
|
|
|
900
906
|
if num_neighbors is not None:
|
|
@@ -1069,7 +1075,7 @@ class KumoRFM:
|
|
|
1069
1075
|
raise NotImplementedError
|
|
1070
1076
|
logger.log(msg)
|
|
1071
1077
|
|
|
1072
|
-
entity_table_names:
|
|
1078
|
+
entity_table_names: tuple[str, ...]
|
|
1073
1079
|
if task_type.is_link_pred:
|
|
1074
1080
|
final_aggr = query.get_final_target_aggregation()
|
|
1075
1081
|
assert final_aggr is not None
|
|
@@ -1127,7 +1133,7 @@ class KumoRFM:
|
|
|
1127
1133
|
|
|
1128
1134
|
@staticmethod
|
|
1129
1135
|
def _validate_metrics(
|
|
1130
|
-
metrics:
|
|
1136
|
+
metrics: list[str],
|
|
1131
1137
|
task_type: TaskType,
|
|
1132
1138
|
) -> None:
|
|
1133
1139
|
|
|
@@ -1184,7 +1190,7 @@ class KumoRFM:
|
|
|
1184
1190
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1185
1191
|
|
|
1186
1192
|
|
|
1187
|
-
def format_value(value:
|
|
1193
|
+
def format_value(value: int | float) -> str:
|
|
1188
1194
|
if value == int(value):
|
|
1189
1195
|
return f'{int(value):,}'
|
|
1190
1196
|
if abs(value) >= 1000:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import json
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
import requests
|
|
6
6
|
|
|
@@ -48,8 +48,8 @@ class KumoClient_SageMakerAdapter(KumoClient):
|
|
|
48
48
|
|
|
49
49
|
# Recording buffers.
|
|
50
50
|
self._recording_active = False
|
|
51
|
-
self._recorded_reqs:
|
|
52
|
-
self._recorded_resps:
|
|
51
|
+
self._recorded_reqs: list[dict[str, Any]] = []
|
|
52
|
+
self._recorded_resps: list[dict[str, Any]] = []
|
|
53
53
|
|
|
54
54
|
def authenticate(self) -> None:
|
|
55
55
|
# TODO(siyang): call /ping to verify?
|
|
@@ -92,7 +92,7 @@ class KumoClient_SageMakerAdapter(KumoClient):
|
|
|
92
92
|
self._recorded_reqs.clear()
|
|
93
93
|
self._recorded_resps.clear()
|
|
94
94
|
|
|
95
|
-
def end_recording(self) ->
|
|
95
|
+
def end_recording(self) -> list[tuple[dict[str, Any], dict[str, Any]]]:
|
|
96
96
|
"""Stop recording and return recorded requests/responses."""
|
|
97
97
|
assert self._recording_active
|
|
98
98
|
self._recording_active = False
|