kumoai 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/experimental/rfm/__init__.py +22 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +25 -24
- kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
- kumoai/experimental/rfm/backend/snow/table.py +146 -51
- kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
- kumoai/experimental/rfm/backend/sqlite/table.py +94 -47
- kumoai/experimental/rfm/base/__init__.py +6 -7
- kumoai/experimental/rfm/base/column.py +97 -5
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +5 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +68 -9
- kumoai/experimental/rfm/base/table.py +284 -120
- kumoai/experimental/rfm/graph.py +139 -86
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +6 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +4 -20
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +51 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +33 -30
- kumoai/experimental/rfm/base/column_expression.py +0 -16
- kumoai/experimental/rfm/base/sql_table.py +0 -113
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
kumoai/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import os
|
|
2
3
|
import sys
|
|
3
4
|
import threading
|
|
@@ -68,9 +69,8 @@ class GlobalState(metaclass=Singleton):
|
|
|
68
69
|
if self._url is None or (self._api_key is None
|
|
69
70
|
and self._spcs_token is None
|
|
70
71
|
and self._snowpark_session is None):
|
|
71
|
-
raise ValueError(
|
|
72
|
-
|
|
73
|
-
"your client before proceeding.")
|
|
72
|
+
raise ValueError("Client creation or authentication failed. "
|
|
73
|
+
"Please re-create your client before proceeding.")
|
|
74
74
|
|
|
75
75
|
if hasattr(self.thread_local, '_client'):
|
|
76
76
|
# Set the spcs token in the client to ensure it has the latest.
|
|
@@ -123,10 +123,9 @@ def init(
|
|
|
123
123
|
""" # noqa
|
|
124
124
|
# Avoid mutations to the global state after it is set:
|
|
125
125
|
if global_state.initialized:
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
"session.")
|
|
126
|
+
warnings.warn("Kumo SDK already initialized. To re-initialize the "
|
|
127
|
+
"SDK, please start a new interpreter. No changes will "
|
|
128
|
+
"be made to the current session.")
|
|
130
129
|
return
|
|
131
130
|
|
|
132
131
|
set_log_level(os.getenv(_ENV_KUMO_LOG, log_level))
|
|
@@ -138,15 +137,15 @@ def init(
|
|
|
138
137
|
if snowflake_application:
|
|
139
138
|
if url is not None:
|
|
140
139
|
raise ValueError(
|
|
141
|
-
"
|
|
142
|
-
"are specified. If running from a
|
|
143
|
-
"only snowflake_application.")
|
|
140
|
+
"Kumo SDK initialization failed. Both 'snowflake_application' "
|
|
141
|
+
"and 'url' are specified. If running from a Snowflake "
|
|
142
|
+
"notebook, specify only 'snowflake_application'.")
|
|
144
143
|
snowpark_session = _get_active_session()
|
|
145
144
|
if not snowpark_session:
|
|
146
145
|
raise ValueError(
|
|
147
|
-
"
|
|
148
|
-
"without an active
|
|
149
|
-
"a
|
|
146
|
+
"Kumo SDK initialization failed. 'snowflake_application' is "
|
|
147
|
+
"specified without an active Snowpark session. If running "
|
|
148
|
+
"outside a Snowflake notebook, specify a URL and credentials.")
|
|
150
149
|
description = snowpark_session.sql(
|
|
151
150
|
f"DESCRIBE SERVICE {snowflake_application}."
|
|
152
151
|
"USER_SCHEMA.KUMO_SERVICE").collect()[0]
|
|
@@ -155,14 +154,14 @@ def init(
|
|
|
155
154
|
if api_key is None and not snowflake_application:
|
|
156
155
|
if snowflake_credentials is None:
|
|
157
156
|
raise ValueError(
|
|
158
|
-
"
|
|
159
|
-
"credentials provided. Please either set the
|
|
160
|
-
"or explicitly call `kumoai.init(...)`.")
|
|
157
|
+
"Kumo SDK initialization failed. Neither an API key nor "
|
|
158
|
+
"Snowflake credentials provided. Please either set the "
|
|
159
|
+
"'KUMO_API_KEY' or explicitly call `kumoai.init(...)`.")
|
|
161
160
|
if (set(snowflake_credentials.keys())
|
|
162
161
|
!= {'user', 'password', 'account'}):
|
|
163
162
|
raise ValueError(
|
|
164
|
-
f"Provided credentials should be a dictionary with
|
|
165
|
-
f"'user', 'password', and 'account'. Only "
|
|
163
|
+
f"Provided Snowflake credentials should be a dictionary with "
|
|
164
|
+
f"keys 'user', 'password', and 'account'. Only "
|
|
166
165
|
f"{set(snowflake_credentials.keys())} were provided.")
|
|
167
166
|
|
|
168
167
|
# Get or infer URL:
|
|
@@ -173,10 +172,10 @@ def init(
|
|
|
173
172
|
except KeyError:
|
|
174
173
|
pass
|
|
175
174
|
if url is None:
|
|
176
|
-
raise ValueError(
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
175
|
+
raise ValueError("Kumo SDK initialization failed since no endpoint "
|
|
176
|
+
"URL was provided. Please either set the "
|
|
177
|
+
"'KUMO_API_ENDPOINT' environment variable or "
|
|
178
|
+
"explicitly call `kumoai.init(...)`.")
|
|
180
179
|
|
|
181
180
|
# Assign global state after verification that client can be created and
|
|
182
181
|
# authenticated successfully:
|
|
@@ -198,10 +197,8 @@ def init(
|
|
|
198
197
|
logger = logging.getLogger('kumoai')
|
|
199
198
|
log_level = logging.getLevelName(logger.getEffectiveLevel())
|
|
200
199
|
|
|
201
|
-
logger.info(
|
|
202
|
-
|
|
203
|
-
f"against deployment {url}, with "
|
|
204
|
-
f"log level {log_level}.")
|
|
200
|
+
logger.info(f"Initialized Kumo SDK v{__version__} against deployment "
|
|
201
|
+
f"'{url}'")
|
|
205
202
|
|
|
206
203
|
|
|
207
204
|
def set_log_level(level: str) -> None:
|
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.14.0.
|
|
1
|
+
__version__ = '2.14.0.dev202512301731'
|
kumoai/client/client.py
CHANGED
|
@@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
|
|
13
13
|
ArtifactExportJobAPI,
|
|
14
14
|
BaselineJobAPI,
|
|
15
15
|
BatchPredictionJobAPI,
|
|
16
|
+
DistillationJobAPI,
|
|
16
17
|
GeneratePredictionTableJobAPI,
|
|
17
18
|
GenerateTrainTableJobAPI,
|
|
18
19
|
LLMJobAPI,
|
|
@@ -132,6 +133,11 @@ class KumoClient:
|
|
|
132
133
|
from kumoai.client.jobs import TrainingJobAPI
|
|
133
134
|
return TrainingJobAPI(self)
|
|
134
135
|
|
|
136
|
+
@property
|
|
137
|
+
def distillation_job_api(self) -> 'DistillationJobAPI':
|
|
138
|
+
from kumoai.client.jobs import DistillationJobAPI
|
|
139
|
+
return DistillationJobAPI(self)
|
|
140
|
+
|
|
135
141
|
@property
|
|
136
142
|
def batch_prediction_job_api(self) -> 'BatchPredictionJobAPI':
|
|
137
143
|
from kumoai.client.jobs import BatchPredictionJobAPI
|
kumoai/client/jobs.py
CHANGED
|
@@ -22,6 +22,8 @@ from kumoapi.jobs import (
|
|
|
22
22
|
BatchPredictionRequest,
|
|
23
23
|
CancelBatchPredictionJobResponse,
|
|
24
24
|
CancelTrainingJobResponse,
|
|
25
|
+
DistillationJobRequest,
|
|
26
|
+
DistillationJobResource,
|
|
25
27
|
ErrorDetails,
|
|
26
28
|
GeneratePredictionTableJobResource,
|
|
27
29
|
GeneratePredictionTableRequest,
|
|
@@ -171,6 +173,28 @@ class TrainingJobAPI(CommonJobAPI[TrainingJobRequest, TrainingJobResource]):
|
|
|
171
173
|
return resource.config
|
|
172
174
|
|
|
173
175
|
|
|
176
|
+
class DistillationJobAPI(CommonJobAPI[DistillationJobRequest,
|
|
177
|
+
DistillationJobResource]):
|
|
178
|
+
r"""Typed API definition for the distillation job resource."""
|
|
179
|
+
def __init__(self, client: KumoClient) -> None:
|
|
180
|
+
super().__init__(client, '/training_jobs/distilled_training_job',
|
|
181
|
+
DistillationJobResource)
|
|
182
|
+
|
|
183
|
+
def get_config(self, job_id: str) -> DistillationJobRequest:
|
|
184
|
+
raise NotImplementedError(
|
|
185
|
+
"Getting the configuration for a distillation job is "
|
|
186
|
+
"not implemented yet.")
|
|
187
|
+
|
|
188
|
+
def get_progress(self, id: str) -> AutoTrainerProgress:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Getting the progress for a distillation job is not "
|
|
191
|
+
"implemented yet.")
|
|
192
|
+
|
|
193
|
+
def cancel(self, id: str) -> CancelTrainingJobResponse:
|
|
194
|
+
raise NotImplementedError(
|
|
195
|
+
"Cancelling a distillation job is not implemented yet.")
|
|
196
|
+
|
|
197
|
+
|
|
174
198
|
class BatchPredictionJobAPI(CommonJobAPI[BatchPredictionRequest,
|
|
175
199
|
BatchPredictionJobResource]):
|
|
176
200
|
r"""Typed API definition for the prediction job resource."""
|
|
@@ -78,9 +78,9 @@ def _get_snowflake_url(snowflake_application: str) -> str:
|
|
|
78
78
|
snowpark_session = _get_active_session()
|
|
79
79
|
if not snowpark_session:
|
|
80
80
|
raise ValueError(
|
|
81
|
-
"
|
|
82
|
-
"without an active
|
|
83
|
-
"a
|
|
81
|
+
"KumoRFM initialization failed. 'snowflake_application' is "
|
|
82
|
+
"specified without an active Snowpark session. If running outside "
|
|
83
|
+
"a Snowflake notebook, specify a URL and credentials.")
|
|
84
84
|
with snowpark_session.connection.cursor() as cur:
|
|
85
85
|
cur.execute(
|
|
86
86
|
f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
|
|
@@ -103,6 +103,9 @@ class RfmGlobalState:
|
|
|
103
103
|
|
|
104
104
|
@property
|
|
105
105
|
def client(self) -> KumoClient:
|
|
106
|
+
if self._backend == InferenceBackend.UNKNOWN:
|
|
107
|
+
raise RuntimeError("KumoRFM is not yet initialized")
|
|
108
|
+
|
|
106
109
|
if self._backend == InferenceBackend.REST:
|
|
107
110
|
return kumoai.global_state.client
|
|
108
111
|
|
|
@@ -146,18 +149,19 @@ def init(
|
|
|
146
149
|
with global_state._lock:
|
|
147
150
|
if global_state._initialized:
|
|
148
151
|
if url != global_state._url:
|
|
149
|
-
raise
|
|
150
|
-
"
|
|
151
|
-
"URL. Re-initialization with a different URL is not "
|
|
152
|
+
raise RuntimeError(
|
|
153
|
+
"KumoRFM has already been initialized with a different "
|
|
154
|
+
"API URL. Re-initialization with a different URL is not "
|
|
152
155
|
"supported.")
|
|
153
156
|
return
|
|
154
157
|
|
|
155
158
|
if snowflake_application:
|
|
156
159
|
if url is not None:
|
|
157
160
|
raise ValueError(
|
|
158
|
-
"
|
|
159
|
-
"url are specified. If
|
|
160
|
-
"specify only
|
|
161
|
+
"KumoRFM initialization failed. Both "
|
|
162
|
+
"'snowflake_application' and 'url' are specified. If "
|
|
163
|
+
"running from a Snowflake notebook, specify only "
|
|
164
|
+
"'snowflake_application'.")
|
|
161
165
|
url = _get_snowflake_url(snowflake_application)
|
|
162
166
|
api_key = "test:DISABLED"
|
|
163
167
|
|
|
@@ -166,32 +170,28 @@ def init(
|
|
|
166
170
|
|
|
167
171
|
backend, region, endpoint_name = _detect_backend(url)
|
|
168
172
|
if backend == InferenceBackend.REST:
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
kumoai.init(url=url, api_key=api_key,
|
|
177
|
-
snowflake_credentials=snowflake_credentials,
|
|
178
|
-
snowflake_application=snowflake_application,
|
|
179
|
-
log_level=log_level)
|
|
173
|
+
kumoai.init(
|
|
174
|
+
url=url,
|
|
175
|
+
api_key=api_key,
|
|
176
|
+
snowflake_credentials=snowflake_credentials,
|
|
177
|
+
snowflake_application=snowflake_application,
|
|
178
|
+
log_level=log_level,
|
|
179
|
+
)
|
|
180
180
|
elif backend == InferenceBackend.AWS_SAGEMAKER:
|
|
181
181
|
assert region
|
|
182
182
|
assert endpoint_name
|
|
183
183
|
KumoClient_SageMakerAdapter(region, endpoint_name).authenticate()
|
|
184
|
+
logger.info("KumoRFM initialized in AWS SageMaker")
|
|
184
185
|
else:
|
|
185
186
|
assert backend == InferenceBackend.LOCAL_SAGEMAKER
|
|
186
187
|
KumoClient_SageMakerProxy_Local(url).authenticate()
|
|
188
|
+
logger.info(f"KumoRFM initialized in local SageMaker at '{url}'")
|
|
187
189
|
|
|
188
190
|
global_state._url = url
|
|
189
191
|
global_state._backend = backend
|
|
190
192
|
global_state._region = region
|
|
191
193
|
global_state._endpoint_name = endpoint_name
|
|
192
194
|
global_state._initialized = True
|
|
193
|
-
logger.info("Kumo RFM initialized with backend: %s, url: %s", backend,
|
|
194
|
-
url)
|
|
195
195
|
|
|
196
196
|
|
|
197
197
|
LocalGraph = Graph # NOTE Backward compatibility - do not use anymore.
|
|
@@ -1,12 +1,11 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from typing import TYPE_CHECKING
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import pandas as pd
|
|
6
5
|
from kumoapi.rfm.context import Subgraph
|
|
7
|
-
from kumoapi.typing import Stype
|
|
8
6
|
|
|
9
7
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
8
|
+
from kumoai.experimental.rfm.base import Table
|
|
10
9
|
from kumoai.utils import ProgressLogger
|
|
11
10
|
|
|
12
11
|
try:
|
|
@@ -106,26 +105,20 @@ class LocalGraphStore:
|
|
|
106
105
|
df_dict: dict[str, pd.DataFrame] = {}
|
|
107
106
|
for table_name, table in graph.tables.items():
|
|
108
107
|
assert isinstance(table, LocalTable)
|
|
109
|
-
|
|
110
|
-
|
|
108
|
+
df_dict[table_name] = Table._sanitize(
|
|
109
|
+
df=table._data.copy(deep=False).reset_index(drop=True),
|
|
110
|
+
dtype_dict={
|
|
111
|
+
column.name: column.dtype
|
|
112
|
+
for column in table.columns
|
|
113
|
+
},
|
|
114
|
+
stype_dict={
|
|
115
|
+
column.name: column.stype
|
|
116
|
+
for column in table.columns
|
|
117
|
+
},
|
|
118
|
+
)
|
|
111
119
|
|
|
112
120
|
mask_dict: dict[str, np.ndarray] = {}
|
|
113
121
|
for table in graph.tables.values():
|
|
114
|
-
for col in table.columns:
|
|
115
|
-
if col.stype == Stype.timestamp:
|
|
116
|
-
ser = df_dict[table.name][col.name]
|
|
117
|
-
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
118
|
-
with warnings.catch_warnings():
|
|
119
|
-
warnings.filterwarnings(
|
|
120
|
-
'ignore',
|
|
121
|
-
message='Could not infer format',
|
|
122
|
-
)
|
|
123
|
-
ser = pd.to_datetime(ser, errors='coerce')
|
|
124
|
-
df_dict[table.name][col.name] = ser
|
|
125
|
-
if isinstance(ser.dtype, pd.DatetimeTZDtype):
|
|
126
|
-
ser = ser.dt.tz_localize(None)
|
|
127
|
-
df_dict[table.name][col.name] = ser
|
|
128
|
-
|
|
129
122
|
mask: np.ndarray | None = None
|
|
130
123
|
if table._time_column is not None:
|
|
131
124
|
ser = df_dict[table.name][table._time_column]
|
|
@@ -188,8 +181,6 @@ class LocalGraphStore:
|
|
|
188
181
|
continue
|
|
189
182
|
|
|
190
183
|
time = self.df_dict[table.name][table._time_column]
|
|
191
|
-
if time.dtype != 'datetime64[ns]':
|
|
192
|
-
time = time.astype('datetime64[ns]')
|
|
193
184
|
time_dict[table.name] = time.astype(int).to_numpy() // 1000**3
|
|
194
185
|
|
|
195
186
|
if table.name in self.mask_dict.keys():
|
|
@@ -219,9 +219,6 @@ class LocalSampler(Sampler):
|
|
|
219
219
|
for edge_type in set(self.edge_types) - set(time_offset_dict.keys()):
|
|
220
220
|
num_neighbors_dict['__'.join(edge_type)] = [0] * num_hops
|
|
221
221
|
|
|
222
|
-
if anchor_time.dtype != 'datetime64[ns]':
|
|
223
|
-
anchor_time = anchor_time.astype('datetime64')
|
|
224
|
-
|
|
225
222
|
count = 0
|
|
226
223
|
ys: list[pd.Series] = []
|
|
227
224
|
mask = np.full(len(index), False, dtype=bool)
|
|
@@ -1,11 +1,15 @@
|
|
|
1
|
-
import
|
|
2
|
-
from typing import cast
|
|
1
|
+
from typing import Sequence, cast
|
|
3
2
|
|
|
4
3
|
import pandas as pd
|
|
5
4
|
from kumoapi.model_plan import MissingType
|
|
6
5
|
|
|
7
|
-
from kumoai.experimental.rfm.base import
|
|
8
|
-
|
|
6
|
+
from kumoai.experimental.rfm.base import (
|
|
7
|
+
ColumnSpec,
|
|
8
|
+
DataBackend,
|
|
9
|
+
SourceColumn,
|
|
10
|
+
SourceForeignKey,
|
|
11
|
+
Table,
|
|
12
|
+
)
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
class LocalTable(Table):
|
|
@@ -71,7 +75,6 @@ class LocalTable(Table):
|
|
|
71
75
|
|
|
72
76
|
super().__init__(
|
|
73
77
|
name=name,
|
|
74
|
-
columns=list(df.columns),
|
|
75
78
|
primary_key=primary_key,
|
|
76
79
|
time_column=time_column,
|
|
77
80
|
end_time_column=end_time_column,
|
|
@@ -82,31 +85,29 @@ class LocalTable(Table):
|
|
|
82
85
|
return cast(DataBackend, DataBackend.LOCAL)
|
|
83
86
|
|
|
84
87
|
def _get_source_columns(self) -> list[SourceColumn]:
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
dtype = infer_dtype(ser)
|
|
90
|
-
except Exception:
|
|
91
|
-
warnings.warn(f"Data type inference for column '{column}' in "
|
|
92
|
-
f"table '{self.name}' failed. Consider changing "
|
|
93
|
-
f"the data type of the column to use it within "
|
|
94
|
-
f"this table.")
|
|
95
|
-
continue
|
|
96
|
-
|
|
97
|
-
source_column = SourceColumn(
|
|
98
|
-
name=column,
|
|
99
|
-
dtype=dtype,
|
|
88
|
+
return [
|
|
89
|
+
SourceColumn(
|
|
90
|
+
name=column_name,
|
|
91
|
+
dtype=None,
|
|
100
92
|
is_primary_key=False,
|
|
101
93
|
is_unique_key=False,
|
|
102
94
|
is_nullable=True,
|
|
103
|
-
)
|
|
104
|
-
|
|
95
|
+
) for column_name in self._data.columns
|
|
96
|
+
]
|
|
105
97
|
|
|
106
|
-
|
|
98
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
99
|
+
return []
|
|
107
100
|
|
|
108
|
-
def
|
|
101
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
109
102
|
return self._data
|
|
110
103
|
|
|
104
|
+
def _get_expr_sample_df(
|
|
105
|
+
self,
|
|
106
|
+
columns: Sequence[ColumnSpec],
|
|
107
|
+
) -> pd.DataFrame:
|
|
108
|
+
raise RuntimeError(f"Column expressions are not supported in "
|
|
109
|
+
f"'{self.__class__.__name__}'. Please apply your "
|
|
110
|
+
f"expressions on the `pd.DataFrame` directly.")
|
|
111
|
+
|
|
111
112
|
def _get_num_rows(self) -> int | None:
|
|
112
113
|
return len(self._data)
|