onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,386 @@
|
|
|
1
|
+
import pprint
|
|
2
|
+
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
3
|
+
import onnx
|
|
4
|
+
import onnx.helper as oh
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class GraphRendering:
|
|
8
|
+
"""
|
|
9
|
+
Helpers to renders a graph.
|
|
10
|
+
|
|
11
|
+
:param proto: model or graph to render.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto]):
|
|
15
|
+
self.proto = proto
|
|
16
|
+
|
|
17
|
+
def __repr__(self) -> str:
|
|
18
|
+
"usual"
|
|
19
|
+
return f"{self.__class__.__name__}(<{self.proto.__class__.__name__}>)"
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def computation_order(
|
|
23
|
+
cls,
|
|
24
|
+
nodes: Sequence[onnx.NodeProto],
|
|
25
|
+
existing: Optional[List[str]] = None,
|
|
26
|
+
start: int = 1,
|
|
27
|
+
) -> List[int]:
|
|
28
|
+
"""
|
|
29
|
+
Returns the soonest a node can be computed,
|
|
30
|
+
every node can assume all nodes with a lower number exists.
|
|
31
|
+
Every node with a higher number must wait for the previous one.
|
|
32
|
+
|
|
33
|
+
:param nodes: list of nodes
|
|
34
|
+
:param existing: existing before any computation starts
|
|
35
|
+
:param start: lower number
|
|
36
|
+
:return: computation order
|
|
37
|
+
"""
|
|
38
|
+
assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
|
|
39
|
+
f"This algorithme is not yet implemented if the sequence contains "
|
|
40
|
+
f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
|
|
41
|
+
)
|
|
42
|
+
number = {e: start - 1 for e in (existing or [])} # noqa: C420
|
|
43
|
+
results = [start for _ in nodes]
|
|
44
|
+
for i_node, node in enumerate(nodes):
|
|
45
|
+
assert all(i in number for i in node.input), (
|
|
46
|
+
f"Missing input in node {i_node} type={node.op_type}: "
|
|
47
|
+
f"{[i for i in node.input if i not in number]}"
|
|
48
|
+
)
|
|
49
|
+
if node.input:
|
|
50
|
+
mx = max(number[i] for i in node.input) + 1
|
|
51
|
+
results[i_node] = mx
|
|
52
|
+
else:
|
|
53
|
+
# A constant
|
|
54
|
+
mx = max(number.values()) if number else 0
|
|
55
|
+
for i in node.output:
|
|
56
|
+
number[i] = mx
|
|
57
|
+
return results
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def graph_positions(
|
|
61
|
+
cls,
|
|
62
|
+
nodes: Sequence[onnx.NodeProto],
|
|
63
|
+
order: List[int],
|
|
64
|
+
existing: Optional[List[str]] = None,
|
|
65
|
+
) -> List[Tuple[int, int]]:
|
|
66
|
+
"""
|
|
67
|
+
Returns positions on a plan for every node in a graph.
|
|
68
|
+
The function minimizes the number of lines crossing each others.
|
|
69
|
+
It goes forward, every line is optimized depending on what is below.
|
|
70
|
+
It could be improved with more iterations.
|
|
71
|
+
|
|
72
|
+
:param nodes: list of nodes
|
|
73
|
+
:param existing: existing names
|
|
74
|
+
:param order: computation order returned by
|
|
75
|
+
:meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.computation_order`
|
|
76
|
+
:return: list of tuple( row, column)
|
|
77
|
+
"""
|
|
78
|
+
# initialization
|
|
79
|
+
min_row = min(order)
|
|
80
|
+
n_rows = max(order) + 1
|
|
81
|
+
names: Dict[str, int] = {}
|
|
82
|
+
|
|
83
|
+
positions = [(min_row, i) for i in range(len(order))]
|
|
84
|
+
for row in range(min_row, n_rows):
|
|
85
|
+
indices = [i for i, o in enumerate(order) if o == row]
|
|
86
|
+
assert indices, f"indices cannot be empty for row={row}, order={order}"
|
|
87
|
+
ns = [nodes[i] for i in indices]
|
|
88
|
+
mx = [(max(names.get(i, 0) for i in n.input) if n.input else 0) for n in ns]
|
|
89
|
+
mix = [(m, i) for i, m in enumerate(mx)]
|
|
90
|
+
mix.sort()
|
|
91
|
+
for c, (_m, i) in enumerate(mix):
|
|
92
|
+
positions[indices[i]] = (row, c)
|
|
93
|
+
n = nodes[indices[i]]
|
|
94
|
+
for o in n.output:
|
|
95
|
+
names[o] = c
|
|
96
|
+
|
|
97
|
+
return positions
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def text_positions(
|
|
101
|
+
cls, nodes: Sequence[onnx.NodeProto], positions: List[Tuple[int, int]]
|
|
102
|
+
) -> List[Tuple[int, int]]:
|
|
103
|
+
"""
|
|
104
|
+
Returns positions for the nodes assuming it is rendered into text.
|
|
105
|
+
|
|
106
|
+
:param nodes: list of nodes
|
|
107
|
+
:param positions: positions returned by
|
|
108
|
+
:meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.graph_positions`
|
|
109
|
+
:return: text positions
|
|
110
|
+
"""
|
|
111
|
+
new_positions = [(row * 4, col * 2 + row) for row, col in positions]
|
|
112
|
+
column_size = {col: 3 for _, col in new_positions}
|
|
113
|
+
for i, (_row, col) in enumerate(new_positions):
|
|
114
|
+
size = len(nodes[i].op_type) + 5
|
|
115
|
+
column_size[col] = max(column_size[col], size)
|
|
116
|
+
assert column_size[col] < 200, (
|
|
117
|
+
f"column_size[{col}]={column_size[col]}, this is quite big, i={i}, "
|
|
118
|
+
f"nodes[i].op_type={nodes[i].op_type}"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
# cumulated
|
|
122
|
+
sort = sorted(column_size.items())
|
|
123
|
+
cumul = dict(sort[:1])
|
|
124
|
+
results = {sort[0][0]: sort[0][1] // 2}
|
|
125
|
+
for col, size in sort[1:]:
|
|
126
|
+
c = max(cumul.values())
|
|
127
|
+
cumul[col] = c + size
|
|
128
|
+
results[col] = c + size // 2
|
|
129
|
+
return [(row, results[col]) for row, col in new_positions]
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def nodes(self) -> List[onnx.NodeProto]:
|
|
133
|
+
"Returns the list of nodes"
|
|
134
|
+
return (
|
|
135
|
+
self.proto.graph.node
|
|
136
|
+
if isinstance(self.proto, onnx.ModelProto)
|
|
137
|
+
else self.proto.node
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def start_names(self) -> List[onnx.NodeProto]:
|
|
142
|
+
"Returns the list of known names, inputs and initializer"
|
|
143
|
+
graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
|
|
144
|
+
input_names = (
|
|
145
|
+
list(graph.input)
|
|
146
|
+
if isinstance(graph, onnx.FunctionProto)
|
|
147
|
+
else [i.name for i in graph.input]
|
|
148
|
+
)
|
|
149
|
+
init_names = (
|
|
150
|
+
[]
|
|
151
|
+
if isinstance(graph, onnx.FunctionProto)
|
|
152
|
+
else [
|
|
153
|
+
*[i.name for i in graph.initializer],
|
|
154
|
+
*[i.name for i in graph.sparse_initializer],
|
|
155
|
+
]
|
|
156
|
+
)
|
|
157
|
+
return [*input_names, *init_names]
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def input_names(self) -> List[str]:
|
|
161
|
+
"Returns the list of input names."
|
|
162
|
+
return (
|
|
163
|
+
self.proto.input
|
|
164
|
+
if isinstance(self.proto, onnx.FunctionProto)
|
|
165
|
+
else [
|
|
166
|
+
i.name
|
|
167
|
+
for i in (
|
|
168
|
+
self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
|
|
169
|
+
).input
|
|
170
|
+
]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def output_names(self) -> List[str]:
|
|
175
|
+
"Returns the list of output names."
|
|
176
|
+
return (
|
|
177
|
+
self.proto.output
|
|
178
|
+
if isinstance(self.proto, onnx.FunctionProto)
|
|
179
|
+
else [
|
|
180
|
+
i.name
|
|
181
|
+
for i in (
|
|
182
|
+
self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
|
|
183
|
+
).output
|
|
184
|
+
]
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int]]:
|
|
189
|
+
"""Builds the list of edges between nodes."""
|
|
190
|
+
produced = {}
|
|
191
|
+
for i, node in enumerate(nodes):
|
|
192
|
+
for o in node.output:
|
|
193
|
+
produced[o] = i
|
|
194
|
+
edges = set()
|
|
195
|
+
for i, node in enumerate(nodes):
|
|
196
|
+
for name in node.input:
|
|
197
|
+
if name in produced:
|
|
198
|
+
edge = produced[name], i
|
|
199
|
+
edges.add(edge)
|
|
200
|
+
return edges
|
|
201
|
+
|
|
202
|
+
ADD_RULES = {
|
|
203
|
+
("┴", "┘"): "┴",
|
|
204
|
+
("┴", "└"): "┴",
|
|
205
|
+
("┬", "┐"): "┬",
|
|
206
|
+
("┬", "┌"): "┬",
|
|
207
|
+
("-", "└"): "┴",
|
|
208
|
+
("-", "|"): "┼",
|
|
209
|
+
("-", "┐"): "┬",
|
|
210
|
+
("┐", "-"): "┬",
|
|
211
|
+
("┘", "-"): "┴",
|
|
212
|
+
("┴", "-"): "┴",
|
|
213
|
+
("-", "┘"): "┴",
|
|
214
|
+
("┌", "-"): "┬",
|
|
215
|
+
("┬", "-"): "┬",
|
|
216
|
+
("-", "┌"): "┬",
|
|
217
|
+
("|", "-"): "┼",
|
|
218
|
+
("└", "-"): "┴",
|
|
219
|
+
("|", "└"): "├",
|
|
220
|
+
("|", "┘"): "┤",
|
|
221
|
+
("┐", "|"): "┤",
|
|
222
|
+
("┬", "|"): "┼",
|
|
223
|
+
("|", "┐"): "┤",
|
|
224
|
+
("|", "┌"): "├",
|
|
225
|
+
("├", "-"): "┼",
|
|
226
|
+
("└", "|"): "├",
|
|
227
|
+
("┤", "┐"): "┤",
|
|
228
|
+
("┤", "|"): "┤",
|
|
229
|
+
("├", "|"): "├",
|
|
230
|
+
("┴", "┌"): "┼",
|
|
231
|
+
("┐", "┌"): "┬",
|
|
232
|
+
("┌", "┐"): "┬",
|
|
233
|
+
("┌", "|"): "┼",
|
|
234
|
+
("┴", "┐"): "┼",
|
|
235
|
+
("┐", "└"): "┼",
|
|
236
|
+
("┬", "┘"): "┼",
|
|
237
|
+
("├", "└"): "├",
|
|
238
|
+
("┤", "┌"): "┼",
|
|
239
|
+
("┘", "|"): "┤",
|
|
240
|
+
("┴", "|"): "┼",
|
|
241
|
+
("┤", "-"): "┼",
|
|
242
|
+
("┘", "└"): "┴",
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
@classmethod
|
|
246
|
+
def text_grid(cls, grid: List[List[str]], position: Tuple[int, int], text: str):
|
|
247
|
+
"""
|
|
248
|
+
Prints inplace a text in a grid. The text is centered.
|
|
249
|
+
|
|
250
|
+
:param grid: grid
|
|
251
|
+
:param position: position
|
|
252
|
+
:param text: text to print
|
|
253
|
+
"""
|
|
254
|
+
row, col = position
|
|
255
|
+
begin = col - len(text) // 2
|
|
256
|
+
grid[row][begin : begin + len(text)] = list(text)
|
|
257
|
+
|
|
258
|
+
def text_edge(
|
|
259
|
+
cls,
|
|
260
|
+
grid: List[List[str]],
|
|
261
|
+
p1: Tuple[int, int],
|
|
262
|
+
p2: Tuple[int, int],
|
|
263
|
+
mode: str = "square",
|
|
264
|
+
):
|
|
265
|
+
"""
|
|
266
|
+
Prints inplace an edge in a grid. The text is centered.
|
|
267
|
+
|
|
268
|
+
:param grid: grid
|
|
269
|
+
:param p1: first position
|
|
270
|
+
:param p2: second position
|
|
271
|
+
:param mode: ``'square'`` is the only supported value
|
|
272
|
+
"""
|
|
273
|
+
assert mode == "square", f"mode={mode!r} not supported"
|
|
274
|
+
assert p1[0] < p2[0], f"Unexpected edge p1={p1}, p2={p2}"
|
|
275
|
+
assert p1[0] + 2 <= p2[0] - 2, f"Unexpected edge p1={p1}, p2={p2}"
|
|
276
|
+
# removes this when the algorithm is ready
|
|
277
|
+
assert 0 <= p1[0] < len(grid) - 3, f"p1={p1}, grid:{len(grid)},{len(grid[0])}"
|
|
278
|
+
assert 2 <= p2[0] < len(grid) - 1, f"p2={p2}, grid:{len(grid)},{len(grid[0])}"
|
|
279
|
+
assert (
|
|
280
|
+
0 <= p1[1] < min(len(g) for g in grid)
|
|
281
|
+
), f"p1={p1}, sizes={[len(g) for g in grid]}"
|
|
282
|
+
assert (
|
|
283
|
+
0 <= p2[1] < min(len(g) for g in grid)
|
|
284
|
+
), f"p2={p2}, sizes={[len(g) for g in grid]}"
|
|
285
|
+
|
|
286
|
+
def add(s1, s2):
|
|
287
|
+
assert s2 != " ", f"s1={s1!r}, s2={s2!r}"
|
|
288
|
+
if s1 == " " or s1 == s2:
|
|
289
|
+
return s2
|
|
290
|
+
if s1 == "┼" or s2 == "┼":
|
|
291
|
+
return "┼"
|
|
292
|
+
if (s1, s2) in cls.ADD_RULES:
|
|
293
|
+
return cls.ADD_RULES[s1, s2]
|
|
294
|
+
raise NotImplementedError(f"Unable to add: ({s1!r},{s2!r}): '',")
|
|
295
|
+
|
|
296
|
+
def place(grid, x, y, symbol):
|
|
297
|
+
grid[x][y] = add(grid[x][y], symbol)
|
|
298
|
+
|
|
299
|
+
place(grid, p1[0] + 1, p1[1], "|")
|
|
300
|
+
place(grid, p1[0] + 2, p1[1], "└" if p1[1] < p2[1] else "┘")
|
|
301
|
+
|
|
302
|
+
if p1[0] + 2 == p2[0] - 2:
|
|
303
|
+
a, b = (p1[1] + 1, p2[1] - 1) if p1[1] < p2[1] else (p2[1] + 1, p1[1] - 1)
|
|
304
|
+
for i in range(a, b + 1):
|
|
305
|
+
place(grid, p1[0] + 2, i, "-")
|
|
306
|
+
else:
|
|
307
|
+
middle = (p1[1] + p2[1]) // 2
|
|
308
|
+
a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
|
|
309
|
+
for i in range(a, b + 1):
|
|
310
|
+
place(grid, p1[0] + 2, i, "-")
|
|
311
|
+
a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
|
|
312
|
+
for i in range(a, b + 1):
|
|
313
|
+
place(grid, p1[0] + 2, i, "-")
|
|
314
|
+
|
|
315
|
+
place(grid, p1[0] + 2, middle, "┐" if p1[1] < p2[1] else "┌")
|
|
316
|
+
place(grid, p2[0] - 2, middle, "└" if p1[1] < p2[1] else "┘")
|
|
317
|
+
|
|
318
|
+
for i in range(p1[0] + 2 + 1, p2[0] - 2):
|
|
319
|
+
place(grid, i, middle, "|")
|
|
320
|
+
|
|
321
|
+
place(grid, p2[0] - 2, p2[1], "┐" if p1[1] < p2[1] else "┌")
|
|
322
|
+
place(grid, p2[0] - 1, p2[1], "|")
|
|
323
|
+
|
|
324
|
+
def text_rendering(self, prefix="") -> str:
|
|
325
|
+
"""
|
|
326
|
+
Renders a model in text.
|
|
327
|
+
|
|
328
|
+
.. runpython::
|
|
329
|
+
:showcode:
|
|
330
|
+
|
|
331
|
+
import textwrap
|
|
332
|
+
import onnx
|
|
333
|
+
import onnx.helper as oh
|
|
334
|
+
from onnx_diagnostic.helpers.graph_helper import GraphRendering
|
|
335
|
+
|
|
336
|
+
TFLOAT = onnx.TensorProto.FLOAT
|
|
337
|
+
|
|
338
|
+
proto = oh.make_model(
|
|
339
|
+
oh.make_graph(
|
|
340
|
+
[
|
|
341
|
+
oh.make_node("Add", ["X", "Y"], ["xy"]),
|
|
342
|
+
oh.make_node("Neg", ["Y"], ["ny"]),
|
|
343
|
+
oh.make_node("Mul", ["xy", "ny"], ["a"]),
|
|
344
|
+
oh.make_node("Mul", ["a", "Y"], ["Z"]),
|
|
345
|
+
],
|
|
346
|
+
"-nd-",
|
|
347
|
+
[
|
|
348
|
+
oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
|
|
349
|
+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
|
|
350
|
+
],
|
|
351
|
+
[oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
|
|
352
|
+
),
|
|
353
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
354
|
+
ir_version=9,
|
|
355
|
+
)
|
|
356
|
+
graph = GraphRendering(proto)
|
|
357
|
+
text = textwrap.dedent(graph.text_rendering()).strip("\\n")
|
|
358
|
+
print(text)
|
|
359
|
+
"""
|
|
360
|
+
nodes = [
|
|
361
|
+
*[oh.make_node(i, ["BEGIN"], [i]) for i in self.input_names],
|
|
362
|
+
*self.nodes,
|
|
363
|
+
*[oh.make_node(i, [i], ["END"]) for i in self.output_names],
|
|
364
|
+
]
|
|
365
|
+
exist = set(self.start_names) - set(self.input_names)
|
|
366
|
+
exist |= {"BEGIN"}
|
|
367
|
+
existing = sorted(exist)
|
|
368
|
+
order = self.computation_order(nodes, existing)
|
|
369
|
+
positions = self.graph_positions(nodes, order, existing)
|
|
370
|
+
text_pos = self.text_positions(nodes, positions)
|
|
371
|
+
edges = self.build_node_edges(nodes)
|
|
372
|
+
max_len = max(col for _, col in text_pos) + max(len(n.op_type) for n in nodes)
|
|
373
|
+
assert max_len < 1e6, f"max_len={max_len}, text_pos=\n{pprint.pformat(text_pos)}"
|
|
374
|
+
max_row = max(row for row, _ in text_pos) + 2
|
|
375
|
+
grid = [[" " for i in range(max_len + 1)] for _ in range(max_row + 1)]
|
|
376
|
+
|
|
377
|
+
for n1, n2 in edges:
|
|
378
|
+
self.text_edge(grid, text_pos[n1], text_pos[n2])
|
|
379
|
+
assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
|
|
380
|
+
for node, pos in zip(nodes, text_pos):
|
|
381
|
+
self.text_grid(grid, pos, node.op_type)
|
|
382
|
+
assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
|
|
383
|
+
|
|
384
|
+
return "\n".join(
|
|
385
|
+
f"{prefix}{line.rstrip()}" for line in ["".join(line) for line in grid]
|
|
386
|
+
)
|