fundamental-client 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,437 @@
1
+ """
2
+ Data utilities for NEXUS client.
3
+
4
+ This module provides data serialization and transformation utilities.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import typing
11
+ from io import BytesIO
12
+ from typing import Any, Dict, List, Optional, Union
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import pyarrow.lib as palib
17
+ from pydantic import BaseModel
18
+ from sklearn.base import BaseEstimator
19
+
20
+ from fundamental.clients.base import BaseClient
21
+ from fundamental.constants import (
22
+ DEFAULT_COMPLETE_MULTIPART_UPLOAD_TIMEOUT_SECONDS,
23
+ DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS,
24
+ DEFAULT_MODEL_METADATA_GENERATE_TIMEOUT_SECONDS,
25
+ DEFAULT_MODEL_METADATA_UPLOAD_TIMEOUT_SECONDS,
26
+ )
27
+ from fundamental.models import (
28
+ FeatureImportanceMultipartMetadataResponse,
29
+ FitMultipartMetadataResponse,
30
+ MultipartUploadInfo,
31
+ PredictMultipartMetadataResponse,
32
+ )
33
+ from fundamental.utils.http import api_call
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ XType = Union[pd.DataFrame, np.ndarray]
38
+ YType = Union[np.ndarray, pd.Series]
39
+
40
+
41
+ # Version-compatible _check_n_features
42
+ # In sklearn 1.6+, _check_n_features was moved to sklearn.utils.validation
43
+ # In sklearn 1.5.x, it's a method on BaseEstimator
44
+ try:
45
+ from sklearn.utils.validation import (
46
+ _check_n_features, # pyright: ignore[reportAttributeAccessIssue]
47
+ )
48
+
49
+ def check_n_features_compat(estimator: BaseEstimator, X: XType, *, reset: bool) -> None:
50
+ """Version-compatible wrapper for sklearn's _check_n_features."""
51
+ _check_n_features(estimator, X, reset=reset)
52
+
53
+ except ImportError:
54
+ # sklearn < 1.6.0: _check_n_features is a method on BaseEstimator
55
+ def check_n_features_compat(estimator: BaseEstimator, X: XType, *, reset: bool) -> None:
56
+ """Version-compatible wrapper for sklearn's _check_n_features."""
57
+ estimator._check_n_features(X, reset=reset) # type: ignore[attr-defined]
58
+
59
+
60
+ class FileUpload(BaseModel):
61
+ """File upload with metadata for multipart HTTP requests."""
62
+
63
+ file_name: str
64
+ file_type: str
65
+ content: bytes
66
+
67
+
68
+ class PredictMetadata(BaseModel):
69
+ request_id: str
70
+ x_test_upload: MultipartUploadInfo
71
+
72
+
73
+ def serialize_df_to_parquet_bytes(
74
+ data: Union[pd.DataFrame, np.ndarray, pd.Series],
75
+ ) -> bytes:
76
+ """Serialize data to Parquet-formatted bytes for HTTP transmission.
77
+
78
+ Raises
79
+ ------
80
+ ValueError
81
+ If data contains types that cannot be serialized to Parquet.
82
+ """
83
+ if isinstance(data, np.ndarray):
84
+ data = pd.DataFrame(data)
85
+ elif isinstance(data, pd.Series):
86
+ data = data.to_frame()
87
+
88
+ try:
89
+ buffer = BytesIO()
90
+ data.to_parquet(
91
+ buffer,
92
+ index=False,
93
+ engine="pyarrow",
94
+ write_statistics=False,
95
+ )
96
+ return buffer.getvalue()
97
+ except (palib.ArrowInvalid, palib.ArrowNotImplementedError) as e:
98
+ error_msg = str(e)
99
+ # Provide specific error messages for common issues
100
+ if "cannot mix struct and non-struct" in error_msg:
101
+ raise ValueError(
102
+ "Data contains mixed types that cannot be serialized. "
103
+ "Please ensure all columns have consistent types."
104
+ ) from e
105
+ if "Unsupported numpy type" in error_msg or "complex" in error_msg.lower():
106
+ raise ValueError(
107
+ "Data contains unsupported types (e.g., complex numbers). "
108
+ "Please convert to numeric or string types."
109
+ ) from e
110
+ # Generic message for other PyArrow errors
111
+ raise ValueError(
112
+ "Data cannot be serialized. Please ensure data contains only numeric or string types."
113
+ ) from e
114
+
115
+
116
+ def to_httpx_post_file_format(
117
+ file_uploads: List[FileUpload],
118
+ ) -> Dict[str, tuple[str, bytes]]:
119
+ """Convert file uploads to httpx multipart format."""
120
+ files = {}
121
+ for f in file_uploads:
122
+ file = f"{f.file_name}.{f.file_type}"
123
+ files[f.file_name] = (file, f.content)
124
+ return files
125
+
126
+
127
+ def validate_inputs_type(X: Any, y: Optional[Any] = None) -> None: # noqa: ANN401
128
+ """
129
+ Validate X and y inputs match expected types.
130
+
131
+ Parameters
132
+ ----------
133
+ X : XType
134
+ Input features.
135
+ y : YType, optional
136
+ Target values.
137
+
138
+ Raises
139
+ ------
140
+ TypeError
141
+ If inputs don't match expected types.
142
+ """
143
+
144
+ def _validate_param(var: Any, expected_type: Any, param_name: str) -> None: # noqa: ANN401
145
+ if var is not None:
146
+ if hasattr(expected_type, "__args__"): # Union Type
147
+ valid_types = typing.get_args(expected_type)
148
+ else:
149
+ valid_types = (expected_type,)
150
+
151
+ if not isinstance(var, valid_types):
152
+ type_names = " or ".join(t.__name__ for t in valid_types)
153
+ raise TypeError(
154
+ f"{param_name} must be {type_names}, got {type(var).__name__}. "
155
+ f"Please convert your data to a supported format."
156
+ )
157
+
158
+ _validate_param(var=X, expected_type=XType, param_name="X")
159
+ _validate_param(var=y, expected_type=YType, param_name="y")
160
+
161
+
162
+ def validate_data(
163
+ X: XType,
164
+ y: Optional[YType] = None,
165
+ ) -> None:
166
+ """
167
+ Validate data shapes and alignment.
168
+
169
+ Parameters
170
+ ----------
171
+ X : XType
172
+ Input features.
173
+ y : YType, optional
174
+ Target values.
175
+
176
+ Raises
177
+ ------
178
+ ValueError
179
+ If data is empty or misaligned.
180
+ """
181
+
182
+ if X.shape[0] == 0:
183
+ raise ValueError(
184
+ "Training data X cannot be empty. Please provide feature data with at least one sample."
185
+ )
186
+
187
+ if y is not None and len(y) == 0:
188
+ raise ValueError(
189
+ "Target data y cannot be empty. Please provide target values for training."
190
+ )
191
+
192
+ if y is not None and X.shape[0] != y.shape[0]:
193
+ raise ValueError(
194
+ f"Mismatched sample counts: X has {X.shape[0]} samples but y has "
195
+ f"{y.shape[0]} samples. Both must have the same number of samples."
196
+ )
197
+
198
+
199
+ def _upload_data_in_parts(
200
+ data: bytes, upload_urls: List[str], part_size: int, client: BaseClient
201
+ ) -> List[Dict[str, Any]]:
202
+ parts = []
203
+ num_parts = len(upload_urls)
204
+
205
+ for part_number, upload_url in enumerate(upload_urls, start=1):
206
+ # Calculate the start and end positions for this part
207
+ start = (part_number - 1) * part_size
208
+ end = min(start + part_size, len(data))
209
+ part_data = data[start:end]
210
+
211
+ logger.info(f"Uploading part {part_number}/{num_parts} ({len(part_data)} bytes)")
212
+
213
+ response = api_call(
214
+ method="PUT",
215
+ full_url=upload_url,
216
+ client=client,
217
+ content=part_data,
218
+ timeout=DEFAULT_MODEL_METADATA_UPLOAD_TIMEOUT_SECONDS,
219
+ )
220
+
221
+ etag = response.headers.get("ETag", "").strip('"')
222
+ parts.append({"PartNumber": part_number, "ETag": etag})
223
+
224
+ return parts
225
+
226
+
227
+ def _complete_multipart_upload(
228
+ trained_model_id: str,
229
+ file_type: str,
230
+ upload_id: str,
231
+ parts: List[Dict[str, Any]],
232
+ client: BaseClient,
233
+ request_id: Optional[str] = None,
234
+ ) -> None:
235
+ json_data = {
236
+ "trained_model_id": trained_model_id,
237
+ "file_type": file_type,
238
+ "upload_id": upload_id,
239
+ "parts": parts,
240
+ }
241
+
242
+ if request_id:
243
+ json_data["request_id"] = request_id
244
+
245
+ api_call(
246
+ method="POST",
247
+ full_url=client.config.get_full_complete_multipart_upload_url(),
248
+ client=client,
249
+ json=json_data,
250
+ timeout=DEFAULT_COMPLETE_MULTIPART_UPLOAD_TIMEOUT_SECONDS,
251
+ )
252
+
253
+
254
+ def _upload_and_complete_multipart(
255
+ data: bytes,
256
+ upload_info: MultipartUploadInfo,
257
+ trained_model_id: str,
258
+ file_type: str,
259
+ client: BaseClient,
260
+ request_id: Optional[str] = None,
261
+ ) -> None:
262
+ parts = _upload_data_in_parts(
263
+ data=data,
264
+ upload_urls=upload_info.upload_urls,
265
+ part_size=upload_info.part_size,
266
+ client=client,
267
+ )
268
+ _complete_multipart_upload(
269
+ trained_model_id=trained_model_id,
270
+ file_type=file_type,
271
+ upload_id=upload_info.upload_id,
272
+ parts=parts,
273
+ client=client,
274
+ request_id=request_id,
275
+ )
276
+
277
+
278
+ def _set_client_trace_dict(data: Dict[str, Any], client: BaseClient) -> None:
279
+ trace_dict = {}
280
+ for k in ["trace_id", "span_id"]:
281
+ if k in data:
282
+ trace_dict[k] = data[k]
283
+ if trace_dict:
284
+ client._set_trace_dict(trace_dict)
285
+
286
+
287
+ def create_fit_task_metadata(
288
+ x_train_size: int,
289
+ y_train_size: int,
290
+ client: BaseClient,
291
+ ) -> FitMultipartMetadataResponse:
292
+ json_data = {
293
+ "x_train_size": x_train_size,
294
+ "y_train_size": y_train_size,
295
+ }
296
+
297
+ client._set_trace_dict(None)
298
+ response = api_call(
299
+ method="POST",
300
+ full_url=client.config.get_full_fit_model_metadata_generate_url(),
301
+ client=client,
302
+ json=json_data,
303
+ timeout=DEFAULT_MODEL_METADATA_GENERATE_TIMEOUT_SECONDS,
304
+ )
305
+ data = response.json()
306
+
307
+ _set_client_trace_dict(data, client)
308
+
309
+ return FitMultipartMetadataResponse(**data)
310
+
311
+
312
+ def create_predict_task_metadata(
313
+ trained_model_id: str,
314
+ x_test_size: int,
315
+ client: BaseClient,
316
+ ) -> PredictMetadata:
317
+ json_data = {
318
+ "trained_model_id": trained_model_id,
319
+ "x_test_size": x_test_size,
320
+ }
321
+
322
+ client._set_trace_dict(None)
323
+ response = api_call(
324
+ method="POST",
325
+ full_url=client.config.get_full_predict_model_metadata_generate_url(),
326
+ client=client,
327
+ json=json_data,
328
+ timeout=DEFAULT_MODEL_METADATA_GENERATE_TIMEOUT_SECONDS,
329
+ )
330
+ data = response.json()
331
+
332
+ _set_client_trace_dict(data, client)
333
+
334
+ # Adapt PredictMultipartMetadataResponse to PredictMetadata interface
335
+ api_response = PredictMultipartMetadataResponse(**data)
336
+ return PredictMetadata(
337
+ request_id=api_response.request_id,
338
+ x_test_upload=MultipartUploadInfo(
339
+ upload_id=api_response.upload_id,
340
+ upload_urls=api_response.upload_urls,
341
+ num_parts=api_response.num_parts,
342
+ part_size=api_response.part_size,
343
+ ),
344
+ )
345
+
346
+
347
+ def create_feature_importance_task_metadata(
348
+ trained_model_id: str,
349
+ x_size: int,
350
+ client: BaseClient,
351
+ ) -> FeatureImportanceMultipartMetadataResponse:
352
+ json_data: Dict[str, Any] = {
353
+ "trained_model_id": trained_model_id,
354
+ "x_size": x_size,
355
+ }
356
+
357
+ client._set_trace_dict(None)
358
+ response = api_call(
359
+ method="POST",
360
+ full_url=client.config.get_full_feature_importance_model_metadata_generate_url(),
361
+ client=client,
362
+ json=json_data,
363
+ timeout=DEFAULT_MODEL_METADATA_GENERATE_TIMEOUT_SECONDS,
364
+ )
365
+ data = response.json()
366
+
367
+ _set_client_trace_dict(data, client)
368
+
369
+ return FeatureImportanceMultipartMetadataResponse(**data)
370
+
371
+
372
+ def upload_fit_data(
373
+ X_serialized: bytes,
374
+ y_serialized: bytes,
375
+ metadata: FitMultipartMetadataResponse,
376
+ client: BaseClient,
377
+ ) -> None:
378
+ _upload_and_complete_multipart(
379
+ data=X_serialized,
380
+ upload_info=metadata.x_train_upload,
381
+ trained_model_id=metadata.trained_model_id,
382
+ file_type="x_train",
383
+ client=client,
384
+ )
385
+ _upload_and_complete_multipart(
386
+ data=y_serialized,
387
+ upload_info=metadata.y_train_upload,
388
+ trained_model_id=metadata.trained_model_id,
389
+ file_type="y_train",
390
+ client=client,
391
+ )
392
+
393
+
394
+ def upload_predict_data(
395
+ X_serialized: bytes,
396
+ metadata: PredictMetadata,
397
+ trained_model_id: str,
398
+ client: BaseClient,
399
+ ) -> None:
400
+ _upload_and_complete_multipart(
401
+ data=X_serialized,
402
+ upload_info=metadata.x_test_upload,
403
+ trained_model_id=trained_model_id,
404
+ file_type="x_test",
405
+ client=client,
406
+ request_id=metadata.request_id,
407
+ )
408
+
409
+
410
+ def upload_feature_importance_data(
411
+ X_serialized: bytes,
412
+ metadata: FeatureImportanceMultipartMetadataResponse,
413
+ trained_model_id: str,
414
+ client: BaseClient,
415
+ ) -> None:
416
+ _upload_and_complete_multipart(
417
+ data=X_serialized,
418
+ upload_info=metadata.x_upload,
419
+ trained_model_id=trained_model_id,
420
+ file_type="x_feature_importance",
421
+ client=client,
422
+ request_id=metadata.request_id,
423
+ )
424
+
425
+
426
+ def download_result_from_url(
427
+ download_url: str,
428
+ client: BaseClient,
429
+ timeout: int = DEFAULT_DOWNLOAD_RESULT_TIMEOUT_SECONDS,
430
+ ) -> Any: # noqa: ANN401
431
+ response = api_call(
432
+ method="GET",
433
+ full_url=download_url,
434
+ client=client,
435
+ timeout=timeout,
436
+ )
437
+ return response.json()