tetra-rp 0.6.0__tar.gz → 0.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/PKG-INFO +1 -1
  2. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/pyproject.toml +1 -1
  3. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/client.py +24 -25
  4. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/api/runpod.py +9 -27
  5. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/network_volume.py +20 -11
  6. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/serverless.py +42 -10
  7. tetra_rp-0.8.0/src/tetra_rp/execute_class.py +178 -0
  8. tetra_rp-0.8.0/src/tetra_rp/protos/remote_execution.py +128 -0
  9. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/stubs/registry.py +14 -5
  10. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp.egg-info/PKG-INFO +1 -1
  11. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp.egg-info/SOURCES.txt +1 -0
  12. tetra_rp-0.6.0/src/tetra_rp/protos/remote_execution.py +0 -57
  13. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/README.md +0 -0
  14. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/setup.cfg +0 -0
  15. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/__init__.py +0 -0
  16. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/__init__.py +0 -0
  17. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/api/__init__.py +0 -0
  18. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/pool/__init__.py +0 -0
  19. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/pool/cluster_manager.py +0 -0
  20. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/pool/dataclass.py +0 -0
  21. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/pool/ex.py +0 -0
  22. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/pool/job.py +0 -0
  23. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/pool/worker.py +0 -0
  24. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/__init__.py +0 -0
  25. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/base.py +0 -0
  26. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/cloud.py +0 -0
  27. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/constants.py +0 -0
  28. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/cpu.py +0 -0
  29. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/environment.py +0 -0
  30. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/gpu.py +0 -0
  31. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/live_serverless.py +0 -0
  32. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/resource_manager.py +0 -0
  33. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/template.py +0 -0
  34. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/resources/utils.py +0 -0
  35. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/utils/__init__.py +0 -0
  36. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/utils/backoff.py +0 -0
  37. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/utils/json.py +0 -0
  38. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/core/utils/singleton.py +0 -0
  39. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/logger.py +0 -0
  40. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/protos/__init__.py +0 -0
  41. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/stubs/__init__.py +0 -0
  42. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/stubs/live_serverless.py +0 -0
  43. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp/stubs/serverless.py +0 -0
  44. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp.egg-info/dependency_links.txt +0 -0
  45. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp.egg-info/requires.txt +0 -0
  46. {tetra_rp-0.6.0 → tetra_rp-0.8.0}/src/tetra_rp.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tetra_rp
3
- Version: 0.6.0
3
+ Version: 0.8.0
4
4
  Summary: A Python library for distributed inference and serving of machine learning models
5
5
  Author-email: Marut Pandya <pandyamarut@gmail.com>, Patrick Rachford <prachford@icloud.com>, Dean Quinanola <dean.quinanola@runpod.io>
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tetra_rp"
3
- version = "0.6.0"
3
+ version = "0.8.0"
4
4
  description = "A Python library for distributed inference and serving of machine learning models"
5
5
  authors = [
6
6
  { name = "Marut Pandya", email = "pandyamarut@gmail.com" },
@@ -1,18 +1,19 @@
1
+ import inspect
1
2
  import logging
2
3
  from functools import wraps
3
4
  from typing import List, Optional
4
- from .core.resources import ServerlessResource, ResourceManager, NetworkVolume
5
- from .stubs import stub_resource
6
5
 
6
+ from .core.resources import ResourceManager, ServerlessResource
7
+ from .execute_class import create_remote_class
8
+ from .stubs import stub_resource
7
9
 
8
10
  log = logging.getLogger(__name__)
9
11
 
10
12
 
11
13
  def remote(
12
14
  resource_config: ServerlessResource,
13
- dependencies: List[str] = None,
14
- system_dependencies: List[str] = None,
15
- mount_volume: Optional[NetworkVolume] = None,
15
+ dependencies: Optional[List[str]] = None,
16
+ system_dependencies: Optional[List[str]] = None,
16
17
  **extra,
17
18
  ):
18
19
  """
@@ -25,8 +26,6 @@ def remote(
25
26
  to be provisioned or used.
26
27
  dependencies (List[str], optional): A list of pip package names to be installed in the remote
27
28
  environment before executing the function. Defaults to None.
28
- mount_volume (NetworkVolume, optional): Configuration for creating and mounting a network volume.
29
- Should contain 'size', 'datacenter_id', and 'name' keys. Defaults to None.
30
29
  extra (dict, optional): Additional parameters for the execution of the resource. Defaults to an empty dict.
31
30
 
32
31
  Returns:
@@ -46,26 +45,26 @@ def remote(
46
45
  ```
47
46
  """
48
47
 
49
- def decorator(func):
50
- @wraps(func)
51
- async def wrapper(*args, **kwargs):
52
- # Create netowrk volume if mount_volume is provided
53
- if mount_volume:
54
- try:
55
- network_volume = await mount_volume.deploy()
56
- resource_config.networkVolumeId = network_volume.id
57
- except Exception as e:
58
- log.error(f"Failed to create or mount network volume: {e}")
59
- raise
60
-
61
- resource_manager = ResourceManager()
62
- remote_resource = await resource_manager.get_or_deploy_resource(
63
- resource_config
48
+ def decorator(func_or_class):
49
+ if inspect.isclass(func_or_class):
50
+ # Handle class decoration
51
+ return create_remote_class(
52
+ func_or_class, resource_config, dependencies, system_dependencies, extra
64
53
  )
54
+ else:
55
+ # Handle function decoration (unchanged)
56
+ @wraps(func_or_class)
57
+ async def wrapper(*args, **kwargs):
58
+ resource_manager = ResourceManager()
59
+ remote_resource = await resource_manager.get_or_deploy_resource(
60
+ resource_config
61
+ )
65
62
 
66
- stub = stub_resource(remote_resource, **extra)
67
- return await stub(func, dependencies, system_dependencies, *args, **kwargs)
63
+ stub = stub_resource(remote_resource, **extra)
64
+ return await stub(
65
+ func_or_class, dependencies, system_dependencies, *args, **kwargs
66
+ )
68
67
 
69
- return wrapper
68
+ return wrapper
70
69
 
71
70
  return decorator
@@ -3,11 +3,12 @@ Direct GraphQL communication with Runpod API.
3
3
  Bypasses the outdated runpod-python SDK limitations.
4
4
  """
5
5
 
6
- import os
7
6
  import json
8
- import aiohttp
9
- from typing import Dict, Any, Optional
10
7
  import logging
8
+ import os
9
+ from typing import Any, Dict, Optional
10
+
11
+ import aiohttp
11
12
 
12
13
  log = logging.getLogger(__name__)
13
14
 
@@ -267,31 +268,12 @@ class RunpodRestClient:
267
268
  raise Exception(f"HTTP request failed: {e}")
268
269
 
269
270
  async def create_network_volume(self, payload: Dict[str, Any]) -> Dict[str, Any]:
270
- """
271
- Create a network volume in Runpod.
271
+ """Create a network volume in Runpod."""
272
+ log.debug(f"Creating network volume: {payload.get('name', 'unnamed')}")
272
273
 
273
- Args:
274
- datacenter_id (str): The ID of the datacenter where the volume will be created.
275
- name (str): The name of the network volume.
276
- size_gb (int): The size of the volume in GB.
277
-
278
- Returns:
279
- Dict[str, Any]: The created network volume details.
280
- """
281
- datacenter_id = payload.get("dataCenterId")
282
- if hasattr(datacenter_id, "value"):
283
- # If datacenter_id is an enum, get its value
284
- datacenter_id = datacenter_id.value
285
- data = {
286
- "dataCenterId": datacenter_id,
287
- "name": payload.get("name"),
288
- "size": payload.get("size"),
289
- }
290
- url = f"{RUNPOD_REST_API_URL}/networkvolumes"
291
-
292
- log.debug(f"Creating network volume: {data.get('name', 'unnamed')}")
293
-
294
- result = await self._execute_rest("POST", url, data)
274
+ result = await self._execute_rest(
275
+ "POST", f"{RUNPOD_REST_API_URL}/networkvolumes", payload
276
+ )
295
277
 
296
278
  log.info(
297
279
  f"Created network volume: {result.get('id', 'unknown')} - {result.get('name', 'unnamed')}"
@@ -4,6 +4,7 @@ from typing import Optional
4
4
 
5
5
  from pydantic import (
6
6
  Field,
7
+ field_serializer,
7
8
  )
8
9
 
9
10
  from ..api.runpod import RunpodRestClient
@@ -20,8 +21,6 @@ class DataCenter(str, Enum):
20
21
  """
21
22
 
22
23
  EU_RO_1 = "EU-RO-1"
23
- US_WA_1 = "US-WA-1"
24
- US_CA_1 = "US-CA-1"
25
24
 
26
25
 
27
26
  class NetworkVolume(DeployableResource):
@@ -33,10 +32,20 @@ class NetworkVolume(DeployableResource):
33
32
 
34
33
  """
35
34
 
36
- dataCenterId: Optional[DataCenter] = None
35
+ # Internal fixed value
36
+ dataCenterId: DataCenter = Field(default=DataCenter.EU_RO_1, frozen=True)
37
+
37
38
  id: Optional[str] = Field(default=None)
38
39
  name: Optional[str] = None
39
- size: Optional[int] = None # Size in GB
40
+ size: Optional[int] = Field(default=10, gt=0) # Size in GB
41
+
42
+ def __str__(self) -> str:
43
+ return f"{self.__class__.__name__}:{self.id}"
44
+
45
+ @field_serializer("dataCenterId")
46
+ def serialize_data_center_id(self, value: Optional[DataCenter]) -> Optional[str]:
47
+ """Convert DataCenter enum to string."""
48
+ return value.value if value is not None else None
40
49
 
41
50
  @property
42
51
  def is_created(self) -> bool:
@@ -79,17 +88,17 @@ class NetworkVolume(DeployableResource):
79
88
  try:
80
89
  # If the resource is already deployed, return it
81
90
  if self.is_deployed():
82
- log.debug(
83
- f"Network volume {self.id} is already deployed. Mounting existing volume."
84
- )
85
- log.info(f"Mounted existing network volume: {self.id}")
91
+ log.debug(f"{self} exists")
86
92
  return self
87
93
 
88
94
  # Create the network volume
89
- self = await self.create_network_volume()
95
+ async with RunpodRestClient() as client:
96
+ # Create the network volume
97
+ payload = self.model_dump(exclude_none=True)
98
+ result = await client.create_network_volume(payload)
90
99
 
91
- if self.is_deployed():
92
- return self
100
+ if volume := self.__class__(**result):
101
+ return volume
93
102
 
94
103
  raise ValueError("Deployment failed, no volume was created.")
95
104
 
@@ -1,27 +1,27 @@
1
1
  import asyncio
2
2
  import logging
3
- from typing import Any, Dict, List, Optional
4
3
  from enum import Enum
4
+ from typing import Any, Dict, List, Optional
5
+
5
6
  from pydantic import (
7
+ BaseModel,
8
+ Field,
6
9
  field_serializer,
7
10
  field_validator,
8
11
  model_validator,
9
- BaseModel,
10
- Field,
11
12
  )
12
-
13
13
  from runpod.endpoint.runner import Job
14
14
 
15
15
  from ..api.runpod import RunpodGraphQLClient
16
16
  from ..utils.backoff import get_backoff_delay
17
-
18
- from .cloud import runpod
19
17
  from .base import DeployableResource
20
- from .template import PodTemplate, KeyValuePair
21
- from .gpu import GpuGroup
18
+ from .cloud import runpod
19
+ from .constants import CONSOLE_URL
22
20
  from .cpu import CpuInstanceType
23
21
  from .environment import EnvironmentVars
24
- from .constants import CONSOLE_URL
22
+ from .gpu import GpuGroup
23
+ from .network_volume import NetworkVolume
24
+ from .template import KeyValuePair, PodTemplate
25
25
 
26
26
 
27
27
  # Environment variables are loaded from the .env file
@@ -62,7 +62,15 @@ class ServerlessResource(DeployableResource):
62
62
  Base class for GPU serverless resource
63
63
  """
64
64
 
65
- _input_only = {"id", "cudaVersions", "env", "gpus", "flashboot", "imageName"}
65
+ _input_only = {
66
+ "id",
67
+ "cudaVersions",
68
+ "env",
69
+ "gpus",
70
+ "flashboot",
71
+ "imageName",
72
+ "networkVolume",
73
+ }
66
74
 
67
75
  # === Input-only Fields ===
68
76
  cudaVersions: Optional[List[CudaVersion]] = [] # for allowedCudaVersions
@@ -71,6 +79,8 @@ class ServerlessResource(DeployableResource):
71
79
  gpus: Optional[List[GpuGroup]] = [GpuGroup.ANY] # for gpuIds
72
80
  imageName: Optional[str] = "" # for template.imageName
73
81
 
82
+ networkVolume: Optional[NetworkVolume] = None
83
+
74
84
  # === Input Fields ===
75
85
  executionTimeoutMs: Optional[int] = None
76
86
  gpuCount: Optional[int] = 1
@@ -142,6 +152,10 @@ class ServerlessResource(DeployableResource):
142
152
  if self.flashboot:
143
153
  self.name += "-fb"
144
154
 
155
+ if self.networkVolume and self.networkVolume.is_created:
156
+ # Volume already exists, use its ID
157
+ self.networkVolumeId = self.networkVolume.id
158
+
145
159
  if self.instanceIds:
146
160
  return self._sync_input_fields_cpu()
147
161
  else:
@@ -177,6 +191,21 @@ class ServerlessResource(DeployableResource):
177
191
 
178
192
  return self
179
193
 
194
+ async def _ensure_network_volume_deployed(self) -> None:
195
+ """
196
+ Ensures network volume is deployed and ready.
197
+ Updates networkVolumeId with the deployed volume ID.
198
+ """
199
+ if self.networkVolumeId:
200
+ return
201
+
202
+ if not self.networkVolume:
203
+ log.info(f"{self.name} requires a default network volume")
204
+ self.networkVolume = NetworkVolume(name=f"{self.name}-volume")
205
+
206
+ if deployedNetworkVolume := await self.networkVolume.deploy():
207
+ self.networkVolumeId = deployedNetworkVolume.id
208
+
180
209
  def is_deployed(self) -> bool:
181
210
  """
182
211
  Checks if the serverless resource is deployed and available.
@@ -202,6 +231,9 @@ class ServerlessResource(DeployableResource):
202
231
  log.debug(f"{self} exists")
203
232
  return self
204
233
 
234
+ # NEW: Ensure network volume is deployed first
235
+ await self._ensure_network_volume_deployed()
236
+
205
237
  async with RunpodGraphQLClient() as client:
206
238
  payload = self.model_dump(exclude=self._input_only, exclude_none=True)
207
239
  result = await client.create_endpoint(payload)
@@ -0,0 +1,178 @@
1
+ import base64
2
+ import inspect
3
+ import logging
4
+ import textwrap
5
+ import uuid
6
+ from typing import List, Type, Optional
7
+
8
+ import cloudpickle
9
+
10
+ from .core.resources import ResourceManager, ServerlessResource
11
+ from .protos.remote_execution import FunctionRequest
12
+ from .stubs import stub_resource
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ def extract_class_code_simple(cls: Type) -> str:
18
+ """Extract clean class code without decorators and proper indentation"""
19
+ try:
20
+ # Get source code
21
+ source = inspect.getsource(cls)
22
+
23
+ # Split into lines
24
+ lines = source.split("\n")
25
+
26
+ # Find the class definition line (starts with 'class' and contains ':')
27
+ class_start_idx = -1
28
+ for i, line in enumerate(lines):
29
+ stripped = line.strip()
30
+ if stripped.startswith("class ") and ":" in stripped:
31
+ class_start_idx = i
32
+ break
33
+
34
+ if class_start_idx == -1:
35
+ raise ValueError("Could not find class definition")
36
+
37
+ # Take lines from class definition onwards (ignore everything before)
38
+ class_lines = lines[class_start_idx:]
39
+
40
+ # Remove empty lines at the end
41
+ while class_lines and not class_lines[-1].strip():
42
+ class_lines.pop()
43
+
44
+ # Join back and dedent to remove any leading indentation
45
+ class_code = "\n".join(class_lines)
46
+ class_code = textwrap.dedent(class_code)
47
+
48
+ # Validate the code by trying to compile it
49
+ compile(class_code, "<string>", "exec")
50
+
51
+ log.debug(f"Successfully extracted class code for {cls.__name__}")
52
+ return class_code
53
+
54
+ except Exception as e:
55
+ log.warning(f"Could not extract class code for {cls.__name__}: {e}")
56
+ log.warning("Falling back to basic class structure")
57
+
58
+ # Enhanced fallback: try to preserve method signatures
59
+ fallback_methods = []
60
+ for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
61
+ try:
62
+ sig = inspect.signature(method)
63
+ fallback_methods.append(f" def {name}{sig}:")
64
+ fallback_methods.append(" pass")
65
+ fallback_methods.append("")
66
+ except (TypeError, ValueError, OSError) as e:
67
+ log.warning(f"Could not extract method signature for {name}: {e}")
68
+ fallback_methods.append(f" def {name}(self, *args, **kwargs):")
69
+ fallback_methods.append(" pass")
70
+ fallback_methods.append("")
71
+
72
+ fallback_code = f"""class {cls.__name__}:
73
+ def __init__(self, *args, **kwargs):
74
+ pass
75
+
76
+ {chr(10).join(fallback_methods)}"""
77
+
78
+ return fallback_code
79
+
80
+
81
+ def create_remote_class(
82
+ cls: Type,
83
+ resource_config: ServerlessResource,
84
+ dependencies: Optional[List[str]],
85
+ system_dependencies: Optional[List[str]],
86
+ extra: dict,
87
+ ):
88
+ """
89
+ Create a remote class wrapper.
90
+ """
91
+ # Validate inputs
92
+ if not inspect.isclass(cls):
93
+ raise TypeError(f"Expected a class, got {type(cls).__name__}")
94
+ if not hasattr(cls, "__name__"):
95
+ raise ValueError("Class must have a __name__ attribute")
96
+
97
+ class RemoteClassWrapper:
98
+ def __init__(self, *args, **kwargs):
99
+ self._class_type = cls
100
+ self._resource_config = resource_config
101
+ self._dependencies = dependencies or []
102
+ self._system_dependencies = system_dependencies or []
103
+ self._extra = extra
104
+ self._constructor_args = args
105
+ self._constructor_kwargs = kwargs
106
+ self._instance_id = f"{cls.__name__}_{uuid.uuid4().hex[:8]}"
107
+ self._initialized = False
108
+
109
+ self._clean_class_code = extract_class_code_simple(cls)
110
+
111
+ log.debug(f"Created remote class wrapper for {cls.__name__}")
112
+
113
+ async def _ensure_initialized(self):
114
+ """Ensure the remote instance is created."""
115
+ if self._initialized:
116
+ return
117
+
118
+ # Get remote resource
119
+ resource_manager = ResourceManager()
120
+ remote_resource = await resource_manager.get_or_deploy_resource(
121
+ self._resource_config
122
+ )
123
+ self._stub = stub_resource(remote_resource, **self._extra)
124
+
125
+ # Create the remote instance by calling a method (which will trigger instance creation)
126
+ # We'll do this on first method call
127
+ self._initialized = True
128
+
129
+ def __getattr__(self, name):
130
+ """Dynamically create method proxies for all class methods."""
131
+ if name.startswith("_"):
132
+ raise AttributeError(
133
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
134
+ )
135
+
136
+ async def method_proxy(*args, **kwargs):
137
+ await self._ensure_initialized()
138
+
139
+ # Create class method request
140
+
141
+ # class_code = inspect.getsource(self._class_type)
142
+ class_code = self._clean_class_code
143
+
144
+ request = FunctionRequest(
145
+ execution_type="class",
146
+ class_name=self._class_type.__name__,
147
+ class_code=class_code,
148
+ method_name=name,
149
+ args=[
150
+ base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8")
151
+ for arg in args
152
+ ],
153
+ kwargs={
154
+ k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
155
+ for k, v in kwargs.items()
156
+ },
157
+ constructor_args=[
158
+ base64.b64encode(cloudpickle.dumps(arg)).decode("utf-8")
159
+ for arg in self._constructor_args
160
+ ],
161
+ constructor_kwargs={
162
+ k: base64.b64encode(cloudpickle.dumps(v)).decode("utf-8")
163
+ for k, v in self._constructor_kwargs.items()
164
+ },
165
+ dependencies=self._dependencies,
166
+ system_dependencies=self._system_dependencies,
167
+ instance_id=self._instance_id,
168
+ create_new_instance=not hasattr(
169
+ self, "_stub"
170
+ ), # Create new only on first call
171
+ )
172
+
173
+ # Execute via stub
174
+ return await self._stub.execute_class_method(request) # type: ignore
175
+
176
+ return method_proxy
177
+
178
+ return RemoteClassWrapper
@@ -0,0 +1,128 @@
1
+ # TODO: generate using betterproto
2
+ from abc import ABC, abstractmethod
3
+ from typing import Dict, List, Optional
4
+
5
+ from pydantic import BaseModel, Field, model_validator
6
+
7
+
8
+ class FunctionRequest(BaseModel):
9
+ # MADE OPTIONAL - can be None for class-only execution
10
+ function_name: Optional[str] = Field(
11
+ default=None,
12
+ description="Name of the function to execute",
13
+ )
14
+ function_code: Optional[str] = Field(
15
+ default=None,
16
+ description="Source code of the function to execute",
17
+ )
18
+ args: List = Field(
19
+ default_factory=list,
20
+ description="List of base64-encoded cloudpickle-serialized arguments",
21
+ )
22
+ kwargs: Dict = Field(
23
+ default_factory=dict,
24
+ description="Dictionary of base64-encoded cloudpickle-serialized keyword arguments",
25
+ )
26
+ dependencies: Optional[List] = Field(
27
+ default=None,
28
+ description="Optional list of pip packages to install before executing the function",
29
+ )
30
+ system_dependencies: Optional[List] = Field(
31
+ default=None,
32
+ description="Optional list of system dependencies to install before executing the function",
33
+ )
34
+
35
+ # NEW FIELDS FOR CLASS SUPPORT
36
+ execution_type: str = Field(
37
+ default="function", description="Type of execution: 'function' or 'class'"
38
+ )
39
+ class_name: Optional[str] = Field(
40
+ default=None,
41
+ description="Name of the class to instantiate (for class execution)",
42
+ )
43
+ class_code: Optional[str] = Field(
44
+ default=None,
45
+ description="Source code of the class to instantiate (for class execution)",
46
+ )
47
+ constructor_args: Optional[List] = Field(
48
+ default_factory=list,
49
+ description="List of base64-encoded cloudpickle-serialized constructor arguments",
50
+ )
51
+ constructor_kwargs: Optional[Dict] = Field(
52
+ default_factory=dict,
53
+ description="Dictionary of base64-encoded cloudpickle-serialized constructor keyword arguments",
54
+ )
55
+ method_name: str = Field(
56
+ default="__call__",
57
+ description="Name of the method to call on the class instance",
58
+ )
59
+ instance_id: Optional[str] = Field(
60
+ default=None,
61
+ description="Unique identifier for the class instance (for persistence)",
62
+ )
63
+ create_new_instance: bool = Field(
64
+ default=True,
65
+ description="Whether to create a new instance or reuse existing one",
66
+ )
67
+
68
+ @model_validator(mode="after")
69
+ def validate_execution_requirements(self) -> "FunctionRequest":
70
+ """Validate that required fields are provided based on execution_type"""
71
+ if self.execution_type == "function":
72
+ if self.function_name is None:
73
+ raise ValueError(
74
+ 'function_name is required when execution_type is "function"'
75
+ )
76
+ if self.function_code is None:
77
+ raise ValueError(
78
+ 'function_code is required when execution_type is "function"'
79
+ )
80
+
81
+ elif self.execution_type == "class":
82
+ if self.class_name is None:
83
+ raise ValueError(
84
+ 'class_name is required when execution_type is "class"'
85
+ )
86
+ if self.class_code is None:
87
+ raise ValueError(
88
+ 'class_code is required when execution_type is "class"'
89
+ )
90
+
91
+ return self
92
+
93
+
94
+ class FunctionResponse(BaseModel):
95
+ # EXISTING FIELDS (unchanged)
96
+ success: bool = Field(
97
+ description="Indicates if the function execution was successful",
98
+ )
99
+ result: Optional[str] = Field(
100
+ default=None,
101
+ description="Base64-encoded cloudpickle-serialized result of the function",
102
+ )
103
+ error: Optional[str] = Field(
104
+ default=None,
105
+ description="Error message if the function execution failed",
106
+ )
107
+ stdout: Optional[str] = Field(
108
+ default=None,
109
+ description="Captured standard output from the function execution",
110
+ )
111
+
112
+ # NEW FIELDS FOR CLASS SUPPORT
113
+ instance_id: Optional[str] = Field(
114
+ default=None, description="ID of the class instance that was used/created"
115
+ )
116
+ instance_info: Optional[Dict] = Field(
117
+ default=None,
118
+ description="Metadata about the class instance (creation time, call count, etc.)",
119
+ )
120
+
121
+
122
+ class RemoteExecutorStub(ABC):
123
+ """Abstract base class for remote execution."""
124
+
125
+ @abstractmethod
126
+ async def ExecuteFunction(self, request: FunctionRequest) -> FunctionResponse:
127
+ """Execute a function on the remote resource."""
128
+ raise NotImplementedError("Subclasses should implement this method.")
@@ -1,13 +1,13 @@
1
1
  import logging
2
2
  from functools import singledispatch
3
- from .live_serverless import LiveServerlessStub
4
- from .serverless import ServerlessEndpointStub
3
+
5
4
  from ..core.resources import (
6
5
  CpuServerlessEndpoint,
7
6
  LiveServerless,
8
7
  ServerlessEndpoint,
9
8
  )
10
-
9
+ from .live_serverless import LiveServerlessStub
10
+ from .serverless import ServerlessEndpointStub
11
11
 
12
12
  log = logging.getLogger(__name__)
13
13
 
@@ -22,20 +22,29 @@ def stub_resource(resource, **extra):
22
22
 
23
23
  @stub_resource.register(LiveServerless)
24
24
  def _(resource, **extra):
25
+ stub = LiveServerlessStub(resource)
26
+
27
+ # Function execution
25
28
  async def stubbed_resource(
26
29
  func, dependencies, system_dependencies, *args, **kwargs
27
30
  ) -> dict:
28
31
  if args == (None,):
29
- # cleanup: when the function is called with no args
30
32
  args = []
31
33
 
32
- stub = LiveServerlessStub(resource)
33
34
  request = stub.prepare_request(
34
35
  func, dependencies, system_dependencies, *args, **kwargs
35
36
  )
36
37
  response = await stub.ExecuteFunction(request)
37
38
  return stub.handle_response(response)
38
39
 
40
+ # Class method execution
41
+ async def execute_class_method(request):
42
+ response = await stub.ExecuteFunction(request)
43
+ return stub.handle_response(response)
44
+
45
+ # Attach the method to the function
46
+ stubbed_resource.execute_class_method = execute_class_method
47
+
39
48
  return stubbed_resource
40
49
 
41
50
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tetra_rp
3
- Version: 0.6.0
3
+ Version: 0.8.0
4
4
  Summary: A Python library for distributed inference and serving of machine learning models
5
5
  Author-email: Marut Pandya <pandyamarut@gmail.com>, Patrick Rachford <prachford@icloud.com>, Dean Quinanola <dean.quinanola@runpod.io>
6
6
  License: MIT
@@ -2,6 +2,7 @@ README.md
2
2
  pyproject.toml
3
3
  src/tetra_rp/__init__.py
4
4
  src/tetra_rp/client.py
5
+ src/tetra_rp/execute_class.py
5
6
  src/tetra_rp/logger.py
6
7
  src/tetra_rp.egg-info/PKG-INFO
7
8
  src/tetra_rp.egg-info/SOURCES.txt
@@ -1,57 +0,0 @@
1
- # TODO: generate using betterproto
2
-
3
- from abc import ABC, abstractmethod
4
- from typing import List, Dict, Optional
5
- from pydantic import BaseModel, Field
6
-
7
-
8
- class FunctionRequest(BaseModel):
9
- function_name: str = Field(
10
- description="Name of the function to execute",
11
- )
12
- function_code: str = Field(
13
- description="Source code of the function to execute",
14
- )
15
- args: List = Field(
16
- default_factory=list,
17
- description="List of base64-encoded cloudpickle-serialized arguments",
18
- )
19
- kwargs: Dict = Field(
20
- default_factory=dict,
21
- description="Dictionary of base64-encoded cloudpickle-serialized keyword arguments",
22
- )
23
- dependencies: Optional[List] = Field(
24
- default=None,
25
- description="Optional list of pip packages to install before executing the function",
26
- )
27
- system_dependencies: Optional[List] = Field(
28
- default=None,
29
- description="Optional list of system dependencies to install before executing the function",
30
- )
31
-
32
-
33
- class FunctionResponse(BaseModel):
34
- success: bool = Field(
35
- description="Indicates if the function execution was successful",
36
- )
37
- result: Optional[str] = Field(
38
- default=None,
39
- description="Base64-encoded cloudpickle-serialized result of the function",
40
- )
41
- error: Optional[str] = Field(
42
- default=None,
43
- description="Error message if the function execution failed",
44
- )
45
- stdout: Optional[str] = Field(
46
- default=None,
47
- description="Captured standard output from the function execution",
48
- )
49
-
50
-
51
- class RemoteExecutorStub(ABC):
52
- """Abstract base class for remote execution."""
53
-
54
- @abstractmethod
55
- async def ExecuteFunction(self, request: FunctionRequest) -> FunctionResponse:
56
- """Execute a function on the remote resource."""
57
- raise NotImplementedError("Subclasses should implement this method.")
File without changes
File without changes