fundamental-client 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,172 @@
1
+ """
2
+ Feature importance services for NEXUS client.
3
+ Handles feature importance computation with the NEXUS service.
4
+ """
5
+
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+
10
+ from fundamental.clients.base import BaseClient
11
+ from fundamental.constants import (
12
+ DEFAULT_POLLING_INTERVAL_SECONDS,
13
+ DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
14
+ )
15
+ from fundamental.exceptions import ServerError
16
+ from fundamental.models import TaskStatus
17
+ from fundamental.utils.data import (
18
+ XType,
19
+ api_call,
20
+ create_feature_importance_task_metadata,
21
+ download_result_from_url,
22
+ serialize_df_to_parquet_bytes,
23
+ upload_feature_importance_data,
24
+ )
25
+ from fundamental.utils.polling import wait_for_task_status
26
+
27
+
28
+ def submit_feature_importance_task(
29
+ X: XType,
30
+ trained_model_id: str,
31
+ client: BaseClient,
32
+ ) -> str:
33
+ """
34
+ Submit a feature importance computation task without waiting for completion.
35
+
36
+ Parameters
37
+ ----------
38
+ X : XType
39
+ Input features for feature importance computation.
40
+ trained_model_id : str
41
+ The trained model ID.
42
+ client : BaseClient
43
+ The client instance.
44
+
45
+ Returns
46
+ -------
47
+ str
48
+ The task ID to use with poll_feature_importance_result.
49
+ """
50
+ X_serialized = serialize_df_to_parquet_bytes(data=X)
51
+
52
+ metadata = create_feature_importance_task_metadata(
53
+ trained_model_id=trained_model_id,
54
+ x_size=len(X_serialized),
55
+ client=client,
56
+ )
57
+
58
+ upload_feature_importance_data(
59
+ X_serialized=X_serialized,
60
+ metadata=metadata,
61
+ trained_model_id=trained_model_id,
62
+ client=client,
63
+ )
64
+
65
+ json_data = {
66
+ "trained_model_id": trained_model_id,
67
+ "request_id": metadata.request_id,
68
+ "timeout": client.config.feature_importance_timeout,
69
+ }
70
+
71
+ response = api_call(
72
+ method="POST",
73
+ full_url=client.config.get_full_feature_importance_url(),
74
+ client=client,
75
+ json=json_data,
76
+ timeout=DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
77
+ )
78
+ data = response.json()
79
+ task_id: str = data["task_id"]
80
+ return task_id
81
+
82
+
83
+ def remote_get_feature_importance(
84
+ X: XType,
85
+ trained_model_id: str,
86
+ client: BaseClient,
87
+ ) -> np.ndarray:
88
+ """
89
+ Get feature importance for a trained model.
90
+
91
+ Submits the task and waits for completion.
92
+
93
+ Parameters
94
+ ----------
95
+ X : XType
96
+ Input features for feature importance computation.
97
+ trained_model_id : str
98
+ The trained model ID.
99
+ client : BaseClient
100
+ The client instance.
101
+
102
+ Returns
103
+ -------
104
+ np.ndarray
105
+ Feature importance values.
106
+ """
107
+ task_id = submit_feature_importance_task(
108
+ X=X,
109
+ trained_model_id=trained_model_id,
110
+ client=client,
111
+ )
112
+
113
+ status_response = wait_for_task_status(
114
+ client=client,
115
+ status_url=f"{client.config.get_full_feature_importance_status_url()}/{trained_model_id}/{task_id}",
116
+ timeout=client.config.feature_importance_timeout,
117
+ polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
118
+ )
119
+
120
+ if not status_response.result:
121
+ raise ServerError("Request failed: Internal Server Error")
122
+
123
+ downloaded_result = download_result_from_url(
124
+ download_url=status_response.result.download_url,
125
+ client=client,
126
+ timeout=client.config.download_feature_importance_result_timeout,
127
+ )
128
+
129
+ return np.array(downloaded_result)
130
+
131
+
132
+ def poll_feature_importance_result(
133
+ task_id: str,
134
+ trained_model_id: str,
135
+ client: BaseClient,
136
+ ) -> Optional[np.ndarray]:
137
+ """
138
+ Check the status of a feature importance task.
139
+
140
+ Parameters
141
+ ----------
142
+ task_id : str
143
+ The task ID returned by submit_feature_importance_task.
144
+ trained_model_id : str
145
+ The trained model ID.
146
+ client : BaseClient
147
+ The client instance.
148
+
149
+ Returns
150
+ -------
151
+ Optional[np.ndarray]
152
+ Feature importance values if completed, None if still in progress.
153
+ """
154
+ status_response = wait_for_task_status(
155
+ client=client,
156
+ status_url=f"{client.config.get_full_feature_importance_status_url()}/{trained_model_id}/{task_id}",
157
+ timeout=client.config.feature_importance_timeout,
158
+ polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
159
+ wait_for_completion=False,
160
+ )
161
+
162
+ if status_response.status == TaskStatus.SUCCESS:
163
+ if not status_response.result:
164
+ raise ServerError("Request failed: Internal Server Error")
165
+ downloaded_result = download_result_from_url(
166
+ download_url=status_response.result.download_url,
167
+ client=client,
168
+ timeout=client.config.download_feature_importance_result_timeout,
169
+ )
170
+ return np.array(downloaded_result)
171
+
172
+ return None
@@ -0,0 +1,283 @@
1
+ """
2
+ Model inference services for NEXUS client.
3
+ Handles fit and predict operations with the NEXUS service.
4
+ """
5
+
6
+ import logging
7
+ from typing import Literal, Optional
8
+
9
+ import numpy as np
10
+ from pydantic import BaseModel
11
+
12
+ from fundamental.clients.base import BaseClient
13
+ from fundamental.constants import (
14
+ DEFAULT_POLLING_INTERVAL_SECONDS,
15
+ DEFAULT_PREDICT_POLLING_REQUESTS_WITHOUT_DELAY,
16
+ DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
17
+ )
18
+ from fundamental.models import TaskStatus
19
+ from fundamental.services.models import ModelsService
20
+ from fundamental.utils.data import (
21
+ XType,
22
+ YType,
23
+ api_call,
24
+ create_fit_task_metadata,
25
+ create_predict_task_metadata,
26
+ download_result_from_url,
27
+ serialize_df_to_parquet_bytes,
28
+ upload_fit_data,
29
+ upload_predict_data,
30
+ )
31
+ from fundamental.utils.polling import wait_for_task_status
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class RemoteFitResponse(BaseModel):
37
+ trained_model_id: str
38
+ estimator_fields: dict
39
+
40
+
41
+ class SubmitFitTaskResult(BaseModel):
42
+ """Result of submitting a fit task."""
43
+
44
+ task_id: str
45
+ trained_model_id: str
46
+
47
+
48
+ def submit_fit_task(
49
+ X: XType,
50
+ y: YType,
51
+ task: Literal["classification", "regression"],
52
+ mode: Literal["quality", "speed"],
53
+ client: BaseClient,
54
+ ) -> SubmitFitTaskResult:
55
+ """
56
+ Submit a fit task without waiting for completion.
57
+
58
+ Parameters
59
+ ----------
60
+ X : XType
61
+ Training features.
62
+ y : YType
63
+ Training targets.
64
+ task : {"classification", "regression"}
65
+ Task type.
66
+ mode : {"quality", "speed"}
67
+ Model fit mode.
68
+ client : BaseClient
69
+ The client instance.
70
+
71
+ Returns
72
+ -------
73
+ SubmitFitTaskResult
74
+ Result containing task_id and trained_model_id.
75
+ """
76
+ X_serialized = serialize_df_to_parquet_bytes(data=X)
77
+ y_serialized = serialize_df_to_parquet_bytes(data=y)
78
+
79
+ metadata = create_fit_task_metadata(
80
+ x_train_size=len(X_serialized),
81
+ y_train_size=len(y_serialized),
82
+ client=client,
83
+ )
84
+
85
+ upload_fit_data(
86
+ X_serialized=X_serialized,
87
+ y_serialized=y_serialized,
88
+ metadata=metadata,
89
+ client=client,
90
+ )
91
+
92
+ json_data = {
93
+ "task": task,
94
+ "mode": mode,
95
+ "trained_model_id": metadata.trained_model_id,
96
+ "timeout": client.config.fit_timeout,
97
+ }
98
+
99
+ response = api_call(
100
+ method="POST",
101
+ full_url=client.config.get_full_fit_url(),
102
+ client=client,
103
+ json=json_data,
104
+ timeout=DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
105
+ )
106
+ data = response.json()
107
+ task_id: str = data["task_id"]
108
+
109
+ return SubmitFitTaskResult(
110
+ task_id=task_id,
111
+ trained_model_id=metadata.trained_model_id,
112
+ )
113
+
114
+
115
+ def poll_fit_result(
116
+ task_id: str,
117
+ trained_model_id: str,
118
+ client: BaseClient,
119
+ ) -> Optional[RemoteFitResponse]:
120
+ """
121
+ Check the status of a fit task.
122
+
123
+ Parameters
124
+ ----------
125
+ task_id : str
126
+ The task ID returned by submit_fit_task.
127
+ trained_model_id : str
128
+ The trained model ID from the submit result.
129
+ client : BaseClient
130
+ The client instance.
131
+
132
+ Returns
133
+ -------
134
+ Optional[RemoteFitResponse]
135
+ RemoteFitResponse with trained_model_id and estimator_fields if completed,
136
+ None if still in progress.
137
+ """
138
+ status_response = wait_for_task_status(
139
+ client=client,
140
+ status_url=f"{client.config.get_full_fit_status_url()}/{task_id}",
141
+ timeout=client.config.fit_timeout,
142
+ polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
143
+ wait_for_completion=False,
144
+ )
145
+
146
+ if status_response.status == TaskStatus.SUCCESS:
147
+ logger.debug("Loading trained model metadata")
148
+ models_service = ModelsService(client=client)
149
+ loaded_model = models_service.load(trained_model_id=trained_model_id)
150
+
151
+ return RemoteFitResponse(
152
+ trained_model_id=trained_model_id,
153
+ estimator_fields=loaded_model.estimator_fields,
154
+ )
155
+
156
+ return None
157
+
158
+
159
+ def remote_fit(
160
+ X: XType,
161
+ y: YType,
162
+ task: Literal["classification", "regression"],
163
+ mode: Literal["quality", "speed"],
164
+ client: BaseClient,
165
+ ) -> RemoteFitResponse:
166
+ """
167
+ Fit a model (blocking).
168
+
169
+ Submits the task and waits for completion.
170
+
171
+ Parameters
172
+ ----------
173
+ X : XType
174
+ Training features.
175
+ y : YType
176
+ Training targets.
177
+ task : {"classification", "regression"}
178
+ Task type.
179
+ mode : {"quality", "speed"}
180
+ Model fit mode.
181
+ client : BaseClient
182
+ The client instance.
183
+
184
+ Returns
185
+ -------
186
+ RemoteFitResponse
187
+ Service response with trained_model_id and estimator_fields.
188
+ """
189
+ submit_result = submit_fit_task(
190
+ X=X,
191
+ y=y,
192
+ task=task,
193
+ mode=mode,
194
+ client=client,
195
+ )
196
+
197
+ wait_for_task_status(
198
+ client=client,
199
+ status_url=f"{client.config.get_full_fit_status_url()}/{submit_result.task_id}",
200
+ timeout=client.config.fit_timeout,
201
+ polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
202
+ )
203
+
204
+ logger.debug("Loading trained model metadata")
205
+ models_service = ModelsService(client=client)
206
+ loaded_model = models_service.load(trained_model_id=submit_result.trained_model_id)
207
+
208
+ return RemoteFitResponse(
209
+ trained_model_id=submit_result.trained_model_id,
210
+ estimator_fields=loaded_model.estimator_fields,
211
+ )
212
+
213
+
214
+ def remote_predict(
215
+ X: XType,
216
+ output_type: Literal["preds", "probas"],
217
+ trained_model_id: str,
218
+ client: BaseClient,
219
+ ) -> np.ndarray:
220
+ """
221
+ Make predictions using a trained model identified by trained_model_id.
222
+
223
+ Parameters
224
+ ----------
225
+ X : XType
226
+ Input features for prediction.
227
+ output_type : {"preds", "probas"}
228
+ Output type.
229
+ trained_model_id : str
230
+ The model ID generated by after the fit operation.
231
+
232
+ Returns
233
+ -------
234
+ np.ndarray
235
+ Prediction results.
236
+ """
237
+ X_serialized = serialize_df_to_parquet_bytes(data=X)
238
+
239
+ metadata = create_predict_task_metadata(
240
+ trained_model_id=trained_model_id,
241
+ x_test_size=len(X_serialized),
242
+ client=client,
243
+ )
244
+
245
+ upload_predict_data(
246
+ X_serialized=X_serialized,
247
+ metadata=metadata,
248
+ trained_model_id=trained_model_id,
249
+ client=client,
250
+ )
251
+
252
+ json_data = {
253
+ "output_type": output_type,
254
+ "trained_model_id": trained_model_id,
255
+ "request_id": metadata.request_id,
256
+ "timeout": client.config.predict_timeout,
257
+ }
258
+
259
+ response = api_call(
260
+ method="POST",
261
+ full_url=client.config.get_full_predict_url(),
262
+ client=client,
263
+ json=json_data,
264
+ timeout=DEFAULT_SUBMIT_REQUEST_TIMEOUT_SECONDS,
265
+ )
266
+ data = response.json()
267
+ task_id = data["task_id"]
268
+
269
+ status_response = wait_for_task_status(
270
+ client=client,
271
+ status_url=f"{client.config.get_full_predict_status_url()}/{trained_model_id}/{task_id}",
272
+ timeout=client.config.predict_timeout,
273
+ polling_interval=DEFAULT_POLLING_INTERVAL_SECONDS,
274
+ polling_requests_without_delay=DEFAULT_PREDICT_POLLING_REQUESTS_WITHOUT_DELAY,
275
+ )
276
+
277
+ preds = download_result_from_url(
278
+ download_url=status_response.result.download_url, # type: ignore[union-attr]
279
+ client=client,
280
+ timeout=client.config.download_prediction_result_timeout,
281
+ )
282
+
283
+ return np.array(preds)
@@ -0,0 +1,186 @@
1
+ """Model management services for NEXUS client."""
2
+
3
+ import base64
4
+ import logging
5
+ from typing import Any, Dict
6
+
7
+ from pydantic import BaseModel
8
+ from typing_extensions import TypeAlias
9
+
10
+ from fundamental.clients.base import BaseClient
11
+ from fundamental.exceptions import ValidationError
12
+ from fundamental.models import (
13
+ DeleteTrainedModelResponse,
14
+ TrainedModelMetadata,
15
+ UpdateAttributesResponse,
16
+ )
17
+ from fundamental.utils.http import api_call
18
+ from fundamental.utils.safetensors_deserialize import load_estimator_fields_from_bytes
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ TrainedModelsListResponse: TypeAlias = list[str]
23
+
24
+
25
+ class LoadedModelResponse(BaseModel):
26
+ estimator_fields: Dict[str, Any]
27
+
28
+
29
+ class ModelsService:
30
+ """Service for model management operations."""
31
+
32
+ def __init__(self, client: BaseClient) -> None:
33
+ """Initialize the models service.
34
+
35
+ Args:
36
+ client: Client instance
37
+ """
38
+ self.client = client
39
+
40
+ def _validate_model_id(self, model_id: str) -> None:
41
+ if not model_id or not model_id.strip():
42
+ raise ValidationError("model_id cannot be empty. Please provide a valid model_id.")
43
+
44
+ def list(self) -> TrainedModelsListResponse:
45
+ """
46
+ List all trained models.
47
+
48
+ Returns
49
+ -------
50
+ TrainedModelsListResponse
51
+ List of trained model IDs.
52
+ """
53
+ response = api_call(
54
+ method="GET",
55
+ full_url=self.client.config.get_full_model_management_url(),
56
+ client=self.client,
57
+ )
58
+
59
+ res: list[str] = response.json()
60
+ return res
61
+
62
+ def delete(self, model_id: str) -> DeleteTrainedModelResponse:
63
+ """
64
+ Delete a specific trained model.
65
+
66
+ Parameters
67
+ ----------
68
+ model_id : str
69
+ The ID of the model to delete.
70
+
71
+ Returns
72
+ -------
73
+ DeleteTrainedModelResponse
74
+ Response from the deletion operation.
75
+
76
+ Raises
77
+ ------
78
+ ValidationError
79
+ If model_id is empty or invalid.
80
+ """
81
+ self._validate_model_id(model_id)
82
+
83
+ response = api_call(
84
+ method="DELETE",
85
+ full_url=f"{self.client.config.get_full_model_management_url()}/{model_id}",
86
+ client=self.client,
87
+ )
88
+
89
+ return DeleteTrainedModelResponse(**response.json())
90
+
91
+ def get(self, model_id: str) -> TrainedModelMetadata:
92
+ """
93
+ Get information about a specific trained model.
94
+
95
+ Parameters
96
+ ----------
97
+ model_id : str
98
+ The ID of the model to retrieve.
99
+
100
+ Returns
101
+ -------
102
+ TrainedModelMetadata
103
+ Model information dictionary.
104
+
105
+ Raises
106
+ ------
107
+ ValidationError
108
+ If model_id is empty or invalid.
109
+ """
110
+ self._validate_model_id(model_id)
111
+
112
+ response = api_call(
113
+ method="GET",
114
+ full_url=f"{self.client.config.get_full_model_management_url()}/{model_id}",
115
+ client=self.client,
116
+ max_retries=3,
117
+ )
118
+
119
+ return TrainedModelMetadata(**response.json())
120
+
121
+ def set_attributes(
122
+ self,
123
+ model_id: str,
124
+ attributes: dict[str, str],
125
+ ) -> UpdateAttributesResponse:
126
+ """
127
+ Set attributes for a specific trained model.
128
+
129
+ Parameters
130
+ ----------
131
+ model_id : str
132
+ The ID of the model to update.
133
+ attributes : dict[str, str]
134
+ The attributes to set.
135
+
136
+ Returns
137
+ -------
138
+ UpdateAttributesResponse
139
+ Response containing the updated attributes.
140
+
141
+ Raises
142
+ ------
143
+ ValidationError
144
+ If model_id is empty or invalid.
145
+ """
146
+ self._validate_model_id(model_id)
147
+
148
+ response = api_call(
149
+ method="PATCH",
150
+ full_url=f"{self.client.config.get_full_model_management_url()}/{model_id}/attributes",
151
+ client=self.client,
152
+ json={"attributes": attributes},
153
+ )
154
+
155
+ return UpdateAttributesResponse(**response.json())
156
+
157
+ def load(
158
+ self,
159
+ trained_model_id: str,
160
+ ) -> LoadedModelResponse:
161
+ """
162
+ Load a specific trained model.
163
+
164
+ Parameters
165
+ ----------
166
+ trained_model_id : str
167
+ The ID of the model to load.
168
+
169
+ Returns
170
+ -------
171
+ LoadedModelResponse
172
+ Model information dictionary.
173
+ """
174
+ self._validate_model_id(trained_model_id)
175
+
176
+ response = api_call(
177
+ method="GET",
178
+ full_url=f"{self.client.config.get_full_model_management_url()}/{trained_model_id}/load_model",
179
+ client=self.client,
180
+ )
181
+ response_json = response.json()
182
+ raw_bytes = base64.b64decode(response_json["estimator_fields"], validate=True)
183
+ estimator_fields = load_estimator_fields_from_bytes(raw_bytes)
184
+ return LoadedModelResponse(
185
+ estimator_fields=estimator_fields,
186
+ )
File without changes