onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +18 -0
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/ext_test_case.py +3 -1
  5. onnx_diagnostic/helpers/args_helper.py +1 -1
  6. onnx_diagnostic/helpers/doc_helper.py +143 -0
  7. onnx_diagnostic/helpers/helper.py +6 -5
  8. onnx_diagnostic/helpers/model_builder_helper.py +24 -8
  9. onnx_diagnostic/helpers/rt_helper.py +5 -1
  10. onnx_diagnostic/helpers/torch_helper.py +2 -0
  11. onnx_diagnostic/reference/__init__.py +1 -0
  12. onnx_diagnostic/reference/torch_evaluator.py +648 -0
  13. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  14. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  15. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  16. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  17. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  18. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  19. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  20. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  21. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  22. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  23. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  24. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  25. onnx_diagnostic/tasks/__init__.py +22 -1
  26. onnx_diagnostic/tasks/image_classification.py +2 -2
  27. onnx_diagnostic/tasks/text_generation.py +3 -3
  28. onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
  29. onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
  30. onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
  31. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
  32. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  33. onnx_diagnostic/torch_models/test_helper.py +133 -16
  34. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  35. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/METADATA +1 -1
  36. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/RECORD +39 -23
  37. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/WHEEL +1 -1
  38. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/licenses/LICENSE.txt +0 -0
  39. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,648 @@
1
+ import functools
2
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
3
+ import numpy as np
4
+ import onnx
5
+ import torch
6
+ from ..helpers.torch_helper import to_tensor
7
+ from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
8
+ from . import torch_ops
9
+
10
+
11
+ @functools.lru_cache
12
+ def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRunKernel]]:
13
+ """
14
+ Retrieves all the available kernels class :class:`TorchOnnxEvaluator`
15
+ can use. The full list is the following.
16
+
17
+ .. runpython::
18
+ :showcode:
19
+
20
+ from onnx_diagnostic.reference.torch_evaluator import get_kernels
21
+
22
+ for k, v in sorted(get_kernels().items()):
23
+ domain, name, version = k
24
+ f = f"{name}({version})" if domain == "" else f"{name}[{domain}]({version})"
25
+ add = " " * max(25 - len(f), 0)
26
+ dd = " -- device dependent" if v.device_dependent() else ""
27
+ print(f"{f}{add} -- {v.__name__}{dd}")
28
+ """
29
+ res = {}
30
+ for _k, v in torch_ops.__dict__.items():
31
+ if isinstance(v, type) and issubclass(v, torch_ops.OpRunKernel) and "_" in v.__name__:
32
+ name, version = v.__name__.split("_")
33
+ domain = getattr(v, "domain", "")
34
+ res[domain, name, int(version)] = v
35
+ return res
36
+
37
+
38
+ class TorchOnnxEvaluator:
39
+ """
40
+ Torch evaluator for onnx models.
41
+ The model does not stores the original proto it evaluates to avoid
42
+
43
+ :param proto: a proto
44
+ :param providers: where to run the model
45
+ :param opsets: needed if proto is a graph
46
+ :param functions: known local functions
47
+ :param verbose: verbosity level
48
+ :param custom_kernels: dictionary of kernels the user can defined to overwrite
49
+ a specific implementation: ``("", "LayerNormalization"): CustomKernel``
50
+
51
+ The class holds the following attributes:
52
+
53
+ * `providers`: providers
54
+ * `default_device`: default torch device
55
+ * `constants`: all initializers or constants
56
+ * `kernels`: kernels
57
+ * `runtime_info`: produced by :func:`first_used_last_used
58
+ <onnx_diagnostic.torch_onnx.runtime_info.first_used_last_used>`
59
+ * `last_used`: contains the list of intermediate results,
60
+ to remove after every node execution,
61
+ this avoid the memory to grow too much
62
+ * `functions`: local functions
63
+
64
+ The class is not multithreaded. `runtime_info` gets updated
65
+ by the the class. The list of available kernels is returned by function
66
+ :func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
67
+ Example:
68
+
69
+ .. runpython::
70
+ :showcode:
71
+
72
+ import onnx
73
+ import onnx.helper as oh
74
+ import torch
75
+ from onnx_diagnostic.helpers import string_type
76
+ from onnx_diagnostic.reference import TorchOnnxEvaluator
77
+
78
+ TFLOAT = onnx.TensorProto.FLOAT
79
+
80
+ proto = oh.make_model(
81
+ oh.make_graph(
82
+ [
83
+ oh.make_node("Sigmoid", ["Y"], ["sy"]),
84
+ oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
85
+ oh.make_node("Mul", ["X", "ysy"], ["final"]),
86
+ ],
87
+ "-nd-",
88
+ [
89
+ oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
90
+ oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
91
+ ],
92
+ [oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
93
+ ),
94
+ opset_imports=[oh.make_opsetid("", 18)],
95
+ ir_version=9,
96
+ )
97
+
98
+ sess = TorchOnnxEvaluator(proto)
99
+ feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
100
+ result = sess.run(None, feeds)
101
+ print(string_type(result, with_shape=True, with_min_max=True))
102
+
103
+ With ``verbose=1``, the class prints out every kernel run and
104
+ and every result deleted along the run.
105
+ It shows when a result is not needed anymore. In that case,
106
+ it is deleted to free the memory it takes.
107
+
108
+ .. runpython::
109
+ :showcode:
110
+
111
+ import onnx
112
+ import onnx.helper as oh
113
+ import torch
114
+ from onnx_diagnostic.helpers import string_type
115
+ from onnx_diagnostic.reference import TorchOnnxEvaluator
116
+
117
+ TFLOAT = onnx.TensorProto.FLOAT
118
+
119
+ proto = oh.make_model(
120
+ oh.make_graph(
121
+ [
122
+ oh.make_node("Sigmoid", ["Y"], ["sy"]),
123
+ oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
124
+ oh.make_node("Mul", ["X", "ysy"], ["final"]),
125
+ ],
126
+ "-nd-",
127
+ [
128
+ oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
129
+ oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
130
+ ],
131
+ [oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
132
+ ),
133
+ opset_imports=[oh.make_opsetid("", 18)],
134
+ ir_version=9,
135
+ )
136
+
137
+ sess = TorchOnnxEvaluator(proto, verbose=1)
138
+ feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
139
+ result = sess.run(None, feeds)
140
+ print(string_type(result, with_shape=True, with_min_max=True))
141
+
142
+ The runtime can also execute the kernel the onnx model on CUDA.
143
+ It follows the same logic as :class:`onnxruntime.InferenceSession`:
144
+ ``providers=["CUDAExecutionProvider"]``.
145
+ It is better in that case to move the input on CUDA. The class
146
+ tries to move every weight on CUDA but tries to keep any tensor
147
+ identified as a shape in CPU. Some bugs may remain as torch
148
+ raises an exception when devices are expected to be the same.
149
+ The runtime was validated with model :epkg:`arnir0/Tiny-LLM`.
150
+ Next example shows how to replace a kernel with a different
151
+ one based on :epkg:`onnxruntime`.
152
+
153
+ .. runpython::
154
+ :showcode:
155
+
156
+ import numpy as np
157
+ import onnx
158
+ import onnx.helper as oh
159
+ import onnxruntime
160
+ import torch
161
+ from onnx_diagnostic.helpers import string_type
162
+ from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
163
+ from onnx_diagnostic.reference import TorchOnnxEvaluator
164
+ from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
165
+
166
+ TFLOAT16 = onnx.TensorProto.FLOAT16
167
+
168
+ class LayerNormalizationOrt(OpRunKernel):
169
+ "LayerNormalization based on onnxruntime"
170
+
171
+ def __init__(self, node: onnx.NodeProto, version=None):
172
+ super().__init__(node, version)
173
+ self.axis = self.get_attribute_int(node, "axis", -1)
174
+ self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
175
+ self.stash_type = onnx_dtype_to_torch_dtype(
176
+ self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
177
+ )
178
+ self.compute_std = len(node.output) > 1
179
+ assert not self.compute_std, "The keren only computes the first output."
180
+ layer_model = oh.make_model(
181
+ oh.make_graph(
182
+ [
183
+ oh.make_node(
184
+ "LayerNormalization",
185
+ ["X", "W", "B"],
186
+ ["Z"],
187
+ axis=-1,
188
+ epsilon=9.999999974752427e-7,
189
+ )
190
+ ],
191
+ "dummy",
192
+ [
193
+ oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
194
+ oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
195
+ oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
196
+ ],
197
+ [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
198
+ ),
199
+ ir_version=9,
200
+ opset_imports=[oh.make_opsetid("", 17)],
201
+ )
202
+ self.ort_sess = onnxruntime.InferenceSession(
203
+ layer_model.SerializeToString(), providers=["CUDAExecutionProvider"]
204
+ )
205
+
206
+ def run(self, x, scale, bias=None):
207
+ print(f"-- running {self.__class__.__name__}")
208
+ feeds = dict(X=x, W=scale)
209
+ if bias is not None:
210
+ feeds["B"] = bias
211
+ feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
212
+ got = self.ort_sess.run(None, feeds)[0]
213
+ return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
214
+
215
+ # This kernel is tested on this model.
216
+ model = oh.make_model(
217
+ oh.make_graph(
218
+ [
219
+ oh.make_node(
220
+ "LayerNormalization",
221
+ ["X", "W", "B"],
222
+ ["ln"],
223
+ axis=-1,
224
+ epsilon=9.999999974752427e-7,
225
+ ),
226
+ oh.make_node(
227
+ "Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
228
+ ),
229
+ ],
230
+ "dummy",
231
+ [
232
+ oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
233
+ oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
234
+ oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
235
+ ],
236
+ [oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
237
+ ),
238
+ ir_version=9,
239
+ opset_imports=[oh.make_opsetid("", 17)],
240
+ )
241
+
242
+ torch_sess = TorchOnnxEvaluator(
243
+ model,
244
+ custom_kernels={("", "LayerNormalization"): LayerNormalizationOrt},
245
+ verbose=1,
246
+ )
247
+ feeds = dict(
248
+ zip(
249
+ torch_sess.input_names,
250
+ [
251
+ torch.rand(3, 4, 5, dtype=torch.float16),
252
+ torch.abs(torch.rand(5, dtype=torch.float16)),
253
+ torch.rand(5, dtype=torch.float16),
254
+ ],
255
+ )
256
+ )
257
+ res = torch_sess.run(None, feeds)
258
+ print(string_type(res, with_shape=True, with_min_max=True))
259
+ """
260
+
261
+ class IO:
262
+ "IO"
263
+
264
+ def __init__(self, name: str, type: int, shape: Tuple[Union[str, int], ...]):
265
+ self.name = name
266
+ self.type = type
267
+ self.shape = shape
268
+
269
+ @classmethod
270
+ def _on_cuda(cls, providers) -> int:
271
+ if not providers:
272
+ return -1
273
+ for p in providers:
274
+ if p == "CUDAExecutionProvider":
275
+ return 0
276
+ if isinstance(p, tuple) and p[0] == "CUDAExecutionProvider":
277
+ return p[1]["device_id"]
278
+ return -1
279
+
280
+ def __init__(
281
+ self,
282
+ proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
283
+ providers: Tuple[str, ...] = ("CPUExecutionProvider",),
284
+ opsets: Optional[Dict[str, int]] = None,
285
+ local_functions: Optional[Dict[Tuple[str, str], "TorchOnnxEvaluator"]] = None,
286
+ verbose: int = 0,
287
+ custom_kernels: Optional[Dict[Tuple[str, str], type[torch_ops.OpRunKernel]]] = None,
288
+ ):
289
+ self.providers = providers
290
+ self.constants: Dict[str, torch.Tensor] = {}
291
+ self.kernels: List[Optional[torch_ops.OpRunKernel]] = []
292
+ self.functions = local_functions.copy() if local_functions else {}
293
+ self.CPU = torch.tensor([0]).to("cpu").device
294
+ self.verbose = verbose
295
+ self.custom_kernels = custom_kernels or {}
296
+ dev = self._on_cuda(providers)
297
+ if dev < 0:
298
+ self.default_device = self.CPU
299
+ self.CUDA = None
300
+ else:
301
+ self.CUDA = torch.tensor([0]).to(f"cuda:{dev}").device
302
+ self.default_device = self.CUDA
303
+
304
+ if isinstance(proto, str):
305
+ proto = onnx.load(proto)
306
+ if isinstance(proto, onnx.ModelProto):
307
+ assert opsets is None, "proto is a model, opsets must be None in that case"
308
+ assert not proto.graph.sparse_initializer, "sparse_initializer not support yet"
309
+ self.opsets = {d.domain: d.version for d in proto.opset_import}
310
+ for f in proto.functions:
311
+ self.functions[f.domain, f.name] = self.__class__(
312
+ f,
313
+ providers=providers,
314
+ local_functions=self.functions,
315
+ verbose=self.verbose,
316
+ )
317
+ self._build_initializers(proto.graph.initializer)
318
+ self._build_initializers(proto.graph.node)
319
+ self._build_kernels(proto.graph.node)
320
+ self.input_names = [i.name for i in proto.graph.input]
321
+ self.output_names = [i.name for i in proto.graph.output]
322
+ self._io_input_names = [
323
+ self.IO(
324
+ name=i.name,
325
+ type=i.type.tensor_type.elem_type,
326
+ shape=tuple(
327
+ d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
328
+ ),
329
+ )
330
+ for i in proto.graph.input
331
+ ]
332
+ self._io_output_names = [
333
+ self.IO(
334
+ name=i.name,
335
+ type=i.type.tensor_type.elem_type,
336
+ shape=tuple(
337
+ d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
338
+ ),
339
+ )
340
+ for i in proto.graph.output
341
+ ]
342
+ elif isinstance(proto, onnx.GraphProto):
343
+ assert opsets, "opsets must be specified if proto is a graph"
344
+ assert not proto.sparse_initializer, "sparse_initializer not support yet"
345
+ self.opsets = opsets
346
+ self._build_initializers(proto.initializer)
347
+ self._build_initializers(proto.node)
348
+ self._build_kernels(proto.node)
349
+ self.input_names = [i.name for i in proto.input]
350
+ self.output_names = [i.name for i in proto.output]
351
+ elif isinstance(proto, onnx.FunctionProto):
352
+ assert opsets is None, "proto is a model, opsets must be None in that case"
353
+ self.opsets = {d.domain: d.version for d in proto.opset_import}
354
+ self._build_initializers(proto.node)
355
+ self._build_kernels(proto.node)
356
+ self.input_names = list(proto.input)
357
+ self.output_names = list(proto.output)
358
+ else:
359
+ raise TypeError(f"Unexpected type {type(proto)} for proto")
360
+
361
+ self.runtime_info = first_used_last_used(proto, constant_as_initializer=True)
362
+ self.last_used: List[List[str]] = [[] for _ in self.kernels]
363
+ for name, info in self.runtime_info.items():
364
+ assert isinstance(info.last_used, int) or info.is_input, (
365
+ f"Missing field last_used in {info!r}, last_used={info.last_used!r}, "
366
+ f"This may mean the node is unused and it should be removed."
367
+ )
368
+ if info.last_used is None:
369
+ # Not used.
370
+ self.last_used[0].append(name)
371
+ elif not info.is_output and not info.is_initializer:
372
+ self.last_used[info.last_used].append(name)
373
+
374
+ def get_inputs(self):
375
+ "Same API than onnxruntime."
376
+ assert hasattr(self, "_io_input_names"), "Missing attribute '_io_input_names'."
377
+ return self._io_input_names
378
+
379
+ def get_outputs(self):
380
+ "Same API than onnxruntime."
381
+ assert hasattr(self, "_io_output_names"), "Missing attribute '_io_output_names'."
382
+ return self._io_output_names
383
+
384
+ @property
385
+ def on_cuda(self) -> bool:
386
+ "Tells if the default device is CUDA."
387
+ return self.default_device == self.CUDA
388
+
389
+ def _build_initializers(self, inits: Sequence[Union[onnx.NodeProto, onnx.TensorProto]]):
390
+ for init in inits:
391
+ if isinstance(init, onnx.TensorProto):
392
+ self.constants[init.name] = to_tensor(init).to(self.default_device)
393
+ elif (
394
+ isinstance(init, onnx.NodeProto)
395
+ and init.op_type == "Constant"
396
+ and init.domain == ""
397
+ ):
398
+ value = None
399
+ for att in init.attribute:
400
+ if att.name == "value":
401
+ value = to_tensor(att.t).to(self.default_device)
402
+ elif att.name == "value_floats":
403
+ value = torch.tensor(list(att.floats), dtype=torch.float32).to(
404
+ self.default_device
405
+ )
406
+ assert value is not None, f"No attribute value in node {init}"
407
+ self.constants[init.output[0]] = value
408
+
409
+ def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
410
+ kernels = get_kernels()
411
+ self.kernels.clear()
412
+ for node in nodes:
413
+ kernel_kwargs = dict(verbose=max(0, self.verbose - 1))
414
+ opset = self.opsets[node.domain]
415
+ key = node.domain, node.op_type, opset
416
+ if key[:2] in self.custom_kernels:
417
+ cls = self.custom_kernels[key[:2]]
418
+ ags = [self.default_device] if cls.device_dependent() else []
419
+ kws = dict(parent=self) if cls.has_subgraphs() else {}
420
+ kws.update(kernel_kwargs) # type: ignore[arg-type]
421
+ kernel2 = cls(node, opset, *ags, **kws) # type: ignore[arg-type]
422
+ self.kernels.append(kernel2)
423
+ continue
424
+
425
+ if (node.domain, node.op_type) in self.functions:
426
+ kernel = torch_ops.OpRunFunction(
427
+ self.functions[node.domain, node.op_type],
428
+ node,
429
+ self.opsets[node.domain],
430
+ **kernel_kwargs,
431
+ )
432
+ self.kernels.append(kernel)
433
+ continue
434
+
435
+ if node.op_type == "Constant" and node.domain == "":
436
+ # Treated as a constant.
437
+ self.kernels.append(None)
438
+ continue
439
+
440
+ while key not in kernels and opset > 0:
441
+ opset -= 1
442
+ key = node.domain, node.op_type, opset
443
+ assert key in kernels, (
444
+ f"Missing kernel for node type {node.op_type!r} from domain {node.domain!r}, "
445
+ f"local functions={sorted(self.functions)}"
446
+ )
447
+ cls = kernels[key]
448
+ ags = [self.default_device] if cls.device_dependent() else []
449
+ kws = dict(parent=self) if cls.has_subgraphs() else {}
450
+ kws.update(kernel_kwargs) # type: ignore[arg-type]
451
+ kernel2 = cls(node, opset, *ags, **kws) # type: ignore[arg-type]
452
+ self.kernels.append(kernel2)
453
+
454
+ def run(
455
+ self,
456
+ outputs: Optional[List[str]],
457
+ feeds: Union[Dict[str, torch.Tensor], Dict[str, np.ndarray]],
458
+ ) -> Union[List[Optional[torch.Tensor]], List[Optional[np.ndarray]]]:
459
+ """
460
+ Runs the ONNX model.
461
+
462
+ :param outputs: outputs required
463
+ :param feeds: inputs
464
+ :return: output tensors.
465
+ """
466
+ use_numpy = any(isinstance(t, np.ndarray) for t in feeds.values())
467
+ if use_numpy:
468
+ feeds = {k: torch.from_numpy(v) for k, v in feeds.items()}
469
+ if outputs is None:
470
+ outputs = self.output_names
471
+
472
+ # sets constants
473
+ for k, v in self.constants.items():
474
+ r = self.runtime_info[k]
475
+ if not r.has_value:
476
+ r.set_value(
477
+ torch_ops.OpRunTensor(
478
+ v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
479
+ is_constant=True,
480
+ may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
481
+ )
482
+ )
483
+ if self.verbose:
484
+ print(f"+C {r.name}: {r.string_type()}")
485
+
486
+ # inputs
487
+ for k, v in feeds.items():
488
+ r = self.runtime_info[k]
489
+ r.set_value(
490
+ torch_ops.OpRunTensor(
491
+ v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
492
+ is_constant=False,
493
+ may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
494
+ )
495
+ )
496
+ if self.verbose:
497
+ print(f"+I {r.name}: {r.string_type()}")
498
+
499
+ # node execution
500
+ for it, kernel in enumerate(self.kernels):
501
+ if kernel is not None:
502
+ if self.verbose:
503
+ print(
504
+ f"{kernel.__class__.__name__}"
505
+ f"({', '.join(kernel.input)}) -> "
506
+ f"{', '.join(kernel.output)}"
507
+ )
508
+ # kernel execution
509
+ inputs = [(self.runtime_info[i].value if i else None) for i in kernel.input]
510
+ if kernel.has_subgraphs():
511
+ res = kernel.run(*inputs, context=self.runtime_info) # type: ignore[call-arg]
512
+ else:
513
+ res = kernel.run(*inputs)
514
+ if isinstance(res, tuple):
515
+ # outputs
516
+ assert all(isinstance(o, torch_ops.OpRunValue) for o in res), (
517
+ f"Unexpected output type {[type(o) for o in res]} "
518
+ f"for kernel {type(kernel)}."
519
+ )
520
+ for name, t in zip(kernel.output, res):
521
+ self.runtime_info[name].set_value(t)
522
+ if self.verbose:
523
+ for name in kernel.output:
524
+ print(f"+R {name}: {self.runtime_info[name].string_type()}")
525
+ else:
526
+ assert isinstance(
527
+ res, torch_ops.OpRunValue
528
+ ), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
529
+ self.runtime_info[kernel.output[0]].set_value(res)
530
+ if self.verbose:
531
+ print(
532
+ f"+R {kernel.output[0]}: "
533
+ f"{self.runtime_info[kernel.output[0]].string_type()}"
534
+ )
535
+
536
+ # free intermediate results
537
+ for name in self.last_used[it]:
538
+ self.runtime_info[name].clean_value()
539
+ if self.verbose:
540
+ print(f"- clean {name}")
541
+
542
+ assert all(
543
+ self.runtime_info[o].value is not None for o in outputs
544
+ ), "Not implemented yet when one output is None."
545
+ fres = [self.runtime_info[o].value.tensor for o in outputs] # type: ignore[union-attr]
546
+ if self.verbose:
547
+ print(f"++ outputs {', '.join(outputs)}")
548
+
549
+ # clean previous execution
550
+ for k in feeds:
551
+ self.runtime_info[k].clean_value()
552
+ if self.verbose:
553
+ print(f"- clean {k}")
554
+ for o in outputs:
555
+ self.runtime_info[o].clean_value()
556
+ if self.verbose:
557
+ print(f"- clean {o}")
558
+
559
+ if use_numpy:
560
+ return [None if a is None else a.detach().cpu().numpy() for a in fres]
561
+ return fres
562
+
563
+ def run_with_values(
564
+ self,
565
+ *args: Optional[torch_ops.OpRunTensor],
566
+ context: Optional[Dict[str, RuntimeValue]] = None,
567
+ ) -> Union[torch_ops.OpRunValue, Tuple[torch_ops.OpRunValue, ...]]:
568
+ """
569
+ Runs the ONNX model. The signature is different.
570
+ This method is called by every kernel hokding a subgraph.
571
+ The local variables are stored in `context`.
572
+
573
+ :param args: inputs
574
+ :param context: local context for the execution of subgraphs
575
+ :return: output OpRunTensor
576
+ """
577
+ assert all(
578
+ isinstance(a, torch_ops.OpRunValue) for a in args
579
+ ), f"Unexpected type in args: {[type(a) for a in args]}"
580
+ outputs = self.output_names
581
+ context = context or {}
582
+
583
+ # sets constants
584
+ for k, v in self.constants.items():
585
+ r = self.runtime_info[k]
586
+ if not r.has_value:
587
+ r.set_value(
588
+ torch_ops.OpRunTensor(
589
+ v.to(self.CUDA) if r.is_shape is False and self.on_cuda else v,
590
+ is_constant=True,
591
+ may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
592
+ )
593
+ )
594
+
595
+ # inputs
596
+ for k, v in zip(self.input_names, args):
597
+ r = self.runtime_info[k]
598
+ r.set_value(
599
+ torch_ops.OpRunTensor(None) if v is None else v.__class__(v.tensor_or_sequence)
600
+ )
601
+
602
+ # node execution
603
+ for it, kernel in enumerate(self.kernels):
604
+ if kernel is not None:
605
+ # kernel execution
606
+ inputs = [
607
+ (
608
+ (
609
+ self.runtime_info[i].value
610
+ if i in self.runtime_info
611
+ else context[i].value
612
+ )
613
+ if i
614
+ else None
615
+ )
616
+ for i in kernel.input
617
+ ]
618
+ res = kernel.run(*inputs)
619
+ if isinstance(res, tuple):
620
+ # outputs
621
+ assert all(isinstance(o, torch_ops.OpRunTensor) for o in res), (
622
+ f"Unexpected output type {[type(o) for o in res]} "
623
+ f"for kernel {type(kernel)}."
624
+ )
625
+ for name, t in zip(kernel.output, res):
626
+ self.runtime_info[name].set_value(t)
627
+ else:
628
+ assert isinstance(
629
+ res, torch_ops.OpRunValue
630
+ ), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
631
+ self.runtime_info[kernel.output[0]].set_value(res)
632
+
633
+ # free intermediate results
634
+ for name in self.last_used[it]:
635
+ self.runtime_info[name].clean_value()
636
+
637
+ assert all(
638
+ self.runtime_info[o].value is not None for o in outputs
639
+ ), "Not implemented yet when one output is None."
640
+ res2 = [self.runtime_info[o].value.copy() for o in outputs] # type: ignore[assignment, union-attr]
641
+
642
+ # clean previous execution
643
+ for k in self.input_names:
644
+ self.runtime_info[k].clean_value()
645
+ for o in self.output_names:
646
+ self.runtime_info[o].clean_value()
647
+
648
+ return res2[0] if len(res2) == 1 else tuple(res2) # type: ignore[index, return-value, arg-type]
@@ -0,0 +1,55 @@
1
+ from ._op_run import OpRunKernel, OpRunFunction, OpRunSequence, OpRunTensor, OpRunValue
2
+ from .access_ops import Gather_1, ScatterND_16, Slice_13
3
+ from .binary_ops import (
4
+ And_1,
5
+ Add_1,
6
+ Div_1,
7
+ Equal_1,
8
+ Greater_1,
9
+ GreaterOrEqual_1,
10
+ Less_1,
11
+ LessOrEqual_1,
12
+ MatMul_1,
13
+ Mul_1,
14
+ Or_1,
15
+ Pow_12,
16
+ Sub_1,
17
+ )
18
+ from .controlflow_ops import If_1, Loop_16
19
+ from .generator_ops import Range_11
20
+ from .nn_ops import AveragePool_11, Conv_11, LayerNormalization_17, Softmax_13, Tanh_6
21
+ from .other_ops import (
22
+ Cast_6,
23
+ CastLike_15,
24
+ NonZero_13,
25
+ Concat_1,
26
+ Tile_6,
27
+ Transpose_1,
28
+ Trilu_14,
29
+ Where_9,
30
+ )
31
+ from .reduce_ops import ReduceMax_18, ReduceMean_18, ReduceMin_17, ReduceMin_18, ReduceSum_13
32
+ from .sequence_ops import ConcatFromSequence_11, SequenceEmpty_11, SequenceInsert_11
33
+ from .shape_ops import (
34
+ ConstantOfShape_9,
35
+ Expand_8,
36
+ Reshape_14,
37
+ Shape_15,
38
+ Squeeze_13,
39
+ Split_18,
40
+ Unsqueeze_13,
41
+ )
42
+ from .unary_ops import (
43
+ Abs_1,
44
+ Cos_1,
45
+ Erf_9,
46
+ Exp_1,
47
+ Identity_1,
48
+ Log_1,
49
+ Neg_1,
50
+ Not_1,
51
+ Reciprocal_1,
52
+ Sigmoid_6,
53
+ Sin_1,
54
+ Sqrt_1,
55
+ )