mantatech-sdk 0.5b0.dev65__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- manta/__init__.light.py +22 -0
- manta/__init__.py +83 -0
- manta/__main__.py +21 -0
- manta/apis/__init__.py +7 -0
- manta/apis/async_user_api.py +6458 -0
- manta/apis/graph.py +498 -0
- manta/apis/module.py +316 -0
- manta/apis/results.py +251 -0
- manta/apis/swarm.py +206 -0
- manta/apis/user_api.py +1016 -0
- manta/cli/__init__.py +1 -0
- manta/cli/commands/__init__.py +1 -0
- manta/cli/commands/base_handler.py +229 -0
- manta/cli/commands/doc.py +192 -0
- manta/cli/commands/install.py +346 -0
- manta/cli/commands/sdk.py +9 -0
- manta/cli/commands/sdk_cluster.py +211 -0
- manta/cli/commands/sdk_config.py +347 -0
- manta/cli/commands/sdk_globals.py +280 -0
- manta/cli/commands/sdk_logs.py +174 -0
- manta/cli/commands/sdk_main.py +167 -0
- manta/cli/commands/sdk_module.py +516 -0
- manta/cli/commands/sdk_nodes.py +168 -0
- manta/cli/commands/sdk_original.py +3873 -0
- manta/cli/commands/sdk_results.py +265 -0
- manta/cli/commands/sdk_swarm.py +454 -0
- manta/cli/commands/sdk_user.py +234 -0
- manta/cli/commands/status.py +292 -0
- manta/cli/component_detector.py +112 -0
- manta/cli/config_manager.py +445 -0
- manta/cli/main.py +265 -0
- manta/cli/utils/__init__.py +27 -0
- manta/cli/utils/converters.py +140 -0
- manta/clients/cluster_management_client.py +486 -0
- manta/clients/local_client.py +149 -0
- manta/clients/module_management_client.py +217 -0
- manta/clients/swarm_management_client.py +562 -0
- manta/clients/user_management_client.py +395 -0
- manta/clients/world_client.py +195 -0
- manta/light/__init__.py +31 -0
- manta/light/globals.py +245 -0
- manta/light/local.py +407 -0
- manta/light/logging_config.py +39 -0
- manta/light/path.py +116 -0
- manta/light/results.py +236 -0
- manta/light/task.py +100 -0
- manta/light/utils.py +217 -0
- manta/light/world.py +177 -0
- mantatech_sdk-0.5b0.dev65.dist-info/METADATA +1039 -0
- mantatech_sdk-0.5b0.dev65.dist-info/RECORD +54 -0
- mantatech_sdk-0.5b0.dev65.dist-info/WHEEL +5 -0
- mantatech_sdk-0.5b0.dev65.dist-info/entry_points.txt +2 -0
- mantatech_sdk-0.5b0.dev65.dist-info/licenses/LICENSE +683 -0
- mantatech_sdk-0.5b0.dev65.dist-info/top_level.txt +1 -0
manta/light/path.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import List, Optional, Union
|
|
6
|
+
|
|
7
|
+
from manta_common.build.node.light_service import ProtoPath
|
|
8
|
+
from manta_common.event_loop import EventLoopManager
|
|
9
|
+
from .local import Local
|
|
10
|
+
|
|
11
|
+
_shared_local = Local() # shared Local instance between files
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MantaPath:
|
|
15
|
+
def __init__(self, path: str, is_file: bool = True):
|
|
16
|
+
self._is_file = is_file
|
|
17
|
+
self._path = path
|
|
18
|
+
self._local = _shared_local
|
|
19
|
+
self.loop_manager = EventLoopManager.get_instance()
|
|
20
|
+
|
|
21
|
+
def _to_proto(self):
|
|
22
|
+
return ProtoPath(value=str(self._path), is_file=self._is_file)
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def name(self):
|
|
26
|
+
return Path(self._path).name
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def stem(self):
|
|
30
|
+
return Path(self._path).stem
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def suffix(self):
|
|
34
|
+
return Path(self._path).suffix
|
|
35
|
+
|
|
36
|
+
def __str__(self):
|
|
37
|
+
return self._path
|
|
38
|
+
|
|
39
|
+
def __truediv__(self, other: Union[str, MantaPath]) -> MantaPath:
|
|
40
|
+
if isinstance(other, str):
|
|
41
|
+
other_path = Path(other)
|
|
42
|
+
elif isinstance(other, MantaPath):
|
|
43
|
+
other_path = Path(other._path)
|
|
44
|
+
else:
|
|
45
|
+
raise TypeError(f"Unsupported type (found: {type(other)})")
|
|
46
|
+
|
|
47
|
+
return MantaPath(str(Path(self._path) / other_path))
|
|
48
|
+
|
|
49
|
+
async def async_exists(self) -> bool:
|
|
50
|
+
response = await self._local.async_exists(self._to_proto())
|
|
51
|
+
return response
|
|
52
|
+
|
|
53
|
+
def exists(self) -> bool:
|
|
54
|
+
return self._local.exists(self._to_proto())
|
|
55
|
+
|
|
56
|
+
async def async_read_bytes(self) -> io.BytesIO:
|
|
57
|
+
return await self._local.async_get_binary_data(self._to_proto())
|
|
58
|
+
|
|
59
|
+
def read_bytes(self):
|
|
60
|
+
return self._local.get_binary_data(self._to_proto())
|
|
61
|
+
|
|
62
|
+
async def async_read_text(
|
|
63
|
+
self,
|
|
64
|
+
encoding: Optional[str] = None,
|
|
65
|
+
errors: Optional[str] = None,
|
|
66
|
+
newline: Optional[str] = None,
|
|
67
|
+
) -> io.StringIO:
|
|
68
|
+
return await self._local.async_read_file_lines(
|
|
69
|
+
self._to_proto(), encoding, errors, newline
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def read_text(
|
|
73
|
+
self,
|
|
74
|
+
encoding: Optional[str] = None,
|
|
75
|
+
errors: Optional[str] = None,
|
|
76
|
+
newline: Optional[str] = None,
|
|
77
|
+
) -> io.StringIO:
|
|
78
|
+
"""
|
|
79
|
+
Read Text
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
encoding : Optional[str], optional
|
|
84
|
+
Specifies the encoding to use for decoding the file contents, by default None
|
|
85
|
+
errors : Optional[str], optional
|
|
86
|
+
Specifies how encoding/decoding errors should be handled, by default None
|
|
87
|
+
newline : Optional[str], optional
|
|
88
|
+
Specifies how newlines (`\n`, `\r\n`, `\r`) should be handled, by default None
|
|
89
|
+
|
|
90
|
+
Returns
|
|
91
|
+
-------
|
|
92
|
+
io.StringIO
|
|
93
|
+
_description_
|
|
94
|
+
"""
|
|
95
|
+
return self._local.read_file_lines(self._to_proto(), encoding, errors, newline)
|
|
96
|
+
|
|
97
|
+
async def async_iterdir(self) -> List[MantaPath]:
|
|
98
|
+
response = await self._local.async_list_dir(self._to_proto())
|
|
99
|
+
return [MantaPath(p.value, is_file=p.is_file) for p in response.paths]
|
|
100
|
+
|
|
101
|
+
def iterdir(self) -> List[MantaPath]:
|
|
102
|
+
return self.loop_manager.run_coroutine(self.async_iterdir())
|
|
103
|
+
|
|
104
|
+
# async def async_glob(self, pattern, *, case_sensitive=None, recurse_symlinks=False):
|
|
105
|
+
# request = Glob(self._to_proto(), pattern, case_sensitive, recurse_symlinks)
|
|
106
|
+
# response = await self._local.glob(request)
|
|
107
|
+
# return [MantaPath(p.value, is_file=p.is_file) for p in response.paths]
|
|
108
|
+
|
|
109
|
+
# def glob(self, pattern, *, case_sensitive=None, recurse_symlinks=False):
|
|
110
|
+
# return asyncio.run(
|
|
111
|
+
# self.async_glob(
|
|
112
|
+
# pattern,
|
|
113
|
+
# case_sensitive=case_sensitive,
|
|
114
|
+
# recurse_symlinks=recurse_symlinks,
|
|
115
|
+
# )
|
|
116
|
+
# )
|
manta/light/results.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import AsyncIterable, Dict, Optional
|
|
5
|
+
|
|
6
|
+
from ..clients.world_client import WorldClient
|
|
7
|
+
from manta_common.build.common.results import ResultMethod
|
|
8
|
+
from manta_common.build.node.light_service import LightResult, LightResultQuery
|
|
9
|
+
from manta_common.const import CHUNK_SIZE
|
|
10
|
+
from manta_common.conversions import ID
|
|
11
|
+
from manta_common.event_loop import EventLoopManager
|
|
12
|
+
from .utils import bytes_to_dict, dict_to_bytes
|
|
13
|
+
|
|
14
|
+
__all__ = ["Results"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Results:
|
|
18
|
+
"""
|
|
19
|
+
Class for accessing results from the shared database or
|
|
20
|
+
adding results into the shared database
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
host : str
|
|
25
|
+
Manager host
|
|
26
|
+
port : int
|
|
27
|
+
Manager port
|
|
28
|
+
swarm_id : ID
|
|
29
|
+
Swarm ID
|
|
30
|
+
task_id : ID
|
|
31
|
+
Task ID
|
|
32
|
+
chunk_size : int
|
|
33
|
+
Chunk size
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
__slots__ = [
|
|
37
|
+
"world_client",
|
|
38
|
+
"swarm_id",
|
|
39
|
+
"task_id",
|
|
40
|
+
"logger",
|
|
41
|
+
"chunk_size",
|
|
42
|
+
"loop_manager",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
world_client: Optional[WorldClient] = None,
|
|
48
|
+
host: Optional[str] = None,
|
|
49
|
+
port: Optional[int] = None,
|
|
50
|
+
swarm_id: Optional[ID] = None,
|
|
51
|
+
task_id: Optional[ID] = None,
|
|
52
|
+
chunk_size: int = CHUNK_SIZE,
|
|
53
|
+
):
|
|
54
|
+
# Retrieve env variables for RPC host and port
|
|
55
|
+
if world_client is None:
|
|
56
|
+
self.world_client = WorldClient(
|
|
57
|
+
host=host or os.getenv("RPC_HOST", "host.docker.internal"),
|
|
58
|
+
port=int(port or os.getenv("RPC_PORT", 50051)),
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
self.world_client = world_client
|
|
62
|
+
|
|
63
|
+
self.task_id: ID = task_id or ID(os.getenv("TASK_ID"))
|
|
64
|
+
self.swarm_id: ID = swarm_id or ID(os.getenv("SWARM_ID"))
|
|
65
|
+
|
|
66
|
+
self.chunk_size = chunk_size
|
|
67
|
+
self.logger = logging.getLogger(__name__)
|
|
68
|
+
self.loop_manager = EventLoopManager.get_instance()
|
|
69
|
+
|
|
70
|
+
def select(self, tag: str, size: int = -1, method: ResultMethod = ResultMethod.ALL):
|
|
71
|
+
"""
|
|
72
|
+
Get the results of a Task
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
tag : str
|
|
77
|
+
The tag of the result to get
|
|
78
|
+
size : int, optional
|
|
79
|
+
The number of results to get
|
|
80
|
+
method : ResultMethod, optional
|
|
81
|
+
Method to use to select the results
|
|
82
|
+
|
|
83
|
+
Returns
|
|
84
|
+
-------
|
|
85
|
+
dict
|
|
86
|
+
The response from the world service
|
|
87
|
+
|
|
88
|
+
Examples
|
|
89
|
+
--------
|
|
90
|
+
|
|
91
|
+
Inside a :class:`Task <manta_light.task.Task>` class, you can
|
|
92
|
+
select results stored in the Manager database from the attribute
|
|
93
|
+
:code:`self.world` automatically created by
|
|
94
|
+
:class:`Task <manta_light.task.Task>`:
|
|
95
|
+
|
|
96
|
+
>>> params = self.world.results.select("model_params")
|
|
97
|
+
"""
|
|
98
|
+
return self.loop_manager.run_coroutine(self.async_select(tag, size, method))
|
|
99
|
+
|
|
100
|
+
def add(self, tag: str, result: dict):
|
|
101
|
+
"""
|
|
102
|
+
Set a result of a task
|
|
103
|
+
|
|
104
|
+
Parameters
|
|
105
|
+
----------
|
|
106
|
+
tag : str
|
|
107
|
+
The tag of the result to set
|
|
108
|
+
result : dict
|
|
109
|
+
The result to set
|
|
110
|
+
|
|
111
|
+
Examples
|
|
112
|
+
--------
|
|
113
|
+
|
|
114
|
+
Inside a :class:`Task <manta_light.task.Task>` class, you can
|
|
115
|
+
add results to be stored in the Manager database from the
|
|
116
|
+
attribute :code:`self.world` automatically created by
|
|
117
|
+
:class:`Task <manta_light.task.Task>`:
|
|
118
|
+
|
|
119
|
+
>>> self.world.results.add("metrics", metrics)
|
|
120
|
+
"""
|
|
121
|
+
self.loop_manager.run_coroutine(self.async_add(tag, result))
|
|
122
|
+
|
|
123
|
+
def __str__(self): # pragma: no cover
|
|
124
|
+
return f"Results(host={self.world_client.host}, port={self.world_client.port}, swarm_id={self.swarm_id}, task_id={self.task_id})"
|
|
125
|
+
|
|
126
|
+
def __repr__(self): # pragma: no cover
|
|
127
|
+
return str(self)
|
|
128
|
+
|
|
129
|
+
# Asynchronous methods
|
|
130
|
+
|
|
131
|
+
async def async_select(self, tag: str, size: int, method: ResultMethod) -> dict:
|
|
132
|
+
"""
|
|
133
|
+
Get the results of tasks
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
tag : str
|
|
138
|
+
Tag of the result to get
|
|
139
|
+
size : int
|
|
140
|
+
Number of results to get
|
|
141
|
+
method : str
|
|
142
|
+
The method to use to select the results
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
dict
|
|
147
|
+
Response from the world service
|
|
148
|
+
|
|
149
|
+
Examples
|
|
150
|
+
--------
|
|
151
|
+
|
|
152
|
+
Same as :meth:`select <manta_light.results.Results.select>`
|
|
153
|
+
but asynchronous:
|
|
154
|
+
|
|
155
|
+
>>> params = await self.world.results.async_select("model_params")
|
|
156
|
+
"""
|
|
157
|
+
# Initialize a buffer to accumulate chunks
|
|
158
|
+
buffer_dict: Dict[str, io.BytesIO] = {}
|
|
159
|
+
|
|
160
|
+
# Iterate over the chunks
|
|
161
|
+
async for chunk in self.world_client.get_task_result(
|
|
162
|
+
LightResultQuery(
|
|
163
|
+
task_id=self.task_id.oid,
|
|
164
|
+
swarm_id=self.swarm_id.oid,
|
|
165
|
+
tag=tag,
|
|
166
|
+
size=size,
|
|
167
|
+
method=method,
|
|
168
|
+
)
|
|
169
|
+
):
|
|
170
|
+
node_id = ID(chunk.node_id).xid
|
|
171
|
+
if node_id not in buffer_dict:
|
|
172
|
+
buffer_dict[node_id] = io.BytesIO()
|
|
173
|
+
buffer_dict[node_id].write(chunk.data)
|
|
174
|
+
|
|
175
|
+
self.logger.info("Results received")
|
|
176
|
+
return {
|
|
177
|
+
node_id: bytes_to_dict(buffer.getvalue())
|
|
178
|
+
for node_id, buffer in buffer_dict.items()
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
async def chunked_light_result(
|
|
182
|
+
self, request: LightResult
|
|
183
|
+
) -> AsyncIterable[LightResult]:
|
|
184
|
+
"""
|
|
185
|
+
This function chunks the data into smaller pieces and yields LightResult messages.
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
request : LightResult
|
|
190
|
+
Request
|
|
191
|
+
|
|
192
|
+
Returns
|
|
193
|
+
-------
|
|
194
|
+
AsyncIterable[LightResult]
|
|
195
|
+
AsyncIterable of LightResult messages
|
|
196
|
+
"""
|
|
197
|
+
data_stream = io.BytesIO(request.data)
|
|
198
|
+
while chunk := data_stream.read(self.chunk_size):
|
|
199
|
+
yield LightResult(
|
|
200
|
+
task_id=request.task_id,
|
|
201
|
+
swarm_id=request.swarm_id,
|
|
202
|
+
tag=request.tag,
|
|
203
|
+
data=chunk,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
async def async_add(self, tag: str, result: dict):
|
|
207
|
+
"""
|
|
208
|
+
Set a result of a task
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
tag : str
|
|
213
|
+
Tag of the result to set
|
|
214
|
+
result : dict
|
|
215
|
+
Result to add
|
|
216
|
+
|
|
217
|
+
Examples
|
|
218
|
+
--------
|
|
219
|
+
|
|
220
|
+
Same as :meth:`select <manta_light.results.Results.add>`
|
|
221
|
+
but asynchronous:
|
|
222
|
+
|
|
223
|
+
>>> await self.world.results.async_add("metrics", metrics)
|
|
224
|
+
"""
|
|
225
|
+
self.logger.info(f"Setting result for tag: {tag}")
|
|
226
|
+
await self.world_client.add_task_result(
|
|
227
|
+
self.chunked_light_result(
|
|
228
|
+
LightResult(
|
|
229
|
+
task_id=self.task_id.oid,
|
|
230
|
+
swarm_id=self.swarm_id.oid,
|
|
231
|
+
tag=tag,
|
|
232
|
+
data=dict_to_bytes(result),
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
)
|
|
236
|
+
self.logger.info("Set result response")
|
manta/light/task.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from abc import ABC
|
|
4
|
+
|
|
5
|
+
from manta_common.conversions import ID
|
|
6
|
+
from manta_common.event_loop import EventLoopManager
|
|
7
|
+
from .local import Local
|
|
8
|
+
from .world import World
|
|
9
|
+
|
|
10
|
+
__all__ = ["Task"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Task(ABC):
|
|
14
|
+
"""
|
|
15
|
+
Task abstract module.
|
|
16
|
+
|
|
17
|
+
- Set the name of the task
|
|
18
|
+
- Set the host and port for the RPC connection
|
|
19
|
+
- Set the Task ID
|
|
20
|
+
- Initialize the Local and World services
|
|
21
|
+
- Initialize the logger
|
|
22
|
+
|
|
23
|
+
Attributes
|
|
24
|
+
----------
|
|
25
|
+
world: World
|
|
26
|
+
For accessing and sending global data
|
|
27
|
+
local: Local
|
|
28
|
+
For accessing local data
|
|
29
|
+
logger: logging.Logger
|
|
30
|
+
Convenient logger which are collected and stored in Manager
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
__slots__ = [
|
|
34
|
+
"name",
|
|
35
|
+
"host",
|
|
36
|
+
"port",
|
|
37
|
+
"task_id",
|
|
38
|
+
"swarm_id",
|
|
39
|
+
"local",
|
|
40
|
+
"world",
|
|
41
|
+
"logger",
|
|
42
|
+
"loop_manager",
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
def __init__(self):
|
|
46
|
+
"""Initialize the task."""
|
|
47
|
+
self.logger = logging.getLogger(__name__)
|
|
48
|
+
self.logger.debug("Initializing Task")
|
|
49
|
+
|
|
50
|
+
self.name: str = str(__class__)
|
|
51
|
+
|
|
52
|
+
self.host = os.getenv("RPC_HOST", "host.docker.internal")
|
|
53
|
+
self.port = int(os.getenv("RPC_PORT", 50051))
|
|
54
|
+
self.task_id = ID(os.getenv("TASK_ID"))
|
|
55
|
+
self.swarm_id = ID(os.getenv("SWARM_ID"))
|
|
56
|
+
|
|
57
|
+
self.local = Local(
|
|
58
|
+
host=self.host, port=self.port, swarm_id=self.swarm_id, task_id=self.task_id
|
|
59
|
+
)
|
|
60
|
+
self.world = World(
|
|
61
|
+
host=self.host, port=self.port, swarm_id=self.swarm_id, task_id=self.task_id
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.loop_manager = EventLoopManager.get_instance()
|
|
65
|
+
|
|
66
|
+
def __str__(self): # pragma: no cover
|
|
67
|
+
return f"Task(name={self.name}, host={self.host}, port={self.port}, swarm_id={self.swarm_id}, task_id={self.task_id})"
|
|
68
|
+
|
|
69
|
+
def __repr__(self): # pragma: no cover
|
|
70
|
+
return str(self)
|
|
71
|
+
|
|
72
|
+
def cleanup(self) -> None:
|
|
73
|
+
"""Clean up any resources used by the task.
|
|
74
|
+
|
|
75
|
+
This method should be called when the task is done to ensure all
|
|
76
|
+
resources are properly released. It's especially important to
|
|
77
|
+
close any open event loops to prevent resource leaks.
|
|
78
|
+
"""
|
|
79
|
+
self.logger.info("Cleaning up task resources")
|
|
80
|
+
try:
|
|
81
|
+
self.loop_manager.run_coroutine(self.local.local_client.disconnect())
|
|
82
|
+
self.loop_manager.run_coroutine(self.world.world_client.disconnect())
|
|
83
|
+
self.loop_manager.close()
|
|
84
|
+
self.logger.info("Event loop closed successfully")
|
|
85
|
+
except Exception as e:
|
|
86
|
+
self.logger.error(f"Error closing event loop: {e}")
|
|
87
|
+
|
|
88
|
+
def __del__(self) -> None:
|
|
89
|
+
"""Destructor to ensure cleanup is called."""
|
|
90
|
+
try:
|
|
91
|
+
self.cleanup()
|
|
92
|
+
except Exception as e:
|
|
93
|
+
# Log the exception before suppressing to avoid silent failures
|
|
94
|
+
try:
|
|
95
|
+
self.logger.error(f"Error during task cleanup in destructor: {e}")
|
|
96
|
+
except Exception:
|
|
97
|
+
# If logging fails, at least print to stderr
|
|
98
|
+
import sys
|
|
99
|
+
|
|
100
|
+
print(f"Task cleanup error: {e}", file=sys.stderr)
|
manta/light/utils.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import io
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from manta_common.conversions import bytes_to_dict, dict_to_bytes
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"dict_to_bytes",
|
|
10
|
+
"bytes_to_dict",
|
|
11
|
+
"numpy_to_bytes",
|
|
12
|
+
"bytes_to_numpy",
|
|
13
|
+
"torchmodel_to_bytes",
|
|
14
|
+
"bytes_to_torchmodel",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def numpy_to_bytes(data: Union[np.ndarray, list, dict]) -> Union[bytes, dict]:
|
|
19
|
+
"""
|
|
20
|
+
Recursive function which converts a numpy array or iterator with numpy arrays to bytes
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
data : np.array
|
|
25
|
+
The numpy array to convert
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
bytes
|
|
30
|
+
The bytes representation of the numpy array
|
|
31
|
+
|
|
32
|
+
Examples
|
|
33
|
+
--------
|
|
34
|
+
|
|
35
|
+
* From :code:`np.array`
|
|
36
|
+
|
|
37
|
+
>>> import numpy as np
|
|
38
|
+
>>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
|
|
39
|
+
>>> np_array = np.array([1, 2, 3])
|
|
40
|
+
>>> numpy_to_bytes(np_array)
|
|
41
|
+
|
|
42
|
+
* From :code:`Dict[str, np.array]`
|
|
43
|
+
|
|
44
|
+
>>> import numpy as np
|
|
45
|
+
>>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
|
|
46
|
+
>>> dict_np_array = {"key1": np.array([1, 2, 3]), "key2": np.array([4, 5, 6])}
|
|
47
|
+
>>> numpy_to_bytes(dict_np_array)
|
|
48
|
+
|
|
49
|
+
* From :code:`Dict[str, list]`
|
|
50
|
+
|
|
51
|
+
>>> import numpy as np
|
|
52
|
+
>>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
|
|
53
|
+
>>> dict_list = {"key1": [1, 2, 3], "key2": [4, 5, 6]}
|
|
54
|
+
>>> numpy_to_bytes(dict_list)
|
|
55
|
+
|
|
56
|
+
* From :code:`List[np.array]`
|
|
57
|
+
|
|
58
|
+
>>> import numpy as np
|
|
59
|
+
>>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
|
|
60
|
+
>>> list_np_array = [np.array([1, 2, 3]), np.array([4, 5, 6])]
|
|
61
|
+
>>> numpy_to_bytes(list_np_array)
|
|
62
|
+
"""
|
|
63
|
+
import numpy as np
|
|
64
|
+
|
|
65
|
+
if isinstance(data, np.ndarray) or isinstance(data, list):
|
|
66
|
+
buffer = io.BytesIO()
|
|
67
|
+
np.save(buffer, data)
|
|
68
|
+
return buffer.getvalue()
|
|
69
|
+
elif isinstance(data, dict):
|
|
70
|
+
return {key: numpy_to_bytes(value) for key, value in data.items()}
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(f"Unsupported type: {type(data)}")
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def bytes_to_numpy(b: Union[bytes, dict]) -> Union[np.ndarray, dict]:
|
|
76
|
+
"""
|
|
77
|
+
Convert bytes to a numpy array
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
b : bytes
|
|
82
|
+
The bytes to convert
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
np.array
|
|
87
|
+
The numpy array representation of the bytes
|
|
88
|
+
|
|
89
|
+
Examples
|
|
90
|
+
--------
|
|
91
|
+
|
|
92
|
+
* From :code:`np.array`
|
|
93
|
+
|
|
94
|
+
>>> import numpy as np
|
|
95
|
+
>>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
|
|
96
|
+
>>> np_array = np.array([1, 2, 3])
|
|
97
|
+
>>> np_bytes = numpy_to_bytes(np_array)
|
|
98
|
+
>>> bytes_to_numpy(np_bytes)
|
|
99
|
+
|
|
100
|
+
* From :code:`Dict[str, np.array]`
|
|
101
|
+
|
|
102
|
+
>>> import numpy as np
|
|
103
|
+
>>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
|
|
104
|
+
>>> dict_np_array = {"key1": np.array([1, 2, 3]), "key2": np.array([4, 5, 6])}
|
|
105
|
+
>>> np_bytes = numpy_to_bytes(dict_np_array)
|
|
106
|
+
>>> bytes_to_numpy(np_bytes)
|
|
107
|
+
|
|
108
|
+
* From :code:`Dict[str, list]`
|
|
109
|
+
|
|
110
|
+
>>> import numpy as np
|
|
111
|
+
>>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
|
|
112
|
+
>>> dict_list = {"key1": [1, 2, 3], "key2": [4, 5, 6]}
|
|
113
|
+
>>> np_bytes = numpy_to_bytes(dict_list)
|
|
114
|
+
>>> bytes_to_numpy(np_bytes)
|
|
115
|
+
|
|
116
|
+
* From :code:`List[np.array]`
|
|
117
|
+
|
|
118
|
+
>>> import numpy as np
|
|
119
|
+
>>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
|
|
120
|
+
>>> list_np_array = [np.array([1, 2, 3]), np.array([4, 5, 6])]
|
|
121
|
+
>>> np_bytes = numpy_to_bytes(list_np_array)
|
|
122
|
+
>>> bytes_to_numpy(np_bytes)
|
|
123
|
+
"""
|
|
124
|
+
import numpy as np
|
|
125
|
+
|
|
126
|
+
if isinstance(b, bytes):
|
|
127
|
+
buffer = io.BytesIO(b)
|
|
128
|
+
return np.load(buffer)
|
|
129
|
+
elif isinstance(b, dict):
|
|
130
|
+
return {key: bytes_to_numpy(value) for key, value in b.items()}
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(f"Unsupported type: {type(b)}")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def torchmodel_to_bytes(model: "torch.nn.Module") -> bytes: # type: ignore # noqa: F821
|
|
136
|
+
"""
|
|
137
|
+
Transform a torch model into bytes
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
model : "torch.nn.Module"
|
|
142
|
+
Torch model
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
bytes
|
|
147
|
+
Bytes from the torch model
|
|
148
|
+
|
|
149
|
+
Examples
|
|
150
|
+
--------
|
|
151
|
+
|
|
152
|
+
>>> from torch.nn import Linear, ReLU, Sequential
|
|
153
|
+
>>> from manta_light.utils import bytes_to_torchmodel, torchmodel_to_bytes
|
|
154
|
+
>>> torch_model = Sequential(
|
|
155
|
+
... Linear(3, 2), ReLU(), Linear(2, 1), ReLU(), Linear(1, 1), ReLU()
|
|
156
|
+
... )
|
|
157
|
+
>>> torchmodel_to_bytes(torch_model)
|
|
158
|
+
"""
|
|
159
|
+
import torch
|
|
160
|
+
|
|
161
|
+
buffer = io.BytesIO()
|
|
162
|
+
torch.save(model, buffer)
|
|
163
|
+
return buffer.getvalue()
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def bytes_to_torchmodel(b: bytes) -> "torch.nn.Module": # type: ignore # noqa: F821
|
|
167
|
+
"""
|
|
168
|
+
Transform bytes to torch model
|
|
169
|
+
|
|
170
|
+
Parameters
|
|
171
|
+
----------
|
|
172
|
+
b : bytes
|
|
173
|
+
Bytes from a torch model
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
"torch.nn.Module"
|
|
178
|
+
Torch model
|
|
179
|
+
|
|
180
|
+
Examples
|
|
181
|
+
--------
|
|
182
|
+
|
|
183
|
+
>>> from torch.nn import Linear, ReLU, Sequential
|
|
184
|
+
>>> from manta_light.utils import bytes_to_torchmodel, torchmodel_to_bytes
|
|
185
|
+
>>> torch_model = Sequential(
|
|
186
|
+
... Linear(3, 2), ReLU(), Linear(2, 1), ReLU(), Linear(1, 1), ReLU()
|
|
187
|
+
... )
|
|
188
|
+
>>> model_bytes = torchmodel_to_bytes(torch_model)
|
|
189
|
+
>>> bytes_to_torchmodel(model_bytes)
|
|
190
|
+
|
|
191
|
+
Security Note
|
|
192
|
+
-------------
|
|
193
|
+
This function uses torch.load with weights_only=False which can execute
|
|
194
|
+
arbitrary code. Only use with trusted model sources. In the Manta platform,
|
|
195
|
+
models should only come from authenticated users and trusted containers.
|
|
196
|
+
"""
|
|
197
|
+
import pickle
|
|
198
|
+
import warnings
|
|
199
|
+
|
|
200
|
+
import torch
|
|
201
|
+
|
|
202
|
+
buffer = io.BytesIO(b)
|
|
203
|
+
|
|
204
|
+
# Try to load with weights_only=True first (safer)
|
|
205
|
+
try:
|
|
206
|
+
return torch.load(buffer, weights_only=True)
|
|
207
|
+
except (RuntimeError, TypeError, pickle.UnpicklingError):
|
|
208
|
+
# If that fails, fallback to full model loading with a warning
|
|
209
|
+
# This is necessary for complex models with custom layers
|
|
210
|
+
buffer.seek(0) # Reset buffer position
|
|
211
|
+
warnings.warn(
|
|
212
|
+
"Loading PyTorch model with weights_only=False. "
|
|
213
|
+
"This can execute arbitrary code. Ensure the model source is trusted.",
|
|
214
|
+
UserWarning,
|
|
215
|
+
stacklevel=2,
|
|
216
|
+
)
|
|
217
|
+
return torch.load(buffer, weights_only=False)
|