onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -9,6 +9,7 @@ import itertools
|
|
|
9
9
|
import logging
|
|
10
10
|
import os
|
|
11
11
|
import re
|
|
12
|
+
import shutil
|
|
12
13
|
import sys
|
|
13
14
|
import unittest
|
|
14
15
|
import warnings
|
|
@@ -63,7 +64,7 @@ def skipif_ci_apple(msg) -> Callable:
|
|
|
63
64
|
return lambda x: x
|
|
64
65
|
|
|
65
66
|
|
|
66
|
-
def unit_test_going():
|
|
67
|
+
def unit_test_going() -> bool:
|
|
67
68
|
"""
|
|
68
69
|
Enables a flag telling the script is running while testing it.
|
|
69
70
|
Avois unit tests to be very long.
|
|
@@ -147,7 +148,7 @@ def hide_stdout(f: Optional[Callable] = None) -> Callable:
|
|
|
147
148
|
|
|
148
149
|
def wrapper(fct):
|
|
149
150
|
def call_f(self):
|
|
150
|
-
if os.environ.get("UNHIDE", ""):
|
|
151
|
+
if os.environ.get("UNHIDE", "") in (1, "1", "True", "true"):
|
|
151
152
|
fct(self)
|
|
152
153
|
return
|
|
153
154
|
st = StringIO()
|
|
@@ -609,6 +610,21 @@ def requires_onnxruntime(version: str, msg: str = "") -> Callable:
|
|
|
609
610
|
return lambda x: x
|
|
610
611
|
|
|
611
612
|
|
|
613
|
+
def has_onnxruntime(version: str, msg: str = "") -> Callable:
|
|
614
|
+
"""Skips a unit test if :epkg:`onnxruntime` is not recent enough."""
|
|
615
|
+
import packaging.version as pv
|
|
616
|
+
import onnxruntime
|
|
617
|
+
|
|
618
|
+
if not hasattr(onnxruntime, "__version__"):
|
|
619
|
+
# development version
|
|
620
|
+
return True
|
|
621
|
+
|
|
622
|
+
if pv.Version(onnxruntime.__version__) < pv.Version(version):
|
|
623
|
+
msg = f"onnxruntime version {onnxruntime.__version__} < {version}: {msg}"
|
|
624
|
+
return False
|
|
625
|
+
return True
|
|
626
|
+
|
|
627
|
+
|
|
612
628
|
def has_onnxruntime_training(push_back_batch: bool = False):
|
|
613
629
|
"""Tells if onnxruntime_training is installed."""
|
|
614
630
|
try:
|
|
@@ -742,8 +758,15 @@ class ExtTestCase(unittest.TestCase):
|
|
|
742
758
|
_warns: List[Tuple[str, int, Warning]] = []
|
|
743
759
|
_todos: List[Tuple[Callable, str]] = []
|
|
744
760
|
|
|
761
|
+
def unit_test_going(self) -> bool:
|
|
762
|
+
"""
|
|
763
|
+
Enables a flag telling the script is running while testing it.
|
|
764
|
+
Avois unit tests to be very long.
|
|
765
|
+
"""
|
|
766
|
+
return unit_test_going()
|
|
767
|
+
|
|
745
768
|
@property
|
|
746
|
-
def verbose(self):
|
|
769
|
+
def verbose(self) -> int:
|
|
747
770
|
"Returns the the value of environment variable ``VERBOSE``."
|
|
748
771
|
return int(os.environ.get("VERBOSE", "0"))
|
|
749
772
|
|
|
@@ -768,13 +791,13 @@ class ExtTestCase(unittest.TestCase):
|
|
|
768
791
|
cls._todos.append((f, msg))
|
|
769
792
|
|
|
770
793
|
@classmethod
|
|
771
|
-
def ort(cls):
|
|
794
|
+
def ort(cls) -> unittest.__class__:
|
|
772
795
|
import onnxruntime
|
|
773
796
|
|
|
774
797
|
return onnxruntime
|
|
775
798
|
|
|
776
799
|
@classmethod
|
|
777
|
-
def to_onnx(self, *args, **kwargs):
|
|
800
|
+
def to_onnx(self, *args, **kwargs) -> "ModelProto": # noqa: F821
|
|
778
801
|
from experimental_experiment.torch_interpreter import to_onnx
|
|
779
802
|
|
|
780
803
|
return to_onnx(*args, **kwargs)
|
|
@@ -806,18 +829,29 @@ class ExtTestCase(unittest.TestCase):
|
|
|
806
829
|
os.makedirs(folder)
|
|
807
830
|
return folder
|
|
808
831
|
|
|
809
|
-
def
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
832
|
+
def clean_dump(self, folder: str = "dump_test"):
|
|
833
|
+
"""Cleans this folder."""
|
|
834
|
+
for item in os.listdir(folder):
|
|
835
|
+
item_path = os.path.join(folder, item)
|
|
836
|
+
if os.path.isfile(item_path) or os.path.islink(item_path):
|
|
837
|
+
os.remove(item_path)
|
|
838
|
+
elif os.path.isdir(item_path):
|
|
839
|
+
shutil.rmtree(item_path)
|
|
840
|
+
|
|
841
|
+
def dump_onnx(self, name: str, proto: Any, folder: Optional[str] = None) -> str:
|
|
815
842
|
"""Dumps an onnx file."""
|
|
816
843
|
fullname = self.get_dump_file(name, folder=folder)
|
|
817
844
|
with open(fullname, "wb") as f:
|
|
818
845
|
f.write(proto.SerializeToString())
|
|
819
846
|
return fullname
|
|
820
847
|
|
|
848
|
+
def dump_text(self, name: str, text: str, folder: Optional[str] = None) -> str:
|
|
849
|
+
"""Dumps text in a file."""
|
|
850
|
+
fullname = self.get_dump_file(name, folder=folder)
|
|
851
|
+
with open(fullname, "w") as f:
|
|
852
|
+
f.write(text)
|
|
853
|
+
return fullname
|
|
854
|
+
|
|
821
855
|
def assertExists(self, name):
|
|
822
856
|
"""Checks the existing of a file."""
|
|
823
857
|
if not os.path.exists(name):
|
|
@@ -1094,10 +1128,15 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1094
1128
|
value = numpy.array(value).astype(expected.dtype)
|
|
1095
1129
|
self.assertEqualArray(expected, value, atol=atol, rtol=rtol)
|
|
1096
1130
|
|
|
1097
|
-
def check_ort(
|
|
1131
|
+
def check_ort(
|
|
1132
|
+
self, onx: "onnx.ModelProto" # noqa: F821
|
|
1133
|
+
) -> "onnxruntime.InferenceSession": # noqa: F821
|
|
1098
1134
|
from onnxruntime import InferenceSession
|
|
1099
1135
|
|
|
1100
|
-
return InferenceSession(
|
|
1136
|
+
return InferenceSession(
|
|
1137
|
+
onx if isinstance(onx, str) else onx.SerializeToString(),
|
|
1138
|
+
providers=["CPUExecutionProvider"],
|
|
1139
|
+
)
|
|
1101
1140
|
|
|
1102
1141
|
def assertRaise(self, fct: Callable, exc_type: type[Exception], msg: Optional[str] = None):
|
|
1103
1142
|
"""In the name"""
|
|
@@ -1137,7 +1176,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1137
1176
|
if not full.endswith(suffix):
|
|
1138
1177
|
raise AssertionError(f"suffix={suffix!r} does not end string {full!r}.")
|
|
1139
1178
|
|
|
1140
|
-
def capture(self, fct: Callable):
|
|
1179
|
+
def capture(self, fct: Callable) -> Tuple[Any, str, str]:
|
|
1141
1180
|
"""
|
|
1142
1181
|
Runs a function and capture standard output and error.
|
|
1143
1182
|
|
|
@@ -1179,9 +1218,9 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1179
1218
|
def assert_onnx_disc(
|
|
1180
1219
|
self,
|
|
1181
1220
|
test_name: str,
|
|
1182
|
-
proto: "onnx.ModelProto", # noqa: F821
|
|
1221
|
+
proto: Union[str, "onnx.ModelProto"], # noqa: F821
|
|
1183
1222
|
model: "torch.nn.Module", # noqa: F821
|
|
1184
|
-
inputs: Union[Tuple[Any], Dict[str, Any]],
|
|
1223
|
+
inputs: Union[Tuple[Any], Dict[str, Any], List[Any]],
|
|
1185
1224
|
verbose: int = 0,
|
|
1186
1225
|
atol: float = 1e-5,
|
|
1187
1226
|
rtol: float = 1e-3,
|
|
@@ -1189,6 +1228,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1189
1228
|
expected: Optional[Any] = None,
|
|
1190
1229
|
use_ort: bool = False,
|
|
1191
1230
|
ort_optimized_graph: bool = False,
|
|
1231
|
+
ep: Optional[Union["torch.export.ExportedProgram", str]] = None, # noqa: F821
|
|
1192
1232
|
**kwargs,
|
|
1193
1233
|
):
|
|
1194
1234
|
"""
|
|
@@ -1208,6 +1248,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1208
1248
|
:param copy_inputs: to copy the inputs
|
|
1209
1249
|
:param use_ort: use :class:`onnxruntime.InferenceSession`
|
|
1210
1250
|
:param ort_optimized_graph: dumps the optimized onnxruntime graph
|
|
1251
|
+
:param ep: exported program (or saved exported program)
|
|
1211
1252
|
:param kwargs: arguments sent to
|
|
1212
1253
|
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
|
|
1213
1254
|
"""
|
|
@@ -1223,71 +1264,135 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1223
1264
|
name = f"{test_name}.onnx"
|
|
1224
1265
|
if verbose:
|
|
1225
1266
|
print(f"[{vname}] save the onnx model into {name!r}")
|
|
1267
|
+
model_file = None
|
|
1226
1268
|
if isinstance(proto, str):
|
|
1269
|
+
model_file = proto
|
|
1227
1270
|
name = proto
|
|
1228
1271
|
proto = onnx.load(name)
|
|
1229
|
-
|
|
1272
|
+
elif not self.unit_test_going():
|
|
1230
1273
|
assert isinstance(
|
|
1231
1274
|
proto, onnx.ModelProto
|
|
1232
1275
|
), f"Unexpected type {type(proto)} for proto"
|
|
1233
1276
|
name = self.dump_onnx(name, proto)
|
|
1234
|
-
if verbose:
|
|
1277
|
+
if verbose and not self.unit_test_going():
|
|
1235
1278
|
print(f"[{vname}] file size {os.stat(name).st_size // 2**10:1.3f} kb")
|
|
1236
1279
|
if verbose:
|
|
1237
1280
|
print(f"[{vname}] make feeds {string_type(inputs, **kws)}")
|
|
1281
|
+
|
|
1282
|
+
if not isinstance(inputs, list):
|
|
1283
|
+
inputs = [inputs]
|
|
1284
|
+
if expected is not None:
|
|
1285
|
+
expected = [expected]
|
|
1286
|
+
|
|
1287
|
+
gots = []
|
|
1238
1288
|
if use_ort:
|
|
1239
1289
|
assert isinstance(
|
|
1240
1290
|
proto, onnx.ModelProto
|
|
1241
1291
|
), f"Unexpected type {type(proto)} for proto"
|
|
1242
|
-
feeds = make_feeds(proto, inputs, use_numpy=True, copy=True)
|
|
1243
1292
|
import onnxruntime
|
|
1244
1293
|
|
|
1245
|
-
if verbose:
|
|
1246
|
-
print(f"[{vname}] create onnxruntime.InferenceSession")
|
|
1247
1294
|
options = onnxruntime.SessionOptions()
|
|
1248
1295
|
if ort_optimized_graph:
|
|
1249
1296
|
options.optimized_model_filepath = f"{name}.optort.onnx"
|
|
1297
|
+
if "log_severity_level" in kwargs:
|
|
1298
|
+
options.log_severity_level = kwargs["log_severity_level"]
|
|
1299
|
+
if "log_verbosity_level" in kwargs:
|
|
1300
|
+
options.log_verbosity_level = kwargs["log_verbosity_level"]
|
|
1301
|
+
providers = kwargs.get("providers", ["CPUExecutionProvider"])
|
|
1302
|
+
if verbose:
|
|
1303
|
+
print(f"[{vname}] create onnxruntime.InferenceSession with {providers}")
|
|
1250
1304
|
sess = onnxruntime.InferenceSession(
|
|
1251
|
-
proto.SerializeToString(),
|
|
1252
|
-
options,
|
|
1253
|
-
providers=kwargs.get("providers", ["CPUExecutionProvider"]),
|
|
1305
|
+
model_file or proto.SerializeToString(), options, providers=providers
|
|
1254
1306
|
)
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1307
|
+
for inp in inputs:
|
|
1308
|
+
feeds = make_feeds(proto, inp, use_numpy=True, copy=True)
|
|
1309
|
+
if verbose:
|
|
1310
|
+
print(f"[{vname}] run ort feeds {string_type(feeds, **kws)}")
|
|
1311
|
+
got = sess.run(None, feeds)
|
|
1312
|
+
gots.append(got)
|
|
1258
1313
|
else:
|
|
1259
|
-
feeds = make_feeds(proto, inputs, copy=True)
|
|
1260
1314
|
if verbose:
|
|
1261
1315
|
print(f"[{vname}] create InferenceSessionForTorch")
|
|
1262
1316
|
sess = InferenceSessionForTorch(proto, **kwargs)
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1317
|
+
for inp in inputs:
|
|
1318
|
+
feeds = make_feeds(proto, inp, copy=True)
|
|
1319
|
+
if verbose:
|
|
1320
|
+
print(f"[{vname}] run orttorch feeds {string_type(feeds, **kws)}")
|
|
1321
|
+
got = sess.run(None, feeds)
|
|
1322
|
+
gots.append(got)
|
|
1266
1323
|
if verbose:
|
|
1267
1324
|
print(f"[{vname}] compute expected values")
|
|
1325
|
+
|
|
1268
1326
|
if expected is None:
|
|
1269
1327
|
if copy_inputs:
|
|
1270
|
-
expected =
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1328
|
+
expected = [
|
|
1329
|
+
(
|
|
1330
|
+
model(*copy.deepcopy(inp))
|
|
1331
|
+
if isinstance(inp, tuple)
|
|
1332
|
+
else model(**copy.deepcopy(inp))
|
|
1333
|
+
)
|
|
1334
|
+
for inp in inputs
|
|
1335
|
+
]
|
|
1275
1336
|
else:
|
|
1276
|
-
expected =
|
|
1337
|
+
expected = [
|
|
1338
|
+
model(*inp) if isinstance(inp, tuple) else model(**inp) for inp in inputs
|
|
1339
|
+
]
|
|
1340
|
+
|
|
1277
1341
|
if verbose:
|
|
1278
1342
|
print(f"[{vname}] expected {string_type(expected, **kws)}")
|
|
1279
1343
|
print(f"[{vname}] obtained {string_type(got, **kws)}")
|
|
1280
|
-
|
|
1281
|
-
if
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1344
|
+
|
|
1345
|
+
if ep:
|
|
1346
|
+
if isinstance(ep, str):
|
|
1347
|
+
if verbose:
|
|
1348
|
+
print(f"[{vname}] load exported program {ep!r}")
|
|
1349
|
+
import torch
|
|
1350
|
+
|
|
1351
|
+
ep = torch.export.load(ep)
|
|
1352
|
+
|
|
1353
|
+
ep_model = ep.module() # type: ignore[union-attr]
|
|
1354
|
+
for expe, inp, got in zip(expected, inputs, gots):
|
|
1355
|
+
ep_inputs = copy.deepcopy(inp) if copy_inputs else inp
|
|
1356
|
+
ep_expected = (
|
|
1357
|
+
ep_model(*copy.deepcopy(ep_inputs))
|
|
1358
|
+
if isinstance(ep_inputs, tuple)
|
|
1359
|
+
else ep_model(**copy.deepcopy(ep_inputs))
|
|
1360
|
+
)
|
|
1361
|
+
if verbose:
|
|
1362
|
+
print(f"[{vname}] ep_expected {string_type(ep_expected, **kws)}")
|
|
1363
|
+
ep_diff = max_diff(expe, ep_expected, hist=[0.1, 0.01])
|
|
1364
|
+
if verbose:
|
|
1365
|
+
print(f"[{vname}] ep_diff {string_diff(ep_diff)}")
|
|
1366
|
+
assert (
|
|
1367
|
+
isinstance(ep_diff["abs"], float)
|
|
1368
|
+
and isinstance(ep_diff["rel"], float)
|
|
1369
|
+
and not numpy.isnan(ep_diff["abs"])
|
|
1370
|
+
and ep_diff["abs"] <= atol
|
|
1371
|
+
and not numpy.isnan(ep_diff["rel"])
|
|
1372
|
+
and ep_diff["rel"] <= rtol
|
|
1373
|
+
), (
|
|
1374
|
+
f"discrepancies in {test_name!r} between the exported program "
|
|
1375
|
+
f"and the exported model diff={string_diff(ep_diff)}"
|
|
1376
|
+
)
|
|
1377
|
+
ep_nx_diff = max_diff(ep_expected, got, flatten=True, hist=[0.1, 0.01])
|
|
1378
|
+
if verbose:
|
|
1379
|
+
print(f"[{vname}] ep_nx_diff {string_diff(ep_nx_diff)}")
|
|
1380
|
+
|
|
1381
|
+
for expe, got in zip(expected, gots):
|
|
1382
|
+
diff = max_diff(expe, got, flatten=True, hist=[0.1, 0.01])
|
|
1383
|
+
if verbose:
|
|
1384
|
+
print(f"[{vname}] diff {string_diff(diff)}")
|
|
1385
|
+
assert (
|
|
1386
|
+
isinstance(diff["abs"], float)
|
|
1387
|
+
and isinstance(diff["rel"], float)
|
|
1388
|
+
and not numpy.isnan(diff["abs"])
|
|
1389
|
+
and diff["abs"] <= atol
|
|
1390
|
+
and not numpy.isnan(diff["rel"])
|
|
1391
|
+
and diff["rel"] <= rtol
|
|
1392
|
+
), (
|
|
1393
|
+
f"discrepancies in {test_name!r} between the model and "
|
|
1394
|
+
f"the onnx model diff={string_diff(diff)}"
|
|
1395
|
+
)
|
|
1291
1396
|
|
|
1292
1397
|
def _debug(self):
|
|
1293
1398
|
"Tells if DEBUG=1 is set up."
|
|
@@ -1298,6 +1403,16 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1298
1403
|
|
|
1299
1404
|
return string_type(*args, **kwargs)
|
|
1300
1405
|
|
|
1406
|
+
def max_diff(self, *args, **kwargs):
|
|
1407
|
+
from .helpers import max_diff
|
|
1408
|
+
|
|
1409
|
+
return max_diff(*args, **kwargs)
|
|
1410
|
+
|
|
1411
|
+
def use_dyn_not_str(self, *args, **kwargs):
|
|
1412
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
1413
|
+
|
|
1414
|
+
return use_dyn_not_str(*args, *kwargs)
|
|
1415
|
+
|
|
1301
1416
|
def subloop(self, *args, verbose: int = 0):
|
|
1302
1417
|
"Loops over elements and calls :meth:`unittests.TestCase.subTest`."
|
|
1303
1418
|
if len(args) == 1:
|
|
@@ -80,7 +80,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
80
80
|
start = 0
|
|
81
81
|
end = 0
|
|
82
82
|
subtrees = []
|
|
83
|
-
for subspec in spec.children_specs:
|
|
83
|
+
for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs):
|
|
84
84
|
end += subspec.num_leaves
|
|
85
85
|
value = subspec.unflatten(flat[start:end])
|
|
86
86
|
value = flatten_unflatten_for_dynamic_shapes(
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
import numpy as np
|
|
3
|
+
import onnx
|
|
4
|
+
import onnx.numpy_helper as onh
|
|
5
|
+
from ..reference import ExtendedReferenceEvaluator as Inference
|
|
6
|
+
from .onnx_helper import onnx_dtype_name, pretty_onnx, get_hidden_inputs
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _make_node_label(node: onnx.NodeProto, tiny_inits: Dict[str, str]) -> str:
|
|
10
|
+
els = [f"{node.domain}.\\n{node.op_type}" if node.domain else node.op_type, "\\n("]
|
|
11
|
+
ee = [tiny_inits.get(i, ".") if i else "" for i in node.input]
|
|
12
|
+
for att in node.attribute:
|
|
13
|
+
if att.name == "to":
|
|
14
|
+
ee.append(f"{att.name}={onnx_dtype_name(att.i)}")
|
|
15
|
+
elif att.name in {"to", "axis", "value_int", "stash_type", "start", "end"}:
|
|
16
|
+
ee.append(f"{att.name}={att.i}")
|
|
17
|
+
elif att.name in {"value_float"}:
|
|
18
|
+
ee.append(f"{att.name}={att.f}")
|
|
19
|
+
elif att.name in {"value_floats"}:
|
|
20
|
+
ee.append(f"{att.name}={att.floats}")
|
|
21
|
+
elif att.name in {"value_ints", "perm"}:
|
|
22
|
+
ee.append(f"{att.name}={att.ints}")
|
|
23
|
+
els.append(", ".join(ee))
|
|
24
|
+
els.append(")")
|
|
25
|
+
if node.op_type == "Constant":
|
|
26
|
+
els.extend([" -> ", node.output[0]])
|
|
27
|
+
res = "".join(els)
|
|
28
|
+
if len(res) < 40:
|
|
29
|
+
return res.replace("\\n(", "(")
|
|
30
|
+
return res
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _make_edge_label(value_info: onnx.ValueInfoProto, multi_line: bool = False) -> str:
|
|
34
|
+
itype = value_info.type.tensor_type.elem_type
|
|
35
|
+
if itype == onnx.TensorProto.UNDEFINED:
|
|
36
|
+
return ""
|
|
37
|
+
shape = tuple(
|
|
38
|
+
d.dim_param if d.dim_param else d.dim_value
|
|
39
|
+
for d in value_info.type.tensor_type.shape.dim
|
|
40
|
+
)
|
|
41
|
+
res = [
|
|
42
|
+
str(a)
|
|
43
|
+
for a in [("?" if isinstance(s, str) and s.startswith("unk") else s) for s in shape]
|
|
44
|
+
]
|
|
45
|
+
sshape = ",".join(res)
|
|
46
|
+
if multi_line and len(sshape) > 30:
|
|
47
|
+
sshape = ",\\n".join(res)
|
|
48
|
+
return f"{onnx_dtype_name(itype)}({sshape})"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def to_dot(model: onnx.ModelProto) -> str:
|
|
52
|
+
"""
|
|
53
|
+
Converts a model into a dot graph.
|
|
54
|
+
Here is an example:
|
|
55
|
+
|
|
56
|
+
.. gdot::
|
|
57
|
+
:script: DOT-SECTION
|
|
58
|
+
:process:
|
|
59
|
+
|
|
60
|
+
from onnx_diagnostic.helpers.dot_helper import to_dot
|
|
61
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
62
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
63
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
64
|
+
|
|
65
|
+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
|
|
66
|
+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
|
|
67
|
+
with torch_export_patches(patch_transformers=True):
|
|
68
|
+
em = to_onnx(model, inputs, dynamic_shapes=ds, exporter="custom")
|
|
69
|
+
dot = to_dot(em.model_proto)
|
|
70
|
+
print("DOT-SECTION", dot)
|
|
71
|
+
|
|
72
|
+
Or this one obtained with :func:`torch.onnx.export`.
|
|
73
|
+
|
|
74
|
+
.. gdot::
|
|
75
|
+
:script: DOT-SECTION
|
|
76
|
+
:process:
|
|
77
|
+
|
|
78
|
+
from onnx_diagnostic.helpers.dot_helper import to_dot
|
|
79
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
80
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
81
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
82
|
+
|
|
83
|
+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
|
|
84
|
+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
|
|
85
|
+
with torch_export_patches(patch_transformers=True):
|
|
86
|
+
em = to_onnx(model, kwargs=inputs, dynamic_shapes=ds, exporter="onnx-dynamo")
|
|
87
|
+
dot = to_dot(em.model_proto)
|
|
88
|
+
print("DOT-SECTION", dot)
|
|
89
|
+
"""
|
|
90
|
+
_unique: Dict[int, int] = {}
|
|
91
|
+
|
|
92
|
+
def _mkn(obj: object) -> int:
|
|
93
|
+
id_obj = id(obj)
|
|
94
|
+
if id_obj in _unique:
|
|
95
|
+
return _unique[id_obj]
|
|
96
|
+
i = len(_unique)
|
|
97
|
+
_unique[id_obj] = i
|
|
98
|
+
return i
|
|
99
|
+
|
|
100
|
+
model = onnx.shape_inference.infer_shapes(model)
|
|
101
|
+
|
|
102
|
+
op_type_colors = {
|
|
103
|
+
"Shape": "#d2a81f",
|
|
104
|
+
"MatMul": "#ee9999",
|
|
105
|
+
"Transpose": "#ee99ee",
|
|
106
|
+
"Reshape": "#eeeeee",
|
|
107
|
+
"Squeeze": "#eeeeee",
|
|
108
|
+
"Unsqueeze": "#eeeeee",
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
edge_label = {}
|
|
112
|
+
for val in model.graph.value_info:
|
|
113
|
+
edge_label[val.name] = _make_edge_label(val, multi_line=True)
|
|
114
|
+
|
|
115
|
+
rows = [
|
|
116
|
+
"digraph {",
|
|
117
|
+
(
|
|
118
|
+
" graph [rankdir=TB, splines=true, overlap=false, nodesep=0.2, "
|
|
119
|
+
"ranksep=0.2, fontsize=8];"
|
|
120
|
+
),
|
|
121
|
+
' node [style="rounded,filled", color="#888888", fontcolor="#222222", shape=box];',
|
|
122
|
+
" edge [arrowhead=vee, fontsize=7, labeldistance=-5, labelangle=0];",
|
|
123
|
+
]
|
|
124
|
+
inputs = list(model.graph.input)
|
|
125
|
+
outputs = list(model.graph.output)
|
|
126
|
+
nodes = list(model.graph.node)
|
|
127
|
+
inits = list(model.graph.initializer)
|
|
128
|
+
tiny_inits = {}
|
|
129
|
+
name_to_ids = {}
|
|
130
|
+
|
|
131
|
+
for inp in inputs:
|
|
132
|
+
if not inp.name:
|
|
133
|
+
continue
|
|
134
|
+
lab = _make_edge_label(inp)
|
|
135
|
+
rows.append(f' I_{_mkn(inp)} [label="{inp.name}\\n{lab}", fillcolor="#aaeeaa"];')
|
|
136
|
+
name_to_ids[inp.name] = f"I_{_mkn(inp)}"
|
|
137
|
+
edge_label[inp.name] = _make_edge_label(inp, multi_line=True)
|
|
138
|
+
|
|
139
|
+
# Small constant --> initializer
|
|
140
|
+
output_names = {n.name for n in outputs}
|
|
141
|
+
for node in nodes:
|
|
142
|
+
if node.op_type != "Constant" or node.output[0] in output_names:
|
|
143
|
+
continue
|
|
144
|
+
skip = False
|
|
145
|
+
for att in node.attribute:
|
|
146
|
+
if att.name == "value" and (
|
|
147
|
+
len(att.t.dims) > 1 or np.prod(tuple(att.t.dims)) > 10
|
|
148
|
+
):
|
|
149
|
+
skip = True
|
|
150
|
+
break
|
|
151
|
+
if skip:
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
sess = Inference(node)
|
|
155
|
+
value = sess.run(None, {})[0]
|
|
156
|
+
inits.append(onh.from_array(value, name=node.output[0]))
|
|
157
|
+
|
|
158
|
+
for init in inits:
|
|
159
|
+
if init.name in name_to_ids:
|
|
160
|
+
# hide optional inputs
|
|
161
|
+
continue
|
|
162
|
+
shape = tuple(init.dims)
|
|
163
|
+
if len(shape) == 0 or (len(shape) == 1 and shape[0] < 10):
|
|
164
|
+
a = onh.to_array(init)
|
|
165
|
+
tiny_inits[init.name] = (
|
|
166
|
+
str(a) if len(shape) == 0 else f"[{', '.join([str(i) for i in a])}]"
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
ls = f"{onnx_dtype_name(init.data_type)}({', '.join(map(str,shape))})"
|
|
170
|
+
rows.append(f' i_{_mkn(init)} [label="{init.name}\\n{ls}", fillcolor="#cccc00"];')
|
|
171
|
+
name_to_ids[init.name] = f"i_{_mkn(init)}"
|
|
172
|
+
edge_label[init.name] = ls
|
|
173
|
+
|
|
174
|
+
for node in nodes:
|
|
175
|
+
if node.op_type == "Constant" and node.output[0] in tiny_inits:
|
|
176
|
+
continue
|
|
177
|
+
color = op_type_colors.get(node.op_type, "#cccccc")
|
|
178
|
+
label = _make_node_label(node, tiny_inits)
|
|
179
|
+
rows.append(f' {node.op_type}_{_mkn(node)} [label="{label}", fillcolor="{color}"];')
|
|
180
|
+
name_to_ids.update({o: f"{node.op_type}_{_mkn(node)}" for o in node.output if o})
|
|
181
|
+
|
|
182
|
+
# nodes
|
|
183
|
+
done = set()
|
|
184
|
+
for node in nodes:
|
|
185
|
+
names = list(node.input)
|
|
186
|
+
for i in names:
|
|
187
|
+
if not i or i in tiny_inits:
|
|
188
|
+
continue
|
|
189
|
+
if i not in name_to_ids:
|
|
190
|
+
raise ValueError(f"Unable to find {i!r}\n{pretty_onnx(model)}")
|
|
191
|
+
edge = name_to_ids[i], f"{node.op_type}_{_mkn(node)}"
|
|
192
|
+
if edge in done:
|
|
193
|
+
continue
|
|
194
|
+
done.add(edge)
|
|
195
|
+
lab = edge_label.get(i, "")
|
|
196
|
+
if lab:
|
|
197
|
+
ls = ",".join([f'label="{lab}"'])
|
|
198
|
+
lab = f" [{ls}]"
|
|
199
|
+
rows.append(f" {edge[0]} -> {edge[1]}{lab};")
|
|
200
|
+
if node.op_type in {"Scan", "Loop", "If"}:
|
|
201
|
+
unique = set()
|
|
202
|
+
for att in node.attribute:
|
|
203
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
204
|
+
unique |= get_hidden_inputs(att.g)
|
|
205
|
+
for i in unique:
|
|
206
|
+
edge = name_to_ids[i], _mkn(node) # type: ignore[assignment]
|
|
207
|
+
if edge in done:
|
|
208
|
+
continue
|
|
209
|
+
done.add(edge)
|
|
210
|
+
rows.append(f" {edge[0]} -> {edge[1]} [style=dotted];")
|
|
211
|
+
|
|
212
|
+
# outputs
|
|
213
|
+
for out in outputs:
|
|
214
|
+
if not out.name:
|
|
215
|
+
continue
|
|
216
|
+
lab = _make_edge_label(out)
|
|
217
|
+
rows.append(f' O_{_mkn(out)} [label="{out.name}\\n{lab}", fillcolor="#aaaaee"];')
|
|
218
|
+
edge = name_to_ids[out.name], f"O_{_mkn(out)}"
|
|
219
|
+
rows.append(f" {edge[0]} -> {edge[1]};")
|
|
220
|
+
|
|
221
|
+
rows.append("}")
|
|
222
|
+
return "\n".join(rows)
|