fundamental-client 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fundamental/__init__.py +34 -0
- fundamental/clients/__init__.py +7 -0
- fundamental/clients/base.py +37 -0
- fundamental/clients/ec2.py +37 -0
- fundamental/clients/fundamental.py +20 -0
- fundamental/config.py +138 -0
- fundamental/constants.py +41 -0
- fundamental/deprecated.py +43 -0
- fundamental/estimator/__init__.py +16 -0
- fundamental/estimator/base.py +263 -0
- fundamental/estimator/classification.py +46 -0
- fundamental/estimator/nexus_estimator.py +120 -0
- fundamental/estimator/regression.py +22 -0
- fundamental/exceptions.py +78 -0
- fundamental/models/__init__.py +4 -0
- fundamental/models/generated.py +431 -0
- fundamental/services/__init__.py +25 -0
- fundamental/services/feature_importance.py +172 -0
- fundamental/services/inference.py +283 -0
- fundamental/services/models.py +186 -0
- fundamental/utils/__init__.py +0 -0
- fundamental/utils/data.py +437 -0
- fundamental/utils/http.py +294 -0
- fundamental/utils/polling.py +97 -0
- fundamental/utils/safetensors_deserialize.py +98 -0
- fundamental_client-0.2.3.dist-info/METADATA +241 -0
- fundamental_client-0.2.3.dist-info/RECORD +29 -0
- fundamental_client-0.2.3.dist-info/WHEEL +4 -0
- fundamental_client-0.2.3.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Feature importance services for NEXUS client.
|
|
3
|
+
Handles feature importance computation with the NEXUS service.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from fundamental.clients.base import BaseClient
|
|
11
|
+
from fundamental.constants import (
|
|
12
|
+
DEFAULT_POLLING_INTERVAL_SECONDS,
|
|
13
|
+
DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
|
|
14
|
+
)
|
|
15
|
+
from fundamental.exceptions import ServerError
|
|
16
|
+
from fundamental.models import TaskStatus
|
|
17
|
+
from fundamental.utils.data import (
|
|
18
|
+
XType,
|
|
19
|
+
api_call,
|
|
20
|
+
create_feature_importance_task_metadata,
|
|
21
|
+
download_result_from_url,
|
|
22
|
+
serialize_df_to_parquet_bytes,
|
|
23
|
+
upload_feature_importance_data,
|
|
24
|
+
)
|
|
25
|
+
from fundamental.utils.polling import wait_for_task_status
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def submit_feature_importance_task(
|
|
29
|
+
X: XType,
|
|
30
|
+
trained_model_id: str,
|
|
31
|
+
client: BaseClient,
|
|
32
|
+
) -> str:
|
|
33
|
+
"""
|
|
34
|
+
Submit a feature importance computation task without waiting for completion.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
X : XType
|
|
39
|
+
Input features for feature importance computation.
|
|
40
|
+
trained_model_id : str
|
|
41
|
+
The trained model ID.
|
|
42
|
+
client : BaseClient
|
|
43
|
+
The client instance.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
str
|
|
48
|
+
The task ID to use with poll_feature_importance_result.
|
|
49
|
+
"""
|
|
50
|
+
X_serialized = serialize_df_to_parquet_bytes(data=X)
|
|
51
|
+
|
|
52
|
+
metadata = create_feature_importance_task_metadata(
|
|
53
|
+
trained_model_id=trained_model_id,
|
|
54
|
+
x_size=len(X_serialized),
|
|
55
|
+
client=client,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
upload_feature_importance_data(
|
|
59
|
+
X_serialized=X_serialized,
|
|
60
|
+
metadata=metadata,
|
|
61
|
+
trained_model_id=trained_model_id,
|
|
62
|
+
client=client,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
json_data = {
|
|
66
|
+
"trained_model_id": trained_model_id,
|
|
67
|
+
"request_id": metadata.request_id,
|
|
68
|
+
"timeout": client.config.feature_importance_timeout,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
response = api_call(
|
|
72
|
+
method="POST",
|
|
73
|
+
full_url=client.config.get_full_feature_importance_url(),
|
|
74
|
+
client=client,
|
|
75
|
+
json=json_data,
|
|
76
|
+
timeout=DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
|
|
77
|
+
)
|
|
78
|
+
data = response.json()
|
|
79
|
+
task_id: str = data["task_id"]
|
|
80
|
+
return task_id
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def remote_get_feature_importance(
|
|
84
|
+
X: XType,
|
|
85
|
+
trained_model_id: str,
|
|
86
|
+
client: BaseClient,
|
|
87
|
+
) -> np.ndarray:
|
|
88
|
+
"""
|
|
89
|
+
Get feature importance for a trained model.
|
|
90
|
+
|
|
91
|
+
Submits the task and waits for completion.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
X : XType
|
|
96
|
+
Input features for feature importance computation.
|
|
97
|
+
trained_model_id : str
|
|
98
|
+
The trained model ID.
|
|
99
|
+
client : BaseClient
|
|
100
|
+
The client instance.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
np.ndarray
|
|
105
|
+
Feature importance values.
|
|
106
|
+
"""
|
|
107
|
+
task_id = submit_feature_importance_task(
|
|
108
|
+
X=X,
|
|
109
|
+
trained_model_id=trained_model_id,
|
|
110
|
+
client=client,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
status_response = wait_for_task_status(
|
|
114
|
+
client=client,
|
|
115
|
+
status_url=f"{client.config.get_full_feature_importance_status_url()}/{trained_model_id}/{task_id}",
|
|
116
|
+
timeout=client.config.feature_importance_timeout,
|
|
117
|
+
polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if not status_response.result:
|
|
121
|
+
raise ServerError("Request failed: Internal Server Error")
|
|
122
|
+
|
|
123
|
+
downloaded_result = download_result_from_url(
|
|
124
|
+
download_url=status_response.result.download_url,
|
|
125
|
+
client=client,
|
|
126
|
+
timeout=client.config.download_feature_importance_result_timeout,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
return np.array(downloaded_result)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def poll_feature_importance_result(
|
|
133
|
+
task_id: str,
|
|
134
|
+
trained_model_id: str,
|
|
135
|
+
client: BaseClient,
|
|
136
|
+
) -> Optional[np.ndarray]:
|
|
137
|
+
"""
|
|
138
|
+
Check the status of a feature importance task.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
task_id : str
|
|
143
|
+
The task ID returned by submit_feature_importance_task.
|
|
144
|
+
trained_model_id : str
|
|
145
|
+
The trained model ID.
|
|
146
|
+
client : BaseClient
|
|
147
|
+
The client instance.
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
Optional[np.ndarray]
|
|
152
|
+
Feature importance values if completed, None if still in progress.
|
|
153
|
+
"""
|
|
154
|
+
status_response = wait_for_task_status(
|
|
155
|
+
client=client,
|
|
156
|
+
status_url=f"{client.config.get_full_feature_importance_status_url()}/{trained_model_id}/{task_id}",
|
|
157
|
+
timeout=client.config.feature_importance_timeout,
|
|
158
|
+
polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
|
|
159
|
+
wait_for_completion=False,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if status_response.status == TaskStatus.SUCCESS:
|
|
163
|
+
if not status_response.result:
|
|
164
|
+
raise ServerError("Request failed: Internal Server Error")
|
|
165
|
+
downloaded_result = download_result_from_url(
|
|
166
|
+
download_url=status_response.result.download_url,
|
|
167
|
+
client=client,
|
|
168
|
+
timeout=client.config.download_feature_importance_result_timeout,
|
|
169
|
+
)
|
|
170
|
+
return np.array(downloaded_result)
|
|
171
|
+
|
|
172
|
+
return None
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model inference services for NEXUS client.
|
|
3
|
+
Handles fit and predict operations with the NEXUS service.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Literal, Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
|
|
12
|
+
from fundamental.clients.base import BaseClient
|
|
13
|
+
from fundamental.constants import (
|
|
14
|
+
DEFAULT_POLLING_INTERVAL_SECONDS,
|
|
15
|
+
DEFAULT_PREDICT_POLLING_REQUESTS_WITHOUT_DELAY,
|
|
16
|
+
DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
|
|
17
|
+
)
|
|
18
|
+
from fundamental.models import TaskStatus
|
|
19
|
+
from fundamental.services.models import ModelsService
|
|
20
|
+
from fundamental.utils.data import (
|
|
21
|
+
XType,
|
|
22
|
+
YType,
|
|
23
|
+
api_call,
|
|
24
|
+
create_fit_task_metadata,
|
|
25
|
+
create_predict_task_metadata,
|
|
26
|
+
download_result_from_url,
|
|
27
|
+
serialize_df_to_parquet_bytes,
|
|
28
|
+
upload_fit_data,
|
|
29
|
+
upload_predict_data,
|
|
30
|
+
)
|
|
31
|
+
from fundamental.utils.polling import wait_for_task_status
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RemoteFitResponse(BaseModel):
|
|
37
|
+
trained_model_id: str
|
|
38
|
+
estimator_fields: dict
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SubmitFitTaskResult(BaseModel):
|
|
42
|
+
"""Result of submitting a fit task."""
|
|
43
|
+
|
|
44
|
+
task_id: str
|
|
45
|
+
trained_model_id: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def submit_fit_task(
|
|
49
|
+
X: XType,
|
|
50
|
+
y: YType,
|
|
51
|
+
task: Literal["classification", "regression"],
|
|
52
|
+
mode: Literal["quality", "speed"],
|
|
53
|
+
client: BaseClient,
|
|
54
|
+
) -> SubmitFitTaskResult:
|
|
55
|
+
"""
|
|
56
|
+
Submit a fit task without waiting for completion.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
X : XType
|
|
61
|
+
Training features.
|
|
62
|
+
y : YType
|
|
63
|
+
Training targets.
|
|
64
|
+
task : {"classification", "regression"}
|
|
65
|
+
Task type.
|
|
66
|
+
mode : {"quality", "speed"}
|
|
67
|
+
Model fit mode.
|
|
68
|
+
client : BaseClient
|
|
69
|
+
The client instance.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
SubmitFitTaskResult
|
|
74
|
+
Result containing task_id and trained_model_id.
|
|
75
|
+
"""
|
|
76
|
+
X_serialized = serialize_df_to_parquet_bytes(data=X)
|
|
77
|
+
y_serialized = serialize_df_to_parquet_bytes(data=y)
|
|
78
|
+
|
|
79
|
+
metadata = create_fit_task_metadata(
|
|
80
|
+
x_train_size=len(X_serialized),
|
|
81
|
+
y_train_size=len(y_serialized),
|
|
82
|
+
client=client,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
upload_fit_data(
|
|
86
|
+
X_serialized=X_serialized,
|
|
87
|
+
y_serialized=y_serialized,
|
|
88
|
+
metadata=metadata,
|
|
89
|
+
client=client,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
json_data = {
|
|
93
|
+
"task": task,
|
|
94
|
+
"mode": mode,
|
|
95
|
+
"trained_model_id": metadata.trained_model_id,
|
|
96
|
+
"timeout": client.config.fit_timeout,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
response = api_call(
|
|
100
|
+
method="POST",
|
|
101
|
+
full_url=client.config.get_full_fit_url(),
|
|
102
|
+
client=client,
|
|
103
|
+
json=json_data,
|
|
104
|
+
timeout=DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
|
|
105
|
+
)
|
|
106
|
+
data = response.json()
|
|
107
|
+
task_id: str = data["task_id"]
|
|
108
|
+
|
|
109
|
+
return SubmitFitTaskResult(
|
|
110
|
+
task_id=task_id,
|
|
111
|
+
trained_model_id=metadata.trained_model_id,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def poll_fit_result(
|
|
116
|
+
task_id: str,
|
|
117
|
+
trained_model_id: str,
|
|
118
|
+
client: BaseClient,
|
|
119
|
+
) -> Optional[RemoteFitResponse]:
|
|
120
|
+
"""
|
|
121
|
+
Check the status of a fit task.
|
|
122
|
+
|
|
123
|
+
Parameters
|
|
124
|
+
----------
|
|
125
|
+
task_id : str
|
|
126
|
+
The task ID returned by submit_fit_task.
|
|
127
|
+
trained_model_id : str
|
|
128
|
+
The trained model ID from the submit result.
|
|
129
|
+
client : BaseClient
|
|
130
|
+
The client instance.
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
Optional[RemoteFitResponse]
|
|
135
|
+
RemoteFitResponse with trained_model_id and estimator_fields if completed,
|
|
136
|
+
None if still in progress.
|
|
137
|
+
"""
|
|
138
|
+
status_response = wait_for_task_status(
|
|
139
|
+
client=client,
|
|
140
|
+
status_url=f"{client.config.get_full_fit_status_url()}/{task_id}",
|
|
141
|
+
timeout=client.config.fit_timeout,
|
|
142
|
+
polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
|
|
143
|
+
wait_for_completion=False,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if status_response.status == TaskStatus.SUCCESS:
|
|
147
|
+
logger.debug("Loading trained model metadata")
|
|
148
|
+
models_service = ModelsService(client=client)
|
|
149
|
+
loaded_model = models_service.load(trained_model_id=trained_model_id)
|
|
150
|
+
|
|
151
|
+
return RemoteFitResponse(
|
|
152
|
+
trained_model_id=trained_model_id,
|
|
153
|
+
estimator_fields=loaded_model.estimator_fields,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def remote_fit(
|
|
160
|
+
X: XType,
|
|
161
|
+
y: YType,
|
|
162
|
+
task: Literal["classification", "regression"],
|
|
163
|
+
mode: Literal["quality", "speed"],
|
|
164
|
+
client: BaseClient,
|
|
165
|
+
) -> RemoteFitResponse:
|
|
166
|
+
"""
|
|
167
|
+
Fit a model (blocking).
|
|
168
|
+
|
|
169
|
+
Submits the task and waits for completion.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
X : XType
|
|
174
|
+
Training features.
|
|
175
|
+
y : YType
|
|
176
|
+
Training targets.
|
|
177
|
+
task : {"classification", "regression"}
|
|
178
|
+
Task type.
|
|
179
|
+
mode : {"quality", "speed"}
|
|
180
|
+
Model fit mode.
|
|
181
|
+
client : BaseClient
|
|
182
|
+
The client instance.
|
|
183
|
+
|
|
184
|
+
Returns
|
|
185
|
+
-------
|
|
186
|
+
RemoteFitResponse
|
|
187
|
+
Service response with trained_model_id and estimator_fields.
|
|
188
|
+
"""
|
|
189
|
+
submit_result = submit_fit_task(
|
|
190
|
+
X=X,
|
|
191
|
+
y=y,
|
|
192
|
+
task=task,
|
|
193
|
+
mode=mode,
|
|
194
|
+
client=client,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
wait_for_task_status(
|
|
198
|
+
client=client,
|
|
199
|
+
status_url=f"{client.config.get_full_fit_status_url()}/{submit_result.task_id}",
|
|
200
|
+
timeout=client.config.fit_timeout,
|
|
201
|
+
polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
logger.debug("Loading trained model metadata")
|
|
205
|
+
models_service = ModelsService(client=client)
|
|
206
|
+
loaded_model = models_service.load(trained_model_id=submit_result.trained_model_id)
|
|
207
|
+
|
|
208
|
+
return RemoteFitResponse(
|
|
209
|
+
trained_model_id=submit_result.trained_model_id,
|
|
210
|
+
estimator_fields=loaded_model.estimator_fields,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def remote_predict(
|
|
215
|
+
X: XType,
|
|
216
|
+
output_type: Literal["preds", "probas"],
|
|
217
|
+
trained_model_id: str,
|
|
218
|
+
client: BaseClient,
|
|
219
|
+
) -> np.ndarray:
|
|
220
|
+
"""
|
|
221
|
+
Make predictions using a trained model identified by trained_model_id.
|
|
222
|
+
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
X : XType
|
|
226
|
+
Input features for prediction.
|
|
227
|
+
output_type : {"preds", "probas"}
|
|
228
|
+
Output type.
|
|
229
|
+
trained_model_id : str
|
|
230
|
+
The model ID generated by after the fit operation.
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
np.ndarray
|
|
235
|
+
Prediction results.
|
|
236
|
+
"""
|
|
237
|
+
X_serialized = serialize_df_to_parquet_bytes(data=X)
|
|
238
|
+
|
|
239
|
+
metadata = create_predict_task_metadata(
|
|
240
|
+
trained_model_id=trained_model_id,
|
|
241
|
+
x_test_size=len(X_serialized),
|
|
242
|
+
client=client,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
upload_predict_data(
|
|
246
|
+
X_serialized=X_serialized,
|
|
247
|
+
metadata=metadata,
|
|
248
|
+
trained_model_id=trained_model_id,
|
|
249
|
+
client=client,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
json_data = {
|
|
253
|
+
"output_type": output_type,
|
|
254
|
+
"trained_model_id": trained_model_id,
|
|
255
|
+
"request_id": metadata.request_id,
|
|
256
|
+
"timeout": client.config.predict_timeout,
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
response = api_call(
|
|
260
|
+
method="POST",
|
|
261
|
+
full_url=client.config.get_full_predict_url(),
|
|
262
|
+
client=client,
|
|
263
|
+
json=json_data,
|
|
264
|
+
timeout=DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
|
|
265
|
+
)
|
|
266
|
+
data = response.json()
|
|
267
|
+
task_id = data["task_id"]
|
|
268
|
+
|
|
269
|
+
status_response = wait_for_task_status(
|
|
270
|
+
client=client,
|
|
271
|
+
status_url=f"{client.config.get_full_predict_status_url()}/{trained_model_id}/{task_id}",
|
|
272
|
+
timeout=client.config.predict_timeout,
|
|
273
|
+
polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
|
|
274
|
+
polling_requests_without_delay=DEFAULT_PREDICT_POLLING_REQUESTS_WITHOUT_DELAY,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
preds = download_result_from_url(
|
|
278
|
+
download_url=status_response.result.download_url, # type: ignore[union-attr]
|
|
279
|
+
client=client,
|
|
280
|
+
timeout=client.config.download_prediction_result_timeout,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return np.array(preds)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
"""Model management services for NEXUS client."""
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, Dict
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from typing_extensions import TypeAlias
|
|
9
|
+
|
|
10
|
+
from fundamental.clients.base import BaseClient
|
|
11
|
+
from fundamental.exceptions import ValidationError
|
|
12
|
+
from fundamental.models import (
|
|
13
|
+
DeleteTrainedModelResponse,
|
|
14
|
+
TrainedModelMetadata,
|
|
15
|
+
UpdateAttributesResponse,
|
|
16
|
+
)
|
|
17
|
+
from fundamental.utils.http import api_call
|
|
18
|
+
from fundamental.utils.safetensors_deserialize import load_estimator_fields_from_bytes
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
TrainedModelsListResponse: TypeAlias = list[str]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class LoadedModelResponse(BaseModel):
|
|
26
|
+
estimator_fields: Dict[str, Any]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ModelsService:
|
|
30
|
+
"""Service for model management operations."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, client: BaseClient) -> None:
|
|
33
|
+
"""Initialize the models service.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
client: Client instance
|
|
37
|
+
"""
|
|
38
|
+
self.client = client
|
|
39
|
+
|
|
40
|
+
def _validate_model_id(self, model_id: str) -> None:
|
|
41
|
+
if not model_id or not model_id.strip():
|
|
42
|
+
raise ValidationError("model_id cannot be empty. Please provide a valid model_id.")
|
|
43
|
+
|
|
44
|
+
def list(self) -> TrainedModelsListResponse:
|
|
45
|
+
"""
|
|
46
|
+
List all trained models.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
TrainedModelsListResponse
|
|
51
|
+
List of trained model IDs.
|
|
52
|
+
"""
|
|
53
|
+
response = api_call(
|
|
54
|
+
method="GET",
|
|
55
|
+
full_url=self.client.config.get_full_model_management_url(),
|
|
56
|
+
client=self.client,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
res: list[str] = response.json()
|
|
60
|
+
return res
|
|
61
|
+
|
|
62
|
+
def delete(self, model_id: str) -> DeleteTrainedModelResponse:
|
|
63
|
+
"""
|
|
64
|
+
Delete a specific trained model.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
model_id : str
|
|
69
|
+
The ID of the model to delete.
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
DeleteTrainedModelResponse
|
|
74
|
+
Response from the deletion operation.
|
|
75
|
+
|
|
76
|
+
Raises
|
|
77
|
+
------
|
|
78
|
+
ValidationError
|
|
79
|
+
If model_id is empty or invalid.
|
|
80
|
+
"""
|
|
81
|
+
self._validate_model_id(model_id)
|
|
82
|
+
|
|
83
|
+
response = api_call(
|
|
84
|
+
method="DELETE",
|
|
85
|
+
full_url=f"{self.client.config.get_full_model_management_url()}/{model_id}",
|
|
86
|
+
client=self.client,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return DeleteTrainedModelResponse(**response.json())
|
|
90
|
+
|
|
91
|
+
def get(self, model_id: str) -> TrainedModelMetadata:
|
|
92
|
+
"""
|
|
93
|
+
Get information about a specific trained model.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
model_id : str
|
|
98
|
+
The ID of the model to retrieve.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
TrainedModelMetadata
|
|
103
|
+
Model information dictionary.
|
|
104
|
+
|
|
105
|
+
Raises
|
|
106
|
+
------
|
|
107
|
+
ValidationError
|
|
108
|
+
If model_id is empty or invalid.
|
|
109
|
+
"""
|
|
110
|
+
self._validate_model_id(model_id)
|
|
111
|
+
|
|
112
|
+
response = api_call(
|
|
113
|
+
method="GET",
|
|
114
|
+
full_url=f"{self.client.config.get_full_model_management_url()}/{model_id}",
|
|
115
|
+
client=self.client,
|
|
116
|
+
max_retries=3,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return TrainedModelMetadata(**response.json())
|
|
120
|
+
|
|
121
|
+
def set_attributes(
|
|
122
|
+
self,
|
|
123
|
+
model_id: str,
|
|
124
|
+
attributes: dict[str, str],
|
|
125
|
+
) -> UpdateAttributesResponse:
|
|
126
|
+
"""
|
|
127
|
+
Set attributes for a specific trained model.
|
|
128
|
+
|
|
129
|
+
Parameters
|
|
130
|
+
----------
|
|
131
|
+
model_id : str
|
|
132
|
+
The ID of the model to update.
|
|
133
|
+
attributes : dict[str, str]
|
|
134
|
+
The attributes to set.
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
UpdateAttributesResponse
|
|
139
|
+
Response containing the updated attributes.
|
|
140
|
+
|
|
141
|
+
Raises
|
|
142
|
+
------
|
|
143
|
+
ValidationError
|
|
144
|
+
If model_id is empty or invalid.
|
|
145
|
+
"""
|
|
146
|
+
self._validate_model_id(model_id)
|
|
147
|
+
|
|
148
|
+
response = api_call(
|
|
149
|
+
method="PATCH",
|
|
150
|
+
full_url=f"{self.client.config.get_full_model_management_url()}/{model_id}/attributes",
|
|
151
|
+
client=self.client,
|
|
152
|
+
json={"attributes": attributes},
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return UpdateAttributesResponse(**response.json())
|
|
156
|
+
|
|
157
|
+
def load(
|
|
158
|
+
self,
|
|
159
|
+
trained_model_id: str,
|
|
160
|
+
) -> LoadedModelResponse:
|
|
161
|
+
"""
|
|
162
|
+
Load a specific trained model.
|
|
163
|
+
|
|
164
|
+
Parameters
|
|
165
|
+
----------
|
|
166
|
+
trained_model_id : str
|
|
167
|
+
The ID of the model to load.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
LoadedModelResponse
|
|
172
|
+
Model information dictionary.
|
|
173
|
+
"""
|
|
174
|
+
self._validate_model_id(trained_model_id)
|
|
175
|
+
|
|
176
|
+
response = api_call(
|
|
177
|
+
method="GET",
|
|
178
|
+
full_url=f"{self.client.config.get_full_model_management_url()}/{trained_model_id}/load_model",
|
|
179
|
+
client=self.client,
|
|
180
|
+
)
|
|
181
|
+
response_json = response.json()
|
|
182
|
+
raw_bytes = base64.b64decode(response_json["estimator_fields"], validate=True)
|
|
183
|
+
estimator_fields = load_estimator_fields_from_bytes(raw_bytes)
|
|
184
|
+
return LoadedModelResponse(
|
|
185
|
+
estimator_fields=estimator_fields,
|
|
186
|
+
)
|
|
File without changes
|