onnx-diagnostic 0.8.8__py3-none-any.whl → 0.8.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.8"
6
+ __version__ = "0.8.10"
7
7
  __author__ = "Xavier Dupré"
onnx_diagnostic/doc.py CHANGED
@@ -1,5 +1,11 @@
1
- from typing import Optional
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ import tempfile
5
+ from typing import List, Optional, Tuple, Union
2
6
  import numpy as np
7
+ import onnx
8
+ from .helpers.dot_helper import to_dot
3
9
 
4
10
 
5
11
  def get_latest_pypi_version(package_name="onnx-diagnostic") -> str:
@@ -36,6 +42,15 @@ def reset_torch_transformers(gallery_conf, fname):
36
42
  def plot_legend(
37
43
  text: str, text_bottom: str = "", color: str = "green", fontsize: int = 15
38
44
  ) -> "matplotlib.axes.Axes": # noqa: F821
45
+ """
46
+ Plots a graph with only text (for :epkg:`sphinx-gallery`).
47
+
48
+ :param text: legend
49
+ :param text_bottom: text at the bottom
50
+ :param color: color
51
+ :param fontsize: font size
52
+ :return: axis
53
+ """
39
54
  import matplotlib.pyplot as plt
40
55
 
41
56
  fig = plt.figure(figsize=(2, 2))
@@ -66,17 +81,14 @@ def rotate_align(ax, angle=15, align="right"):
66
81
  return ax
67
82
 
68
83
 
69
- def save_fig(ax, name: str):
84
+ def save_fig(ax, name: str, **kwargs) -> "matplotlib.axis.Axis": # noqa: F821
70
85
  """Applies ``tight_layout`` and saves the figures. Returns ax."""
71
- import matplotlib.pyplot as plt
72
-
73
- plt.tight_layout()
74
86
  fig = ax.get_figure()
75
- fig.savefig(name)
87
+ fig.savefig(name, **kwargs)
76
88
  return ax
77
89
 
78
90
 
79
- def title(ax: "plt.axes", title: str) -> "plt.axes": # noqa: F821
91
+ def title(ax: "plt.axes", title: str) -> "matplotlib.axis.Axis": # noqa: F821
80
92
  "Adds a title to axes and returns them."
81
93
  ax.set_title(title)
82
94
  return ax
@@ -88,7 +100,7 @@ def plot_histogram(
88
100
  bins: int = 30,
89
101
  color: str = "orange",
90
102
  alpha: float = 0.7,
91
- ) -> "plt.axes": # noqa: F821
103
+ ) -> "matplotlib.axis.Axis": # noqa: F821
92
104
  "Computes the distribution for a tensor."
93
105
  if ax is None:
94
106
  import matplotlib.pyplot as plt
@@ -98,3 +110,241 @@ def plot_histogram(
98
110
  ax.hist(tensor, bins=30, color="orange", alpha=0.7)
99
111
  ax.set_yscale("log")
100
112
  return ax
113
+
114
+
115
+ def _find_in_PATH(prog: str) -> Optional[str]:
116
+ """
117
+ Looks into every path mentioned in ``%PATH%`` a specific file,
118
+ it raises an exception if not found.
119
+
120
+ :param prog: program to look for
121
+ :return: path
122
+ """
123
+ sep = ";" if sys.platform.startswith("win") else ":"
124
+ path = os.environ["PATH"]
125
+ for p in path.split(sep):
126
+ f = os.path.join(p, prog)
127
+ if os.path.exists(f):
128
+ return p
129
+ return None
130
+
131
+
132
+ def _find_graphviz_dot(exc: bool = True) -> str:
133
+ """
134
+ Determines the path to graphviz (on Windows),
135
+ the function tests the existence of versions 34 to 45
136
+ assuming it was installed in a standard folder:
137
+ ``C:\\Program Files\\MiKTeX 2.9\\miktex\\bin\\x64``.
138
+
139
+ :param exc: raise exception of be silent
140
+ :return: path to dot
141
+ :raises FileNotFoundError: if graphviz not found
142
+ """
143
+ if sys.platform.startswith("win"):
144
+ version = list(range(34, 60))
145
+ version.extend([f"{v}.1" for v in version])
146
+ for v in version:
147
+ graphviz_dot = f"C:\\Program Files (x86)\\Graphviz2.{v}\\bin\\dot.exe"
148
+ if os.path.exists(graphviz_dot):
149
+ return graphviz_dot
150
+ extra = ["build/update_modules/Graphviz/bin"]
151
+ for ext in extra:
152
+ graphviz_dot = os.path.join(ext, "dot.exe")
153
+ if os.path.exists(graphviz_dot):
154
+ return graphviz_dot
155
+ p = _find_in_PATH("dot.exe")
156
+ if p is None:
157
+ if exc:
158
+ raise FileNotFoundError(
159
+ f"Unable to find graphviz, look into paths such as {graphviz_dot}."
160
+ )
161
+ return None
162
+ return os.path.join(p, "dot.exe")
163
+ # linux
164
+ return "dot"
165
+
166
+
167
+ def _run_subprocess(args: List[str], cwd: Optional[str] = None):
168
+ assert not isinstance(args, str), "args should be a sequence of strings, not a string."
169
+
170
+ p = subprocess.Popen(
171
+ args,
172
+ cwd=cwd,
173
+ shell=False,
174
+ env=os.environ,
175
+ stdout=subprocess.PIPE,
176
+ stderr=subprocess.PIPE,
177
+ )
178
+ raise_exception = False
179
+ output = ""
180
+ while True:
181
+ output = p.stdout.readline().decode(errors="ignore") # type: ignore[union-attr]
182
+ if output == "" and p.poll() is not None:
183
+ break
184
+ if output:
185
+ if (
186
+ "fatal error" in output
187
+ or "CMake Error" in output
188
+ or "gmake: ***" in output
189
+ or "): error C" in output
190
+ or ": error: " in output
191
+ ):
192
+ raise_exception = True
193
+ p.poll()
194
+ error = p.stderr.readline().decode(errors="ignore") # type: ignore[union-attr]
195
+ p.stdout.close() # type: ignore[union-attr]
196
+ p.stderr.close() # type: ignore[union-attr]
197
+ if error and raise_exception:
198
+ raise RuntimeError(
199
+ f"An error was found in the output. The build is stopped."
200
+ f"\n{output}\n---\n{error}"
201
+ )
202
+ return output + "\n" + error
203
+
204
+
205
+ def _run_graphviz(filename: str, image: str, engine: str = "dot") -> str:
206
+ """
207
+ Run :epkg:`Graphviz`.
208
+
209
+ :param filename: filename which contains the graph definition
210
+ :param image: output image
211
+ :param engine: *dot* or *neato*
212
+ :return: output of graphviz
213
+ """
214
+ ext = os.path.splitext(image)[-1]
215
+ assert ext in {
216
+ ".png",
217
+ ".bmp",
218
+ ".fig",
219
+ ".gif",
220
+ ".ico",
221
+ ".jpg",
222
+ ".jpeg",
223
+ ".pdf",
224
+ ".ps",
225
+ ".svg",
226
+ ".vrml",
227
+ ".tif",
228
+ ".tiff",
229
+ ".wbmp",
230
+ }, f"Unexpected extension {ext!r} for {image!r}."
231
+ if sys.platform.startswith("win"):
232
+ bin_ = os.path.dirname(_find_graphviz_dot())
233
+ # if bin not in os.environ["PATH"]:
234
+ # os.environ["PATH"] = os.environ["PATH"] + ";" + bin
235
+ exe = os.path.join(bin_, engine)
236
+ else:
237
+ exe = engine
238
+ if os.path.exists(image):
239
+ os.remove(image)
240
+ cmd = [exe, f"-T{ext[1:]}", filename, "-o", image]
241
+ output = _run_subprocess(cmd)
242
+ assert os.path.exists(image), (
243
+ f"Unable to find {image!r}, command line is "
244
+ f"{' '.join(cmd)!r}, Graphviz failed due to\n{output}"
245
+ )
246
+ return output
247
+
248
+
249
+ def draw_graph_graphviz(
250
+ dot: Union[str, onnx.ModelProto], image: str, engine: str = "dot"
251
+ ) -> str:
252
+ """
253
+ Draws a graph using :epkg:`Graphviz`.
254
+
255
+ :param dot: dot graph or ModelProto
256
+ :param image: output image, None, just returns the output
257
+ :param engine: *dot* or *neato*
258
+ :return: :epkg:`Graphviz` output or
259
+ the dot text if *image* is None
260
+
261
+ The function creates a temporary file to store the dot file if *image* is not None.
262
+ """
263
+ if isinstance(dot, onnx.ModelProto):
264
+ sdot = to_dot(dot)
265
+ else:
266
+ if "{" not in dot:
267
+ assert dot.endswith(".onnx"), f"Unexpected file extension for {dot!r}"
268
+ proto = onnx.load(dot)
269
+ sdot = to_dot(proto)
270
+ else:
271
+ sdot = dot
272
+ assert "{" in sdot, f"This string is not a dot string\n{sdot}"
273
+ with tempfile.NamedTemporaryFile(delete=False) as fp:
274
+ fp.write(sdot.encode("utf-8"))
275
+ fp.close()
276
+
277
+ filename = fp.name
278
+ assert os.path.exists(
279
+ filename
280
+ ), f"File {filename!r} cannot be created to store the graph."
281
+ out = _run_graphviz(filename, image, engine=engine)
282
+ assert os.path.exists(
283
+ image
284
+ ), f"Graphviz failed with no reason, {image!r} not found, output is {out}."
285
+ os.remove(filename)
286
+ return out
287
+
288
+
289
+ def plot_dot(
290
+ dot: Union[str, onnx.ModelProto],
291
+ ax: Optional["matplotlib.axis.Axis"] = None, # noqa: F821
292
+ engine: str = "dot",
293
+ figsize: Optional[Tuple[int, int]] = None,
294
+ ) -> "matplotlib.axis.Axis": # noqa: F821
295
+ """
296
+ Draws a dot graph into a matplotlib graph.
297
+
298
+ :param dot: dot graph or ModelProto
299
+ :param image: output image, None, just returns the output
300
+ :param engine: *dot* or *neato*
301
+ :param figsize: figsize of ax is None
302
+ :return: :epkg:`Graphviz` output or, the dot text if *image* is None
303
+
304
+ .. plot::
305
+
306
+ import matplotlib.pyplot as plt
307
+ import onnx.parser
308
+ from onnx_diagnostic.doc import plot_dot
309
+
310
+ model = onnx.parser.parse_model(
311
+ '''
312
+ <ir_version: 8, opset_import: [ "": 18]>
313
+ agraph (float[N] x) => (float[N] z) {
314
+ two = Constant <value_float=2.0> ()
315
+ four = Add(two, two)
316
+ z = Mul(four, four)
317
+ }
318
+ ''')
319
+
320
+ ax = plot_dot(model)
321
+ ax.set_title("Dummy graph")
322
+ plt.show()
323
+ """
324
+ if ax is None:
325
+ import matplotlib.pyplot as plt
326
+
327
+ _, ax = plt.subplots(1, 1, figsize=figsize)
328
+ clean = True
329
+ else:
330
+ clean = False
331
+
332
+ from PIL import Image
333
+
334
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as fp:
335
+ fp.close()
336
+
337
+ draw_graph_graphviz(dot, fp.name, engine=engine)
338
+ img = np.asarray(Image.open(fp.name))
339
+ os.remove(fp.name)
340
+
341
+ ax.imshow(img)
342
+
343
+ if clean:
344
+ ax.set_xticks([])
345
+ ax.set_yticks([])
346
+ ax.get_xaxis().set_visible(False)
347
+ ax.get_yaxis().set_visible(False)
348
+ ax.set_axis_off()
349
+ ax.get_figure().tight_layout()
350
+ return ax