fundamental-client 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fundamental/__init__.py +34 -0
- fundamental/clients/__init__.py +7 -0
- fundamental/clients/base.py +37 -0
- fundamental/clients/ec2.py +37 -0
- fundamental/clients/fundamental.py +20 -0
- fundamental/config.py +138 -0
- fundamental/constants.py +41 -0
- fundamental/deprecated.py +43 -0
- fundamental/estimator/__init__.py +16 -0
- fundamental/estimator/base.py +263 -0
- fundamental/estimator/classification.py +46 -0
- fundamental/estimator/nexus_estimator.py +120 -0
- fundamental/estimator/regression.py +22 -0
- fundamental/exceptions.py +78 -0
- fundamental/models/__init__.py +4 -0
- fundamental/models/generated.py +431 -0
- fundamental/services/__init__.py +25 -0
- fundamental/services/feature_importance.py +172 -0
- fundamental/services/inference.py +283 -0
- fundamental/services/models.py +186 -0
- fundamental/utils/__init__.py +0 -0
- fundamental/utils/data.py +437 -0
- fundamental/utils/http.py +294 -0
- fundamental/utils/polling.py +97 -0
- fundamental/utils/safetensors_deserialize.py +98 -0
- fundamental_client-0.2.3.dist-info/METADATA +241 -0
- fundamental_client-0.2.3.dist-info/RECORD +29 -0
- fundamental_client-0.2.3.dist-info/WHEEL +4 -0
- fundamental_client-0.2.3.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""Simple HTTP utilities for NEXUS client."""
|
|
2
|
+
|
|
3
|
+
import json as json_lib
|
|
4
|
+
import logging
|
|
5
|
+
import random
|
|
6
|
+
import time
|
|
7
|
+
from typing import Any, Dict, Optional
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from fundamental.clients.base import BaseClient
|
|
12
|
+
from fundamental.config import Config
|
|
13
|
+
from fundamental.constants import DEFAULT_TIMEOUT_SECONDS, SIGV4_SERVICE_NAME
|
|
14
|
+
from fundamental.exceptions import (
|
|
15
|
+
AuthenticationError,
|
|
16
|
+
AuthorizationError,
|
|
17
|
+
HTTPError,
|
|
18
|
+
NetworkError,
|
|
19
|
+
NotFoundError,
|
|
20
|
+
RateLimitError,
|
|
21
|
+
RequestTimeoutError,
|
|
22
|
+
ServerError,
|
|
23
|
+
ValidationError,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _sign_request_sigv4(
|
|
31
|
+
method: str,
|
|
32
|
+
url: str,
|
|
33
|
+
headers: Dict[str, str],
|
|
34
|
+
body: Optional[bytes],
|
|
35
|
+
region: str,
|
|
36
|
+
) -> Dict[str, str]:
|
|
37
|
+
try:
|
|
38
|
+
import boto3 # pyright: ignore[reportMissingImports]
|
|
39
|
+
from botocore.auth import SigV4Auth # pyright: ignore[reportMissingImports]
|
|
40
|
+
from botocore.awsrequest import AWSRequest # pyright: ignore[reportMissingImports]
|
|
41
|
+
except ImportError as e:
|
|
42
|
+
raise ImportError(
|
|
43
|
+
"boto3 is required for EC2/SigV4 authentication. "
|
|
44
|
+
'Install it with: pip install "fundamental-client[ec2]" '
|
|
45
|
+
'or: uv add "fundamental-client[ec2]"'
|
|
46
|
+
) from e
|
|
47
|
+
|
|
48
|
+
session = boto3.Session()
|
|
49
|
+
credentials = session.get_credentials()
|
|
50
|
+
|
|
51
|
+
if credentials is None:
|
|
52
|
+
raise AuthenticationError(
|
|
53
|
+
message="AWS credentials not found. Configure AWS credentials via "
|
|
54
|
+
"environment variables, AWS config file, or IAM role."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
aws_request = AWSRequest(method=method, url=url, headers=headers, data=body or b"")
|
|
58
|
+
|
|
59
|
+
SigV4Auth(credentials, SIGV4_SERVICE_NAME, region).add_auth(aws_request)
|
|
60
|
+
|
|
61
|
+
return dict(aws_request.headers)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _prepare_request(
|
|
65
|
+
method: str,
|
|
66
|
+
url: str,
|
|
67
|
+
client_config: Config,
|
|
68
|
+
trace_dict: Optional[Dict[str, str]] = None,
|
|
69
|
+
extra_headers: Optional[Dict[str, str]] = None,
|
|
70
|
+
content: Optional[bytes] = None,
|
|
71
|
+
json: Optional[Dict[str, Any]] = None,
|
|
72
|
+
) -> tuple[Dict[str, str], Optional[bytes]]:
|
|
73
|
+
"""Prepare headers and body for an API request.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tuple of (headers, body_bytes)
|
|
77
|
+
"""
|
|
78
|
+
headers: Dict[str, str] = {
|
|
79
|
+
"User-Agent": f"NEXUS-Client-SDK/{client_config.get_version()}",
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
if trace_dict:
|
|
83
|
+
trace_id = trace_dict.get("trace_id")
|
|
84
|
+
span_id = trace_dict.get("span_id")
|
|
85
|
+
if trace_id:
|
|
86
|
+
headers["x-datadog-trace-id"] = trace_id
|
|
87
|
+
if span_id:
|
|
88
|
+
headers["x-datadog-parent-id"] = span_id
|
|
89
|
+
|
|
90
|
+
if extra_headers:
|
|
91
|
+
headers.update(extra_headers)
|
|
92
|
+
|
|
93
|
+
# Serialize JSON to bytes
|
|
94
|
+
body = content
|
|
95
|
+
if json is not None:
|
|
96
|
+
body = json_lib.dumps(json).encode("utf-8")
|
|
97
|
+
headers["Content-Type"] = "application/json"
|
|
98
|
+
|
|
99
|
+
if client_config.use_sigv4_auth:
|
|
100
|
+
headers = _sign_request_sigv4(
|
|
101
|
+
method=method,
|
|
102
|
+
url=url,
|
|
103
|
+
headers=headers,
|
|
104
|
+
body=body,
|
|
105
|
+
region=client_config.aws_region,
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
headers["x-api-key"] = f"{client_config.get_api_key()}"
|
|
109
|
+
|
|
110
|
+
return headers, body
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _calculate_backoff(attempt: int) -> float:
|
|
114
|
+
"""Calculate exponential backoff delay with jitter."""
|
|
115
|
+
delay = min(2**attempt, 10) # Cap at 10 seconds
|
|
116
|
+
jitter = random.uniform(0.1, 0.3) * delay
|
|
117
|
+
return float(delay + jitter)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _handle_response_error(response: httpx.Response, trace_id: Optional[str] = None) -> None:
|
|
121
|
+
"""Handle HTTP error responses by raising appropriate exceptions."""
|
|
122
|
+
try:
|
|
123
|
+
error_data = response.json()
|
|
124
|
+
detail = error_data.get("detail", response.text[:200])
|
|
125
|
+
except ValueError:
|
|
126
|
+
detail = response.text[:200] or f"HTTP error {response.status_code}"
|
|
127
|
+
|
|
128
|
+
if response.status_code == 400:
|
|
129
|
+
raise ValidationError(
|
|
130
|
+
message=f"Invalid request. Error: {detail}",
|
|
131
|
+
trace_id=trace_id,
|
|
132
|
+
)
|
|
133
|
+
if response.status_code == 401:
|
|
134
|
+
raise AuthenticationError(
|
|
135
|
+
message="Invalid API key. Check your credentials.",
|
|
136
|
+
trace_id=trace_id,
|
|
137
|
+
)
|
|
138
|
+
if response.status_code == 403:
|
|
139
|
+
raise AuthorizationError(
|
|
140
|
+
message="Access denied. Check your API permissions.",
|
|
141
|
+
trace_id=trace_id,
|
|
142
|
+
)
|
|
143
|
+
if response.status_code == 404:
|
|
144
|
+
raise NotFoundError(
|
|
145
|
+
message=f"Resource not found. Error: {detail}",
|
|
146
|
+
trace_id=trace_id,
|
|
147
|
+
)
|
|
148
|
+
if response.status_code == 429:
|
|
149
|
+
retry_after = response.headers.get("Retry-After")
|
|
150
|
+
raise RateLimitError(
|
|
151
|
+
message=f"Rate limit exceeded. Retry after {retry_after} seconds.",
|
|
152
|
+
trace_id=trace_id,
|
|
153
|
+
)
|
|
154
|
+
if response.status_code >= 500:
|
|
155
|
+
raise ServerError(
|
|
156
|
+
message="Internal server error.",
|
|
157
|
+
trace_id=trace_id,
|
|
158
|
+
)
|
|
159
|
+
http_error = HTTPError(message=f"Request failed: {detail}", trace_id=trace_id)
|
|
160
|
+
http_error.status_code = response.status_code
|
|
161
|
+
raise http_error
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def api_call(
|
|
165
|
+
method: str,
|
|
166
|
+
full_url: str,
|
|
167
|
+
client: BaseClient,
|
|
168
|
+
files: Optional[Dict] = None,
|
|
169
|
+
content: Optional[bytes] = None,
|
|
170
|
+
data: Optional[Dict[str, Any]] = None,
|
|
171
|
+
json: Optional[Dict[str, Any]] = None,
|
|
172
|
+
headers: Optional[Dict[str, str]] = None,
|
|
173
|
+
timeout: Optional[float] = DEFAULT_TIMEOUT_SECONDS,
|
|
174
|
+
max_retries: Optional[int] = None,
|
|
175
|
+
) -> httpx.Response:
|
|
176
|
+
"""Make HTTP request to the api.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
method: HTTP method (GET, POST, etc.)
|
|
180
|
+
full_url: Complete URL for the request
|
|
181
|
+
client: Client instance
|
|
182
|
+
files: Files to upload
|
|
183
|
+
content: Raw content bytes
|
|
184
|
+
data: Form data
|
|
185
|
+
json: JSON payload
|
|
186
|
+
headers: Additional headers
|
|
187
|
+
timeout: Request timeout in seconds
|
|
188
|
+
max_retries: Number of retries
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
httpx.Response object
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
Various HTTP exceptions based on response status
|
|
195
|
+
"""
|
|
196
|
+
client_config = client.config
|
|
197
|
+
trace_dict = client.get_trace_dict()
|
|
198
|
+
if max_retries is None:
|
|
199
|
+
max_retries = client_config.retries
|
|
200
|
+
|
|
201
|
+
# Only add auth for requests to our API, not external URLs (e.g., S3)
|
|
202
|
+
is_api_request = full_url.startswith(client_config.api_url)
|
|
203
|
+
merged_headers = headers or {}
|
|
204
|
+
|
|
205
|
+
if is_api_request:
|
|
206
|
+
# _prepare_request serializes JSON to bytes (content) so SigV4 can sign the exact body.
|
|
207
|
+
# For API key auth this is harmless - we're just doing what httpx does behind the scenes.
|
|
208
|
+
# We set json=None so httpx doesn't re-serialize it.
|
|
209
|
+
merged_headers, content = _prepare_request(
|
|
210
|
+
method=method,
|
|
211
|
+
url=full_url,
|
|
212
|
+
client_config=client_config,
|
|
213
|
+
trace_dict=trace_dict,
|
|
214
|
+
extra_headers=headers,
|
|
215
|
+
content=content,
|
|
216
|
+
json=json,
|
|
217
|
+
)
|
|
218
|
+
json = None
|
|
219
|
+
|
|
220
|
+
trace_id = trace_dict.get("trace_id") if trace_dict else None
|
|
221
|
+
with httpx.Client(
|
|
222
|
+
timeout=httpx.Timeout(timeout, connect=5.0),
|
|
223
|
+
follow_redirects=True,
|
|
224
|
+
) as http_client:
|
|
225
|
+
for attempt in range(max_retries + 1):
|
|
226
|
+
try:
|
|
227
|
+
logger.debug(f"Attempt {attempt + 1} of {max_retries + 1} to {method} {full_url}")
|
|
228
|
+
response = http_client.request(
|
|
229
|
+
method=method,
|
|
230
|
+
url=full_url,
|
|
231
|
+
files=files,
|
|
232
|
+
headers=merged_headers,
|
|
233
|
+
content=content,
|
|
234
|
+
data=data,
|
|
235
|
+
json=json,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
if 200 <= response.status_code < 300:
|
|
239
|
+
return response
|
|
240
|
+
|
|
241
|
+
# Handle non-retryable errors immediately
|
|
242
|
+
if 400 <= response.status_code < 500 and response.status_code != 429:
|
|
243
|
+
_handle_response_error(response, trace_id=trace_id)
|
|
244
|
+
|
|
245
|
+
# Handle retryable errors (429, 5xx)
|
|
246
|
+
if attempt < max_retries and response.status_code in (
|
|
247
|
+
429,
|
|
248
|
+
500,
|
|
249
|
+
502,
|
|
250
|
+
503,
|
|
251
|
+
504,
|
|
252
|
+
):
|
|
253
|
+
wait_time = _calculate_backoff(attempt)
|
|
254
|
+
|
|
255
|
+
# Use Retry-After header for rate limiting
|
|
256
|
+
if response.status_code == 429:
|
|
257
|
+
retry_after = response.headers.get("Retry-After")
|
|
258
|
+
if retry_after:
|
|
259
|
+
try:
|
|
260
|
+
wait_time = float(retry_after)
|
|
261
|
+
except ValueError:
|
|
262
|
+
pass
|
|
263
|
+
|
|
264
|
+
time.sleep(wait_time)
|
|
265
|
+
continue
|
|
266
|
+
|
|
267
|
+
# Final attempt failed
|
|
268
|
+
_handle_response_error(response, trace_id=trace_id)
|
|
269
|
+
|
|
270
|
+
except (httpx.TimeoutException, httpx.RequestError) as e:
|
|
271
|
+
if attempt < max_retries:
|
|
272
|
+
wait_time = _calculate_backoff(attempt)
|
|
273
|
+
logger.debug(f"Network error, retrying in {wait_time:.1f}s")
|
|
274
|
+
time.sleep(wait_time)
|
|
275
|
+
continue
|
|
276
|
+
|
|
277
|
+
# Final attempt failed
|
|
278
|
+
if isinstance(e, httpx.TimeoutException):
|
|
279
|
+
raise RequestTimeoutError(
|
|
280
|
+
message="Request timed out. Check your connection or try again.",
|
|
281
|
+
trace_id=trace_id,
|
|
282
|
+
) from e
|
|
283
|
+
if "Connection refused" in str(e):
|
|
284
|
+
raise NetworkError(
|
|
285
|
+
message="Connection refused, cannot connect to api server. \n"
|
|
286
|
+
"Check your connection or try again.",
|
|
287
|
+
trace_id=trace_id,
|
|
288
|
+
) from e
|
|
289
|
+
raise NetworkError(
|
|
290
|
+
message="Network error occurred. Please try again.",
|
|
291
|
+
trace_id=trace_id,
|
|
292
|
+
) from e
|
|
293
|
+
|
|
294
|
+
raise RuntimeError("Network error occurred. Please try again.")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Polling utilities for task status checks.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
|
|
7
|
+
from fundamental.clients.base import BaseClient
|
|
8
|
+
from fundamental.exceptions import (
|
|
9
|
+
NEXUSError,
|
|
10
|
+
RequestTimeoutError,
|
|
11
|
+
ServerError,
|
|
12
|
+
)
|
|
13
|
+
from fundamental.models import TaskStatus, TaskStatusResponse
|
|
14
|
+
from fundamental.utils.http import api_call
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _raise_exception_with_trace(
|
|
18
|
+
exception_class: type[NEXUSError],
|
|
19
|
+
message: str,
|
|
20
|
+
client: BaseClient,
|
|
21
|
+
) -> None:
|
|
22
|
+
trace_dict = client.get_trace_dict()
|
|
23
|
+
trace_id = trace_dict.get("trace_id") if trace_dict else None
|
|
24
|
+
raise exception_class(message=message, trace_id=trace_id)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def wait_for_task_status(
|
|
28
|
+
client: BaseClient,
|
|
29
|
+
status_url: str,
|
|
30
|
+
timeout: float,
|
|
31
|
+
polling_interval: float,
|
|
32
|
+
polling_requests_without_delay: int = 0,
|
|
33
|
+
wait_for_completion: bool = True,
|
|
34
|
+
) -> TaskStatusResponse:
|
|
35
|
+
"""
|
|
36
|
+
Poll for task status until completion or timeout.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
client : BaseClient
|
|
41
|
+
The client instance.
|
|
42
|
+
status_url : str
|
|
43
|
+
The URL to poll for status.
|
|
44
|
+
timeout : float
|
|
45
|
+
Maximum time to wait in seconds.
|
|
46
|
+
polling_interval : float
|
|
47
|
+
Time between polling requests in seconds.
|
|
48
|
+
polling_requests_without_delay : int, default=0
|
|
49
|
+
Number of initial requests without delay.
|
|
50
|
+
wait_for_completion : bool, default=True
|
|
51
|
+
Whether to wait for completion or return immediately on in_progress.
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
TaskStatusResponse
|
|
56
|
+
The task status response.
|
|
57
|
+
|
|
58
|
+
Raises
|
|
59
|
+
------
|
|
60
|
+
RequestTimeoutError
|
|
61
|
+
If the request times out.
|
|
62
|
+
ServerError
|
|
63
|
+
If the server returns an error.
|
|
64
|
+
ValidationError
|
|
65
|
+
If the request fails validation.
|
|
66
|
+
"""
|
|
67
|
+
start_time = time.perf_counter()
|
|
68
|
+
request_counter = 0
|
|
69
|
+
|
|
70
|
+
while True:
|
|
71
|
+
if time.perf_counter() - start_time > timeout:
|
|
72
|
+
_raise_exception_with_trace(
|
|
73
|
+
RequestTimeoutError,
|
|
74
|
+
message=f"Request timed out after {timeout}s",
|
|
75
|
+
client=client,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
response = api_call(method="GET", full_url=status_url, client=client)
|
|
79
|
+
status_response = TaskStatusResponse(**response.json())
|
|
80
|
+
|
|
81
|
+
if status_response.status == TaskStatus.SUCCESS:
|
|
82
|
+
return status_response
|
|
83
|
+
|
|
84
|
+
if status_response.status == TaskStatus.IN_PROGRESS:
|
|
85
|
+
if not wait_for_completion:
|
|
86
|
+
return status_response
|
|
87
|
+
if request_counter > polling_requests_without_delay:
|
|
88
|
+
time.sleep(polling_interval)
|
|
89
|
+
else:
|
|
90
|
+
# Unexpected state
|
|
91
|
+
_raise_exception_with_trace(
|
|
92
|
+
ServerError,
|
|
93
|
+
message=f"Unexpected task status: {status_response.status}",
|
|
94
|
+
client=client,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
request_counter += 1
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Safe deserialization for estimator fields using safetensors.
|
|
2
|
+
|
|
3
|
+
Deserializes estimator_fields.safetensors from the API into a dict.
|
|
4
|
+
This replaces pickle to eliminate RCE vulnerabilities.
|
|
5
|
+
|
|
6
|
+
Coupling:
|
|
7
|
+
- Input: estimator_fields.safetensors bytes from API (base64 decoded)
|
|
8
|
+
- Producer: Controller's safetensors_serialization.save_estimator_fields()
|
|
9
|
+
- Output: dict matching LTM's sync_fields() structure
|
|
10
|
+
|
|
11
|
+
Format:
|
|
12
|
+
- Tensors -> numeric numpy arrays
|
|
13
|
+
- Metadata {"value": [...], "type": "string_array"} -> string arrays
|
|
14
|
+
- Metadata {"type": "array_list", "length": N} + field::0, field::1 -> list of arrays
|
|
15
|
+
- Metadata {"value": ...} -> scalars/lists/dicts
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import tempfile
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
from safetensors import safe_open
|
|
25
|
+
from safetensors.numpy import load
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _deserialize_fields(tensors: dict[str, np.ndarray], metadata: dict[str, str]) -> dict[str, Any]:
|
|
29
|
+
"""Deserialize tensors and metadata back into fields.
|
|
30
|
+
|
|
31
|
+
Metadata format:
|
|
32
|
+
- {"value": ...} -> scalar/list/dict
|
|
33
|
+
- {"value": [...], "type": "string_array"} -> numpy array (string/object dtype)
|
|
34
|
+
- {"type": "array_list", "length": N} -> list of arrays (indexed as field::0, field::1, ...)
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
tensors: Dictionary of tensor names to numpy arrays
|
|
38
|
+
metadata: Dictionary of metadata keys to JSON strings
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Reconstructed fields dictionary
|
|
42
|
+
"""
|
|
43
|
+
fields: dict[str, Any] = {}
|
|
44
|
+
|
|
45
|
+
# Add all tensors that aren't indexed (list items have :: in name)
|
|
46
|
+
fields = {k: v for k, v in tensors.items() if "::" not in k}
|
|
47
|
+
|
|
48
|
+
# Process metadata entries
|
|
49
|
+
for k, v in metadata.items():
|
|
50
|
+
if "::" in k: # Skip indexed list items (processed via parent)
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
parsed = json.loads(v)
|
|
54
|
+
|
|
55
|
+
if parsed.get("type") == "array_list":
|
|
56
|
+
# Reconstruct list of arrays from indexed entries
|
|
57
|
+
length = parsed["length"]
|
|
58
|
+
arrays = []
|
|
59
|
+
for i in range(length):
|
|
60
|
+
idx_key = f"{k}::{i}"
|
|
61
|
+
if idx_key in tensors:
|
|
62
|
+
arrays.append(tensors[idx_key])
|
|
63
|
+
elif idx_key in metadata:
|
|
64
|
+
idx_parsed = json.loads(metadata[idx_key])
|
|
65
|
+
arrays.append(np.array(idx_parsed["value"]))
|
|
66
|
+
else:
|
|
67
|
+
raise ValueError(f"Missing array_list entry for {idx_key}")
|
|
68
|
+
fields[k] = arrays
|
|
69
|
+
elif parsed.get("type") == "string_array":
|
|
70
|
+
fields[k] = np.array(parsed["value"])
|
|
71
|
+
else:
|
|
72
|
+
fields[k] = parsed["value"]
|
|
73
|
+
|
|
74
|
+
return fields
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def load_estimator_fields_from_bytes(data: bytes) -> dict[str, Any]:
|
|
78
|
+
"""Load estimator fields from safetensors bytes.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
data: Bytes containing the safetensors data (from API response)
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Reconstructed fields dictionary with numpy arrays
|
|
85
|
+
"""
|
|
86
|
+
tensors = load(data)
|
|
87
|
+
|
|
88
|
+
# To get metadata, we need to use safe_open with a file-like object
|
|
89
|
+
# safetensors supports loading from bytes directly for tensors,
|
|
90
|
+
# but metadata requires safe_open.
|
|
91
|
+
# Use a temp directory so Windows can reopen the file without lock issues.
|
|
92
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
93
|
+
tmp_path = Path(tmp_dir) / "estimator_fields.safetensors"
|
|
94
|
+
tmp_path.write_bytes(data)
|
|
95
|
+
with safe_open(str(tmp_path), framework="np") as f:
|
|
96
|
+
metadata = f.metadata() or {}
|
|
97
|
+
|
|
98
|
+
return _deserialize_fields(tensors, metadata)
|