leapp 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- leapp/__init__.py +49 -0
- leapp/backends/export_backend.py +538 -0
- leapp/backends/module_builder.py +395 -0
- leapp/backends/onnx_export_backend.py +304 -0
- leapp/backends/torch_export_backend.py +102 -0
- leapp/buffer_tracker.py +255 -0
- leapp/export_manager.py +934 -0
- leapp/inference_manager.py +465 -0
- leapp/leapp.py +175 -0
- leapp/leapp_graph/datatypes/__init__.py +187 -0
- leapp/leapp_graph/datatypes/global_patching.py +229 -0
- leapp/leapp_graph/datatypes/traced_data.py +363 -0
- leapp/leapp_graph/datatypes/traced_np_array.py +1049 -0
- leapp/leapp_graph/datatypes/traced_tensor.py +1155 -0
- leapp/leapp_graph/function_decorator_node.py +365 -0
- leapp/leapp_graph/graph_gui.py +615 -0
- leapp/leapp_graph/leapp_graph.py +245 -0
- leapp/leapp_graph/leapp_node.py +501 -0
- leapp/leapp_graph/traced_node.py +675 -0
- leapp/utils/__init__.py +16 -0
- leapp/utils/caller_identity.py +135 -0
- leapp/utils/enums.py +55 -0
- leapp/utils/logging.py +157 -0
- leapp/utils/tensor_description.py +838 -0
- leapp/utils/tracing_lock.py +63 -0
- leapp/utils/utils.py +487 -0
- leapp-0.5.1.dist-info/METADATA +173 -0
- leapp-0.5.1.dist-info/RECORD +31 -0
- leapp-0.5.1.dist-info/WHEEL +5 -0
- leapp-0.5.1.dist-info/licenses/LICENSE +204 -0
- leapp-0.5.1.dist-info/top_level.txt +1 -0
leapp/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
LEAPP - Lightweight Export Annotations for Policy Pipelines
|
|
20
|
+
|
|
21
|
+
A Python package for tracing and exporting computational graphs from PyTorch code.
|
|
22
|
+
LEAPP is specifically designed for robotics and autonomous agent applications, allowing
|
|
23
|
+
you to trace and export complex policy pipelines with interconnected components to
|
|
24
|
+
various formats including PyTorch JIT, ONNX, and generate visualization and YAML specifications.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from .export_manager import ExportManager
|
|
28
|
+
from .inference_manager import InferenceManager
|
|
29
|
+
from .leapp import annotate, start, stop, compile_graph
|
|
30
|
+
from .utils.enums import InputKindEnum, OutputKindEnum
|
|
31
|
+
from .utils.tensor_description import TensorSemantics
|
|
32
|
+
|
|
33
|
+
__version__ = "0.5.1"
|
|
34
|
+
__config_version__ = "1.1"
|
|
35
|
+
__author__ = "Frank Lai"
|
|
36
|
+
__email__ = "frlai@nvidia.com"
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"ExportManager",
|
|
40
|
+
"InferenceManager",
|
|
41
|
+
"InputKindEnum",
|
|
42
|
+
"OutputKindEnum",
|
|
43
|
+
"annotate",
|
|
44
|
+
"start",
|
|
45
|
+
"stop",
|
|
46
|
+
"compile_graph",
|
|
47
|
+
"__version__",
|
|
48
|
+
"TensorSemantics",
|
|
49
|
+
]
|
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
import abc
|
|
18
|
+
import hashlib
|
|
19
|
+
import os
|
|
20
|
+
import shutil
|
|
21
|
+
from typing import Callable, Tuple, Any
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import torch
|
|
25
|
+
import onnx
|
|
26
|
+
import onnxruntime as ort
|
|
27
|
+
|
|
28
|
+
from leapp.utils.logging import _get_logger
|
|
29
|
+
from leapp.backends.module_builder import ModuleBuilder
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SimplifiedONNXProgram:
|
|
33
|
+
"""Wrapper for ONNX models.
|
|
34
|
+
|
|
35
|
+
This class mimics the behavior of ONNXProgram generated by torch.onnx.export
|
|
36
|
+
with onnx-dynamo, allowing the model to be called directly for inference.
|
|
37
|
+
|
|
38
|
+
Keeps source files on disk and copies on save. Cleans up temp dir on delete.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, onnx_model_path, temp_dir=None):
|
|
42
|
+
"""Initialize the ONNX program.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
onnx_model_path: Path to an ONNX file.
|
|
46
|
+
temp_dir: If provided, this temp directory will be cleaned up when the object is deleted.
|
|
47
|
+
"""
|
|
48
|
+
model_path = str(onnx_model_path)
|
|
49
|
+
self._source_dir = os.path.dirname(os.path.abspath(model_path))
|
|
50
|
+
self._source_filename = os.path.basename(model_path)
|
|
51
|
+
self._temp_dir = temp_dir
|
|
52
|
+
|
|
53
|
+
# Session created lazily on first __call__
|
|
54
|
+
self._session = None
|
|
55
|
+
self._input_names = None
|
|
56
|
+
self._output_metas = None
|
|
57
|
+
self._active_provider = None
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def _torch_dtype_to_numpy_dtype(dtype):
|
|
61
|
+
mapping = {
|
|
62
|
+
torch.bool: np.bool_,
|
|
63
|
+
torch.uint8: np.uint8,
|
|
64
|
+
torch.int8: np.int8,
|
|
65
|
+
torch.int16: np.int16,
|
|
66
|
+
torch.int32: np.int32,
|
|
67
|
+
torch.int64: np.int64,
|
|
68
|
+
torch.float16: np.float16,
|
|
69
|
+
torch.float32: np.float32,
|
|
70
|
+
torch.float64: np.float64,
|
|
71
|
+
}
|
|
72
|
+
return mapping.get(dtype)
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _onnx_type_to_torch_dtype(type_str):
|
|
76
|
+
mapping = {
|
|
77
|
+
"tensor(bool)": torch.bool,
|
|
78
|
+
"tensor(uint8)": torch.uint8,
|
|
79
|
+
"tensor(int8)": torch.int8,
|
|
80
|
+
"tensor(int16)": torch.int16,
|
|
81
|
+
"tensor(int32)": torch.int32,
|
|
82
|
+
"tensor(int64)": torch.int64,
|
|
83
|
+
"tensor(float16)": torch.float16,
|
|
84
|
+
"tensor(float)": torch.float32,
|
|
85
|
+
"tensor(double)": torch.float64,
|
|
86
|
+
}
|
|
87
|
+
return mapping.get(type_str)
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def _output_shape_is_static(shape):
|
|
91
|
+
if shape is None:
|
|
92
|
+
return False
|
|
93
|
+
return all(isinstance(dim, int) and dim >= 0 for dim in shape)
|
|
94
|
+
|
|
95
|
+
def _can_use_cuda_iobinding(self, args):
|
|
96
|
+
if self._active_provider != 'CUDAExecutionProvider':
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
if not args:
|
|
100
|
+
return False
|
|
101
|
+
|
|
102
|
+
first_tensor = args[0]
|
|
103
|
+
if not isinstance(first_tensor, torch.Tensor) or not first_tensor.is_cuda:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
for tensor in args:
|
|
107
|
+
if not isinstance(tensor, torch.Tensor) or not tensor.is_cuda:
|
|
108
|
+
return False
|
|
109
|
+
if tensor.device != first_tensor.device:
|
|
110
|
+
return False
|
|
111
|
+
|
|
112
|
+
if self._output_metas is None:
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
for output_meta in self._output_metas:
|
|
116
|
+
if not self._output_shape_is_static(output_meta.shape):
|
|
117
|
+
return False
|
|
118
|
+
if self._onnx_type_to_torch_dtype(output_meta.type) is None:
|
|
119
|
+
return False
|
|
120
|
+
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
def _run_with_standard_inference(self, args, output_device):
|
|
124
|
+
input_dict = {}
|
|
125
|
+
for name, tensor in zip(self._input_names, args):
|
|
126
|
+
if isinstance(tensor, torch.Tensor):
|
|
127
|
+
input_dict[name] = tensor.detach().cpu().numpy()
|
|
128
|
+
else:
|
|
129
|
+
input_dict[name] = tensor
|
|
130
|
+
|
|
131
|
+
outputs = self._session.run(None, input_dict)
|
|
132
|
+
return tuple(torch.from_numpy(out).to(output_device) for out in outputs)
|
|
133
|
+
|
|
134
|
+
def _run_with_cuda_iobinding(self, args):
|
|
135
|
+
binding = self._session.io_binding()
|
|
136
|
+
|
|
137
|
+
prepared_inputs = []
|
|
138
|
+
for name, tensor in zip(self._input_names, args):
|
|
139
|
+
prepared = tensor.detach().contiguous()
|
|
140
|
+
prepared_inputs.append(prepared)
|
|
141
|
+
device_id = prepared.device.index or 0
|
|
142
|
+
np_dtype = self._torch_dtype_to_numpy_dtype(prepared.dtype)
|
|
143
|
+
if np_dtype is None:
|
|
144
|
+
raise TypeError(
|
|
145
|
+
f"Unsupported CUDA input dtype for ONNX Runtime I/O binding: {prepared.dtype}"
|
|
146
|
+
)
|
|
147
|
+
binding.bind_input(
|
|
148
|
+
name=name,
|
|
149
|
+
device_type='cuda',
|
|
150
|
+
device_id=device_id,
|
|
151
|
+
element_type=np_dtype,
|
|
152
|
+
shape=tuple(prepared.shape),
|
|
153
|
+
buffer_ptr=prepared.data_ptr(),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
output_device = prepared_inputs[0].device
|
|
157
|
+
outputs = []
|
|
158
|
+
for output_meta in self._output_metas:
|
|
159
|
+
torch_dtype = self._onnx_type_to_torch_dtype(output_meta.type)
|
|
160
|
+
if torch_dtype is None:
|
|
161
|
+
raise TypeError(
|
|
162
|
+
f"Unsupported ONNX output dtype for CUDA I/O binding: {output_meta.type}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
output_tensor = torch.empty(
|
|
166
|
+
tuple(output_meta.shape),
|
|
167
|
+
dtype=torch_dtype,
|
|
168
|
+
device=output_device,
|
|
169
|
+
).contiguous()
|
|
170
|
+
binding.bind_output(
|
|
171
|
+
name=output_meta.name,
|
|
172
|
+
device_type='cuda',
|
|
173
|
+
device_id=output_device.index or 0,
|
|
174
|
+
element_type=self._torch_dtype_to_numpy_dtype(torch_dtype),
|
|
175
|
+
shape=tuple(output_tensor.shape),
|
|
176
|
+
buffer_ptr=output_tensor.data_ptr(),
|
|
177
|
+
)
|
|
178
|
+
outputs.append(output_tensor)
|
|
179
|
+
|
|
180
|
+
self._session.run_with_iobinding(binding)
|
|
181
|
+
return tuple(outputs)
|
|
182
|
+
|
|
183
|
+
def __del__(self):
|
|
184
|
+
"""Clean up temp directory if we own it."""
|
|
185
|
+
if getattr(self, '_temp_dir', None) is not None:
|
|
186
|
+
try:
|
|
187
|
+
if os.path.exists(self._temp_dir):
|
|
188
|
+
shutil.rmtree(self._temp_dir)
|
|
189
|
+
except Exception:
|
|
190
|
+
pass # Silently ignore cleanup errors
|
|
191
|
+
|
|
192
|
+
def _get_providers(self):
|
|
193
|
+
"""Get execution providers, always preferring CUDA when available."""
|
|
194
|
+
available_providers = ort.get_available_providers()
|
|
195
|
+
|
|
196
|
+
if 'CUDAExecutionProvider' in available_providers:
|
|
197
|
+
return ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
|
198
|
+
|
|
199
|
+
_get_logger().warning(
|
|
200
|
+
"CUDA execution provider not available. Falling back to CPU."
|
|
201
|
+
)
|
|
202
|
+
return ['CPUExecutionProvider']
|
|
203
|
+
|
|
204
|
+
def _get_source_size(self):
|
|
205
|
+
"""Get total size of all files in the source directory."""
|
|
206
|
+
total_size = 0
|
|
207
|
+
files_info = []
|
|
208
|
+
for f in os.listdir(self._source_dir):
|
|
209
|
+
file_path = os.path.join(self._source_dir, f)
|
|
210
|
+
if os.path.isfile(file_path):
|
|
211
|
+
size = os.path.getsize(file_path)
|
|
212
|
+
total_size += size
|
|
213
|
+
files_info.append(f"{f} ({size / (1024**2):.2f} MB)")
|
|
214
|
+
|
|
215
|
+
return total_size
|
|
216
|
+
|
|
217
|
+
def save(self, destination, include_initializers=True, keep_initializers_as_inputs=False):
|
|
218
|
+
"""Save the ONNX model to disk.
|
|
219
|
+
|
|
220
|
+
For large models (>2GB), re-saves with external data format.
|
|
221
|
+
For small models, copies files directly.
|
|
222
|
+
"""
|
|
223
|
+
# 2GB threshold for protobuf limit
|
|
224
|
+
SIZE_THRESHOLD = 2 * 1024 * 1024 * 1024
|
|
225
|
+
|
|
226
|
+
dest_dir = os.path.dirname(os.path.abspath(destination))
|
|
227
|
+
dest_filename = os.path.basename(destination)
|
|
228
|
+
src_path = os.path.join(self._source_dir, self._source_filename)
|
|
229
|
+
|
|
230
|
+
total_size = self._get_source_size()
|
|
231
|
+
|
|
232
|
+
if total_size > SIZE_THRESHOLD:
|
|
233
|
+
# Large model - load and re-save with external data
|
|
234
|
+
_get_logger().info(
|
|
235
|
+
f"Large model (~{total_size / (1024**3):.2f} GB) - saving with external data format"
|
|
236
|
+
)
|
|
237
|
+
# Ensure destination directory exists
|
|
238
|
+
os.makedirs(dest_dir, exist_ok=True)
|
|
239
|
+
|
|
240
|
+
model_proto = onnx.load(src_path, load_external_data=True)
|
|
241
|
+
data_filename = dest_filename + ".data"
|
|
242
|
+
data_path = os.path.join(dest_dir, data_filename)
|
|
243
|
+
|
|
244
|
+
# Delete existing data file to avoid appending (ONNX appends instead of overwriting)
|
|
245
|
+
if os.path.exists(data_path):
|
|
246
|
+
os.remove(data_path)
|
|
247
|
+
|
|
248
|
+
_get_logger().info(
|
|
249
|
+
f"Saving model to {destination} with external data in {data_filename}")
|
|
250
|
+
onnx.save(
|
|
251
|
+
model_proto,
|
|
252
|
+
destination,
|
|
253
|
+
save_as_external_data=True,
|
|
254
|
+
all_tensors_to_one_file=True,
|
|
255
|
+
location=data_filename,
|
|
256
|
+
size_threshold=1024,
|
|
257
|
+
convert_attribute=True, # Also convert Constant node attributes
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# Verify external data file was created
|
|
261
|
+
expected_data_path = os.path.join(dest_dir, data_filename)
|
|
262
|
+
if os.path.exists(expected_data_path):
|
|
263
|
+
_get_logger().info(
|
|
264
|
+
f"External data file created: {expected_data_path}")
|
|
265
|
+
else:
|
|
266
|
+
_get_logger().warning(
|
|
267
|
+
f"External data file NOT found: {expected_data_path}")
|
|
268
|
+
else:
|
|
269
|
+
# Small model - just copy files
|
|
270
|
+
shutil.copy2(src_path, destination)
|
|
271
|
+
|
|
272
|
+
# Copy ALL other files from source directory (external data files)
|
|
273
|
+
for f in os.listdir(self._source_dir):
|
|
274
|
+
if f != self._source_filename:
|
|
275
|
+
src_file = os.path.join(self._source_dir, f)
|
|
276
|
+
if os.path.isfile(src_file):
|
|
277
|
+
dst_file = os.path.join(dest_dir, f)
|
|
278
|
+
shutil.copy2(src_file, dst_file)
|
|
279
|
+
|
|
280
|
+
def _ensure_session(self):
|
|
281
|
+
"""Create the inference session if not already created."""
|
|
282
|
+
if self._session is None:
|
|
283
|
+
model_path = os.path.join(self._source_dir, self._source_filename)
|
|
284
|
+
providers = self._get_providers()
|
|
285
|
+
sess_options = ort.SessionOptions()
|
|
286
|
+
# ORT_ENABLE_ALL can silently corrupt results for certain graph
|
|
287
|
+
# patterns (e.g. Gemm chains produced by FX make_fx decomposition).
|
|
288
|
+
# ORT_ENABLE_BASIC is safe and still applies constant folding.
|
|
289
|
+
sess_options.graph_optimization_level = (
|
|
290
|
+
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC)
|
|
291
|
+
self._session = ort.InferenceSession(
|
|
292
|
+
model_path, sess_options, providers=providers)
|
|
293
|
+
self._input_names = [
|
|
294
|
+
inp.name for inp in self._session.get_inputs()]
|
|
295
|
+
self._output_metas = list(self._session.get_outputs())
|
|
296
|
+
self._active_provider = self._session.get_providers(
|
|
297
|
+
)[0] if self._session.get_providers() else 'CPUExecutionProvider'
|
|
298
|
+
|
|
299
|
+
def __call__(self, *args):
|
|
300
|
+
"""Run inference on the ONNX model.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
*args: Input tensors in order, matching the model's input signature.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
List of output tensors as torch.Tensor objects on the same device as inputs.
|
|
307
|
+
"""
|
|
308
|
+
self._ensure_session()
|
|
309
|
+
|
|
310
|
+
if len(args) != len(self._input_names):
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"Expected {len(self._input_names)} inputs, got {len(args)}. "
|
|
313
|
+
f"Input names: {self._input_names}"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Determine output device from input tensors
|
|
317
|
+
output_device = 'cpu'
|
|
318
|
+
for tensor in args:
|
|
319
|
+
if isinstance(tensor, torch.Tensor) and tensor.is_cuda:
|
|
320
|
+
output_device = tensor.device
|
|
321
|
+
break
|
|
322
|
+
|
|
323
|
+
if self._can_use_cuda_iobinding(args):
|
|
324
|
+
return self._run_with_cuda_iobinding(args)
|
|
325
|
+
|
|
326
|
+
return self._run_with_standard_inference(args, output_device)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def prepare_tensors_for_export(tensors):
|
|
330
|
+
"""
|
|
331
|
+
Prepare tensors for export by cloning them to escape inference mode.
|
|
332
|
+
|
|
333
|
+
Tensors created under torch.inference_mode() cannot participate in autograd,
|
|
334
|
+
which causes torch.export.export() (used by dynamo) to fail with:
|
|
335
|
+
"RuntimeError: Inference tensors cannot be saved for backward."
|
|
336
|
+
|
|
337
|
+
Cloning creates new tensors that are not marked as inference tensors.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
tensors: A sequence of tensors (or other values) to prepare.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
A tuple of prepared tensors (cloned if they were torch.Tensor).
|
|
344
|
+
"""
|
|
345
|
+
prepared = []
|
|
346
|
+
for t in tensors:
|
|
347
|
+
if isinstance(t, torch.Tensor):
|
|
348
|
+
if hasattr(t, 'original_clone'):
|
|
349
|
+
prepared.append(t.original_clone())
|
|
350
|
+
else:
|
|
351
|
+
prepared.append(t.clone())
|
|
352
|
+
else:
|
|
353
|
+
prepared.append(t)
|
|
354
|
+
return tuple(prepared)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class ExportBackend(abc.ABC):
|
|
358
|
+
def __init__(self, node_context, backend_params=None):
|
|
359
|
+
self.node_context = node_context
|
|
360
|
+
if backend_params is None:
|
|
361
|
+
self.backend_params = {}
|
|
362
|
+
else:
|
|
363
|
+
self.backend_params = backend_params
|
|
364
|
+
|
|
365
|
+
self.module_builder = ModuleBuilder(node_context)
|
|
366
|
+
|
|
367
|
+
self.compiled_model = None
|
|
368
|
+
self.compiled_module = None
|
|
369
|
+
self.runtime_device = None
|
|
370
|
+
|
|
371
|
+
def override_module_builder(self, module_builder: Callable):
|
|
372
|
+
self.module_builder = module_builder
|
|
373
|
+
|
|
374
|
+
def _verify_model_location_and_get_hash(self, model_path):
|
|
375
|
+
if not os.path.exists(model_path):
|
|
376
|
+
_get_logger().error(f"Model file not found at {model_path}")
|
|
377
|
+
return None, None
|
|
378
|
+
with open(model_path, 'rb') as f:
|
|
379
|
+
file_data = f.read()
|
|
380
|
+
md5sum = hashlib.md5(file_data).hexdigest()
|
|
381
|
+
sha256sum = hashlib.sha256(file_data).hexdigest()
|
|
382
|
+
return md5sum, sha256sum
|
|
383
|
+
|
|
384
|
+
def _copy_model_to_path(self, model_path, save_path):
|
|
385
|
+
if not os.path.exists(save_path):
|
|
386
|
+
os.makedirs(save_path)
|
|
387
|
+
if not os.path.exists(model_path):
|
|
388
|
+
return None
|
|
389
|
+
|
|
390
|
+
# Check if save_path is the same as the directory containing model_path
|
|
391
|
+
model_dir = os.path.dirname(os.path.abspath(model_path))
|
|
392
|
+
save_dir = os.path.abspath(save_path)
|
|
393
|
+
if model_dir == save_dir:
|
|
394
|
+
# No need to copy if already in the same directory
|
|
395
|
+
return model_path
|
|
396
|
+
|
|
397
|
+
# Get the filename from model_path
|
|
398
|
+
filename = os.path.basename(model_path)
|
|
399
|
+
# Create the full destination path
|
|
400
|
+
dest_path = os.path.join(save_path, filename)
|
|
401
|
+
# Copy the file
|
|
402
|
+
shutil.copy2(model_path, dest_path)
|
|
403
|
+
|
|
404
|
+
return dest_path
|
|
405
|
+
|
|
406
|
+
@abc.abstractmethod
|
|
407
|
+
def get_backend_model_type(self):
|
|
408
|
+
raise NotImplementedError
|
|
409
|
+
|
|
410
|
+
@abc.abstractmethod
|
|
411
|
+
def get_backend_metadata(self) -> dict:
|
|
412
|
+
"""Optional backend-specific metadata to include in node YAML parameters."""
|
|
413
|
+
return {}
|
|
414
|
+
|
|
415
|
+
@abc.abstractmethod
|
|
416
|
+
def compile(self, m: torch.nn.Module = None) -> Any:
|
|
417
|
+
'''
|
|
418
|
+
Compiles the model.
|
|
419
|
+
|
|
420
|
+
This function should return the compiled model. The resulting compiled model
|
|
421
|
+
can be used in recombination with other models.
|
|
422
|
+
'''
|
|
423
|
+
raise NotImplementedError
|
|
424
|
+
|
|
425
|
+
@abc.abstractmethod
|
|
426
|
+
def save(self, save_path: str) -> Tuple[str, str, str]:
|
|
427
|
+
'''
|
|
428
|
+
Save the compiled model to the given path
|
|
429
|
+
|
|
430
|
+
This function should apply all necessary optimizations
|
|
431
|
+
'''
|
|
432
|
+
raise NotImplementedError
|
|
433
|
+
|
|
434
|
+
@abc.abstractmethod
|
|
435
|
+
def load(self, model_path: str, sha256sum: str):
|
|
436
|
+
raise NotImplementedError
|
|
437
|
+
|
|
438
|
+
def _select_runtime_device(self):
|
|
439
|
+
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
440
|
+
|
|
441
|
+
def _load_onnx(self, model_path: str, sha256sum: str):
|
|
442
|
+
_, actual_sha256sum = self._verify_model_location_and_get_hash(
|
|
443
|
+
model_path)
|
|
444
|
+
if actual_sha256sum != sha256sum:
|
|
445
|
+
raise ValueError(
|
|
446
|
+
f"SHA256 checksum mismatch for {model_path}: "
|
|
447
|
+
f"expected {sha256sum}, got {actual_sha256sum}"
|
|
448
|
+
)
|
|
449
|
+
model = SimplifiedONNXProgram(model_path)
|
|
450
|
+
self.compiled_model = model
|
|
451
|
+
has_ort_cuda = 'CUDAExecutionProvider' in ort.get_available_providers()
|
|
452
|
+
self.runtime_device = 'cuda' if (
|
|
453
|
+
torch.cuda.is_available() and has_ort_cuda
|
|
454
|
+
) else 'cpu'
|
|
455
|
+
self.compiled_module = None # ONNX models cannot be represented as a module for reexport.
|
|
456
|
+
|
|
457
|
+
def _load_torchscript(self, model_path: str, sha256sum: str):
|
|
458
|
+
_, actual_sha256sum = self._verify_model_location_and_get_hash(
|
|
459
|
+
model_path)
|
|
460
|
+
if actual_sha256sum != sha256sum:
|
|
461
|
+
raise ValueError(
|
|
462
|
+
f"SHA256 checksum mismatch for {model_path}: "
|
|
463
|
+
f"expected {sha256sum}, got {actual_sha256sum}"
|
|
464
|
+
)
|
|
465
|
+
device = self._select_runtime_device()
|
|
466
|
+
model = torch.jit.load(model_path, map_location=device)
|
|
467
|
+
model = model.to(device)
|
|
468
|
+
self.compiled_model = model.eval()
|
|
469
|
+
self.runtime_device = device
|
|
470
|
+
self.compiled_module = model
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
class NoneExportBackend(ExportBackend):
|
|
474
|
+
"""Null export backend for nodes without an active export backend.
|
|
475
|
+
|
|
476
|
+
This backend represents "no export" behavior in normal usage. It can also
|
|
477
|
+
optionally act as a thin wrapper for existing model artifacts when
|
|
478
|
+
``backend_params['model_path']`` is provided, including checksum validation
|
|
479
|
+
and backend-specific loading.
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
def get_backend_metadata(self):
|
|
483
|
+
return {}
|
|
484
|
+
|
|
485
|
+
def compile(self, m: torch.nn.Module = None):
|
|
486
|
+
if "model_path" not in self.backend_params or self.backend_params['model_path'] is None:
|
|
487
|
+
_get_logger().warning(
|
|
488
|
+
f"No model path provided for {self.node_context.name}")
|
|
489
|
+
_get_logger().warning("if this is intentional, please provide a path to the correct model "
|
|
490
|
+
"in the generated yaml file. Otherwise, please manually fill in the backend parameters.")
|
|
491
|
+
else:
|
|
492
|
+
_, sha256sum = self._verify_model_location_and_get_hash(self.backend_params['model_path'])
|
|
493
|
+
self.load(self.backend_params['model_path'], sha256sum, self.get_backend_model_type())
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def save(self, save_path: str) -> Tuple[str, str, str]:
|
|
497
|
+
if "model_path" not in self.backend_params or self.backend_params['model_path'] is None:
|
|
498
|
+
return None, None, None
|
|
499
|
+
md5sum, sha256sum = self._verify_model_location_and_get_hash(
|
|
500
|
+
self.backend_params['model_path'])
|
|
501
|
+
model_path = self.backend_params['model_path']
|
|
502
|
+
|
|
503
|
+
if "copy_original_model" in self.backend_params and self.backend_params['copy_original_model'] is True:
|
|
504
|
+
model_path = self._copy_model_to_path(model_path, save_path)
|
|
505
|
+
|
|
506
|
+
return model_path, md5sum, sha256sum
|
|
507
|
+
|
|
508
|
+
def load(self, model_path: str, sha256sum: str, model_type=None):
|
|
509
|
+
if model_type is None:
|
|
510
|
+
return
|
|
511
|
+
if model_type == "onnx":
|
|
512
|
+
self._load_onnx(model_path, sha256sum)
|
|
513
|
+
elif model_type == "jit":
|
|
514
|
+
self._load_torchscript(model_path, sha256sum)
|
|
515
|
+
else:
|
|
516
|
+
raise ValueError(f"Unsupported model type: {model_type}")
|
|
517
|
+
|
|
518
|
+
def get_backend_model_type(self):
|
|
519
|
+
if "model_path" not in self.backend_params:
|
|
520
|
+
return None
|
|
521
|
+
|
|
522
|
+
path = self.backend_params['model_path']
|
|
523
|
+
suffix = path.split('.')[-1]
|
|
524
|
+
if suffix == 'pt':
|
|
525
|
+
return "jit"
|
|
526
|
+
elif suffix == 'pt2':
|
|
527
|
+
return "torchscript2"
|
|
528
|
+
elif suffix == 'onnx':
|
|
529
|
+
return "onnx"
|
|
530
|
+
elif suffix == 'cpp' or suffix == "cc":
|
|
531
|
+
return "cpp"
|
|
532
|
+
elif suffix == 'py':
|
|
533
|
+
return "py"
|
|
534
|
+
elif suffix == 'engine' or suffix == 'plan':
|
|
535
|
+
return 'trt'
|
|
536
|
+
else:
|
|
537
|
+
raise ValueError(
|
|
538
|
+
f"Unsupported model file suffix: {suffix}")
|