onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +21 -9
- onnx_diagnostic/export/api.py +15 -4
- onnx_diagnostic/export/onnx_plug.py +60 -6
- onnx_diagnostic/helpers/helper.py +26 -27
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +28 -28
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.5.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.5.dist-info}/RECORD +12 -12
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.5.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.5.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -198,15 +198,19 @@ def get_parser_print() -> ArgumentParser:
|
|
|
198
198
|
)
|
|
199
199
|
parser.add_argument(
|
|
200
200
|
"fmt",
|
|
201
|
-
choices=["pretty", "raw", "
|
|
201
|
+
choices=["dot", "pretty", "printer", "raw", "shape", "text"],
|
|
202
202
|
default="pretty",
|
|
203
203
|
help=textwrap.dedent(
|
|
204
204
|
"""
|
|
205
205
|
Prints out a model on the standard output.
|
|
206
|
-
|
|
207
|
-
|
|
206
|
+
|
|
207
|
+
dot - converts the graph into dot
|
|
208
208
|
pretty - an improved rendering
|
|
209
|
+
printer - onnx.printer.to_text(...)
|
|
210
|
+
raw - just prints the model with print(...)
|
|
211
|
+
shape - prints every node node with input and output shapes
|
|
209
212
|
text - uses GraphRendering
|
|
213
|
+
|
|
210
214
|
""".strip(
|
|
211
215
|
"\n"
|
|
212
216
|
)
|
|
@@ -232,6 +236,14 @@ def _cmd_print(argv: List[Any]):
|
|
|
232
236
|
from .helpers.graph_helper import GraphRendering
|
|
233
237
|
|
|
234
238
|
print(GraphRendering(onx).text_rendering())
|
|
239
|
+
elif args.fmt == "shape":
|
|
240
|
+
from experimental_experiment.xbuilder import GraphBuilder
|
|
241
|
+
|
|
242
|
+
print(GraphBuilder(onx).pretty_text())
|
|
243
|
+
elif args.fmt == "dot":
|
|
244
|
+
from .helpers.dot_helper import to_dot
|
|
245
|
+
|
|
246
|
+
print(to_dot(onx))
|
|
235
247
|
else:
|
|
236
248
|
raise ValueError(f"Unexpected value fmt={args.fmt!r}")
|
|
237
249
|
|
|
@@ -517,12 +529,12 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
|
|
|
517
529
|
nargs="*",
|
|
518
530
|
help=textwrap.dedent(
|
|
519
531
|
"""
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
532
|
+
Applies patches before exporting, it can be a boolean
|
|
533
|
+
to enable to disable the patches or be more finetuned
|
|
534
|
+
(default is True). It is possible to disable patch for torch
|
|
535
|
+
by adding:
|
|
536
|
+
--patch "patch_sympy=False" --patch "patch_torch=False"
|
|
537
|
+
""".strip(
|
|
526
538
|
"\n"
|
|
527
539
|
)
|
|
528
540
|
),
|
onnx_diagnostic/export/api.py
CHANGED
|
@@ -64,6 +64,7 @@ def to_onnx(
|
|
|
64
64
|
exporter_kwargs: Optional[Dict[str, Any]] = None,
|
|
65
65
|
save_ep: Optional[str] = None,
|
|
66
66
|
optimize: bool = True,
|
|
67
|
+
optimizer_for_ort: bool = True,
|
|
67
68
|
use_control_flow_dispatcher: bool = False,
|
|
68
69
|
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
|
|
69
70
|
inline: bool = True,
|
|
@@ -88,6 +89,7 @@ def to_onnx(
|
|
|
88
89
|
:param exporter_kwargs: additional parameters sent to the exporter
|
|
89
90
|
:param save_ep: saves the exported program
|
|
90
91
|
:param optimize: optimizes the model
|
|
92
|
+
:param optimizer_for_ort: optimizes the model for onnxruntime
|
|
91
93
|
:param use_control_flow_dispatcher: use the dispatcher created to supported
|
|
92
94
|
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
|
|
93
95
|
:param onnx_plugs: the code was modified to replace some parts with onnx translation
|
|
@@ -126,8 +128,10 @@ def to_onnx(
|
|
|
126
128
|
options = None
|
|
127
129
|
if exporter_kwargs is not None:
|
|
128
130
|
options = exporter_kwargs.pop("options", None)
|
|
129
|
-
if options is None:
|
|
130
|
-
options = OptimizationOptions(
|
|
131
|
+
if options is None and optimize:
|
|
132
|
+
options = OptimizationOptions(
|
|
133
|
+
patterns="default+onnxruntime" if optimizer_for_ort else "default"
|
|
134
|
+
)
|
|
131
135
|
main_dispatcher = (
|
|
132
136
|
get_main_dispatcher(use_control_flow_dispatcher, onnx_plugs)
|
|
133
137
|
if onnx_plugs or use_control_flow_dispatcher
|
|
@@ -161,6 +165,9 @@ def to_onnx(
|
|
|
161
165
|
assert (
|
|
162
166
|
not output_dynamic_shapes
|
|
163
167
|
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
|
|
168
|
+
assert (
|
|
169
|
+
optimize
|
|
170
|
+
), f"torch.onnx.export always optimizes the model but optimize={optimize}"
|
|
164
171
|
custom_translation_table = {}
|
|
165
172
|
if onnx_plugs:
|
|
166
173
|
for plug in onnx_plugs:
|
|
@@ -180,7 +187,7 @@ def to_onnx(
|
|
|
180
187
|
custom_translation_table=custom_translation_table,
|
|
181
188
|
**(exporter_kwargs or {}),
|
|
182
189
|
)
|
|
183
|
-
if not inline and optimize:
|
|
190
|
+
if not inline and optimize and optimizer_for_ort:
|
|
184
191
|
ort_fusions.optimize_for_ort(epo.model)
|
|
185
192
|
|
|
186
193
|
if onnx_plugs:
|
|
@@ -207,7 +214,7 @@ def to_onnx(
|
|
|
207
214
|
common_passes.InlinePass()(epo.model)
|
|
208
215
|
common_passes.RemoveUnusedOpsetsPass()(epo.model)
|
|
209
216
|
|
|
210
|
-
if inline and optimize:
|
|
217
|
+
if inline and optimize and optimizer_for_ort:
|
|
211
218
|
ort_fusions.optimize_for_ort(epo.model)
|
|
212
219
|
if filename:
|
|
213
220
|
epo.save(filename, external_data=True)
|
|
@@ -232,6 +239,10 @@ def to_onnx(
|
|
|
232
239
|
f"Only a specified set of inputs is supported for exporter={exporter!r}, "
|
|
233
240
|
f"but it is {list(kwargs)}" # type: ignore[arg-type]
|
|
234
241
|
)
|
|
242
|
+
assert optimizer_for_ort and optimize, (
|
|
243
|
+
f"ModelBuilder only produces model optimized for onnxruntime but "
|
|
244
|
+
f"optimizer_for_ort={optimizer_for_ort} and optimize={optimize}"
|
|
245
|
+
)
|
|
235
246
|
flat_inputs = flatten_object(kwargs, drop_keys=True)
|
|
236
247
|
first = flat_inputs[0]
|
|
237
248
|
first_float = [
|
|
@@ -128,7 +128,61 @@ class EagerDirectReplacementWithOnnx:
|
|
|
128
128
|
|
|
129
129
|
print(pretty_onnx(onx))
|
|
130
130
|
|
|
131
|
-
|
|
131
|
+
We do the same with :func:`torch.onnx.export`:
|
|
132
|
+
|
|
133
|
+
.. runpython::
|
|
134
|
+
:showcode:
|
|
135
|
+
|
|
136
|
+
import onnx.helper as oh
|
|
137
|
+
import torch
|
|
138
|
+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
|
|
139
|
+
from onnx_diagnostic.export.onnx_plug import EagerDirectReplacementWithOnnx
|
|
140
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
141
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def demo_customsub(x, y):
|
|
145
|
+
return x - y
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def demo_customsub_shape(x, y):
|
|
149
|
+
return torch.empty(torch.broadcast_shapes(x.shape, y.shape), dtype=x.dtype)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def make_function_proto():
|
|
153
|
+
return oh.make_function(
|
|
154
|
+
"onnx_plug",
|
|
155
|
+
"demo_customsub",
|
|
156
|
+
["x", "y"],
|
|
157
|
+
["z"],
|
|
158
|
+
[oh.make_node("Sub", ["x", "y"], ["z"])],
|
|
159
|
+
opset_imports=[oh.make_opsetid("", 22)],
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class Model(torch.nn.Module):
|
|
164
|
+
def forward(self, x):
|
|
165
|
+
y = x.sum(axis=1, keepdim=True)
|
|
166
|
+
d = torch.ops.onnx_plug.demo_customsub(x, y)
|
|
167
|
+
return torch.abs(d)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
replacements = [
|
|
171
|
+
EagerDirectReplacementWithOnnx(
|
|
172
|
+
demo_customsub, demo_customsub_shape, make_function_proto(), 2, 1
|
|
173
|
+
)
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
x = torch.randn((3, 4), dtype=torch.float32)
|
|
177
|
+
model = Model()
|
|
178
|
+
ds = ({0: "d1", 1: "d2"},)
|
|
179
|
+
|
|
180
|
+
# The exported program shows a custom op.
|
|
181
|
+
ep = torch.export.export(model, (x,), dynamic_shapes=use_dyn_not_str(ds))
|
|
182
|
+
print("ep")
|
|
183
|
+
|
|
184
|
+
# As the exporter knows how the replace this custom op.
|
|
185
|
+
# Let's export.
|
|
132
186
|
|
|
133
187
|
onx = to_onnx(
|
|
134
188
|
model,
|
|
@@ -152,8 +206,8 @@ class EagerDirectReplacementWithOnnx:
|
|
|
152
206
|
dtype = first_tensor.dtype
|
|
153
207
|
itype = torch_dtype_to_onnx_dtype(dtype)
|
|
154
208
|
if dtype == torch.float32:
|
|
155
|
-
if opset >=
|
|
156
|
-
return "
|
|
209
|
+
if opset >= 23:
|
|
210
|
+
return "LOOPA23", itype
|
|
157
211
|
return "LOOPMHA", itype
|
|
158
212
|
if dtype == torch.float16:
|
|
159
213
|
if first_tensor.is_cuda:
|
|
@@ -175,9 +229,9 @@ class EagerDirectReplacementWithOnnx:
|
|
|
175
229
|
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
176
230
|
PackedAttention.to_function_proto()
|
|
177
231
|
),
|
|
178
|
-
("
|
|
179
|
-
("
|
|
180
|
-
onnx.TensorProto.FLOAT16,
|
|
232
|
+
("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
|
|
233
|
+
("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
234
|
+
onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
|
|
181
235
|
),
|
|
182
236
|
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
183
237
|
LoopMHAAttention.to_function_proto()
|
|
@@ -2,6 +2,7 @@ import ast
|
|
|
2
2
|
import enum
|
|
3
3
|
import inspect
|
|
4
4
|
import itertools
|
|
5
|
+
import json
|
|
5
6
|
from dataclasses import is_dataclass, fields
|
|
6
7
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
7
8
|
import numpy as np
|
|
@@ -1373,11 +1374,7 @@ def max_diff(
|
|
|
1373
1374
|
if hist:
|
|
1374
1375
|
if isinstance(hist, bool):
|
|
1375
1376
|
hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
|
|
1376
|
-
|
|
1377
|
-
cou = np.bincount(ind, minlength=ind.shape[0] + 1)
|
|
1378
|
-
res["rep"] = dict(
|
|
1379
|
-
zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))])
|
|
1380
|
-
)
|
|
1377
|
+
res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist}
|
|
1381
1378
|
return res # type: ignore
|
|
1382
1379
|
|
|
1383
1380
|
if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
|
|
@@ -1493,27 +1490,11 @@ def max_diff(
|
|
|
1493
1490
|
dev=dev,
|
|
1494
1491
|
)
|
|
1495
1492
|
if hist:
|
|
1496
|
-
if isinstance(hist,
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
res["rep"] = {
|
|
1500
|
-
f">{hist[0]}": (diff > hist[0]).sum().item(),
|
|
1501
|
-
f">{hist[1]}": (diff > hist[1]).sum().item(),
|
|
1502
|
-
}
|
|
1503
|
-
else:
|
|
1504
|
-
if isinstance(hist, bool):
|
|
1505
|
-
hist = torch.tensor(
|
|
1506
|
-
[0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
|
|
1507
|
-
)
|
|
1508
|
-
hist = torch.tensor(hist).to(diff.device)
|
|
1509
|
-
ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
|
|
1510
|
-
cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
|
|
1511
|
-
res["rep"] = dict(
|
|
1512
|
-
zip(
|
|
1513
|
-
[f">{x}" for x in hist],
|
|
1514
|
-
[int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
|
|
1515
|
-
)
|
|
1493
|
+
if isinstance(hist, bool):
|
|
1494
|
+
hist = torch.tensor(
|
|
1495
|
+
[0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
|
|
1516
1496
|
)
|
|
1497
|
+
res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist}
|
|
1517
1498
|
return res # type: ignore
|
|
1518
1499
|
|
|
1519
1500
|
if isinstance(expected, int) and isinstance(got, torch.Tensor):
|
|
@@ -1750,8 +1731,26 @@ def max_diff(
|
|
|
1750
1731
|
)
|
|
1751
1732
|
|
|
1752
1733
|
|
|
1753
|
-
def string_diff(diff: Dict[str, Any]) -> str:
|
|
1754
|
-
"""
|
|
1734
|
+
def string_diff(diff: Dict[str, Any], js: bool = False, ratio: bool = False, **kwargs) -> str:
|
|
1735
|
+
"""
|
|
1736
|
+
Renders discrepancies return by :func:`max_diff` into one string.
|
|
1737
|
+
|
|
1738
|
+
:param diff: differences
|
|
1739
|
+
:param js: json format
|
|
1740
|
+
:param ratio: display mismatch ratio
|
|
1741
|
+
:param kwargs: addition values to add in the json format
|
|
1742
|
+
"""
|
|
1743
|
+
if js:
|
|
1744
|
+
if "rep" in diff:
|
|
1745
|
+
rep = diff["rep"]
|
|
1746
|
+
diff = {**{k: v for k, v in diff.items() if k != "rep"}, **rep}
|
|
1747
|
+
if ratio:
|
|
1748
|
+
for k, v in rep.items():
|
|
1749
|
+
diff[f"%{k}"] = v / diff["n"]
|
|
1750
|
+
diff["mean"] = diff["sum"] / diff["n"]
|
|
1751
|
+
diff.update(kwargs)
|
|
1752
|
+
return json.dumps(diff)
|
|
1753
|
+
|
|
1755
1754
|
# dict(abs=, rel=, sum=, n=n_diff, dnan=)
|
|
1756
1755
|
if "dev" in diff:
|
|
1757
1756
|
ddiff = {k: v for k, v in diff.items() if k != "dev"}
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
from typing import Callable, List, Optional, Tuple
|
|
2
3
|
import torch
|
|
3
4
|
|
|
@@ -19,6 +20,12 @@ if patch_masking_utils:
|
|
|
19
20
|
prepare_padding_mask,
|
|
20
21
|
)
|
|
21
22
|
|
|
23
|
+
_prepare_padding_mask_kwargs = (
|
|
24
|
+
dict(_slice=False)
|
|
25
|
+
if "_slice" in inspect.signature(prepare_padding_mask).parameters
|
|
26
|
+
else {}
|
|
27
|
+
)
|
|
28
|
+
|
|
22
29
|
try:
|
|
23
30
|
# transformers>=5.0
|
|
24
31
|
from transformers.masking_utils import (
|
|
@@ -132,7 +139,9 @@ if patch_masking_utils:
|
|
|
132
139
|
) -> Optional[torch.Tensor]:
|
|
133
140
|
"""manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
|
|
134
141
|
q_length = cache_position.shape[0]
|
|
135
|
-
padding_mask = prepare_padding_mask(
|
|
142
|
+
padding_mask = prepare_padding_mask(
|
|
143
|
+
attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs
|
|
144
|
+
)
|
|
136
145
|
if allow_is_causal_skip and _ignore_causal_mask_sdpa(
|
|
137
146
|
padding_mask, q_length, kv_length, kv_offset, local_size
|
|
138
147
|
):
|
|
@@ -24,7 +24,7 @@ if patch_qwen2_5:
|
|
|
24
24
|
|
|
25
25
|
onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
|
|
26
26
|
op = onnxscript.opset22
|
|
27
|
-
|
|
27
|
+
op23 = onnxscript.onnx_opset.opset23
|
|
28
28
|
msft_op = onnxscript.values.Opset("com.microsoft", 1)
|
|
29
29
|
STOPAT = (
|
|
30
30
|
int(os.environ.get("STOPAT", None))
|
|
@@ -101,7 +101,7 @@ if patch_qwen2_5:
|
|
|
101
101
|
return attn_output_4d
|
|
102
102
|
|
|
103
103
|
@onnxscript.script(opset=onnx_plugs_op)
|
|
104
|
-
def
|
|
104
|
+
def LoopAttention23(
|
|
105
105
|
query_states,
|
|
106
106
|
key_states,
|
|
107
107
|
value_states,
|
|
@@ -109,26 +109,26 @@ if patch_qwen2_5:
|
|
|
109
109
|
scaling: float = 0.11180339887498948,
|
|
110
110
|
num_heads: int = 16,
|
|
111
111
|
):
|
|
112
|
-
to_3d_shape =
|
|
113
|
-
query_transposed =
|
|
114
|
-
output_shape =
|
|
115
|
-
query_3d =
|
|
116
|
-
value_3d =
|
|
117
|
-
key_3d =
|
|
118
|
-
cu_seqlens =
|
|
119
|
-
num_patches =
|
|
120
|
-
seq_axis =
|
|
121
|
-
seq_axis_int32 =
|
|
122
|
-
seq_attn =
|
|
112
|
+
to_3d_shape = op23.Constant(value_ints=[0, 0, -1])
|
|
113
|
+
query_transposed = op23.Transpose(query_states, perm=[0, 2, 1, 3])
|
|
114
|
+
output_shape = op23.Shape(query_transposed)
|
|
115
|
+
query_3d = op23.Reshape(query_transposed, to_3d_shape)
|
|
116
|
+
value_3d = op23.Reshape(op23.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
117
|
+
key_3d = op23.Reshape(op23.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
|
|
118
|
+
cu_seqlens = op23.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
|
|
119
|
+
num_patches = op23.Size(cu_seqlens) - 1
|
|
120
|
+
seq_axis = op23.Constant(value_ints=[1])
|
|
121
|
+
seq_axis_int32 = op23.Cast(seq_axis, to=onnx.TensorProto.INT32)
|
|
122
|
+
seq_attn = op23.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
|
|
123
123
|
for i_patch in range(num_patches):
|
|
124
|
-
i_1d =
|
|
124
|
+
i_1d = op23.Reshape(i_patch, [1])
|
|
125
125
|
i_plus_1_1d = i_1d + 1
|
|
126
|
-
start =
|
|
127
|
-
end =
|
|
128
|
-
query_i =
|
|
129
|
-
key_i =
|
|
130
|
-
value_i =
|
|
131
|
-
mha_output =
|
|
126
|
+
start = op23.Gather(cu_seqlens, i_1d, axis=0)
|
|
127
|
+
end = op23.Gather(cu_seqlens, i_plus_1_1d, axis=0)
|
|
128
|
+
query_i = op23.Slice(query_3d, start, end, seq_axis_int32)
|
|
129
|
+
key_i = op23.Slice(key_3d, start, end, seq_axis_int32)
|
|
130
|
+
value_i = op23.Slice(value_3d, start, end, seq_axis_int32)
|
|
131
|
+
mha_output = op23.Attention(
|
|
132
132
|
query_i,
|
|
133
133
|
key_i,
|
|
134
134
|
value_i,
|
|
@@ -137,9 +137,9 @@ if patch_qwen2_5:
|
|
|
137
137
|
kv_num_heads=num_heads,
|
|
138
138
|
softmax_precision=onnx.TensorProto.FLOAT,
|
|
139
139
|
)
|
|
140
|
-
seq_attn =
|
|
141
|
-
attn_output =
|
|
142
|
-
attn_output_4d =
|
|
140
|
+
seq_attn = op23.SequenceInsert(seq_attn, mha_output)
|
|
141
|
+
attn_output = op23.ConcatFromSequence(seq_attn, axis=1)
|
|
142
|
+
attn_output_4d = op23.Reshape(attn_output, output_shape)
|
|
143
143
|
return attn_output_4d
|
|
144
144
|
|
|
145
145
|
@onnxscript.script(opset=onnx_plugs_op)
|
|
@@ -263,8 +263,8 @@ if patch_qwen2_5:
|
|
|
263
263
|
if strategy is not None:
|
|
264
264
|
return strategy, itype
|
|
265
265
|
if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
|
|
266
|
-
if opset >=
|
|
267
|
-
return "
|
|
266
|
+
if opset >= 23:
|
|
267
|
+
return "LOOPA23", itype
|
|
268
268
|
return "LOOPMHA", itype
|
|
269
269
|
if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
|
|
270
270
|
# first_tensor may be a SymbolicTensor (onnx).
|
|
@@ -288,9 +288,9 @@ if patch_qwen2_5:
|
|
|
288
288
|
("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
|
|
289
289
|
PackedAttention.to_function_proto()
|
|
290
290
|
),
|
|
291
|
-
("
|
|
292
|
-
("
|
|
293
|
-
onnx.TensorProto.FLOAT16,
|
|
291
|
+
("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
|
|
292
|
+
("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
|
|
293
|
+
onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
|
|
294
294
|
),
|
|
295
295
|
("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
|
|
296
296
|
LoopMHAAttention.to_function_proto()
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
onnx_diagnostic/__init__.py,sha256=
|
|
1
|
+
onnx_diagnostic/__init__.py,sha256=dCiBK_S7EOo_rAsmsgv-laLhtKzE2uny0XIR5aO4eDk,173
|
|
2
2
|
onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
|
|
3
|
-
onnx_diagnostic/_command_lines_parser.py,sha256=
|
|
3
|
+
onnx_diagnostic/_command_lines_parser.py,sha256=ZFJdQP1Ee8D5a_xUch-0CHaYbbILztejTjVdyc9KrMw,52667
|
|
4
4
|
onnx_diagnostic/api.py,sha256=BhCl_yCd78N7TlVtPOHjeYv1QBEy39TjZ647rcHqLh0,345
|
|
5
5
|
onnx_diagnostic/doc.py,sha256=t3RELgfooYnVMAi0JSpggWkQEgUsREz8NmRvn0TnLI8,2829
|
|
6
6
|
onnx_diagnostic/ext_test_case.py,sha256=rVZWqFEfnvwnsD3wF4jeDblh5uj5ckZ8C6DZQ0RGb_E,49599
|
|
7
7
|
onnx_diagnostic/export/__init__.py,sha256=yEIoWiOeTwBsDhyYt2fTKuhtA0Ya1J9u9ZzMTOTWaWs,101
|
|
8
|
-
onnx_diagnostic/export/api.py,sha256=
|
|
8
|
+
onnx_diagnostic/export/api.py,sha256=BX4c99gMlRYsBWk3P15FMRogArxjP4dXYXP5gILjgIk,10626
|
|
9
9
|
onnx_diagnostic/export/control_flow.py,sha256=zU5n_QYhNcBllyMsl1_i6ohZt2CshqG2MokJghrvA60,7751
|
|
10
10
|
onnx_diagnostic/export/control_flow_onnx.py,sha256=sODOD4v7EJj6LWhrfcdCW68r9nYKsRM4SRnqDw4TrSI,18049
|
|
11
11
|
onnx_diagnostic/export/control_flow_research.py,sha256=RuYz9_eM42Bk6TKSiPV6dS68LIMZu-6WBCFCKoSvjrk,5422
|
|
12
12
|
onnx_diagnostic/export/dynamic_shapes.py,sha256=M2hlpHSTbkzZwGKAbrpQXng5HQrwjF5Z6wGGxEgnp74,42061
|
|
13
|
-
onnx_diagnostic/export/onnx_plug.py,sha256=
|
|
13
|
+
onnx_diagnostic/export/onnx_plug.py,sha256=U13fL0BjnhMzcDGxaAOqM4TQte5Z4zKDg4ESS0iktjM,22704
|
|
14
14
|
onnx_diagnostic/export/shape_helper.py,sha256=m628y0oRCQbeZkeh8JDHIfWMsSjoJoeX-IPiPGDHT-w,11273
|
|
15
15
|
onnx_diagnostic/export/validate.py,sha256=_PGUql2DJhIgGKo0WjTGUc5AgsZUx8fEs00MePy-w98,6043
|
|
16
16
|
onnx_diagnostic/helpers/__init__.py,sha256=GJ2GT7cgnlIveVUwMZhuvUwidbTJaKv8CsSIOpZDsJg,83
|
|
@@ -23,7 +23,7 @@ onnx_diagnostic/helpers/doc_helper.py,sha256=pl5MZd3_FaE8BqQnqoBuSBxoNCFcd2OJd3e
|
|
|
23
23
|
onnx_diagnostic/helpers/dot_helper.py,sha256=hwgTJsbsUv0qq7euyPDnc1NsBZDGOwv32JXSZxIHJkE,8118
|
|
24
24
|
onnx_diagnostic/helpers/fake_tensor_helper.py,sha256=J7wnK3WTuVKnYiMzLVTAPkdJr3hQfIfMC9ZlOu7oGmI,11024
|
|
25
25
|
onnx_diagnostic/helpers/graph_helper.py,sha256=hevQT5a7_QuriVPQcbT5qe18n99Doyl5h3-qshx1-uk,14093
|
|
26
|
-
onnx_diagnostic/helpers/helper.py,sha256=
|
|
26
|
+
onnx_diagnostic/helpers/helper.py,sha256=x8EYQmgrz_G5QS_IsbeFIoDcN_sUs-CslJMHseBj1Fw,65482
|
|
27
27
|
onnx_diagnostic/helpers/log_helper.py,sha256=0lJiTF87lliI-LmgpUH_V2N8NuzJ0LryH0mSYpkRaL8,93272
|
|
28
28
|
onnx_diagnostic/helpers/memory_peak.py,sha256=M3m4_thWFIwP5HytbJYEqaijXIv5v5BW_vlcJowIYI4,6434
|
|
29
29
|
onnx_diagnostic/helpers/mini_onnx_builder.py,sha256=jR2lkRZEQ0N30H0FqeBwaxJd_w_6kyxFagrnulqFjhE,23883
|
|
@@ -117,9 +117,9 @@ onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.p
|
|
|
117
117
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py,sha256=nVgYQk0xXpHiictN1wOHVMN2lTH9b0vfIJ4ie-uKopg,1999
|
|
118
118
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py,sha256=VIZsVHgR8NmAcBQalPl5I6ZzNgcBxjGb6ars31m9gRg,21936
|
|
119
119
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py,sha256=kTjuTRsfkGGGhspJnMxAMQSchZgGC_IruJzpHh_FmI8,6348
|
|
120
|
-
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py,sha256=
|
|
120
|
+
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py,sha256=HE3fovyvMiYe9EPz1UjdD9AWopX3H188SMwPb8w5mzM,7111
|
|
121
121
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py,sha256=OxYdlLrwtd_KGHt3E17poduxvWFg-CfGS57-yN1i6gI,3827
|
|
122
|
-
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py,sha256=
|
|
122
|
+
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py,sha256=GS7IDHyRaLAsbZE5k7KN-ZT5-ezbmEUzXPJ_xG4SulA,31601
|
|
123
123
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py,sha256=cND9Iqo1aKdlX-BXGr9Qlq_Y4EW1L5VWSwZfqYTVazU,4888
|
|
124
124
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py,sha256=4bJ_z2gizZQla_fcCVt0dmuhzO9Vu-D7CCMWdxMlrKM,16893
|
|
125
125
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py,sha256=-6TuBm3sLAFEGuW3vRfOTtE5uP6aINFfu7xMnl27Dws,5703
|
|
@@ -146,8 +146,8 @@ onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJ
|
|
|
146
146
|
onnx_diagnostic/torch_onnx/runtime_info.py,sha256=u1bD6VXqzBCRmqmbzQtDswaPs1PH_ygr1r-CrcfXpNU,8562
|
|
147
147
|
onnx_diagnostic/torch_onnx/sbs.py,sha256=8okBEIupMgw7TtKc80YFimMtwnY3GchdY05FsA9ooa0,40749
|
|
148
148
|
onnx_diagnostic/torch_onnx/sbs_dataclasses.py,sha256=UctdBjzoPTQG1LS0tZ8A6E9hpoq5HWUYaJLPOPJc9FI,20299
|
|
149
|
-
onnx_diagnostic-0.8.
|
|
150
|
-
onnx_diagnostic-0.8.
|
|
151
|
-
onnx_diagnostic-0.8.
|
|
152
|
-
onnx_diagnostic-0.8.
|
|
153
|
-
onnx_diagnostic-0.8.
|
|
149
|
+
onnx_diagnostic-0.8.5.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
|
|
150
|
+
onnx_diagnostic-0.8.5.dist-info/METADATA,sha256=A54IonPIcnualwiRJhvjRMfhF3p3jdXhEH1vTtZBgyE,6734
|
|
151
|
+
onnx_diagnostic-0.8.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
152
|
+
onnx_diagnostic-0.8.5.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
|
|
153
|
+
onnx_diagnostic-0.8.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|