onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,510 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Utilities for using safetensors as an external data format."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = ["save_safetensors"]
|
|
8
|
+
|
|
9
|
+
import functools
|
|
10
|
+
import io
|
|
11
|
+
import json
|
|
12
|
+
import os
|
|
13
|
+
import struct
|
|
14
|
+
from collections.abc import Callable, Sequence
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import packaging.version
|
|
18
|
+
|
|
19
|
+
import onnx_ir as ir
|
|
20
|
+
|
|
21
|
+
_HEADER_SIZE_NUMBER_SIZE = 8
|
|
22
|
+
# https://github.com/huggingface/safetensors/blob/806426784adb43631e9a1102d4621126bb589347/safetensors/src/tensor.rs#L811
|
|
23
|
+
_SAFETENSORS_DTYPE_TO_IR_DTYPE = {
|
|
24
|
+
"BOOL": ir.DataType.BOOL,
|
|
25
|
+
"F4": ir.DataType.FLOAT4E2M1,
|
|
26
|
+
"F8_E5M2": ir.DataType.FLOAT8E5M2,
|
|
27
|
+
"F8_E4M3": ir.DataType.FLOAT8E4M3FN,
|
|
28
|
+
"F8_E8M0": ir.DataType.FLOAT8E8M0,
|
|
29
|
+
"BF16": ir.DataType.BFLOAT16,
|
|
30
|
+
"F16": ir.DataType.FLOAT16,
|
|
31
|
+
"F32": ir.DataType.FLOAT,
|
|
32
|
+
"F64": ir.DataType.DOUBLE,
|
|
33
|
+
"I8": ir.DataType.INT8,
|
|
34
|
+
"I16": ir.DataType.INT16,
|
|
35
|
+
"I32": ir.DataType.INT32,
|
|
36
|
+
"I64": ir.DataType.INT64,
|
|
37
|
+
"U8": ir.DataType.UINT8,
|
|
38
|
+
"U16": ir.DataType.UINT16,
|
|
39
|
+
"U32": ir.DataType.UINT32,
|
|
40
|
+
"U64": ir.DataType.UINT64,
|
|
41
|
+
"C64": ir.DataType.COMPLEX64,
|
|
42
|
+
}
|
|
43
|
+
# https://github.com/huggingface/safetensors/blob/806426784adb43631e9a1102d4621126bb589347/bindings/python/src/view.rs#L77
|
|
44
|
+
_IR_DTYPE_TO_SAFETENSORS_DTYPE = {
|
|
45
|
+
ir.DataType.BOOL: "bool",
|
|
46
|
+
ir.DataType.FLOAT4E2M1: "float4_e2m1fn_x2",
|
|
47
|
+
ir.DataType.FLOAT8E5M2: "float8_e5m2",
|
|
48
|
+
ir.DataType.FLOAT8E4M3FN: "float8_e4m3fn",
|
|
49
|
+
ir.DataType.FLOAT8E8M0: "float8_e8m0",
|
|
50
|
+
ir.DataType.FLOAT8E4M3FNUZ: "uint8",
|
|
51
|
+
ir.DataType.FLOAT8E5M2FNUZ: "uint8",
|
|
52
|
+
ir.DataType.BFLOAT16: "bfloat16",
|
|
53
|
+
ir.DataType.FLOAT16: "float16",
|
|
54
|
+
ir.DataType.FLOAT: "float32",
|
|
55
|
+
ir.DataType.DOUBLE: "float64",
|
|
56
|
+
ir.DataType.INT2: "uint8",
|
|
57
|
+
ir.DataType.INT4: "uint8",
|
|
58
|
+
ir.DataType.INT8: "int8",
|
|
59
|
+
ir.DataType.INT16: "int16",
|
|
60
|
+
ir.DataType.INT32: "int32",
|
|
61
|
+
ir.DataType.INT64: "int64",
|
|
62
|
+
ir.DataType.UINT2: "uint8",
|
|
63
|
+
ir.DataType.UINT4: "uint8",
|
|
64
|
+
ir.DataType.UINT8: "uint8",
|
|
65
|
+
ir.DataType.UINT16: "uint16",
|
|
66
|
+
ir.DataType.UINT32: "uint32",
|
|
67
|
+
ir.DataType.UINT64: "uint64",
|
|
68
|
+
ir.DataType.COMPLEX64: "complex64",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@functools.lru_cache(maxsize=1)
|
|
73
|
+
def _import_safetensors():
|
|
74
|
+
"""Raise an error if safetensors is not installed."""
|
|
75
|
+
try:
|
|
76
|
+
import safetensors
|
|
77
|
+
except ImportError as e:
|
|
78
|
+
raise ImportError(
|
|
79
|
+
"safetensors is required for using safetensors external data format. "
|
|
80
|
+
"Please install it with 'pip install --upgrade safetensors'."
|
|
81
|
+
) from e
|
|
82
|
+
|
|
83
|
+
min_required_version = packaging.version.parse("0.7.0")
|
|
84
|
+
version = getattr(safetensors, "__version__", None)
|
|
85
|
+
if version is None or packaging.version.parse(version) < min_required_version:
|
|
86
|
+
raise ImportError(
|
|
87
|
+
f"safetensors version 0.7.0 or higher is required, but version {version} is installed. "
|
|
88
|
+
"Please upgrade it with 'pip install --upgrade safetensors'."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return safetensors
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _get_shard_filename(base_name: str, shard_idx: int, total_shards: int) -> str:
|
|
95
|
+
"""Generate a filename for a shard.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
base_name: The base filename (e.g., 'model.safetensors').
|
|
99
|
+
shard_idx: The index of this shard (1-indexed).
|
|
100
|
+
total_shards: The total number of shards.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
The shard filename (e.g., 'model-00001-of-00003.safetensors').
|
|
104
|
+
"""
|
|
105
|
+
if total_shards == 1:
|
|
106
|
+
return base_name
|
|
107
|
+
|
|
108
|
+
# Extract extension
|
|
109
|
+
if "." in base_name:
|
|
110
|
+
name, ext = base_name.rsplit(".", 1)
|
|
111
|
+
ext = f".{ext}"
|
|
112
|
+
else:
|
|
113
|
+
name = base_name
|
|
114
|
+
ext = ""
|
|
115
|
+
|
|
116
|
+
# Always use 5 digits to follow transformers convention
|
|
117
|
+
return f"{name}-{shard_idx:05d}-of-{total_shards:05d}{ext}"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _shard_tensors(
|
|
121
|
+
tensors: Sequence[ir.TensorProtocol], max_shard_size_bytes: int | None
|
|
122
|
+
) -> list[list[ir.TensorProtocol]]:
|
|
123
|
+
"""Shard tensors into multiple files based on max_shard_size_bytes.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
tensors: The tensors to shard.
|
|
127
|
+
max_shard_size_bytes: Maximum size for each shard in bytes. When None,
|
|
128
|
+
no sharding is performed.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A list of tensor lists for each shard.
|
|
132
|
+
"""
|
|
133
|
+
if max_shard_size_bytes is None:
|
|
134
|
+
# No sharding
|
|
135
|
+
return [list(tensors)]
|
|
136
|
+
|
|
137
|
+
# Shard the tensors by current order
|
|
138
|
+
shards: list[list[ir.TensorProtocol]] = [[]]
|
|
139
|
+
current_shard_size = 0
|
|
140
|
+
|
|
141
|
+
for tensor in tensors:
|
|
142
|
+
tensor_size = tensor.nbytes
|
|
143
|
+
# Check if adding this tensor would exceed max_shard_size_bytes
|
|
144
|
+
if current_shard_size + tensor_size > max_shard_size_bytes and current_shard_size > 0:
|
|
145
|
+
# Start a new shard
|
|
146
|
+
shards.append([])
|
|
147
|
+
current_shard_size = 0
|
|
148
|
+
|
|
149
|
+
shards[-1].append(tensor)
|
|
150
|
+
current_shard_size += tensor_size
|
|
151
|
+
|
|
152
|
+
return shards
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _replace_tensors(
|
|
156
|
+
values: Sequence[ir.Value],
|
|
157
|
+
/,
|
|
158
|
+
location: str | os.PathLike,
|
|
159
|
+
base_dir: str | os.PathLike,
|
|
160
|
+
) -> None:
|
|
161
|
+
"""Replace all tensors in an ONNX model with external data from a safetensors file.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
values: List of initialized values to replace constant values from.
|
|
165
|
+
location: Path to the safetensors file relative to the ONNX model file.
|
|
166
|
+
base_dir: Directory where the ONNX model file is stored.
|
|
167
|
+
"""
|
|
168
|
+
tensors: dict[str, ir.ExternalTensor] = _read_safetensors(location, base_dir=base_dir)
|
|
169
|
+
value_map: dict[str, ir.Value] = {value.name: value for value in values} # type: ignore[misc]
|
|
170
|
+
for name, tensor in tensors.items():
|
|
171
|
+
assert name in value_map, f"Bug: Tensor '{name}' not found in model initializers."
|
|
172
|
+
value = value_map[name]
|
|
173
|
+
model_tensor = value.const_value
|
|
174
|
+
assert model_tensor is not None
|
|
175
|
+
updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor)
|
|
176
|
+
value.const_value = updated_tensor
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _get_tensor_storage_shape(tensor: ir.TensorProtocol) -> Sequence[int]:
|
|
180
|
+
"""Get the storage shape of a tensor for safetensors."""
|
|
181
|
+
# Handle sub-byte dtypes
|
|
182
|
+
if tensor.dtype.bitwidth < 8:
|
|
183
|
+
return [tensor.nbytes]
|
|
184
|
+
return tensor.shape.numpy()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _save_file(
|
|
188
|
+
initialized_values: Sequence[ir.Value],
|
|
189
|
+
/,
|
|
190
|
+
location: str | os.PathLike,
|
|
191
|
+
base_dir: str | os.PathLike = "",
|
|
192
|
+
*,
|
|
193
|
+
size_threshold_bytes: int,
|
|
194
|
+
max_shard_size_bytes: int | None,
|
|
195
|
+
callback: Callable[[ir.TensorProtocol, ir.external_data.CallbackInfo], None] | None = None,
|
|
196
|
+
) -> None:
|
|
197
|
+
"""Save all tensors in an ONNX model to a safetensors file.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
initialized_values: List of initialized values to consider for saving.
|
|
201
|
+
location: Path to the safetensors file relative to the ONNX model file.
|
|
202
|
+
base_dir: Directory where the ONNX model file is stored.
|
|
203
|
+
size_threshold_bytes: Save to external data if the tensor size in bytes
|
|
204
|
+
is not smaller than this threshold.
|
|
205
|
+
max_shard_size_bytes: Maximum size in bytes (as int) a safetensors file
|
|
206
|
+
before being sharded. If None, no sharding is performed.
|
|
207
|
+
callback: A callback function that is called after each tensor is saved.
|
|
208
|
+
"""
|
|
209
|
+
safetensors = _import_safetensors()
|
|
210
|
+
|
|
211
|
+
# Ensure that external_data ends with .safetensors
|
|
212
|
+
if not str(location).endswith(".safetensors"):
|
|
213
|
+
raise ValueError(
|
|
214
|
+
f'The path to safetensors file must have a .safetensors extension, got: "{location}"'
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# First, collect metadata without loading tensor data
|
|
218
|
+
tensors_to_save: list[ir.TensorProtocol] = []
|
|
219
|
+
values_to_save: list[ir.Value] = []
|
|
220
|
+
for value in initialized_values:
|
|
221
|
+
tensor = value.const_value
|
|
222
|
+
assert tensor is not None
|
|
223
|
+
if tensor.nbytes < size_threshold_bytes:
|
|
224
|
+
continue
|
|
225
|
+
tensors_to_save.append(tensor)
|
|
226
|
+
values_to_save.append(value)
|
|
227
|
+
|
|
228
|
+
total_size = sum(tensor.nbytes for tensor in tensors_to_save)
|
|
229
|
+
|
|
230
|
+
if tensors_to_save:
|
|
231
|
+
# Determine sharding based on max_shard_size_bytes. When max_shard_size_bytes is None,
|
|
232
|
+
# It is the same as one shard (which is the same as no sharding).
|
|
233
|
+
tensor_shards = _shard_tensors(tensors_to_save, max_shard_size_bytes)
|
|
234
|
+
total_shards = len(tensor_shards)
|
|
235
|
+
|
|
236
|
+
# Save each shard, loading only necessary tensor data
|
|
237
|
+
all_filenames = []
|
|
238
|
+
weight_map: dict[str, str] = {} # Maps tensor name to shard filename
|
|
239
|
+
current_offset = 0
|
|
240
|
+
current_index = 0
|
|
241
|
+
for shard_idx, tensor_shard in enumerate(tensor_shards, start=1):
|
|
242
|
+
shard_filename = _get_shard_filename(str(location), shard_idx, total_shards)
|
|
243
|
+
|
|
244
|
+
shard_path = os.path.join(base_dir, shard_filename)
|
|
245
|
+
all_filenames.append(shard_filename)
|
|
246
|
+
|
|
247
|
+
# Build tensor_dict for this shard only
|
|
248
|
+
shard_dict: dict[str, Any] = {}
|
|
249
|
+
for tensor in tensor_shard:
|
|
250
|
+
if callback is not None:
|
|
251
|
+
callback(
|
|
252
|
+
tensor,
|
|
253
|
+
ir.external_data.CallbackInfo(
|
|
254
|
+
total=len(tensors_to_save),
|
|
255
|
+
index=current_index,
|
|
256
|
+
offset=current_offset,
|
|
257
|
+
filename=shard_filename,
|
|
258
|
+
),
|
|
259
|
+
)
|
|
260
|
+
assert tensor.name is not None
|
|
261
|
+
shard_dict[tensor.name] = {
|
|
262
|
+
"dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype],
|
|
263
|
+
"shape": _get_tensor_storage_shape(tensor),
|
|
264
|
+
"data": tensor.tobytes(),
|
|
265
|
+
}
|
|
266
|
+
# Update weight_map with shard filename
|
|
267
|
+
weight_map[tensor.name] = shard_filename
|
|
268
|
+
current_offset += tensor.nbytes
|
|
269
|
+
current_index += 1
|
|
270
|
+
|
|
271
|
+
safetensors.serialize_file(shard_dict, shard_path)
|
|
272
|
+
|
|
273
|
+
# Save index file if sharding occurred
|
|
274
|
+
if total_shards > 1:
|
|
275
|
+
location_str = str(location)
|
|
276
|
+
if location_str.endswith(".safetensors"):
|
|
277
|
+
index_filename = (
|
|
278
|
+
location_str.rsplit(".safetensors", 1)[0] + ".safetensors.index.json"
|
|
279
|
+
)
|
|
280
|
+
else:
|
|
281
|
+
index_filename = location_str + ".index.json"
|
|
282
|
+
index_path = os.path.join(base_dir, index_filename)
|
|
283
|
+
index_data = {
|
|
284
|
+
"metadata": {"total_size": total_size},
|
|
285
|
+
"weight_map": weight_map,
|
|
286
|
+
}
|
|
287
|
+
with open(index_path, "w") as f:
|
|
288
|
+
json.dump(index_data, f, indent=2)
|
|
289
|
+
|
|
290
|
+
# Replace tensors from each shard file
|
|
291
|
+
for filename in all_filenames:
|
|
292
|
+
_replace_tensors(values_to_save, filename, base_dir)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def save_safetensors(
|
|
296
|
+
model: ir.Model,
|
|
297
|
+
path: str | os.PathLike,
|
|
298
|
+
/,
|
|
299
|
+
*,
|
|
300
|
+
format: str | None = None,
|
|
301
|
+
size_threshold_bytes: int = 256,
|
|
302
|
+
max_shard_size_bytes: int | None = None,
|
|
303
|
+
callback: Callable[[ir.TensorProtocol, ir.external_data.CallbackInfo], None] | None = None,
|
|
304
|
+
) -> None:
|
|
305
|
+
"""Save an ONNX model to a file with external data in a safetensors file.
|
|
306
|
+
|
|
307
|
+
The model object is unmodified after this operation.
|
|
308
|
+
|
|
309
|
+
When sharding is enabled, multiple safetensors files will be created
|
|
310
|
+
with names like "model-00001-of-00003.safetensors", and an index
|
|
311
|
+
file "model.safetensors.index.json" will be created to map tensors
|
|
312
|
+
to their respective shard files. The shards will be created only if
|
|
313
|
+
the total size of tensors exceeds the specified max_shard_size_bytes.
|
|
314
|
+
|
|
315
|
+
.. note::
|
|
316
|
+
Because the safetensors data format uses key-value mapping to store tensors,
|
|
317
|
+
all initializer names in the model (across subgraphs) must be unique.
|
|
318
|
+
Externalizing tensor attributes in nodes to safetensors files is currently not
|
|
319
|
+
supported. If you have tensors from Constant nodes that you want to externalize,
|
|
320
|
+
consider converting them to initializers first with
|
|
321
|
+
:class:`~onnx_ir.passes.common.LiftConstantsToInitializersPass`.
|
|
322
|
+
|
|
323
|
+
Example::
|
|
324
|
+
|
|
325
|
+
import onnx_ir as ir
|
|
326
|
+
|
|
327
|
+
model = ir.load("model.onnx")
|
|
328
|
+
|
|
329
|
+
# Save model with tensors larger than 100 bytes to safetensors external data,
|
|
330
|
+
# sharding files larger than 5GB.
|
|
331
|
+
ir.save_safetensors(
|
|
332
|
+
model,
|
|
333
|
+
"model.onnx",
|
|
334
|
+
size_threshold_bytes=100,
|
|
335
|
+
max_shard_size_bytes=int(5 * 1000**3), # Shard safetensors files larger than 5GB
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
.. tip::
|
|
339
|
+
|
|
340
|
+
A simple progress bar can be implemented by passing a callback function as the following::
|
|
341
|
+
|
|
342
|
+
import onnx_ir as ir
|
|
343
|
+
import tqdm
|
|
344
|
+
|
|
345
|
+
with tqdm.tqdm() as pbar:
|
|
346
|
+
total_set = False
|
|
347
|
+
|
|
348
|
+
def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
|
|
349
|
+
nonlocal total_set
|
|
350
|
+
if not total_set:
|
|
351
|
+
pbar.total = metadata.total
|
|
352
|
+
total_set = True
|
|
353
|
+
|
|
354
|
+
pbar.update()
|
|
355
|
+
pbar.set_description(f"Saving {metadata.filename}: {tensor.name} ({tensor.dtype}, {tensor.shape})")
|
|
356
|
+
|
|
357
|
+
ir.save_safetensors(
|
|
358
|
+
...,
|
|
359
|
+
callback=callback,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
.. versionadded:: 0.1.15
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
model: ONNX model to save.
|
|
366
|
+
path: Path to the ONNX model file. E.g. "model.onnx".
|
|
367
|
+
format: The format of the file (e.g. ``protobuf``, ``textproto``, ``json``, etc.).
|
|
368
|
+
If None, the format is inferred from the file extension.
|
|
369
|
+
size_threshold_bytes: Save to external data if the tensor size in bytes
|
|
370
|
+
is not smaller than this threshold.
|
|
371
|
+
max_shard_size_bytes: Maximum size in bytes (as int) a safetensors file
|
|
372
|
+
before being sharded. If None, no sharding is performed.
|
|
373
|
+
callback: A callback function that is called after each tensor is saved.
|
|
374
|
+
The callback must have signature ``Callable[[ir.TensorProtocol, ir.external_data.CallbackInfo], None]``,
|
|
375
|
+
where the first argument is the tensor being saved and the second contains metadata such as filename and progress.
|
|
376
|
+
|
|
377
|
+
Raises:
|
|
378
|
+
ValueError: If duplicate initializer names are found in the model.
|
|
379
|
+
"""
|
|
380
|
+
# Derive external_data from path if not provided
|
|
381
|
+
path_str = str(path)
|
|
382
|
+
# Get the base name without extension
|
|
383
|
+
if "." in os.path.basename(path_str):
|
|
384
|
+
base_name = os.path.splitext(os.path.basename(path_str))[0]
|
|
385
|
+
else:
|
|
386
|
+
base_name = os.path.basename(path_str)
|
|
387
|
+
external_data = f"{base_name}.safetensors"
|
|
388
|
+
|
|
389
|
+
# Store the original initializer values so they can be restored if modify_model=False
|
|
390
|
+
value_tensor_pairs: list[tuple[ir.Value, ir.TensorProtocol]] = []
|
|
391
|
+
initializer_names: set[str] = set()
|
|
392
|
+
for graph in model.graphs():
|
|
393
|
+
for value in graph.initializers.values():
|
|
394
|
+
tensor = value.const_value
|
|
395
|
+
# The value.name should be the same as tensor.name. However,
|
|
396
|
+
# in case there is a conflict, we do not care and will prefer value.name.
|
|
397
|
+
name = value.name
|
|
398
|
+
if name is None:
|
|
399
|
+
raise ValueError(
|
|
400
|
+
f"Initializer value '{value!r}' has no name (in graph {graph.name!r}). "
|
|
401
|
+
"All initializers must have names."
|
|
402
|
+
)
|
|
403
|
+
if tensor is None:
|
|
404
|
+
continue
|
|
405
|
+
if name in initializer_names:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
f"Duplicate initializer name found: {name} (in graph {graph.name!r})."
|
|
408
|
+
" Rename the initializers to have unique names before saving to safetensors."
|
|
409
|
+
)
|
|
410
|
+
initializer_names.add(name)
|
|
411
|
+
value_tensor_pairs.append((value, tensor))
|
|
412
|
+
|
|
413
|
+
try:
|
|
414
|
+
_save_file(
|
|
415
|
+
[value for value, _ in value_tensor_pairs],
|
|
416
|
+
external_data,
|
|
417
|
+
os.path.dirname(path),
|
|
418
|
+
size_threshold_bytes=size_threshold_bytes,
|
|
419
|
+
max_shard_size_bytes=max_shard_size_bytes,
|
|
420
|
+
callback=callback,
|
|
421
|
+
)
|
|
422
|
+
ir.save(model, path, format=format)
|
|
423
|
+
finally:
|
|
424
|
+
# Restore original initializers to avoid side effects
|
|
425
|
+
for value, tensor in value_tensor_pairs:
|
|
426
|
+
value.const_value = tensor
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def _read_safetensors_header(file: io.IOBase) -> tuple[dict[str, dict[str, Any]], int]:
|
|
430
|
+
"""Read the header of a safetensors file.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
file: The safetensors file to read.
|
|
434
|
+
|
|
435
|
+
Returns:
|
|
436
|
+
The header of the safetensors file.
|
|
437
|
+
"""
|
|
438
|
+
file.seek(0)
|
|
439
|
+
header_size = struct.unpack_from("<Q", file.read(_HEADER_SIZE_NUMBER_SIZE))[0]
|
|
440
|
+
header = file.read(header_size)
|
|
441
|
+
return json.loads(header.decode("utf-8")), header_size
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def _read_safetensors(
|
|
445
|
+
location: str | os.PathLike, base_dir: str | os.PathLike
|
|
446
|
+
) -> dict[str, ir.ExternalTensor]:
|
|
447
|
+
"""Read a safetensors file.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
location: The safetensors file to read.
|
|
451
|
+
base_dir: Directory where the ONNX model file is stored.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
The contents of the safetensors file.
|
|
455
|
+
"""
|
|
456
|
+
path = os.path.join(base_dir, location)
|
|
457
|
+
with open(path, "rb") as file:
|
|
458
|
+
header, header_size = _read_safetensors_header(file)
|
|
459
|
+
tensors = {}
|
|
460
|
+
for name, metadata in header.items():
|
|
461
|
+
if name == "__metadata__":
|
|
462
|
+
continue
|
|
463
|
+
offset = metadata["data_offsets"][0] + header_size + _HEADER_SIZE_NUMBER_SIZE
|
|
464
|
+
length = metadata["data_offsets"][1] - metadata["data_offsets"][0]
|
|
465
|
+
tensors[name] = ir.ExternalTensor(
|
|
466
|
+
location=location,
|
|
467
|
+
offset=offset,
|
|
468
|
+
length=length,
|
|
469
|
+
dtype=_SAFETENSORS_DTYPE_TO_IR_DTYPE[metadata["dtype"]],
|
|
470
|
+
shape=ir.Shape(metadata["shape"]),
|
|
471
|
+
name=name,
|
|
472
|
+
base_dir=base_dir,
|
|
473
|
+
)
|
|
474
|
+
return tensors
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def _migrate_tensor_shape_dtype(
|
|
478
|
+
model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor
|
|
479
|
+
) -> ir.ExternalTensor:
|
|
480
|
+
"""Migrate the shape and dtype of a tensor.
|
|
481
|
+
|
|
482
|
+
This is needed because we store 4bit and 2bit tensors as UINT8 in safetensors.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
model_tensor: The tensor to migrate.
|
|
486
|
+
safe_tensor: The tensor to migrate to.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
The migrated tensor.
|
|
490
|
+
"""
|
|
491
|
+
if model_tensor.dtype in {
|
|
492
|
+
# Types that safetensors does not support directly
|
|
493
|
+
ir.DataType.FLOAT8E4M3FNUZ,
|
|
494
|
+
ir.DataType.FLOAT8E5M2FNUZ,
|
|
495
|
+
ir.DataType.FLOAT4E2M1, # Still need to migrate shape
|
|
496
|
+
ir.DataType.INT4,
|
|
497
|
+
ir.DataType.INT2,
|
|
498
|
+
ir.DataType.UINT4,
|
|
499
|
+
ir.DataType.UINT2,
|
|
500
|
+
}:
|
|
501
|
+
return ir.ExternalTensor(
|
|
502
|
+
location=safe_tensor.location,
|
|
503
|
+
offset=safe_tensor.offset,
|
|
504
|
+
length=safe_tensor.length,
|
|
505
|
+
dtype=model_tensor.dtype,
|
|
506
|
+
shape=model_tensor.shape, # type: ignore[arg-type]
|
|
507
|
+
name=safe_tensor.name,
|
|
508
|
+
base_dir=safe_tensor.base_dir,
|
|
509
|
+
)
|
|
510
|
+
return safe_tensor
|