tetra-rp 0.6.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tetra_rp/__init__.py +109 -19
- tetra_rp/cli/commands/__init__.py +1 -0
- tetra_rp/cli/commands/apps.py +143 -0
- tetra_rp/cli/commands/build.py +1082 -0
- tetra_rp/cli/commands/build_utils/__init__.py +1 -0
- tetra_rp/cli/commands/build_utils/handler_generator.py +176 -0
- tetra_rp/cli/commands/build_utils/lb_handler_generator.py +309 -0
- tetra_rp/cli/commands/build_utils/manifest.py +430 -0
- tetra_rp/cli/commands/build_utils/mothership_handler_generator.py +75 -0
- tetra_rp/cli/commands/build_utils/scanner.py +596 -0
- tetra_rp/cli/commands/deploy.py +580 -0
- tetra_rp/cli/commands/init.py +123 -0
- tetra_rp/cli/commands/resource.py +108 -0
- tetra_rp/cli/commands/run.py +296 -0
- tetra_rp/cli/commands/test_mothership.py +458 -0
- tetra_rp/cli/commands/undeploy.py +533 -0
- tetra_rp/cli/main.py +97 -0
- tetra_rp/cli/utils/__init__.py +1 -0
- tetra_rp/cli/utils/app.py +15 -0
- tetra_rp/cli/utils/conda.py +127 -0
- tetra_rp/cli/utils/deployment.py +530 -0
- tetra_rp/cli/utils/ignore.py +143 -0
- tetra_rp/cli/utils/skeleton.py +184 -0
- tetra_rp/cli/utils/skeleton_template/.env.example +4 -0
- tetra_rp/cli/utils/skeleton_template/.flashignore +40 -0
- tetra_rp/cli/utils/skeleton_template/.gitignore +44 -0
- tetra_rp/cli/utils/skeleton_template/README.md +263 -0
- tetra_rp/cli/utils/skeleton_template/main.py +44 -0
- tetra_rp/cli/utils/skeleton_template/mothership.py +55 -0
- tetra_rp/cli/utils/skeleton_template/pyproject.toml +58 -0
- tetra_rp/cli/utils/skeleton_template/requirements.txt +1 -0
- tetra_rp/cli/utils/skeleton_template/workers/__init__.py +0 -0
- tetra_rp/cli/utils/skeleton_template/workers/cpu/__init__.py +19 -0
- tetra_rp/cli/utils/skeleton_template/workers/cpu/endpoint.py +36 -0
- tetra_rp/cli/utils/skeleton_template/workers/gpu/__init__.py +19 -0
- tetra_rp/cli/utils/skeleton_template/workers/gpu/endpoint.py +61 -0
- tetra_rp/client.py +136 -33
- tetra_rp/config.py +29 -0
- tetra_rp/core/api/runpod.py +591 -39
- tetra_rp/core/deployment.py +232 -0
- tetra_rp/core/discovery.py +425 -0
- tetra_rp/core/exceptions.py +50 -0
- tetra_rp/core/resources/__init__.py +27 -9
- tetra_rp/core/resources/app.py +738 -0
- tetra_rp/core/resources/base.py +139 -4
- tetra_rp/core/resources/constants.py +21 -0
- tetra_rp/core/resources/cpu.py +115 -13
- tetra_rp/core/resources/gpu.py +182 -16
- tetra_rp/core/resources/live_serverless.py +153 -16
- tetra_rp/core/resources/load_balancer_sls_resource.py +440 -0
- tetra_rp/core/resources/network_volume.py +126 -31
- tetra_rp/core/resources/resource_manager.py +436 -35
- tetra_rp/core/resources/serverless.py +537 -120
- tetra_rp/core/resources/serverless_cpu.py +201 -0
- tetra_rp/core/resources/template.py +1 -59
- tetra_rp/core/utils/constants.py +10 -0
- tetra_rp/core/utils/file_lock.py +260 -0
- tetra_rp/core/utils/http.py +67 -0
- tetra_rp/core/utils/lru_cache.py +75 -0
- tetra_rp/core/utils/singleton.py +36 -1
- tetra_rp/core/validation.py +44 -0
- tetra_rp/execute_class.py +301 -0
- tetra_rp/protos/remote_execution.py +98 -9
- tetra_rp/runtime/__init__.py +1 -0
- tetra_rp/runtime/circuit_breaker.py +274 -0
- tetra_rp/runtime/config.py +12 -0
- tetra_rp/runtime/exceptions.py +49 -0
- tetra_rp/runtime/generic_handler.py +206 -0
- tetra_rp/runtime/lb_handler.py +189 -0
- tetra_rp/runtime/load_balancer.py +160 -0
- tetra_rp/runtime/manifest_fetcher.py +192 -0
- tetra_rp/runtime/metrics.py +325 -0
- tetra_rp/runtime/models.py +73 -0
- tetra_rp/runtime/mothership_provisioner.py +512 -0
- tetra_rp/runtime/production_wrapper.py +266 -0
- tetra_rp/runtime/reliability_config.py +149 -0
- tetra_rp/runtime/retry_manager.py +118 -0
- tetra_rp/runtime/serialization.py +124 -0
- tetra_rp/runtime/service_registry.py +346 -0
- tetra_rp/runtime/state_manager_client.py +248 -0
- tetra_rp/stubs/live_serverless.py +35 -17
- tetra_rp/stubs/load_balancer_sls.py +357 -0
- tetra_rp/stubs/registry.py +145 -19
- {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/METADATA +398 -60
- tetra_rp-0.24.0.dist-info/RECORD +99 -0
- {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/WHEEL +1 -1
- tetra_rp-0.24.0.dist-info/entry_points.txt +2 -0
- tetra_rp/core/pool/cluster_manager.py +0 -177
- tetra_rp/core/pool/dataclass.py +0 -18
- tetra_rp/core/pool/ex.py +0 -38
- tetra_rp/core/pool/job.py +0 -22
- tetra_rp/core/pool/worker.py +0 -19
- tetra_rp/core/resources/utils.py +0 -50
- tetra_rp/core/utils/json.py +0 -33
- tetra_rp-0.6.0.dist-info/RECORD +0 -39
- /tetra_rp/{core/pool → cli}/__init__.py +0 -0
- {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,346 @@
|
|
|
1
|
+
"""Runtime service registry for cross-endpoint function routing."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import time
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Dict, Optional
|
|
10
|
+
from urllib.parse import urlparse
|
|
11
|
+
|
|
12
|
+
from tetra_rp.core.resources.serverless import ServerlessResource
|
|
13
|
+
|
|
14
|
+
from .config import DEFAULT_CACHE_TTL
|
|
15
|
+
from .state_manager_client import StateManagerClient, ManifestServiceUnavailableError
|
|
16
|
+
from .models import Manifest
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ServiceRegistry:
|
|
22
|
+
"""Service discovery and routing for cross-endpoint function calls.
|
|
23
|
+
|
|
24
|
+
Loads manifest to map functions to resource configs, queries mothership
|
|
25
|
+
manifest for endpoint URLs, and determines if function calls are local
|
|
26
|
+
or remote.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
manifest_path: Optional[Path] = None,
|
|
32
|
+
cache_ttl: int = DEFAULT_CACHE_TTL,
|
|
33
|
+
):
|
|
34
|
+
"""Initialize service registry with peer-to-peer State Manager access.
|
|
35
|
+
|
|
36
|
+
All endpoints query State Manager directly for manifest updates.
|
|
37
|
+
No Mothership dependency - all endpoints are equal peers.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
manifest_path: Path to flash_manifest.json. Defaults to
|
|
41
|
+
FLASH_MANIFEST_PATH env var or auto-detection.
|
|
42
|
+
cache_ttl: Manifest cache lifetime in seconds (default: 300).
|
|
43
|
+
|
|
44
|
+
Environment Variables:
|
|
45
|
+
FLASH_RESOURCE_NAME: Resource config name for this endpoint.
|
|
46
|
+
Identifies which resource config this endpoint represents.
|
|
47
|
+
RUNPOD_ENDPOINT_ID: Endpoint ID (used for State Manager queries and fallback).
|
|
48
|
+
RUNPOD_API_KEY: API key for State Manager GraphQL access.
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
FileNotFoundError: If manifest_path doesn't exist.
|
|
52
|
+
"""
|
|
53
|
+
self.cache_ttl = cache_ttl
|
|
54
|
+
self._endpoint_registry: Dict[str, str] = {}
|
|
55
|
+
self._endpoint_registry_loaded_at = 0.0
|
|
56
|
+
self._manifest: Manifest = Manifest(
|
|
57
|
+
version="1.0",
|
|
58
|
+
generated_at="",
|
|
59
|
+
project_name="",
|
|
60
|
+
function_registry={},
|
|
61
|
+
resources={},
|
|
62
|
+
)
|
|
63
|
+
self._endpoint_registry_lock = asyncio.Lock()
|
|
64
|
+
|
|
65
|
+
# Load manifest
|
|
66
|
+
self._load_manifest(manifest_path)
|
|
67
|
+
|
|
68
|
+
# Peer-to-peer: All endpoints use StateManagerClient directly
|
|
69
|
+
try:
|
|
70
|
+
self._manifest_client = StateManagerClient()
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.warning(f"Failed to initialize State Manager client: {e}")
|
|
73
|
+
self._manifest_client = None
|
|
74
|
+
|
|
75
|
+
# Current endpoint identification for local vs remote detection
|
|
76
|
+
self._current_endpoint = os.getenv("FLASH_RESOURCE_NAME") or os.getenv(
|
|
77
|
+
"RUNPOD_ENDPOINT_ID"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def _load_manifest(self, manifest_path: Optional[Path]) -> None:
|
|
81
|
+
"""Load flash_manifest.json.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
manifest_path: Explicit path to manifest. Tries env var and
|
|
85
|
+
auto-detection if not provided.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
FileNotFoundError: If manifest not found.
|
|
89
|
+
"""
|
|
90
|
+
paths_to_try = []
|
|
91
|
+
|
|
92
|
+
# Explicit path
|
|
93
|
+
if manifest_path:
|
|
94
|
+
paths_to_try.append(manifest_path)
|
|
95
|
+
|
|
96
|
+
# Environment variable
|
|
97
|
+
env_path = os.getenv("FLASH_MANIFEST_PATH")
|
|
98
|
+
if env_path:
|
|
99
|
+
paths_to_try.append(Path(env_path))
|
|
100
|
+
|
|
101
|
+
# Auto-detection: same directory as this file, or cwd
|
|
102
|
+
paths_to_try.extend(
|
|
103
|
+
[
|
|
104
|
+
Path(__file__).parent.parent.parent / "flash_manifest.json",
|
|
105
|
+
Path.cwd() / "flash_manifest.json",
|
|
106
|
+
]
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Try each path
|
|
110
|
+
for path in paths_to_try:
|
|
111
|
+
if path and path.exists():
|
|
112
|
+
try:
|
|
113
|
+
with open(path) as f:
|
|
114
|
+
manifest_dict = json.load(f)
|
|
115
|
+
self._manifest = Manifest.from_dict(manifest_dict)
|
|
116
|
+
logger.debug(f"Manifest loaded from {path}")
|
|
117
|
+
return
|
|
118
|
+
except Exception as e:
|
|
119
|
+
logger.warning(f"Failed to load manifest from {path}: {e}")
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
# No manifest found - log warning but don't fail
|
|
123
|
+
logger.warning(
|
|
124
|
+
"flash_manifest.json not found. Cross-endpoint routing disabled. "
|
|
125
|
+
"Manifest is required for routing functions between endpoints."
|
|
126
|
+
)
|
|
127
|
+
self._manifest = Manifest(
|
|
128
|
+
version="1.0",
|
|
129
|
+
generated_at="",
|
|
130
|
+
project_name="",
|
|
131
|
+
function_registry={},
|
|
132
|
+
resources={},
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
async def _ensure_manifest_loaded(self) -> None:
|
|
136
|
+
"""Load manifest from State Manager if cache expired or not loaded.
|
|
137
|
+
|
|
138
|
+
Peer-to-Peer Architecture:
|
|
139
|
+
Each endpoint queries State Manager independently using its own
|
|
140
|
+
RUNPOD_ENDPOINT_ID. No mothership dependency - all endpoints
|
|
141
|
+
are equal peers discovering each other through the manifest.
|
|
142
|
+
|
|
143
|
+
Query Flow:
|
|
144
|
+
1. get_flash_environment(RUNPOD_ENDPOINT_ID) → activeBuildId
|
|
145
|
+
2. get_flash_build(activeBuildId) → manifest
|
|
146
|
+
3. Extract manifest["resources_endpoints"] mapping
|
|
147
|
+
4. Cache for 300s (DEFAULT_CACHE_TTL)
|
|
148
|
+
|
|
149
|
+
State Manager Consistency:
|
|
150
|
+
- CLI updates manifest after provisioning all endpoints
|
|
151
|
+
- Endpoints cache manifest to reduce API calls
|
|
152
|
+
- TTL ensures eventual consistency (300s by default)
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
None. Updates self._endpoint_registry internally.
|
|
156
|
+
"""
|
|
157
|
+
async with self._endpoint_registry_lock:
|
|
158
|
+
now = time.time()
|
|
159
|
+
cache_age = now - self._endpoint_registry_loaded_at
|
|
160
|
+
|
|
161
|
+
if cache_age > self.cache_ttl:
|
|
162
|
+
if self._manifest_client is None:
|
|
163
|
+
logger.debug("State Manager client not available, skipping refresh")
|
|
164
|
+
return
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
mothership_id = os.getenv("RUNPOD_ENDPOINT_ID")
|
|
168
|
+
if not mothership_id:
|
|
169
|
+
logger.warning(
|
|
170
|
+
"RUNPOD_ENDPOINT_ID not set, cannot query State Manager"
|
|
171
|
+
)
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
# Query State Manager directly for full manifest
|
|
175
|
+
full_manifest = await self._manifest_client.get_persisted_manifest(
|
|
176
|
+
mothership_id
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Extract resources_endpoints mapping
|
|
180
|
+
resources_endpoints = full_manifest.get("resources_endpoints", {})
|
|
181
|
+
|
|
182
|
+
self._endpoint_registry = resources_endpoints
|
|
183
|
+
self._endpoint_registry_loaded_at = now
|
|
184
|
+
logger.debug(
|
|
185
|
+
f"Manifest loaded from State Manager: {len(self._endpoint_registry)} endpoints, "
|
|
186
|
+
f"cache TTL {self.cache_ttl}s"
|
|
187
|
+
)
|
|
188
|
+
except ManifestServiceUnavailableError as e:
|
|
189
|
+
logger.warning(
|
|
190
|
+
f"Failed to load manifest from State Manager: {e}. "
|
|
191
|
+
f"Cross-endpoint routing unavailable."
|
|
192
|
+
)
|
|
193
|
+
self._endpoint_registry = {}
|
|
194
|
+
|
|
195
|
+
async def get_endpoint_for_function(self, function_name: str) -> Optional[str]:
|
|
196
|
+
"""Get endpoint URL for a function.
|
|
197
|
+
|
|
198
|
+
Determines if function is local (same endpoint) or remote (different
|
|
199
|
+
endpoint), returning None for local and URL for remote.
|
|
200
|
+
|
|
201
|
+
Queries State Manager if endpoint registry cache is expired.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
function_name: Name of the function to route.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Endpoint URL if function is remote, None if local.
|
|
208
|
+
|
|
209
|
+
Raises:
|
|
210
|
+
ValueError: If function not in manifest.
|
|
211
|
+
"""
|
|
212
|
+
# Ensure manifest is loaded from State Manager (with caching)
|
|
213
|
+
await self._ensure_manifest_loaded()
|
|
214
|
+
|
|
215
|
+
function_registry = self._manifest.function_registry
|
|
216
|
+
|
|
217
|
+
if function_name not in function_registry:
|
|
218
|
+
raise ValueError(
|
|
219
|
+
f"Function '{function_name}' not found in manifest. "
|
|
220
|
+
f"Available functions: {list(function_registry.keys())}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
resource_config_name = function_registry[function_name]
|
|
224
|
+
|
|
225
|
+
# Check if this is the current endpoint (local)
|
|
226
|
+
if resource_config_name == self._current_endpoint:
|
|
227
|
+
return None
|
|
228
|
+
|
|
229
|
+
# Check manifest for remote endpoint URL
|
|
230
|
+
endpoint_url = self._endpoint_registry.get(resource_config_name)
|
|
231
|
+
if not endpoint_url:
|
|
232
|
+
logger.debug(
|
|
233
|
+
f"Endpoint URL for '{resource_config_name}' not in manifest. "
|
|
234
|
+
f"Manifest has: {list(self._endpoint_registry.keys())}"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return endpoint_url
|
|
238
|
+
|
|
239
|
+
async def get_resource_for_function(
|
|
240
|
+
self, function_name: str
|
|
241
|
+
) -> Optional[ServerlessResource]:
|
|
242
|
+
"""Get ServerlessResource for a function.
|
|
243
|
+
|
|
244
|
+
Creates a ServerlessResource with the correct endpoint ID if the function
|
|
245
|
+
is remote, returns None if local.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
function_name: Name of the function to route.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
ServerlessResource with ID set if function is remote
|
|
252
|
+
None if function runs on current endpoint
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
ValueError: If function not in manifest.
|
|
256
|
+
"""
|
|
257
|
+
endpoint_url = await self.get_endpoint_for_function(function_name)
|
|
258
|
+
|
|
259
|
+
if endpoint_url is None:
|
|
260
|
+
return None # Local function
|
|
261
|
+
|
|
262
|
+
# Extract endpoint ID from URL (format: https://{endpoint_base_url}/v2/{endpoint_id})
|
|
263
|
+
try:
|
|
264
|
+
parsed = urlparse(endpoint_url)
|
|
265
|
+
# Get the last path component (the endpoint ID)
|
|
266
|
+
path_parts = parsed.path.rstrip("/").split("/")
|
|
267
|
+
endpoint_id = path_parts[-1] if path_parts else ""
|
|
268
|
+
|
|
269
|
+
if not endpoint_id:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
f"Invalid endpoint URL format: {endpoint_url} - no endpoint ID found"
|
|
272
|
+
)
|
|
273
|
+
except Exception as e:
|
|
274
|
+
raise ValueError(
|
|
275
|
+
f"Failed to parse endpoint URL '{endpoint_url}': {e}"
|
|
276
|
+
) from e
|
|
277
|
+
|
|
278
|
+
# Create and return ServerlessResource
|
|
279
|
+
resource = ServerlessResource(name=f"remote_{function_name}")
|
|
280
|
+
resource.id = endpoint_id
|
|
281
|
+
|
|
282
|
+
return resource
|
|
283
|
+
|
|
284
|
+
async def is_local_function(self, function_name: str) -> bool:
|
|
285
|
+
"""Check if function executes on current endpoint.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
function_name: Name of the function.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
True if function is local, False if remote or not found.
|
|
292
|
+
"""
|
|
293
|
+
try:
|
|
294
|
+
endpoint_url = await self.get_endpoint_for_function(function_name)
|
|
295
|
+
return endpoint_url is None
|
|
296
|
+
except ValueError:
|
|
297
|
+
# Function not in manifest, assume local (will execute and fail)
|
|
298
|
+
return True
|
|
299
|
+
|
|
300
|
+
def get_current_endpoint_id(self) -> Optional[str]:
|
|
301
|
+
"""Get ID of current endpoint from environment.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Endpoint ID from FLASH_RESOURCE_NAME or RUNPOD_ENDPOINT_ID, or None if not set.
|
|
305
|
+
"""
|
|
306
|
+
return self._current_endpoint
|
|
307
|
+
|
|
308
|
+
def refresh_manifest(self) -> None:
|
|
309
|
+
"""Force refresh manifest from mothership on next access."""
|
|
310
|
+
self._endpoint_registry_loaded_at = 0
|
|
311
|
+
|
|
312
|
+
def get_manifest(self) -> Manifest:
|
|
313
|
+
"""Get loaded manifest.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Loaded Manifest object.
|
|
317
|
+
"""
|
|
318
|
+
return self._manifest
|
|
319
|
+
|
|
320
|
+
def get_all_resources(self) -> Dict[str, Dict]:
|
|
321
|
+
"""Get all resource configs from manifest.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Dictionary of resource configs as dictionaries.
|
|
325
|
+
"""
|
|
326
|
+
from dataclasses import asdict
|
|
327
|
+
|
|
328
|
+
return {
|
|
329
|
+
name: asdict(config) for name, config in self._manifest.resources.items()
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
def get_resource_functions(self, resource_name: str) -> list:
|
|
333
|
+
"""Get list of functions for a resource.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
resource_name: Name of the resource config.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
List of function metadata dictionaries.
|
|
340
|
+
"""
|
|
341
|
+
resource = self._manifest.resources.get(resource_name)
|
|
342
|
+
if not resource:
|
|
343
|
+
return []
|
|
344
|
+
from dataclasses import asdict
|
|
345
|
+
|
|
346
|
+
return [asdict(func) for func in resource.functions]
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
"""GraphQL client for State Manager API to persist and reconcile manifests."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, Dict, Optional
|
|
6
|
+
|
|
7
|
+
from tetra_rp.core.api.runpod import RunpodGraphQLClient
|
|
8
|
+
|
|
9
|
+
from .config import DEFAULT_MAX_RETRIES
|
|
10
|
+
from .exceptions import GraphQLError, ManifestServiceUnavailableError
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class StateManagerClient:
|
|
16
|
+
"""GraphQL client for State Manager manifest persistence.
|
|
17
|
+
|
|
18
|
+
The State Manager persists manifest state via RunPod GraphQL API,
|
|
19
|
+
providing reconciliation capabilities for the mothership to track
|
|
20
|
+
deployed resources across boots.
|
|
21
|
+
|
|
22
|
+
Thread Safety:
|
|
23
|
+
Uses asyncio.Lock to serialize read-modify-write operations,
|
|
24
|
+
preventing race conditions during concurrent resource updates.
|
|
25
|
+
|
|
26
|
+
Architecture:
|
|
27
|
+
Manifest updates follow a read-modify-write pattern:
|
|
28
|
+
1. Fetch environment -> activeBuildId
|
|
29
|
+
2. Fetch build -> manifest
|
|
30
|
+
3. Merge changes into manifest
|
|
31
|
+
4. Call updateFlashBuildManifest mutation
|
|
32
|
+
|
|
33
|
+
Performance:
|
|
34
|
+
Each update requires 3 GraphQL roundtrips. Consider batching
|
|
35
|
+
updates when provisioning multiple resources.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
41
|
+
):
|
|
42
|
+
"""Initialize State Manager client.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
max_retries: Maximum retry attempts for operations.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
RunpodAPIKeyError: If RUNPOD_API_KEY environment variable is not set (raised by RunpodGraphQLClient).
|
|
49
|
+
"""
|
|
50
|
+
self.max_retries = max_retries
|
|
51
|
+
self._manifest_lock = asyncio.Lock()
|
|
52
|
+
|
|
53
|
+
async def get_persisted_manifest(
|
|
54
|
+
self, mothership_id: str
|
|
55
|
+
) -> Optional[Dict[str, Any]]:
|
|
56
|
+
"""Fetch persisted manifest from State Manager.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
mothership_id: ID of the mothership endpoint.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Manifest dict.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
ManifestServiceUnavailableError: If State Manager unavailable after retries.
|
|
66
|
+
"""
|
|
67
|
+
last_exception: Optional[Exception] = None
|
|
68
|
+
|
|
69
|
+
for attempt in range(self.max_retries):
|
|
70
|
+
try:
|
|
71
|
+
async with RunpodGraphQLClient() as client:
|
|
72
|
+
_, manifest = await self._fetch_build_and_manifest(
|
|
73
|
+
client, mothership_id
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
logger.debug(f"Persisted manifest loaded for {mothership_id}")
|
|
77
|
+
return manifest
|
|
78
|
+
|
|
79
|
+
except (
|
|
80
|
+
asyncio.TimeoutError,
|
|
81
|
+
ManifestServiceUnavailableError,
|
|
82
|
+
GraphQLError,
|
|
83
|
+
ConnectionError,
|
|
84
|
+
) as e:
|
|
85
|
+
last_exception = e
|
|
86
|
+
if attempt < self.max_retries - 1:
|
|
87
|
+
backoff = 2**attempt
|
|
88
|
+
logger.warning(
|
|
89
|
+
f"State Manager request failed (attempt {attempt + 1}): {e}, "
|
|
90
|
+
f"retrying in {backoff}s..."
|
|
91
|
+
)
|
|
92
|
+
await asyncio.sleep(backoff)
|
|
93
|
+
continue
|
|
94
|
+
|
|
95
|
+
raise ManifestServiceUnavailableError(
|
|
96
|
+
f"Failed to fetch persisted manifest after {self.max_retries} attempts: "
|
|
97
|
+
f"{last_exception}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
async def update_resource_state(
|
|
101
|
+
self,
|
|
102
|
+
mothership_id: str,
|
|
103
|
+
resource_name: str,
|
|
104
|
+
resource_data: Dict[str, Any],
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Update single resource entry in State Manager.
|
|
107
|
+
|
|
108
|
+
Uses locking to prevent race conditions when multiple resources
|
|
109
|
+
are deployed concurrently.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
mothership_id: ID of the mothership endpoint.
|
|
113
|
+
resource_name: Name of the resource.
|
|
114
|
+
resource_data: Resource metadata (config_hash, endpoint_url, status, etc).
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
ManifestServiceUnavailableError: If State Manager unavailable.
|
|
118
|
+
"""
|
|
119
|
+
last_exception: Optional[Exception] = None
|
|
120
|
+
|
|
121
|
+
for attempt in range(self.max_retries):
|
|
122
|
+
try:
|
|
123
|
+
async with self._manifest_lock:
|
|
124
|
+
async with RunpodGraphQLClient() as client:
|
|
125
|
+
build_id, manifest = await self._fetch_build_and_manifest(
|
|
126
|
+
client, mothership_id
|
|
127
|
+
)
|
|
128
|
+
resources = manifest.setdefault("resources", {})
|
|
129
|
+
existing = resources.get(resource_name)
|
|
130
|
+
if not isinstance(existing, dict):
|
|
131
|
+
existing = {}
|
|
132
|
+
resources[resource_name] = {**existing, **resource_data}
|
|
133
|
+
await client.update_build_manifest(build_id, manifest)
|
|
134
|
+
|
|
135
|
+
logger.debug(
|
|
136
|
+
f"Updated resource state in State Manager: {mothership_id}/{resource_name}"
|
|
137
|
+
)
|
|
138
|
+
return
|
|
139
|
+
|
|
140
|
+
except (
|
|
141
|
+
asyncio.TimeoutError,
|
|
142
|
+
ManifestServiceUnavailableError,
|
|
143
|
+
GraphQLError,
|
|
144
|
+
ConnectionError,
|
|
145
|
+
) as e:
|
|
146
|
+
last_exception = e
|
|
147
|
+
if attempt < self.max_retries - 1:
|
|
148
|
+
backoff = 2**attempt
|
|
149
|
+
logger.warning(
|
|
150
|
+
f"State Manager request failed (attempt {attempt + 1}): {e}, "
|
|
151
|
+
f"retrying in {backoff}s..."
|
|
152
|
+
)
|
|
153
|
+
await asyncio.sleep(backoff)
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
raise ManifestServiceUnavailableError(
|
|
157
|
+
f"Failed to update resource state after {self.max_retries} attempts: "
|
|
158
|
+
f"{last_exception}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
async def remove_resource_state(
|
|
162
|
+
self, mothership_id: str, resource_name: str
|
|
163
|
+
) -> None:
|
|
164
|
+
"""Remove resource entry from State Manager.
|
|
165
|
+
|
|
166
|
+
Uses locking to prevent race conditions when multiple resources
|
|
167
|
+
are deployed concurrently.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
mothership_id: ID of the mothership endpoint.
|
|
171
|
+
resource_name: Name of the resource.
|
|
172
|
+
|
|
173
|
+
Raises:
|
|
174
|
+
ManifestServiceUnavailableError: If State Manager unavailable.
|
|
175
|
+
"""
|
|
176
|
+
last_exception: Optional[Exception] = None
|
|
177
|
+
|
|
178
|
+
for attempt in range(self.max_retries):
|
|
179
|
+
try:
|
|
180
|
+
async with self._manifest_lock:
|
|
181
|
+
async with RunpodGraphQLClient() as client:
|
|
182
|
+
build_id, manifest = await self._fetch_build_and_manifest(
|
|
183
|
+
client, mothership_id
|
|
184
|
+
)
|
|
185
|
+
resources = manifest.setdefault("resources", {})
|
|
186
|
+
resources.pop(resource_name, None)
|
|
187
|
+
await client.update_build_manifest(build_id, manifest)
|
|
188
|
+
|
|
189
|
+
logger.debug(
|
|
190
|
+
f"Removed resource state from State Manager: {mothership_id}/{resource_name}"
|
|
191
|
+
)
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
except (
|
|
195
|
+
asyncio.TimeoutError,
|
|
196
|
+
ManifestServiceUnavailableError,
|
|
197
|
+
GraphQLError,
|
|
198
|
+
ConnectionError,
|
|
199
|
+
) as e:
|
|
200
|
+
last_exception = e
|
|
201
|
+
if attempt < self.max_retries - 1:
|
|
202
|
+
backoff = 2**attempt
|
|
203
|
+
logger.warning(
|
|
204
|
+
f"State Manager request failed (attempt {attempt + 1}): {e}, "
|
|
205
|
+
f"retrying in {backoff}s..."
|
|
206
|
+
)
|
|
207
|
+
await asyncio.sleep(backoff)
|
|
208
|
+
continue
|
|
209
|
+
|
|
210
|
+
raise ManifestServiceUnavailableError(
|
|
211
|
+
f"Failed to remove resource state after {self.max_retries} attempts: "
|
|
212
|
+
f"{last_exception}"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
async def _fetch_build_and_manifest(
|
|
216
|
+
self, client: RunpodGraphQLClient, mothership_id: str
|
|
217
|
+
) -> tuple[str, Dict[str, Any]]:
|
|
218
|
+
"""Fetch active build ID and manifest for an environment.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
client: Authenticated GraphQL client.
|
|
222
|
+
mothership_id: Flash environment ID.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Tuple of (build_id, manifest_dict).
|
|
226
|
+
|
|
227
|
+
Raises:
|
|
228
|
+
ManifestServiceUnavailableError: If environment, build, or manifest not found.
|
|
229
|
+
"""
|
|
230
|
+
environment = await client.get_flash_environment(
|
|
231
|
+
{"flashEnvironmentId": mothership_id}
|
|
232
|
+
)
|
|
233
|
+
build_id = environment.get("activeBuildId")
|
|
234
|
+
if not build_id:
|
|
235
|
+
raise ManifestServiceUnavailableError(
|
|
236
|
+
f"Active build not found for environment {mothership_id}. "
|
|
237
|
+
f"Environment may not be fully initialized or has no deployed build."
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
build = await client.get_flash_build(build_id)
|
|
241
|
+
manifest = build.get("manifest")
|
|
242
|
+
if not manifest:
|
|
243
|
+
raise ManifestServiceUnavailableError(
|
|
244
|
+
f"Manifest not found for build {build.get('id', build_id)}. "
|
|
245
|
+
f"Build may be corrupted, not yet published, or manifest was not generated."
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return build_id, manifest
|
|
@@ -4,6 +4,7 @@ import inspect
|
|
|
4
4
|
import textwrap
|
|
5
5
|
import hashlib
|
|
6
6
|
import traceback
|
|
7
|
+
import threading
|
|
7
8
|
import cloudpickle
|
|
8
9
|
import logging
|
|
9
10
|
from ..core.resources import LiveServerless
|
|
@@ -12,26 +13,37 @@ from ..protos.remote_execution import (
|
|
|
12
13
|
FunctionResponse,
|
|
13
14
|
RemoteExecutorStub,
|
|
14
15
|
)
|
|
16
|
+
from ..runtime.serialization import serialize_args, serialize_kwargs
|
|
15
17
|
|
|
16
18
|
log = logging.getLogger(__name__)
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
#
|
|
21
|
+
# Global in-memory cache with thread safety
|
|
20
22
|
_SERIALIZED_FUNCTION_CACHE = {}
|
|
23
|
+
_function_cache_lock = threading.RLock()
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
def get_function_source(func):
|
|
24
27
|
"""Extract the function source code without the decorator."""
|
|
28
|
+
# Unwrap any decorators to get the original function
|
|
29
|
+
func = inspect.unwrap(func)
|
|
30
|
+
|
|
25
31
|
# Get the source code of the decorated function
|
|
26
32
|
source = inspect.getsource(func)
|
|
27
33
|
|
|
34
|
+
# Dedent the source to handle functions defined in classes or indented contexts
|
|
35
|
+
source = textwrap.dedent(source)
|
|
36
|
+
|
|
28
37
|
# Parse the source code
|
|
29
38
|
module = ast.parse(source)
|
|
30
39
|
|
|
31
|
-
# Find the function definition node
|
|
40
|
+
# Find the function definition node (both sync and async)
|
|
32
41
|
function_def = None
|
|
33
42
|
for node in ast.walk(module):
|
|
34
|
-
if
|
|
43
|
+
if (
|
|
44
|
+
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
|
45
|
+
and node.name == func.__name__
|
|
46
|
+
):
|
|
35
47
|
function_def = node
|
|
36
48
|
break
|
|
37
49
|
|
|
@@ -60,32 +72,38 @@ class LiveServerlessStub(RemoteExecutorStub):
|
|
|
60
72
|
def __init__(self, server: LiveServerless):
|
|
61
73
|
self.server = server
|
|
62
74
|
|
|
63
|
-
def prepare_request(
|
|
75
|
+
def prepare_request(
|
|
76
|
+
self,
|
|
77
|
+
func,
|
|
78
|
+
dependencies,
|
|
79
|
+
system_dependencies,
|
|
80
|
+
accelerate_downloads,
|
|
81
|
+
*args,
|
|
82
|
+
**kwargs,
|
|
83
|
+
):
|
|
64
84
|
source, src_hash = get_function_source(func)
|
|
65
85
|
|
|
66
86
|
request = {
|
|
67
87
|
"function_name": func.__name__,
|
|
68
88
|
"dependencies": dependencies,
|
|
69
89
|
"system_dependencies": system_dependencies,
|
|
90
|
+
"accelerate_downloads": accelerate_downloads,
|
|
70
91
|
}
|
|
71
92
|
|
|
72
|
-
#
|
|
73
|
-
|
|
74
|
-
#
|
|
75
|
-
|
|
93
|
+
# Thread-safe cache access
|
|
94
|
+
with _function_cache_lock:
|
|
95
|
+
# check if the function is already cached
|
|
96
|
+
if src_hash not in _SERIALIZED_FUNCTION_CACHE:
|
|
97
|
+
# Cache the serialized function
|
|
98
|
+
_SERIALIZED_FUNCTION_CACHE[src_hash] = source
|
|
76
99
|
|
|
77
|
-
|
|
100
|
+
request["function_code"] = _SERIALIZED_FUNCTION_CACHE[src_hash]
|
|
78
101
|
|
|
79
102
|
# Serialize arguments using cloudpickle
|
|
80
103
|
if args:
|
|
81
|
-
request["args"] =
|
|
82
|
-
base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8") for arg in args
|
|
83
|
-
]
|
|
104
|
+
request["args"] = serialize_args(args)
|
|
84
105
|
if kwargs:
|
|
85
|
-
request["kwargs"] =
|
|
86
|
-
k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
|
|
87
|
-
for k, v in kwargs.items()
|
|
88
|
-
}
|
|
106
|
+
request["kwargs"] = serialize_kwargs(kwargs)
|
|
89
107
|
|
|
90
108
|
return FunctionRequest(**request)
|
|
91
109
|
|
|
@@ -95,7 +113,7 @@ class LiveServerlessStub(RemoteExecutorStub):
|
|
|
95
113
|
|
|
96
114
|
if response.stdout:
|
|
97
115
|
for line in response.stdout.splitlines():
|
|
98
|
-
|
|
116
|
+
print(line)
|
|
99
117
|
|
|
100
118
|
if response.success:
|
|
101
119
|
if response.result is None:
|