fundamental-client 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,294 @@
1
+ """Simple HTTP utilities for NEXUS client."""
2
+
3
+ import json as json_lib
4
+ import logging
5
+ import random
6
+ import time
7
+ from typing import Any, Dict, Optional
8
+
9
+ import httpx
10
+
11
+ from fundamental.clients.base import BaseClient
12
+ from fundamental.config import Config
13
+ from fundamental.constants import DEFAULT_TIMEOUT_SECONDS, SIGV4_SERVICE_NAME
14
+ from fundamental.exceptions import (
15
+ AuthenticationError,
16
+ AuthorizationError,
17
+ HTTPError,
18
+ NetworkError,
19
+ NotFoundError,
20
+ RateLimitError,
21
+ RequestTimeoutError,
22
+ ServerError,
23
+ ValidationError,
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+ logging.getLogger("httpx").setLevel(logging.WARNING)
28
+
29
+
30
+ def _sign_request_sigv4(
31
+ method: str,
32
+ url: str,
33
+ headers: Dict[str, str],
34
+ body: Optional[bytes],
35
+ region: str,
36
+ ) -> Dict[str, str]:
37
+ try:
38
+ import boto3 # pyright: ignore[reportMissingImports]
39
+ from botocore.auth import SigV4Auth # pyright: ignore[reportMissingImports]
40
+ from botocore.awsrequest import AWSRequest # pyright: ignore[reportMissingImports]
41
+ except ImportError as e:
42
+ raise ImportError(
43
+ "boto3 is required for EC2/SigV4 authentication. "
44
+ 'Install it with: pip install "fundamental-client[ec2]" '
45
+ 'or: uv add "fundamental-client[ec2]"'
46
+ ) from e
47
+
48
+ session = boto3.Session()
49
+ credentials = session.get_credentials()
50
+
51
+ if credentials is None:
52
+ raise AuthenticationError(
53
+ message="AWS credentials not found. Configure AWS credentials via "
54
+ "environment variables, AWS config file, or IAM role."
55
+ )
56
+
57
+ aws_request = AWSRequest(method=method, url=url, headers=headers, data=body or b"")
58
+
59
+ SigV4Auth(credentials, SIGV4_SERVICE_NAME, region).add_auth(aws_request)
60
+
61
+ return dict(aws_request.headers)
62
+
63
+
64
+ def _prepare_request(
65
+ method: str,
66
+ url: str,
67
+ client_config: Config,
68
+ trace_dict: Optional[Dict[str, str]] = None,
69
+ extra_headers: Optional[Dict[str, str]] = None,
70
+ content: Optional[bytes] = None,
71
+ json: Optional[Dict[str, Any]] = None,
72
+ ) -> tuple[Dict[str, str], Optional[bytes]]:
73
+ """Prepare headers and body for an API request.
74
+
75
+ Returns:
76
+ Tuple of (headers, body_bytes)
77
+ """
78
+ headers: Dict[str, str] = {
79
+ "User-Agent": f"NEXUS-Client-SDK/{client_config.get_version()}",
80
+ }
81
+
82
+ if trace_dict:
83
+ trace_id = trace_dict.get("trace_id")
84
+ span_id = trace_dict.get("span_id")
85
+ if trace_id:
86
+ headers["x-datadog-trace-id"] = trace_id
87
+ if span_id:
88
+ headers["x-datadog-parent-id"] = span_id
89
+
90
+ if extra_headers:
91
+ headers.update(extra_headers)
92
+
93
+ # Serialize JSON to bytes
94
+ body = content
95
+ if json is not None:
96
+ body = json_lib.dumps(json).encode("utf-8")
97
+ headers["Content-Type"] = "application/json"
98
+
99
+ if client_config.use_sigv4_auth:
100
+ headers = _sign_request_sigv4(
101
+ method=method,
102
+ url=url,
103
+ headers=headers,
104
+ body=body,
105
+ region=client_config.aws_region,
106
+ )
107
+ else:
108
+ headers["x-api-key"] = f"{client_config.get_api_key()}"
109
+
110
+ return headers, body
111
+
112
+
113
+ def _calculate_backoff(attempt: int) -> float:
114
+ """Calculate exponential backoff delay with jitter."""
115
+ delay = min(2**attempt, 10) # Cap at 10 seconds
116
+ jitter = random.uniform(0.1, 0.3) * delay
117
+ return float(delay + jitter)
118
+
119
+
120
+ def _handle_response_error(response: httpx.Response, trace_id: Optional[str] = None) -> None:
121
+ """Handle HTTP error responses by raising appropriate exceptions."""
122
+ try:
123
+ error_data = response.json()
124
+ detail = error_data.get("detail", response.text[:200])
125
+ except ValueError:
126
+ detail = response.text[:200] or f"HTTP error {response.status_code}"
127
+
128
+ if response.status_code == 400:
129
+ raise ValidationError(
130
+ message=f"Invalid request. Error: {detail}",
131
+ trace_id=trace_id,
132
+ )
133
+ if response.status_code == 401:
134
+ raise AuthenticationError(
135
+ message="Invalid API key. Check your credentials.",
136
+ trace_id=trace_id,
137
+ )
138
+ if response.status_code == 403:
139
+ raise AuthorizationError(
140
+ message="Access denied. Check your API permissions.",
141
+ trace_id=trace_id,
142
+ )
143
+ if response.status_code == 404:
144
+ raise NotFoundError(
145
+ message=f"Resource not found. Error: {detail}",
146
+ trace_id=trace_id,
147
+ )
148
+ if response.status_code == 429:
149
+ retry_after = response.headers.get("Retry-After")
150
+ raise RateLimitError(
151
+ message=f"Rate limit exceeded. Retry after {retry_after} seconds.",
152
+ trace_id=trace_id,
153
+ )
154
+ if response.status_code >= 500:
155
+ raise ServerError(
156
+ message="Internal server error.",
157
+ trace_id=trace_id,
158
+ )
159
+ http_error = HTTPError(message=f"Request failed: {detail}", trace_id=trace_id)
160
+ http_error.status_code = response.status_code
161
+ raise http_error
162
+
163
+
164
+ def api_call(
165
+ method: str,
166
+ full_url: str,
167
+ client: BaseClient,
168
+ files: Optional[Dict] = None,
169
+ content: Optional[bytes] = None,
170
+ data: Optional[Dict[str, Any]] = None,
171
+ json: Optional[Dict[str, Any]] = None,
172
+ headers: Optional[Dict[str, str]] = None,
173
+ timeout: Optional[float] = DEFAULT_TIMEOUT_SECONDS,
174
+ max_retries: Optional[int] = None,
175
+ ) -> httpx.Response:
176
+ """Make HTTP request to the api.
177
+
178
+ Args:
179
+ method: HTTP method (GET, POST, etc.)
180
+ full_url: Complete URL for the request
181
+ client: Client instance
182
+ files: Files to upload
183
+ content: Raw content bytes
184
+ data: Form data
185
+ json: JSON payload
186
+ headers: Additional headers
187
+ timeout: Request timeout in seconds
188
+ max_retries: Number of retries
189
+
190
+ Returns:
191
+ httpx.Response object
192
+
193
+ Raises:
194
+ Various HTTP exceptions based on response status
195
+ """
196
+ client_config = client.config
197
+ trace_dict = client.get_trace_dict()
198
+ if max_retries is None:
199
+ max_retries = client_config.retries
200
+
201
+ # Only add auth for requests to our API, not external URLs (e.g., S3)
202
+ is_api_request = full_url.startswith(client_config.api_url)
203
+ merged_headers = headers or {}
204
+
205
+ if is_api_request:
206
+ # _prepare_request serializes JSON to bytes (content) so SigV4 can sign the exact body.
207
+ # For API key auth this is harmless - we're just doing what httpx does behind the scenes.
208
+ # We set json=None so httpx doesn't re-serialize it.
209
+ merged_headers, content = _prepare_request(
210
+ method=method,
211
+ url=full_url,
212
+ client_config=client_config,
213
+ trace_dict=trace_dict,
214
+ extra_headers=headers,
215
+ content=content,
216
+ json=json,
217
+ )
218
+ json = None
219
+
220
+ trace_id = trace_dict.get("trace_id") if trace_dict else None
221
+ with httpx.Client(
222
+ timeout=httpx.Timeout(timeout, connect=5.0),
223
+ follow_redirects=True,
224
+ ) as http_client:
225
+ for attempt in range(max_retries + 1):
226
+ try:
227
+ logger.debug(f"Attempt {attempt + 1} of {max_retries + 1} to {method} {full_url}")
228
+ response = http_client.request(
229
+ method=method,
230
+ url=full_url,
231
+ files=files,
232
+ headers=merged_headers,
233
+ content=content,
234
+ data=data,
235
+ json=json,
236
+ )
237
+
238
+ if 200 <= response.status_code < 300:
239
+ return response
240
+
241
+ # Handle non-retryable errors immediately
242
+ if 400 <= response.status_code < 500 and response.status_code != 429:
243
+ _handle_response_error(response, trace_id=trace_id)
244
+
245
+ # Handle retryable errors (429, 5xx)
246
+ if attempt < max_retries and response.status_code in (
247
+ 429,
248
+ 500,
249
+ 502,
250
+ 503,
251
+ 504,
252
+ ):
253
+ wait_time = _calculate_backoff(attempt)
254
+
255
+ # Use Retry-After header for rate limiting
256
+ if response.status_code == 429:
257
+ retry_after = response.headers.get("Retry-After")
258
+ if retry_after:
259
+ try:
260
+ wait_time = float(retry_after)
261
+ except ValueError:
262
+ pass
263
+
264
+ time.sleep(wait_time)
265
+ continue
266
+
267
+ # Final attempt failed
268
+ _handle_response_error(response, trace_id=trace_id)
269
+
270
+ except (httpx.TimeoutException, httpx.RequestError) as e:
271
+ if attempt < max_retries:
272
+ wait_time = _calculate_backoff(attempt)
273
+ logger.debug(f"Network error, retrying in {wait_time:.1f}s")
274
+ time.sleep(wait_time)
275
+ continue
276
+
277
+ # Final attempt failed
278
+ if isinstance(e, httpx.TimeoutException):
279
+ raise RequestTimeoutError(
280
+ message="Request timed out. Check your connection or try again.",
281
+ trace_id=trace_id,
282
+ ) from e
283
+ if "Connection refused" in str(e):
284
+ raise NetworkError(
285
+ message="Connection refused, cannot connect to api server. \n"
286
+ "Check your connection or try again.",
287
+ trace_id=trace_id,
288
+ ) from e
289
+ raise NetworkError(
290
+ message="Network error occurred. Please try again.",
291
+ trace_id=trace_id,
292
+ ) from e
293
+
294
+ raise RuntimeError("Network error occurred. Please try again.")
@@ -0,0 +1,97 @@
1
+ """
2
+ Polling utilities for task status checks.
3
+ """
4
+
5
+ import time
6
+
7
+ from fundamental.clients.base import BaseClient
8
+ from fundamental.exceptions import (
9
+ NEXUSError,
10
+ RequestTimeoutError,
11
+ ServerError,
12
+ )
13
+ from fundamental.models import TaskStatus, TaskStatusResponse
14
+ from fundamental.utils.http import api_call
15
+
16
+
17
+ def _raise_exception_with_trace(
18
+ exception_class: type[NEXUSError],
19
+ message: str,
20
+ client: BaseClient,
21
+ ) -> None:
22
+ trace_dict = client.get_trace_dict()
23
+ trace_id = trace_dict.get("trace_id") if trace_dict else None
24
+ raise exception_class(message=message, trace_id=trace_id)
25
+
26
+
27
+ def wait_for_task_status(
28
+ client: BaseClient,
29
+ status_url: str,
30
+ timeout: float,
31
+ polling_interval: float,
32
+ polling_requests_without_delay: int = 0,
33
+ wait_for_completion: bool = True,
34
+ ) -> TaskStatusResponse:
35
+ """
36
+ Poll for task status until completion or timeout.
37
+
38
+ Parameters
39
+ ----------
40
+ client : BaseClient
41
+ The client instance.
42
+ status_url : str
43
+ The URL to poll for status.
44
+ timeout : float
45
+ Maximum time to wait in seconds.
46
+ polling_interval : float
47
+ Time between polling requests in seconds.
48
+ polling_requests_without_delay : int, default=0
49
+ Number of initial requests without delay.
50
+ wait_for_completion : bool, default=True
51
+ Whether to wait for completion or return immediately on in_progress.
52
+
53
+ Returns
54
+ -------
55
+ TaskStatusResponse
56
+ The task status response.
57
+
58
+ Raises
59
+ ------
60
+ RequestTimeoutError
61
+ If the request times out.
62
+ ServerError
63
+ If the server returns an error.
64
+ ValidationError
65
+ If the request fails validation.
66
+ """
67
+ start_time = time.perf_counter()
68
+ request_counter = 0
69
+
70
+ while True:
71
+ if time.perf_counter() - start_time > timeout:
72
+ _raise_exception_with_trace(
73
+ RequestTimeoutError,
74
+ message=f"Request timed out after {timeout}s",
75
+ client=client,
76
+ )
77
+
78
+ response = api_call(method="GET", full_url=status_url, client=client)
79
+ status_response = TaskStatusResponse(**response.json())
80
+
81
+ if status_response.status == TaskStatus.SUCCESS:
82
+ return status_response
83
+
84
+ if status_response.status == TaskStatus.IN_PROGRESS:
85
+ if not wait_for_completion:
86
+ return status_response
87
+ if request_counter > polling_requests_without_delay:
88
+ time.sleep(polling_interval)
89
+ else:
90
+ # Unexpected state
91
+ _raise_exception_with_trace(
92
+ ServerError,
93
+ message=f"Unexpected task status: {status_response.status}",
94
+ client=client,
95
+ )
96
+
97
+ request_counter += 1
@@ -0,0 +1,98 @@
1
+ """Safe deserialization for estimator fields using safetensors.
2
+
3
+ Deserializes estimator_fields.safetensors from the API into a dict.
4
+ This replaces pickle to eliminate RCE vulnerabilities.
5
+
6
+ Coupling:
7
+ - Input: estimator_fields.safetensors bytes from API (base64 decoded)
8
+ - Producer: Controller's safetensors_serialization.save_estimator_fields()
9
+ - Output: dict matching LTM's sync_fields() structure
10
+
11
+ Format:
12
+ - Tensors -> numeric numpy arrays
13
+ - Metadata {"value": [...], "type": "string_array"} -> string arrays
14
+ - Metadata {"type": "array_list", "length": N} + field::0, field::1 -> list of arrays
15
+ - Metadata {"value": ...} -> scalars/lists/dicts
16
+ """
17
+
18
+ import json
19
+ import tempfile
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ import numpy as np
24
+ from safetensors import safe_open
25
+ from safetensors.numpy import load
26
+
27
+
28
+ def _deserialize_fields(tensors: dict[str, np.ndarray], metadata: dict[str, str]) -> dict[str, Any]:
29
+ """Deserialize tensors and metadata back into fields.
30
+
31
+ Metadata format:
32
+ - {"value": ...} -> scalar/list/dict
33
+ - {"value": [...], "type": "string_array"} -> numpy array (string/object dtype)
34
+ - {"type": "array_list", "length": N} -> list of arrays (indexed as field::0, field::1, ...)
35
+
36
+ Args:
37
+ tensors: Dictionary of tensor names to numpy arrays
38
+ metadata: Dictionary of metadata keys to JSON strings
39
+
40
+ Returns:
41
+ Reconstructed fields dictionary
42
+ """
43
+ fields: dict[str, Any] = {}
44
+
45
+ # Add all tensors that aren't indexed (list items have :: in name)
46
+ fields = {k: v for k, v in tensors.items() if "::" not in k}
47
+
48
+ # Process metadata entries
49
+ for k, v in metadata.items():
50
+ if "::" in k: # Skip indexed list items (processed via parent)
51
+ continue
52
+
53
+ parsed = json.loads(v)
54
+
55
+ if parsed.get("type") == "array_list":
56
+ # Reconstruct list of arrays from indexed entries
57
+ length = parsed["length"]
58
+ arrays = []
59
+ for i in range(length):
60
+ idx_key = f"{k}::{i}"
61
+ if idx_key in tensors:
62
+ arrays.append(tensors[idx_key])
63
+ elif idx_key in metadata:
64
+ idx_parsed = json.loads(metadata[idx_key])
65
+ arrays.append(np.array(idx_parsed["value"]))
66
+ else:
67
+ raise ValueError(f"Missing array_list entry for {idx_key}")
68
+ fields[k] = arrays
69
+ elif parsed.get("type") == "string_array":
70
+ fields[k] = np.array(parsed["value"])
71
+ else:
72
+ fields[k] = parsed["value"]
73
+
74
+ return fields
75
+
76
+
77
+ def load_estimator_fields_from_bytes(data: bytes) -> dict[str, Any]:
78
+ """Load estimator fields from safetensors bytes.
79
+
80
+ Args:
81
+ data: Bytes containing the safetensors data (from API response)
82
+
83
+ Returns:
84
+ Reconstructed fields dictionary with numpy arrays
85
+ """
86
+ tensors = load(data)
87
+
88
+ # To get metadata, we need to use safe_open with a file-like object
89
+ # safetensors supports loading from bytes directly for tensors,
90
+ # but metadata requires safe_open.
91
+ # Use a temp directory so Windows can reopen the file without lock issues.
92
+ with tempfile.TemporaryDirectory() as tmp_dir:
93
+ tmp_path = Path(tmp_dir) / "estimator_fields.safetensors"
94
+ tmp_path.write_bytes(data)
95
+ with safe_open(str(tmp_path), framework="np") as f:
96
+ metadata = f.metadata() or {}
97
+
98
+ return _deserialize_fields(tensors, metadata)