onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +2 -2
- onnx_diagnostic/_command_lines_parser.py +21 -1
- onnx_diagnostic/export/dynamic_shapes.py +14 -5
- onnx_diagnostic/ext_test_case.py +12 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +24 -0
- onnx_diagnostic/helpers/model_builder_helper.py +333 -0
- onnx_diagnostic/helpers/rt_helper.py +65 -1
- onnx_diagnostic/torch_export_patches/eval/__init__.py +621 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +896 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +34 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +6 -1
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +25 -19
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +91 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +29 -1
- onnx_diagnostic/torch_models/test_helper.py +110 -7
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/RECORD +21 -17
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.0.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -126,9 +126,23 @@ def get_parser_print() -> ArgumentParser:
|
|
|
126
126
|
"""
|
|
127
127
|
),
|
|
128
128
|
epilog="To show a model.",
|
|
129
|
+
formatter_class=RawTextHelpFormatter,
|
|
129
130
|
)
|
|
130
131
|
parser.add_argument(
|
|
131
|
-
"fmt",
|
|
132
|
+
"fmt",
|
|
133
|
+
choices=["pretty", "raw", "text", "printer"],
|
|
134
|
+
default="pretty",
|
|
135
|
+
help=textwrap.dedent(
|
|
136
|
+
"""
|
|
137
|
+
Prints out a model on the standard output.
|
|
138
|
+
raw - just prints the model with print(...)
|
|
139
|
+
printer - onnx.printer.to_text(...)
|
|
140
|
+
pretty - an improved rendering
|
|
141
|
+
text - uses GraphRendering
|
|
142
|
+
""".strip(
|
|
143
|
+
"\n"
|
|
144
|
+
)
|
|
145
|
+
),
|
|
132
146
|
)
|
|
133
147
|
parser.add_argument("input", type=str, help="onnx model to load")
|
|
134
148
|
return parser
|
|
@@ -144,6 +158,12 @@ def _cmd_print(argv: List[Any]):
|
|
|
144
158
|
from .helpers.onnx_helper import pretty_onnx
|
|
145
159
|
|
|
146
160
|
print(pretty_onnx(onx))
|
|
161
|
+
elif args.fmt == "printer":
|
|
162
|
+
print(onnx.printer.to_text(onx))
|
|
163
|
+
elif args.fmt == "text":
|
|
164
|
+
from .helpers.graph_helper import GraphRendering
|
|
165
|
+
|
|
166
|
+
print(GraphRendering(onx).text_rendering())
|
|
147
167
|
else:
|
|
148
168
|
raise ValueError(f"Unexpected value fmt={args.fmt!r}")
|
|
149
169
|
|
|
@@ -379,8 +379,9 @@ class CoupleInputsDynamicShapes:
|
|
|
379
379
|
return torch.utils._pytree.tree_unflatten(res, spec)
|
|
380
380
|
|
|
381
381
|
class ChangeDimensionProcessor:
|
|
382
|
-
def __init__(self, desired_values):
|
|
382
|
+
def __init__(self, desired_values, only_desired):
|
|
383
383
|
self.mapping = desired_values or {}
|
|
384
|
+
self.only_desired = only_desired
|
|
384
385
|
|
|
385
386
|
def _build_new_shape(
|
|
386
387
|
self, shape: Tuple[int, ...], ds: Dict[int, Any]
|
|
@@ -397,14 +398,16 @@ class CoupleInputsDynamicShapes:
|
|
|
397
398
|
torch.export.dynamic_shapes._Dim,
|
|
398
399
|
),
|
|
399
400
|
):
|
|
400
|
-
d =
|
|
401
|
+
d = ds[i].__name__
|
|
401
402
|
elif not isinstance(ds[i], int):
|
|
402
403
|
raise NotImplementedError(f"Unable to handle type {ds[i]} in {ds}")
|
|
403
404
|
if d in self.mapping:
|
|
404
405
|
new_dim = self.mapping[d]
|
|
405
|
-
|
|
406
|
+
elif not self.only_desired:
|
|
406
407
|
new_dim = shape[i] + 1
|
|
407
408
|
self.mapping[d] = new_dim
|
|
409
|
+
else:
|
|
410
|
+
new_dim = shape[i]
|
|
408
411
|
new_shape[i] = new_dim
|
|
409
412
|
return tuple(new_shape)
|
|
410
413
|
|
|
@@ -447,7 +450,10 @@ class CoupleInputsDynamicShapes:
|
|
|
447
450
|
return self._build_new_tensor(inputs, new_shape)
|
|
448
451
|
|
|
449
452
|
def change_dynamic_dimensions(
|
|
450
|
-
self,
|
|
453
|
+
self,
|
|
454
|
+
desired_values: Optional[Dict[str, int]] = None,
|
|
455
|
+
args_kwargs: bool = False,
|
|
456
|
+
only_desired: bool = False,
|
|
451
457
|
):
|
|
452
458
|
"""
|
|
453
459
|
A model exported with dynamic shapes is not necessarily dynamic
|
|
@@ -460,6 +466,8 @@ class CoupleInputsDynamicShapes:
|
|
|
460
466
|
|
|
461
467
|
:param desired_values: to fixed named dimension to have the desired value
|
|
462
468
|
:param args_kwargs: return both args, kwargs even if empty
|
|
469
|
+
:param only_desired: if True, only change the dimension specified in
|
|
470
|
+
``desired_values``
|
|
463
471
|
:return: new inputs
|
|
464
472
|
|
|
465
473
|
Example:
|
|
@@ -483,7 +491,8 @@ class CoupleInputsDynamicShapes:
|
|
|
483
491
|
print("-after:", string_type(new_kwargs, with_shape=True))
|
|
484
492
|
"""
|
|
485
493
|
return self._generic_walker(
|
|
486
|
-
self.ChangeDimensionProcessor(desired_values
|
|
494
|
+
self.ChangeDimensionProcessor(desired_values, only_desired=only_desired),
|
|
495
|
+
args_kwargs=args_kwargs,
|
|
487
496
|
)
|
|
488
497
|
|
|
489
498
|
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -776,6 +776,13 @@ class ExtTestCase(unittest.TestCase):
|
|
|
776
776
|
os.mkdir(folder)
|
|
777
777
|
return os.path.join(folder, name)
|
|
778
778
|
|
|
779
|
+
def get_dump_folder(self, folder: str) -> str:
|
|
780
|
+
"""Returns a folder."""
|
|
781
|
+
folder = os.path.join("dump_test", folder)
|
|
782
|
+
if not os.path.exists(folder):
|
|
783
|
+
os.makedirs(folder)
|
|
784
|
+
return folder
|
|
785
|
+
|
|
779
786
|
def dump_onnx(
|
|
780
787
|
self,
|
|
781
788
|
name: str,
|
|
@@ -813,6 +820,11 @@ class ExtTestCase(unittest.TestCase):
|
|
|
813
820
|
msg or f"Unable to find the list of strings {tofind!r} in\n--\n{text}"
|
|
814
821
|
)
|
|
815
822
|
|
|
823
|
+
def assertHasAttr(self, obj: Any, name: str):
|
|
824
|
+
assert hasattr(
|
|
825
|
+
obj, name
|
|
826
|
+
), f"Unable to find attribute {name!r} in object type {type(obj)}"
|
|
827
|
+
|
|
816
828
|
def assertSetContained(self, set1, set2):
|
|
817
829
|
"Checks that ``set1`` is contained in ``set2``."
|
|
818
830
|
set1 = set(set1)
|
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import pprint
|
|
2
|
+
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
3
|
+
import onnx
|
|
4
|
+
import onnx.helper as oh
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GraphRendering:
|
|
8
|
+
"""
|
|
9
|
+
Helpers to renders a graph.
|
|
10
|
+
|
|
11
|
+
:param proto: model or graph to render.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto]):
|
|
15
|
+
self.proto = proto
|
|
16
|
+
|
|
17
|
+
def __repr__(self) -> str:
|
|
18
|
+
"usual"
|
|
19
|
+
return f"{self.__class__.__name__}(<{self.proto.__class__.__name__}>)"
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def computation_order(
|
|
23
|
+
cls,
|
|
24
|
+
nodes: Sequence[onnx.NodeProto],
|
|
25
|
+
existing: Optional[List[str]] = None,
|
|
26
|
+
start: int = 1,
|
|
27
|
+
) -> List[int]:
|
|
28
|
+
"""
|
|
29
|
+
Returns the soonest a node can be computed,
|
|
30
|
+
every node can assume all nodes with a lower number exists.
|
|
31
|
+
Every node with a higher number must wait for the previous one.
|
|
32
|
+
|
|
33
|
+
:param nodes: list of nodes
|
|
34
|
+
:param existing: existing before any computation starts
|
|
35
|
+
:param start: lower number
|
|
36
|
+
:return: computation order
|
|
37
|
+
"""
|
|
38
|
+
assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
|
|
39
|
+
f"This algorithme is not yet implemented if the sequence contains "
|
|
40
|
+
f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
|
|
41
|
+
)
|
|
42
|
+
number = {e: start - 1 for e in (existing or [])} # noqa: C420
|
|
43
|
+
results = [start for _ in nodes]
|
|
44
|
+
for i_node, node in enumerate(nodes):
|
|
45
|
+
assert all(i in number for i in node.input), (
|
|
46
|
+
f"Missing input in node {i_node} type={node.op_type}: "
|
|
47
|
+
f"{[i for i in node.input if i not in number]}"
|
|
48
|
+
)
|
|
49
|
+
if node.input:
|
|
50
|
+
mx = max(number[i] for i in node.input) + 1
|
|
51
|
+
results[i_node] = mx
|
|
52
|
+
else:
|
|
53
|
+
# A constant
|
|
54
|
+
mx = max(number.values()) if number else 0
|
|
55
|
+
for i in node.output:
|
|
56
|
+
number[i] = mx
|
|
57
|
+
return results
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def graph_positions(
|
|
61
|
+
cls,
|
|
62
|
+
nodes: Sequence[onnx.NodeProto],
|
|
63
|
+
order: List[int],
|
|
64
|
+
existing: Optional[List[str]] = None,
|
|
65
|
+
) -> List[Tuple[int, int]]:
|
|
66
|
+
"""
|
|
67
|
+
Returns positions on a plan for every node in a graph.
|
|
68
|
+
The function minimizes the number of lines crossing each others.
|
|
69
|
+
It goes forward, every line is optimized depending on what is below.
|
|
70
|
+
It could be improved with more iterations.
|
|
71
|
+
|
|
72
|
+
:param nodes: list of nodes
|
|
73
|
+
:param existing: existing names
|
|
74
|
+
:param order: computation order returned by
|
|
75
|
+
:meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.computation_order`
|
|
76
|
+
:return: list of tuple( row, column)
|
|
77
|
+
"""
|
|
78
|
+
# initialization
|
|
79
|
+
min_row = min(order)
|
|
80
|
+
n_rows = max(order) + 1
|
|
81
|
+
names: Dict[str, int] = {}
|
|
82
|
+
|
|
83
|
+
positions = [(min_row, i) for i in range(len(order))]
|
|
84
|
+
for row in range(min_row, n_rows):
|
|
85
|
+
indices = [i for i, o in enumerate(order) if o == row]
|
|
86
|
+
assert indices, f"indices cannot be empty for row={row}, order={order}"
|
|
87
|
+
ns = [nodes[i] for i in indices]
|
|
88
|
+
mx = [(max(names.get(i, 0) for i in n.input) if n.input else 0) for n in ns]
|
|
89
|
+
mix = [(m, i) for i, m in enumerate(mx)]
|
|
90
|
+
mix.sort()
|
|
91
|
+
for c, (_m, i) in enumerate(mix):
|
|
92
|
+
positions[indices[i]] = (row, c)
|
|
93
|
+
n = nodes[indices[i]]
|
|
94
|
+
for o in n.output:
|
|
95
|
+
names[o] = c
|
|
96
|
+
|
|
97
|
+
return positions
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def text_positions(
|
|
101
|
+
cls, nodes: Sequence[onnx.NodeProto], positions: List[Tuple[int, int]]
|
|
102
|
+
) -> List[Tuple[int, int]]:
|
|
103
|
+
"""
|
|
104
|
+
Returns positions for the nodes assuming it is rendered into text.
|
|
105
|
+
|
|
106
|
+
:param nodes: list of nodes
|
|
107
|
+
:param positions: positions returned by
|
|
108
|
+
:meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.graph_positions`
|
|
109
|
+
:return: text positions
|
|
110
|
+
"""
|
|
111
|
+
new_positions = [(row * 4, col * 2 + row) for row, col in positions]
|
|
112
|
+
column_size = {col: 3 for _, col in new_positions}
|
|
113
|
+
for i, (_row, col) in enumerate(new_positions):
|
|
114
|
+
size = len(nodes[i].op_type) + 5
|
|
115
|
+
column_size[col] = max(column_size[col], size)
|
|
116
|
+
assert column_size[col] < 200, (
|
|
117
|
+
f"column_size[{col}]={column_size[col]}, this is quite big, i={i}, "
|
|
118
|
+
f"nodes[i].op_type={nodes[i].op_type}"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# cumulated
|
|
122
|
+
sort = sorted(column_size.items())
|
|
123
|
+
cumul = dict(sort[:1])
|
|
124
|
+
results = {sort[0][0]: sort[0][1] // 2}
|
|
125
|
+
for col, size in sort[1:]:
|
|
126
|
+
c = max(cumul.values())
|
|
127
|
+
cumul[col] = c + size
|
|
128
|
+
results[col] = c + size // 2
|
|
129
|
+
return [(row, results[col]) for row, col in new_positions]
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def nodes(self) -> List[onnx.NodeProto]:
|
|
133
|
+
"Returns the list of nodes"
|
|
134
|
+
return (
|
|
135
|
+
self.proto.graph.node
|
|
136
|
+
if isinstance(self.proto, onnx.ModelProto)
|
|
137
|
+
else self.proto.node
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def start_names(self) -> List[onnx.NodeProto]:
|
|
142
|
+
"Returns the list of known names, inputs and initializer"
|
|
143
|
+
graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
|
|
144
|
+
input_names = (
|
|
145
|
+
list(graph.input)
|
|
146
|
+
if isinstance(graph, onnx.FunctionProto)
|
|
147
|
+
else [i.name for i in graph.input]
|
|
148
|
+
)
|
|
149
|
+
init_names = (
|
|
150
|
+
[]
|
|
151
|
+
if isinstance(graph, onnx.FunctionProto)
|
|
152
|
+
else [
|
|
153
|
+
*[i.name for i in graph.initializer],
|
|
154
|
+
*[i.name for i in graph.sparse_initializer],
|
|
155
|
+
]
|
|
156
|
+
)
|
|
157
|
+
return [*input_names, *init_names]
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def input_names(self) -> List[str]:
|
|
161
|
+
"Returns the list of input names."
|
|
162
|
+
return (
|
|
163
|
+
self.proto.input
|
|
164
|
+
if isinstance(self.proto, onnx.FunctionProto)
|
|
165
|
+
else [
|
|
166
|
+
i.name
|
|
167
|
+
for i in (
|
|
168
|
+
self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
|
|
169
|
+
).input
|
|
170
|
+
]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def output_names(self) -> List[str]:
|
|
175
|
+
"Returns the list of output names."
|
|
176
|
+
return (
|
|
177
|
+
self.proto.output
|
|
178
|
+
if isinstance(self.proto, onnx.FunctionProto)
|
|
179
|
+
else [
|
|
180
|
+
i.name
|
|
181
|
+
for i in (
|
|
182
|
+
self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
|
|
183
|
+
).output
|
|
184
|
+
]
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int]]:
|
|
189
|
+
"""Builds the list of edges between nodes."""
|
|
190
|
+
produced = {}
|
|
191
|
+
for i, node in enumerate(nodes):
|
|
192
|
+
for o in node.output:
|
|
193
|
+
produced[o] = i
|
|
194
|
+
edges = set()
|
|
195
|
+
for i, node in enumerate(nodes):
|
|
196
|
+
for name in node.input:
|
|
197
|
+
if name in produced:
|
|
198
|
+
edge = produced[name], i
|
|
199
|
+
edges.add(edge)
|
|
200
|
+
return edges
|
|
201
|
+
|
|
202
|
+
ADD_RULES = {
|
|
203
|
+
("┴", "┘"): "┴",
|
|
204
|
+
("┴", "└"): "┴",
|
|
205
|
+
("┬", "┐"): "┬",
|
|
206
|
+
("┬", "┌"): "┬",
|
|
207
|
+
("-", "└"): "┴",
|
|
208
|
+
("-", "|"): "┼",
|
|
209
|
+
("-", "┐"): "┬",
|
|
210
|
+
("┐", "-"): "┬",
|
|
211
|
+
("┘", "-"): "┴",
|
|
212
|
+
("┴", "-"): "┴",
|
|
213
|
+
("-", "┘"): "┴",
|
|
214
|
+
("┌", "-"): "┬",
|
|
215
|
+
("┬", "-"): "┬",
|
|
216
|
+
("-", "┌"): "┬",
|
|
217
|
+
("|", "-"): "┼",
|
|
218
|
+
("└", "-"): "┴",
|
|
219
|
+
("|", "└"): "├",
|
|
220
|
+
("|", "┘"): "┤",
|
|
221
|
+
("┐", "|"): "┤",
|
|
222
|
+
("┬", "|"): "┼",
|
|
223
|
+
("|", "┐"): "┤",
|
|
224
|
+
("|", "┌"): "├",
|
|
225
|
+
("├", "-"): "┼",
|
|
226
|
+
("└", "|"): "├",
|
|
227
|
+
("┤", "┐"): "┤",
|
|
228
|
+
("┤", "|"): "┤",
|
|
229
|
+
("├", "|"): "├",
|
|
230
|
+
("┴", "┌"): "┼",
|
|
231
|
+
("┐", "┌"): "┬",
|
|
232
|
+
("┌", "┐"): "┬",
|
|
233
|
+
("┌", "|"): "┼",
|
|
234
|
+
("┴", "┐"): "┼",
|
|
235
|
+
("┐", "└"): "┼",
|
|
236
|
+
("┬", "┘"): "┼",
|
|
237
|
+
("├", "└"): "├",
|
|
238
|
+
("┤", "┌"): "┼",
|
|
239
|
+
("┘", "|"): "┤",
|
|
240
|
+
("┴", "|"): "┼",
|
|
241
|
+
("┤", "-"): "┼",
|
|
242
|
+
("┘", "└"): "┴",
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
@classmethod
|
|
246
|
+
def text_grid(cls, grid: List[List[str]], position: Tuple[int, int], text: str):
|
|
247
|
+
"""
|
|
248
|
+
Prints inplace a text in a grid. The text is centered.
|
|
249
|
+
|
|
250
|
+
:param grid: grid
|
|
251
|
+
:param position: position
|
|
252
|
+
:param text: text to print
|
|
253
|
+
"""
|
|
254
|
+
row, col = position
|
|
255
|
+
begin = col - len(text) // 2
|
|
256
|
+
grid[row][begin : begin + len(text)] = list(text)
|
|
257
|
+
|
|
258
|
+
def text_edge(
|
|
259
|
+
cls,
|
|
260
|
+
grid: List[List[str]],
|
|
261
|
+
p1: Tuple[int, int],
|
|
262
|
+
p2: Tuple[int, int],
|
|
263
|
+
mode: str = "square",
|
|
264
|
+
):
|
|
265
|
+
"""
|
|
266
|
+
Prints inplace an edge in a grid. The text is centered.
|
|
267
|
+
|
|
268
|
+
:param grid: grid
|
|
269
|
+
:param p1: first position
|
|
270
|
+
:param p2: second position
|
|
271
|
+
:param mode: ``'square'`` is the only supported value
|
|
272
|
+
"""
|
|
273
|
+
assert mode == "square", f"mode={mode!r} not supported"
|
|
274
|
+
assert p1[0] < p2[0], f"Unexpected edge p1={p1}, p2={p2}"
|
|
275
|
+
assert p1[0] + 2 <= p2[0] - 2, f"Unexpected edge p1={p1}, p2={p2}"
|
|
276
|
+
# removes this when the algorithm is ready
|
|
277
|
+
assert 0 <= p1[0] < len(grid) - 3, f"p1={p1}, grid:{len(grid)},{len(grid[0])}"
|
|
278
|
+
assert 2 <= p2[0] < len(grid) - 1, f"p2={p2}, grid:{len(grid)},{len(grid[0])}"
|
|
279
|
+
assert (
|
|
280
|
+
0 <= p1[1] < min(len(g) for g in grid)
|
|
281
|
+
), f"p1={p1}, sizes={[len(g) for g in grid]}"
|
|
282
|
+
assert (
|
|
283
|
+
0 <= p2[1] < min(len(g) for g in grid)
|
|
284
|
+
), f"p2={p2}, sizes={[len(g) for g in grid]}"
|
|
285
|
+
|
|
286
|
+
def add(s1, s2):
|
|
287
|
+
assert s2 != " ", f"s1={s1!r}, s2={s2!r}"
|
|
288
|
+
if s1 == " " or s1 == s2:
|
|
289
|
+
return s2
|
|
290
|
+
if s1 == "┼" or s2 == "┼":
|
|
291
|
+
return "┼"
|
|
292
|
+
if (s1, s2) in cls.ADD_RULES:
|
|
293
|
+
return cls.ADD_RULES[s1, s2]
|
|
294
|
+
raise NotImplementedError(f"Unable to add: ({s1!r},{s2!r}): '',")
|
|
295
|
+
|
|
296
|
+
def place(grid, x, y, symbol):
|
|
297
|
+
grid[x][y] = add(grid[x][y], symbol)
|
|
298
|
+
|
|
299
|
+
place(grid, p1[0] + 1, p1[1], "|")
|
|
300
|
+
place(grid, p1[0] + 2, p1[1], "└" if p1[1] < p2[1] else "┘")
|
|
301
|
+
|
|
302
|
+
if p1[0] + 2 == p2[0] - 2:
|
|
303
|
+
a, b = (p1[1] + 1, p2[1] - 1) if p1[1] < p2[1] else (p2[1] + 1, p1[1] - 1)
|
|
304
|
+
for i in range(a, b + 1):
|
|
305
|
+
place(grid, p1[0] + 2, i, "-")
|
|
306
|
+
else:
|
|
307
|
+
middle = (p1[1] + p2[1]) // 2
|
|
308
|
+
a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
|
|
309
|
+
for i in range(a, b + 1):
|
|
310
|
+
place(grid, p1[0] + 2, i, "-")
|
|
311
|
+
a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
|
|
312
|
+
for i in range(a, b + 1):
|
|
313
|
+
place(grid, p1[0] + 2, i, "-")
|
|
314
|
+
|
|
315
|
+
place(grid, p1[0] + 2, middle, "┐" if p1[1] < p2[1] else "┌")
|
|
316
|
+
place(grid, p2[0] - 2, middle, "└" if p1[1] < p2[1] else "┘")
|
|
317
|
+
|
|
318
|
+
for i in range(p1[0] + 2 + 1, p2[0] - 2):
|
|
319
|
+
place(grid, i, middle, "|")
|
|
320
|
+
|
|
321
|
+
place(grid, p2[0] - 2, p2[1], "┐" if p1[1] < p2[1] else "┌")
|
|
322
|
+
place(grid, p2[0] - 1, p2[1], "|")
|
|
323
|
+
|
|
324
|
+
def text_rendering(self, prefix="") -> str:
|
|
325
|
+
"""
|
|
326
|
+
Renders a model in text.
|
|
327
|
+
|
|
328
|
+
.. runpython::
|
|
329
|
+
:showcode:
|
|
330
|
+
|
|
331
|
+
import textwrap
|
|
332
|
+
import onnx
|
|
333
|
+
import onnx.helper as oh
|
|
334
|
+
from onnx_diagnostic.helpers.graph_helper import GraphRendering
|
|
335
|
+
|
|
336
|
+
TFLOAT = onnx.TensorProto.FLOAT
|
|
337
|
+
|
|
338
|
+
proto = oh.make_model(
|
|
339
|
+
oh.make_graph(
|
|
340
|
+
[
|
|
341
|
+
oh.make_node("Add", ["X", "Y"], ["xy"]),
|
|
342
|
+
oh.make_node("Neg", ["Y"], ["ny"]),
|
|
343
|
+
oh.make_node("Mul", ["xy", "ny"], ["a"]),
|
|
344
|
+
oh.make_node("Mul", ["a", "Y"], ["Z"]),
|
|
345
|
+
],
|
|
346
|
+
"-nd-",
|
|
347
|
+
[
|
|
348
|
+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
|
|
349
|
+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
|
|
350
|
+
],
|
|
351
|
+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
|
|
352
|
+
),
|
|
353
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
354
|
+
ir_version=9,
|
|
355
|
+
)
|
|
356
|
+
graph = GraphRendering(proto)
|
|
357
|
+
text = textwrap.dedent(graph.text_rendering()).strip("\\n")
|
|
358
|
+
print(text)
|
|
359
|
+
"""
|
|
360
|
+
nodes = [
|
|
361
|
+
*[oh.make_node(i, ["BEGIN"], [i]) for i in self.input_names],
|
|
362
|
+
*self.nodes,
|
|
363
|
+
*[oh.make_node(i, [i], ["END"]) for i in self.output_names],
|
|
364
|
+
]
|
|
365
|
+
exist = set(self.start_names) - set(self.input_names)
|
|
366
|
+
exist |= {"BEGIN"}
|
|
367
|
+
existing = sorted(exist)
|
|
368
|
+
order = self.computation_order(nodes, existing)
|
|
369
|
+
positions = self.graph_positions(nodes, order, existing)
|
|
370
|
+
text_pos = self.text_positions(nodes, positions)
|
|
371
|
+
edges = self.build_node_edges(nodes)
|
|
372
|
+
max_len = max(col for _, col in text_pos) + max(len(n.op_type) for n in nodes)
|
|
373
|
+
assert max_len < 1e6, f"max_len={max_len}, text_pos=\n{pprint.pformat(text_pos)}"
|
|
374
|
+
max_row = max(row for row, _ in text_pos) + 2
|
|
375
|
+
grid = [[" " for i in range(max_len + 1)] for _ in range(max_row + 1)]
|
|
376
|
+
|
|
377
|
+
for n1, n2 in edges:
|
|
378
|
+
self.text_edge(grid, text_pos[n1], text_pos[n2])
|
|
379
|
+
assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
|
|
380
|
+
for node, pos in zip(nodes, text_pos):
|
|
381
|
+
self.text_grid(grid, pos, node.op_type)
|
|
382
|
+
assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
|
|
383
|
+
|
|
384
|
+
return "\n".join(
|
|
385
|
+
f"{prefix}{line.rstrip()}" for line in ["".join(line) for line in grid]
|
|
386
|
+
)
|
|
@@ -112,6 +112,30 @@ def string_type(
|
|
|
112
112
|
:param verbose: verbosity (to show the path it followed to get that print)
|
|
113
113
|
:return: str
|
|
114
114
|
|
|
115
|
+
The function displays something like the following for a tensor.
|
|
116
|
+
|
|
117
|
+
.. code-block:: text
|
|
118
|
+
|
|
119
|
+
T7s2x7[0.5:6:A3.56]
|
|
120
|
+
^^^+-^^----+------^
|
|
121
|
+
|| | |
|
|
122
|
+
|| | +-- information about the content of a tensor or array
|
|
123
|
+
|| | [min,max:A<average>]
|
|
124
|
+
|| |
|
|
125
|
+
|| +-- a shape
|
|
126
|
+
||
|
|
127
|
+
|+-- integer following the code defined by onnx.TensorProto,
|
|
128
|
+
| 7 is onnx.TensorProto.INT64 (see onnx_dtype_name)
|
|
129
|
+
|
|
|
130
|
+
+-- A,T,F
|
|
131
|
+
A is an array from numpy
|
|
132
|
+
T is a Tensor from pytorch
|
|
133
|
+
F is a FakeTensor from pytorch
|
|
134
|
+
|
|
135
|
+
The element types for a tensor are displayed as integer to shorten the message.
|
|
136
|
+
The semantic is defined by :class:`onnx.TensorProto` and can be obtained
|
|
137
|
+
by :func:`onnx_diagnostic.helpers.onnx_helper.onnx_dtype_name`.
|
|
138
|
+
|
|
115
139
|
.. runpython::
|
|
116
140
|
:showcode:
|
|
117
141
|
|