tetra-rp 0.6.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. tetra_rp/__init__.py +109 -19
  2. tetra_rp/cli/commands/__init__.py +1 -0
  3. tetra_rp/cli/commands/apps.py +143 -0
  4. tetra_rp/cli/commands/build.py +1082 -0
  5. tetra_rp/cli/commands/build_utils/__init__.py +1 -0
  6. tetra_rp/cli/commands/build_utils/handler_generator.py +176 -0
  7. tetra_rp/cli/commands/build_utils/lb_handler_generator.py +309 -0
  8. tetra_rp/cli/commands/build_utils/manifest.py +430 -0
  9. tetra_rp/cli/commands/build_utils/mothership_handler_generator.py +75 -0
  10. tetra_rp/cli/commands/build_utils/scanner.py +596 -0
  11. tetra_rp/cli/commands/deploy.py +580 -0
  12. tetra_rp/cli/commands/init.py +123 -0
  13. tetra_rp/cli/commands/resource.py +108 -0
  14. tetra_rp/cli/commands/run.py +296 -0
  15. tetra_rp/cli/commands/test_mothership.py +458 -0
  16. tetra_rp/cli/commands/undeploy.py +533 -0
  17. tetra_rp/cli/main.py +97 -0
  18. tetra_rp/cli/utils/__init__.py +1 -0
  19. tetra_rp/cli/utils/app.py +15 -0
  20. tetra_rp/cli/utils/conda.py +127 -0
  21. tetra_rp/cli/utils/deployment.py +530 -0
  22. tetra_rp/cli/utils/ignore.py +143 -0
  23. tetra_rp/cli/utils/skeleton.py +184 -0
  24. tetra_rp/cli/utils/skeleton_template/.env.example +4 -0
  25. tetra_rp/cli/utils/skeleton_template/.flashignore +40 -0
  26. tetra_rp/cli/utils/skeleton_template/.gitignore +44 -0
  27. tetra_rp/cli/utils/skeleton_template/README.md +263 -0
  28. tetra_rp/cli/utils/skeleton_template/main.py +44 -0
  29. tetra_rp/cli/utils/skeleton_template/mothership.py +55 -0
  30. tetra_rp/cli/utils/skeleton_template/pyproject.toml +58 -0
  31. tetra_rp/cli/utils/skeleton_template/requirements.txt +1 -0
  32. tetra_rp/cli/utils/skeleton_template/workers/__init__.py +0 -0
  33. tetra_rp/cli/utils/skeleton_template/workers/cpu/__init__.py +19 -0
  34. tetra_rp/cli/utils/skeleton_template/workers/cpu/endpoint.py +36 -0
  35. tetra_rp/cli/utils/skeleton_template/workers/gpu/__init__.py +19 -0
  36. tetra_rp/cli/utils/skeleton_template/workers/gpu/endpoint.py +61 -0
  37. tetra_rp/client.py +136 -33
  38. tetra_rp/config.py +29 -0
  39. tetra_rp/core/api/runpod.py +591 -39
  40. tetra_rp/core/deployment.py +232 -0
  41. tetra_rp/core/discovery.py +425 -0
  42. tetra_rp/core/exceptions.py +50 -0
  43. tetra_rp/core/resources/__init__.py +27 -9
  44. tetra_rp/core/resources/app.py +738 -0
  45. tetra_rp/core/resources/base.py +139 -4
  46. tetra_rp/core/resources/constants.py +21 -0
  47. tetra_rp/core/resources/cpu.py +115 -13
  48. tetra_rp/core/resources/gpu.py +182 -16
  49. tetra_rp/core/resources/live_serverless.py +153 -16
  50. tetra_rp/core/resources/load_balancer_sls_resource.py +440 -0
  51. tetra_rp/core/resources/network_volume.py +126 -31
  52. tetra_rp/core/resources/resource_manager.py +436 -35
  53. tetra_rp/core/resources/serverless.py +537 -120
  54. tetra_rp/core/resources/serverless_cpu.py +201 -0
  55. tetra_rp/core/resources/template.py +1 -59
  56. tetra_rp/core/utils/constants.py +10 -0
  57. tetra_rp/core/utils/file_lock.py +260 -0
  58. tetra_rp/core/utils/http.py +67 -0
  59. tetra_rp/core/utils/lru_cache.py +75 -0
  60. tetra_rp/core/utils/singleton.py +36 -1
  61. tetra_rp/core/validation.py +44 -0
  62. tetra_rp/execute_class.py +301 -0
  63. tetra_rp/protos/remote_execution.py +98 -9
  64. tetra_rp/runtime/__init__.py +1 -0
  65. tetra_rp/runtime/circuit_breaker.py +274 -0
  66. tetra_rp/runtime/config.py +12 -0
  67. tetra_rp/runtime/exceptions.py +49 -0
  68. tetra_rp/runtime/generic_handler.py +206 -0
  69. tetra_rp/runtime/lb_handler.py +189 -0
  70. tetra_rp/runtime/load_balancer.py +160 -0
  71. tetra_rp/runtime/manifest_fetcher.py +192 -0
  72. tetra_rp/runtime/metrics.py +325 -0
  73. tetra_rp/runtime/models.py +73 -0
  74. tetra_rp/runtime/mothership_provisioner.py +512 -0
  75. tetra_rp/runtime/production_wrapper.py +266 -0
  76. tetra_rp/runtime/reliability_config.py +149 -0
  77. tetra_rp/runtime/retry_manager.py +118 -0
  78. tetra_rp/runtime/serialization.py +124 -0
  79. tetra_rp/runtime/service_registry.py +346 -0
  80. tetra_rp/runtime/state_manager_client.py +248 -0
  81. tetra_rp/stubs/live_serverless.py +35 -17
  82. tetra_rp/stubs/load_balancer_sls.py +357 -0
  83. tetra_rp/stubs/registry.py +145 -19
  84. {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/METADATA +398 -60
  85. tetra_rp-0.24.0.dist-info/RECORD +99 -0
  86. {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/WHEEL +1 -1
  87. tetra_rp-0.24.0.dist-info/entry_points.txt +2 -0
  88. tetra_rp/core/pool/cluster_manager.py +0 -177
  89. tetra_rp/core/pool/dataclass.py +0 -18
  90. tetra_rp/core/pool/ex.py +0 -38
  91. tetra_rp/core/pool/job.py +0 -22
  92. tetra_rp/core/pool/worker.py +0 -19
  93. tetra_rp/core/resources/utils.py +0 -50
  94. tetra_rp/core/utils/json.py +0 -33
  95. tetra_rp-0.6.0.dist-info/RECORD +0 -39
  96. /tetra_rp/{core/pool → cli}/__init__.py +0 -0
  97. {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,346 @@
1
+ """Runtime service registry for cross-endpoint function routing."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import os
7
+ import time
8
+ from pathlib import Path
9
+ from typing import Dict, Optional
10
+ from urllib.parse import urlparse
11
+
12
+ from tetra_rp.core.resources.serverless import ServerlessResource
13
+
14
+ from .config import DEFAULT_CACHE_TTL
15
+ from .state_manager_client import StateManagerClient, ManifestServiceUnavailableError
16
+ from .models import Manifest
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ServiceRegistry:
22
+ """Service discovery and routing for cross-endpoint function calls.
23
+
24
+ Loads manifest to map functions to resource configs, queries mothership
25
+ manifest for endpoint URLs, and determines if function calls are local
26
+ or remote.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ manifest_path: Optional[Path] = None,
32
+ cache_ttl: int = DEFAULT_CACHE_TTL,
33
+ ):
34
+ """Initialize service registry with peer-to-peer State Manager access.
35
+
36
+ All endpoints query State Manager directly for manifest updates.
37
+ No Mothership dependency - all endpoints are equal peers.
38
+
39
+ Args:
40
+ manifest_path: Path to flash_manifest.json. Defaults to
41
+ FLASH_MANIFEST_PATH env var or auto-detection.
42
+ cache_ttl: Manifest cache lifetime in seconds (default: 300).
43
+
44
+ Environment Variables:
45
+ FLASH_RESOURCE_NAME: Resource config name for this endpoint.
46
+ Identifies which resource config this endpoint represents.
47
+ RUNPOD_ENDPOINT_ID: Endpoint ID (used for State Manager queries and fallback).
48
+ RUNPOD_API_KEY: API key for State Manager GraphQL access.
49
+
50
+ Raises:
51
+ FileNotFoundError: If manifest_path doesn't exist.
52
+ """
53
+ self.cache_ttl = cache_ttl
54
+ self._endpoint_registry: Dict[str, str] = {}
55
+ self._endpoint_registry_loaded_at = 0.0
56
+ self._manifest: Manifest = Manifest(
57
+ version="1.0",
58
+ generated_at="",
59
+ project_name="",
60
+ function_registry={},
61
+ resources={},
62
+ )
63
+ self._endpoint_registry_lock = asyncio.Lock()
64
+
65
+ # Load manifest
66
+ self._load_manifest(manifest_path)
67
+
68
+ # Peer-to-peer: All endpoints use StateManagerClient directly
69
+ try:
70
+ self._manifest_client = StateManagerClient()
71
+ except Exception as e:
72
+ logger.warning(f"Failed to initialize State Manager client: {e}")
73
+ self._manifest_client = None
74
+
75
+ # Current endpoint identification for local vs remote detection
76
+ self._current_endpoint = os.getenv("FLASH_RESOURCE_NAME") or os.getenv(
77
+ "RUNPOD_ENDPOINT_ID"
78
+ )
79
+
80
+ def _load_manifest(self, manifest_path: Optional[Path]) -> None:
81
+ """Load flash_manifest.json.
82
+
83
+ Args:
84
+ manifest_path: Explicit path to manifest. Tries env var and
85
+ auto-detection if not provided.
86
+
87
+ Raises:
88
+ FileNotFoundError: If manifest not found.
89
+ """
90
+ paths_to_try = []
91
+
92
+ # Explicit path
93
+ if manifest_path:
94
+ paths_to_try.append(manifest_path)
95
+
96
+ # Environment variable
97
+ env_path = os.getenv("FLASH_MANIFEST_PATH")
98
+ if env_path:
99
+ paths_to_try.append(Path(env_path))
100
+
101
+ # Auto-detection: same directory as this file, or cwd
102
+ paths_to_try.extend(
103
+ [
104
+ Path(__file__).parent.parent.parent / "flash_manifest.json",
105
+ Path.cwd() / "flash_manifest.json",
106
+ ]
107
+ )
108
+
109
+ # Try each path
110
+ for path in paths_to_try:
111
+ if path and path.exists():
112
+ try:
113
+ with open(path) as f:
114
+ manifest_dict = json.load(f)
115
+ self._manifest = Manifest.from_dict(manifest_dict)
116
+ logger.debug(f"Manifest loaded from {path}")
117
+ return
118
+ except Exception as e:
119
+ logger.warning(f"Failed to load manifest from {path}: {e}")
120
+ continue
121
+
122
+ # No manifest found - log warning but don't fail
123
+ logger.warning(
124
+ "flash_manifest.json not found. Cross-endpoint routing disabled. "
125
+ "Manifest is required for routing functions between endpoints."
126
+ )
127
+ self._manifest = Manifest(
128
+ version="1.0",
129
+ generated_at="",
130
+ project_name="",
131
+ function_registry={},
132
+ resources={},
133
+ )
134
+
135
+ async def _ensure_manifest_loaded(self) -> None:
136
+ """Load manifest from State Manager if cache expired or not loaded.
137
+
138
+ Peer-to-Peer Architecture:
139
+ Each endpoint queries State Manager independently using its own
140
+ RUNPOD_ENDPOINT_ID. No mothership dependency - all endpoints
141
+ are equal peers discovering each other through the manifest.
142
+
143
+ Query Flow:
144
+ 1. get_flash_environment(RUNPOD_ENDPOINT_ID) → activeBuildId
145
+ 2. get_flash_build(activeBuildId) → manifest
146
+ 3. Extract manifest["resources_endpoints"] mapping
147
+ 4. Cache for 300s (DEFAULT_CACHE_TTL)
148
+
149
+ State Manager Consistency:
150
+ - CLI updates manifest after provisioning all endpoints
151
+ - Endpoints cache manifest to reduce API calls
152
+ - TTL ensures eventual consistency (300s by default)
153
+
154
+ Returns:
155
+ None. Updates self._endpoint_registry internally.
156
+ """
157
+ async with self._endpoint_registry_lock:
158
+ now = time.time()
159
+ cache_age = now - self._endpoint_registry_loaded_at
160
+
161
+ if cache_age > self.cache_ttl:
162
+ if self._manifest_client is None:
163
+ logger.debug("State Manager client not available, skipping refresh")
164
+ return
165
+
166
+ try:
167
+ mothership_id = os.getenv("RUNPOD_ENDPOINT_ID")
168
+ if not mothership_id:
169
+ logger.warning(
170
+ "RUNPOD_ENDPOINT_ID not set, cannot query State Manager"
171
+ )
172
+ return
173
+
174
+ # Query State Manager directly for full manifest
175
+ full_manifest = await self._manifest_client.get_persisted_manifest(
176
+ mothership_id
177
+ )
178
+
179
+ # Extract resources_endpoints mapping
180
+ resources_endpoints = full_manifest.get("resources_endpoints", {})
181
+
182
+ self._endpoint_registry = resources_endpoints
183
+ self._endpoint_registry_loaded_at = now
184
+ logger.debug(
185
+ f"Manifest loaded from State Manager: {len(self._endpoint_registry)} endpoints, "
186
+ f"cache TTL {self.cache_ttl}s"
187
+ )
188
+ except ManifestServiceUnavailableError as e:
189
+ logger.warning(
190
+ f"Failed to load manifest from State Manager: {e}. "
191
+ f"Cross-endpoint routing unavailable."
192
+ )
193
+ self._endpoint_registry = {}
194
+
195
+ async def get_endpoint_for_function(self, function_name: str) -> Optional[str]:
196
+ """Get endpoint URL for a function.
197
+
198
+ Determines if function is local (same endpoint) or remote (different
199
+ endpoint), returning None for local and URL for remote.
200
+
201
+ Queries State Manager if endpoint registry cache is expired.
202
+
203
+ Args:
204
+ function_name: Name of the function to route.
205
+
206
+ Returns:
207
+ Endpoint URL if function is remote, None if local.
208
+
209
+ Raises:
210
+ ValueError: If function not in manifest.
211
+ """
212
+ # Ensure manifest is loaded from State Manager (with caching)
213
+ await self._ensure_manifest_loaded()
214
+
215
+ function_registry = self._manifest.function_registry
216
+
217
+ if function_name not in function_registry:
218
+ raise ValueError(
219
+ f"Function '{function_name}' not found in manifest. "
220
+ f"Available functions: {list(function_registry.keys())}"
221
+ )
222
+
223
+ resource_config_name = function_registry[function_name]
224
+
225
+ # Check if this is the current endpoint (local)
226
+ if resource_config_name == self._current_endpoint:
227
+ return None
228
+
229
+ # Check manifest for remote endpoint URL
230
+ endpoint_url = self._endpoint_registry.get(resource_config_name)
231
+ if not endpoint_url:
232
+ logger.debug(
233
+ f"Endpoint URL for '{resource_config_name}' not in manifest. "
234
+ f"Manifest has: {list(self._endpoint_registry.keys())}"
235
+ )
236
+
237
+ return endpoint_url
238
+
239
+ async def get_resource_for_function(
240
+ self, function_name: str
241
+ ) -> Optional[ServerlessResource]:
242
+ """Get ServerlessResource for a function.
243
+
244
+ Creates a ServerlessResource with the correct endpoint ID if the function
245
+ is remote, returns None if local.
246
+
247
+ Args:
248
+ function_name: Name of the function to route.
249
+
250
+ Returns:
251
+ ServerlessResource with ID set if function is remote
252
+ None if function runs on current endpoint
253
+
254
+ Raises:
255
+ ValueError: If function not in manifest.
256
+ """
257
+ endpoint_url = await self.get_endpoint_for_function(function_name)
258
+
259
+ if endpoint_url is None:
260
+ return None # Local function
261
+
262
+ # Extract endpoint ID from URL (format: https://{endpoint_base_url}/v2/{endpoint_id})
263
+ try:
264
+ parsed = urlparse(endpoint_url)
265
+ # Get the last path component (the endpoint ID)
266
+ path_parts = parsed.path.rstrip("/").split("/")
267
+ endpoint_id = path_parts[-1] if path_parts else ""
268
+
269
+ if not endpoint_id:
270
+ raise ValueError(
271
+ f"Invalid endpoint URL format: {endpoint_url} - no endpoint ID found"
272
+ )
273
+ except Exception as e:
274
+ raise ValueError(
275
+ f"Failed to parse endpoint URL '{endpoint_url}': {e}"
276
+ ) from e
277
+
278
+ # Create and return ServerlessResource
279
+ resource = ServerlessResource(name=f"remote_{function_name}")
280
+ resource.id = endpoint_id
281
+
282
+ return resource
283
+
284
+ async def is_local_function(self, function_name: str) -> bool:
285
+ """Check if function executes on current endpoint.
286
+
287
+ Args:
288
+ function_name: Name of the function.
289
+
290
+ Returns:
291
+ True if function is local, False if remote or not found.
292
+ """
293
+ try:
294
+ endpoint_url = await self.get_endpoint_for_function(function_name)
295
+ return endpoint_url is None
296
+ except ValueError:
297
+ # Function not in manifest, assume local (will execute and fail)
298
+ return True
299
+
300
+ def get_current_endpoint_id(self) -> Optional[str]:
301
+ """Get ID of current endpoint from environment.
302
+
303
+ Returns:
304
+ Endpoint ID from FLASH_RESOURCE_NAME or RUNPOD_ENDPOINT_ID, or None if not set.
305
+ """
306
+ return self._current_endpoint
307
+
308
+ def refresh_manifest(self) -> None:
309
+ """Force refresh manifest from mothership on next access."""
310
+ self._endpoint_registry_loaded_at = 0
311
+
312
+ def get_manifest(self) -> Manifest:
313
+ """Get loaded manifest.
314
+
315
+ Returns:
316
+ Loaded Manifest object.
317
+ """
318
+ return self._manifest
319
+
320
+ def get_all_resources(self) -> Dict[str, Dict]:
321
+ """Get all resource configs from manifest.
322
+
323
+ Returns:
324
+ Dictionary of resource configs as dictionaries.
325
+ """
326
+ from dataclasses import asdict
327
+
328
+ return {
329
+ name: asdict(config) for name, config in self._manifest.resources.items()
330
+ }
331
+
332
+ def get_resource_functions(self, resource_name: str) -> list:
333
+ """Get list of functions for a resource.
334
+
335
+ Args:
336
+ resource_name: Name of the resource config.
337
+
338
+ Returns:
339
+ List of function metadata dictionaries.
340
+ """
341
+ resource = self._manifest.resources.get(resource_name)
342
+ if not resource:
343
+ return []
344
+ from dataclasses import asdict
345
+
346
+ return [asdict(func) for func in resource.functions]
@@ -0,0 +1,248 @@
1
+ """GraphQL client for State Manager API to persist and reconcile manifests."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import Any, Dict, Optional
6
+
7
+ from tetra_rp.core.api.runpod import RunpodGraphQLClient
8
+
9
+ from .config import DEFAULT_MAX_RETRIES
10
+ from .exceptions import GraphQLError, ManifestServiceUnavailableError
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class StateManagerClient:
16
+ """GraphQL client for State Manager manifest persistence.
17
+
18
+ The State Manager persists manifest state via RunPod GraphQL API,
19
+ providing reconciliation capabilities for the mothership to track
20
+ deployed resources across boots.
21
+
22
+ Thread Safety:
23
+ Uses asyncio.Lock to serialize read-modify-write operations,
24
+ preventing race conditions during concurrent resource updates.
25
+
26
+ Architecture:
27
+ Manifest updates follow a read-modify-write pattern:
28
+ 1. Fetch environment -> activeBuildId
29
+ 2. Fetch build -> manifest
30
+ 3. Merge changes into manifest
31
+ 4. Call updateFlashBuildManifest mutation
32
+
33
+ Performance:
34
+ Each update requires 3 GraphQL roundtrips. Consider batching
35
+ updates when provisioning multiple resources.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ max_retries: int = DEFAULT_MAX_RETRIES,
41
+ ):
42
+ """Initialize State Manager client.
43
+
44
+ Args:
45
+ max_retries: Maximum retry attempts for operations.
46
+
47
+ Raises:
48
+ RunpodAPIKeyError: If RUNPOD_API_KEY environment variable is not set (raised by RunpodGraphQLClient).
49
+ """
50
+ self.max_retries = max_retries
51
+ self._manifest_lock = asyncio.Lock()
52
+
53
+ async def get_persisted_manifest(
54
+ self, mothership_id: str
55
+ ) -> Optional[Dict[str, Any]]:
56
+ """Fetch persisted manifest from State Manager.
57
+
58
+ Args:
59
+ mothership_id: ID of the mothership endpoint.
60
+
61
+ Returns:
62
+ Manifest dict.
63
+
64
+ Raises:
65
+ ManifestServiceUnavailableError: If State Manager unavailable after retries.
66
+ """
67
+ last_exception: Optional[Exception] = None
68
+
69
+ for attempt in range(self.max_retries):
70
+ try:
71
+ async with RunpodGraphQLClient() as client:
72
+ _, manifest = await self._fetch_build_and_manifest(
73
+ client, mothership_id
74
+ )
75
+
76
+ logger.debug(f"Persisted manifest loaded for {mothership_id}")
77
+ return manifest
78
+
79
+ except (
80
+ asyncio.TimeoutError,
81
+ ManifestServiceUnavailableError,
82
+ GraphQLError,
83
+ ConnectionError,
84
+ ) as e:
85
+ last_exception = e
86
+ if attempt < self.max_retries - 1:
87
+ backoff = 2**attempt
88
+ logger.warning(
89
+ f"State Manager request failed (attempt {attempt + 1}): {e}, "
90
+ f"retrying in {backoff}s..."
91
+ )
92
+ await asyncio.sleep(backoff)
93
+ continue
94
+
95
+ raise ManifestServiceUnavailableError(
96
+ f"Failed to fetch persisted manifest after {self.max_retries} attempts: "
97
+ f"{last_exception}"
98
+ )
99
+
100
+ async def update_resource_state(
101
+ self,
102
+ mothership_id: str,
103
+ resource_name: str,
104
+ resource_data: Dict[str, Any],
105
+ ) -> None:
106
+ """Update single resource entry in State Manager.
107
+
108
+ Uses locking to prevent race conditions when multiple resources
109
+ are deployed concurrently.
110
+
111
+ Args:
112
+ mothership_id: ID of the mothership endpoint.
113
+ resource_name: Name of the resource.
114
+ resource_data: Resource metadata (config_hash, endpoint_url, status, etc).
115
+
116
+ Raises:
117
+ ManifestServiceUnavailableError: If State Manager unavailable.
118
+ """
119
+ last_exception: Optional[Exception] = None
120
+
121
+ for attempt in range(self.max_retries):
122
+ try:
123
+ async with self._manifest_lock:
124
+ async with RunpodGraphQLClient() as client:
125
+ build_id, manifest = await self._fetch_build_and_manifest(
126
+ client, mothership_id
127
+ )
128
+ resources = manifest.setdefault("resources", {})
129
+ existing = resources.get(resource_name)
130
+ if not isinstance(existing, dict):
131
+ existing = {}
132
+ resources[resource_name] = {**existing, **resource_data}
133
+ await client.update_build_manifest(build_id, manifest)
134
+
135
+ logger.debug(
136
+ f"Updated resource state in State Manager: {mothership_id}/{resource_name}"
137
+ )
138
+ return
139
+
140
+ except (
141
+ asyncio.TimeoutError,
142
+ ManifestServiceUnavailableError,
143
+ GraphQLError,
144
+ ConnectionError,
145
+ ) as e:
146
+ last_exception = e
147
+ if attempt < self.max_retries - 1:
148
+ backoff = 2**attempt
149
+ logger.warning(
150
+ f"State Manager request failed (attempt {attempt + 1}): {e}, "
151
+ f"retrying in {backoff}s..."
152
+ )
153
+ await asyncio.sleep(backoff)
154
+ continue
155
+
156
+ raise ManifestServiceUnavailableError(
157
+ f"Failed to update resource state after {self.max_retries} attempts: "
158
+ f"{last_exception}"
159
+ )
160
+
161
+ async def remove_resource_state(
162
+ self, mothership_id: str, resource_name: str
163
+ ) -> None:
164
+ """Remove resource entry from State Manager.
165
+
166
+ Uses locking to prevent race conditions when multiple resources
167
+ are deployed concurrently.
168
+
169
+ Args:
170
+ mothership_id: ID of the mothership endpoint.
171
+ resource_name: Name of the resource.
172
+
173
+ Raises:
174
+ ManifestServiceUnavailableError: If State Manager unavailable.
175
+ """
176
+ last_exception: Optional[Exception] = None
177
+
178
+ for attempt in range(self.max_retries):
179
+ try:
180
+ async with self._manifest_lock:
181
+ async with RunpodGraphQLClient() as client:
182
+ build_id, manifest = await self._fetch_build_and_manifest(
183
+ client, mothership_id
184
+ )
185
+ resources = manifest.setdefault("resources", {})
186
+ resources.pop(resource_name, None)
187
+ await client.update_build_manifest(build_id, manifest)
188
+
189
+ logger.debug(
190
+ f"Removed resource state from State Manager: {mothership_id}/{resource_name}"
191
+ )
192
+ return
193
+
194
+ except (
195
+ asyncio.TimeoutError,
196
+ ManifestServiceUnavailableError,
197
+ GraphQLError,
198
+ ConnectionError,
199
+ ) as e:
200
+ last_exception = e
201
+ if attempt < self.max_retries - 1:
202
+ backoff = 2**attempt
203
+ logger.warning(
204
+ f"State Manager request failed (attempt {attempt + 1}): {e}, "
205
+ f"retrying in {backoff}s..."
206
+ )
207
+ await asyncio.sleep(backoff)
208
+ continue
209
+
210
+ raise ManifestServiceUnavailableError(
211
+ f"Failed to remove resource state after {self.max_retries} attempts: "
212
+ f"{last_exception}"
213
+ )
214
+
215
+ async def _fetch_build_and_manifest(
216
+ self, client: RunpodGraphQLClient, mothership_id: str
217
+ ) -> tuple[str, Dict[str, Any]]:
218
+ """Fetch active build ID and manifest for an environment.
219
+
220
+ Args:
221
+ client: Authenticated GraphQL client.
222
+ mothership_id: Flash environment ID.
223
+
224
+ Returns:
225
+ Tuple of (build_id, manifest_dict).
226
+
227
+ Raises:
228
+ ManifestServiceUnavailableError: If environment, build, or manifest not found.
229
+ """
230
+ environment = await client.get_flash_environment(
231
+ {"flashEnvironmentId": mothership_id}
232
+ )
233
+ build_id = environment.get("activeBuildId")
234
+ if not build_id:
235
+ raise ManifestServiceUnavailableError(
236
+ f"Active build not found for environment {mothership_id}. "
237
+ f"Environment may not be fully initialized or has no deployed build."
238
+ )
239
+
240
+ build = await client.get_flash_build(build_id)
241
+ manifest = build.get("manifest")
242
+ if not manifest:
243
+ raise ManifestServiceUnavailableError(
244
+ f"Manifest not found for build {build.get('id', build_id)}. "
245
+ f"Build may be corrupted, not yet published, or manifest was not generated."
246
+ )
247
+
248
+ return build_id, manifest
@@ -4,6 +4,7 @@ import inspect
4
4
  import textwrap
5
5
  import hashlib
6
6
  import traceback
7
+ import threading
7
8
  import cloudpickle
8
9
  import logging
9
10
  from ..core.resources import LiveServerless
@@ -12,26 +13,37 @@ from ..protos.remote_execution import (
12
13
  FunctionResponse,
13
14
  RemoteExecutorStub,
14
15
  )
16
+ from ..runtime.serialization import serialize_args, serialize_kwargs
15
17
 
16
18
  log = logging.getLogger(__name__)
17
19
 
18
20
 
19
- # global in memory cache, TODO: use a more robust cache in future
21
+ # Global in-memory cache with thread safety
20
22
  _SERIALIZED_FUNCTION_CACHE = {}
23
+ _function_cache_lock = threading.RLock()
21
24
 
22
25
 
23
26
  def get_function_source(func):
24
27
  """Extract the function source code without the decorator."""
28
+ # Unwrap any decorators to get the original function
29
+ func = inspect.unwrap(func)
30
+
25
31
  # Get the source code of the decorated function
26
32
  source = inspect.getsource(func)
27
33
 
34
+ # Dedent the source to handle functions defined in classes or indented contexts
35
+ source = textwrap.dedent(source)
36
+
28
37
  # Parse the source code
29
38
  module = ast.parse(source)
30
39
 
31
- # Find the function definition node
40
+ # Find the function definition node (both sync and async)
32
41
  function_def = None
33
42
  for node in ast.walk(module):
34
- if isinstance(node, ast.FunctionDef) and node.name == func.__name__:
43
+ if (
44
+ isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
45
+ and node.name == func.__name__
46
+ ):
35
47
  function_def = node
36
48
  break
37
49
 
@@ -60,32 +72,38 @@ class LiveServerlessStub(RemoteExecutorStub):
60
72
  def __init__(self, server: LiveServerless):
61
73
  self.server = server
62
74
 
63
- def prepare_request(self, func, dependencies, system_dependencies, *args, **kwargs):
75
+ def prepare_request(
76
+ self,
77
+ func,
78
+ dependencies,
79
+ system_dependencies,
80
+ accelerate_downloads,
81
+ *args,
82
+ **kwargs,
83
+ ):
64
84
  source, src_hash = get_function_source(func)
65
85
 
66
86
  request = {
67
87
  "function_name": func.__name__,
68
88
  "dependencies": dependencies,
69
89
  "system_dependencies": system_dependencies,
90
+ "accelerate_downloads": accelerate_downloads,
70
91
  }
71
92
 
72
- # check if the function is already cached
73
- if src_hash not in _SERIALIZED_FUNCTION_CACHE:
74
- # Cache the serialized function
75
- _SERIALIZED_FUNCTION_CACHE[src_hash] = source
93
+ # Thread-safe cache access
94
+ with _function_cache_lock:
95
+ # check if the function is already cached
96
+ if src_hash not in _SERIALIZED_FUNCTION_CACHE:
97
+ # Cache the serialized function
98
+ _SERIALIZED_FUNCTION_CACHE[src_hash] = source
76
99
 
77
- request["function_code"] = _SERIALIZED_FUNCTION_CACHE[src_hash]
100
+ request["function_code"] = _SERIALIZED_FUNCTION_CACHE[src_hash]
78
101
 
79
102
  # Serialize arguments using cloudpickle
80
103
  if args:
81
- request["args"] = [
82
- base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8") for arg in args
83
- ]
104
+ request["args"] = serialize_args(args)
84
105
  if kwargs:
85
- request["kwargs"] = {
86
- k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
87
- for k, v in kwargs.items()
88
- }
106
+ request["kwargs"] = serialize_kwargs(kwargs)
89
107
 
90
108
  return FunctionRequest(**request)
91
109
 
@@ -95,7 +113,7 @@ class LiveServerlessStub(RemoteExecutorStub):
95
113
 
96
114
  if response.stdout:
97
115
  for line in response.stdout.splitlines():
98
- log.info(f"Remote | {line}")
116
+ print(line)
99
117
 
100
118
  if response.success:
101
119
  if response.result is None: