onnx-diagnostic 0.8.8__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/doc.py +258 -8
- onnx_diagnostic/export/api.py +478 -17
- onnx_diagnostic/export/dynamic_shapes.py +21 -6
- onnx_diagnostic/export/shape_helper.py +0 -8
- onnx_diagnostic/helpers/cache_helper.py +98 -13
- onnx_diagnostic/helpers/helper.py +6 -5
- onnx_diagnostic/helpers/onnx_helper.py +7 -0
- onnx_diagnostic/helpers/rt_helper.py +14 -1
- onnx_diagnostic/helpers/torch_helper.py +22 -9
- onnx_diagnostic/tasks/image_text_to_text.py +4 -1
- onnx_diagnostic/tasks/text_generation.py +17 -17
- onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +62 -38
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -9
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/RECORD +21 -21
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.8.dist-info → onnx_diagnostic-0.8.9.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
onnx_diagnostic/doc.py
CHANGED
|
@@ -1,5 +1,11 @@
|
|
|
1
|
-
|
|
1
|
+
import os
|
|
2
|
+
import subprocess
|
|
3
|
+
import sys
|
|
4
|
+
import tempfile
|
|
5
|
+
from typing import List, Optional, Tuple, Union
|
|
2
6
|
import numpy as np
|
|
7
|
+
import onnx
|
|
8
|
+
from .helpers.dot_helper import to_dot
|
|
3
9
|
|
|
4
10
|
|
|
5
11
|
def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
|
|
@@ -36,6 +42,15 @@ def reset_torch_transformers(gallery_conf, fname):
|
|
|
36
42
|
def plot_legend(
|
|
37
43
|
text: str, text_bottom: str = "", color: str = "green", fontsize: int = 15
|
|
38
44
|
) -> "matplotlib.axes.Axes": # noqa: F821
|
|
45
|
+
"""
|
|
46
|
+
Plots a graph with only text (for :epkg:`sphinx-gallery`).
|
|
47
|
+
|
|
48
|
+
:param text: legend
|
|
49
|
+
:param text_bottom: text at the bottom
|
|
50
|
+
:param color: color
|
|
51
|
+
:param fontsize: font size
|
|
52
|
+
:return: axis
|
|
53
|
+
"""
|
|
39
54
|
import matplotlib.pyplot as plt
|
|
40
55
|
|
|
41
56
|
fig = plt.figure(figsize=(2, 2))
|
|
@@ -66,17 +81,14 @@ def rotate_align(ax, angle=15, align="right"):
|
|
|
66
81
|
return ax
|
|
67
82
|
|
|
68
83
|
|
|
69
|
-
def save_fig(ax, name: str):
|
|
84
|
+
def save_fig(ax, name: str, **kwargs) -> "matplotlib.axis.Axis": # noqa: F821
|
|
70
85
|
"""Applies ``tight_layout`` and saves the figures. Returns ax."""
|
|
71
|
-
import matplotlib.pyplot as plt
|
|
72
|
-
|
|
73
|
-
plt.tight_layout()
|
|
74
86
|
fig = ax.get_figure()
|
|
75
|
-
fig.savefig(name)
|
|
87
|
+
fig.savefig(name, **kwargs)
|
|
76
88
|
return ax
|
|
77
89
|
|
|
78
90
|
|
|
79
|
-
def title(ax: "plt.axes", title: str) -> "
|
|
91
|
+
def title(ax: "plt.axes", title: str) -> "matplotlib.axis.Axis": # noqa: F821
|
|
80
92
|
"Adds a title to axes and returns them."
|
|
81
93
|
ax.set_title(title)
|
|
82
94
|
return ax
|
|
@@ -88,7 +100,7 @@ def plot_histogram(
|
|
|
88
100
|
bins: int = 30,
|
|
89
101
|
color: str = "orange",
|
|
90
102
|
alpha: float = 0.7,
|
|
91
|
-
) -> "
|
|
103
|
+
) -> "matplotlib.axis.Axis": # noqa: F821
|
|
92
104
|
"Computes the distribution for a tensor."
|
|
93
105
|
if ax is None:
|
|
94
106
|
import matplotlib.pyplot as plt
|
|
@@ -98,3 +110,241 @@ def plot_histogram(
|
|
|
98
110
|
ax.hist(tensor, bins=30, color="orange", alpha=0.7)
|
|
99
111
|
ax.set_yscale("log")
|
|
100
112
|
return ax
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _find_in_PATH(prog: str) -> Optional[str]:
|
|
116
|
+
"""
|
|
117
|
+
Looks into every path mentioned in ``%PATH%`` a specific file,
|
|
118
|
+
it raises an exception if not found.
|
|
119
|
+
|
|
120
|
+
:param prog: program to look for
|
|
121
|
+
:return: path
|
|
122
|
+
"""
|
|
123
|
+
sep = ";" if sys.platform.startswith("win") else ":"
|
|
124
|
+
path = os.environ["PATH"]
|
|
125
|
+
for p in path.split(sep):
|
|
126
|
+
f = os.path.join(p, prog)
|
|
127
|
+
if os.path.exists(f):
|
|
128
|
+
return p
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _find_graphviz_dot(exc: bool = True) -> str:
|
|
133
|
+
"""
|
|
134
|
+
Determines the path to graphviz (on Windows),
|
|
135
|
+
the function tests the existence of versions 34 to 45
|
|
136
|
+
assuming it was installed in a standard folder:
|
|
137
|
+
``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``.
|
|
138
|
+
|
|
139
|
+
:param exc: raise exception of be silent
|
|
140
|
+
:return: path to dot
|
|
141
|
+
:raises FileNotFoundError: if graphviz not found
|
|
142
|
+
"""
|
|
143
|
+
if sys.platform.startswith("win"):
|
|
144
|
+
version = list(range(34, 60))
|
|
145
|
+
version.extend([f"{v}.1" for v in version])
|
|
146
|
+
for v in version:
|
|
147
|
+
graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe"
|
|
148
|
+
if os.path.exists(graphviz_dot):
|
|
149
|
+
return graphviz_dot
|
|
150
|
+
extra = ["build/update_modules/Graphviz/bin"]
|
|
151
|
+
for ext in extra:
|
|
152
|
+
graphviz_dot = os.path.join(ext, "dot.exe")
|
|
153
|
+
if os.path.exists(graphviz_dot):
|
|
154
|
+
return graphviz_dot
|
|
155
|
+
p = _find_in_PATH("dot.exe")
|
|
156
|
+
if p is None:
|
|
157
|
+
if exc:
|
|
158
|
+
raise FileNotFoundError(
|
|
159
|
+
f"Unable to find graphviz, look into paths such as {graphviz_dot}."
|
|
160
|
+
)
|
|
161
|
+
return None
|
|
162
|
+
return os.path.join(p, "dot.exe")
|
|
163
|
+
# linux
|
|
164
|
+
return "dot"
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _run_subprocess(args: List[str], cwd: Optional[str] = None):
|
|
168
|
+
assert not isinstance(args, str), "args should be a sequence of strings, not a string."
|
|
169
|
+
|
|
170
|
+
p = subprocess.Popen(
|
|
171
|
+
args,
|
|
172
|
+
cwd=cwd,
|
|
173
|
+
shell=False,
|
|
174
|
+
env=os.environ,
|
|
175
|
+
stdout=subprocess.PIPE,
|
|
176
|
+
stderr=subprocess.PIPE,
|
|
177
|
+
)
|
|
178
|
+
raise_exception = False
|
|
179
|
+
output = ""
|
|
180
|
+
while True:
|
|
181
|
+
output = p.stdout.readline().decode(errors="ignore") # type: ignore[union-attr]
|
|
182
|
+
if output == "" and p.poll() is not None:
|
|
183
|
+
break
|
|
184
|
+
if output:
|
|
185
|
+
if (
|
|
186
|
+
"fatal error" in output
|
|
187
|
+
or "CMake Error" in output
|
|
188
|
+
or "gmake: ***" in output
|
|
189
|
+
or "): error C" in output
|
|
190
|
+
or ": error: " in output
|
|
191
|
+
):
|
|
192
|
+
raise_exception = True
|
|
193
|
+
p.poll()
|
|
194
|
+
error = p.stderr.readline().decode(errors="ignore") # type: ignore[union-attr]
|
|
195
|
+
p.stdout.close() # type: ignore[union-attr]
|
|
196
|
+
p.stderr.close() # type: ignore[union-attr]
|
|
197
|
+
if error and raise_exception:
|
|
198
|
+
raise RuntimeError(
|
|
199
|
+
f"An error was found in the output. The build is stopped."
|
|
200
|
+
f"\n{output}\n---\n{error}"
|
|
201
|
+
)
|
|
202
|
+
return output + "\n" + error
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str:
|
|
206
|
+
"""
|
|
207
|
+
Run :epkg:`Graphviz`.
|
|
208
|
+
|
|
209
|
+
:param filename: filename which contains the graph definition
|
|
210
|
+
:param image: output image
|
|
211
|
+
:param engine: *dot* or *neato*
|
|
212
|
+
:return: output of graphviz
|
|
213
|
+
"""
|
|
214
|
+
ext = os.path.splitext(image)[-1]
|
|
215
|
+
assert ext in {
|
|
216
|
+
".png",
|
|
217
|
+
".bmp",
|
|
218
|
+
".fig",
|
|
219
|
+
".gif",
|
|
220
|
+
".ico",
|
|
221
|
+
".jpg",
|
|
222
|
+
".jpeg",
|
|
223
|
+
".pdf",
|
|
224
|
+
".ps",
|
|
225
|
+
".svg",
|
|
226
|
+
".vrml",
|
|
227
|
+
".tif",
|
|
228
|
+
".tiff",
|
|
229
|
+
".wbmp",
|
|
230
|
+
}, f"Unexpected extension {ext!r} for {image!r}."
|
|
231
|
+
if sys.platform.startswith("win"):
|
|
232
|
+
bin_ = os.path.dirname(_find_graphviz_dot())
|
|
233
|
+
# if bin not in os.environ["PATH"]:
|
|
234
|
+
# os.environ["PATH"] = os.environ["PATH"] + ";" + bin
|
|
235
|
+
exe = os.path.join(bin_, engine)
|
|
236
|
+
else:
|
|
237
|
+
exe = engine
|
|
238
|
+
if os.path.exists(image):
|
|
239
|
+
os.remove(image)
|
|
240
|
+
cmd = [exe, f"-T{ext[1:]}", filename, "-o", image]
|
|
241
|
+
output = _run_subprocess(cmd)
|
|
242
|
+
assert os.path.exists(image), (
|
|
243
|
+
f"Unable to find {image!r}, command line is "
|
|
244
|
+
f"{' '.join(cmd)!r}, Graphviz failed due to\n{output}"
|
|
245
|
+
)
|
|
246
|
+
return output
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def draw_graph_graphviz(
|
|
250
|
+
dot: Union[str, onnx.ModelProto], image: str, engine: str = "dot"
|
|
251
|
+
) -> str:
|
|
252
|
+
"""
|
|
253
|
+
Draws a graph using :epkg:`Graphviz`.
|
|
254
|
+
|
|
255
|
+
:param dot: dot graph or ModelProto
|
|
256
|
+
:param image: output image, None, just returns the output
|
|
257
|
+
:param engine: *dot* or *neato*
|
|
258
|
+
:return: :epkg:`Graphviz` output or
|
|
259
|
+
the dot text if *image* is None
|
|
260
|
+
|
|
261
|
+
The function creates a temporary file to store the dot file if *image* is not None.
|
|
262
|
+
"""
|
|
263
|
+
if isinstance(dot, onnx.ModelProto):
|
|
264
|
+
sdot = to_dot(dot)
|
|
265
|
+
else:
|
|
266
|
+
if "{" not in dot:
|
|
267
|
+
assert dot.endswith(".onnx"), f"Unexpected file extension for {dot!r}"
|
|
268
|
+
proto = onnx.load(dot)
|
|
269
|
+
sdot = to_dot(proto)
|
|
270
|
+
else:
|
|
271
|
+
sdot = dot
|
|
272
|
+
assert "{" in sdot, f"This string is not a dot string\n{sdot}"
|
|
273
|
+
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
|
274
|
+
fp.write(sdot.encode("utf-8"))
|
|
275
|
+
fp.close()
|
|
276
|
+
|
|
277
|
+
filename = fp.name
|
|
278
|
+
assert os.path.exists(
|
|
279
|
+
filename
|
|
280
|
+
), f"File {filename!r} cannot be created to store the graph."
|
|
281
|
+
out = _run_graphviz(filename, image, engine=engine)
|
|
282
|
+
assert os.path.exists(
|
|
283
|
+
image
|
|
284
|
+
), f"Graphviz failed with no reason, {image!r} not found, output is {out}."
|
|
285
|
+
os.remove(filename)
|
|
286
|
+
return out
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def plot_dot(
|
|
290
|
+
dot: Union[str, onnx.ModelProto],
|
|
291
|
+
ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821
|
|
292
|
+
engine: str = "dot",
|
|
293
|
+
figsize: Optional[Tuple[int, int]] = None,
|
|
294
|
+
) -> "matplotlib.axis.Axis": # noqa: F821
|
|
295
|
+
"""
|
|
296
|
+
Draws a dot graph into a matplotlib graph.
|
|
297
|
+
|
|
298
|
+
:param dot: dot graph or ModelProto
|
|
299
|
+
:param image: output image, None, just returns the output
|
|
300
|
+
:param engine: *dot* or *neato*
|
|
301
|
+
:param figsize: figsize of ax is None
|
|
302
|
+
:return: :epkg:`Graphviz` output or, the dot text if *image* is None
|
|
303
|
+
|
|
304
|
+
.. plot::
|
|
305
|
+
|
|
306
|
+
import matplotlib.pyplot as plt
|
|
307
|
+
import onnx.parser
|
|
308
|
+
from onnx_diagnostic.doc import plot_dot
|
|
309
|
+
|
|
310
|
+
model = onnx.parser.parse_model(
|
|
311
|
+
'''
|
|
312
|
+
<ir_version: 8, opset_import: [ "": 18]>
|
|
313
|
+
agraph (float[N] x) => (float[N] z) {
|
|
314
|
+
two = Constant <value_float=2.0> ()
|
|
315
|
+
four = Add(two, two)
|
|
316
|
+
z = Mul(four, four)
|
|
317
|
+
}
|
|
318
|
+
''')
|
|
319
|
+
|
|
320
|
+
ax = plot_dot(model)
|
|
321
|
+
ax.set_title("Dummy graph")
|
|
322
|
+
plt.show()
|
|
323
|
+
"""
|
|
324
|
+
if ax is None:
|
|
325
|
+
import matplotlib.pyplot as plt
|
|
326
|
+
|
|
327
|
+
_, ax = plt.subplots(1, 1, figsize=figsize)
|
|
328
|
+
clean = True
|
|
329
|
+
else:
|
|
330
|
+
clean = False
|
|
331
|
+
|
|
332
|
+
from PIL import Image
|
|
333
|
+
|
|
334
|
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp:
|
|
335
|
+
fp.close()
|
|
336
|
+
|
|
337
|
+
draw_graph_graphviz(dot, fp.name, engine=engine)
|
|
338
|
+
img = np.asarray(Image.open(fp.name))
|
|
339
|
+
os.remove(fp.name)
|
|
340
|
+
|
|
341
|
+
ax.imshow(img)
|
|
342
|
+
|
|
343
|
+
if clean:
|
|
344
|
+
ax.set_xticks([])
|
|
345
|
+
ax.set_yticks([])
|
|
346
|
+
ax.get_xaxis().set_visible(False)
|
|
347
|
+
ax.get_yaxis().set_visible(False)
|
|
348
|
+
ax.set_axis_off()
|
|
349
|
+
ax.get_figure().tight_layout()
|
|
350
|
+
return ax
|