onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,386 @@
1
+ import pprint
2
+ from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
3
+ import onnx
4
+ import onnx.helper as oh
5
+
6
+
7
+ class GraphRendering:
8
+ """
9
+ Helpers to renders a graph.
10
+
11
+ :param proto: model or graph to render.
12
+ """
13
+
14
+ def __init__(self, proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto]):
15
+ self.proto = proto
16
+
17
+ def __repr__(self) -> str:
18
+ "usual"
19
+ return f"{self.__class__.__name__}(<{self.proto.__class__.__name__}>)"
20
+
21
+ @classmethod
22
+ def computation_order(
23
+ cls,
24
+ nodes: Sequence[onnx.NodeProto],
25
+ existing: Optional[List[str]] = None,
26
+ start: int = 1,
27
+ ) -> List[int]:
28
+ """
29
+ Returns the soonest a node can be computed,
30
+ every node can assume all nodes with a lower number exists.
31
+ Every node with a higher number must wait for the previous one.
32
+
33
+ :param nodes: list of nodes
34
+ :param existing: existing before any computation starts
35
+ :param start: lower number
36
+ :return: computation order
37
+ """
38
+ assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
39
+ f"This algorithme is not yet implemented if the sequence contains "
40
+ f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
41
+ )
42
+ number = {e: start - 1 for e in (existing or [])} # noqa: C420
43
+ results = [start for _ in nodes]
44
+ for i_node, node in enumerate(nodes):
45
+ assert all(i in number for i in node.input), (
46
+ f"Missing input in node {i_node} type={node.op_type}: "
47
+ f"{[i for i in node.input if i not in number]}"
48
+ )
49
+ if node.input:
50
+ mx = max(number[i] for i in node.input) + 1
51
+ results[i_node] = mx
52
+ else:
53
+ # A constant
54
+ mx = max(number.values()) if number else 0
55
+ for i in node.output:
56
+ number[i] = mx
57
+ return results
58
+
59
+ @classmethod
60
+ def graph_positions(
61
+ cls,
62
+ nodes: Sequence[onnx.NodeProto],
63
+ order: List[int],
64
+ existing: Optional[List[str]] = None,
65
+ ) -> List[Tuple[int, int]]:
66
+ """
67
+ Returns positions on a plan for every node in a graph.
68
+ The function minimizes the number of lines crossing each others.
69
+ It goes forward, every line is optimized depending on what is below.
70
+ It could be improved with more iterations.
71
+
72
+ :param nodes: list of nodes
73
+ :param existing: existing names
74
+ :param order: computation order returned by
75
+ :meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.computation_order`
76
+ :return: list of tuple( row, column)
77
+ """
78
+ # initialization
79
+ min_row = min(order)
80
+ n_rows = max(order) + 1
81
+ names: Dict[str, int] = {}
82
+
83
+ positions = [(min_row, i) for i in range(len(order))]
84
+ for row in range(min_row, n_rows):
85
+ indices = [i for i, o in enumerate(order) if o == row]
86
+ assert indices, f"indices cannot be empty for row={row}, order={order}"
87
+ ns = [nodes[i] for i in indices]
88
+ mx = [(max(names.get(i, 0) for i in n.input) if n.input else 0) for n in ns]
89
+ mix = [(m, i) for i, m in enumerate(mx)]
90
+ mix.sort()
91
+ for c, (_m, i) in enumerate(mix):
92
+ positions[indices[i]] = (row, c)
93
+ n = nodes[indices[i]]
94
+ for o in n.output:
95
+ names[o] = c
96
+
97
+ return positions
98
+
99
+ @classmethod
100
+ def text_positions(
101
+ cls, nodes: Sequence[onnx.NodeProto], positions: List[Tuple[int, int]]
102
+ ) -> List[Tuple[int, int]]:
103
+ """
104
+ Returns positions for the nodes assuming it is rendered into text.
105
+
106
+ :param nodes: list of nodes
107
+ :param positions: positions returned by
108
+ :meth:`onnx_diagnostic.helpers.graph_helper.GraphRendering.graph_positions`
109
+ :return: text positions
110
+ """
111
+ new_positions = [(row * 4, col * 2 + row) for row, col in positions]
112
+ column_size = {col: 3 for _, col in new_positions}
113
+ for i, (_row, col) in enumerate(new_positions):
114
+ size = len(nodes[i].op_type) + 5
115
+ column_size[col] = max(column_size[col], size)
116
+ assert column_size[col] < 200, (
117
+ f"column_size[{col}]={column_size[col]}, this is quite big, i={i}, "
118
+ f"nodes[i].op_type={nodes[i].op_type}"
119
+ )
120
+
121
+ # cumulated
122
+ sort = sorted(column_size.items())
123
+ cumul = dict(sort[:1])
124
+ results = {sort[0][0]: sort[0][1] // 2}
125
+ for col, size in sort[1:]:
126
+ c = max(cumul.values())
127
+ cumul[col] = c + size
128
+ results[col] = c + size // 2
129
+ return [(row, results[col]) for row, col in new_positions]
130
+
131
+ @property
132
+ def nodes(self) -> List[onnx.NodeProto]:
133
+ "Returns the list of nodes"
134
+ return (
135
+ self.proto.graph.node
136
+ if isinstance(self.proto, onnx.ModelProto)
137
+ else self.proto.node
138
+ )
139
+
140
+ @property
141
+ def start_names(self) -> List[onnx.NodeProto]:
142
+ "Returns the list of known names, inputs and initializer"
143
+ graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
144
+ input_names = (
145
+ list(graph.input)
146
+ if isinstance(graph, onnx.FunctionProto)
147
+ else [i.name for i in graph.input]
148
+ )
149
+ init_names = (
150
+ []
151
+ if isinstance(graph, onnx.FunctionProto)
152
+ else [
153
+ *[i.name for i in graph.initializer],
154
+ *[i.name for i in graph.sparse_initializer],
155
+ ]
156
+ )
157
+ return [*input_names, *init_names]
158
+
159
+ @property
160
+ def input_names(self) -> List[str]:
161
+ "Returns the list of input names."
162
+ return (
163
+ self.proto.input
164
+ if isinstance(self.proto, onnx.FunctionProto)
165
+ else [
166
+ i.name
167
+ for i in (
168
+ self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
169
+ ).input
170
+ ]
171
+ )
172
+
173
+ @property
174
+ def output_names(self) -> List[str]:
175
+ "Returns the list of output names."
176
+ return (
177
+ self.proto.output
178
+ if isinstance(self.proto, onnx.FunctionProto)
179
+ else [
180
+ i.name
181
+ for i in (
182
+ self.proto if isinstance(self.proto, onnx.GraphProto) else self.proto.graph
183
+ ).output
184
+ ]
185
+ )
186
+
187
+ @classmethod
188
+ def build_node_edges(cls, nodes: Sequence[onnx.NodeProto]) -> Set[Tuple[int, int]]:
189
+ """Builds the list of edges between nodes."""
190
+ produced = {}
191
+ for i, node in enumerate(nodes):
192
+ for o in node.output:
193
+ produced[o] = i
194
+ edges = set()
195
+ for i, node in enumerate(nodes):
196
+ for name in node.input:
197
+ if name in produced:
198
+ edge = produced[name], i
199
+ edges.add(edge)
200
+ return edges
201
+
202
+ ADD_RULES = {
203
+ ("┴", "┘"): "┴",
204
+ ("┴", "└"): "┴",
205
+ ("┬", "┐"): "┬",
206
+ ("┬", "┌"): "┬",
207
+ ("-", "└"): "┴",
208
+ ("-", "|"): "┼",
209
+ ("-", "┐"): "┬",
210
+ ("┐", "-"): "┬",
211
+ ("┘", "-"): "┴",
212
+ ("┴", "-"): "┴",
213
+ ("-", "┘"): "┴",
214
+ ("┌", "-"): "┬",
215
+ ("┬", "-"): "┬",
216
+ ("-", "┌"): "┬",
217
+ ("|", "-"): "┼",
218
+ ("└", "-"): "┴",
219
+ ("|", "└"): "├",
220
+ ("|", "┘"): "┤",
221
+ ("┐", "|"): "┤",
222
+ ("┬", "|"): "┼",
223
+ ("|", "┐"): "┤",
224
+ ("|", "┌"): "├",
225
+ ("├", "-"): "┼",
226
+ ("└", "|"): "├",
227
+ ("┤", "┐"): "┤",
228
+ ("┤", "|"): "┤",
229
+ ("├", "|"): "├",
230
+ ("┴", "┌"): "┼",
231
+ ("┐", "┌"): "┬",
232
+ ("┌", "┐"): "┬",
233
+ ("┌", "|"): "┼",
234
+ ("┴", "┐"): "┼",
235
+ ("┐", "└"): "┼",
236
+ ("┬", "┘"): "┼",
237
+ ("├", "└"): "├",
238
+ ("┤", "┌"): "┼",
239
+ ("┘", "|"): "┤",
240
+ ("┴", "|"): "┼",
241
+ ("┤", "-"): "┼",
242
+ ("┘", "└"): "┴",
243
+ }
244
+
245
+ @classmethod
246
+ def text_grid(cls, grid: List[List[str]], position: Tuple[int, int], text: str):
247
+ """
248
+ Prints inplace a text in a grid. The text is centered.
249
+
250
+ :param grid: grid
251
+ :param position: position
252
+ :param text: text to print
253
+ """
254
+ row, col = position
255
+ begin = col - len(text) // 2
256
+ grid[row][begin : begin + len(text)] = list(text)
257
+
258
+ def text_edge(
259
+ cls,
260
+ grid: List[List[str]],
261
+ p1: Tuple[int, int],
262
+ p2: Tuple[int, int],
263
+ mode: str = "square",
264
+ ):
265
+ """
266
+ Prints inplace an edge in a grid. The text is centered.
267
+
268
+ :param grid: grid
269
+ :param p1: first position
270
+ :param p2: second position
271
+ :param mode: ``'square'`` is the only supported value
272
+ """
273
+ assert mode == "square", f"mode={mode!r} not supported"
274
+ assert p1[0] < p2[0], f"Unexpected edge p1={p1}, p2={p2}"
275
+ assert p1[0] + 2 <= p2[0] - 2, f"Unexpected edge p1={p1}, p2={p2}"
276
+ # removes this when the algorithm is ready
277
+ assert 0 <= p1[0] < len(grid) - 3, f"p1={p1}, grid:{len(grid)},{len(grid[0])}"
278
+ assert 2 <= p2[0] < len(grid) - 1, f"p2={p2}, grid:{len(grid)},{len(grid[0])}"
279
+ assert (
280
+ 0 <= p1[1] < min(len(g) for g in grid)
281
+ ), f"p1={p1}, sizes={[len(g) for g in grid]}"
282
+ assert (
283
+ 0 <= p2[1] < min(len(g) for g in grid)
284
+ ), f"p2={p2}, sizes={[len(g) for g in grid]}"
285
+
286
+ def add(s1, s2):
287
+ assert s2 != " ", f"s1={s1!r}, s2={s2!r}"
288
+ if s1 == " " or s1 == s2:
289
+ return s2
290
+ if s1 == "┼" or s2 == "┼":
291
+ return "┼"
292
+ if (s1, s2) in cls.ADD_RULES:
293
+ return cls.ADD_RULES[s1, s2]
294
+ raise NotImplementedError(f"Unable to add: ({s1!r},{s2!r}): '',")
295
+
296
+ def place(grid, x, y, symbol):
297
+ grid[x][y] = add(grid[x][y], symbol)
298
+
299
+ place(grid, p1[0] + 1, p1[1], "|")
300
+ place(grid, p1[0] + 2, p1[1], "└" if p1[1] < p2[1] else "┘")
301
+
302
+ if p1[0] + 2 == p2[0] - 2:
303
+ a, b = (p1[1] + 1, p2[1] - 1) if p1[1] < p2[1] else (p2[1] + 1, p1[1] - 1)
304
+ for i in range(a, b + 1):
305
+ place(grid, p1[0] + 2, i, "-")
306
+ else:
307
+ middle = (p1[1] + p2[1]) // 2
308
+ a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
309
+ for i in range(a, b + 1):
310
+ place(grid, p1[0] + 2, i, "-")
311
+ a, b = (p1[1] + 1, middle - 1) if p1[1] < middle else (middle + 1, p1[1] - 1)
312
+ for i in range(a, b + 1):
313
+ place(grid, p1[0] + 2, i, "-")
314
+
315
+ place(grid, p1[0] + 2, middle, "┐" if p1[1] < p2[1] else "┌")
316
+ place(grid, p2[0] - 2, middle, "└" if p1[1] < p2[1] else "┘")
317
+
318
+ for i in range(p1[0] + 2 + 1, p2[0] - 2):
319
+ place(grid, i, middle, "|")
320
+
321
+ place(grid, p2[0] - 2, p2[1], "┐" if p1[1] < p2[1] else "┌")
322
+ place(grid, p2[0] - 1, p2[1], "|")
323
+
324
+ def text_rendering(self, prefix="") -> str:
325
+ """
326
+ Renders a model in text.
327
+
328
+ .. runpython::
329
+ :showcode:
330
+
331
+ import textwrap
332
+ import onnx
333
+ import onnx.helper as oh
334
+ from onnx_diagnostic.helpers.graph_helper import GraphRendering
335
+
336
+ TFLOAT = onnx.TensorProto.FLOAT
337
+
338
+ proto = oh.make_model(
339
+ oh.make_graph(
340
+ [
341
+ oh.make_node("Add", ["X", "Y"], ["xy"]),
342
+ oh.make_node("Neg", ["Y"], ["ny"]),
343
+ oh.make_node("Mul", ["xy", "ny"], ["a"]),
344
+ oh.make_node("Mul", ["a", "Y"], ["Z"]),
345
+ ],
346
+ "-nd-",
347
+ [
348
+ oh.make_tensor_value_info("X", TFLOAT, ["a", "b", "c"]),
349
+ oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
350
+ ],
351
+ [oh.make_tensor_value_info("Z", TFLOAT, ["a", "b", "c"])],
352
+ ),
353
+ opset_imports=[oh.make_opsetid("", 18)],
354
+ ir_version=9,
355
+ )
356
+ graph = GraphRendering(proto)
357
+ text = textwrap.dedent(graph.text_rendering()).strip("\\n")
358
+ print(text)
359
+ """
360
+ nodes = [
361
+ *[oh.make_node(i, ["BEGIN"], [i]) for i in self.input_names],
362
+ *self.nodes,
363
+ *[oh.make_node(i, [i], ["END"]) for i in self.output_names],
364
+ ]
365
+ exist = set(self.start_names) - set(self.input_names)
366
+ exist |= {"BEGIN"}
367
+ existing = sorted(exist)
368
+ order = self.computation_order(nodes, existing)
369
+ positions = self.graph_positions(nodes, order, existing)
370
+ text_pos = self.text_positions(nodes, positions)
371
+ edges = self.build_node_edges(nodes)
372
+ max_len = max(col for _, col in text_pos) + max(len(n.op_type) for n in nodes)
373
+ assert max_len < 1e6, f"max_len={max_len}, text_pos=\n{pprint.pformat(text_pos)}"
374
+ max_row = max(row for row, _ in text_pos) + 2
375
+ grid = [[" " for i in range(max_len + 1)] for _ in range(max_row + 1)]
376
+
377
+ for n1, n2 in edges:
378
+ self.text_edge(grid, text_pos[n1], text_pos[n2])
379
+ assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
380
+ for node, pos in zip(nodes, text_pos):
381
+ self.text_grid(grid, pos, node.op_type)
382
+ assert len(set(len(g) for g in grid)) == 1, f"lengths={[len(g) for g in grid]}"
383
+
384
+ return "\n".join(
385
+ f"{prefix}{line.rstrip()}" for line in ["".join(line) for line in grid]
386
+ )