mantatech-sdk 0.5b0.dev65__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. manta/__init__.light.py +22 -0
  2. manta/__init__.py +83 -0
  3. manta/__main__.py +21 -0
  4. manta/apis/__init__.py +7 -0
  5. manta/apis/async_user_api.py +6458 -0
  6. manta/apis/graph.py +498 -0
  7. manta/apis/module.py +316 -0
  8. manta/apis/results.py +251 -0
  9. manta/apis/swarm.py +206 -0
  10. manta/apis/user_api.py +1016 -0
  11. manta/cli/__init__.py +1 -0
  12. manta/cli/commands/__init__.py +1 -0
  13. manta/cli/commands/base_handler.py +229 -0
  14. manta/cli/commands/doc.py +192 -0
  15. manta/cli/commands/install.py +346 -0
  16. manta/cli/commands/sdk.py +9 -0
  17. manta/cli/commands/sdk_cluster.py +211 -0
  18. manta/cli/commands/sdk_config.py +347 -0
  19. manta/cli/commands/sdk_globals.py +280 -0
  20. manta/cli/commands/sdk_logs.py +174 -0
  21. manta/cli/commands/sdk_main.py +167 -0
  22. manta/cli/commands/sdk_module.py +516 -0
  23. manta/cli/commands/sdk_nodes.py +168 -0
  24. manta/cli/commands/sdk_original.py +3873 -0
  25. manta/cli/commands/sdk_results.py +265 -0
  26. manta/cli/commands/sdk_swarm.py +454 -0
  27. manta/cli/commands/sdk_user.py +234 -0
  28. manta/cli/commands/status.py +292 -0
  29. manta/cli/component_detector.py +112 -0
  30. manta/cli/config_manager.py +445 -0
  31. manta/cli/main.py +265 -0
  32. manta/cli/utils/__init__.py +27 -0
  33. manta/cli/utils/converters.py +140 -0
  34. manta/clients/cluster_management_client.py +486 -0
  35. manta/clients/local_client.py +149 -0
  36. manta/clients/module_management_client.py +217 -0
  37. manta/clients/swarm_management_client.py +562 -0
  38. manta/clients/user_management_client.py +395 -0
  39. manta/clients/world_client.py +195 -0
  40. manta/light/__init__.py +31 -0
  41. manta/light/globals.py +245 -0
  42. manta/light/local.py +407 -0
  43. manta/light/logging_config.py +39 -0
  44. manta/light/path.py +116 -0
  45. manta/light/results.py +236 -0
  46. manta/light/task.py +100 -0
  47. manta/light/utils.py +217 -0
  48. manta/light/world.py +177 -0
  49. mantatech_sdk-0.5b0.dev65.dist-info/METADATA +1039 -0
  50. mantatech_sdk-0.5b0.dev65.dist-info/RECORD +54 -0
  51. mantatech_sdk-0.5b0.dev65.dist-info/WHEEL +5 -0
  52. mantatech_sdk-0.5b0.dev65.dist-info/entry_points.txt +2 -0
  53. mantatech_sdk-0.5b0.dev65.dist-info/licenses/LICENSE +683 -0
  54. mantatech_sdk-0.5b0.dev65.dist-info/top_level.txt +1 -0
manta/light/path.py ADDED
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ from pathlib import Path
5
+ from typing import List, Optional, Union
6
+
7
+ from manta_common.build.node.light_service import ProtoPath
8
+ from manta_common.event_loop import EventLoopManager
9
+ from .local import Local
10
+
11
+ _shared_local = Local() # shared Local instance between files
12
+
13
+
14
+ class MantaPath:
15
+ def __init__(self, path: str, is_file: bool = True):
16
+ self._is_file = is_file
17
+ self._path = path
18
+ self._local = _shared_local
19
+ self.loop_manager = EventLoopManager.get_instance()
20
+
21
+ def _to_proto(self):
22
+ return ProtoPath(value=str(self._path), is_file=self._is_file)
23
+
24
+ @property
25
+ def name(self):
26
+ return Path(self._path).name
27
+
28
+ @property
29
+ def stem(self):
30
+ return Path(self._path).stem
31
+
32
+ @property
33
+ def suffix(self):
34
+ return Path(self._path).suffix
35
+
36
+ def __str__(self):
37
+ return self._path
38
+
39
+ def __truediv__(self, other: Union[str, MantaPath]) -> MantaPath:
40
+ if isinstance(other, str):
41
+ other_path = Path(other)
42
+ elif isinstance(other, MantaPath):
43
+ other_path = Path(other._path)
44
+ else:
45
+ raise TypeError(f"Unsupported type (found: {type(other)})")
46
+
47
+ return MantaPath(str(Path(self._path) / other_path))
48
+
49
+ async def async_exists(self) -> bool:
50
+ response = await self._local.async_exists(self._to_proto())
51
+ return response
52
+
53
+ def exists(self) -> bool:
54
+ return self._local.exists(self._to_proto())
55
+
56
+ async def async_read_bytes(self) -> io.BytesIO:
57
+ return await self._local.async_get_binary_data(self._to_proto())
58
+
59
+ def read_bytes(self):
60
+ return self._local.get_binary_data(self._to_proto())
61
+
62
+ async def async_read_text(
63
+ self,
64
+ encoding: Optional[str] = None,
65
+ errors: Optional[str] = None,
66
+ newline: Optional[str] = None,
67
+ ) -> io.StringIO:
68
+ return await self._local.async_read_file_lines(
69
+ self._to_proto(), encoding, errors, newline
70
+ )
71
+
72
+ def read_text(
73
+ self,
74
+ encoding: Optional[str] = None,
75
+ errors: Optional[str] = None,
76
+ newline: Optional[str] = None,
77
+ ) -> io.StringIO:
78
+ """
79
+ Read Text
80
+
81
+ Parameters
82
+ ----------
83
+ encoding : Optional[str], optional
84
+ Specifies the encoding to use for decoding the file contents, by default None
85
+ errors : Optional[str], optional
86
+ Specifies how encoding/decoding errors should be handled, by default None
87
+ newline : Optional[str], optional
88
+ Specifies how newlines (`\n`, `\r\n`, `\r`) should be handled, by default None
89
+
90
+ Returns
91
+ -------
92
+ io.StringIO
93
+ _description_
94
+ """
95
+ return self._local.read_file_lines(self._to_proto(), encoding, errors, newline)
96
+
97
+ async def async_iterdir(self) -> List[MantaPath]:
98
+ response = await self._local.async_list_dir(self._to_proto())
99
+ return [MantaPath(p.value, is_file=p.is_file) for p in response.paths]
100
+
101
+ def iterdir(self) -> List[MantaPath]:
102
+ return self.loop_manager.run_coroutine(self.async_iterdir())
103
+
104
+ # async def async_glob(self, pattern, *, case_sensitive=None, recurse_symlinks=False):
105
+ # request = Glob(self._to_proto(), pattern, case_sensitive, recurse_symlinks)
106
+ # response = await self._local.glob(request)
107
+ # return [MantaPath(p.value, is_file=p.is_file) for p in response.paths]
108
+
109
+ # def glob(self, pattern, *, case_sensitive=None, recurse_symlinks=False):
110
+ # return asyncio.run(
111
+ # self.async_glob(
112
+ # pattern,
113
+ # case_sensitive=case_sensitive,
114
+ # recurse_symlinks=recurse_symlinks,
115
+ # )
116
+ # )
manta/light/results.py ADDED
@@ -0,0 +1,236 @@
1
+ import io
2
+ import logging
3
+ import os
4
+ from typing import AsyncIterable, Dict, Optional
5
+
6
+ from ..clients.world_client import WorldClient
7
+ from manta_common.build.common.results import ResultMethod
8
+ from manta_common.build.node.light_service import LightResult, LightResultQuery
9
+ from manta_common.const import CHUNK_SIZE
10
+ from manta_common.conversions import ID
11
+ from manta_common.event_loop import EventLoopManager
12
+ from .utils import bytes_to_dict, dict_to_bytes
13
+
14
+ __all__ = ["Results"]
15
+
16
+
17
+ class Results:
18
+ """
19
+ Class for accessing results from the shared database or
20
+ adding results into the shared database
21
+
22
+ Parameters
23
+ ----------
24
+ host : str
25
+ Manager host
26
+ port : int
27
+ Manager port
28
+ swarm_id : ID
29
+ Swarm ID
30
+ task_id : ID
31
+ Task ID
32
+ chunk_size : int
33
+ Chunk size
34
+ """
35
+
36
+ __slots__ = [
37
+ "world_client",
38
+ "swarm_id",
39
+ "task_id",
40
+ "logger",
41
+ "chunk_size",
42
+ "loop_manager",
43
+ ]
44
+
45
+ def __init__(
46
+ self,
47
+ world_client: Optional[WorldClient] = None,
48
+ host: Optional[str] = None,
49
+ port: Optional[int] = None,
50
+ swarm_id: Optional[ID] = None,
51
+ task_id: Optional[ID] = None,
52
+ chunk_size: int = CHUNK_SIZE,
53
+ ):
54
+ # Retrieve env variables for RPC host and port
55
+ if world_client is None:
56
+ self.world_client = WorldClient(
57
+ host=host or os.getenv("RPC_HOST", "host.docker.internal"),
58
+ port=int(port or os.getenv("RPC_PORT", 50051)),
59
+ )
60
+ else:
61
+ self.world_client = world_client
62
+
63
+ self.task_id: ID = task_id or ID(os.getenv("TASK_ID"))
64
+ self.swarm_id: ID = swarm_id or ID(os.getenv("SWARM_ID"))
65
+
66
+ self.chunk_size = chunk_size
67
+ self.logger = logging.getLogger(__name__)
68
+ self.loop_manager = EventLoopManager.get_instance()
69
+
70
+ def select(self, tag: str, size: int = -1, method: ResultMethod = ResultMethod.ALL):
71
+ """
72
+ Get the results of a Task
73
+
74
+ Parameters
75
+ ----------
76
+ tag : str
77
+ The tag of the result to get
78
+ size : int, optional
79
+ The number of results to get
80
+ method : ResultMethod, optional
81
+ Method to use to select the results
82
+
83
+ Returns
84
+ -------
85
+ dict
86
+ The response from the world service
87
+
88
+ Examples
89
+ --------
90
+
91
+ Inside a :class:`Task <manta_light.task.Task>` class, you can
92
+ select results stored in the Manager database from the attribute
93
+ :code:`self.world` automatically created by
94
+ :class:`Task <manta_light.task.Task>`:
95
+
96
+ >>> params = self.world.results.select("model_params")
97
+ """
98
+ return self.loop_manager.run_coroutine(self.async_select(tag, size, method))
99
+
100
+ def add(self, tag: str, result: dict):
101
+ """
102
+ Set a result of a task
103
+
104
+ Parameters
105
+ ----------
106
+ tag : str
107
+ The tag of the result to set
108
+ result : dict
109
+ The result to set
110
+
111
+ Examples
112
+ --------
113
+
114
+ Inside a :class:`Task <manta_light.task.Task>` class, you can
115
+ add results to be stored in the Manager database from the
116
+ attribute :code:`self.world` automatically created by
117
+ :class:`Task <manta_light.task.Task>`:
118
+
119
+ >>> self.world.results.add("metrics", metrics)
120
+ """
121
+ self.loop_manager.run_coroutine(self.async_add(tag, result))
122
+
123
+ def __str__(self): # pragma: no cover
124
+ return f"Results(host={self.world_client.host}, port={self.world_client.port}, swarm_id={self.swarm_id}, task_id={self.task_id})"
125
+
126
+ def __repr__(self): # pragma: no cover
127
+ return str(self)
128
+
129
+ # Asynchronous methods
130
+
131
+ async def async_select(self, tag: str, size: int, method: ResultMethod) -> dict:
132
+ """
133
+ Get the results of tasks
134
+
135
+ Parameters
136
+ ----------
137
+ tag : str
138
+ Tag of the result to get
139
+ size : int
140
+ Number of results to get
141
+ method : str
142
+ The method to use to select the results
143
+
144
+ Returns
145
+ -------
146
+ dict
147
+ Response from the world service
148
+
149
+ Examples
150
+ --------
151
+
152
+ Same as :meth:`select <manta_light.results.Results.select>`
153
+ but asynchronous:
154
+
155
+ >>> params = await self.world.results.async_select("model_params")
156
+ """
157
+ # Initialize a buffer to accumulate chunks
158
+ buffer_dict: Dict[str, io.BytesIO] = {}
159
+
160
+ # Iterate over the chunks
161
+ async for chunk in self.world_client.get_task_result(
162
+ LightResultQuery(
163
+ task_id=self.task_id.oid,
164
+ swarm_id=self.swarm_id.oid,
165
+ tag=tag,
166
+ size=size,
167
+ method=method,
168
+ )
169
+ ):
170
+ node_id = ID(chunk.node_id).xid
171
+ if node_id not in buffer_dict:
172
+ buffer_dict[node_id] = io.BytesIO()
173
+ buffer_dict[node_id].write(chunk.data)
174
+
175
+ self.logger.info("Results received")
176
+ return {
177
+ node_id: bytes_to_dict(buffer.getvalue())
178
+ for node_id, buffer in buffer_dict.items()
179
+ }
180
+
181
+ async def chunked_light_result(
182
+ self, request: LightResult
183
+ ) -> AsyncIterable[LightResult]:
184
+ """
185
+ This function chunks the data into smaller pieces and yields LightResult messages.
186
+
187
+ Parameters
188
+ ----------
189
+ request : LightResult
190
+ Request
191
+
192
+ Returns
193
+ -------
194
+ AsyncIterable[LightResult]
195
+ AsyncIterable of LightResult messages
196
+ """
197
+ data_stream = io.BytesIO(request.data)
198
+ while chunk := data_stream.read(self.chunk_size):
199
+ yield LightResult(
200
+ task_id=request.task_id,
201
+ swarm_id=request.swarm_id,
202
+ tag=request.tag,
203
+ data=chunk,
204
+ )
205
+
206
+ async def async_add(self, tag: str, result: dict):
207
+ """
208
+ Set a result of a task
209
+
210
+ Parameters
211
+ ----------
212
+ tag : str
213
+ Tag of the result to set
214
+ result : dict
215
+ Result to add
216
+
217
+ Examples
218
+ --------
219
+
220
+ Same as :meth:`select <manta_light.results.Results.add>`
221
+ but asynchronous:
222
+
223
+ >>> await self.world.results.async_add("metrics", metrics)
224
+ """
225
+ self.logger.info(f"Setting result for tag: {tag}")
226
+ await self.world_client.add_task_result(
227
+ self.chunked_light_result(
228
+ LightResult(
229
+ task_id=self.task_id.oid,
230
+ swarm_id=self.swarm_id.oid,
231
+ tag=tag,
232
+ data=dict_to_bytes(result),
233
+ )
234
+ )
235
+ )
236
+ self.logger.info("Set result response")
manta/light/task.py ADDED
@@ -0,0 +1,100 @@
1
+ import logging
2
+ import os
3
+ from abc import ABC
4
+
5
+ from manta_common.conversions import ID
6
+ from manta_common.event_loop import EventLoopManager
7
+ from .local import Local
8
+ from .world import World
9
+
10
+ __all__ = ["Task"]
11
+
12
+
13
+ class Task(ABC):
14
+ """
15
+ Task abstract module.
16
+
17
+ - Set the name of the task
18
+ - Set the host and port for the RPC connection
19
+ - Set the Task ID
20
+ - Initialize the Local and World services
21
+ - Initialize the logger
22
+
23
+ Attributes
24
+ ----------
25
+ world: World
26
+ For accessing and sending global data
27
+ local: Local
28
+ For accessing local data
29
+ logger: logging.Logger
30
+ Convenient logger which are collected and stored in Manager
31
+ """
32
+
33
+ __slots__ = [
34
+ "name",
35
+ "host",
36
+ "port",
37
+ "task_id",
38
+ "swarm_id",
39
+ "local",
40
+ "world",
41
+ "logger",
42
+ "loop_manager",
43
+ ]
44
+
45
+ def __init__(self):
46
+ """Initialize the task."""
47
+ self.logger = logging.getLogger(__name__)
48
+ self.logger.debug("Initializing Task")
49
+
50
+ self.name: str = str(__class__)
51
+
52
+ self.host = os.getenv("RPC_HOST", "host.docker.internal")
53
+ self.port = int(os.getenv("RPC_PORT", 50051))
54
+ self.task_id = ID(os.getenv("TASK_ID"))
55
+ self.swarm_id = ID(os.getenv("SWARM_ID"))
56
+
57
+ self.local = Local(
58
+ host=self.host, port=self.port, swarm_id=self.swarm_id, task_id=self.task_id
59
+ )
60
+ self.world = World(
61
+ host=self.host, port=self.port, swarm_id=self.swarm_id, task_id=self.task_id
62
+ )
63
+
64
+ self.loop_manager = EventLoopManager.get_instance()
65
+
66
+ def __str__(self): # pragma: no cover
67
+ return f"Task(name={self.name}, host={self.host}, port={self.port}, swarm_id={self.swarm_id}, task_id={self.task_id})"
68
+
69
+ def __repr__(self): # pragma: no cover
70
+ return str(self)
71
+
72
+ def cleanup(self) -> None:
73
+ """Clean up any resources used by the task.
74
+
75
+ This method should be called when the task is done to ensure all
76
+ resources are properly released. It's especially important to
77
+ close any open event loops to prevent resource leaks.
78
+ """
79
+ self.logger.info("Cleaning up task resources")
80
+ try:
81
+ self.loop_manager.run_coroutine(self.local.local_client.disconnect())
82
+ self.loop_manager.run_coroutine(self.world.world_client.disconnect())
83
+ self.loop_manager.close()
84
+ self.logger.info("Event loop closed successfully")
85
+ except Exception as e:
86
+ self.logger.error(f"Error closing event loop: {e}")
87
+
88
+ def __del__(self) -> None:
89
+ """Destructor to ensure cleanup is called."""
90
+ try:
91
+ self.cleanup()
92
+ except Exception as e:
93
+ # Log the exception before suppressing to avoid silent failures
94
+ try:
95
+ self.logger.error(f"Error during task cleanup in destructor: {e}")
96
+ except Exception:
97
+ # If logging fails, at least print to stderr
98
+ import sys
99
+
100
+ print(f"Task cleanup error: {e}", file=sys.stderr)
manta/light/utils.py ADDED
@@ -0,0 +1,217 @@
1
+ import io
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+
6
+ from manta_common.conversions import bytes_to_dict, dict_to_bytes
7
+
8
+ __all__ = [
9
+ "dict_to_bytes",
10
+ "bytes_to_dict",
11
+ "numpy_to_bytes",
12
+ "bytes_to_numpy",
13
+ "torchmodel_to_bytes",
14
+ "bytes_to_torchmodel",
15
+ ]
16
+
17
+
18
+ def numpy_to_bytes(data: Union[np.ndarray, list, dict]) -> Union[bytes, dict]:
19
+ """
20
+ Recursive function which converts a numpy array or iterator with numpy arrays to bytes
21
+
22
+ Parameters
23
+ ----------
24
+ data : np.array
25
+ The numpy array to convert
26
+
27
+ Returns
28
+ -------
29
+ bytes
30
+ The bytes representation of the numpy array
31
+
32
+ Examples
33
+ --------
34
+
35
+ * From :code:`np.array`
36
+
37
+ >>> import numpy as np
38
+ >>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
39
+ >>> np_array = np.array([1, 2, 3])
40
+ >>> numpy_to_bytes(np_array)
41
+
42
+ * From :code:`Dict[str, np.array]`
43
+
44
+ >>> import numpy as np
45
+ >>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
46
+ >>> dict_np_array = {"key1": np.array([1, 2, 3]), "key2": np.array([4, 5, 6])}
47
+ >>> numpy_to_bytes(dict_np_array)
48
+
49
+ * From :code:`Dict[str, list]`
50
+
51
+ >>> import numpy as np
52
+ >>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
53
+ >>> dict_list = {"key1": [1, 2, 3], "key2": [4, 5, 6]}
54
+ >>> numpy_to_bytes(dict_list)
55
+
56
+ * From :code:`List[np.array]`
57
+
58
+ >>> import numpy as np
59
+ >>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
60
+ >>> list_np_array = [np.array([1, 2, 3]), np.array([4, 5, 6])]
61
+ >>> numpy_to_bytes(list_np_array)
62
+ """
63
+ import numpy as np
64
+
65
+ if isinstance(data, np.ndarray) or isinstance(data, list):
66
+ buffer = io.BytesIO()
67
+ np.save(buffer, data)
68
+ return buffer.getvalue()
69
+ elif isinstance(data, dict):
70
+ return {key: numpy_to_bytes(value) for key, value in data.items()}
71
+ else:
72
+ raise ValueError(f"Unsupported type: {type(data)}")
73
+
74
+
75
+ def bytes_to_numpy(b: Union[bytes, dict]) -> Union[np.ndarray, dict]:
76
+ """
77
+ Convert bytes to a numpy array
78
+
79
+ Parameters
80
+ ----------
81
+ b : bytes
82
+ The bytes to convert
83
+
84
+ Returns
85
+ -------
86
+ np.array
87
+ The numpy array representation of the bytes
88
+
89
+ Examples
90
+ --------
91
+
92
+ * From :code:`np.array`
93
+
94
+ >>> import numpy as np
95
+ >>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
96
+ >>> np_array = np.array([1, 2, 3])
97
+ >>> np_bytes = numpy_to_bytes(np_array)
98
+ >>> bytes_to_numpy(np_bytes)
99
+
100
+ * From :code:`Dict[str, np.array]`
101
+
102
+ >>> import numpy as np
103
+ >>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
104
+ >>> dict_np_array = {"key1": np.array([1, 2, 3]), "key2": np.array([4, 5, 6])}
105
+ >>> np_bytes = numpy_to_bytes(dict_np_array)
106
+ >>> bytes_to_numpy(np_bytes)
107
+
108
+ * From :code:`Dict[str, list]`
109
+
110
+ >>> import numpy as np
111
+ >>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
112
+ >>> dict_list = {"key1": [1, 2, 3], "key2": [4, 5, 6]}
113
+ >>> np_bytes = numpy_to_bytes(dict_list)
114
+ >>> bytes_to_numpy(np_bytes)
115
+
116
+ * From :code:`List[np.array]`
117
+
118
+ >>> import numpy as np
119
+ >>> from manta_light.utils import bytes_to_numpy, numpy_to_bytes
120
+ >>> list_np_array = [np.array([1, 2, 3]), np.array([4, 5, 6])]
121
+ >>> np_bytes = numpy_to_bytes(list_np_array)
122
+ >>> bytes_to_numpy(np_bytes)
123
+ """
124
+ import numpy as np
125
+
126
+ if isinstance(b, bytes):
127
+ buffer = io.BytesIO(b)
128
+ return np.load(buffer)
129
+ elif isinstance(b, dict):
130
+ return {key: bytes_to_numpy(value) for key, value in b.items()}
131
+ else:
132
+ raise ValueError(f"Unsupported type: {type(b)}")
133
+
134
+
135
+ def torchmodel_to_bytes(model: "torch.nn.Module") -> bytes: # type: ignore # noqa: F821
136
+ """
137
+ Transform a torch model into bytes
138
+
139
+ Parameters
140
+ ----------
141
+ model : "torch.nn.Module"
142
+ Torch model
143
+
144
+ Returns
145
+ -------
146
+ bytes
147
+ Bytes from the torch model
148
+
149
+ Examples
150
+ --------
151
+
152
+ >>> from torch.nn import Linear, ReLU, Sequential
153
+ >>> from manta_light.utils import bytes_to_torchmodel, torchmodel_to_bytes
154
+ >>> torch_model = Sequential(
155
+ ... Linear(3, 2), ReLU(), Linear(2, 1), ReLU(), Linear(1, 1), ReLU()
156
+ ... )
157
+ >>> torchmodel_to_bytes(torch_model)
158
+ """
159
+ import torch
160
+
161
+ buffer = io.BytesIO()
162
+ torch.save(model, buffer)
163
+ return buffer.getvalue()
164
+
165
+
166
+ def bytes_to_torchmodel(b: bytes) -> "torch.nn.Module": # type: ignore # noqa: F821
167
+ """
168
+ Transform bytes to torch model
169
+
170
+ Parameters
171
+ ----------
172
+ b : bytes
173
+ Bytes from a torch model
174
+
175
+ Returns
176
+ -------
177
+ "torch.nn.Module"
178
+ Torch model
179
+
180
+ Examples
181
+ --------
182
+
183
+ >>> from torch.nn import Linear, ReLU, Sequential
184
+ >>> from manta_light.utils import bytes_to_torchmodel, torchmodel_to_bytes
185
+ >>> torch_model = Sequential(
186
+ ... Linear(3, 2), ReLU(), Linear(2, 1), ReLU(), Linear(1, 1), ReLU()
187
+ ... )
188
+ >>> model_bytes = torchmodel_to_bytes(torch_model)
189
+ >>> bytes_to_torchmodel(model_bytes)
190
+
191
+ Security Note
192
+ -------------
193
+ This function uses torch.load with weights_only=False which can execute
194
+ arbitrary code. Only use with trusted model sources. In the Manta platform,
195
+ models should only come from authenticated users and trusted containers.
196
+ """
197
+ import pickle
198
+ import warnings
199
+
200
+ import torch
201
+
202
+ buffer = io.BytesIO(b)
203
+
204
+ # Try to load with weights_only=True first (safer)
205
+ try:
206
+ return torch.load(buffer, weights_only=True)
207
+ except (RuntimeError, TypeError, pickle.UnpicklingError):
208
+ # If that fails, fallback to full model loading with a warning
209
+ # This is necessary for complex models with custom layers
210
+ buffer.seek(0) # Reset buffer position
211
+ warnings.warn(
212
+ "Loading PyTorch model with weights_only=False. "
213
+ "This can execute arbitrary code. Ensure the model source is trusted.",
214
+ UserWarning,
215
+ stacklevel=2,
216
+ )
217
+ return torch.load(buffer, weights_only=False)