onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  """
2
- Investigates onnx models.
2
+ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.5.0"
6
+ __version__ = "0.6.0"
7
7
  __author__ = "Xavier Dupré"
@@ -126,9 +126,23 @@ def get_parser_print() -> ArgumentParser:
126
126
  """
127
127
  ),
128
128
  epilog="To show a model.",
129
+ formatter_class=RawTextHelpFormatter,
129
130
  )
130
131
  parser.add_argument(
131
- "fmt", choices=["pretty", "raw"], help="Format to use.", default="pretty"
132
+ "fmt",
133
+ choices=["pretty", "raw", "text", "printer"],
134
+ default="pretty",
135
+ help=textwrap.dedent(
136
+ """
137
+ Prints out a model on the standard output.
138
+ raw - just prints the model with print(...)
139
+ printer - onnx.printer.to_text(...)
140
+ pretty - an improved rendering
141
+ text - uses GraphRendering
142
+ """.strip(
143
+ "\n"
144
+ )
145
+ ),
132
146
  )
133
147
  parser.add_argument("input", type=str, help="onnx model to load")
134
148
  return parser
@@ -144,6 +158,12 @@ def _cmd_print(argv: List[Any]):
144
158
  from .helpers.onnx_helper import pretty_onnx
145
159
 
146
160
  print(pretty_onnx(onx))
161
+ elif args.fmt == "printer":
162
+ print(onnx.printer.to_text(onx))
163
+ elif args.fmt == "text":
164
+ from .helpers.graph_helper import GraphRendering
165
+
166
+ print(GraphRendering(onx).text_rendering())
147
167
  else:
148
168
  raise ValueError(f"Unexpected value fmt={args.fmt!r}")
149
169
 
@@ -379,8 +379,9 @@ class CoupleInputsDynamicShapes:
379
379
  return torch.utils._pytree.tree_unflatten(res, spec)
380
380
 
381
381
  class ChangeDimensionProcessor:
382
- def __init__(self, desired_values):
382
+ def __init__(self, desired_values, only_desired):
383
383
  self.mapping = desired_values or {}
384
+ self.only_desired = only_desired
384
385
 
385
386
  def _build_new_shape(
386
387
  self, shape: Tuple[int, ...], ds: Dict[int, Any]
@@ -397,14 +398,16 @@ class CoupleInputsDynamicShapes:
397
398
  torch.export.dynamic_shapes._Dim,
398
399
  ),
399
400
  ):
400
- d = str(ds[i])
401
+ d = ds[i].__name__
401
402
  elif not isinstance(ds[i], int):
402
403
  raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
403
404
  if d in self.mapping:
404
405
  new_dim = self.mapping[d]
405
- else:
406
+ elif not self.only_desired:
406
407
  new_dim = shape[i] + 1
407
408
  self.mapping[d] = new_dim
409
+ else:
410
+ new_dim = shape[i]
408
411
  new_shape[i] = new_dim
409
412
  return tuple(new_shape)
410
413
 
@@ -447,7 +450,10 @@ class CoupleInputsDynamicShapes:
447
450
  return self._build_new_tensor(inputs, new_shape)
448
451
 
449
452
  def change_dynamic_dimensions(
450
- self, desired_values: Optional[Dict[str, int]] = None, args_kwargs: bool = False
453
+ self,
454
+ desired_values: Optional[Dict[str, int]] = None,
455
+ args_kwargs: bool = False,
456
+ only_desired: bool = False,
451
457
  ):
452
458
  """
453
459
  A model exported with dynamic shapes is not necessarily dynamic
@@ -460,6 +466,8 @@ class CoupleInputsDynamicShapes:
460
466
 
461
467
  :param desired_values: to fixed named dimension to have the desired value
462
468
  :param args_kwargs: return both args, kwargs even if empty
469
+ :param only_desired: if True, only change the dimension specified in
470
+ ``desired_values``
463
471
  :return: new inputs
464
472
 
465
473
  Example:
@@ -483,7 +491,8 @@ class CoupleInputsDynamicShapes:
483
491
  print("-after:", string_type(new_kwargs, with_shape=True))
484
492
  """
485
493
  return self._generic_walker(
486
- self.ChangeDimensionProcessor(desired_values), args_kwargs=args_kwargs
494
+ self.ChangeDimensionProcessor(desired_values, only_desired=only_desired),
495
+ args_kwargs=args_kwargs,
487
496
  )
488
497
 
489
498
 
@@ -776,6 +776,13 @@ class ExtTestCase(unittest.TestCase):
776
776
  os.mkdir(folder)
777
777
  return os.path.join(folder, name)
778
778
 
779
+ def get_dump_folder(self, folder: str) -> str:
780
+ """Returns a folder."""
781
+ folder = os.path.join("dump_test", folder)
782
+ if not os.path.exists(folder):
783
+ os.makedirs(folder)
784
+ return folder
785
+
779
786
  def dump_onnx(
780
787
  self,
781
788
  name: str,
@@ -813,6 +820,11 @@ class ExtTestCase(unittest.TestCase):
813
820
  msg or f"Unable to find the list of strings {tofind!r} in\n--\n{text}"
814
821
  )
815
822
 
823
+ def assertHasAttr(self, obj: Any, name: str):
824
+ assert hasattr(
825
+ obj, name
826
+ ), f"Unable to find attribute {name!r} in object type {type(obj)}"
827
+
816
828
  def assertSetContained(self, set1, set2):
817
829
  "Checks that ``set1`` is contained in ``set2``."
818
830
  set1 = set(set1)
@@ -0,0 +1,386 @@
1
+ import pprint
2
+ from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
3
+ import onnx
4
+ import onnx.helper as oh
5
+
6
+
7
+ class GraphRendering:
8
+ """
9
+ Helpers to renders a graph.
10
+
11
+ :param proto: model or graph to render.
12
+ """
13
+
14
+ def __init__(self, proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto]):
15
+ self.proto = proto
16
+
17
+ def __repr__(self) -> str:
18
+ "usual"
19
+ return f"{self.__class__.__name__}(<{self.proto.__class__.__name__}>)"
20
+
21
+ @classmethod
22
+ def computation_order(
23
+ cls,
24
+ nodes: Sequence[onnx.NodeProto],
25
+ existing: Optional[List[str]] = None,
26
+ start: int = 1,
27
+ ) -> List[int]:
28
+ """
29
+ Returns the soonest a node can be computed,
30
+ every node can assume all nodes with a lower number exists.
31
+ Every node with a higher number must wait for the previous one.
32
+
33
+ :param nodes: list of nodes
34
+ :param existing: existing before any computation starts
35
+ :param start: lower number
36
+ :return: computation order
37
+ """
38
+ assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
39
+ f"This algorithme is not yet implemented if the sequence contains "
40
+ f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
41
+ )
42
+ number = {e: start - 1 for e in (existing or [])} # noqa: C420
43
+ results = [start for _ in nodes]
44
+ for i_node, node in enumerate(nodes):
45
+ assert all(i in number for i in node.input), (
46
+ f"Missing input in node {i_node} type={node.op_type}: "
47
+ f"{[i for i in node.input if i not in number]}"
48
+ )
49
+ if node.input:
50
+ mx = max(number[i] for i in node.input) + 1
51
+ results[i_node] = mx
52
+ else:
53
+ # A constant
54
+ mx = max(number.values()) if number else 0
55
+ for i in node.output:
56
+ number[i] = mx
57
+ return results
58
+
59
+ @classmethod
60
+ def graph_positions(
61
+ cls,
62
+ nodes: Sequence[onnx.NodeProto],
63
+ order: List[int],
64
+ existing: Optional[List[str]] = None,
65
+ ) -> List[Tuple[int, int]]:
66
+ """
67
+ Returns positions on a plan for every node in a graph.
68
+ The function minimizes the number of lines crossing each others.
69
+ It goes forward, every line is optimized depending on what is below.
70
+ It could be improved with more iterations.
71
+
72
+ :param nodes: list of nodes
73
+ :param existing: existing names
74
+ :param order: computation order returned by
75
+ :meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.computation_order`
76
+ :return: list of tuple( row, column)
77
+ """
78
+ # initialization
79
+ min_row = min(order)
80
+ n_rows = max(order) + 1
81
+ names: Dict[str, int] = {}
82
+
83
+ positions = [(min_row, i) for i in range(len(order))]
84
+ for row in range(min_row, n_rows):
85
+ indices = [i for i, o in enumerate(order) if o == row]
86
+ assert indices, f"indices cannot be empty for row={row}, order={order}"
87
+ ns = [nodes[i] for i in indices]
88
+ mx = [(max(names.get(i, 0) for i in n.input) if n.input else 0) for n in ns]
89
+ mix = [(m, i) for i, m in enumerate(mx)]
90
+ mix.sort()
91
+ for c, (_m, i) in enumerate(mix):
92
+ positions[indices[i]] = (row, c)
93
+ n = nodes[indices[i]]
94
+ for o in n.output:
95
+ names[o] = c
96
+
97
+ return positions
98
+
99
+ @classmethod
100
+ def text_positions(
101
+ cls, nodes: Sequence[onnx.NodeProto], positions: List[Tuple[int, int]]
102
+ ) -> List[Tuple[int, int]]:
103
+ """
104
+ Returns positions for the nodes assuming it is rendered into text.
105
+
106
+ :param nodes: list of nodes
107
+ :param positions: positions returned by
108
+ :meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.graph_positions`
109
+ :return: text positions
110
+ """
111
+ new_positions = [(row * 4, col * 2 + row) for row, col in positions]
112
+ column_size = {col: 3 for _, col in new_positions}
113
+ for i, (_row, col) in enumerate(new_positions):
114
+ size = len(nodes[i].op_type) + 5
115
+ column_size[col] = max(column_size[col], size)
116
+ assert column_size[col] < 200, (
117
+ f"column_size[{col}]={column_size[col]}, this is quite big, i={i}, "
118
+ f"nodes[i].op_type={nodes[i].op_type}"
119
+ )
120
+
121
+ # cumulated
122
+ sort = sorted(column_size.items())
123
+ cumul = dict(sort[:1])
124
+ results = {sort[0][0]: sort[0][1] // 2}
125
+ for col, size in sort[1:]:
126
+ c = max(cumul.values())
127
+ cumul[col] = c + size
128
+ results[col] = c + size // 2
129
+ return [(row, results[col]) for row, col in new_positions]
130
+
131
+ @property
132
+ def nodes(self) -> List[onnx.NodeProto]:
133
+ "Returns the list of nodes"
134
+ return (
135
+ self.proto.graph.node
136
+ if isinstance(self.proto, onnx.ModelProto)
137
+ else self.proto.node
138
+ )
139
+
140
+ @property
141
+ def start_names(self) -> List[onnx.NodeProto]:
142
+ "Returns the list of known names, inputs and initializer"
143
+ graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
144
+ input_names = (
145
+ list(graph.input)
146
+ if isinstance(graph, onnx.FunctionProto)
147
+ else [i.name for i in graph.input]
148
+ )
149
+ init_names = (
150
+ []
151
+ if isinstance(graph, onnx.FunctionProto)
152
+ else [
153
+ *[i.name for i in graph.initializer],
154
+ *[i.name for i in graph.sparse_initializer],
155
+ ]
156
+ )
157
+ return [*input_names, *init_names]
158
+
159
+ @property
160
+ def input_names(self) -> List[str]:
161
+ "Returns the list of input names."
162
+ return (
163
+ self.proto.input
164
+ if isinstance(self.proto, onnx.FunctionProto)
165
+ else [
166
+ i.name
167
+ for i in (
168
+ self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
169
+ ).input
170
+ ]
171
+ )
172
+
173
+ @property
174
+ def output_names(self) -> List[str]:
175
+ "Returns the list of output names."
176
+ return (
177
+ self.proto.output
178
+ if isinstance(self.proto, onnx.FunctionProto)
179
+ else [
180
+ i.name
181
+ for i in (
182
+ self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
183
+ ).output
184
+ ]
185
+ )
186
+
187
+ @classmethod
188
+ def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int]]:
189
+ """Builds the list of edges between nodes."""
190
+ produced = {}
191
+ for i, node in enumerate(nodes):
192
+ for o in node.output:
193
+ produced[o] = i
194
+ edges = set()
195
+ for i, node in enumerate(nodes):
196
+ for name in node.input:
197
+ if name in produced:
198
+ edge = produced[name], i
199
+ edges.add(edge)
200
+ return edges
201
+
202
+ ADD_RULES = {
203
+ ("┴", "┘"): "┴",
204
+ ("┴", "└"): "┴",
205
+ ("┬", "┐"): "┬",
206
+ ("┬", "┌"): "┬",
207
+ ("-", "└"): "┴",
208
+ ("-", "|"): "┼",
209
+ ("-", "┐"): "┬",
210
+ ("┐", "-"): "┬",
211
+ ("┘", "-"): "┴",
212
+ ("┴", "-"): "┴",
213
+ ("-", "┘"): "┴",
214
+ ("┌", "-"): "┬",
215
+ ("┬", "-"): "┬",
216
+ ("-", "┌"): "┬",
217
+ ("|", "-"): "┼",
218
+ ("└", "-"): "┴",
219
+ ("|", "└"): "├",
220
+ ("|", "┘"): "┤",
221
+ ("┐", "|"): "┤",
222
+ ("┬", "|"): "┼",
223
+ ("|", "┐"): "┤",
224
+ ("|", "┌"): "├",
225
+ ("├", "-"): "┼",
226
+ ("└", "|"): "├",
227
+ ("┤", "┐"): "┤",
228
+ ("┤", "|"): "┤",
229
+ ("├", "|"): "├",
230
+ ("┴", "┌"): "┼",
231
+ ("┐", "┌"): "┬",
232
+ ("┌", "┐"): "┬",
233
+ ("┌", "|"): "┼",
234
+ ("┴", "┐"): "┼",
235
+ ("┐", "└"): "┼",
236
+ ("┬", "┘"): "┼",
237
+ ("├", "└"): "├",
238
+ ("┤", "┌"): "┼",
239
+ ("┘", "|"): "┤",
240
+ ("┴", "|"): "┼",
241
+ ("┤", "-"): "┼",
242
+ ("┘", "└"): "┴",
243
+ }
244
+
245
+ @classmethod
246
+ def text_grid(cls, grid: List[List[str]], position: Tuple[int, int], text: str):
247
+ """
248
+ Prints inplace a text in a grid. The text is centered.
249
+
250
+ :param grid: grid
251
+ :param position: position
252
+ :param text: text to print
253
+ """
254
+ row, col = position
255
+ begin = col - len(text) // 2
256
+ grid[row][begin : begin + len(text)] = list(text)
257
+
258
+ def text_edge(
259
+ cls,
260
+ grid: List[List[str]],
261
+ p1: Tuple[int, int],
262
+ p2: Tuple[int, int],
263
+ mode: str = "square",
264
+ ):
265
+ """
266
+ Prints inplace an edge in a grid. The text is centered.
267
+
268
+ :param grid: grid
269
+ :param p1: first position
270
+ :param p2: second position
271
+ :param mode: ``'square'`` is the only supported value
272
+ """
273
+ assert mode == "square", f"mode={mode!r} not supported"
274
+ assert p1[0] < p2[0], f"Unexpected edge p1={p1}, p2={p2}"
275
+ assert p1[0] + 2 <= p2[0] - 2, f"Unexpected edge p1={p1}, p2={p2}"
276
+ # removes this when the algorithm is ready
277
+ assert 0 <= p1[0] < len(grid) - 3, f"p1={p1}, grid:{len(grid)},{len(grid[0])}"
278
+ assert 2 <= p2[0] < len(grid) - 1, f"p2={p2}, grid:{len(grid)},{len(grid[0])}"
279
+ assert (
280
+ 0 <= p1[1] < min(len(g) for g in grid)
281
+ ), f"p1={p1}, sizes={[len(g) for g in grid]}"
282
+ assert (
283
+ 0 <= p2[1] < min(len(g) for g in grid)
284
+ ), f"p2={p2}, sizes={[len(g) for g in grid]}"
285
+
286
+ def add(s1, s2):
287
+ assert s2 != " ", f"s1={s1!r}, s2={s2!r}"
288
+ if s1 == " " or s1 == s2:
289
+ return s2
290
+ if s1 == "┼" or s2 == "┼":
291
+ return "┼"
292
+ if (s1, s2) in cls.ADD_RULES:
293
+ return cls.ADD_RULES[s1, s2]
294
+ raise NotImplementedError(f"Unable to add: ({s1!r},{s2!r}): '',")
295
+
296
+ def place(grid, x, y, symbol):
297
+ grid[x][y] = add(grid[x][y], symbol)
298
+
299
+ place(grid, p1[0] + 1, p1[1], "|")
300
+ place(grid, p1[0] + 2, p1[1], "└" if p1[1] < p2[1] else "┘")
301
+
302
+ if p1[0] + 2 == p2[0] - 2:
303
+ a, b = (p1[1] + 1, p2[1] - 1) if p1[1] < p2[1] else (p2[1] + 1, p1[1] - 1)
304
+ for i in range(a, b + 1):
305
+ place(grid, p1[0] + 2, i, "-")
306
+ else:
307
+ middle = (p1[1] + p2[1]) // 2
308
+ a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
309
+ for i in range(a, b + 1):
310
+ place(grid, p1[0] + 2, i, "-")
311
+ a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
312
+ for i in range(a, b + 1):
313
+ place(grid, p1[0] + 2, i, "-")
314
+
315
+ place(grid, p1[0] + 2, middle, "┐" if p1[1] < p2[1] else "┌")
316
+ place(grid, p2[0] - 2, middle, "└" if p1[1] < p2[1] else "┘")
317
+
318
+ for i in range(p1[0] + 2 + 1, p2[0] - 2):
319
+ place(grid, i, middle, "|")
320
+
321
+ place(grid, p2[0] - 2, p2[1], "┐" if p1[1] < p2[1] else "┌")
322
+ place(grid, p2[0] - 1, p2[1], "|")
323
+
324
+ def text_rendering(self, prefix="") -> str:
325
+ """
326
+ Renders a model in text.
327
+
328
+ .. runpython::
329
+ :showcode:
330
+
331
+ import textwrap
332
+ import onnx
333
+ import onnx.helper as oh
334
+ from onnx_diagnostic.helpers.graph_helper import GraphRendering
335
+
336
+ TFLOAT = onnx.TensorProto.FLOAT
337
+
338
+ proto = oh.make_model(
339
+ oh.make_graph(
340
+ [
341
+ oh.make_node("Add", ["X", "Y"], ["xy"]),
342
+ oh.make_node("Neg", ["Y"], ["ny"]),
343
+ oh.make_node("Mul", ["xy", "ny"], ["a"]),
344
+ oh.make_node("Mul", ["a", "Y"], ["Z"]),
345
+ ],
346
+ "-nd-",
347
+ [
348
+ oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
349
+ oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
350
+ ],
351
+ [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
352
+ ),
353
+ opset_imports=[oh.make_opsetid("", 18)],
354
+ ir_version=9,
355
+ )
356
+ graph = GraphRendering(proto)
357
+ text = textwrap.dedent(graph.text_rendering()).strip("\\n")
358
+ print(text)
359
+ """
360
+ nodes = [
361
+ *[oh.make_node(i, ["BEGIN"], [i]) for i in self.input_names],
362
+ *self.nodes,
363
+ *[oh.make_node(i, [i], ["END"]) for i in self.output_names],
364
+ ]
365
+ exist = set(self.start_names) - set(self.input_names)
366
+ exist |= {"BEGIN"}
367
+ existing = sorted(exist)
368
+ order = self.computation_order(nodes, existing)
369
+ positions = self.graph_positions(nodes, order, existing)
370
+ text_pos = self.text_positions(nodes, positions)
371
+ edges = self.build_node_edges(nodes)
372
+ max_len = max(col for _, col in text_pos) + max(len(n.op_type) for n in nodes)
373
+ assert max_len < 1e6, f"max_len={max_len}, text_pos=\n{pprint.pformat(text_pos)}"
374
+ max_row = max(row for row, _ in text_pos) + 2
375
+ grid = [[" " for i in range(max_len + 1)] for _ in range(max_row + 1)]
376
+
377
+ for n1, n2 in edges:
378
+ self.text_edge(grid, text_pos[n1], text_pos[n2])
379
+ assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
380
+ for node, pos in zip(nodes, text_pos):
381
+ self.text_grid(grid, pos, node.op_type)
382
+ assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
383
+
384
+ return "\n".join(
385
+ f"{prefix}{line.rstrip()}" for line in ["".join(line) for line in grid]
386
+ )
@@ -112,6 +112,30 @@ def string_type(
112
112
  :param verbose: verbosity (to show the path it followed to get that print)
113
113
  :return: str
114
114
 
115
+ The function displays something like the following for a tensor.
116
+
117
+ .. code-block:: text
118
+
119
+ T7s2x7[0.5:6:A3.56]
120
+ ^^^+-^^----+------^
121
+ || | |
122
+ || | +-- information about the content of a tensor or array
123
+ || | [min,max:A<average>]
124
+ || |
125
+ || +-- a shape
126
+ ||
127
+ |+-- integer following the code defined by onnx.TensorProto,
128
+ | 7 is onnx.TensorProto.INT64 (see onnx_dtype_name)
129
+ |
130
+ +-- A,T,F
131
+ A is an array from numpy
132
+ T is a Tensor from pytorch
133
+ F is a FakeTensor from pytorch
134
+
135
+ The element types for a tensor are displayed as integer to shorten the message.
136
+ The semantic is defined by :class:`onnx.TensorProto` and can be obtained
137
+ by :func:`onnx_diagnostic.helpers.onnx_helper.onnx_dtype_name`.
138
+
115
139
  .. runpython::
116
140
  :showcode:
117
141