shared-tensor 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shared_tensor/__init__.py +27 -0
- shared_tensor/async_client.py +302 -0
- shared_tensor/async_provider.py +173 -0
- shared_tensor/async_task.py +361 -0
- shared_tensor/client.py +265 -0
- shared_tensor/errors.py +16 -0
- shared_tensor/jsonrpc.py +163 -0
- shared_tensor/provider.py +155 -0
- shared_tensor/server.py +458 -0
- shared_tensor/utils.py +122 -0
- shared_tensor-0.1.0.dist-info/METADATA +420 -0
- shared_tensor-0.1.0.dist-info/RECORD +16 -0
- shared_tensor-0.1.0.dist-info/WHEEL +5 -0
- shared_tensor-0.1.0.dist-info/entry_points.txt +2 -0
- shared_tensor-0.1.0.dist-info/licenses/LICENSE +181 -0
- shared_tensor-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Shared Tensor Library
|
|
3
|
+
|
|
4
|
+
A library for sharing GPU memory objects across processes using IPC mechanisms.
|
|
5
|
+
Enables model and inference engine separation architecture using JSON-RPC 2.0 protocol.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from shared_tensor.provider import SharedTensorProvider
|
|
9
|
+
from shared_tensor.client import SharedTensorClient
|
|
10
|
+
from shared_tensor.server import SharedTensorServer
|
|
11
|
+
from shared_tensor.async_provider import AsyncSharedTensorProvider
|
|
12
|
+
from shared_tensor.async_client import AsyncSharedTensorClient
|
|
13
|
+
from shared_tensor.async_task import TaskStatus, TaskInfo
|
|
14
|
+
|
|
15
|
+
__version__ = "0.1.0"
|
|
16
|
+
__author__ = "Athena Team"
|
|
17
|
+
|
|
18
|
+
# Export main functionality
|
|
19
|
+
__all__ = [
|
|
20
|
+
"SharedTensorProvider",
|
|
21
|
+
"SharedTensorClient",
|
|
22
|
+
"SharedTensorServer",
|
|
23
|
+
"AsyncSharedTensorProvider",
|
|
24
|
+
"AsyncSharedTensorClient",
|
|
25
|
+
"TaskStatus",
|
|
26
|
+
"TaskInfo",
|
|
27
|
+
]
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Async Shared Tensor Client
|
|
3
|
+
|
|
4
|
+
Supports long-running task execution without HTTP timeout limitations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Any, Dict, Optional, Callable
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
from shared_tensor.errors import SharedTensorServerError
|
|
14
|
+
from shared_tensor.client import SharedTensorClient
|
|
15
|
+
from shared_tensor.async_task import TaskStatus, TaskInfo
|
|
16
|
+
from shared_tensor.utils import serialize_result
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = ["AsyncSharedTensorClient", "execute_remote_function_async"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AsyncSharedTensorClient:
|
|
26
|
+
"""
|
|
27
|
+
Async client for shared tensor operations
|
|
28
|
+
|
|
29
|
+
Supports submitting long-running tasks and polling for results
|
|
30
|
+
without being limited by HTTP timeouts.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, server_port: int = 2537, verbose_debug: bool = False, poll_interval: float = 1.0):
|
|
34
|
+
"""
|
|
35
|
+
Initialize async client
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
server_port: Port of the shared tensor server
|
|
39
|
+
verbose_debug: Whether to enable verbose debug logging
|
|
40
|
+
poll_interval: Interval in seconds for polling task status
|
|
41
|
+
"""
|
|
42
|
+
self.server_url = f"http://localhost:{server_port}"
|
|
43
|
+
self.verbose_debug = verbose_debug
|
|
44
|
+
self.poll_interval = poll_interval
|
|
45
|
+
self._client = SharedTensorClient(server_port, verbose_debug=verbose_debug)
|
|
46
|
+
|
|
47
|
+
def submit_task(self, function_path: str, args: tuple = (), kwargs: Dict[str, Any] = None, options: Dict[str, Any] = None) -> str:
|
|
48
|
+
"""
|
|
49
|
+
Submit a task for async execution
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
function_path: Function path in format "module.submodule:function_name"
|
|
53
|
+
args: Positional arguments
|
|
54
|
+
kwargs: Keyword arguments
|
|
55
|
+
options: Options for the task
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Task ID for tracking the execution
|
|
59
|
+
"""
|
|
60
|
+
if kwargs is None:
|
|
61
|
+
kwargs = {}
|
|
62
|
+
|
|
63
|
+
args_hex = serialize_result(args).hex() if args else ""
|
|
64
|
+
kwargs_hex = serialize_result(kwargs).hex() if kwargs else ""
|
|
65
|
+
|
|
66
|
+
if self.verbose_debug:
|
|
67
|
+
logger.debug(f"Submitting task with function path {function_path}, args {args}, kwargs {kwargs}, and options {options}")
|
|
68
|
+
else:
|
|
69
|
+
logger.debug(f"Submitting task with function path {function_path}")
|
|
70
|
+
|
|
71
|
+
response = self._client._send_request(
|
|
72
|
+
self._client._create_request("submit_task", {
|
|
73
|
+
"function_path": function_path,
|
|
74
|
+
"args": args_hex,
|
|
75
|
+
"kwargs": kwargs_hex,
|
|
76
|
+
"options": options,
|
|
77
|
+
"encoding": "pickle_hex",
|
|
78
|
+
})
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if response.error:
|
|
82
|
+
raise SharedTensorServerError(f"Failed to submit task: {response.error}")
|
|
83
|
+
|
|
84
|
+
task_id = response.result.get("task_id")
|
|
85
|
+
if not task_id:
|
|
86
|
+
raise SharedTensorServerError("Server did not return task ID")
|
|
87
|
+
|
|
88
|
+
logger.debug(f"Task submitted: {task_id}")
|
|
89
|
+
return task_id
|
|
90
|
+
|
|
91
|
+
def get_task_status(self, task_id: str) -> TaskInfo:
|
|
92
|
+
"""
|
|
93
|
+
Get current status of a task
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
task_id: Task ID returned by submit_task
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
TaskInfo object with current status
|
|
100
|
+
"""
|
|
101
|
+
logger.debug(f"Getting task status for task {task_id}")
|
|
102
|
+
response = self._client._send_request(
|
|
103
|
+
self._client._create_request("get_task_status", {"task_id": task_id})
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if response.error:
|
|
107
|
+
logger.debug(f"Failed to get task status: {response.error}")
|
|
108
|
+
raise SharedTensorServerError(f"Failed to get task status: {response.error}")
|
|
109
|
+
|
|
110
|
+
task_data = response.result
|
|
111
|
+
if not task_data:
|
|
112
|
+
raise SharedTensorServerError(f"Task {task_id} not found")
|
|
113
|
+
|
|
114
|
+
return TaskInfo.from_dict(task_data)
|
|
115
|
+
|
|
116
|
+
def get_task_result(self, task_id: str) -> Any:
|
|
117
|
+
"""
|
|
118
|
+
Get result of a completed task
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
task_id: Task ID
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Task result (deserialized)
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
RuntimeError: If task failed or not completed
|
|
128
|
+
"""
|
|
129
|
+
task_info = self.get_task_status(task_id)
|
|
130
|
+
|
|
131
|
+
if task_info.status == TaskStatus.FAILED:
|
|
132
|
+
raise SharedTensorServerError(f"Task failed: {task_info.error_message}")
|
|
133
|
+
|
|
134
|
+
if task_info.status != TaskStatus.COMPLETED:
|
|
135
|
+
raise SharedTensorServerError(f"Task not completed, current status: {task_info.status.value}")
|
|
136
|
+
|
|
137
|
+
if task_info.result_hex:
|
|
138
|
+
result_bytes = bytes.fromhex(task_info.result_hex)
|
|
139
|
+
return torch.multiprocessing.reducer.ForkingPickler.loads(result_bytes)
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
def wait_for_task(self, task_id: str, timeout: Optional[float] = None,
|
|
143
|
+
callback: Optional[Callable[[TaskInfo], None]] = None) -> Any:
|
|
144
|
+
"""
|
|
145
|
+
Wait for a task to complete and return its result
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
task_id: Task ID
|
|
149
|
+
timeout: Maximum time to wait (None for no timeout)
|
|
150
|
+
callback: Optional callback function called on each status update
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Task result
|
|
154
|
+
"""
|
|
155
|
+
start_time = time.time()
|
|
156
|
+
|
|
157
|
+
while True:
|
|
158
|
+
task_info = self.get_task_status(task_id)
|
|
159
|
+
|
|
160
|
+
# Call callback if provided
|
|
161
|
+
if callback:
|
|
162
|
+
try:
|
|
163
|
+
callback(task_info)
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.warning(f"Callback error: {e}")
|
|
166
|
+
|
|
167
|
+
# Check if completed
|
|
168
|
+
if task_info.status == TaskStatus.COMPLETED:
|
|
169
|
+
return self.get_task_result(task_id)
|
|
170
|
+
|
|
171
|
+
# Check if failed
|
|
172
|
+
if task_info.status == TaskStatus.FAILED:
|
|
173
|
+
raise SharedTensorServerError(f"Task {task_id} failed: {task_info.error_message}")
|
|
174
|
+
|
|
175
|
+
# Check if cancelled
|
|
176
|
+
if task_info.status == TaskStatus.CANCELLED:
|
|
177
|
+
raise SharedTensorServerError("Task was cancelled")
|
|
178
|
+
|
|
179
|
+
# Check timeout
|
|
180
|
+
if timeout and (time.time() - start_time) > timeout:
|
|
181
|
+
raise SharedTensorServerError(f"Task {task_id} did not complete within {timeout} seconds")
|
|
182
|
+
|
|
183
|
+
# Sleep before next poll
|
|
184
|
+
time.sleep(self.poll_interval)
|
|
185
|
+
|
|
186
|
+
def execute_function_async(self, function_path: str, args: tuple = (),
|
|
187
|
+
kwargs: Dict[str, Any] = None, options: Dict[str, Any] = None, wait: bool = True,
|
|
188
|
+
timeout: Optional[float] = None,
|
|
189
|
+
callback: Optional[Callable[[TaskInfo], None]] = None) -> Any:
|
|
190
|
+
"""
|
|
191
|
+
Execute a function asynchronously
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
function_path: Function path
|
|
195
|
+
args: Positional arguments
|
|
196
|
+
kwargs: Keyword arguments
|
|
197
|
+
options: Options for the task
|
|
198
|
+
wait: Whether to wait for completion
|
|
199
|
+
timeout: Maximum time to wait if wait=True
|
|
200
|
+
callback: Status update callback
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
If wait=True: Function result
|
|
204
|
+
If wait=False: Task ID
|
|
205
|
+
"""
|
|
206
|
+
task_id = self.submit_task(function_path, args, kwargs, options)
|
|
207
|
+
|
|
208
|
+
if wait:
|
|
209
|
+
return self.wait_for_task(task_id, timeout, callback)
|
|
210
|
+
else:
|
|
211
|
+
return task_id
|
|
212
|
+
|
|
213
|
+
def cancel_task(self, task_id: str) -> bool:
|
|
214
|
+
"""
|
|
215
|
+
Cancel a task
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
task_id: Task ID
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
True if successfully cancelled
|
|
222
|
+
"""
|
|
223
|
+
response = self._client._send_request(
|
|
224
|
+
self._client._create_request("cancel_task", {"task_id": task_id})
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if response.error:
|
|
228
|
+
logger.error(f"Failed to cancel task: {response.error}")
|
|
229
|
+
return False
|
|
230
|
+
|
|
231
|
+
return response.result.get("cancelled", False)
|
|
232
|
+
|
|
233
|
+
def list_tasks(self, status: Optional[str] = None) -> Dict[str, TaskInfo]:
|
|
234
|
+
"""
|
|
235
|
+
List tasks on the server
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
status: Optional status filter
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Dictionary of task ID -> TaskInfo
|
|
242
|
+
"""
|
|
243
|
+
params = {}
|
|
244
|
+
if status:
|
|
245
|
+
params["status"] = status
|
|
246
|
+
|
|
247
|
+
response = self._client._send_request(
|
|
248
|
+
self._client._create_request("list_tasks", params)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if response.error:
|
|
252
|
+
raise SharedTensorServerError(f"Failed to list tasks: {response.error}")
|
|
253
|
+
|
|
254
|
+
tasks = {}
|
|
255
|
+
for task_id, task_data in response.result.items():
|
|
256
|
+
tasks[task_id] = TaskInfo.from_dict(task_data)
|
|
257
|
+
|
|
258
|
+
return tasks
|
|
259
|
+
|
|
260
|
+
def close(self):
|
|
261
|
+
"""Close the client"""
|
|
262
|
+
self._client.close()
|
|
263
|
+
|
|
264
|
+
def __enter__(self):
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
268
|
+
self.close()
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def execute_remote_function_async(
|
|
272
|
+
function_path: str,
|
|
273
|
+
args: tuple = (),
|
|
274
|
+
kwargs: Dict[str, Any] = None,
|
|
275
|
+
options: Dict[str, Any] = None,
|
|
276
|
+
server_port: int = 2537,
|
|
277
|
+
verbose_debug: bool = False,
|
|
278
|
+
poll_interval: float = 1.0,
|
|
279
|
+
wait: bool = True,
|
|
280
|
+
timeout: Optional[float] = None,
|
|
281
|
+
callback: Optional[Callable[[TaskInfo], None]] = None
|
|
282
|
+
) -> Any:
|
|
283
|
+
"""
|
|
284
|
+
Convenience function to execute a remote function asynchronously
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
function_path: Function path
|
|
288
|
+
args: Positional arguments
|
|
289
|
+
kwargs: Keyword arguments
|
|
290
|
+
options: Options for the task
|
|
291
|
+
server_port: Port of the shared tensor server
|
|
292
|
+
verbose_debug: Whether to enable verbose debug logging
|
|
293
|
+
poll_interval: Interval in seconds for polling task status
|
|
294
|
+
wait: Whether to wait for completion
|
|
295
|
+
timeout: Maximum time to wait
|
|
296
|
+
callback: Status update callback
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Function result if wait=True, task ID if wait=False
|
|
300
|
+
"""
|
|
301
|
+
with AsyncSharedTensorClient(server_port, verbose_debug, poll_interval) as client:
|
|
302
|
+
return client.execute_function_async(function_path, args, kwargs, options, wait, timeout, callback)
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Async Provider for Shared Tensor
|
|
3
|
+
|
|
4
|
+
Extends the provider pattern to support async task execution
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import logging
|
|
9
|
+
from functools import wraps
|
|
10
|
+
from typing import Any, Dict, Callable, Optional
|
|
11
|
+
|
|
12
|
+
from shared_tensor.errors import SharedTensorProviderError
|
|
13
|
+
from shared_tensor.provider import SharedTensorProvider
|
|
14
|
+
from shared_tensor.async_client import AsyncSharedTensorClient
|
|
15
|
+
from shared_tensor.async_task import TaskInfo
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
__all__ = ["AsyncSharedTensorProvider"]
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
global_rank = int(os.getenv("RANK", 0))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class AsyncSharedTensorProvider(SharedTensorProvider):
|
|
25
|
+
"""
|
|
26
|
+
Async provider for shared tensor operations
|
|
27
|
+
|
|
28
|
+
Supports both sync and async execution modes
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, server_port: int = 2537 + global_rank, verbose_debug: bool = False, poll_interval: float = 1.0):
|
|
32
|
+
super().__init__(server_port=server_port, verbose_debug=verbose_debug)
|
|
33
|
+
self.poll_interval = poll_interval
|
|
34
|
+
logger.debug(f"AsyncSharedTensorProvider initialized with server port {server_port}, verbose debug {verbose_debug}, and poll interval {poll_interval}")
|
|
35
|
+
self._async_client = None
|
|
36
|
+
|
|
37
|
+
def _get_async_client(self) -> AsyncSharedTensorClient:
|
|
38
|
+
"""Get or create async client"""
|
|
39
|
+
if self._async_client is None:
|
|
40
|
+
logger.debug(f"Creating new async client with server port {self.server_port} and poll interval {self.poll_interval}")
|
|
41
|
+
self._async_client = AsyncSharedTensorClient(self.server_port, self.verbose_debug, self.poll_interval)
|
|
42
|
+
logger.debug(f"Async client created with server port {self.server_port} and poll interval {self.poll_interval}")
|
|
43
|
+
return self._async_client
|
|
44
|
+
|
|
45
|
+
def share_async(self, name: Optional[str] = None, wait: bool = True, singleton: bool = True, singleton_key_formatter: Optional[str] = None):
|
|
46
|
+
"""
|
|
47
|
+
Decorator to register a function for async remote sharing
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
name: Optional custom name for the function
|
|
51
|
+
wait: Whether to wait for completion by default
|
|
52
|
+
singleton: Whether to use a singleton instance of the function result
|
|
53
|
+
singleton_key_formatter: Formatter for cached results
|
|
54
|
+
"""
|
|
55
|
+
def decorator(func: Callable):
|
|
56
|
+
func_name = name or func.__name__
|
|
57
|
+
|
|
58
|
+
if self.server_mode == "true":
|
|
59
|
+
logger.debug(f"Server mode is true, returning function {func_name} without registering")
|
|
60
|
+
return func
|
|
61
|
+
|
|
62
|
+
logger.debug(f"Server mode is false, registering function {func_name}")
|
|
63
|
+
|
|
64
|
+
function_path = self._get_function_path(func)
|
|
65
|
+
logger.debug(f"Function {func_name} registered with function path {function_path}")
|
|
66
|
+
|
|
67
|
+
options = {
|
|
68
|
+
'name': func_name,
|
|
69
|
+
'singleton': singleton,
|
|
70
|
+
'singleton_key_formatter': singleton_key_formatter,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
function_info = {
|
|
74
|
+
'name': func_name,
|
|
75
|
+
'function_path': function_path,
|
|
76
|
+
'options': options,
|
|
77
|
+
'async_default_wait': wait
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
self._registered_functions[func_name] = function_info
|
|
81
|
+
|
|
82
|
+
@wraps(func)
|
|
83
|
+
def wrapper(*args, **kwargs):
|
|
84
|
+
return self._execute_async_function(func_name, args, kwargs, options)
|
|
85
|
+
|
|
86
|
+
wrapper.submit_async = lambda *args, **kwargs: self._submit_async_function(func_name, args, kwargs, options)
|
|
87
|
+
wrapper.execute_async = lambda *args, wait=wait, timeout=None, callback=None, **kwargs: \
|
|
88
|
+
self._execute_async_function_with_options(func_name, args, kwargs, options, wait, timeout, callback)
|
|
89
|
+
|
|
90
|
+
return wrapper
|
|
91
|
+
return decorator
|
|
92
|
+
|
|
93
|
+
def _submit_async_function(self, func_name: str, args: tuple, kwargs: dict, options: dict) -> str:
|
|
94
|
+
"""Submit function for async execution, return task ID"""
|
|
95
|
+
try:
|
|
96
|
+
if func_name not in self._registered_functions:
|
|
97
|
+
raise SharedTensorProviderError(f"Function {func_name} not registered")
|
|
98
|
+
|
|
99
|
+
function_info = self._registered_functions[func_name]
|
|
100
|
+
function_path = function_info['function_path']
|
|
101
|
+
|
|
102
|
+
async_client = self._get_async_client()
|
|
103
|
+
logger.debug(f"Submitting async function {func_name} with function path {function_path} and options {options}")
|
|
104
|
+
return async_client.submit_task(function_path, args, kwargs, options)
|
|
105
|
+
|
|
106
|
+
except Exception as e:
|
|
107
|
+
raise SharedTensorProviderError(f"Failed to submit async function {func_name}: {str(e)}")
|
|
108
|
+
|
|
109
|
+
def _execute_async_function(self, func_name: str, args: tuple, kwargs: dict, options: dict) -> Any:
|
|
110
|
+
"""Execute function using default async settings"""
|
|
111
|
+
function_info = self._registered_functions[func_name]
|
|
112
|
+
wait = function_info.get('async_default_wait', True)
|
|
113
|
+
if wait:
|
|
114
|
+
return self._execute_async_function_with_options(func_name, args, kwargs, options, True, None, None)
|
|
115
|
+
else:
|
|
116
|
+
return self._submit_async_function(func_name, args, kwargs, options)
|
|
117
|
+
|
|
118
|
+
def _execute_async_function_with_options(self, func_name: str, args: tuple, kwargs: dict, options: dict,
|
|
119
|
+
wait: bool, timeout: Optional[float],
|
|
120
|
+
callback: Optional[Callable[[TaskInfo], None]]) -> Any:
|
|
121
|
+
"""Execute function with specific async options"""
|
|
122
|
+
try:
|
|
123
|
+
if func_name not in self._registered_functions:
|
|
124
|
+
raise SharedTensorProviderError(f"Function {func_name} not registered")
|
|
125
|
+
|
|
126
|
+
function_info = self._registered_functions[func_name]
|
|
127
|
+
function_path = function_info['function_path']
|
|
128
|
+
|
|
129
|
+
async_client = self._get_async_client()
|
|
130
|
+
logger.debug(f"Executing async function {func_name} with function path {function_path} and options {options}")
|
|
131
|
+
return async_client.execute_function_async(function_path, args, kwargs, options, wait, timeout, callback)
|
|
132
|
+
except Exception as e:
|
|
133
|
+
raise SharedTensorProviderError(f"Failed to execute async function {func_name}: {str(e)}")
|
|
134
|
+
|
|
135
|
+
def get_task_status(self, task_id: str) -> TaskInfo:
|
|
136
|
+
"""Get status of a task"""
|
|
137
|
+
async_client = self._get_async_client()
|
|
138
|
+
logger.debug(f"Getting status of task {task_id}")
|
|
139
|
+
return async_client.get_task_status(task_id)
|
|
140
|
+
|
|
141
|
+
def get_task_result(self, task_id: str) -> Any:
|
|
142
|
+
"""Get result of a completed task"""
|
|
143
|
+
async_client = self._get_async_client()
|
|
144
|
+
logger.debug(f"Getting result of task {task_id}")
|
|
145
|
+
return async_client.get_task_result(task_id)
|
|
146
|
+
|
|
147
|
+
def wait_for_task(self, task_id: str, timeout: Optional[float] = None,
|
|
148
|
+
callback: Optional[Callable[[TaskInfo], None]] = None) -> Any:
|
|
149
|
+
"""Wait for a task to complete"""
|
|
150
|
+
async_client = self._get_async_client()
|
|
151
|
+
logger.debug(f"Waiting for task {task_id} with timeout {timeout} and callback {callback}")
|
|
152
|
+
return async_client.wait_for_task(task_id, timeout, callback)
|
|
153
|
+
|
|
154
|
+
def cancel_task(self, task_id: str) -> bool:
|
|
155
|
+
"""Cancel a task"""
|
|
156
|
+
async_client = self._get_async_client()
|
|
157
|
+
logger.debug(f"Cancelling task {task_id}")
|
|
158
|
+
return async_client.cancel_task(task_id)
|
|
159
|
+
|
|
160
|
+
def list_tasks(self, status: Optional[str] = None) -> Dict[str, TaskInfo]:
|
|
161
|
+
"""List tasks on the server"""
|
|
162
|
+
async_client = self._get_async_client()
|
|
163
|
+
logger.debug(f"Listing tasks with status {status}")
|
|
164
|
+
return async_client.list_tasks(status)
|
|
165
|
+
|
|
166
|
+
def close(self):
|
|
167
|
+
"""Close the provider and its clients"""
|
|
168
|
+
super().close()
|
|
169
|
+
if self._async_client:
|
|
170
|
+
logger.debug(f"Closing async client")
|
|
171
|
+
self._async_client.close()
|
|
172
|
+
logger.debug(f"Async client closed")
|
|
173
|
+
self._async_client = None
|