kumoai 2.12.0.dev202511111731__cp311-cp311-macosx_11_0_arm64.whl → 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +18 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +162 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
- kumoai/experimental/rfm/backend/local/sampler.py +242 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +14 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +374 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/local_graph_sampler.py +43 -4
- kumoai/experimental/rfm/local_pquery_driver.py +1 -1
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +28 -27
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/METADATA +12 -2
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/RECORD +36 -21
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
from kumoapi.typing import Dtype
|
|
7
|
+
|
|
8
|
+
PANDAS_TO_DTYPE: Dict[str, Dtype] = {
|
|
9
|
+
'bool': Dtype.bool,
|
|
10
|
+
'boolean': Dtype.bool,
|
|
11
|
+
'int8': Dtype.int,
|
|
12
|
+
'int16': Dtype.int,
|
|
13
|
+
'int32': Dtype.int,
|
|
14
|
+
'int64': Dtype.int,
|
|
15
|
+
'float16': Dtype.float,
|
|
16
|
+
'float32': Dtype.float,
|
|
17
|
+
'float64': Dtype.float,
|
|
18
|
+
'object': Dtype.string,
|
|
19
|
+
'string': Dtype.string,
|
|
20
|
+
'string[python]': Dtype.string,
|
|
21
|
+
'string[pyarrow]': Dtype.string,
|
|
22
|
+
'binary': Dtype.binary,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def infer_dtype(ser: pd.Series) -> Dtype:
|
|
27
|
+
"""Extracts the :class:`Dtype` from a :class:`pandas.Series`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
ser: A :class:`pandas.Series` to analyze.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The data type.
|
|
34
|
+
"""
|
|
35
|
+
if pd.api.types.is_datetime64_any_dtype(ser.dtype):
|
|
36
|
+
return Dtype.date
|
|
37
|
+
if pd.api.types.is_timedelta64_dtype(ser.dtype):
|
|
38
|
+
return Dtype.timedelta
|
|
39
|
+
if isinstance(ser.dtype, pd.CategoricalDtype):
|
|
40
|
+
return Dtype.string
|
|
41
|
+
|
|
42
|
+
if (pd.api.types.is_object_dtype(ser.dtype)
|
|
43
|
+
and not isinstance(ser.dtype, pd.ArrowDtype)):
|
|
44
|
+
index = ser.iloc[:1000].first_valid_index()
|
|
45
|
+
if index is not None and pd.api.types.is_list_like(ser[index]):
|
|
46
|
+
pos = ser.index.get_loc(index)
|
|
47
|
+
assert isinstance(pos, int)
|
|
48
|
+
ser = ser.iloc[pos:pos + 1000].dropna()
|
|
49
|
+
arr = pa.array(ser.tolist())
|
|
50
|
+
ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
|
|
51
|
+
|
|
52
|
+
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
53
|
+
if pa.types.is_list(ser.dtype.pyarrow_dtype):
|
|
54
|
+
elem_dtype = ser.dtype.pyarrow_dtype.value_type
|
|
55
|
+
if pa.types.is_integer(elem_dtype):
|
|
56
|
+
return Dtype.intlist
|
|
57
|
+
if pa.types.is_floating(elem_dtype):
|
|
58
|
+
return Dtype.floatlist
|
|
59
|
+
if pa.types.is_decimal(elem_dtype):
|
|
60
|
+
return Dtype.floatlist
|
|
61
|
+
if pa.types.is_string(elem_dtype):
|
|
62
|
+
return Dtype.stringlist
|
|
63
|
+
if pa.types.is_null(elem_dtype):
|
|
64
|
+
return Dtype.floatlist
|
|
65
|
+
|
|
66
|
+
if isinstance(ser.dtype, np.dtype):
|
|
67
|
+
dtype_str = str(ser.dtype).lower()
|
|
68
|
+
elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
|
|
69
|
+
dtype_str = ser.dtype.name.lower()
|
|
70
|
+
dtype_str = dtype_str.split('[')[0] # Remove backend metadata
|
|
71
|
+
elif isinstance(ser.dtype, pa.DataType):
|
|
72
|
+
dtype_str = str(ser.dtype).lower()
|
|
73
|
+
else:
|
|
74
|
+
dtype_str = 'object'
|
|
75
|
+
|
|
76
|
+
if dtype_str not in PANDAS_TO_DTYPE:
|
|
77
|
+
raise ValueError(f"Unsupported data type '{ser.dtype}'")
|
|
78
|
+
|
|
79
|
+
return PANDAS_TO_DTYPE[dtype_str]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def infer_primary_key(
|
|
9
|
+
table_name: str,
|
|
10
|
+
df: pd.DataFrame,
|
|
11
|
+
candidates: list[str],
|
|
12
|
+
) -> Optional[str]:
|
|
13
|
+
r"""Auto-detect potential primary key column.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
table_name: The table name.
|
|
17
|
+
df: The pandas DataFrame to analyze.
|
|
18
|
+
candidates: A list of potential candidates.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The name of the detected primary key, or ``None`` if not found.
|
|
22
|
+
"""
|
|
23
|
+
# A list of (potentially modified) table names that are eligible to match
|
|
24
|
+
# with a primary key, i.e.:
|
|
25
|
+
# - UserInfo -> User
|
|
26
|
+
# - snakecase <-> camelcase
|
|
27
|
+
# - camelcase <-> snakecase
|
|
28
|
+
# - plural <-> singular (users -> user, eligibilities -> eligibility)
|
|
29
|
+
# - verb -> noun (qualifying -> qualify)
|
|
30
|
+
_table_names = {table_name}
|
|
31
|
+
if table_name.lower().endswith('_info'):
|
|
32
|
+
_table_names.add(table_name[:-5])
|
|
33
|
+
elif table_name.lower().endswith('info'):
|
|
34
|
+
_table_names.add(table_name[:-4])
|
|
35
|
+
|
|
36
|
+
table_names = set()
|
|
37
|
+
for _table_name in _table_names:
|
|
38
|
+
table_names.add(_table_name.lower())
|
|
39
|
+
snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
|
|
40
|
+
snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
|
|
41
|
+
table_names.add(snakecase.lower())
|
|
42
|
+
camelcase = _table_name.replace('_', '')
|
|
43
|
+
table_names.add(camelcase.lower())
|
|
44
|
+
if _table_name.lower().endswith('s'):
|
|
45
|
+
table_names.add(_table_name.lower()[:-1])
|
|
46
|
+
table_names.add(snakecase.lower()[:-1])
|
|
47
|
+
table_names.add(camelcase.lower()[:-1])
|
|
48
|
+
else:
|
|
49
|
+
table_names.add(_table_name.lower() + 's')
|
|
50
|
+
table_names.add(snakecase.lower() + 's')
|
|
51
|
+
table_names.add(camelcase.lower() + 's')
|
|
52
|
+
if _table_name.lower().endswith('ies'):
|
|
53
|
+
table_names.add(_table_name.lower()[:-3] + 'y')
|
|
54
|
+
table_names.add(snakecase.lower()[:-3] + 'y')
|
|
55
|
+
table_names.add(camelcase.lower()[:-3] + 'y')
|
|
56
|
+
elif _table_name.lower().endswith('y'):
|
|
57
|
+
table_names.add(_table_name.lower()[:-1] + 'ies')
|
|
58
|
+
table_names.add(snakecase.lower()[:-1] + 'ies')
|
|
59
|
+
table_names.add(camelcase.lower()[:-1] + 'ies')
|
|
60
|
+
if _table_name.lower().endswith('ing'):
|
|
61
|
+
table_names.add(_table_name.lower()[:-3])
|
|
62
|
+
table_names.add(snakecase.lower()[:-3])
|
|
63
|
+
table_names.add(camelcase.lower()[:-3])
|
|
64
|
+
|
|
65
|
+
scores: list[tuple[str, int]] = []
|
|
66
|
+
for col_name in candidates:
|
|
67
|
+
col_name_lower = col_name.lower()
|
|
68
|
+
|
|
69
|
+
score = 0
|
|
70
|
+
|
|
71
|
+
if col_name_lower == 'id':
|
|
72
|
+
score += 4
|
|
73
|
+
|
|
74
|
+
for table_name_lower in table_names:
|
|
75
|
+
|
|
76
|
+
if col_name_lower == table_name_lower:
|
|
77
|
+
score += 4 # USER -> USER
|
|
78
|
+
break
|
|
79
|
+
|
|
80
|
+
for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
|
|
81
|
+
if not col_name_lower.endswith(suffix):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
if col_name_lower == f'{table_name_lower}_{suffix}':
|
|
85
|
+
score += 5 # USER -> USER_ID
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
if col_name_lower == f'{table_name_lower}{suffix}':
|
|
89
|
+
score += 5 # User -> UserId
|
|
90
|
+
break
|
|
91
|
+
|
|
92
|
+
if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
|
|
93
|
+
score += 2
|
|
94
|
+
|
|
95
|
+
if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
|
|
96
|
+
score += 2
|
|
97
|
+
|
|
98
|
+
# `rel-bench` hard-coding :(
|
|
99
|
+
if table_name == 'studies' and col_name == 'nct_id':
|
|
100
|
+
score += 1
|
|
101
|
+
|
|
102
|
+
ser = df[col_name].iloc[:1_000_000]
|
|
103
|
+
score += 3 * (ser.nunique() / len(ser))
|
|
104
|
+
|
|
105
|
+
scores.append((col_name, score))
|
|
106
|
+
|
|
107
|
+
scores = [x for x in scores if x[-1] >= 4]
|
|
108
|
+
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
109
|
+
|
|
110
|
+
if len(scores) == 0:
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
if len(scores) == 1:
|
|
114
|
+
return scores[0][0]
|
|
115
|
+
|
|
116
|
+
# In case of multiple candidates, only return one if its score is unique:
|
|
117
|
+
if scores[0][1] != scores[1][1]:
|
|
118
|
+
return scores[0][0]
|
|
119
|
+
|
|
120
|
+
max_score = max(scores, key=lambda x: x[1])
|
|
121
|
+
candidates = [col_name for col_name, score in scores if score == max_score]
|
|
122
|
+
warnings.warn(f"Found multiple potential primary keys in table "
|
|
123
|
+
f"'{table_name}': {candidates}. Please specify the primary "
|
|
124
|
+
f"key for this table manually.")
|
|
125
|
+
|
|
126
|
+
return None
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def infer_time_column(
|
|
9
|
+
df: pd.DataFrame,
|
|
10
|
+
candidates: list[str],
|
|
11
|
+
) -> Optional[str]:
|
|
12
|
+
r"""Auto-detect potential time column.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
df: The pandas DataFrame to analyze.
|
|
16
|
+
candidates: A list of potential candidates.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
The name of the detected time column, or ``None`` if not found.
|
|
20
|
+
"""
|
|
21
|
+
candidates = [ # Exclude all candidates with `*last*` in column names:
|
|
22
|
+
col_name for col_name in candidates
|
|
23
|
+
if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
if len(candidates) == 0:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
if len(candidates) == 1:
|
|
30
|
+
return candidates[0]
|
|
31
|
+
|
|
32
|
+
# If there exists a dedicated `create*` column, use it as time column:
|
|
33
|
+
create_candidates = [
|
|
34
|
+
candidate for candidate in candidates
|
|
35
|
+
if candidate.lower().startswith('create')
|
|
36
|
+
]
|
|
37
|
+
if len(create_candidates) == 1:
|
|
38
|
+
return create_candidates[0]
|
|
39
|
+
if len(create_candidates) > 1:
|
|
40
|
+
candidates = create_candidates
|
|
41
|
+
|
|
42
|
+
# Find the most optimal time column. Usually, it is the one pointing to
|
|
43
|
+
# the oldest timestamps:
|
|
44
|
+
with warnings.catch_warnings():
|
|
45
|
+
warnings.filterwarnings('ignore', message='Could not infer format')
|
|
46
|
+
min_timestamp_dict = {
|
|
47
|
+
key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
|
|
48
|
+
for key in candidates
|
|
49
|
+
}
|
|
50
|
+
min_timestamp_dict = {
|
|
51
|
+
key: value.min().tz_localize(None)
|
|
52
|
+
for key, value in min_timestamp_dict.items()
|
|
53
|
+
}
|
|
54
|
+
min_timestamp_dict = {
|
|
55
|
+
key: value
|
|
56
|
+
for key, value in min_timestamp_dict.items() if not pd.isna(value)
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
if len(min_timestamp_dict) == 0:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
|
|
@@ -1,14 +1,54 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from typing import Dict, List, Optional, Tuple
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import pandas as pd
|
|
5
|
-
from kumoapi.model_plan import RunMode
|
|
6
6
|
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
7
7
|
from kumoapi.typing import Stype
|
|
8
8
|
|
|
9
9
|
import kumoai.kumolib as kumolib
|
|
10
|
-
from kumoai.experimental.rfm.
|
|
11
|
-
|
|
10
|
+
from kumoai.experimental.rfm.backend.local import LocalGraphStore
|
|
11
|
+
|
|
12
|
+
PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
|
|
13
|
+
MULTISPACE = re.compile(r"\s+")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def normalize_text(
|
|
17
|
+
ser: pd.Series,
|
|
18
|
+
max_words: Optional[int] = 50,
|
|
19
|
+
) -> pd.Series:
|
|
20
|
+
r"""Normalizes text into a list of lower-case words.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
ser: The :class:`pandas.Series` to normalize.
|
|
24
|
+
max_words: The maximum number of words to return.
|
|
25
|
+
This will auto-shrink any large text column to avoid blowing up
|
|
26
|
+
context size.
|
|
27
|
+
"""
|
|
28
|
+
if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
|
|
29
|
+
return ser
|
|
30
|
+
|
|
31
|
+
def normalize_fn(line: str) -> list[str]:
|
|
32
|
+
line = PUNCTUATION.sub(" ", line)
|
|
33
|
+
line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
|
|
34
|
+
line = MULTISPACE.sub(" ", line)
|
|
35
|
+
words = line.split()
|
|
36
|
+
if max_words is not None:
|
|
37
|
+
words = words[:max_words]
|
|
38
|
+
return words
|
|
39
|
+
|
|
40
|
+
ser = ser.fillna('').astype(str)
|
|
41
|
+
|
|
42
|
+
if max_words is not None:
|
|
43
|
+
# We estimate the number of words as 5 characters + 1 space in an
|
|
44
|
+
# English text on average. We need this pre-filter here, as word
|
|
45
|
+
# splitting on a giant text can be very expensive:
|
|
46
|
+
ser = ser.str[:6 * max_words]
|
|
47
|
+
|
|
48
|
+
ser = ser.str.lower()
|
|
49
|
+
ser = ser.map(normalize_fn)
|
|
50
|
+
|
|
51
|
+
return ser
|
|
12
52
|
|
|
13
53
|
|
|
14
54
|
class LocalGraphSampler:
|
|
@@ -33,7 +73,6 @@ class LocalGraphSampler:
|
|
|
33
73
|
entity_table_names: Tuple[str, ...],
|
|
34
74
|
node: np.ndarray,
|
|
35
75
|
time: np.ndarray,
|
|
36
|
-
run_mode: RunMode,
|
|
37
76
|
num_neighbors: List[int],
|
|
38
77
|
exclude_cols_dict: Dict[str, List[str]],
|
|
39
78
|
) -> Subgraph:
|
|
@@ -17,7 +17,7 @@ from kumoapi.task import TaskType
|
|
|
17
17
|
from kumoapi.typing import AggregationType, DateOffset, Stype
|
|
18
18
|
|
|
19
19
|
import kumoai.kumolib as kumolib
|
|
20
|
-
from kumoai.experimental.rfm.
|
|
20
|
+
from kumoai.experimental.rfm.backend.local import LocalGraphStore
|
|
21
21
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
22
22
|
|
|
23
23
|
_coverage_warned = False
|
|
@@ -134,7 +134,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
134
134
|
outs: List[pd.Series] = []
|
|
135
135
|
masks: List[np.ndarray] = []
|
|
136
136
|
for _ in range(num_forecasts):
|
|
137
|
-
anchor_target_time = anchor_time[target_batch]
|
|
137
|
+
anchor_target_time = anchor_time.iloc[target_batch]
|
|
138
138
|
anchor_target_time = anchor_target_time.reset_index(drop=True)
|
|
139
139
|
|
|
140
140
|
time_filter_mask = (target_time <= anchor_target_time +
|
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -30,11 +30,11 @@ from kumoapi.rfm import (
|
|
|
30
30
|
)
|
|
31
31
|
from kumoapi.task import TaskType
|
|
32
32
|
|
|
33
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
34
34
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import
|
|
35
|
+
from kumoai.experimental.rfm import Graph
|
|
36
|
+
from kumoai.experimental.rfm.backend.local import LocalGraphStore
|
|
36
37
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
38
|
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
39
|
LocalPQueryDriver,
|
|
40
40
|
date_offset_to_seconds,
|
|
@@ -123,17 +123,17 @@ class KumoRFM:
|
|
|
123
123
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
124
124
|
relational dataset without training.
|
|
125
125
|
The model is pre-trained and the class provides an interface to query the
|
|
126
|
-
model from a :class:`
|
|
126
|
+
model from a :class:`Graph` object.
|
|
127
127
|
|
|
128
128
|
.. code-block:: python
|
|
129
129
|
|
|
130
|
-
from kumoai.experimental.rfm import
|
|
130
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
131
131
|
|
|
132
132
|
df_users = pd.DataFrame(...)
|
|
133
133
|
df_items = pd.DataFrame(...)
|
|
134
134
|
df_orders = pd.DataFrame(...)
|
|
135
135
|
|
|
136
|
-
graph =
|
|
136
|
+
graph = Graph.from_data({
|
|
137
137
|
'users': df_users,
|
|
138
138
|
'items': df_items,
|
|
139
139
|
'orders': df_orders,
|
|
@@ -141,40 +141,41 @@ class KumoRFM:
|
|
|
141
141
|
|
|
142
142
|
rfm = KumoRFM(graph)
|
|
143
143
|
|
|
144
|
-
query = ("PREDICT COUNT(
|
|
145
|
-
"FOR users.user_id=
|
|
146
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
147
147
|
|
|
148
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
149
149
|
# 1 0.85
|
|
150
150
|
|
|
151
151
|
Args:
|
|
152
152
|
graph: The graph.
|
|
153
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
154
|
-
materialization.
|
|
155
|
-
This is a runtime trade-off between graph materialization and model
|
|
156
|
-
processing speed.
|
|
157
|
-
It can be benefical to preprocess your data once and then run many
|
|
158
|
-
queries on top to achieve maximum model speed.
|
|
159
|
-
However, if activiated, graph materialization can take potentially
|
|
160
|
-
much longer, especially on graphs with many large text columns.
|
|
161
|
-
Best to tune this option manually.
|
|
162
153
|
verbose: Whether to print verbose output.
|
|
163
154
|
"""
|
|
164
155
|
def __init__(
|
|
165
156
|
self,
|
|
166
|
-
graph:
|
|
167
|
-
preprocess: bool = False,
|
|
157
|
+
graph: Graph,
|
|
168
158
|
verbose: Union[bool, ProgressLogger] = True,
|
|
169
159
|
) -> None:
|
|
170
160
|
graph = graph.validate()
|
|
171
161
|
self._graph_def = graph._to_api_graph_definition()
|
|
172
|
-
self._graph_store = LocalGraphStore(graph,
|
|
162
|
+
self._graph_store = LocalGraphStore(graph, verbose)
|
|
173
163
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
174
164
|
|
|
165
|
+
self._client: Optional[RFMAPI] = None
|
|
166
|
+
|
|
175
167
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
176
168
|
self.num_retries: int = 0
|
|
177
169
|
|
|
170
|
+
@property
|
|
171
|
+
def _api_client(self) -> RFMAPI:
|
|
172
|
+
if self._client is not None:
|
|
173
|
+
return self._client
|
|
174
|
+
|
|
175
|
+
from kumoai.experimental.rfm import global_state
|
|
176
|
+
self._client = RFMAPI(global_state.client)
|
|
177
|
+
return self._client
|
|
178
|
+
|
|
178
179
|
def __repr__(self) -> str:
|
|
179
180
|
return f'{self.__class__.__name__}()'
|
|
180
181
|
|
|
@@ -420,14 +421,14 @@ class KumoRFM:
|
|
|
420
421
|
for attempt in range(self.num_retries + 1):
|
|
421
422
|
try:
|
|
422
423
|
if explain_config is not None:
|
|
423
|
-
resp =
|
|
424
|
+
resp = self._api_client.explain(
|
|
424
425
|
request=_bytes,
|
|
425
426
|
skip_summary=explain_config.skip_summary,
|
|
426
427
|
)
|
|
427
428
|
summary = resp.summary
|
|
428
429
|
details = resp.details
|
|
429
430
|
else:
|
|
430
|
-
resp =
|
|
431
|
+
resp = self._api_client.predict(_bytes)
|
|
431
432
|
df = pd.DataFrame(**resp.prediction)
|
|
432
433
|
|
|
433
434
|
# Cast 'ENTITY' to correct data type:
|
|
@@ -630,10 +631,10 @@ class KumoRFM:
|
|
|
630
631
|
|
|
631
632
|
if len(request_bytes) > _MAX_SIZE:
|
|
632
633
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
633
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
634
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
634
635
|
|
|
635
636
|
try:
|
|
636
|
-
resp =
|
|
637
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
637
638
|
except HTTPException as e:
|
|
638
639
|
try:
|
|
639
640
|
msg = json.loads(e.detail)['detail']
|
|
@@ -731,7 +732,8 @@ class KumoRFM:
|
|
|
731
732
|
graph_definition=self._graph_def,
|
|
732
733
|
)
|
|
733
734
|
|
|
734
|
-
resp =
|
|
735
|
+
resp = self._api_client.parse_query(request)
|
|
736
|
+
|
|
735
737
|
# TODO Expose validation warnings.
|
|
736
738
|
|
|
737
739
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -1035,7 +1037,6 @@ class KumoRFM:
|
|
|
1035
1037
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
1036
1038
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
1037
1039
|
]),
|
|
1038
|
-
run_mode=run_mode,
|
|
1039
1040
|
num_neighbors=num_neighbors,
|
|
1040
1041
|
exclude_cols_dict=exclude_cols_dict,
|
|
1041
1042
|
)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Dict, List, Tuple
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from kumoai.client import KumoClient
|
|
8
|
+
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
9
|
+
from kumoai.exceptions import HTTPException
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
# isort: off
|
|
13
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
14
|
+
from mypy_boto3_sagemaker_runtime.type_defs import (
|
|
15
|
+
InvokeEndpointOutputTypeDef, )
|
|
16
|
+
# isort: on
|
|
17
|
+
except ImportError:
|
|
18
|
+
SageMakerRuntimeClient = Any
|
|
19
|
+
InvokeEndpointOutputTypeDef = Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SageMakerResponseAdapter(requests.Response):
|
|
23
|
+
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
24
|
+
super().__init__()
|
|
25
|
+
# Read the body bytes
|
|
26
|
+
self._content = sm_response['Body'].read()
|
|
27
|
+
self.status_code = 200
|
|
28
|
+
self.headers['Content-Type'] = sm_response.get('ContentType',
|
|
29
|
+
'application/json')
|
|
30
|
+
# Optionally, you can store original sm_response for debugging
|
|
31
|
+
self.sm_response = sm_response
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def text(self) -> str:
|
|
35
|
+
assert isinstance(self._content, bytes)
|
|
36
|
+
return self._content.decode('utf-8')
|
|
37
|
+
|
|
38
|
+
def json(self, **kwargs) -> dict[str, Any]: # type: ignore
|
|
39
|
+
return json.loads(self.text, **kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class KumoClient_SageMakerAdapter(KumoClient):
|
|
43
|
+
def __init__(self, region: str, endpoint_name: str):
|
|
44
|
+
import boto3
|
|
45
|
+
self._client: SageMakerRuntimeClient = boto3.client(
|
|
46
|
+
service_name="sagemaker-runtime", region_name=region)
|
|
47
|
+
self._endpoint_name = endpoint_name
|
|
48
|
+
|
|
49
|
+
# Recording buffers.
|
|
50
|
+
self._recording_active = False
|
|
51
|
+
self._recorded_reqs: List[Dict[str, Any]] = []
|
|
52
|
+
self._recorded_resps: List[Dict[str, Any]] = []
|
|
53
|
+
|
|
54
|
+
def authenticate(self) -> None:
|
|
55
|
+
# TODO(siyang): call /ping to verify?
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
59
|
+
assert endpoint.method == HTTPMethod.POST
|
|
60
|
+
if 'json' in kwargs:
|
|
61
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
62
|
+
elif 'data' in kwargs:
|
|
63
|
+
raw_payload = kwargs.pop('data')
|
|
64
|
+
assert isinstance(raw_payload, bytes)
|
|
65
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
66
|
+
else:
|
|
67
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
68
|
+
|
|
69
|
+
request = {
|
|
70
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
71
|
+
'payload': payload,
|
|
72
|
+
}
|
|
73
|
+
response: InvokeEndpointOutputTypeDef = self._client.invoke_endpoint(
|
|
74
|
+
EndpointName=self._endpoint_name,
|
|
75
|
+
ContentType="application/json",
|
|
76
|
+
Body=json.dumps(request),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
adapted_response = SageMakerResponseAdapter(response)
|
|
80
|
+
|
|
81
|
+
# If validation is active, store input/output
|
|
82
|
+
if self._recording_active:
|
|
83
|
+
self._recorded_reqs.append(request)
|
|
84
|
+
self._recorded_resps.append(adapted_response.json())
|
|
85
|
+
|
|
86
|
+
return adapted_response
|
|
87
|
+
|
|
88
|
+
def start_recording(self) -> None:
|
|
89
|
+
"""Start recording requests/responses to/from sagemaker endpoint."""
|
|
90
|
+
assert not self._recording_active
|
|
91
|
+
self._recording_active = True
|
|
92
|
+
self._recorded_reqs.clear()
|
|
93
|
+
self._recorded_resps.clear()
|
|
94
|
+
|
|
95
|
+
def end_recording(self) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
|
96
|
+
"""Stop recording and return recorded requests/responses."""
|
|
97
|
+
assert self._recording_active
|
|
98
|
+
self._recording_active = False
|
|
99
|
+
recorded = list(zip(self._recorded_reqs, self._recorded_resps))
|
|
100
|
+
self._recorded_reqs.clear()
|
|
101
|
+
self._recorded_resps.clear()
|
|
102
|
+
return recorded
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class KumoClient_SageMakerProxy_Local(KumoClient):
|
|
106
|
+
def __init__(self, url: str):
|
|
107
|
+
self._client = KumoClient(url, api_key=None)
|
|
108
|
+
self._client._api_url = self._client._url
|
|
109
|
+
self._endpoint = Endpoint('/invocations', HTTPMethod.POST)
|
|
110
|
+
|
|
111
|
+
def authenticate(self) -> None:
|
|
112
|
+
try:
|
|
113
|
+
self._client._session.get(
|
|
114
|
+
self._url + '/ping',
|
|
115
|
+
verify=self._verify_ssl).raise_for_status()
|
|
116
|
+
except Exception:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"Client authentication failed. Please check if you "
|
|
119
|
+
"have a valid API key/credentials.")
|
|
120
|
+
|
|
121
|
+
def _request(self, endpoint: Endpoint, **kwargs: Any) -> requests.Response:
|
|
122
|
+
assert endpoint.method == HTTPMethod.POST
|
|
123
|
+
if 'json' in kwargs:
|
|
124
|
+
payload = json.dumps(kwargs.pop('json'))
|
|
125
|
+
elif 'data' in kwargs:
|
|
126
|
+
raw_payload = kwargs.pop('data')
|
|
127
|
+
assert isinstance(raw_payload, bytes)
|
|
128
|
+
payload = base64.b64encode(raw_payload).decode()
|
|
129
|
+
else:
|
|
130
|
+
raise HTTPException(400, 'Unable to send data to KumoRFM.')
|
|
131
|
+
return self._client._request(
|
|
132
|
+
self._endpoint,
|
|
133
|
+
json={
|
|
134
|
+
'method': endpoint.get_path().rsplit('/')[-1],
|
|
135
|
+
'payload': payload,
|
|
136
|
+
},
|
|
137
|
+
**kwargs,
|
|
138
|
+
)
|
kumoai/spcs.py
CHANGED
|
@@ -54,9 +54,7 @@ def _refresh_spcs_token() -> None:
|
|
|
54
54
|
api_key=global_state._api_key,
|
|
55
55
|
spcs_token=spcs_token,
|
|
56
56
|
)
|
|
57
|
-
|
|
58
|
-
raise ValueError("Client authentication failed. Please check if you "
|
|
59
|
-
"have a valid API key.")
|
|
57
|
+
client.authenticate()
|
|
60
58
|
|
|
61
59
|
# Update state:
|
|
62
60
|
global_state.set_spcs_token(spcs_token)
|
kumoai/testing/decorators.py
CHANGED
|
@@ -25,7 +25,7 @@ def onlyFullTest(func: Callable) -> Callable:
|
|
|
25
25
|
def has_package(package: str) -> bool:
|
|
26
26
|
r"""Returns ``True`` in case ``package`` is installed."""
|
|
27
27
|
req = Requirement(package)
|
|
28
|
-
if importlib.util.find_spec(req.name) is None:
|
|
28
|
+
if importlib.util.find_spec(req.name) is None: # type: ignore
|
|
29
29
|
return False
|
|
30
30
|
|
|
31
31
|
try:
|
{kumoai-2.12.0.dev202511111731.dist-info → kumoai-2.13.0.dev202512091732.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: kumoai
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.13.0.dev202512091732
|
|
4
4
|
Summary: AI on the Modern Data Stack
|
|
5
5
|
Author-email: "Kumo.AI" <hello@kumo.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -23,7 +23,7 @@ Requires-Dist: requests>=2.28.2
|
|
|
23
23
|
Requires-Dist: urllib3
|
|
24
24
|
Requires-Dist: plotly
|
|
25
25
|
Requires-Dist: typing_extensions>=4.5.0
|
|
26
|
-
Requires-Dist: kumo-api==0.
|
|
26
|
+
Requires-Dist: kumo-api==0.48.0
|
|
27
27
|
Requires-Dist: tqdm>=4.66.0
|
|
28
28
|
Requires-Dist: aiohttp>=3.10.0
|
|
29
29
|
Requires-Dist: pydantic>=1.10.21
|
|
@@ -38,6 +38,16 @@ Provides-Extra: test
|
|
|
38
38
|
Requires-Dist: pytest; extra == "test"
|
|
39
39
|
Requires-Dist: pytest-mock; extra == "test"
|
|
40
40
|
Requires-Dist: requests-mock; extra == "test"
|
|
41
|
+
Provides-Extra: sqlite
|
|
42
|
+
Requires-Dist: adbc_driver_sqlite; extra == "sqlite"
|
|
43
|
+
Provides-Extra: snowflake
|
|
44
|
+
Requires-Dist: snowflake-connector-python; extra == "snowflake"
|
|
45
|
+
Requires-Dist: pyyaml; extra == "snowflake"
|
|
46
|
+
Provides-Extra: sagemaker
|
|
47
|
+
Requires-Dist: boto3<2.0,>=1.30.0; extra == "sagemaker"
|
|
48
|
+
Requires-Dist: mypy-boto3-sagemaker-runtime<2.0,>=1.34.0; extra == "sagemaker"
|
|
49
|
+
Provides-Extra: test-sagemaker
|
|
50
|
+
Requires-Dist: sagemaker<3.0; extra == "test-sagemaker"
|
|
41
51
|
Dynamic: license-file
|
|
42
52
|
Dynamic: requires-dist
|
|
43
53
|
|