onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,45 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from __future__ import annotations
5
+
6
+ import numpy as np
7
+
8
+ from onnx.reference.op_run import OpRun
9
+
10
+
11
+ def gather_numpy_2(self: np.ndarray, index: np.ndarray) -> np.ndarray:
12
+ res = []
13
+ for a, b in zip(self, index):
14
+ res.append(a[b[0]])
15
+ return np.array(res, dtype=self.dtype).reshape(index.shape)
16
+
17
+
18
+ def gather_numpy(self: np.ndarray, dim: int, index: np.ndarray) -> np.ndarray:
19
+ idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1 :]
20
+ self_xsection_shape = self.shape[:dim] + self.shape[dim + 1 :]
21
+ if idx_xsection_shape != self_xsection_shape:
22
+ raise ValueError(
23
+ f"Except for dimension {dim!r}, all dimensions of "
24
+ f"index and self should be the same size."
25
+ )
26
+ data_swaped = np.swapaxes(self, 0, dim)
27
+ index_swaped = np.swapaxes(index, 0, dim)
28
+
29
+ try:
30
+ gathered = np.choose(index_swaped, data_swaped, mode="wrap")
31
+ except ValueError:
32
+ if len(index_swaped.shape) == 2 and len(data_swaped.shape) == 2:
33
+ return gather_numpy_2(self, index)
34
+ raise # pragma: no cover
35
+
36
+ return np.swapaxes(gathered, 0, dim)
37
+
38
+
39
+ class GatherElements(OpRun):
40
+ def _run(self, data, indices, axis=None):
41
+ try:
42
+ return (gather_numpy(data, axis, indices),)
43
+ except TypeError:
44
+ # distribution x86 requires int32.
45
+ return (gather_numpy(data, axis, indices.astype(int)),)
@@ -0,0 +1,12 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+ from onnx.reference.ops.op_scatternd import _scatter_nd_impl
4
+
5
+
6
+ class GatherGrad(OpRun):
7
+ op_domain = "com.microsoft"
8
+
9
+ def _run(self, shape, indices, updates, reduction=None):
10
+ data = np.zeros(shape, dtype=updates.dtype)
11
+ y = _scatter_nd_impl(data, indices, updates, reduction=reduction)
12
+ return (y,)
@@ -0,0 +1,11 @@
1
+ from onnx.reference.op_run import OpRun
2
+
3
+
4
+ class MemcpyFromHost(OpRun):
5
+ def _run(self, x):
6
+ return (x,)
7
+
8
+
9
+ class MemcpyToHost(OpRun):
10
+ def _run(self, x):
11
+ return (x,)
@@ -0,0 +1,23 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ def sigmoid(x): # type: ignore
6
+ if x > 0:
7
+ return 1 / (1 + np.exp(-x))
8
+ return np.exp(x) / (1 + np.exp(x))
9
+
10
+
11
+ class MulSigmoid(OpRun):
12
+ op_domain = "onnx_extended.ortops.optim.cuda"
13
+
14
+ def __init__(self, onnx_node, run_params): # type: ignore
15
+ OpRun.__init__(self, onnx_node, run_params)
16
+ self.vf = np.vectorize(sigmoid)
17
+
18
+ def _run(self, X):
19
+ if len(X.shape) == 0:
20
+ return ((X * sigmoid(X)).astype(X.dtype),)
21
+ if X.size == 0:
22
+ return (X,)
23
+ return ((X * self.vf(X)).astype(X.dtype),)
@@ -0,0 +1,8 @@
1
+ from onnx.reference.op_run import OpRun
2
+
3
+
4
+ class NegXplus1(OpRun):
5
+ op_domain = "onnx_extended.ortops.optim.cuda"
6
+
7
+ def _run(self, X):
8
+ return ((1 - X).astype(X.dtype),)
@@ -0,0 +1,40 @@
1
+ from onnx.reference.op_run import OpRun
2
+ from onnx.reference.ops.op_average_pool import AveragePool_19 as AveragePool
3
+ from onnx.reference.ops.op_dequantize_linear import DequantizeLinear_19 as DequantizeLinear
4
+ from onnx.reference.ops.op_quantize_linear import QuantizeLinear_19 as QuantizeLinear
5
+
6
+
7
+ class QLinearAveragePool(OpRun):
8
+ op_domain = "com.microsoft"
9
+
10
+ def _run(
11
+ self,
12
+ x,
13
+ x_scale,
14
+ x_zero_point,
15
+ y_scale,
16
+ y_zero_point,
17
+ auto_pad=None,
18
+ ceil_mode=None,
19
+ channels_last=None,
20
+ count_include_pad=None,
21
+ kernel_shape=None,
22
+ pads=None,
23
+ strides=None,
24
+ ):
25
+ assert channels_last in (
26
+ None,
27
+ 0,
28
+ ), f"QLinearAveragePool not implemented if channels_last={channels_last}"
29
+ dqx = DequantizeLinear.eval(x, x_scale, x_zero_point)
30
+ y = AveragePool.eval(
31
+ dqx,
32
+ auto_pad=auto_pad,
33
+ ceil_mode=ceil_mode,
34
+ count_include_pad=count_include_pad,
35
+ kernel_shape=kernel_shape,
36
+ pads=pads,
37
+ strides=strides,
38
+ )
39
+ qy = QuantizeLinear.eval(y, y_scale, y_zero_point)
40
+ return (qy,)
@@ -0,0 +1,102 @@
1
+ from typing import Tuple
2
+ from onnx.defs import OpSchema
3
+ from onnx.helper import make_attribute
4
+ from onnx.reference.op_run import OpRun
5
+ from onnx.reference.ops.op_conv import Conv
6
+ from onnx.reference.ops.op_dequantize_linear import DequantizeLinear_19 as DequantizeLinear
7
+ from onnx.reference.ops.op_quantize_linear import QuantizeLinear_19 as QuantizeLinear
8
+
9
+
10
+ def _switch_dims_nchw_nhwc(dims: Tuple[int, ...], from_nchw_to_nhwc: bool):
11
+ if len(dims) == 4:
12
+ if from_nchw_to_nhwc:
13
+ return (dims[0], *dims[2:], dims[1])
14
+ return (dims[0], dims[-1], *dims[1:-1])
15
+ if len(dims) == 3:
16
+ if from_nchw_to_nhwc:
17
+ return (*dims[1:], dims[0])
18
+ return (dims[-1], *dims[:-1])
19
+ raise NotImplementedError(f"Unable to process shape={dims}")
20
+
21
+
22
+ class QLinearConv(OpRun):
23
+ op_domain = "com.microsoft"
24
+
25
+ op_schema = OpSchema(
26
+ "QLinearConv",
27
+ "com.microsoft",
28
+ 1,
29
+ inputs=[
30
+ OpSchema.FormalParameter("x", "T"),
31
+ OpSchema.FormalParameter("x_scale", "T"),
32
+ OpSchema.FormalParameter("x_zero_point", "T1"),
33
+ OpSchema.FormalParameter("w", "T"),
34
+ OpSchema.FormalParameter("w_scale", "T"),
35
+ OpSchema.FormalParameter("w_zero_point", "T2"),
36
+ OpSchema.FormalParameter("y_scale", "T"),
37
+ OpSchema.FormalParameter("y_zero_point", "T3"),
38
+ OpSchema.FormalParameter(
39
+ "B", "T3", param_option=OpSchema.FormalParameterOption.Optional
40
+ ),
41
+ ],
42
+ outputs=[OpSchema.FormalParameter("y", "T3")],
43
+ type_constraints=[
44
+ ("T", ["tensor(float)"], ""),
45
+ ("T1", ["tensor(int8)", "tensor(uint8)"], ""),
46
+ ("T2", ["tensor(int8)", "tensor(uint8)"], ""),
47
+ ("T3", ["tensor(int8)", "tensor(uint8)"], ""),
48
+ ],
49
+ attributes=[
50
+ OpSchema.Attribute("auto_pad", make_attribute("auto_pad", "NOTSET"), ""),
51
+ OpSchema.Attribute("kernel_shape", OpSchema.AttrType.INTS, "", required=False),
52
+ OpSchema.Attribute("dilations", OpSchema.AttrType.INTS, "", required=False),
53
+ OpSchema.Attribute("strides", OpSchema.AttrType.INTS, "", required=False),
54
+ OpSchema.Attribute("pads", OpSchema.AttrType.INTS, "", required=False),
55
+ OpSchema.Attribute("group", make_attribute("group", 1), ""),
56
+ OpSchema.Attribute("channels_last", make_attribute("channels_last", 0), ""),
57
+ ],
58
+ )
59
+
60
+ def _run(
61
+ self,
62
+ x,
63
+ x_scale,
64
+ x_zero_point,
65
+ w,
66
+ w_scale,
67
+ w_zero_point,
68
+ y_scale,
69
+ y_zero_point,
70
+ B=None,
71
+ auto_pad=None,
72
+ channels_last=None,
73
+ dilations=None,
74
+ group=None,
75
+ kernel_shape=None,
76
+ pads=None,
77
+ strides=None,
78
+ ):
79
+ dqx = DequantizeLinear.eval(x, x_scale, x_zero_point)
80
+ dqw = DequantizeLinear.eval(w, w_scale, w_zero_point)
81
+ if channels_last:
82
+ dqx = dqx.reshape(_switch_dims_nchw_nhwc(x.shape, False))
83
+ dqb = (
84
+ DequantizeLinear.eval(B, x_scale * w_scale, 0).astype(dqx.dtype)
85
+ if B is not None
86
+ else None
87
+ )
88
+ y = Conv.eval(
89
+ dqx,
90
+ dqw,
91
+ dqb,
92
+ auto_pad=auto_pad,
93
+ dilations=dilations,
94
+ group=group,
95
+ kernel_shape=kernel_shape,
96
+ pads=pads,
97
+ strides=strides,
98
+ )
99
+ if channels_last:
100
+ y = y.reshape(_switch_dims_nchw_nhwc(y.shape, True))
101
+ qy = QuantizeLinear.eval(y, y_scale, y_zero_point)
102
+ return (qy,)
@@ -0,0 +1,23 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ def sigmoid(x): # type: ignore
6
+ if x > 0:
7
+ return 1 / (1 + np.exp(-x))
8
+ return np.exp(x) / (1 + np.exp(x))
9
+
10
+
11
+ class QuickGelu(OpRun):
12
+ op_domain = "com.microsoft"
13
+
14
+ def __init__(self, onnx_node, run_params): # type: ignore
15
+ OpRun.__init__(self, onnx_node, run_params)
16
+ self.vf = np.vectorize(sigmoid)
17
+
18
+ def _run(self, X, alpha=1.0):
19
+ if len(X.shape) == 0:
20
+ return ((X * sigmoid(X * alpha)).astype(X.dtype),)
21
+ if X.size == 0:
22
+ return (X,)
23
+ return ((X * self.vf(X * alpha)).astype(X.dtype),)
@@ -0,0 +1,13 @@
1
+ from onnx.reference.op_run import OpRun
2
+
3
+
4
+ class ReplaceZero(OpRun):
5
+ op_domain = "onnx_extended.ortops.optim.cuda"
6
+
7
+ def _run(self, X, by=None, equal=None):
8
+ x2 = X.copy().flatten()
9
+ if equal:
10
+ x2[x2 == 0] = by
11
+ else:
12
+ x2[x2 != 0] = by
13
+ return (x2.reshape(X.shape),)
@@ -0,0 +1,19 @@
1
+ from onnx.reference.op_run import OpRun
2
+
3
+
4
+ class Rotary(OpRun):
5
+ op_domain = "onnx_extended.ortops.optim.cuda"
6
+
7
+ def _run(self, X, splits=None, side=None):
8
+ assert splits is None or (
9
+ splits.shape == (2,) and splits[0] == splits[1]
10
+ ), f"Unexpected split value {splits}"
11
+ last_dim = X.shape[-1] // 2
12
+ cp = X.copy()
13
+ if side == "left":
14
+ cp[..., :last_dim] = X[..., last_dim:]
15
+ cp[..., last_dim:] = -X[..., :last_dim]
16
+ else:
17
+ cp[..., :last_dim] = -X[..., last_dim:]
18
+ cp[..., last_dim:] = X[..., :last_dim]
19
+ return (cp,)
@@ -0,0 +1,65 @@
1
+ import numpy as np
2
+ from onnx.reference.ops.op_scan import Scan as _Scan
3
+
4
+
5
+ class Scan(_Scan):
6
+
7
+ def need_context(self) -> bool:
8
+ """Tells the runtime if this node needs the context
9
+ (all the results produced so far) as it may silently access
10
+ one of them (operator Loop).
11
+ The default answer is `False`.
12
+ """
13
+ return True
14
+
15
+ def _run(
16
+ self,
17
+ *args,
18
+ context=None,
19
+ body=None,
20
+ num_scan_inputs=None,
21
+ scan_input_axes=None,
22
+ scan_input_directions=None,
23
+ scan_output_axes=None,
24
+ scan_output_directions=None,
25
+ attributes=None,
26
+ ):
27
+ (
28
+ num_loop_state_vars,
29
+ _num_scan_outputs,
30
+ _output_directions,
31
+ _max_dir_out,
32
+ _output_axes,
33
+ _max_axe_out,
34
+ state_names_in,
35
+ state_names_out,
36
+ scan_names_in,
37
+ scan_names_out,
38
+ scan_values,
39
+ states,
40
+ ) = self._common_run_shape(*args)
41
+
42
+ max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]]
43
+ results = [[] for _ in scan_names_out] # type: ignore
44
+
45
+ for it in range(max_iter):
46
+ inputs = context.copy()
47
+ inputs.update(dict(zip(state_names_in, states)))
48
+ inputs.update({name: value[it] for name, value in zip(scan_names_in, scan_values)})
49
+
50
+ try:
51
+ outputs_list = self._run_body(inputs) # type: ignore
52
+ except TypeError as e:
53
+ raise TypeError(
54
+ f"Unable to call 'run' for type '{type(self.body)}'." # type: ignore
55
+ ) from e
56
+
57
+ outputs = dict(zip(self.output_names, outputs_list))
58
+ states = [outputs[name] for name in state_names_out]
59
+ for i, name in enumerate(scan_names_out):
60
+ results[i].append(np.expand_dims(outputs[name], axis=0))
61
+
62
+ for res in results:
63
+ conc = np.vstack(res)
64
+ states.append(conc)
65
+ return self._check_and_fix_outputs(tuple(states))
@@ -0,0 +1,107 @@
1
+ import numpy as np
2
+
3
+ from onnx.reference.op_run import OpRun
4
+
5
+
6
+ def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
7
+ if reduction == "add":
8
+
9
+ def f(x, y):
10
+ return x + y
11
+
12
+ elif reduction == "min":
13
+
14
+ def f(x, y):
15
+ return np.minimum(x, y)
16
+
17
+ elif reduction == "max":
18
+
19
+ def f(x, y):
20
+ return np.maximum(x, y)
21
+
22
+ elif reduction == "mul":
23
+
24
+ def f(x, y):
25
+ return x * y
26
+
27
+ else:
28
+
29
+ def f(x, y):
30
+ return y
31
+
32
+ if axis < 0:
33
+ axis = data.ndim + axis
34
+
35
+ if len(data.shape) == 1 and axis == 0:
36
+ scattered = np.copy(data)
37
+ for pos, up in zip(indices, updates):
38
+ scattered[pos] = f(scattered[pos], up)
39
+ return scattered
40
+
41
+ if len(indices.shape) == 2:
42
+ scattered = np.copy(data)
43
+ if axis == 0:
44
+ for i in range(indices.shape[0]):
45
+ for j in range(indices.shape[1]):
46
+ scattered[indices[i, j], j] = f(scattered[indices[i, j], j], updates[i, j])
47
+ else:
48
+ for i in range(indices.shape[0]):
49
+ for j in range(indices.shape[1]):
50
+ scattered[i, indices[i, j]] = f(scattered[i, indices[i, j]], updates[i, j])
51
+ return scattered
52
+
53
+ if len(indices.shape) == 3:
54
+ scattered = np.copy(data)
55
+ if axis == 0:
56
+ for i in range(indices.shape[0]):
57
+ for j in range(indices.shape[1]):
58
+ for k in range(indices.shape[2]):
59
+ scattered[indices[i, j, k], j, k] = f(
60
+ scattered[indices[i, j, k], j, k], updates[i, j, k]
61
+ )
62
+ elif axis == 1:
63
+ for i in range(indices.shape[0]):
64
+ for j in range(indices.shape[1]):
65
+ for k in range(indices.shape[2]):
66
+ scattered[i, indices[i, j, k], k] = f(
67
+ scattered[i, indices[i, j, k], k], updates[i, j, k]
68
+ )
69
+ elif axis == 2:
70
+ for i in range(indices.shape[0]):
71
+ for j in range(indices.shape[1]):
72
+ for k in range(indices.shape[2]):
73
+ scattered[i, j, indices[i, j, k]] = f(
74
+ scattered[i, j, indices[i, j, k]], updates[i, j, k]
75
+ )
76
+ return scattered
77
+
78
+ if len(indices.shape) == 4:
79
+ scattered = np.copy(data)
80
+ if axis == 3:
81
+ for a in range(indices.shape[0]):
82
+ for i in range(indices.shape[1]):
83
+ for j in range(indices.shape[2]):
84
+ for k in range(indices.shape[3]):
85
+ scattered[a, i, j, indices[a, i, j, k]] = f(
86
+ scattered[a, i, j, indices[a, i, j, k]],
87
+ updates[a, i, j, k],
88
+ )
89
+ return scattered
90
+ if axis == 0:
91
+ for a in range(indices.shape[0]):
92
+ for i in range(indices.shape[1]):
93
+ for j in range(indices.shape[2]):
94
+ for k in range(indices.shape[3]):
95
+ scattered[indices[a, i, j, k], i, j, k] = f(
96
+ scattered[indices[a, i, j, k], i, j, k],
97
+ updates[a, i, j, k],
98
+ )
99
+ return scattered
100
+
101
+ raise RuntimeError(f"Not implemented for indices.shape={indices.shape} and axis={axis}")
102
+
103
+
104
+ class ScatterElements(OpRun):
105
+ def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
106
+ res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
107
+ return (res,)
@@ -0,0 +1,22 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+ from onnx.reference.ops.op_scatternd import _scatter_nd_impl
4
+
5
+
6
+ class ScatterNDOfShape(OpRun):
7
+ op_domain = "onnx_extended.ortops.optim.cuda"
8
+
9
+ def _run(self, shape, indices, updates, reduction=None, strategy=None):
10
+ data = np.zeros(shape, dtype=updates.dtype)
11
+ y = _scatter_nd_impl(data, indices, updates, reduction=reduction)
12
+ return (y,)
13
+
14
+
15
+ class MaskedScatterNDOfShape(OpRun):
16
+ op_domain = "onnx_extended.ortops.optim.cuda"
17
+
18
+ def _run(self, shape, indices, updates, reduction=None, maskedValue=None):
19
+ data = np.zeros(shape, dtype=updates.dtype)
20
+ new_updates = np.where(indices == maskedValue, 0, updates)
21
+ y = _scatter_nd_impl(data, indices, new_updates, reduction=reduction)
22
+ return (y,)
@@ -0,0 +1,8 @@
1
+ from onnx.reference.op_run import OpRun
2
+
3
+
4
+ class SimplifiedLayerNormalization(OpRun):
5
+ def _run(self, x, scale, bias=None, axis=None, epsilon=None, stash_type=None):
6
+ xm = (x**2).mean(axis=axis, keepdims=1) + epsilon
7
+ xq = xm ** (-0.5)
8
+ return (x * xq, xq)
@@ -0,0 +1,13 @@
1
+ from onnx.reference.op_run import OpRun
2
+ from onnx.reference.ops.op_layer_normalization import _layer_normalization
3
+
4
+
5
+ class SkipLayerNormalization(OpRun):
6
+ op_domain = "com.microsoft"
7
+
8
+ def _run(self, x, skip, gamma=None, beta=None, bias=None, epsilon=None):
9
+ add = x + skip
10
+ if bias is not None:
11
+ add = add + bias
12
+ res = _layer_normalization(add, gamma, beta, axis=-1, epsilon=epsilon)
13
+ return (*res, add)
@@ -0,0 +1,20 @@
1
+ from onnx.reference.ops.op_slice import SliceCommon
2
+
3
+
4
+ class Slice_10(SliceCommon):
5
+ def __init__(self, onnx_node, run_params):
6
+ SliceCommon.__init__(self, onnx_node, run_params)
7
+
8
+
9
+ class Slice_1(SliceCommon):
10
+ def __init__(self, onnx_node, run_params):
11
+ print(onnx_node)
12
+ SliceCommon.__init__(self, onnx_node, run_params)
13
+ for f in ["starts", "ends", "steps", "axes"]:
14
+ if not hasattr(self, f):
15
+ continue
16
+ if getattr(self, f) is not None and len(getattr(self, f)) == 0:
17
+ setattr(self, f, None)
18
+
19
+ def _run(self, data, axes=None, ends=None, starts=None):
20
+ return SliceCommon._run(self, data, starts, ends, axes)
@@ -0,0 +1,16 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ class Transpose2DCastFP16(OpRun):
6
+ op_domain = "onnx_extended.ortops.optim.cuda"
7
+
8
+ def _run(self, X):
9
+ return (X.T.astype(np.float16),)
10
+
11
+
12
+ class Transpose2DCastFP32(OpRun):
13
+ op_domain = "onnx_extended.ortops.optim.cuda"
14
+
15
+ def _run(self, X):
16
+ return (X.T.astype(np.float32),)
@@ -0,0 +1,17 @@
1
+ import numpy as np
2
+ from onnx.reference.op_run import OpRun
3
+
4
+
5
+ class TriMatrix(OpRun):
6
+ op_domain = "onnx_extended.ortops.optim.cuda"
7
+
8
+ def _run(self, shape, csts):
9
+ lower, diag, upper = list(csts)
10
+ dtype = csts.dtype
11
+ mat = np.empty(tuple(shape), dtype=dtype)
12
+ i = np.arange(shape[0], dtype=np.int32).reshape((-1, 1))
13
+ j = np.arange(shape[1], dtype=np.int32).reshape((1, -1))
14
+ mat[i > j] = lower
15
+ mat[i < j] = upper
16
+ mat[i == j] = diag
17
+ return (mat,)