tetra-rp 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tetra-rp might be problematic. Click here for more details.

Files changed (66) hide show
  1. tetra_rp/__init__.py +43 -0
  2. tetra_rp/cli/__init__.py +0 -0
  3. tetra_rp/cli/commands/__init__.py +1 -0
  4. tetra_rp/cli/commands/build.py +534 -0
  5. tetra_rp/cli/commands/deploy.py +370 -0
  6. tetra_rp/cli/commands/init.py +119 -0
  7. tetra_rp/cli/commands/resource.py +191 -0
  8. tetra_rp/cli/commands/run.py +100 -0
  9. tetra_rp/cli/main.py +85 -0
  10. tetra_rp/cli/utils/__init__.py +1 -0
  11. tetra_rp/cli/utils/conda.py +127 -0
  12. tetra_rp/cli/utils/deployment.py +172 -0
  13. tetra_rp/cli/utils/ignore.py +139 -0
  14. tetra_rp/cli/utils/skeleton.py +184 -0
  15. tetra_rp/cli/utils/skeleton_template/.env.example +3 -0
  16. tetra_rp/cli/utils/skeleton_template/.flashignore +40 -0
  17. tetra_rp/cli/utils/skeleton_template/.gitignore +44 -0
  18. tetra_rp/cli/utils/skeleton_template/README.md +256 -0
  19. tetra_rp/cli/utils/skeleton_template/main.py +43 -0
  20. tetra_rp/cli/utils/skeleton_template/requirements.txt +1 -0
  21. tetra_rp/cli/utils/skeleton_template/workers/__init__.py +0 -0
  22. tetra_rp/cli/utils/skeleton_template/workers/cpu/__init__.py +20 -0
  23. tetra_rp/cli/utils/skeleton_template/workers/cpu/endpoint.py +38 -0
  24. tetra_rp/cli/utils/skeleton_template/workers/gpu/__init__.py +20 -0
  25. tetra_rp/cli/utils/skeleton_template/workers/gpu/endpoint.py +62 -0
  26. tetra_rp/client.py +128 -0
  27. tetra_rp/config.py +29 -0
  28. tetra_rp/core/__init__.py +0 -0
  29. tetra_rp/core/api/__init__.py +6 -0
  30. tetra_rp/core/api/runpod.py +319 -0
  31. tetra_rp/core/exceptions.py +50 -0
  32. tetra_rp/core/resources/__init__.py +37 -0
  33. tetra_rp/core/resources/base.py +47 -0
  34. tetra_rp/core/resources/cloud.py +4 -0
  35. tetra_rp/core/resources/constants.py +4 -0
  36. tetra_rp/core/resources/cpu.py +146 -0
  37. tetra_rp/core/resources/environment.py +41 -0
  38. tetra_rp/core/resources/gpu.py +68 -0
  39. tetra_rp/core/resources/live_serverless.py +62 -0
  40. tetra_rp/core/resources/network_volume.py +148 -0
  41. tetra_rp/core/resources/resource_manager.py +145 -0
  42. tetra_rp/core/resources/serverless.py +463 -0
  43. tetra_rp/core/resources/serverless_cpu.py +162 -0
  44. tetra_rp/core/resources/template.py +94 -0
  45. tetra_rp/core/resources/utils.py +50 -0
  46. tetra_rp/core/utils/__init__.py +0 -0
  47. tetra_rp/core/utils/backoff.py +43 -0
  48. tetra_rp/core/utils/constants.py +10 -0
  49. tetra_rp/core/utils/file_lock.py +260 -0
  50. tetra_rp/core/utils/json.py +33 -0
  51. tetra_rp/core/utils/lru_cache.py +75 -0
  52. tetra_rp/core/utils/singleton.py +21 -0
  53. tetra_rp/core/validation.py +44 -0
  54. tetra_rp/execute_class.py +319 -0
  55. tetra_rp/logger.py +34 -0
  56. tetra_rp/protos/__init__.py +0 -0
  57. tetra_rp/protos/remote_execution.py +148 -0
  58. tetra_rp/stubs/__init__.py +5 -0
  59. tetra_rp/stubs/live_serverless.py +155 -0
  60. tetra_rp/stubs/registry.py +117 -0
  61. tetra_rp/stubs/serverless.py +30 -0
  62. tetra_rp-0.17.1.dist-info/METADATA +976 -0
  63. tetra_rp-0.17.1.dist-info/RECORD +66 -0
  64. tetra_rp-0.17.1.dist-info/WHEEL +5 -0
  65. tetra_rp-0.17.1.dist-info/entry_points.txt +2 -0
  66. tetra_rp-0.17.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,319 @@
1
+ """
2
+ Class execution module for remote class instantiation and method calls.
3
+
4
+ This module provides functionality to create and execute remote class instances,
5
+ with automatic caching of class serialization data to improve performance and
6
+ prevent memory leaks through LRU eviction.
7
+ """
8
+
9
+ import base64
10
+ import hashlib
11
+ import inspect
12
+ import logging
13
+ import textwrap
14
+ import uuid
15
+ from typing import List, Optional, Type
16
+
17
+ import cloudpickle
18
+
19
+ from .core.resources import ResourceManager, ServerlessResource
20
+ from .core.utils.constants import HASH_TRUNCATE_LENGTH, UUID_FALLBACK_LENGTH
21
+ from .core.utils.lru_cache import LRUCache
22
+ from .protos.remote_execution import FunctionRequest
23
+ from .stubs import stub_resource
24
+
25
+ log = logging.getLogger(__name__)
26
+
27
+ # Global in-memory cache for serialized class data with LRU eviction
28
+ _SERIALIZED_CLASS_CACHE = LRUCache(max_size=1000)
29
+
30
+
31
+ def serialize_constructor_args(args, kwargs):
32
+ """Serialize constructor arguments for caching."""
33
+ serialized_args = [
34
+ base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8") for arg in args
35
+ ]
36
+ serialized_kwargs = {
37
+ k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
38
+ for k, v in kwargs.items()
39
+ }
40
+ return serialized_args, serialized_kwargs
41
+
42
+
43
+ def get_or_cache_class_data(
44
+ cls: Type, args: tuple, kwargs: dict, cache_key: str
45
+ ) -> str:
46
+ """Get class code from cache or extract and cache it."""
47
+ if cache_key not in _SERIALIZED_CLASS_CACHE:
48
+ # Cache miss - extract and cache class code
49
+ clean_class_code = extract_class_code_simple(cls)
50
+
51
+ try:
52
+ serialized_args, serialized_kwargs = serialize_constructor_args(
53
+ args, kwargs
54
+ )
55
+
56
+ # Cache the serialized data
57
+ _SERIALIZED_CLASS_CACHE.set(
58
+ cache_key,
59
+ {
60
+ "class_code": clean_class_code,
61
+ "constructor_args": serialized_args,
62
+ "constructor_kwargs": serialized_kwargs,
63
+ },
64
+ )
65
+
66
+ log.debug(f"Cached class data for {cls.__name__} with key: {cache_key}")
67
+
68
+ except (TypeError, AttributeError, OSError) as e:
69
+ log.warning(
70
+ f"Could not serialize constructor arguments for {cls.__name__}: {e}"
71
+ )
72
+ log.warning(
73
+ f"Skipping constructor argument caching for {cls.__name__} due to unserializable arguments"
74
+ )
75
+
76
+ # Store minimal cache entry to avoid repeated attempts
77
+ _SERIALIZED_CLASS_CACHE.set(
78
+ cache_key,
79
+ {
80
+ "class_code": clean_class_code,
81
+ "constructor_args": None, # Signal that args couldn't be cached
82
+ "constructor_kwargs": None,
83
+ },
84
+ )
85
+
86
+ return clean_class_code
87
+ else:
88
+ # Cache hit - retrieve cached data
89
+ cached_data = _SERIALIZED_CLASS_CACHE.get(cache_key)
90
+ log.debug(
91
+ f"Retrieved cached class data for {cls.__name__} with key: {cache_key}"
92
+ )
93
+ return cached_data["class_code"]
94
+
95
+
96
+ def extract_class_code_simple(cls: Type) -> str:
97
+ """Extract clean class code without decorators and proper indentation"""
98
+ try:
99
+ # Get source code
100
+ source = inspect.getsource(cls)
101
+
102
+ # Split into lines
103
+ lines = source.split("\n")
104
+
105
+ # Find the class definition line (starts with 'class' and contains ':')
106
+ class_start_idx = -1
107
+ for i, line in enumerate(lines):
108
+ stripped = line.strip()
109
+ if stripped.startswith("class ") and ":" in stripped:
110
+ class_start_idx = i
111
+ break
112
+
113
+ if class_start_idx == -1:
114
+ raise ValueError("Could not find class definition")
115
+
116
+ # Take lines from class definition onwards (ignore everything before)
117
+ class_lines = lines[class_start_idx:]
118
+
119
+ # Remove empty lines at the end
120
+ while class_lines and not class_lines[-1].strip():
121
+ class_lines.pop()
122
+
123
+ # Join back and dedent to remove any leading indentation
124
+ class_code = "\n".join(class_lines)
125
+ class_code = textwrap.dedent(class_code)
126
+
127
+ # Validate the code by trying to compile it
128
+ compile(class_code, "<string>", "exec")
129
+
130
+ log.debug(f"Successfully extracted class code for {cls.__name__}")
131
+ return class_code
132
+
133
+ except Exception as e:
134
+ log.warning(f"Could not extract class code for {cls.__name__}: {e}")
135
+ log.warning("Falling back to basic class structure")
136
+
137
+ # Enhanced fallback: try to preserve method signatures
138
+ fallback_methods = []
139
+ for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
140
+ try:
141
+ sig = inspect.signature(method)
142
+ fallback_methods.append(f" def {name}{sig}:")
143
+ fallback_methods.append(" pass")
144
+ fallback_methods.append("")
145
+ except (TypeError, ValueError, OSError) as e:
146
+ log.warning(f"Could not extract method signature for {name}: {e}")
147
+ fallback_methods.append(f" def {name}(self, *args, **kwargs):")
148
+ fallback_methods.append(" pass")
149
+ fallback_methods.append("")
150
+
151
+ fallback_code = f"""class {cls.__name__}:
152
+ def __init__(self, *args, **kwargs):
153
+ pass
154
+
155
+ {chr(10).join(fallback_methods)}"""
156
+
157
+ return fallback_code
158
+
159
+
160
+ def get_class_cache_key(
161
+ cls: Type, constructor_args: tuple, constructor_kwargs: dict
162
+ ) -> str:
163
+ """Generate a cache key for class serialization based on class source and constructor args.
164
+
165
+ Args:
166
+ cls: The class type to generate a key for
167
+ constructor_args: Positional arguments passed to class constructor
168
+ constructor_kwargs: Keyword arguments passed to class constructor
169
+
170
+ Returns:
171
+ A unique cache key string, or a UUID-based fallback if serialization fails
172
+
173
+ Note:
174
+ Falls back to UUID-based key if constructor arguments cannot be serialized,
175
+ which disables caching benefits but maintains functionality.
176
+ """
177
+ try:
178
+ # Get class source code for hashing
179
+ class_source = extract_class_code_simple(cls)
180
+
181
+ # Create hash of class source
182
+ class_hash = hashlib.sha256(class_source.encode()).hexdigest()
183
+
184
+ # Create hash of constructor arguments
185
+ args_data = cloudpickle.dumps((constructor_args, constructor_kwargs))
186
+ args_hash = hashlib.sha256(args_data).hexdigest()
187
+
188
+ # Combine hashes for final cache key
189
+ cache_key = f"{cls.__name__}_{class_hash[:HASH_TRUNCATE_LENGTH]}_{args_hash[:HASH_TRUNCATE_LENGTH]}"
190
+
191
+ log.debug(f"Generated cache key for {cls.__name__}: {cache_key}")
192
+ return cache_key
193
+
194
+ except (TypeError, AttributeError, OSError) as e:
195
+ log.warning(f"Could not generate cache key for {cls.__name__}: {e}")
196
+ # Fallback to basic key without caching benefits
197
+ return f"{cls.__name__}_{uuid.uuid4().hex[:UUID_FALLBACK_LENGTH]}"
198
+
199
+
200
+ def create_remote_class(
201
+ cls: Type,
202
+ resource_config: ServerlessResource,
203
+ dependencies: Optional[List[str]],
204
+ system_dependencies: Optional[List[str]],
205
+ accelerate_downloads: bool,
206
+ extra: dict,
207
+ ):
208
+ """
209
+ Create a remote class wrapper.
210
+ """
211
+ # Validate inputs
212
+ if not inspect.isclass(cls):
213
+ raise TypeError(f"Expected a class, got {type(cls).__name__}")
214
+ if not hasattr(cls, "__name__"):
215
+ raise ValueError("Class must have a __name__ attribute")
216
+
217
+ class RemoteClassWrapper:
218
+ def __init__(self, *args, **kwargs):
219
+ self._class_type = cls
220
+ self._resource_config = resource_config
221
+ self._dependencies = dependencies or []
222
+ self._system_dependencies = system_dependencies or []
223
+ self._accelerate_downloads = accelerate_downloads
224
+ self._extra = extra
225
+ self._constructor_args = args
226
+ self._constructor_kwargs = kwargs
227
+ self._instance_id = (
228
+ f"{cls.__name__}_{uuid.uuid4().hex[:UUID_FALLBACK_LENGTH]}"
229
+ )
230
+ self._initialized = False
231
+
232
+ # Generate cache key and get class code
233
+ self._cache_key = get_class_cache_key(cls, args, kwargs)
234
+ self._clean_class_code = get_or_cache_class_data(
235
+ cls, args, kwargs, self._cache_key
236
+ )
237
+
238
+ log.debug(f"Created remote class wrapper for {cls.__name__}")
239
+
240
+ async def _ensure_initialized(self):
241
+ """Ensure the remote instance is created."""
242
+ if self._initialized:
243
+ return
244
+
245
+ # Get remote resource
246
+ resource_manager = ResourceManager()
247
+ remote_resource = await resource_manager.get_or_deploy_resource(
248
+ self._resource_config
249
+ )
250
+ self._stub = stub_resource(remote_resource, **self._extra)
251
+
252
+ # Create the remote instance by calling a method (which will trigger instance creation)
253
+ # We'll do this on first method call
254
+ self._initialized = True
255
+
256
+ def __getattr__(self, name):
257
+ """Dynamically create method proxies for all class methods."""
258
+ if name.startswith("_"):
259
+ raise AttributeError(
260
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
261
+ )
262
+
263
+ async def method_proxy(*args, **kwargs):
264
+ await self._ensure_initialized()
265
+
266
+ # Get cached data
267
+ cached_data = _SERIALIZED_CLASS_CACHE.get(self._cache_key)
268
+
269
+ # Serialize method arguments (these change per call, so no caching)
270
+ method_args = [
271
+ base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8")
272
+ for arg in args
273
+ ]
274
+ method_kwargs = {
275
+ k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
276
+ for k, v in kwargs.items()
277
+ }
278
+
279
+ # Handle constructor args - use cached if available, else serialize fresh
280
+ if cached_data["constructor_args"] is not None:
281
+ # Use cached constructor args
282
+ constructor_args = cached_data["constructor_args"]
283
+ constructor_kwargs = cached_data["constructor_kwargs"]
284
+ else:
285
+ # Constructor args couldn't be cached due to serialization issues
286
+ # Serialize them fresh for each method call (fallback behavior)
287
+ constructor_args = [
288
+ base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8")
289
+ for arg in self._constructor_args
290
+ ]
291
+ constructor_kwargs = {
292
+ k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
293
+ for k, v in self._constructor_kwargs.items()
294
+ }
295
+
296
+ request = FunctionRequest(
297
+ execution_type="class",
298
+ class_name=self._class_type.__name__,
299
+ class_code=cached_data["class_code"],
300
+ method_name=name,
301
+ args=method_args,
302
+ kwargs=method_kwargs,
303
+ constructor_args=constructor_args,
304
+ constructor_kwargs=constructor_kwargs,
305
+ dependencies=self._dependencies,
306
+ system_dependencies=self._system_dependencies,
307
+ accelerate_downloads=self._accelerate_downloads,
308
+ instance_id=self._instance_id,
309
+ create_new_instance=not hasattr(
310
+ self, "_stub"
311
+ ), # Create new only on first call
312
+ )
313
+
314
+ # Execute via stub
315
+ return await self._stub.execute_class_method(request) # type: ignore
316
+
317
+ return method_proxy
318
+
319
+ return RemoteClassWrapper
tetra_rp/logger.py ADDED
@@ -0,0 +1,34 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ from typing import Union, Optional
5
+
6
+
7
+ def setup_logging(
8
+ level: Union[int, str] = logging.INFO, stream=sys.stdout, fmt: Optional[str] = None
9
+ ):
10
+ """
11
+ Sets up the root logger with a stream handler and basic formatting.
12
+ Does nothing if handlers are already configured.
13
+ """
14
+ if isinstance(level, str):
15
+ level = getattr(logging, level.upper(), logging.INFO)
16
+
17
+ if fmt is None:
18
+ if level == logging.DEBUG:
19
+ fmt = "%(asctime)s | %(levelname)-5s | %(name)s | %(filename)s:%(lineno)d | %(message)s"
20
+ else:
21
+ # Default format for INFO level and above
22
+ fmt = "%(asctime)s | %(levelname)-5s | %(message)s"
23
+
24
+ root_logger = logging.getLogger()
25
+ if not root_logger.hasHandlers():
26
+ handler = logging.StreamHandler(stream)
27
+ handler.setFormatter(logging.Formatter(fmt))
28
+ root_logger.setLevel(level)
29
+ root_logger.addHandler(handler)
30
+
31
+ # Optionally allow log level override via env var
32
+ env_level = os.environ.get("LOG_LEVEL")
33
+ if env_level:
34
+ root_logger.setLevel(env_level.upper())
File without changes
@@ -0,0 +1,148 @@
1
+ """Remote execution protocol definitions using Pydantic models.
2
+
3
+ This module defines the request/response protocol for remote function and class execution.
4
+ The models align with the protobuf schema for communication with remote workers.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from pydantic import BaseModel, Field, model_validator
11
+
12
+
13
+ class FunctionRequest(BaseModel):
14
+ """Request model for remote function or class execution.
15
+
16
+ Supports both function-based execution and class instantiation with method calls.
17
+ All serialized data (args, kwargs, etc.) are base64-encoded cloudpickle strings.
18
+ """
19
+
20
+ # MADE OPTIONAL - can be None for class-only execution
21
+ function_name: Optional[str] = Field(
22
+ default=None,
23
+ description="Name of the function to execute",
24
+ )
25
+ function_code: Optional[str] = Field(
26
+ default=None,
27
+ description="Source code of the function to execute",
28
+ )
29
+ args: List[str] = Field(
30
+ default_factory=list,
31
+ description="List of base64-encoded cloudpickle-serialized arguments",
32
+ )
33
+ kwargs: Dict[str, str] = Field(
34
+ default_factory=dict,
35
+ description="Dictionary of base64-encoded cloudpickle-serialized keyword arguments",
36
+ )
37
+ dependencies: Optional[List[str]] = Field(
38
+ default=None,
39
+ description="Optional list of pip packages to install before executing the function",
40
+ )
41
+ system_dependencies: Optional[List[str]] = Field(
42
+ default=None,
43
+ description="Optional list of system dependencies to install before executing the function",
44
+ )
45
+
46
+ # NEW FIELDS FOR CLASS SUPPORT
47
+ execution_type: str = Field(
48
+ default="function", description="Type of execution: 'function' or 'class'"
49
+ )
50
+ class_name: Optional[str] = Field(
51
+ default=None,
52
+ description="Name of the class to instantiate (for class execution)",
53
+ )
54
+ class_code: Optional[str] = Field(
55
+ default=None,
56
+ description="Source code of the class to instantiate (for class execution)",
57
+ )
58
+ constructor_args: List[str] = Field(
59
+ default_factory=list,
60
+ description="List of base64-encoded cloudpickle-serialized constructor arguments",
61
+ )
62
+ constructor_kwargs: Dict[str, str] = Field(
63
+ default_factory=dict,
64
+ description="Dictionary of base64-encoded cloudpickle-serialized constructor keyword arguments",
65
+ )
66
+ method_name: str = Field(
67
+ default="__call__",
68
+ description="Name of the method to call on the class instance",
69
+ )
70
+ instance_id: Optional[str] = Field(
71
+ default=None,
72
+ description="Unique identifier for the class instance (for persistence)",
73
+ )
74
+ create_new_instance: bool = Field(
75
+ default=True,
76
+ description="Whether to create a new instance or reuse existing one",
77
+ )
78
+
79
+ # Download acceleration fields
80
+ accelerate_downloads: bool = Field(
81
+ default=True,
82
+ description="Enable download acceleration for dependencies and models",
83
+ )
84
+
85
+ @model_validator(mode="after")
86
+ def validate_execution_requirements(self) -> "FunctionRequest":
87
+ """Validate that required fields are provided based on execution_type"""
88
+ if self.execution_type == "function":
89
+ if self.function_name is None:
90
+ raise ValueError(
91
+ 'function_name is required when execution_type is "function"'
92
+ )
93
+ if self.function_code is None:
94
+ raise ValueError(
95
+ 'function_code is required when execution_type is "function"'
96
+ )
97
+
98
+ elif self.execution_type == "class":
99
+ if self.class_name is None:
100
+ raise ValueError(
101
+ 'class_name is required when execution_type is "class"'
102
+ )
103
+ if self.class_code is None:
104
+ raise ValueError(
105
+ 'class_code is required when execution_type is "class"'
106
+ )
107
+
108
+ return self
109
+
110
+
111
+ class FunctionResponse(BaseModel):
112
+ """Response model for remote function or class execution results.
113
+
114
+ Contains execution results, error information, and metadata about class instances
115
+ when applicable. The result field contains base64-encoded cloudpickle data.
116
+ """
117
+
118
+ success: bool = Field(
119
+ description="Indicates if the function execution was successful",
120
+ )
121
+ result: Optional[str] = Field(
122
+ default=None,
123
+ description="Base64-encoded cloudpickle-serialized result of the function",
124
+ )
125
+ error: Optional[str] = Field(
126
+ default=None,
127
+ description="Error message if the function execution failed",
128
+ )
129
+ stdout: Optional[str] = Field(
130
+ default=None,
131
+ description="Captured standard output from the function execution",
132
+ )
133
+ instance_id: Optional[str] = Field(
134
+ default=None, description="ID of the class instance that was used/created"
135
+ )
136
+ instance_info: Optional[Dict[str, Any]] = Field(
137
+ default=None,
138
+ description="Metadata about the class instance (creation time, call count, etc.)",
139
+ )
140
+
141
+
142
+ class RemoteExecutorStub(ABC):
143
+ """Abstract base class for remote execution."""
144
+
145
+ @abstractmethod
146
+ async def ExecuteFunction(self, request: FunctionRequest) -> FunctionResponse:
147
+ """Execute a function on the remote resource."""
148
+ raise NotImplementedError("Subclasses should implement this method.")
@@ -0,0 +1,5 @@
1
+ from .registry import stub_resource
2
+
3
+ __all__ = [
4
+ "stub_resource",
5
+ ]
@@ -0,0 +1,155 @@
1
+ import ast
2
+ import base64
3
+ import inspect
4
+ import textwrap
5
+ import hashlib
6
+ import traceback
7
+ import threading
8
+ import cloudpickle
9
+ import logging
10
+ from ..core.resources import LiveServerless
11
+ from ..protos.remote_execution import (
12
+ FunctionRequest,
13
+ FunctionResponse,
14
+ RemoteExecutorStub,
15
+ )
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+
20
+ # Global in-memory cache with thread safety
21
+ _SERIALIZED_FUNCTION_CACHE = {}
22
+ _function_cache_lock = threading.RLock()
23
+
24
+
25
+ def get_function_source(func):
26
+ """Extract the function source code without the decorator."""
27
+ # Unwrap any decorators to get the original function
28
+ func = inspect.unwrap(func)
29
+
30
+ # Get the source code of the decorated function
31
+ source = inspect.getsource(func)
32
+
33
+ # Dedent the source to handle functions defined in classes or indented contexts
34
+ source = textwrap.dedent(source)
35
+
36
+ # Parse the source code
37
+ module = ast.parse(source)
38
+
39
+ # Find the function definition node (both sync and async)
40
+ function_def = None
41
+ for node in ast.walk(module):
42
+ if (
43
+ isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
44
+ and node.name == func.__name__
45
+ ):
46
+ function_def = node
47
+ break
48
+
49
+ if not function_def:
50
+ raise ValueError(f"Could not find function definition for {func.__name__}")
51
+
52
+ # Get the line and column offsets
53
+ lineno = function_def.lineno - 1 # Line numbers are 1-based
54
+
55
+ # Split into lines and extract just the function part
56
+ lines = source.split("\n")
57
+ function_lines = lines[lineno:]
58
+
59
+ # Dedent to remove any extra indentation
60
+ function_source = textwrap.dedent("\n".join(function_lines))
61
+
62
+ # Return the function hash for cache key
63
+ source_hash = hashlib.sha256(function_source.encode("utf-8")).hexdigest()
64
+
65
+ return function_source, source_hash
66
+
67
+
68
+ class LiveServerlessStub(RemoteExecutorStub):
69
+ """Adapter class to make Runpod endpoints look like gRPC stubs."""
70
+
71
+ def __init__(self, server: LiveServerless):
72
+ self.server = server
73
+
74
+ def prepare_request(
75
+ self,
76
+ func,
77
+ dependencies,
78
+ system_dependencies,
79
+ accelerate_downloads,
80
+ *args,
81
+ **kwargs,
82
+ ):
83
+ source, src_hash = get_function_source(func)
84
+
85
+ request = {
86
+ "function_name": func.__name__,
87
+ "dependencies": dependencies,
88
+ "system_dependencies": system_dependencies,
89
+ "accelerate_downloads": accelerate_downloads,
90
+ }
91
+
92
+ # Thread-safe cache access
93
+ with _function_cache_lock:
94
+ # check if the function is already cached
95
+ if src_hash not in _SERIALIZED_FUNCTION_CACHE:
96
+ # Cache the serialized function
97
+ _SERIALIZED_FUNCTION_CACHE[src_hash] = source
98
+
99
+ request["function_code"] = _SERIALIZED_FUNCTION_CACHE[src_hash]
100
+
101
+ # Serialize arguments using cloudpickle
102
+ if args:
103
+ request["args"] = [
104
+ base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8") for arg in args
105
+ ]
106
+ if kwargs:
107
+ request["kwargs"] = {
108
+ k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
109
+ for k, v in kwargs.items()
110
+ }
111
+
112
+ return FunctionRequest(**request)
113
+
114
+ def handle_response(self, response: FunctionResponse):
115
+ if not (response.success or response.error):
116
+ raise ValueError("Invalid response from server")
117
+
118
+ if response.stdout:
119
+ for line in response.stdout.splitlines():
120
+ print(line)
121
+
122
+ if response.success:
123
+ if response.result is None:
124
+ raise ValueError("Response result is None")
125
+ return cloudpickle.loads(base64.b64decode(response.result))
126
+ else:
127
+ raise Exception(f"Remote execution failed: {response.error}")
128
+
129
+ async def ExecuteFunction(
130
+ self, request: FunctionRequest, sync: bool = False
131
+ ) -> FunctionResponse:
132
+ try:
133
+ # Convert the gRPC request to Runpod format
134
+ payload = request.model_dump(exclude_none=True)
135
+
136
+ if sync:
137
+ job = await self.server.run_sync(payload)
138
+ else:
139
+ job = await self.server.run(payload)
140
+
141
+ if job.error:
142
+ return FunctionResponse(
143
+ success=False,
144
+ error=job.error,
145
+ stdout=job.output.get("stdout", ""),
146
+ )
147
+
148
+ return FunctionResponse(**job.output)
149
+
150
+ except Exception as e:
151
+ error_traceback = traceback.format_exc()
152
+ return FunctionResponse(
153
+ success=False,
154
+ error=f"{str(e)}\n{error_traceback}",
155
+ )