shared-tensor 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,27 @@
1
+ """
2
+ Shared Tensor Library
3
+
4
+ A library for sharing GPU memory objects across processes using IPC mechanisms.
5
+ Enables model and inference engine separation architecture using JSON-RPC 2.0 protocol.
6
+ """
7
+
8
+ from shared_tensor.provider import SharedTensorProvider
9
+ from shared_tensor.client import SharedTensorClient
10
+ from shared_tensor.server import SharedTensorServer
11
+ from shared_tensor.async_provider import AsyncSharedTensorProvider
12
+ from shared_tensor.async_client import AsyncSharedTensorClient
13
+ from shared_tensor.async_task import TaskStatus, TaskInfo
14
+
15
+ __version__ = "0.1.0"
16
+ __author__ = "Athena Team"
17
+
18
+ # Export main functionality
19
+ __all__ = [
20
+ "SharedTensorProvider",
21
+ "SharedTensorClient",
22
+ "SharedTensorServer",
23
+ "AsyncSharedTensorProvider",
24
+ "AsyncSharedTensorClient",
25
+ "TaskStatus",
26
+ "TaskInfo",
27
+ ]
@@ -0,0 +1,302 @@
1
+ """
2
+ Async Shared Tensor Client
3
+
4
+ Supports long-running task execution without HTTP timeout limitations.
5
+ """
6
+
7
+ import time
8
+ import logging
9
+ from typing import Any, Dict, Optional, Callable
10
+
11
+ import torch
12
+
13
+ from shared_tensor.errors import SharedTensorServerError
14
+ from shared_tensor.client import SharedTensorClient
15
+ from shared_tensor.async_task import TaskStatus, TaskInfo
16
+ from shared_tensor.utils import serialize_result
17
+
18
+
19
+ __all__ = ["AsyncSharedTensorClient", "execute_remote_function_async"]
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AsyncSharedTensorClient:
26
+ """
27
+ Async client for shared tensor operations
28
+
29
+ Supports submitting long-running tasks and polling for results
30
+ without being limited by HTTP timeouts.
31
+ """
32
+
33
+ def __init__(self, server_port: int = 2537, verbose_debug: bool = False, poll_interval: float = 1.0):
34
+ """
35
+ Initialize async client
36
+
37
+ Args:
38
+ server_port: Port of the shared tensor server
39
+ verbose_debug: Whether to enable verbose debug logging
40
+ poll_interval: Interval in seconds for polling task status
41
+ """
42
+ self.server_url = f"http://localhost:{server_port}"
43
+ self.verbose_debug = verbose_debug
44
+ self.poll_interval = poll_interval
45
+ self._client = SharedTensorClient(server_port, verbose_debug=verbose_debug)
46
+
47
+ def submit_task(self, function_path: str, args: tuple = (), kwargs: Dict[str, Any] = None, options: Dict[str, Any] = None) -> str:
48
+ """
49
+ Submit a task for async execution
50
+
51
+ Args:
52
+ function_path: Function path in format "module.submodule:function_name"
53
+ args: Positional arguments
54
+ kwargs: Keyword arguments
55
+ options: Options for the task
56
+
57
+ Returns:
58
+ Task ID for tracking the execution
59
+ """
60
+ if kwargs is None:
61
+ kwargs = {}
62
+
63
+ args_hex = serialize_result(args).hex() if args else ""
64
+ kwargs_hex = serialize_result(kwargs).hex() if kwargs else ""
65
+
66
+ if self.verbose_debug:
67
+ logger.debug(f"Submitting task with function path {function_path}, args {args}, kwargs {kwargs}, and options {options}")
68
+ else:
69
+ logger.debug(f"Submitting task with function path {function_path}")
70
+
71
+ response = self._client._send_request(
72
+ self._client._create_request("submit_task", {
73
+ "function_path": function_path,
74
+ "args": args_hex,
75
+ "kwargs": kwargs_hex,
76
+ "options": options,
77
+ "encoding": "pickle_hex",
78
+ })
79
+ )
80
+
81
+ if response.error:
82
+ raise SharedTensorServerError(f"Failed to submit task: {response.error}")
83
+
84
+ task_id = response.result.get("task_id")
85
+ if not task_id:
86
+ raise SharedTensorServerError("Server did not return task ID")
87
+
88
+ logger.debug(f"Task submitted: {task_id}")
89
+ return task_id
90
+
91
+ def get_task_status(self, task_id: str) -> TaskInfo:
92
+ """
93
+ Get current status of a task
94
+
95
+ Args:
96
+ task_id: Task ID returned by submit_task
97
+
98
+ Returns:
99
+ TaskInfo object with current status
100
+ """
101
+ logger.debug(f"Getting task status for task {task_id}")
102
+ response = self._client._send_request(
103
+ self._client._create_request("get_task_status", {"task_id": task_id})
104
+ )
105
+
106
+ if response.error:
107
+ logger.debug(f"Failed to get task status: {response.error}")
108
+ raise SharedTensorServerError(f"Failed to get task status: {response.error}")
109
+
110
+ task_data = response.result
111
+ if not task_data:
112
+ raise SharedTensorServerError(f"Task {task_id} not found")
113
+
114
+ return TaskInfo.from_dict(task_data)
115
+
116
+ def get_task_result(self, task_id: str) -> Any:
117
+ """
118
+ Get result of a completed task
119
+
120
+ Args:
121
+ task_id: Task ID
122
+
123
+ Returns:
124
+ Task result (deserialized)
125
+
126
+ Raises:
127
+ RuntimeError: If task failed or not completed
128
+ """
129
+ task_info = self.get_task_status(task_id)
130
+
131
+ if task_info.status == TaskStatus.FAILED:
132
+ raise SharedTensorServerError(f"Task failed: {task_info.error_message}")
133
+
134
+ if task_info.status != TaskStatus.COMPLETED:
135
+ raise SharedTensorServerError(f"Task not completed, current status: {task_info.status.value}")
136
+
137
+ if task_info.result_hex:
138
+ result_bytes = bytes.fromhex(task_info.result_hex)
139
+ return torch.multiprocessing.reducer.ForkingPickler.loads(result_bytes)
140
+ return None
141
+
142
+ def wait_for_task(self, task_id: str, timeout: Optional[float] = None,
143
+ callback: Optional[Callable[[TaskInfo], None]] = None) -> Any:
144
+ """
145
+ Wait for a task to complete and return its result
146
+
147
+ Args:
148
+ task_id: Task ID
149
+ timeout: Maximum time to wait (None for no timeout)
150
+ callback: Optional callback function called on each status update
151
+
152
+ Returns:
153
+ Task result
154
+ """
155
+ start_time = time.time()
156
+
157
+ while True:
158
+ task_info = self.get_task_status(task_id)
159
+
160
+ # Call callback if provided
161
+ if callback:
162
+ try:
163
+ callback(task_info)
164
+ except Exception as e:
165
+ logger.warning(f"Callback error: {e}")
166
+
167
+ # Check if completed
168
+ if task_info.status == TaskStatus.COMPLETED:
169
+ return self.get_task_result(task_id)
170
+
171
+ # Check if failed
172
+ if task_info.status == TaskStatus.FAILED:
173
+ raise SharedTensorServerError(f"Task {task_id} failed: {task_info.error_message}")
174
+
175
+ # Check if cancelled
176
+ if task_info.status == TaskStatus.CANCELLED:
177
+ raise SharedTensorServerError("Task was cancelled")
178
+
179
+ # Check timeout
180
+ if timeout and (time.time() - start_time) > timeout:
181
+ raise SharedTensorServerError(f"Task {task_id} did not complete within {timeout} seconds")
182
+
183
+ # Sleep before next poll
184
+ time.sleep(self.poll_interval)
185
+
186
+ def execute_function_async(self, function_path: str, args: tuple = (),
187
+ kwargs: Dict[str, Any] = None, options: Dict[str, Any] = None, wait: bool = True,
188
+ timeout: Optional[float] = None,
189
+ callback: Optional[Callable[[TaskInfo], None]] = None) -> Any:
190
+ """
191
+ Execute a function asynchronously
192
+
193
+ Args:
194
+ function_path: Function path
195
+ args: Positional arguments
196
+ kwargs: Keyword arguments
197
+ options: Options for the task
198
+ wait: Whether to wait for completion
199
+ timeout: Maximum time to wait if wait=True
200
+ callback: Status update callback
201
+
202
+ Returns:
203
+ If wait=True: Function result
204
+ If wait=False: Task ID
205
+ """
206
+ task_id = self.submit_task(function_path, args, kwargs, options)
207
+
208
+ if wait:
209
+ return self.wait_for_task(task_id, timeout, callback)
210
+ else:
211
+ return task_id
212
+
213
+ def cancel_task(self, task_id: str) -> bool:
214
+ """
215
+ Cancel a task
216
+
217
+ Args:
218
+ task_id: Task ID
219
+
220
+ Returns:
221
+ True if successfully cancelled
222
+ """
223
+ response = self._client._send_request(
224
+ self._client._create_request("cancel_task", {"task_id": task_id})
225
+ )
226
+
227
+ if response.error:
228
+ logger.error(f"Failed to cancel task: {response.error}")
229
+ return False
230
+
231
+ return response.result.get("cancelled", False)
232
+
233
+ def list_tasks(self, status: Optional[str] = None) -> Dict[str, TaskInfo]:
234
+ """
235
+ List tasks on the server
236
+
237
+ Args:
238
+ status: Optional status filter
239
+
240
+ Returns:
241
+ Dictionary of task ID -> TaskInfo
242
+ """
243
+ params = {}
244
+ if status:
245
+ params["status"] = status
246
+
247
+ response = self._client._send_request(
248
+ self._client._create_request("list_tasks", params)
249
+ )
250
+
251
+ if response.error:
252
+ raise SharedTensorServerError(f"Failed to list tasks: {response.error}")
253
+
254
+ tasks = {}
255
+ for task_id, task_data in response.result.items():
256
+ tasks[task_id] = TaskInfo.from_dict(task_data)
257
+
258
+ return tasks
259
+
260
+ def close(self):
261
+ """Close the client"""
262
+ self._client.close()
263
+
264
+ def __enter__(self):
265
+ return self
266
+
267
+ def __exit__(self, exc_type, exc_val, exc_tb):
268
+ self.close()
269
+
270
+
271
+ def execute_remote_function_async(
272
+ function_path: str,
273
+ args: tuple = (),
274
+ kwargs: Dict[str, Any] = None,
275
+ options: Dict[str, Any] = None,
276
+ server_port: int = 2537,
277
+ verbose_debug: bool = False,
278
+ poll_interval: float = 1.0,
279
+ wait: bool = True,
280
+ timeout: Optional[float] = None,
281
+ callback: Optional[Callable[[TaskInfo], None]] = None
282
+ ) -> Any:
283
+ """
284
+ Convenience function to execute a remote function asynchronously
285
+
286
+ Args:
287
+ function_path: Function path
288
+ args: Positional arguments
289
+ kwargs: Keyword arguments
290
+ options: Options for the task
291
+ server_port: Port of the shared tensor server
292
+ verbose_debug: Whether to enable verbose debug logging
293
+ poll_interval: Interval in seconds for polling task status
294
+ wait: Whether to wait for completion
295
+ timeout: Maximum time to wait
296
+ callback: Status update callback
297
+
298
+ Returns:
299
+ Function result if wait=True, task ID if wait=False
300
+ """
301
+ with AsyncSharedTensorClient(server_port, verbose_debug, poll_interval) as client:
302
+ return client.execute_function_async(function_path, args, kwargs, options, wait, timeout, callback)
@@ -0,0 +1,173 @@
1
+ """
2
+ Async Provider for Shared Tensor
3
+
4
+ Extends the provider pattern to support async task execution
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ from functools import wraps
10
+ from typing import Any, Dict, Callable, Optional
11
+
12
+ from shared_tensor.errors import SharedTensorProviderError
13
+ from shared_tensor.provider import SharedTensorProvider
14
+ from shared_tensor.async_client import AsyncSharedTensorClient
15
+ from shared_tensor.async_task import TaskInfo
16
+
17
+
18
+ __all__ = ["AsyncSharedTensorProvider"]
19
+
20
+ logger = logging.getLogger(__name__)
21
+ global_rank = int(os.getenv("RANK", 0))
22
+
23
+
24
+ class AsyncSharedTensorProvider(SharedTensorProvider):
25
+ """
26
+ Async provider for shared tensor operations
27
+
28
+ Supports both sync and async execution modes
29
+ """
30
+
31
+ def __init__(self, server_port: int = 2537 + global_rank, verbose_debug: bool = False, poll_interval: float = 1.0):
32
+ super().__init__(server_port=server_port, verbose_debug=verbose_debug)
33
+ self.poll_interval = poll_interval
34
+ logger.debug(f"AsyncSharedTensorProvider initialized with server port {server_port}, verbose debug {verbose_debug}, and poll interval {poll_interval}")
35
+ self._async_client = None
36
+
37
+ def _get_async_client(self) -> AsyncSharedTensorClient:
38
+ """Get or create async client"""
39
+ if self._async_client is None:
40
+ logger.debug(f"Creating new async client with server port {self.server_port} and poll interval {self.poll_interval}")
41
+ self._async_client = AsyncSharedTensorClient(self.server_port, self.verbose_debug, self.poll_interval)
42
+ logger.debug(f"Async client created with server port {self.server_port} and poll interval {self.poll_interval}")
43
+ return self._async_client
44
+
45
+ def share_async(self, name: Optional[str] = None, wait: bool = True, singleton: bool = True, singleton_key_formatter: Optional[str] = None):
46
+ """
47
+ Decorator to register a function for async remote sharing
48
+
49
+ Args:
50
+ name: Optional custom name for the function
51
+ wait: Whether to wait for completion by default
52
+ singleton: Whether to use a singleton instance of the function result
53
+ singleton_key_formatter: Formatter for cached results
54
+ """
55
+ def decorator(func: Callable):
56
+ func_name = name or func.__name__
57
+
58
+ if self.server_mode == "true":
59
+ logger.debug(f"Server mode is true, returning function {func_name} without registering")
60
+ return func
61
+
62
+ logger.debug(f"Server mode is false, registering function {func_name}")
63
+
64
+ function_path = self._get_function_path(func)
65
+ logger.debug(f"Function {func_name} registered with function path {function_path}")
66
+
67
+ options = {
68
+ 'name': func_name,
69
+ 'singleton': singleton,
70
+ 'singleton_key_formatter': singleton_key_formatter,
71
+ }
72
+
73
+ function_info = {
74
+ 'name': func_name,
75
+ 'function_path': function_path,
76
+ 'options': options,
77
+ 'async_default_wait': wait
78
+ }
79
+
80
+ self._registered_functions[func_name] = function_info
81
+
82
+ @wraps(func)
83
+ def wrapper(*args, **kwargs):
84
+ return self._execute_async_function(func_name, args, kwargs, options)
85
+
86
+ wrapper.submit_async = lambda *args, **kwargs: self._submit_async_function(func_name, args, kwargs, options)
87
+ wrapper.execute_async = lambda *args, wait=wait, timeout=None, callback=None, **kwargs: \
88
+ self._execute_async_function_with_options(func_name, args, kwargs, options, wait, timeout, callback)
89
+
90
+ return wrapper
91
+ return decorator
92
+
93
+ def _submit_async_function(self, func_name: str, args: tuple, kwargs: dict, options: dict) -> str:
94
+ """Submit function for async execution, return task ID"""
95
+ try:
96
+ if func_name not in self._registered_functions:
97
+ raise SharedTensorProviderError(f"Function {func_name} not registered")
98
+
99
+ function_info = self._registered_functions[func_name]
100
+ function_path = function_info['function_path']
101
+
102
+ async_client = self._get_async_client()
103
+ logger.debug(f"Submitting async function {func_name} with function path {function_path} and options {options}")
104
+ return async_client.submit_task(function_path, args, kwargs, options)
105
+
106
+ except Exception as e:
107
+ raise SharedTensorProviderError(f"Failed to submit async function {func_name}: {str(e)}")
108
+
109
+ def _execute_async_function(self, func_name: str, args: tuple, kwargs: dict, options: dict) -> Any:
110
+ """Execute function using default async settings"""
111
+ function_info = self._registered_functions[func_name]
112
+ wait = function_info.get('async_default_wait', True)
113
+ if wait:
114
+ return self._execute_async_function_with_options(func_name, args, kwargs, options, True, None, None)
115
+ else:
116
+ return self._submit_async_function(func_name, args, kwargs, options)
117
+
118
+ def _execute_async_function_with_options(self, func_name: str, args: tuple, kwargs: dict, options: dict,
119
+ wait: bool, timeout: Optional[float],
120
+ callback: Optional[Callable[[TaskInfo], None]]) -> Any:
121
+ """Execute function with specific async options"""
122
+ try:
123
+ if func_name not in self._registered_functions:
124
+ raise SharedTensorProviderError(f"Function {func_name} not registered")
125
+
126
+ function_info = self._registered_functions[func_name]
127
+ function_path = function_info['function_path']
128
+
129
+ async_client = self._get_async_client()
130
+ logger.debug(f"Executing async function {func_name} with function path {function_path} and options {options}")
131
+ return async_client.execute_function_async(function_path, args, kwargs, options, wait, timeout, callback)
132
+ except Exception as e:
133
+ raise SharedTensorProviderError(f"Failed to execute async function {func_name}: {str(e)}")
134
+
135
+ def get_task_status(self, task_id: str) -> TaskInfo:
136
+ """Get status of a task"""
137
+ async_client = self._get_async_client()
138
+ logger.debug(f"Getting status of task {task_id}")
139
+ return async_client.get_task_status(task_id)
140
+
141
+ def get_task_result(self, task_id: str) -> Any:
142
+ """Get result of a completed task"""
143
+ async_client = self._get_async_client()
144
+ logger.debug(f"Getting result of task {task_id}")
145
+ return async_client.get_task_result(task_id)
146
+
147
+ def wait_for_task(self, task_id: str, timeout: Optional[float] = None,
148
+ callback: Optional[Callable[[TaskInfo], None]] = None) -> Any:
149
+ """Wait for a task to complete"""
150
+ async_client = self._get_async_client()
151
+ logger.debug(f"Waiting for task {task_id} with timeout {timeout} and callback {callback}")
152
+ return async_client.wait_for_task(task_id, timeout, callback)
153
+
154
+ def cancel_task(self, task_id: str) -> bool:
155
+ """Cancel a task"""
156
+ async_client = self._get_async_client()
157
+ logger.debug(f"Cancelling task {task_id}")
158
+ return async_client.cancel_task(task_id)
159
+
160
+ def list_tasks(self, status: Optional[str] = None) -> Dict[str, TaskInfo]:
161
+ """List tasks on the server"""
162
+ async_client = self._get_async_client()
163
+ logger.debug(f"Listing tasks with status {status}")
164
+ return async_client.list_tasks(status)
165
+
166
+ def close(self):
167
+ """Close the provider and its clients"""
168
+ super().close()
169
+ if self._async_client:
170
+ logger.debug(f"Closing async client")
171
+ self._async_client.close()
172
+ logger.debug(f"Async client closed")
173
+ self._async_client = None