onnx-diagnostic 0.2.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +452 -0
  4. onnx_diagnostic/doc.py +4 -4
  5. onnx_diagnostic/export/__init__.py +2 -1
  6. onnx_diagnostic/export/dynamic_shapes.py +574 -23
  7. onnx_diagnostic/export/validate.py +170 -0
  8. onnx_diagnostic/ext_test_case.py +151 -31
  9. onnx_diagnostic/helpers/__init__.py +1 -0
  10. onnx_diagnostic/helpers/bench_run.py +450 -0
  11. onnx_diagnostic/helpers/cache_helper.py +216 -0
  12. onnx_diagnostic/helpers/config_helper.py +80 -0
  13. onnx_diagnostic/{helpers.py → helpers/helper.py} +341 -662
  14. onnx_diagnostic/helpers/memory_peak.py +249 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +921 -0
  16. onnx_diagnostic/{ort_session.py → helpers/ort_session.py} +4 -3
  17. onnx_diagnostic/helpers/rt_helper.py +47 -0
  18. onnx_diagnostic/{torch_test_helper.py → helpers/torch_test_helper.py} +149 -55
  19. onnx_diagnostic/reference/ops/op_cast_like.py +1 -1
  20. onnx_diagnostic/reference/ort_evaluator.py +7 -2
  21. onnx_diagnostic/tasks/__init__.py +48 -0
  22. onnx_diagnostic/tasks/automatic_speech_recognition.py +165 -0
  23. onnx_diagnostic/tasks/fill_mask.py +67 -0
  24. onnx_diagnostic/tasks/image_classification.py +96 -0
  25. onnx_diagnostic/tasks/image_text_to_text.py +145 -0
  26. onnx_diagnostic/tasks/sentence_similarity.py +67 -0
  27. onnx_diagnostic/tasks/text2text_generation.py +172 -0
  28. onnx_diagnostic/tasks/text_classification.py +67 -0
  29. onnx_diagnostic/tasks/text_generation.py +248 -0
  30. onnx_diagnostic/tasks/zero_shot_image_classification.py +106 -0
  31. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +111 -146
  32. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +346 -57
  33. onnx_diagnostic/torch_export_patches/patch_inputs.py +203 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +41 -2
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +39 -49
  36. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  37. onnx_diagnostic/torch_models/hghub/hub_api.py +254 -0
  38. onnx_diagnostic/torch_models/hghub/hub_data.py +203 -0
  39. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +3571 -0
  40. onnx_diagnostic/torch_models/hghub/model_inputs.py +151 -0
  41. onnx_diagnostic/torch_models/test_helper.py +1250 -0
  42. onnx_diagnostic/torch_models/untrained/llm_phi2.py +3 -4
  43. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +3 -4
  44. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  45. onnx_diagnostic/torch_onnx/sbs.py +439 -0
  46. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/METADATA +14 -4
  47. onnx_diagnostic-0.4.0.dist-info/RECORD +86 -0
  48. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/WHEEL +1 -1
  49. onnx_diagnostic/cache_helpers.py +0 -104
  50. onnx_diagnostic/onnx_tools.py +0 -260
  51. onnx_diagnostic-0.2.2.dist-info/RECORD +0 -59
  52. /onnx_diagnostic/{args.py → helpers/args_helper.py} +0 -0
  53. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/licenses/LICENSE.txt +0 -0
  54. {onnx_diagnostic-0.2.2.dist-info → onnx_diagnostic-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,921 @@
1
+ import ctypes
2
+ import functools
3
+ import json
4
+ import os
5
+ import sys
6
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import onnx
10
+ import onnx.helper as oh
11
+ import onnx.numpy_helper as onh
12
+ from onnx import (
13
+ AttributeProto,
14
+ FunctionProto,
15
+ GraphProto,
16
+ ModelProto,
17
+ NodeProto,
18
+ TensorProto,
19
+ ValueInfoProto,
20
+ load as onnx_load,
21
+ )
22
+
23
+
24
+ def _make_stat(init: TensorProto) -> Dict[str, float]:
25
+ """
26
+ Produces statistics.
27
+
28
+ :param init: tensor
29
+ :return statistics
30
+ """
31
+ ar = onh.to_array(init)
32
+ return dict(
33
+ mean=float(ar.mean()),
34
+ std=float(ar.std()),
35
+ shape=ar.shape,
36
+ itype=np_dtype_to_tensor_dtype(ar.dtype),
37
+ min=float(ar.min()),
38
+ max=float(ar.max()),
39
+ )
40
+
41
+
42
+ def onnx_lighten(
43
+ onx: Union[str, ModelProto],
44
+ verbose: int = 0,
45
+ ) -> Tuple[ModelProto, Dict[str, Dict[str, float]]]:
46
+ """
47
+ Creates a model without big initializers but stores statistics
48
+ into dictionaries. The function can be reversed with
49
+ :func:`onnx_diagnostic.helpers.onnx_helper.onnx_unlighten`.
50
+ The model is modified inplace.
51
+
52
+ :param onx: model
53
+ :param verbose: verbosity
54
+ :return: new model, statistics
55
+ """
56
+ if isinstance(onx, str):
57
+ if verbose:
58
+ print(f"[onnx_lighten] load {onx!r}")
59
+ model = onnx.load(onx)
60
+ else:
61
+ assert isinstance(onx, ModelProto), f"Unexpected type {type(onx)}"
62
+ model = onx
63
+
64
+ keep = []
65
+ stats = []
66
+ for init in model.graph.initializer:
67
+ shape = init.dims
68
+ size = np.prod(shape)
69
+ if size > 2**12:
70
+ stat = _make_stat(init)
71
+ stats.append((init.name, stat))
72
+ if verbose:
73
+ print(f"[onnx_lighten] remove initializer {init.name!r} stat={stat}")
74
+ else:
75
+ keep.append(init)
76
+
77
+ del model.graph.initializer[:]
78
+ model.graph.initializer.extend(keep)
79
+ return model, dict(stats)
80
+
81
+
82
+ def _get_tensor(min=None, max=None, mean=None, std=None, shape=None, itype=None):
83
+ assert itype is not None, "itype must be specified."
84
+ assert shape is not None, "shape must be specified."
85
+ dtype = tensor_dtype_to_np_dtype(itype)
86
+ if (mean is None or std is None) or (
87
+ min is not None and max is not None and abs(max - min - 1) < 0.01
88
+ ):
89
+ if min is None:
90
+ min = 0
91
+ if max is None:
92
+ max = 0
93
+ return (np.random.random(shape) * (max - min) + min).astype(dtype)
94
+ assert std is not None and mean is not None, f"mean={mean} or std={std} is None"
95
+ t = np.random.randn(*shape).astype(dtype)
96
+ return t
97
+
98
+
99
+ def onnx_unlighten(
100
+ onx: Union[str, ModelProto],
101
+ stats: Optional[Dict[str, Dict[str, float]]] = None,
102
+ verbose: int = 0,
103
+ ) -> ModelProto:
104
+ """
105
+ Function fixing the model produced by function
106
+ :func:`onnx_diagnostic.helpers.onnx_helper.onnx_lighten`.
107
+ The model is modified inplace.
108
+
109
+ :param onx: model
110
+ :param stats: statistics, can be None if onx is a file,
111
+ then it loads the file ``<filename>.stats``,
112
+ it assumes it is json format
113
+ :param verbose: verbosity
114
+ :return: new model, statistics
115
+ """
116
+ if isinstance(onx, str):
117
+ if stats is None:
118
+ fstats = f"{onx}.stats"
119
+ assert os.path.exists(fstats), f"File {fstats!r} is missing."
120
+ if verbose:
121
+ print(f"[onnx_unlighten] load {fstats!r}")
122
+ with open(fstats, "r") as f:
123
+ stats = json.load(f)
124
+ if verbose:
125
+ print(f"[onnx_unlighten] load {onx!r}")
126
+ model = onnx.load(onx)
127
+ else:
128
+ assert isinstance(onx, ModelProto), f"Unexpected type {type(onx)}"
129
+ model = onx
130
+ assert stats is not None, "stats is missing"
131
+
132
+ keep = []
133
+ for name, stat in stats.items():
134
+ t = _get_tensor(**stat)
135
+ init = from_array_extended(t, name=name)
136
+ keep.append(init)
137
+
138
+ model.graph.initializer.extend(keep)
139
+ return model
140
+
141
+
142
+ def _validate_graph(
143
+ g: GraphProto,
144
+ existing: Set[str],
145
+ verbose: int = 0,
146
+ watch: Optional[Set[str]] = None,
147
+ path: Optional[Sequence[str]] = None,
148
+ ):
149
+ found = []
150
+ path = path or ["root"]
151
+ set_init = set(i.name for i in g.initializer)
152
+ set_input = set(i.name for i in g.input)
153
+ existing |= set_init | set_input
154
+ if watch and set_init & watch:
155
+ if verbose:
156
+ print(f"-- found init {set_init & watch} in {path}")
157
+ found.extend([i for i in g.initializer if i.name in set_init & watch])
158
+ if watch and set_input & watch:
159
+ if verbose:
160
+ print(f"-- found input {set_input & watch} in {path}")
161
+ found.extend([i for i in g.input if i.name in set_input & watch])
162
+ try:
163
+ import tqdm
164
+
165
+ loop = tqdm.tqdm(g.node) if verbose else g.node
166
+ except ImportError:
167
+ loop = g.node
168
+
169
+ for node in loop:
170
+ ins = set(node.input) & existing
171
+ if ins != set(node.input):
172
+ raise AssertionError(
173
+ f"One input is missing from node.input={node.input}, "
174
+ f"existing={ins}, path={'/'.join(path)}, "
175
+ f"node: {node.op_type}[{node.name}]"
176
+ )
177
+ if watch and ins & watch:
178
+ if verbose:
179
+ print(
180
+ f"-- found input {ins & watch} in "
181
+ f"{'/'.join(path)}/{node.op_type}[{node.name}]"
182
+ )
183
+ found.append(node)
184
+ for att in node.attribute:
185
+ if att.type == AttributeProto.GRAPH:
186
+ found.extend(
187
+ _validate_graph(
188
+ att.g,
189
+ existing.copy(),
190
+ watch=watch,
191
+ path=[*path, f"{node.op_type}[{node.name}]"],
192
+ verbose=verbose,
193
+ )
194
+ )
195
+ existing |= set(node.output)
196
+ if watch and set(node.output) & watch:
197
+ if verbose:
198
+ print(
199
+ f"-- found output {set(node.output) & watch} "
200
+ f"in {'/'.join(path)}/{node.op_type}[{node.name}]"
201
+ )
202
+ found.append(node)
203
+ out = set(o.name for o in g.output)
204
+ ins = out & existing
205
+ if ins != out:
206
+ raise AssertionError(
207
+ f"One output is missing, out={node.input}, existing={ins}, path={path}"
208
+ )
209
+ return found
210
+
211
+
212
+ def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[str]] = None):
213
+ existing = set(g.input)
214
+ found = []
215
+ for node in g.node:
216
+ ins = set(node.input) & existing
217
+ if ins != set(node.input):
218
+ raise AssertionError(
219
+ f"One input is missing from node.input={node.input}, existing={ins}"
220
+ )
221
+ if watch and ins & watch:
222
+ if verbose:
223
+ print(f"-- found input {ins & watch} in {node.op_type}[{node.name}]")
224
+ found.append(node)
225
+ for att in node.attribute:
226
+ if att.type == AttributeProto.GRAPH:
227
+ found.extend(
228
+ _validate_graph(g, existing.copy(), path=[g.name], verbose=verbose)
229
+ )
230
+ existing |= set(node.output)
231
+ if watch and set(node.output) & watch:
232
+ if verbose:
233
+ print(
234
+ f"-- found output {set(node.output) & watch} "
235
+ f"in {node.op_type}[{node.name}]"
236
+ )
237
+ out = set(g.output)
238
+ ins = out & existing
239
+ if ins != out:
240
+ raise AssertionError(
241
+ f"One output is missing, out={node.input}, existing={ins}, path={g.name}"
242
+ )
243
+ return found
244
+
245
+
246
+ def onnx_find(
247
+ onx: Union[str, ModelProto], verbose: int = 0, watch: Optional[Set[str]] = None
248
+ ) -> List[Union[NodeProto, TensorProto]]:
249
+ """
250
+ Looks for node producing or consuming some results.
251
+
252
+ :param onx: model
253
+ :param verbose: verbosity
254
+ :param watch: names to search for
255
+ :return: list of nodes
256
+ """
257
+
258
+ if isinstance(onx, str):
259
+ onx = onnx.load(onx, load_external_data=False)
260
+ found = []
261
+ found.extend(_validate_graph(onx.graph, set(), verbose=verbose, watch=watch))
262
+ for f in onx.functions:
263
+ found.extend(_validate_function(f, watch=watch, verbose=verbose))
264
+ if verbose and found:
265
+ print(f"-- found {len(found)} nodes")
266
+ return found
267
+
268
+
269
+ def check_model_ort(
270
+ onx: ModelProto,
271
+ providers: Optional[Union[str, List[Any]]] = None,
272
+ dump_file: Optional[str] = None,
273
+ ) -> "onnxruntime.InferenceSession": # noqa: F821
274
+ """
275
+ Loads a model with onnxruntime.
276
+
277
+ :param onx: ModelProto
278
+ :param providers: list of providers, None fur CPU, cpu for CPU, cuda for CUDA
279
+ :param dump_file: if not empty, dumps the model into this file if
280
+ an error happened
281
+ :return: InferenceSession
282
+ """
283
+ from onnxruntime import InferenceSession
284
+
285
+ if providers is None or providers == "cpu":
286
+ providers = ["CPUExecutionProvider"]
287
+ elif not isinstance(providers, list) and providers.startswith("cuda"):
288
+ device_id = 0 if ":" not in providers else int(providers.split(":")[1])
289
+ providers = [
290
+ ("CUDAExecutionProvider", {"device_id": device_id}),
291
+ ("CPUExecutionProvider", {}),
292
+ ]
293
+
294
+ if isinstance(onx, str):
295
+ try:
296
+ return InferenceSession(onx, providers=providers)
297
+ except Exception as e:
298
+ import onnx
299
+
300
+ if dump_file:
301
+ onnx.save(onx, dump_file)
302
+
303
+ raise AssertionError( # noqa: B904
304
+ f"onnxruntime cannot load the model "
305
+ f"due to {e}\n{pretty_onnx(onnx.load(onx))}"
306
+ )
307
+ return
308
+ try:
309
+ return InferenceSession(onx.SerializeToString(), providers=providers)
310
+ except Exception as e:
311
+ if dump_file:
312
+ onnx.save(onx, dump_file)
313
+ raise AssertionError( # noqa: B904
314
+ f"onnxruntime cannot load the modeldue to {e}\n{pretty_onnx(onx)}"
315
+ )
316
+
317
+
318
+ @functools.cache
319
+ def onnx_dtype_name(itype: int) -> str:
320
+ """
321
+ Returns the ONNX name for a specific element type.
322
+
323
+ .. runpython::
324
+ :showcode:
325
+
326
+ import onnx
327
+ from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name
328
+
329
+ itype = onnx.TensorProto.BFLOAT16
330
+ print(onnx_dtype_name(itype))
331
+ print(onnx_dtype_name(7))
332
+ """
333
+ for k in dir(TensorProto):
334
+ v = getattr(TensorProto, k)
335
+ if v == itype:
336
+ return k
337
+ raise ValueError(f"Unexpected value itype: {itype}")
338
+
339
+
340
+ def pretty_onnx(
341
+ onx: Union[FunctionProto, GraphProto, ModelProto, ValueInfoProto, str],
342
+ with_attributes: bool = False,
343
+ highlight: Optional[Set[str]] = None,
344
+ shape_inference: bool = False,
345
+ ) -> str:
346
+ """
347
+ Displays an onnx prot in a better way.
348
+
349
+ :param with_attributes: displays attributes as well, if only a node is printed
350
+ :param highlight: to highlight some names
351
+ :param shape_inference: run shape inference before printing the model
352
+ :return: text
353
+ """
354
+ assert onx is not None, "onx cannot be None"
355
+ if isinstance(onx, str):
356
+ onx = onnx_load(onx, load_external_data=False)
357
+ assert onx is not None, "onx cannot be None"
358
+
359
+ if shape_inference:
360
+ onx = onnx.shape_inference.infer_shapes(onx)
361
+
362
+ if isinstance(onx, ValueInfoProto):
363
+ name = onx.name
364
+ itype = onx.type.tensor_type.elem_type
365
+ shape = tuple((d.dim_param or d.dim_value) for d in onx.type.tensor_type.shape.dim)
366
+ shape_str = ",".join(map(str, shape))
367
+ return f"{onnx_dtype_name(itype)}[{shape_str}] {name}"
368
+
369
+ if isinstance(onx, AttributeProto):
370
+ att = onx
371
+ if att.type == AttributeProto.INT:
372
+ return f"{att.name}={att.i}"
373
+ if att.type == AttributeProto.INTS:
374
+ return f"{att.name}={att.ints}"
375
+ if att.type == AttributeProto.FLOAT:
376
+ return f"{att.name}={att.f}"
377
+ if att.type == AttributeProto.FLOATS:
378
+ return f"{att.name}={att.floats}"
379
+ if att.type == AttributeProto.STRING:
380
+ return f"{att.name}={att.s!r}"
381
+ if att.type == AttributeProto.TENSOR:
382
+ v = to_array_extended(att.t)
383
+ assert hasattr(v, "reshape"), f"not a tensor {type(v)}"
384
+ assert hasattr(v, "shape"), f"not a tensor {type(v)}"
385
+ vf = v.reshape((-1,))
386
+ if vf.size < 10:
387
+ tt = f"[{', '.join(map(str, vf))}]"
388
+ else:
389
+ tt = f"[{', '.join(map(str, vf[:10]))}, ...]"
390
+ if len(v.shape) != 1:
391
+ return f"{att.name}=tensor({tt}, dtype={v.dtype}).reshape({v.shape})"
392
+ return f"{att.name}=tensor({tt}, dtype={v.dtype})"
393
+ raise NotImplementedError(
394
+ f"pretty_onnx not implemented yet for AttributeProto={att!r}"
395
+ )
396
+
397
+ if isinstance(onx, NodeProto):
398
+
399
+ def _high(n):
400
+ if highlight and n in highlight:
401
+ return f"**{n}**"
402
+ return n
403
+
404
+ text = (
405
+ f"{onx.op_type}({', '.join(map(_high, onx.input))})"
406
+ f" -> {', '.join(map(_high, onx.output))}"
407
+ )
408
+ if onx.domain:
409
+ text = f"{onx.domain}.{text}"
410
+ if not with_attributes or not onx.attribute:
411
+ return text
412
+ rows = []
413
+ for att in onx.attribute:
414
+ rows.append(pretty_onnx(att))
415
+ if len(rows) > 1:
416
+ suffix = "\n".join(f" {s}" for s in rows)
417
+ return f"{text}\n{suffix}"
418
+ return f"{text} --- {rows[0]}"
419
+
420
+ if isinstance(onx, TensorProto):
421
+ shape = "x".join(map(str, onx.dims))
422
+ return f"TensorProto:{onx.data_type}:{shape}:{onx.name}"
423
+
424
+ try:
425
+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
426
+
427
+ if isinstance(onx, FunctionProto):
428
+ return (
429
+ f"function: {onx.name}[{onx.domain}]\n"
430
+ f"{onnx_simple_text_plot(onx, recursive=True)}"
431
+ )
432
+ return onnx_simple_text_plot(onx, recursive=True)
433
+ except ImportError:
434
+ from onnx.printer import to_text
435
+
436
+ return to_text(onx)
437
+
438
+
439
+ def get_onnx_signature(model: ModelProto) -> Tuple[Tuple[str, Any], ...]:
440
+ """
441
+ Produces a tuple of tuples corresponding to the signatures.
442
+
443
+ :param model: model
444
+ :return: signature
445
+ """
446
+ sig: List[Any] = []
447
+ for i in model.graph.input:
448
+ dt = i.type
449
+ if dt.HasField("sequence_type"):
450
+ dst = dt.sequence_type.elem_type
451
+ tdt = dst.tensor_type
452
+ el = tdt.elem_type
453
+ shape = tuple(d.dim_param or d.dim_value for d in tdt.shape.dim)
454
+ sig.append((i.name, [(i.name, el, shape)]))
455
+ elif dt.HasField("tensor_type"):
456
+ el = dt.tensor_type.elem_type
457
+ shape = tuple(d.dim_param or d.dim_value for d in dt.tensor_type.shape.dim)
458
+ sig.append((i.name, el, shape))
459
+ else:
460
+ raise AssertionError(f"Unable to interpret dt={dt!r} in {i!r}")
461
+ return tuple(sig)
462
+
463
+
464
+ def convert_endian(tensor: TensorProto) -> None:
465
+ """Call to convert endianness of raw data in tensor.
466
+
467
+ Args:
468
+ tensor: TensorProto to be converted.
469
+ """
470
+ tensor_dtype = tensor.data_type
471
+ np_dtype = tensor_dtype_to_np_dtype(tensor_dtype)
472
+ tensor.raw_data = np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap().tobytes()
473
+
474
+
475
+ def from_array_ml_dtypes(arr: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
476
+ """
477
+ Converts a numpy array to a tensor def assuming the dtype
478
+ is defined in ml_dtypes.
479
+
480
+ Args:
481
+ arr: a numpy array.
482
+ name: (optional) the name of the tensor.
483
+
484
+ Returns:
485
+ TensorProto: the converted tensor def.
486
+ """
487
+ import ml_dtypes
488
+
489
+ assert isinstance(arr, np.ndarray), f"arr must be of type numpy.ndarray, got {type(arr)}"
490
+
491
+ tensor = TensorProto()
492
+ tensor.dims.extend(arr.shape)
493
+ if name:
494
+ tensor.name = name
495
+
496
+ if arr.dtype == ml_dtypes.bfloat16:
497
+ dtype = TensorProto.BFLOAT16
498
+ elif arr.dtype == ml_dtypes.float8_e4m3fn:
499
+ dtype = TensorProto.FLOAT8E4M3FN
500
+ elif arr.dtype == ml_dtypes.float8_e4m3fnuz:
501
+ dtype = TensorProto.FLOAT8E4M3FNUZ
502
+ elif arr.dtype == ml_dtypes.float8_e5m2:
503
+ dtype = TensorProto.FLOAT8E5M2
504
+ elif arr.dtype == ml_dtypes.float8_e5m2fnuz:
505
+ dtype = TensorProto.FLOAT8E5M2FNUZ
506
+ else:
507
+ raise NotImplementedError(f"No conversion from {arr.dtype}")
508
+ tensor.data_type = dtype
509
+ tensor.raw_data = arr.tobytes() # note: tobytes() is only after 1.9.
510
+ if sys.byteorder == "big":
511
+ convert_endian(tensor)
512
+ return tensor
513
+
514
+
515
+ _STORAGE_TYPE = {
516
+ TensorProto.FLOAT16: np.int16,
517
+ TensorProto.BFLOAT16: np.int16,
518
+ }
519
+
520
+
521
+ def proto_from_tensor(
522
+ arr: "torch.Tensor", # noqa: F821
523
+ name: Optional[str] = None,
524
+ verbose: int = 0,
525
+ ) -> TensorProto:
526
+ """
527
+ Converts a torch Tensor into a TensorProto.
528
+
529
+ :param arr: tensor
530
+ :param verbose: display the type and shape
531
+ :return: a TensorProto
532
+ """
533
+ import torch
534
+
535
+ if not isinstance(arr, torch.Tensor):
536
+ raise TypeError(f"Unexpected type {type(arr)}.")
537
+ if arr.is_sparse:
538
+ raise NotImplementedError(
539
+ f"Sparse tensor is not supported yet but initializer {name!r} is."
540
+ )
541
+
542
+ # arr.contiguous() is slow after a transpose, maybe there is a way to optimize this.
543
+ if arr.is_contiguous():
544
+ arr_cpu = arr.cpu()
545
+ else:
546
+ arr_cpu = arr.contiguous().cpu()
547
+
548
+ numel = torch.numel(arr_cpu)
549
+ element_size = arr_cpu.element_size()
550
+
551
+ if arr_cpu.dtype in {torch.bfloat16}:
552
+ np_arr = arr_cpu
553
+ elif arr_cpu.data_ptr() == arr.data_ptr():
554
+ copy = arr_cpu.clone().detach().requires_grad_(False)
555
+ assert (
556
+ arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr()
557
+ ), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}"
558
+ np_arr = np.from_dlpack(copy)
559
+ else:
560
+ np_arr = np.from_dlpack(arr_cpu.detach())
561
+
562
+ tensor = TensorProto()
563
+ tensor.dims.extend(arr_cpu.shape)
564
+ if name:
565
+ tensor.name = name
566
+ itype = torch_dtype_to_onnx_dtype(arr_cpu.dtype)
567
+ assert not hasattr(TensorProto, "INT4") or itype not in {
568
+ TensorProto.INT4,
569
+ TensorProto.UINT4,
570
+ }, f"Type {arr.dtype} is not supported yet for name={name!r}"
571
+ tensor.data_type = itype
572
+
573
+ if verbose > 1 and numel > 100:
574
+ print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]")
575
+
576
+ if isinstance(np_arr, torch.Tensor):
577
+ byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
578
+ tensor.raw_data = bytes(byte_data)
579
+ if sys.byteorder == "big":
580
+ np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore
581
+ np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore
582
+ else:
583
+ tensor.raw_data = np_arr.tobytes()
584
+ if sys.byteorder == "big":
585
+ np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
586
+ np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
587
+
588
+ return tensor
589
+
590
+
591
+ def from_array_extended(tensor: npt.ArrayLike, name: Optional[str] = None) -> TensorProto:
592
+ """
593
+ Converts an array into a :class:`onnx.TensorProto`.
594
+
595
+ :param tensor: numpy array or torch tensor
596
+ :param name: name
597
+ :return: TensorProto
598
+ """
599
+ try:
600
+ import torch
601
+ except ImportError:
602
+ torch = None
603
+ if torch is not None and isinstance(tensor, torch.Tensor):
604
+ return proto_from_tensor(tensor, name=name)
605
+
606
+ from onnx.reference.ops.op_cast import (
607
+ bfloat16,
608
+ float8e4m3fn,
609
+ float8e4m3fnuz,
610
+ float8e5m2,
611
+ float8e5m2fnuz,
612
+ )
613
+
614
+ dt = tensor.dtype
615
+ if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
616
+ to = TensorProto.FLOAT8E4M3FN
617
+ dt_to = np.uint8
618
+ elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
619
+ to = TensorProto.FLOAT8E4M3FNUZ
620
+ dt_to = np.uint8
621
+ elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
622
+ to = TensorProto.FLOAT8E5M2
623
+ dt_to = np.uint8
624
+ elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
625
+ to = TensorProto.FLOAT8E5M2FNUZ
626
+ dt_to = np.uint8
627
+ elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
628
+ to = TensorProto.BFLOAT16
629
+ dt_to = np.uint16
630
+ else:
631
+ try:
632
+ import ml_dtypes
633
+ except ImportError:
634
+ ml_dtypes = None
635
+ if ml_dtypes is not None and (
636
+ tensor.dtype == ml_dtypes.bfloat16
637
+ or tensor.dtype == ml_dtypes.float8_e4m3fn
638
+ or tensor.dtype == ml_dtypes.float8_e4m3fnuz
639
+ or tensor.dtype == ml_dtypes.float8_e5m2
640
+ or tensor.dtype == ml_dtypes.float8_e5m2fnuz
641
+ ):
642
+ return from_array_ml_dtypes(tensor, name)
643
+ return onh.from_array(tensor, name)
644
+
645
+ t = onh.from_array(tensor.astype(dt_to), name)
646
+ t.data_type = to
647
+ return t
648
+
649
+
650
+ def to_array_extended(proto: TensorProto) -> npt.ArrayLike:
651
+ """Converts :class:`onnx.TensorProto` into a numpy array."""
652
+ arr = onh.to_array(proto)
653
+ if proto.data_type >= onnx.TensorProto.BFLOAT16:
654
+ # Types not supported by numpy
655
+ ml_dtypes = onnx_dtype_to_np_dtype(proto.data_type)
656
+ return arr.view(ml_dtypes)
657
+ return arr
658
+
659
+
660
+ def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
661
+ """
662
+ Converts an onnx type into a torch dtype.
663
+
664
+ :param to: onnx dtype
665
+ :return: torch dtype
666
+ """
667
+ import torch
668
+
669
+ if itype == TensorProto.FLOAT:
670
+ return torch.float32
671
+ if itype == TensorProto.FLOAT16:
672
+ return torch.float16
673
+ if itype == TensorProto.BFLOAT16:
674
+ return torch.bfloat16
675
+ if itype == TensorProto.DOUBLE:
676
+ return torch.float64
677
+ if itype == TensorProto.INT32:
678
+ return torch.int32
679
+ if itype == TensorProto.INT64:
680
+ return torch.int64
681
+ if itype == TensorProto.UINT32:
682
+ return torch.uint32
683
+ if itype == TensorProto.UINT64:
684
+ return torch.uint64
685
+ if itype == TensorProto.BOOL:
686
+ return torch.bool
687
+ if itype == TensorProto.INT16:
688
+ return torch.int16
689
+ if itype == TensorProto.UINT16:
690
+ return torch.uint16
691
+ if itype == TensorProto.INT8:
692
+ return torch.int16
693
+ if itype == TensorProto.UINT8:
694
+ return torch.uint16
695
+ if itype == TensorProto.COMPLEX64:
696
+ return torch.complex64
697
+ if itype == TensorProto.COMPLEX128:
698
+ return torch.complex128
699
+ raise NotImplementedError(
700
+ f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
701
+ )
702
+
703
+
704
+ def onnx_dtype_to_np_dtype(itype: int) -> Any:
705
+ """
706
+ Converts an onnx type into a to numpy dtype.
707
+ That includes :epkg:`ml_dtypes` dtypes.
708
+
709
+ :param to: onnx dtype
710
+ :return: numpy dtype
711
+ """
712
+ if itype == TensorProto.FLOAT:
713
+ return np.float32
714
+ if itype == TensorProto.FLOAT16:
715
+ return np.float16
716
+ if itype == TensorProto.BFLOAT16:
717
+ import ml_dtypes
718
+
719
+ return ml_dtypes.bfloat16
720
+ if itype == TensorProto.DOUBLE:
721
+ return np.float64
722
+ if itype == TensorProto.INT32:
723
+ return np.int32
724
+ if itype == TensorProto.INT64:
725
+ return np.int64
726
+ if itype == TensorProto.UINT32:
727
+ return np.uint32
728
+ if itype == TensorProto.UINT64:
729
+ return np.uint64
730
+ if itype == TensorProto.BOOL:
731
+ return np.bool
732
+ if itype == TensorProto.INT16:
733
+ return np.int16
734
+ if itype == TensorProto.UINT16:
735
+ return np.uint16
736
+ if itype == TensorProto.INT8:
737
+ return np.int16
738
+ if itype == TensorProto.UINT8:
739
+ return np.uint16
740
+ if itype == TensorProto.COMPLEX64:
741
+ return np.complex64
742
+ if itype == TensorProto.COMPLEX128:
743
+ return np.complex128
744
+ raise NotImplementedError(
745
+ f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
746
+ )
747
+
748
+
749
+ def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
750
+ """
751
+ Converts a torch dtype into a onnx element type.
752
+
753
+ :param to: torch dtype
754
+ :return: onnx type
755
+ """
756
+ import torch
757
+
758
+ if to == torch.float32:
759
+ return TensorProto.FLOAT
760
+ if to == torch.float16:
761
+ return TensorProto.FLOAT16
762
+ if to == torch.bfloat16:
763
+ return TensorProto.BFLOAT16
764
+ if to == torch.float64:
765
+ return TensorProto.DOUBLE
766
+ if to == torch.int64:
767
+ return TensorProto.INT64
768
+ if to == torch.int32:
769
+ return TensorProto.INT32
770
+ if to == torch.uint64:
771
+ return TensorProto.UINT64
772
+ if to == torch.uint32:
773
+ return TensorProto.UINT32
774
+ if to == torch.bool:
775
+ return TensorProto.BOOL
776
+ if to == torch.SymInt:
777
+ return TensorProto.INT64
778
+ if to == torch.int16:
779
+ return TensorProto.INT16
780
+ if to == torch.uint16:
781
+ return TensorProto.UINT16
782
+ if to == torch.int8:
783
+ return TensorProto.INT8
784
+ if to == torch.uint8:
785
+ return TensorProto.UINT8
786
+ if to == torch.SymFloat:
787
+ return TensorProto.FLOAT
788
+ if to == torch.complex64:
789
+ return TensorProto.COMPLEX64
790
+ if to == torch.complex128:
791
+ return TensorProto.COMPLEX128
792
+ raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
793
+
794
+
795
+ def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F821
796
+ """
797
+ Converts a torch dtype or numpy dtype into a onnx element type.
798
+
799
+ :param to: dtype
800
+ :return: onnx type
801
+ """
802
+ try:
803
+ return np_dtype_to_tensor_dtype(dt)
804
+ except (KeyError, TypeError, ValueError):
805
+ pass
806
+ return torch_dtype_to_onnx_dtype(dt)
807
+
808
+
809
+ def np_dtype_to_tensor_dtype(dt: np.dtype) -> int: # noqa: F821
810
+ """
811
+ Converts a numpy dtype into a onnx element type.
812
+
813
+ :param to: dtype
814
+ :return: onnx type
815
+ """
816
+ try:
817
+ return oh.np_dtype_to_tensor_dtype(dt)
818
+ except ValueError:
819
+ try:
820
+ import ml_dtypes
821
+ except ImportError:
822
+ ml_dtypes = None # type: ignore
823
+ if ml_dtypes is not None:
824
+ if dt == ml_dtypes.bfloat16:
825
+ return TensorProto.BFLOAT16
826
+ if dt == ml_dtypes.float8_e4m3fn:
827
+ return TensorProto.FLOAT8E4M3FN
828
+ if dt == ml_dtypes.float8_e4m3fnuz:
829
+ return TensorProto.FLOAT8E4M3FNUZ
830
+ if dt == ml_dtypes.float8_e5m2:
831
+ return TensorProto.FLOAT8E5M2
832
+ if dt == ml_dtypes.float8_e5m2fnuz:
833
+ return TensorProto.FLOAT8E5M2FNUZ
834
+ if dt == np.float32:
835
+ return TensorProto.FLOAT
836
+ if dt == np.float16:
837
+ return TensorProto.FLOAT16
838
+ if dt == np.float64:
839
+ return TensorProto.DOUBLE
840
+ if dt == np.int64:
841
+ return TensorProto.INT64
842
+ if dt == np.uint64:
843
+ return TensorProto.UINT64
844
+ if dt == np.int16:
845
+ return TensorProto.INT16
846
+ if dt == np.uint16:
847
+ return TensorProto.UINT16
848
+ if dt == np.int32:
849
+ return TensorProto.INT32
850
+ if dt == np.int8:
851
+ return TensorProto.INT8
852
+ if dt == np.uint8:
853
+ return TensorProto.UINT8
854
+ if dt == np.uint32:
855
+ return TensorProto.UINT32
856
+ if dt == np.bool:
857
+ return TensorProto.BOOL
858
+ if dt == np.complex64:
859
+ return TensorProto.COMPLEX64
860
+ if dt == np.complex128:
861
+ return TensorProto.COMPLEX128
862
+ raise ValueError(f"Unable to convert type {dt}")
863
+
864
+
865
+ def type_info(itype: int, att: str):
866
+ """
867
+ Returns the minimum or maximum value for a type.
868
+
869
+ :param itype: onnx type
870
+ :param att: 'min' or 'max'
871
+ :return: value
872
+ """
873
+ if itype in {TensorProto.FLOAT, TensorProto.FLOAT16, TensorProto.DOUBLE}:
874
+ dtype = tensor_dtype_to_np_dtype(itype)
875
+ fi = np.finfo(dtype)
876
+ elif itype == TensorProto.BFLOAT16:
877
+ import ml_dtypes
878
+
879
+ dtype = tensor_dtype_to_np_dtype(itype)
880
+ fi = ml_dtypes.finfo(dtype) # type: ignore
881
+ else:
882
+ dtype = tensor_dtype_to_np_dtype(itype)
883
+ fi = np.iinfo(dtype) # type: ignore
884
+ if att == "min":
885
+ return fi.min
886
+ if att == "max":
887
+ return fi.max
888
+ raise ValueError(f"Unexpected value {att!r}")
889
+
890
+
891
+ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
892
+ """
893
+ Converts a TensorProto's data_type to corresponding numpy dtype.
894
+ It can be used while making tensor.
895
+
896
+ :param tensor_dtype: TensorProto's data_type
897
+ :return: numpy's data_type
898
+ """
899
+ if tensor_dtype >= 16:
900
+ try:
901
+ import ml_dtypes # noqa: F401
902
+ except ImportError as e:
903
+ raise ValueError(
904
+ f"Unsupported value for tensor_dtype, "
905
+ f"numpy does not support onnx type {tensor_dtype}. "
906
+ f"ml_dtypes can be used."
907
+ ) from e
908
+
909
+ mapping: Dict[int, np.dtype] = {
910
+ TensorProto.BFLOAT16: ml_dtypes.bfloat16,
911
+ TensorProto.FLOAT8E4M3FN: ml_dtypes.float8_e4m3fn,
912
+ TensorProto.FLOAT8E4M3FNUZ: ml_dtypes.float8_e4m3fnuz,
913
+ TensorProto.FLOAT8E5M2: ml_dtypes.float8_e5m2,
914
+ TensorProto.FLOAT8E5M2FNUZ: ml_dtypes.float8_e5m2fnuz,
915
+ }
916
+ assert (
917
+ tensor_dtype in mapping
918
+ ), f"Unable to find tensor_dtype={tensor_dtype!r} in mapping={mapping}"
919
+ return mapping[tensor_dtype]
920
+
921
+ return oh.tensor_dtype_to_np_dtype(tensor_dtype)