onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,510 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Utilities for using safetensors as an external data format."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["save_safetensors"]
8
+
9
+ import functools
10
+ import io
11
+ import json
12
+ import os
13
+ import struct
14
+ from collections.abc import Callable, Sequence
15
+ from typing import Any
16
+
17
+ import packaging.version
18
+
19
+ import onnx_ir as ir
20
+
21
+ _HEADER_SIZE_NUMBER_SIZE = 8
22
+ # https://github.com/huggingface/safetensors/blob/806426784adb43631e9a1102d4621126bb589347/safetensors/src/tensor.rs#L811
23
+ _SAFETENSORS_DTYPE_TO_IR_DTYPE = {
24
+ "BOOL": ir.DataType.BOOL,
25
+ "F4": ir.DataType.FLOAT4E2M1,
26
+ "F8_E5M2": ir.DataType.FLOAT8E5M2,
27
+ "F8_E4M3": ir.DataType.FLOAT8E4M3FN,
28
+ "F8_E8M0": ir.DataType.FLOAT8E8M0,
29
+ "BF16": ir.DataType.BFLOAT16,
30
+ "F16": ir.DataType.FLOAT16,
31
+ "F32": ir.DataType.FLOAT,
32
+ "F64": ir.DataType.DOUBLE,
33
+ "I8": ir.DataType.INT8,
34
+ "I16": ir.DataType.INT16,
35
+ "I32": ir.DataType.INT32,
36
+ "I64": ir.DataType.INT64,
37
+ "U8": ir.DataType.UINT8,
38
+ "U16": ir.DataType.UINT16,
39
+ "U32": ir.DataType.UINT32,
40
+ "U64": ir.DataType.UINT64,
41
+ "C64": ir.DataType.COMPLEX64,
42
+ }
43
+ # https://github.com/huggingface/safetensors/blob/806426784adb43631e9a1102d4621126bb589347/bindings/python/src/view.rs#L77
44
+ _IR_DTYPE_TO_SAFETENSORS_DTYPE = {
45
+ ir.DataType.BOOL: "bool",
46
+ ir.DataType.FLOAT4E2M1: "float4_e2m1fn_x2",
47
+ ir.DataType.FLOAT8E5M2: "float8_e5m2",
48
+ ir.DataType.FLOAT8E4M3FN: "float8_e4m3fn",
49
+ ir.DataType.FLOAT8E8M0: "float8_e8m0",
50
+ ir.DataType.FLOAT8E4M3FNUZ: "uint8",
51
+ ir.DataType.FLOAT8E5M2FNUZ: "uint8",
52
+ ir.DataType.BFLOAT16: "bfloat16",
53
+ ir.DataType.FLOAT16: "float16",
54
+ ir.DataType.FLOAT: "float32",
55
+ ir.DataType.DOUBLE: "float64",
56
+ ir.DataType.INT2: "uint8",
57
+ ir.DataType.INT4: "uint8",
58
+ ir.DataType.INT8: "int8",
59
+ ir.DataType.INT16: "int16",
60
+ ir.DataType.INT32: "int32",
61
+ ir.DataType.INT64: "int64",
62
+ ir.DataType.UINT2: "uint8",
63
+ ir.DataType.UINT4: "uint8",
64
+ ir.DataType.UINT8: "uint8",
65
+ ir.DataType.UINT16: "uint16",
66
+ ir.DataType.UINT32: "uint32",
67
+ ir.DataType.UINT64: "uint64",
68
+ ir.DataType.COMPLEX64: "complex64",
69
+ }
70
+
71
+
72
+ @functools.lru_cache(maxsize=1)
73
+ def _import_safetensors():
74
+ """Raise an error if safetensors is not installed."""
75
+ try:
76
+ import safetensors
77
+ except ImportError as e:
78
+ raise ImportError(
79
+ "safetensors is required for using safetensors external data format. "
80
+ "Please install it with 'pip install --upgrade safetensors'."
81
+ ) from e
82
+
83
+ min_required_version = packaging.version.parse("0.7.0")
84
+ version = getattr(safetensors, "__version__", None)
85
+ if version is None or packaging.version.parse(version) < min_required_version:
86
+ raise ImportError(
87
+ f"safetensors version 0.7.0 or higher is required, but version {version} is installed. "
88
+ "Please upgrade it with 'pip install --upgrade safetensors'."
89
+ )
90
+
91
+ return safetensors
92
+
93
+
94
+ def _get_shard_filename(base_name: str, shard_idx: int, total_shards: int) -> str:
95
+ """Generate a filename for a shard.
96
+
97
+ Args:
98
+ base_name: The base filename (e.g., 'model.safetensors').
99
+ shard_idx: The index of this shard (1-indexed).
100
+ total_shards: The total number of shards.
101
+
102
+ Returns:
103
+ The shard filename (e.g., 'model-00001-of-00003.safetensors').
104
+ """
105
+ if total_shards == 1:
106
+ return base_name
107
+
108
+ # Extract extension
109
+ if "." in base_name:
110
+ name, ext = base_name.rsplit(".", 1)
111
+ ext = f".{ext}"
112
+ else:
113
+ name = base_name
114
+ ext = ""
115
+
116
+ # Always use 5 digits to follow transformers convention
117
+ return f"{name}-{shard_idx:05d}-of-{total_shards:05d}{ext}"
118
+
119
+
120
+ def _shard_tensors(
121
+ tensors: Sequence[ir.TensorProtocol], max_shard_size_bytes: int | None
122
+ ) -> list[list[ir.TensorProtocol]]:
123
+ """Shard tensors into multiple files based on max_shard_size_bytes.
124
+
125
+ Args:
126
+ tensors: The tensors to shard.
127
+ max_shard_size_bytes: Maximum size for each shard in bytes. When None,
128
+ no sharding is performed.
129
+
130
+ Returns:
131
+ A list of tensor lists for each shard.
132
+ """
133
+ if max_shard_size_bytes is None:
134
+ # No sharding
135
+ return [list(tensors)]
136
+
137
+ # Shard the tensors by current order
138
+ shards: list[list[ir.TensorProtocol]] = [[]]
139
+ current_shard_size = 0
140
+
141
+ for tensor in tensors:
142
+ tensor_size = tensor.nbytes
143
+ # Check if adding this tensor would exceed max_shard_size_bytes
144
+ if current_shard_size + tensor_size > max_shard_size_bytes and current_shard_size > 0:
145
+ # Start a new shard
146
+ shards.append([])
147
+ current_shard_size = 0
148
+
149
+ shards[-1].append(tensor)
150
+ current_shard_size += tensor_size
151
+
152
+ return shards
153
+
154
+
155
+ def _replace_tensors(
156
+ values: Sequence[ir.Value],
157
+ /,
158
+ location: str | os.PathLike,
159
+ base_dir: str | os.PathLike,
160
+ ) -> None:
161
+ """Replace all tensors in an ONNX model with external data from a safetensors file.
162
+
163
+ Args:
164
+ values: List of initialized values to replace constant values from.
165
+ location: Path to the safetensors file relative to the ONNX model file.
166
+ base_dir: Directory where the ONNX model file is stored.
167
+ """
168
+ tensors: dict[str, ir.ExternalTensor] = _read_safetensors(location, base_dir=base_dir)
169
+ value_map: dict[str, ir.Value] = {value.name: value for value in values} # type: ignore[misc]
170
+ for name, tensor in tensors.items():
171
+ assert name in value_map, f"Bug: Tensor '{name}' not found in model initializers."
172
+ value = value_map[name]
173
+ model_tensor = value.const_value
174
+ assert model_tensor is not None
175
+ updated_tensor = _migrate_tensor_shape_dtype(model_tensor, tensor)
176
+ value.const_value = updated_tensor
177
+
178
+
179
+ def _get_tensor_storage_shape(tensor: ir.TensorProtocol) -> Sequence[int]:
180
+ """Get the storage shape of a tensor for safetensors."""
181
+ # Handle sub-byte dtypes
182
+ if tensor.dtype.bitwidth < 8:
183
+ return [tensor.nbytes]
184
+ return tensor.shape.numpy()
185
+
186
+
187
+ def _save_file(
188
+ initialized_values: Sequence[ir.Value],
189
+ /,
190
+ location: str | os.PathLike,
191
+ base_dir: str | os.PathLike = "",
192
+ *,
193
+ size_threshold_bytes: int,
194
+ max_shard_size_bytes: int | None,
195
+ callback: Callable[[ir.TensorProtocol, ir.external_data.CallbackInfo], None] | None = None,
196
+ ) -> None:
197
+ """Save all tensors in an ONNX model to a safetensors file.
198
+
199
+ Args:
200
+ initialized_values: List of initialized values to consider for saving.
201
+ location: Path to the safetensors file relative to the ONNX model file.
202
+ base_dir: Directory where the ONNX model file is stored.
203
+ size_threshold_bytes: Save to external data if the tensor size in bytes
204
+ is not smaller than this threshold.
205
+ max_shard_size_bytes: Maximum size in bytes (as int) a safetensors file
206
+ before being sharded. If None, no sharding is performed.
207
+ callback: A callback function that is called after each tensor is saved.
208
+ """
209
+ safetensors = _import_safetensors()
210
+
211
+ # Ensure that external_data ends with .safetensors
212
+ if not str(location).endswith(".safetensors"):
213
+ raise ValueError(
214
+ f'The path to safetensors file must have a .safetensors extension, got: "{location}"'
215
+ )
216
+
217
+ # First, collect metadata without loading tensor data
218
+ tensors_to_save: list[ir.TensorProtocol] = []
219
+ values_to_save: list[ir.Value] = []
220
+ for value in initialized_values:
221
+ tensor = value.const_value
222
+ assert tensor is not None
223
+ if tensor.nbytes < size_threshold_bytes:
224
+ continue
225
+ tensors_to_save.append(tensor)
226
+ values_to_save.append(value)
227
+
228
+ total_size = sum(tensor.nbytes for tensor in tensors_to_save)
229
+
230
+ if tensors_to_save:
231
+ # Determine sharding based on max_shard_size_bytes. When max_shard_size_bytes is None,
232
+ # It is the same as one shard (which is the same as no sharding).
233
+ tensor_shards = _shard_tensors(tensors_to_save, max_shard_size_bytes)
234
+ total_shards = len(tensor_shards)
235
+
236
+ # Save each shard, loading only necessary tensor data
237
+ all_filenames = []
238
+ weight_map: dict[str, str] = {} # Maps tensor name to shard filename
239
+ current_offset = 0
240
+ current_index = 0
241
+ for shard_idx, tensor_shard in enumerate(tensor_shards, start=1):
242
+ shard_filename = _get_shard_filename(str(location), shard_idx, total_shards)
243
+
244
+ shard_path = os.path.join(base_dir, shard_filename)
245
+ all_filenames.append(shard_filename)
246
+
247
+ # Build tensor_dict for this shard only
248
+ shard_dict: dict[str, Any] = {}
249
+ for tensor in tensor_shard:
250
+ if callback is not None:
251
+ callback(
252
+ tensor,
253
+ ir.external_data.CallbackInfo(
254
+ total=len(tensors_to_save),
255
+ index=current_index,
256
+ offset=current_offset,
257
+ filename=shard_filename,
258
+ ),
259
+ )
260
+ assert tensor.name is not None
261
+ shard_dict[tensor.name] = {
262
+ "dtype": _IR_DTYPE_TO_SAFETENSORS_DTYPE[tensor.dtype],
263
+ "shape": _get_tensor_storage_shape(tensor),
264
+ "data": tensor.tobytes(),
265
+ }
266
+ # Update weight_map with shard filename
267
+ weight_map[tensor.name] = shard_filename
268
+ current_offset += tensor.nbytes
269
+ current_index += 1
270
+
271
+ safetensors.serialize_file(shard_dict, shard_path)
272
+
273
+ # Save index file if sharding occurred
274
+ if total_shards > 1:
275
+ location_str = str(location)
276
+ if location_str.endswith(".safetensors"):
277
+ index_filename = (
278
+ location_str.rsplit(".safetensors", 1)[0] + ".safetensors.index.json"
279
+ )
280
+ else:
281
+ index_filename = location_str + ".index.json"
282
+ index_path = os.path.join(base_dir, index_filename)
283
+ index_data = {
284
+ "metadata": {"total_size": total_size},
285
+ "weight_map": weight_map,
286
+ }
287
+ with open(index_path, "w") as f:
288
+ json.dump(index_data, f, indent=2)
289
+
290
+ # Replace tensors from each shard file
291
+ for filename in all_filenames:
292
+ _replace_tensors(values_to_save, filename, base_dir)
293
+
294
+
295
+ def save_safetensors(
296
+ model: ir.Model,
297
+ path: str | os.PathLike,
298
+ /,
299
+ *,
300
+ format: str | None = None,
301
+ size_threshold_bytes: int = 256,
302
+ max_shard_size_bytes: int | None = None,
303
+ callback: Callable[[ir.TensorProtocol, ir.external_data.CallbackInfo], None] | None = None,
304
+ ) -> None:
305
+ """Save an ONNX model to a file with external data in a safetensors file.
306
+
307
+ The model object is unmodified after this operation.
308
+
309
+ When sharding is enabled, multiple safetensors files will be created
310
+ with names like "model-00001-of-00003.safetensors", and an index
311
+ file "model.safetensors.index.json" will be created to map tensors
312
+ to their respective shard files. The shards will be created only if
313
+ the total size of tensors exceeds the specified max_shard_size_bytes.
314
+
315
+ .. note::
316
+ Because the safetensors data format uses key-value mapping to store tensors,
317
+ all initializer names in the model (across subgraphs) must be unique.
318
+ Externalizing tensor attributes in nodes to safetensors files is currently not
319
+ supported. If you have tensors from Constant nodes that you want to externalize,
320
+ consider converting them to initializers first with
321
+ :class:`~onnx_ir.passes.common.LiftConstantsToInitializersPass`.
322
+
323
+ Example::
324
+
325
+ import onnx_ir as ir
326
+
327
+ model = ir.load("model.onnx")
328
+
329
+ # Save model with tensors larger than 100 bytes to safetensors external data,
330
+ # sharding files larger than 5GB.
331
+ ir.save_safetensors(
332
+ model,
333
+ "model.onnx",
334
+ size_threshold_bytes=100,
335
+ max_shard_size_bytes=int(5 * 1000**3), # Shard safetensors files larger than 5GB
336
+ )
337
+
338
+ .. tip::
339
+
340
+ A simple progress bar can be implemented by passing a callback function as the following::
341
+
342
+ import onnx_ir as ir
343
+ import tqdm
344
+
345
+ with tqdm.tqdm() as pbar:
346
+ total_set = False
347
+
348
+ def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
349
+ nonlocal total_set
350
+ if not total_set:
351
+ pbar.total = metadata.total
352
+ total_set = True
353
+
354
+ pbar.update()
355
+ pbar.set_description(f"Saving {metadata.filename}: {tensor.name} ({tensor.dtype}, {tensor.shape})")
356
+
357
+ ir.save_safetensors(
358
+ ...,
359
+ callback=callback,
360
+ )
361
+
362
+ .. versionadded:: 0.1.15
363
+
364
+ Args:
365
+ model: ONNX model to save.
366
+ path: Path to the ONNX model file. E.g. "model.onnx".
367
+ format: The format of the file (e.g. ``protobuf``, ``textproto``, ``json``, etc.).
368
+ If None, the format is inferred from the file extension.
369
+ size_threshold_bytes: Save to external data if the tensor size in bytes
370
+ is not smaller than this threshold.
371
+ max_shard_size_bytes: Maximum size in bytes (as int) a safetensors file
372
+ before being sharded. If None, no sharding is performed.
373
+ callback: A callback function that is called after each tensor is saved.
374
+ The callback must have signature ``Callable[[ir.TensorProtocol, ir.external_data.CallbackInfo], None]``,
375
+ where the first argument is the tensor being saved and the second contains metadata such as filename and progress.
376
+
377
+ Raises:
378
+ ValueError: If duplicate initializer names are found in the model.
379
+ """
380
+ # Derive external_data from path if not provided
381
+ path_str = str(path)
382
+ # Get the base name without extension
383
+ if "." in os.path.basename(path_str):
384
+ base_name = os.path.splitext(os.path.basename(path_str))[0]
385
+ else:
386
+ base_name = os.path.basename(path_str)
387
+ external_data = f"{base_name}.safetensors"
388
+
389
+ # Store the original initializer values so they can be restored if modify_model=False
390
+ value_tensor_pairs: list[tuple[ir.Value, ir.TensorProtocol]] = []
391
+ initializer_names: set[str] = set()
392
+ for graph in model.graphs():
393
+ for value in graph.initializers.values():
394
+ tensor = value.const_value
395
+ # The value.name should be the same as tensor.name. However,
396
+ # in case there is a conflict, we do not care and will prefer value.name.
397
+ name = value.name
398
+ if name is None:
399
+ raise ValueError(
400
+ f"Initializer value '{value!r}' has no name (in graph {graph.name!r}). "
401
+ "All initializers must have names."
402
+ )
403
+ if tensor is None:
404
+ continue
405
+ if name in initializer_names:
406
+ raise ValueError(
407
+ f"Duplicate initializer name found: {name} (in graph {graph.name!r})."
408
+ " Rename the initializers to have unique names before saving to safetensors."
409
+ )
410
+ initializer_names.add(name)
411
+ value_tensor_pairs.append((value, tensor))
412
+
413
+ try:
414
+ _save_file(
415
+ [value for value, _ in value_tensor_pairs],
416
+ external_data,
417
+ os.path.dirname(path),
418
+ size_threshold_bytes=size_threshold_bytes,
419
+ max_shard_size_bytes=max_shard_size_bytes,
420
+ callback=callback,
421
+ )
422
+ ir.save(model, path, format=format)
423
+ finally:
424
+ # Restore original initializers to avoid side effects
425
+ for value, tensor in value_tensor_pairs:
426
+ value.const_value = tensor
427
+
428
+
429
+ def _read_safetensors_header(file: io.IOBase) -> tuple[dict[str, dict[str, Any]], int]:
430
+ """Read the header of a safetensors file.
431
+
432
+ Args:
433
+ file: The safetensors file to read.
434
+
435
+ Returns:
436
+ The header of the safetensors file.
437
+ """
438
+ file.seek(0)
439
+ header_size = struct.unpack_from("<Q", file.read(_HEADER_SIZE_NUMBER_SIZE))[0]
440
+ header = file.read(header_size)
441
+ return json.loads(header.decode("utf-8")), header_size
442
+
443
+
444
+ def _read_safetensors(
445
+ location: str | os.PathLike, base_dir: str | os.PathLike
446
+ ) -> dict[str, ir.ExternalTensor]:
447
+ """Read a safetensors file.
448
+
449
+ Args:
450
+ location: The safetensors file to read.
451
+ base_dir: Directory where the ONNX model file is stored.
452
+
453
+ Returns:
454
+ The contents of the safetensors file.
455
+ """
456
+ path = os.path.join(base_dir, location)
457
+ with open(path, "rb") as file:
458
+ header, header_size = _read_safetensors_header(file)
459
+ tensors = {}
460
+ for name, metadata in header.items():
461
+ if name == "__metadata__":
462
+ continue
463
+ offset = metadata["data_offsets"][0] + header_size + _HEADER_SIZE_NUMBER_SIZE
464
+ length = metadata["data_offsets"][1] - metadata["data_offsets"][0]
465
+ tensors[name] = ir.ExternalTensor(
466
+ location=location,
467
+ offset=offset,
468
+ length=length,
469
+ dtype=_SAFETENSORS_DTYPE_TO_IR_DTYPE[metadata["dtype"]],
470
+ shape=ir.Shape(metadata["shape"]),
471
+ name=name,
472
+ base_dir=base_dir,
473
+ )
474
+ return tensors
475
+
476
+
477
+ def _migrate_tensor_shape_dtype(
478
+ model_tensor: ir.TensorProtocol, safe_tensor: ir.ExternalTensor
479
+ ) -> ir.ExternalTensor:
480
+ """Migrate the shape and dtype of a tensor.
481
+
482
+ This is needed because we store 4bit and 2bit tensors as UINT8 in safetensors.
483
+
484
+ Args:
485
+ model_tensor: The tensor to migrate.
486
+ safe_tensor: The tensor to migrate to.
487
+
488
+ Returns:
489
+ The migrated tensor.
490
+ """
491
+ if model_tensor.dtype in {
492
+ # Types that safetensors does not support directly
493
+ ir.DataType.FLOAT8E4M3FNUZ,
494
+ ir.DataType.FLOAT8E5M2FNUZ,
495
+ ir.DataType.FLOAT4E2M1, # Still need to migrate shape
496
+ ir.DataType.INT4,
497
+ ir.DataType.INT2,
498
+ ir.DataType.UINT4,
499
+ ir.DataType.UINT2,
500
+ }:
501
+ return ir.ExternalTensor(
502
+ location=safe_tensor.location,
503
+ offset=safe_tensor.offset,
504
+ length=safe_tensor.length,
505
+ dtype=model_tensor.dtype,
506
+ shape=model_tensor.shape, # type: ignore[arg-type]
507
+ name=safe_tensor.name,
508
+ base_dir=safe_tensor.base_dir,
509
+ )
510
+ return safe_tensor