tetra-rp 0.6.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tetra_rp/__init__.py +109 -19
- tetra_rp/cli/commands/__init__.py +1 -0
- tetra_rp/cli/commands/apps.py +143 -0
- tetra_rp/cli/commands/build.py +1082 -0
- tetra_rp/cli/commands/build_utils/__init__.py +1 -0
- tetra_rp/cli/commands/build_utils/handler_generator.py +176 -0
- tetra_rp/cli/commands/build_utils/lb_handler_generator.py +309 -0
- tetra_rp/cli/commands/build_utils/manifest.py +430 -0
- tetra_rp/cli/commands/build_utils/mothership_handler_generator.py +75 -0
- tetra_rp/cli/commands/build_utils/scanner.py +596 -0
- tetra_rp/cli/commands/deploy.py +580 -0
- tetra_rp/cli/commands/init.py +123 -0
- tetra_rp/cli/commands/resource.py +108 -0
- tetra_rp/cli/commands/run.py +296 -0
- tetra_rp/cli/commands/test_mothership.py +458 -0
- tetra_rp/cli/commands/undeploy.py +533 -0
- tetra_rp/cli/main.py +97 -0
- tetra_rp/cli/utils/__init__.py +1 -0
- tetra_rp/cli/utils/app.py +15 -0
- tetra_rp/cli/utils/conda.py +127 -0
- tetra_rp/cli/utils/deployment.py +530 -0
- tetra_rp/cli/utils/ignore.py +143 -0
- tetra_rp/cli/utils/skeleton.py +184 -0
- tetra_rp/cli/utils/skeleton_template/.env.example +4 -0
- tetra_rp/cli/utils/skeleton_template/.flashignore +40 -0
- tetra_rp/cli/utils/skeleton_template/.gitignore +44 -0
- tetra_rp/cli/utils/skeleton_template/README.md +263 -0
- tetra_rp/cli/utils/skeleton_template/main.py +44 -0
- tetra_rp/cli/utils/skeleton_template/mothership.py +55 -0
- tetra_rp/cli/utils/skeleton_template/pyproject.toml +58 -0
- tetra_rp/cli/utils/skeleton_template/requirements.txt +1 -0
- tetra_rp/cli/utils/skeleton_template/workers/__init__.py +0 -0
- tetra_rp/cli/utils/skeleton_template/workers/cpu/__init__.py +19 -0
- tetra_rp/cli/utils/skeleton_template/workers/cpu/endpoint.py +36 -0
- tetra_rp/cli/utils/skeleton_template/workers/gpu/__init__.py +19 -0
- tetra_rp/cli/utils/skeleton_template/workers/gpu/endpoint.py +61 -0
- tetra_rp/client.py +136 -33
- tetra_rp/config.py +29 -0
- tetra_rp/core/api/runpod.py +591 -39
- tetra_rp/core/deployment.py +232 -0
- tetra_rp/core/discovery.py +425 -0
- tetra_rp/core/exceptions.py +50 -0
- tetra_rp/core/resources/__init__.py +27 -9
- tetra_rp/core/resources/app.py +738 -0
- tetra_rp/core/resources/base.py +139 -4
- tetra_rp/core/resources/constants.py +21 -0
- tetra_rp/core/resources/cpu.py +115 -13
- tetra_rp/core/resources/gpu.py +182 -16
- tetra_rp/core/resources/live_serverless.py +153 -16
- tetra_rp/core/resources/load_balancer_sls_resource.py +440 -0
- tetra_rp/core/resources/network_volume.py +126 -31
- tetra_rp/core/resources/resource_manager.py +436 -35
- tetra_rp/core/resources/serverless.py +537 -120
- tetra_rp/core/resources/serverless_cpu.py +201 -0
- tetra_rp/core/resources/template.py +1 -59
- tetra_rp/core/utils/constants.py +10 -0
- tetra_rp/core/utils/file_lock.py +260 -0
- tetra_rp/core/utils/http.py +67 -0
- tetra_rp/core/utils/lru_cache.py +75 -0
- tetra_rp/core/utils/singleton.py +36 -1
- tetra_rp/core/validation.py +44 -0
- tetra_rp/execute_class.py +301 -0
- tetra_rp/protos/remote_execution.py +98 -9
- tetra_rp/runtime/__init__.py +1 -0
- tetra_rp/runtime/circuit_breaker.py +274 -0
- tetra_rp/runtime/config.py +12 -0
- tetra_rp/runtime/exceptions.py +49 -0
- tetra_rp/runtime/generic_handler.py +206 -0
- tetra_rp/runtime/lb_handler.py +189 -0
- tetra_rp/runtime/load_balancer.py +160 -0
- tetra_rp/runtime/manifest_fetcher.py +192 -0
- tetra_rp/runtime/metrics.py +325 -0
- tetra_rp/runtime/models.py +73 -0
- tetra_rp/runtime/mothership_provisioner.py +512 -0
- tetra_rp/runtime/production_wrapper.py +266 -0
- tetra_rp/runtime/reliability_config.py +149 -0
- tetra_rp/runtime/retry_manager.py +118 -0
- tetra_rp/runtime/serialization.py +124 -0
- tetra_rp/runtime/service_registry.py +346 -0
- tetra_rp/runtime/state_manager_client.py +248 -0
- tetra_rp/stubs/live_serverless.py +35 -17
- tetra_rp/stubs/load_balancer_sls.py +357 -0
- tetra_rp/stubs/registry.py +145 -19
- {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/METADATA +398 -60
- tetra_rp-0.24.0.dist-info/RECORD +99 -0
- {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/WHEEL +1 -1
- tetra_rp-0.24.0.dist-info/entry_points.txt +2 -0
- tetra_rp/core/pool/cluster_manager.py +0 -177
- tetra_rp/core/pool/dataclass.py +0 -18
- tetra_rp/core/pool/ex.py +0 -38
- tetra_rp/core/pool/job.py +0 -22
- tetra_rp/core/pool/worker.py +0 -19
- tetra_rp/core/resources/utils.py +0 -50
- tetra_rp/core/utils/json.py +0 -33
- tetra_rp-0.6.0.dist-info/RECORD +0 -39
- /tetra_rp/{core/pool → cli}/__init__.py +0 -0
- {tetra_rp-0.6.0.dist-info → tetra_rp-0.24.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Circuit breaker pattern for handling endpoint failures."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Any, Callable, Optional
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CircuitState(Enum):
|
|
15
|
+
"""Circuit breaker state machine."""
|
|
16
|
+
|
|
17
|
+
CLOSED = "closed"
|
|
18
|
+
OPEN = "open"
|
|
19
|
+
HALF_OPEN = "half_open"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class CircuitBreakerStats:
|
|
24
|
+
"""Statistics for a circuit breaker instance."""
|
|
25
|
+
|
|
26
|
+
state: CircuitState = CircuitState.CLOSED
|
|
27
|
+
failure_count: int = 0
|
|
28
|
+
success_count: int = 0
|
|
29
|
+
last_failure_at: Optional[datetime] = None
|
|
30
|
+
last_success_at: Optional[datetime] = None
|
|
31
|
+
state_changed_at: datetime = field(
|
|
32
|
+
default_factory=lambda: datetime.now(timezone.utc)
|
|
33
|
+
)
|
|
34
|
+
total_requests: int = 0
|
|
35
|
+
total_failures: int = 0
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class EndpointCircuitBreaker:
|
|
39
|
+
"""Circuit breaker for a single endpoint with sliding window."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
endpoint_url: str,
|
|
44
|
+
failure_threshold: int = 5,
|
|
45
|
+
success_threshold: int = 2,
|
|
46
|
+
timeout_seconds: int = 60,
|
|
47
|
+
window_size: int = 10,
|
|
48
|
+
):
|
|
49
|
+
"""Initialize circuit breaker for an endpoint.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
endpoint_url: URL of the endpoint to protect
|
|
53
|
+
failure_threshold: Failures required to open circuit
|
|
54
|
+
success_threshold: Successes required to close circuit
|
|
55
|
+
timeout_seconds: Time before attempting recovery
|
|
56
|
+
window_size: Size of sliding window for counting failures
|
|
57
|
+
"""
|
|
58
|
+
self.endpoint_url = endpoint_url
|
|
59
|
+
self.failure_threshold = failure_threshold
|
|
60
|
+
self.success_threshold = success_threshold
|
|
61
|
+
self.timeout_seconds = timeout_seconds
|
|
62
|
+
self.window_size = window_size
|
|
63
|
+
self.stats = CircuitBreakerStats()
|
|
64
|
+
self._lock = asyncio.Lock()
|
|
65
|
+
self._failure_times: list[float] = []
|
|
66
|
+
|
|
67
|
+
async def execute(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
68
|
+
"""Execute function with circuit breaker protection.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
func: Async function to execute
|
|
72
|
+
*args: Positional arguments for func
|
|
73
|
+
**kwargs: Keyword arguments for func
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Result from func
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
CircuitBreakerOpenError: If circuit is open
|
|
80
|
+
Exception: Any exception raised by func
|
|
81
|
+
"""
|
|
82
|
+
async with self._lock:
|
|
83
|
+
state = self.stats.state
|
|
84
|
+
|
|
85
|
+
if state == CircuitState.OPEN:
|
|
86
|
+
# Check if timeout has passed
|
|
87
|
+
if self._should_attempt_recovery():
|
|
88
|
+
self._transition_to_half_open()
|
|
89
|
+
else:
|
|
90
|
+
raise CircuitBreakerOpenError(
|
|
91
|
+
f"Circuit OPEN for {self.endpoint_url}. "
|
|
92
|
+
f"Retry in {self._seconds_until_recovery()}s"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Execute function
|
|
96
|
+
try:
|
|
97
|
+
result = await func(*args, **kwargs)
|
|
98
|
+
await self._on_success()
|
|
99
|
+
return result
|
|
100
|
+
except Exception as e:
|
|
101
|
+
await self._on_failure(e)
|
|
102
|
+
raise
|
|
103
|
+
|
|
104
|
+
async def _on_success(self) -> None:
|
|
105
|
+
"""Record successful request."""
|
|
106
|
+
async with self._lock:
|
|
107
|
+
self.stats.success_count += 1
|
|
108
|
+
self.stats.total_requests += 1
|
|
109
|
+
self.stats.last_success_at = datetime.now(timezone.utc)
|
|
110
|
+
|
|
111
|
+
logger.debug(
|
|
112
|
+
f"Circuit breaker {self.endpoint_url}: "
|
|
113
|
+
f"success {self.stats.success_count}/{self.success_threshold}"
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if self.stats.state == CircuitState.HALF_OPEN:
|
|
117
|
+
# Close circuit after successes
|
|
118
|
+
if self.stats.success_count >= self.success_threshold:
|
|
119
|
+
self._transition_to_closed()
|
|
120
|
+
elif self.stats.state == CircuitState.CLOSED:
|
|
121
|
+
# Reset failure count on success
|
|
122
|
+
self.stats.failure_count = 0
|
|
123
|
+
self._failure_times.clear()
|
|
124
|
+
|
|
125
|
+
async def _on_failure(self, error: Exception) -> None:
|
|
126
|
+
"""Record failed request."""
|
|
127
|
+
async with self._lock:
|
|
128
|
+
self.stats.failure_count += 1
|
|
129
|
+
self.stats.total_failures += 1
|
|
130
|
+
self.stats.total_requests += 1
|
|
131
|
+
self.stats.last_failure_at = datetime.now(timezone.utc)
|
|
132
|
+
|
|
133
|
+
# Track failure times for sliding window
|
|
134
|
+
now = time.time()
|
|
135
|
+
self._failure_times.append(now)
|
|
136
|
+
|
|
137
|
+
# Keep only failures within window
|
|
138
|
+
cutoff = now - self.timeout_seconds
|
|
139
|
+
self._failure_times = [t for t in self._failure_times if t > cutoff]
|
|
140
|
+
|
|
141
|
+
logger.debug(
|
|
142
|
+
f"Circuit breaker {self.endpoint_url}: "
|
|
143
|
+
f"failure {self.stats.failure_count}/{self.failure_threshold}, "
|
|
144
|
+
f"error: {error}"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if self.stats.state == CircuitState.HALF_OPEN:
|
|
148
|
+
# Open circuit on first failure in half-open
|
|
149
|
+
self._transition_to_open()
|
|
150
|
+
elif self.stats.state == CircuitState.CLOSED:
|
|
151
|
+
# Open circuit if threshold reached
|
|
152
|
+
if len(self._failure_times) >= self.failure_threshold:
|
|
153
|
+
self._transition_to_open()
|
|
154
|
+
|
|
155
|
+
def _transition_to_open(self) -> None:
|
|
156
|
+
"""Transition circuit to OPEN state."""
|
|
157
|
+
if self.stats.state == CircuitState.OPEN:
|
|
158
|
+
return # Already open
|
|
159
|
+
self.stats.state = CircuitState.OPEN
|
|
160
|
+
self.stats.state_changed_at = datetime.now(timezone.utc)
|
|
161
|
+
self.stats.success_count = 0
|
|
162
|
+
logger.warning(
|
|
163
|
+
f"Circuit breaker OPEN for {self.endpoint_url} "
|
|
164
|
+
f"after {self.stats.failure_count} failures"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _transition_to_half_open(self) -> None:
|
|
168
|
+
"""Transition circuit to HALF_OPEN state."""
|
|
169
|
+
self.stats.state = CircuitState.HALF_OPEN
|
|
170
|
+
self.stats.state_changed_at = datetime.now(timezone.utc)
|
|
171
|
+
self.stats.failure_count = 0
|
|
172
|
+
self.stats.success_count = 0
|
|
173
|
+
logger.info(
|
|
174
|
+
f"Circuit breaker HALF_OPEN for {self.endpoint_url}, testing recovery"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
def _transition_to_closed(self) -> None:
|
|
178
|
+
"""Transition circuit to CLOSED state."""
|
|
179
|
+
self.stats.state = CircuitState.CLOSED
|
|
180
|
+
self.stats.state_changed_at = datetime.now(timezone.utc)
|
|
181
|
+
self.stats.failure_count = 0
|
|
182
|
+
self.stats.success_count = 0
|
|
183
|
+
self._failure_times.clear()
|
|
184
|
+
logger.info(f"Circuit breaker CLOSED for {self.endpoint_url}, recovered")
|
|
185
|
+
|
|
186
|
+
def _should_attempt_recovery(self) -> bool:
|
|
187
|
+
"""Check if enough time has passed to attempt recovery."""
|
|
188
|
+
if not self.stats.last_failure_at:
|
|
189
|
+
return False
|
|
190
|
+
elapsed = datetime.now(timezone.utc) - self.stats.state_changed_at
|
|
191
|
+
return elapsed.total_seconds() >= self.timeout_seconds
|
|
192
|
+
|
|
193
|
+
def _seconds_until_recovery(self) -> int:
|
|
194
|
+
"""Get seconds until recovery can be attempted."""
|
|
195
|
+
if not self.stats.state_changed_at:
|
|
196
|
+
return self.timeout_seconds
|
|
197
|
+
elapsed = datetime.now(timezone.utc) - self.stats.state_changed_at
|
|
198
|
+
remaining = self.timeout_seconds - int(elapsed.total_seconds())
|
|
199
|
+
return max(0, remaining)
|
|
200
|
+
|
|
201
|
+
def get_state(self) -> CircuitState:
|
|
202
|
+
"""Get current circuit state."""
|
|
203
|
+
return self.stats.state
|
|
204
|
+
|
|
205
|
+
def get_stats(self) -> CircuitBreakerStats:
|
|
206
|
+
"""Get circuit breaker statistics."""
|
|
207
|
+
return self.stats
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class CircuitBreakerRegistry:
|
|
211
|
+
"""Manages circuit breakers for multiple endpoints."""
|
|
212
|
+
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
failure_threshold: int = 5,
|
|
216
|
+
success_threshold: int = 2,
|
|
217
|
+
timeout_seconds: int = 60,
|
|
218
|
+
):
|
|
219
|
+
"""Initialize circuit breaker registry.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
failure_threshold: Failures required to open circuit
|
|
223
|
+
success_threshold: Successes required to close circuit
|
|
224
|
+
timeout_seconds: Time before attempting recovery
|
|
225
|
+
"""
|
|
226
|
+
self.failure_threshold = failure_threshold
|
|
227
|
+
self.success_threshold = success_threshold
|
|
228
|
+
self.timeout_seconds = timeout_seconds
|
|
229
|
+
self._breakers: dict[str, EndpointCircuitBreaker] = {}
|
|
230
|
+
self._lock = asyncio.Lock()
|
|
231
|
+
|
|
232
|
+
def get_breaker(self, endpoint_url: str) -> EndpointCircuitBreaker:
|
|
233
|
+
"""Get or create circuit breaker for endpoint.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
endpoint_url: URL of the endpoint
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
EndpointCircuitBreaker instance
|
|
240
|
+
"""
|
|
241
|
+
if endpoint_url not in self._breakers:
|
|
242
|
+
self._breakers[endpoint_url] = EndpointCircuitBreaker(
|
|
243
|
+
endpoint_url,
|
|
244
|
+
failure_threshold=self.failure_threshold,
|
|
245
|
+
success_threshold=self.success_threshold,
|
|
246
|
+
timeout_seconds=self.timeout_seconds,
|
|
247
|
+
)
|
|
248
|
+
return self._breakers[endpoint_url]
|
|
249
|
+
|
|
250
|
+
def get_state(self, endpoint_url: str) -> CircuitState:
|
|
251
|
+
"""Get state of circuit breaker for endpoint.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
endpoint_url: URL of the endpoint
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Current circuit state
|
|
258
|
+
"""
|
|
259
|
+
breaker = self.get_breaker(endpoint_url)
|
|
260
|
+
return breaker.get_state()
|
|
261
|
+
|
|
262
|
+
def get_all_stats(self) -> dict[str, CircuitBreakerStats]:
|
|
263
|
+
"""Get statistics for all circuit breakers.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Mapping of endpoint URLs to statistics
|
|
267
|
+
"""
|
|
268
|
+
return {url: breaker.get_stats() for url, breaker in self._breakers.items()}
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class CircuitBreakerOpenError(Exception):
|
|
272
|
+
"""Raised when circuit breaker is open."""
|
|
273
|
+
|
|
274
|
+
pass
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Configuration constants for runtime module."""
|
|
2
|
+
|
|
3
|
+
# HTTP client configuration
|
|
4
|
+
DEFAULT_REQUEST_TIMEOUT = 10 # seconds
|
|
5
|
+
DEFAULT_MAX_RETRIES = 3
|
|
6
|
+
DEFAULT_BACKOFF_BASE = 2
|
|
7
|
+
|
|
8
|
+
# Manifest cache configuration
|
|
9
|
+
DEFAULT_CACHE_TTL = 300 # seconds
|
|
10
|
+
|
|
11
|
+
# Serialization limits
|
|
12
|
+
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""Custom exceptions for cross-endpoint runtime."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class FlashRuntimeError(Exception):
|
|
5
|
+
"""Base exception for runtime errors in cross-endpoint execution."""
|
|
6
|
+
|
|
7
|
+
pass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RemoteExecutionError(FlashRuntimeError):
|
|
11
|
+
"""Raised when remote function execution fails."""
|
|
12
|
+
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SerializationError(FlashRuntimeError):
|
|
17
|
+
"""Raised when serialization or deserialization of arguments fails."""
|
|
18
|
+
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GraphQLError(FlashRuntimeError):
|
|
23
|
+
"""Base exception for GraphQL-related errors."""
|
|
24
|
+
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GraphQLMutationError(GraphQLError):
|
|
29
|
+
"""Raised when a GraphQL mutation fails unexpectedly."""
|
|
30
|
+
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GraphQLQueryError(GraphQLError):
|
|
35
|
+
"""Raised when a GraphQL query fails unexpectedly."""
|
|
36
|
+
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ManifestError(FlashRuntimeError):
|
|
41
|
+
"""Raised when manifest is invalid, missing, or has unexpected structure."""
|
|
42
|
+
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ManifestServiceUnavailableError(FlashRuntimeError):
|
|
47
|
+
"""Raised when manifest service is unavailable."""
|
|
48
|
+
|
|
49
|
+
pass
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""Generic RunPod serverless handler factory for Flash."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import traceback
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Callable, Dict
|
|
8
|
+
|
|
9
|
+
from .serialization import deserialize_args, deserialize_kwargs, serialize_arg
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_manifest(manifest_path: Path | None = None) -> Dict[str, Any]:
|
|
15
|
+
"""Load flash_manifest.json with fallback search.
|
|
16
|
+
|
|
17
|
+
Searches multiple locations for manifest:
|
|
18
|
+
1. Provided path (if given)
|
|
19
|
+
2. Current working directory
|
|
20
|
+
3. Module directory
|
|
21
|
+
4. Three levels up (legacy location)
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
manifest_path: Optional explicit path to manifest file
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Manifest dictionary, or empty dict if not found
|
|
28
|
+
"""
|
|
29
|
+
if manifest_path and manifest_path.exists():
|
|
30
|
+
try:
|
|
31
|
+
with open(manifest_path) as f:
|
|
32
|
+
return json.load(f)
|
|
33
|
+
except Exception as e:
|
|
34
|
+
logger.warning(f"Failed to load manifest from {manifest_path}: {e}")
|
|
35
|
+
return {"resources": {}, "function_registry": {}}
|
|
36
|
+
|
|
37
|
+
# Search multiple locations
|
|
38
|
+
search_paths = [
|
|
39
|
+
Path.cwd() / "flash_manifest.json",
|
|
40
|
+
Path(__file__).parent / "flash_manifest.json",
|
|
41
|
+
Path(__file__).parent.parent.parent / "flash_manifest.json",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
for path in search_paths:
|
|
45
|
+
if path.exists():
|
|
46
|
+
try:
|
|
47
|
+
with open(path) as f:
|
|
48
|
+
return json.load(f)
|
|
49
|
+
except Exception as e:
|
|
50
|
+
logger.debug(f"Failed to load manifest from {path}: {e}")
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
logger.warning("flash_manifest.json not found in any expected location")
|
|
54
|
+
return {"resources": {}, "function_registry": {}}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def deserialize_arguments(job_input: Dict[str, Any]) -> tuple[list, dict]:
|
|
58
|
+
"""Deserialize function arguments from job input.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
job_input: Input dict from RunPod job with 'args' and 'kwargs' keys
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Tuple of (args list, kwargs dict) deserialized from cloudpickle
|
|
65
|
+
"""
|
|
66
|
+
args = deserialize_args(job_input.get("args", []))
|
|
67
|
+
kwargs = deserialize_kwargs(job_input.get("kwargs", {}))
|
|
68
|
+
return args, kwargs
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def serialize_result(result: Any) -> str:
|
|
72
|
+
"""Serialize function result for response.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
result: Return value from function
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Base64-encoded cloudpickle of result
|
|
79
|
+
"""
|
|
80
|
+
return serialize_arg(result)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def execute_function(
|
|
84
|
+
func_or_class: Callable,
|
|
85
|
+
args: list,
|
|
86
|
+
kwargs: dict,
|
|
87
|
+
execution_type: str,
|
|
88
|
+
job_input: Dict[str, Any],
|
|
89
|
+
) -> Any:
|
|
90
|
+
"""Execute function or class method.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
func_or_class: Function or class to execute
|
|
94
|
+
args: Positional arguments
|
|
95
|
+
kwargs: Keyword arguments
|
|
96
|
+
execution_type: Either "function" or "class"
|
|
97
|
+
job_input: Full job input for method calls
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Result of execution
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
Exception: If execution fails
|
|
104
|
+
"""
|
|
105
|
+
if execution_type == "class":
|
|
106
|
+
# Instantiate class with constructor args
|
|
107
|
+
instance = func_or_class(*args, **kwargs)
|
|
108
|
+
method_name = job_input.get("method_name", "__call__")
|
|
109
|
+
|
|
110
|
+
# Call method on instance
|
|
111
|
+
method = getattr(instance, method_name)
|
|
112
|
+
method_args, method_kwargs = deserialize_arguments(
|
|
113
|
+
{
|
|
114
|
+
"args": job_input.get("method_args", []),
|
|
115
|
+
"kwargs": job_input.get("method_kwargs", {}),
|
|
116
|
+
}
|
|
117
|
+
)
|
|
118
|
+
return method(*method_args, **method_kwargs)
|
|
119
|
+
else:
|
|
120
|
+
# Direct function call
|
|
121
|
+
return func_or_class(*args, **kwargs)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def create_handler(function_registry: Dict[str, Callable]) -> Callable:
|
|
125
|
+
"""Create a RunPod serverless handler with given function registry.
|
|
126
|
+
|
|
127
|
+
This factory function creates a handler that:
|
|
128
|
+
1. Deserializes function arguments from cloudpickle + base64
|
|
129
|
+
2. Looks up function/class in registry by name
|
|
130
|
+
3. Executes function or class method
|
|
131
|
+
4. Serializes result back to cloudpickle + base64
|
|
132
|
+
5. Returns RunPod-compatible response dict
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
function_registry: Dict mapping function names to function/class objects
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Handler function compatible with runpod.serverless.start()
|
|
139
|
+
|
|
140
|
+
Example:
|
|
141
|
+
```python
|
|
142
|
+
from tetra_rp.runtime.generic_handler import create_handler
|
|
143
|
+
from workers.gpu import process_data, analyze_data
|
|
144
|
+
|
|
145
|
+
registry = {
|
|
146
|
+
"process_data": process_data,
|
|
147
|
+
"analyze_data": analyze_data,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
handler = create_handler(registry)
|
|
151
|
+
|
|
152
|
+
if __name__ == "__main__":
|
|
153
|
+
import runpod
|
|
154
|
+
runpod.serverless.start({"handler": handler})
|
|
155
|
+
```
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
def handler(job: Dict[str, Any]) -> Dict[str, Any]:
|
|
159
|
+
"""RunPod serverless handler.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
job: RunPod job dict with 'input' key
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Response dict with 'success', 'result'/'error' keys
|
|
166
|
+
"""
|
|
167
|
+
job_input = job.get("input", {})
|
|
168
|
+
function_name = job_input.get("function_name")
|
|
169
|
+
execution_type = job_input.get("execution_type", "function")
|
|
170
|
+
|
|
171
|
+
if function_name not in function_registry:
|
|
172
|
+
return {
|
|
173
|
+
"success": False,
|
|
174
|
+
"error": f"Function '{function_name}' not found in registry. "
|
|
175
|
+
f"Available: {list(function_registry.keys())}",
|
|
176
|
+
"traceback": "",
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
# Deserialize arguments
|
|
181
|
+
args, kwargs = deserialize_arguments(job_input)
|
|
182
|
+
|
|
183
|
+
# Get function/class from registry
|
|
184
|
+
func_or_class = function_registry[function_name]
|
|
185
|
+
|
|
186
|
+
# Execute function or class
|
|
187
|
+
result = execute_function(
|
|
188
|
+
func_or_class, args, kwargs, execution_type, job_input
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Serialize result
|
|
192
|
+
serialized_result = serialize_result(result)
|
|
193
|
+
|
|
194
|
+
return {
|
|
195
|
+
"success": True,
|
|
196
|
+
"result": serialized_result,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
except Exception as e:
|
|
200
|
+
return {
|
|
201
|
+
"success": False,
|
|
202
|
+
"error": str(e),
|
|
203
|
+
"traceback": traceback.format_exc(),
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
return handler
|