onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +67 -9
- onnx_diagnostic/ci_models/__init__.py +0 -0
- onnx_diagnostic/ci_models/ci_helpers.py +430 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +560 -0
- onnx_diagnostic/export/api.py +15 -4
- onnx_diagnostic/export/cf_simple_loop_for.py +352 -0
- onnx_diagnostic/export/control_flow_onnx.py +23 -17
- onnx_diagnostic/export/onnx_plug.py +60 -6
- onnx_diagnostic/ext_test_case.py +14 -0
- onnx_diagnostic/helpers/helper.py +26 -27
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +16 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +10 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +103 -31
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
- onnx_diagnostic/torch_onnx/compare.py +357 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/RECORD +22 -19
- onnx_diagnostic/export/control_flow.py +0 -214
- onnx_diagnostic/export/control_flow_research.py +0 -140
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.4.dist-info → onnx_diagnostic-0.8.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
4
|
+
import numpy as np
|
|
5
|
+
import onnx
|
|
6
|
+
from ..helpers.onnx_helper import onnx_dtype_name
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
_NOT_SO_FAR_OPS = [
|
|
10
|
+
{"MatMul", "Gemm", "FusedMatMul"},
|
|
11
|
+
{"Conv", "FusedConv"},
|
|
12
|
+
{"MaxPool"},
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _sum_sets(sets):
|
|
17
|
+
t = set()
|
|
18
|
+
for s in sets:
|
|
19
|
+
t |= s
|
|
20
|
+
return t
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_ALL_NOT_SO_FAR_OPS = _sum_sets(_NOT_SO_FAR_OPS)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _align(res: str, limit: int) -> str:
|
|
27
|
+
if len(res) == limit:
|
|
28
|
+
return res
|
|
29
|
+
if len(res) > limit:
|
|
30
|
+
return res[:limit]
|
|
31
|
+
return res + " " * (limit - len(res))
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ObsType(enum.IntEnum):
|
|
35
|
+
"""Observation kind."""
|
|
36
|
+
|
|
37
|
+
RESULT = 1
|
|
38
|
+
INITIALIZER = 2
|
|
39
|
+
SPARSE_INITIALIZER = 4
|
|
40
|
+
INPUT = 8
|
|
41
|
+
OUTPUT = 16
|
|
42
|
+
NODE = 32
|
|
43
|
+
|
|
44
|
+
def __repr__(self):
|
|
45
|
+
return f"{self.__class__.__name__}.{self._name_}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ObsCompare:
|
|
50
|
+
"""
|
|
51
|
+
The description of an observation, a node, an input, an output, an initializer.
|
|
52
|
+
|
|
53
|
+
:param position: index of this observation in the original model
|
|
54
|
+
:param kind: node type, see :class:`ObsType`
|
|
55
|
+
:param name_or_outputs: name of an initializer or the outputs of a node
|
|
56
|
+
:param itype: onnx type
|
|
57
|
+
:param index: index of an input or output
|
|
58
|
+
:param shape: shape
|
|
59
|
+
:param op_type: node op_type
|
|
60
|
+
:param comment: comment, unused
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
position: int
|
|
64
|
+
kind: ObsType
|
|
65
|
+
name_or_outputs: Tuple[str]
|
|
66
|
+
itype: int = 0
|
|
67
|
+
index: int = 0
|
|
68
|
+
shape: Optional[Tuple[Tuple[Union[int, str], ...]]] = None
|
|
69
|
+
op_type: str = ""
|
|
70
|
+
comment: str = ""
|
|
71
|
+
|
|
72
|
+
def __str__(self) -> str:
|
|
73
|
+
"usual"
|
|
74
|
+
els = [
|
|
75
|
+
_align(f"{self.position:04d}", 4),
|
|
76
|
+
_align(self.kind._name_, 6),
|
|
77
|
+
_align(onnx_dtype_name(self.itype) if self.itype else "?", 8),
|
|
78
|
+
_align("?" if self.shape is None else "x".join(map(str, self.shape)), 18),
|
|
79
|
+
_align(self.op_type or "", 15),
|
|
80
|
+
_align(", ".join(self.name_or_outputs), 35),
|
|
81
|
+
]
|
|
82
|
+
return " ".join(els)
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
def to_str(cls, obs: Optional["ObsCompare"]) -> str:
|
|
86
|
+
assert not obs or isinstance(obs, ObsCompare), f"unexpected type {type(obs)}"
|
|
87
|
+
if obs:
|
|
88
|
+
return str(obs)
|
|
89
|
+
return " " * (4 + 6 + 8 + 18 + 15 + 35 + 5)
|
|
90
|
+
|
|
91
|
+
def distance(self, obs: "ObsCompare") -> float:
|
|
92
|
+
"""Computes a cost between two observations."""
|
|
93
|
+
if self.kind != obs.kind:
|
|
94
|
+
return 1e6
|
|
95
|
+
d: float = 0
|
|
96
|
+
if self.itype != obs.itype:
|
|
97
|
+
d += 1e5
|
|
98
|
+
if self.kind == ObsType.NODE:
|
|
99
|
+
cost = 9997
|
|
100
|
+
d = 0
|
|
101
|
+
if self.op_type != obs.op_type:
|
|
102
|
+
if self.op_type in _ALL_NOT_SO_FAR_OPS or obs.op_type in _ALL_NOT_SO_FAR_OPS:
|
|
103
|
+
d += 1e2
|
|
104
|
+
for aset in _NOT_SO_FAR_OPS:
|
|
105
|
+
if self.op_type in aset and obs.op_type in aset:
|
|
106
|
+
cost = 97
|
|
107
|
+
elif self.op_type in aset or obs.op_type in aset:
|
|
108
|
+
d += 5e4
|
|
109
|
+
else:
|
|
110
|
+
d += 9e2
|
|
111
|
+
if len(self.name_or_outputs) == 1 and len(obs.name_or_outputs) == 1:
|
|
112
|
+
if self.name_or_outputs[0] != obs.name_or_outputs[0]:
|
|
113
|
+
n1 = self.name_or_outputs[0]
|
|
114
|
+
n2 = obs.name_or_outputs[0]
|
|
115
|
+
n1 = n1.replace("_", "")
|
|
116
|
+
n2 = n2.replace("_", "")
|
|
117
|
+
if n1 == n2:
|
|
118
|
+
d += 1
|
|
119
|
+
elif (n1.startswith(("val_", "_onx_")) or "::" in n1 or "--" in n1) and (
|
|
120
|
+
n2.startswith(("val_", "_onx_")) or "::" in n2 or "--" in n2
|
|
121
|
+
):
|
|
122
|
+
# These are name given the exporter
|
|
123
|
+
# and not inspired from the model itself.
|
|
124
|
+
d += cost / 100
|
|
125
|
+
else:
|
|
126
|
+
d += cost
|
|
127
|
+
else:
|
|
128
|
+
a = set(self.name_or_outputs) & set(obs.name_or_outputs)
|
|
129
|
+
b = set(self.name_or_outputs) | set(obs.name_or_outputs)
|
|
130
|
+
d += cost * (len(b) - len(a))
|
|
131
|
+
return d
|
|
132
|
+
if self.kind == ObsType.INPUT:
|
|
133
|
+
return (
|
|
134
|
+
999.7
|
|
135
|
+
if self.itype != obs.itype
|
|
136
|
+
or self.shape != obs.shape
|
|
137
|
+
or self.index != obs.index
|
|
138
|
+
else 0
|
|
139
|
+
)
|
|
140
|
+
if self.kind == ObsType.INITIALIZER or self.kind == ObsType.SPARSE_INITIALIZER:
|
|
141
|
+
return 1e3 if self.itype != obs.itype or self.shape != obs.shape else 0
|
|
142
|
+
if self.kind == ObsType.OUTPUT:
|
|
143
|
+
return (
|
|
144
|
+
999.1
|
|
145
|
+
if self.itype != obs.itype
|
|
146
|
+
or self.shape != obs.shape
|
|
147
|
+
or self.index != obs.index
|
|
148
|
+
else 0
|
|
149
|
+
)
|
|
150
|
+
return 1e8
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def obs_sequence_from_model(
|
|
154
|
+
cls,
|
|
155
|
+
model: Union[onnx.ModelProto, onnx.GraphProto],
|
|
156
|
+
) -> List["ObsCompare"]:
|
|
157
|
+
"""
|
|
158
|
+
Creates a sequence of observations bases on a model.
|
|
159
|
+
|
|
160
|
+
:param model: model
|
|
161
|
+
:return: sequence of observations
|
|
162
|
+
"""
|
|
163
|
+
graph = model if isinstance(model, onnx.GraphProto) else model.graph
|
|
164
|
+
|
|
165
|
+
shapes = {}
|
|
166
|
+
types = {}
|
|
167
|
+
for info in [*graph.value_info, *graph.input, *graph.output]:
|
|
168
|
+
if info.type.tensor_type:
|
|
169
|
+
t = info.type.tensor_type
|
|
170
|
+
shapes[info.name] = tuple((d.dim_param or d.dim_value) for d in t.shape.dim)
|
|
171
|
+
types[info.name] = t.elem_type
|
|
172
|
+
|
|
173
|
+
seq: List[ObsCompare] = []
|
|
174
|
+
for init in graph.initializer:
|
|
175
|
+
obs = ObsCompare(
|
|
176
|
+
position=len(seq),
|
|
177
|
+
kind=ObsType.INITIALIZER,
|
|
178
|
+
itype=init.data_type,
|
|
179
|
+
shape=tuple(init.dims),
|
|
180
|
+
name_or_outputs=(init.name,),
|
|
181
|
+
)
|
|
182
|
+
seq.append(obs)
|
|
183
|
+
for i, inp in enumerate(graph.input):
|
|
184
|
+
obs = ObsCompare(
|
|
185
|
+
position=len(seq),
|
|
186
|
+
kind=ObsType.INPUT,
|
|
187
|
+
itype=inp.type.tensor_type.elem_type,
|
|
188
|
+
index=i,
|
|
189
|
+
shape=tuple(
|
|
190
|
+
(d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim
|
|
191
|
+
),
|
|
192
|
+
name_or_outputs=(inp.name,),
|
|
193
|
+
)
|
|
194
|
+
seq.append(obs)
|
|
195
|
+
for node in graph.node:
|
|
196
|
+
obs = ObsCompare(
|
|
197
|
+
position=len(seq),
|
|
198
|
+
kind=ObsType.NODE,
|
|
199
|
+
itype=types.get(node.output[0], 0),
|
|
200
|
+
index=i,
|
|
201
|
+
shape=shapes.get(node.output[0], None),
|
|
202
|
+
name_or_outputs=tuple(node.output),
|
|
203
|
+
op_type=node.op_type,
|
|
204
|
+
)
|
|
205
|
+
seq.append(obs)
|
|
206
|
+
for i, inp in enumerate(graph.output):
|
|
207
|
+
obs = ObsCompare(
|
|
208
|
+
position=len(seq),
|
|
209
|
+
kind=ObsType.OUTPUT,
|
|
210
|
+
itype=inp.type.tensor_type.elem_type,
|
|
211
|
+
index=i,
|
|
212
|
+
shape=tuple(
|
|
213
|
+
(d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim
|
|
214
|
+
),
|
|
215
|
+
name_or_outputs=(inp.name,),
|
|
216
|
+
)
|
|
217
|
+
seq.append(obs)
|
|
218
|
+
return seq
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@dataclass
|
|
222
|
+
class ObsComparePair:
|
|
223
|
+
"""
|
|
224
|
+
Defines a pair of comparison objects
|
|
225
|
+
|
|
226
|
+
:param side1: object from first side
|
|
227
|
+
:param side2: object from first side
|
|
228
|
+
:param distance: distance
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
side1: Optional[ObsCompare]
|
|
232
|
+
side2: Optional[ObsCompare]
|
|
233
|
+
distance: float
|
|
234
|
+
|
|
235
|
+
def __str__(self) -> str:
|
|
236
|
+
"nice display"
|
|
237
|
+
return (
|
|
238
|
+
f"{self.distance:.4e} | "
|
|
239
|
+
f"{ObsCompare.to_str(self.side1)} | {ObsCompare.to_str(self.side2)}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
@classmethod
|
|
243
|
+
def to_str(cls, seq: List["ObsComparePair"]) -> str:
|
|
244
|
+
"""Displays every pair in text."""
|
|
245
|
+
return "\n".join([f"{str(pair)}" for pair in seq])
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tuple[
|
|
249
|
+
float,
|
|
250
|
+
List[Tuple[int, int]],
|
|
251
|
+
List["ObsComparePair"],
|
|
252
|
+
]:
|
|
253
|
+
"""
|
|
254
|
+
Computes the distance between two sequences of results.
|
|
255
|
+
|
|
256
|
+
:param s1: first sequence
|
|
257
|
+
:param s2: second sequence
|
|
258
|
+
:return: distance and alignment
|
|
259
|
+
|
|
260
|
+
An example:
|
|
261
|
+
|
|
262
|
+
.. runpython::
|
|
263
|
+
:showcode:
|
|
264
|
+
|
|
265
|
+
import torch
|
|
266
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
267
|
+
from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class Model(torch.nn.Module):
|
|
271
|
+
def __init__(self):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.conv1 = torch.nn.Conv2d(3, 16, 5)
|
|
274
|
+
self.fc1 = torch.nn.Linear(144, 64)
|
|
275
|
+
self.fc2 = torch.nn.Linear(64, 128)
|
|
276
|
+
self.fc3 = torch.nn.Linear(128, 10)
|
|
277
|
+
|
|
278
|
+
def forward(self, x):
|
|
279
|
+
x = torch.nn.functional.max_pool2d(
|
|
280
|
+
torch.nn.functional.relu(self.conv1(x)),
|
|
281
|
+
(4, 4),
|
|
282
|
+
)
|
|
283
|
+
# x = F.max_pool2d(F.relu(self.conv2(x)), 2)
|
|
284
|
+
x = torch.flatten(x, 1)
|
|
285
|
+
x = torch.nn.functional.relu(self.fc1(x))
|
|
286
|
+
x = torch.nn.functional.relu(self.fc2(x))
|
|
287
|
+
y = self.fc3(x)
|
|
288
|
+
return y
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
model = Model()
|
|
292
|
+
x = torch.randn((2, 3, 16, 17), dtype=torch.float32)
|
|
293
|
+
dynamic_shapes = ({0: "batch", 3: "dim"},)
|
|
294
|
+
onnx_optimized = to_onnx(
|
|
295
|
+
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True
|
|
296
|
+
).model_proto
|
|
297
|
+
onnx_not_optimized = to_onnx(
|
|
298
|
+
model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False
|
|
299
|
+
).model_proto
|
|
300
|
+
seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized)
|
|
301
|
+
seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized)
|
|
302
|
+
_dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
|
|
303
|
+
text = ObsComparePair.to_str(pair_cmp)
|
|
304
|
+
print(text)
|
|
305
|
+
"""
|
|
306
|
+
delay = max(50, abs(len(s2) - len(s1)) + 1)
|
|
307
|
+
distance: Dict[Tuple[int, int], Union[int, float]] = {(-1, -1): 0}
|
|
308
|
+
predecessor: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {(-1, -1): None}
|
|
309
|
+
insert_cost = 1e3
|
|
310
|
+
for i in range(len(s1)):
|
|
311
|
+
for j in range(max(0, i - delay), min(len(s2), i + delay)):
|
|
312
|
+
best = distance.get((i, j), 1e100)
|
|
313
|
+
pred = None
|
|
314
|
+
ki, kj = i - 1, j - 1
|
|
315
|
+
if (ki, kj) in distance:
|
|
316
|
+
d = distance[ki, kj] + s1[i].distance(s2[j])
|
|
317
|
+
if d < best:
|
|
318
|
+
best = d
|
|
319
|
+
pred = (ki, kj)
|
|
320
|
+
ki, kj = i - 1, j
|
|
321
|
+
if (ki, kj) in distance:
|
|
322
|
+
d = distance[ki, kj] + insert_cost + 1
|
|
323
|
+
if d < best:
|
|
324
|
+
best = d
|
|
325
|
+
pred = (ki, kj)
|
|
326
|
+
ki, kj = i, j - 1
|
|
327
|
+
if (ki, kj) in distance:
|
|
328
|
+
d = distance[ki, kj] + insert_cost + 0.1
|
|
329
|
+
if d < best:
|
|
330
|
+
best = d
|
|
331
|
+
pred = (ki, kj)
|
|
332
|
+
distance[i, j] = best
|
|
333
|
+
predecessor[i, j] = pred
|
|
334
|
+
|
|
335
|
+
# reverse
|
|
336
|
+
way = []
|
|
337
|
+
last: Optional[Tuple[int, int]] = len(s1) - 1, len(s2) - 1
|
|
338
|
+
while last is not None:
|
|
339
|
+
way.append(last)
|
|
340
|
+
last = predecessor[last]
|
|
341
|
+
indices = list(reversed(way))[1:]
|
|
342
|
+
obs_path: List[ObsComparePair] = []
|
|
343
|
+
last = -1, -1
|
|
344
|
+
for i, j in indices:
|
|
345
|
+
di = i - last[0]
|
|
346
|
+
dj = j - last[1]
|
|
347
|
+
cost = distance.get((i, j), np.nan)
|
|
348
|
+
if di == dj == 1:
|
|
349
|
+
obs_path.append(ObsComparePair(s1[i], s2[j], distance=cost))
|
|
350
|
+
elif di == 0:
|
|
351
|
+
obs_path.append(ObsComparePair(None, s2[j], distance=cost))
|
|
352
|
+
elif dj == 0:
|
|
353
|
+
obs_path.append(ObsComparePair(s1[i], None, distance=cost))
|
|
354
|
+
else:
|
|
355
|
+
raise RuntimeError(f"issue with di={di}, dj={dj}")
|
|
356
|
+
last = i, j
|
|
357
|
+
return distance[len(s1) - 1, len(s2) - 1], indices, obs_path
|
|
@@ -1,16 +1,18 @@
|
|
|
1
|
-
onnx_diagnostic/__init__.py,sha256=
|
|
1
|
+
onnx_diagnostic/__init__.py,sha256=YQit5D2idhb9-wNQZzvWLT_qwRrKWBoTqMpNlBaWsGw,173
|
|
2
2
|
onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
|
|
3
|
-
onnx_diagnostic/_command_lines_parser.py,sha256=
|
|
3
|
+
onnx_diagnostic/_command_lines_parser.py,sha256=AWT6XrphbR0C0w9J846jPcRWkoUtnSSAX7gdR-JavQ4,54258
|
|
4
4
|
onnx_diagnostic/api.py,sha256=BhCl_yCd78N7TlVtPOHjeYv1QBEy39TjZ647rcHqLh0,345
|
|
5
5
|
onnx_diagnostic/doc.py,sha256=t3RELgfooYnVMAi0JSpggWkQEgUsREz8NmRvn0TnLI8,2829
|
|
6
|
-
onnx_diagnostic/ext_test_case.py,sha256=
|
|
6
|
+
onnx_diagnostic/ext_test_case.py,sha256=KxRC6s9107hYvNgU9x2B85rj8_EhAtymPIlMpmkUNu8,50154
|
|
7
|
+
onnx_diagnostic/ci_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
onnx_diagnostic/ci_models/ci_helpers.py,sha256=6CKQ4dVHBHeF6rN_Q3Y_0ZFeLYbYQbGQO3YyW3PQyAc,15341
|
|
9
|
+
onnx_diagnostic/ci_models/export_qwen25_vl.py,sha256=pyxmMIps9aDNkXzDAFrG8Q9DDOsEyKRHjoVvAggjFdU,20050
|
|
7
10
|
onnx_diagnostic/export/__init__.py,sha256=yEIoWiOeTwBsDhyYt2fTKuhtA0Ya1J9u9ZzMTOTWaWs,101
|
|
8
|
-
onnx_diagnostic/export/api.py,sha256=
|
|
9
|
-
onnx_diagnostic/export/
|
|
10
|
-
onnx_diagnostic/export/control_flow_onnx.py,sha256=
|
|
11
|
-
onnx_diagnostic/export/control_flow_research.py,sha256=RuYz9_eM42Bk6TKSiPV6dS68LIMZu-6WBCFCKoSvjrk,5422
|
|
11
|
+
onnx_diagnostic/export/api.py,sha256=BX4c99gMlRYsBWk3P15FMRogArxjP4dXYXP5gILjgIk,10626
|
|
12
|
+
onnx_diagnostic/export/cf_simple_loop_for.py,sha256=0I1tRAwhmmqA-6Qaq8AiUL0Ci-HODuRAVcI9azNcxAQ,13345
|
|
13
|
+
onnx_diagnostic/export/control_flow_onnx.py,sha256=izGlctqQANrHzSxPMbT7hoauNbnIBdx6hb8ry7HtVmM,18263
|
|
12
14
|
onnx_diagnostic/export/dynamic_shapes.py,sha256=M2hlpHSTbkzZwGKAbrpQXng5HQrwjF5Z6wGGxEgnp74,42061
|
|
13
|
-
onnx_diagnostic/export/onnx_plug.py,sha256=
|
|
15
|
+
onnx_diagnostic/export/onnx_plug.py,sha256=U13fL0BjnhMzcDGxaAOqM4TQte5Z4zKDg4ESS0iktjM,22704
|
|
14
16
|
onnx_diagnostic/export/shape_helper.py,sha256=m628y0oRCQbeZkeh8JDHIfWMsSjoJoeX-IPiPGDHT-w,11273
|
|
15
17
|
onnx_diagnostic/export/validate.py,sha256=_PGUql2DJhIgGKo0WjTGUc5AgsZUx8fEs00MePy-w98,6043
|
|
16
18
|
onnx_diagnostic/helpers/__init__.py,sha256=GJ2GT7cgnlIveVUwMZhuvUwidbTJaKv8CsSIOpZDsJg,83
|
|
@@ -23,7 +25,7 @@ onnx_diagnostic/helpers/doc_helper.py,sha256=pl5MZd3_FaE8BqQnqoBuSBxoNCFcd2OJd3e
|
|
|
23
25
|
onnx_diagnostic/helpers/dot_helper.py,sha256=hwgTJsbsUv0qq7euyPDnc1NsBZDGOwv32JXSZxIHJkE,8118
|
|
24
26
|
onnx_diagnostic/helpers/fake_tensor_helper.py,sha256=J7wnK3WTuVKnYiMzLVTAPkdJr3hQfIfMC9ZlOu7oGmI,11024
|
|
25
27
|
onnx_diagnostic/helpers/graph_helper.py,sha256=hevQT5a7_QuriVPQcbT5qe18n99Doyl5h3-qshx1-uk,14093
|
|
26
|
-
onnx_diagnostic/helpers/helper.py,sha256=
|
|
28
|
+
onnx_diagnostic/helpers/helper.py,sha256=x8EYQmgrz_G5QS_IsbeFIoDcN_sUs-CslJMHseBj1Fw,65482
|
|
27
29
|
onnx_diagnostic/helpers/log_helper.py,sha256=0lJiTF87lliI-LmgpUH_V2N8NuzJ0LryH0mSYpkRaL8,93272
|
|
28
30
|
onnx_diagnostic/helpers/memory_peak.py,sha256=M3m4_thWFIwP5HytbJYEqaijXIv5v5BW_vlcJowIYI4,6434
|
|
29
31
|
onnx_diagnostic/helpers/mini_onnx_builder.py,sha256=jR2lkRZEQ0N30H0FqeBwaxJd_w_6kyxFagrnulqFjhE,23883
|
|
@@ -100,7 +102,7 @@ onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=jJCMWuOqGv5ahCfjr
|
|
|
100
102
|
onnx_diagnostic/tasks/data/__init__.py,sha256=uJoemrWgEjI6oA-tMX7r3__x-b3siPmkgqaY7bgIles,401
|
|
101
103
|
onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx,sha256=UbtvmWMqcZOKJ-I-HXWI1A6YR6QDaFS5u_yXm5C3ZBw,10299
|
|
102
104
|
onnx_diagnostic/torch_export_patches/__init__.py,sha256=0SaZedwznm1hQUCvXZsGZORV5vby954wEExr5faepGg,720
|
|
103
|
-
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=
|
|
105
|
+
onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=OpZHNWiA0iU-6WCFZcVCj06_MopYiZQ6c6CbAuSQ8Ms,42357
|
|
104
106
|
onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=0HdubI06EGpxOICqDWZoVmZkVO9gAaFADEmby197EyM,11935
|
|
105
107
|
onnx_diagnostic/torch_export_patches/patch_details.py,sha256=MSraVo5ngBhihi8ssPMXSY9B4fJ17J-GAADaw3dT-rc,11794
|
|
106
108
|
onnx_diagnostic/torch_export_patches/patch_expressions.py,sha256=vr4tt61cbDnaaaduzMj4UBZ8OUtr6GfDpIWwOYqjWzs,3213
|
|
@@ -117,15 +119,15 @@ onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.p
|
|
|
117
119
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py,sha256=nVgYQk0xXpHiictN1wOHVMN2lTH9b0vfIJ4ie-uKopg,1999
|
|
118
120
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py,sha256=VIZsVHgR8NmAcBQalPl5I6ZzNgcBxjGb6ars31m9gRg,21936
|
|
119
121
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py,sha256=kTjuTRsfkGGGhspJnMxAMQSchZgGC_IruJzpHh_FmI8,6348
|
|
120
|
-
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py,sha256=
|
|
122
|
+
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py,sha256=HE3fovyvMiYe9EPz1UjdD9AWopX3H188SMwPb8w5mzM,7111
|
|
121
123
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py,sha256=OxYdlLrwtd_KGHt3E17poduxvWFg-CfGS57-yN1i6gI,3827
|
|
122
|
-
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py,sha256=
|
|
124
|
+
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py,sha256=yALbXWi3ysJ6nzQD-rxTdxdNJiBsTbYEBIj4TdksDOA,34598
|
|
123
125
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py,sha256=cND9Iqo1aKdlX-BXGr9Qlq_Y4EW1L5VWSwZfqYTVazU,4888
|
|
124
126
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py,sha256=4bJ_z2gizZQla_fcCVt0dmuhzO9Vu-D7CCMWdxMlrKM,16893
|
|
125
127
|
onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py,sha256=-6TuBm3sLAFEGuW3vRfOTtE5uP6aINFfu7xMnl27Dws,5703
|
|
126
128
|
onnx_diagnostic/torch_export_patches/patches/patch_helper.py,sha256=kK_CGW643iVXxa-m6pttDBS7HTyMQaPypza7iqIInn4,721
|
|
127
129
|
onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=FfES0WWiWxmuQbGTlQ7IJS0YBG7km3IQbnMYwk_lPPU,44667
|
|
128
|
-
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=
|
|
130
|
+
onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=Mvq8q1Lz3l3GyCD6j8WQjbrPk_V2dnc4iKm3cC_o1OA,3112
|
|
129
131
|
onnx_diagnostic/torch_export_patches/serialization/__init__.py,sha256=BHLdRPtNAtNPAS-bPKEj3-foGSPvwAbZXrHzGGPDLEw,1876
|
|
130
132
|
onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py,sha256=drq3EH_yjcSuIWYsVeUWm8Cx6YCZFU6bP_1PLtPfY5I,945
|
|
131
133
|
onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py,sha256=sIHFvUQoMK8ytXQYB-k7OL62z8A3f5uDaq-S5R5uN-M,10034
|
|
@@ -136,18 +138,19 @@ onnx_diagnostic/torch_models/validate.py,sha256=fnbTl5v1n5nM2MpmCgCMaWa6c7DGpb5m
|
|
|
136
138
|
onnx_diagnostic/torch_models/hghub/__init__.py,sha256=vi1Q7YHdddj1soiBN42MSvJdFqe2_KUoWafHISjwOu8,58
|
|
137
139
|
onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=rFbiPNLET-KdBpnv-p0nKgwHX6d7C_Z0s9zZ86_92kQ,14307
|
|
138
140
|
onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=8V_pAgACPLPsLRYUododg7MSL6str-T3tBEGY4OaeYQ,8724
|
|
139
|
-
onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=
|
|
141
|
+
onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=Dxa13rsnTQ8eH_BcQvbY2bp1AYFtzuFrJ-J_urrSmeQ,292694
|
|
140
142
|
onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=tCGqigRyY1omxm2rczRUvCTsweZGbF1MccWI3MmCH20,17423
|
|
141
143
|
onnx_diagnostic/torch_models/hghub/model_specific.py,sha256=j50Nu7wddJMoqmD4QzMbNdFDUUgUmSBKRzPDH55TlUQ,2498
|
|
142
144
|
onnx_diagnostic/torch_models/untrained/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
143
145
|
onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=y_akbdApi136qHcEQgykwIAYVw0Yfi0lbjb3DNuafaU,3948
|
|
144
146
|
onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=QXw_Bs2SzfeiQMf-tmtVl83SmVOL4-Um7Qy-f0E48QI,2507
|
|
145
147
|
onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
148
|
+
onnx_diagnostic/torch_onnx/compare.py,sha256=O0lws4kzn8WAXr8-x-YMPr7oyBC9DtSIs4OfOr4S5-E,12305
|
|
146
149
|
onnx_diagnostic/torch_onnx/runtime_info.py,sha256=u1bD6VXqzBCRmqmbzQtDswaPs1PH_ygr1r-CrcfXpNU,8562
|
|
147
150
|
onnx_diagnostic/torch_onnx/sbs.py,sha256=8okBEIupMgw7TtKc80YFimMtwnY3GchdY05FsA9ooa0,40749
|
|
148
151
|
onnx_diagnostic/torch_onnx/sbs_dataclasses.py,sha256=UctdBjzoPTQG1LS0tZ8A6E9hpoq5HWUYaJLPOPJc9FI,20299
|
|
149
|
-
onnx_diagnostic-0.8.
|
|
150
|
-
onnx_diagnostic-0.8.
|
|
151
|
-
onnx_diagnostic-0.8.
|
|
152
|
-
onnx_diagnostic-0.8.
|
|
153
|
-
onnx_diagnostic-0.8.
|
|
152
|
+
onnx_diagnostic-0.8.6.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
|
|
153
|
+
onnx_diagnostic-0.8.6.dist-info/METADATA,sha256=9xPlJ9UHYSSIyEMqxN14mqZg31Rq9jqzHaczfqWhu-4,6734
|
|
154
|
+
onnx_diagnostic-0.8.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
155
|
+
onnx_diagnostic-0.8.6.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
|
|
156
|
+
onnx_diagnostic-0.8.6.dist-info/RECORD,,
|
|
@@ -1,214 +0,0 @@
|
|
|
1
|
-
import contextlib
|
|
2
|
-
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
|
3
|
-
import torch
|
|
4
|
-
from torch._higher_order_ops.utils import (
|
|
5
|
-
materialize_as_graph,
|
|
6
|
-
check_input_alias_and_mutation_return_outputs,
|
|
7
|
-
# _maybe_reenter_make_fx,
|
|
8
|
-
)
|
|
9
|
-
|
|
10
|
-
_TEST_EXPORT = False
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
@contextlib.contextmanager
|
|
14
|
-
def enable_code_export_control_flow():
|
|
15
|
-
"""Enables the code meant to be exported."""
|
|
16
|
-
global _TEST_EXPORT
|
|
17
|
-
old = _TEST_EXPORT
|
|
18
|
-
_TEST_EXPORT = True
|
|
19
|
-
try:
|
|
20
|
-
yield
|
|
21
|
-
finally:
|
|
22
|
-
_TEST_EXPORT = old
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def is_exporting() -> bool:
|
|
26
|
-
"""
|
|
27
|
-
Returns :func:`torch.compiler.is_exporting` or
|
|
28
|
-
:func:`torch.compiler.is_compiling`.
|
|
29
|
-
Changes ``_TEST_EXPORT`` to make it trigger.
|
|
30
|
-
"""
|
|
31
|
-
return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
|
|
35
|
-
"""
|
|
36
|
-
Python implementation of the loop.
|
|
37
|
-
|
|
38
|
-
:param n_iter: number of iteration
|
|
39
|
-
:param body_fn: function implementing the body
|
|
40
|
-
:param reduction_dim: dimension used to reduce the list produced by the loop
|
|
41
|
-
:param args: arguments to the loop body
|
|
42
|
-
:return: results
|
|
43
|
-
"""
|
|
44
|
-
res = []
|
|
45
|
-
for i in torch.arange(n_iter, dtype=n_iter.dtype):
|
|
46
|
-
r = body_fn(i, *args)
|
|
47
|
-
if isinstance(r, tuple):
|
|
48
|
-
assert not res or len(r) == len(res[-1]), (
|
|
49
|
-
f"Unexpected number of results {len(r)} for function {body_fn}, "
|
|
50
|
-
f"expected {len(res[-1])}"
|
|
51
|
-
)
|
|
52
|
-
res.append(r)
|
|
53
|
-
else:
|
|
54
|
-
assert isinstance(r, torch.Tensor), (
|
|
55
|
-
f"Unexpected type {r} for function {body_fn}, "
|
|
56
|
-
f"it must be a tuple or a Tensor."
|
|
57
|
-
)
|
|
58
|
-
assert not res or len(res[-1]) == 1, (
|
|
59
|
-
f"Unexpected number of results {len(r)} for function {body_fn}, "
|
|
60
|
-
f"expected {len(res[-1])}"
|
|
61
|
-
)
|
|
62
|
-
res.append((r,))
|
|
63
|
-
|
|
64
|
-
if not res:
|
|
65
|
-
return torch.empty(tuple(), dtype=torch.float32, device=args[0].device)
|
|
66
|
-
if len(res) == 1:
|
|
67
|
-
final = res[0]
|
|
68
|
-
else:
|
|
69
|
-
n_res = len(res[0])
|
|
70
|
-
final = [
|
|
71
|
-
torch.cat(
|
|
72
|
-
[r[i] for r in res],
|
|
73
|
-
dim=(
|
|
74
|
-
0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
|
|
75
|
-
),
|
|
76
|
-
)
|
|
77
|
-
for i in range(n_res)
|
|
78
|
-
]
|
|
79
|
-
return tuple(final) if len(final) > 1 else final[0]
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def make_custom_loop_for(
|
|
83
|
-
n_iter: torch.Tensor,
|
|
84
|
-
body_fn: Callable,
|
|
85
|
-
reduction_dim: Optional[Sequence[int]],
|
|
86
|
-
args: Sequence[torch.Tensor],
|
|
87
|
-
body_gm: Optional[torch.fx.GraphModule] = None,
|
|
88
|
-
body_mutated_inputs: Optional[List[Any]] = None,
|
|
89
|
-
body_outputs: Optional[List[Any]] = None,
|
|
90
|
-
) -> Tuple[str, torch.library.CustomOpDef]:
|
|
91
|
-
"""
|
|
92
|
-
Defines a custom operator for a loop in order to avoid
|
|
93
|
-
:func:`torch.export.export` digging into it.
|
|
94
|
-
It registers the custom op and a custom conversion
|
|
95
|
-
to ONNX.
|
|
96
|
-
|
|
97
|
-
:param n_iter: number of iterations defined by a tensor of no dimension
|
|
98
|
-
:param body_fn: the loop body defined as a function
|
|
99
|
-
:param reduction_dim: dimension used to concatenated the results
|
|
100
|
-
:param args: list of tensors, input to the body
|
|
101
|
-
:param body_gm: torch.fx.GraphModule equivalent to *body_gm*
|
|
102
|
-
:param body_mutated_inputs: inputs to *body_gm*
|
|
103
|
-
:param body_outputs: outputs to *body_gm*
|
|
104
|
-
:return: a name and the custom op definition, the name
|
|
105
|
-
is used to cache the custom op
|
|
106
|
-
"""
|
|
107
|
-
assert body_gm is not None, "body_gm cannot be None"
|
|
108
|
-
assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
|
|
109
|
-
assert body_outputs is not None, "body_outputs cannot be None"
|
|
110
|
-
|
|
111
|
-
srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
|
|
112
|
-
sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
|
|
113
|
-
full_name = (
|
|
114
|
-
body_fn.__qualname__.replace("<locals>", "L")
|
|
115
|
-
.replace("<lambda>", "l")
|
|
116
|
-
.replace(".", "_")
|
|
117
|
-
)
|
|
118
|
-
name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
|
|
119
|
-
|
|
120
|
-
schema = "(str body_fn, Tensor n_iter, Tensor[] body_inputs) -> Tensor"
|
|
121
|
-
if len(body_outputs) > 1:
|
|
122
|
-
schema += "[]"
|
|
123
|
-
custom_def = torch.library.CustomOpDef("onnx_higher_ops", "loop_for", schema, body_fn)
|
|
124
|
-
custom_def.register_kernel("cpu")(body_fn)
|
|
125
|
-
|
|
126
|
-
custom_def._abstract_fn = lambda _fn_id, *_args, _o=body_outputs: (
|
|
127
|
-
tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
|
|
128
|
-
)
|
|
129
|
-
return name, custom_def
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def loop_for(
|
|
133
|
-
n_iter: Union[torch.SymInt, torch.Tensor],
|
|
134
|
-
body_fn: Callable[..., Tuple[torch.Tensor]],
|
|
135
|
-
args: Sequence[torch.Tensor],
|
|
136
|
-
reduction_dim: Optional[Sequence[int]] = None,
|
|
137
|
-
) -> Tuple[torch.Tensor, ...]:
|
|
138
|
-
"""
|
|
139
|
-
High operators used to easily export a loop in ONNX.
|
|
140
|
-
Does not fully work with :func:`torch.export.export`,
|
|
141
|
-
it does replaces a custom op with a loop operator afterwards.
|
|
142
|
-
Every iteration produces tensors, all of them are gathered
|
|
143
|
-
into lists, all these lists are concatenated into tensors.
|
|
144
|
-
|
|
145
|
-
:param n_iter: number of iterations, it can be fixed on
|
|
146
|
-
variable, in that case it should a tensor with no dimension
|
|
147
|
-
:param body_fn: function body, takes only tensors and returns
|
|
148
|
-
only tensors, the first argument is the iteration number
|
|
149
|
-
in a tensor with no dimension, all the others
|
|
150
|
-
are not changed during the loop
|
|
151
|
-
:param args: the available tensors at every loop
|
|
152
|
-
:param reduction_dim: the loop aggregated the results into list,
|
|
153
|
-
one of each output, each of them is concatenated into one
|
|
154
|
-
tensor along one dimension, by default, it is the first
|
|
155
|
-
dimension, but it can be defined otherwise
|
|
156
|
-
"""
|
|
157
|
-
assert args, "The function should have at least one arg."
|
|
158
|
-
assert (
|
|
159
|
-
isinstance(n_iter, torch.Tensor)
|
|
160
|
-
and n_iter.dtype == torch.int64
|
|
161
|
-
and len(n_iter.shape) == 0
|
|
162
|
-
), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}."
|
|
163
|
-
if is_exporting():
|
|
164
|
-
from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
|
|
165
|
-
|
|
166
|
-
# tracer = _CURRENT_MAKE_FX_TRACER.fx_tracer
|
|
167
|
-
root = _CURRENT_MAKE_FX_TRACER.fx_tracer.root
|
|
168
|
-
# graph = _CURRENT_MAKE_FX_TRACER.fx_tracer.graph
|
|
169
|
-
|
|
170
|
-
body_gm: torch.fx.GraphModule = materialize_as_graph(
|
|
171
|
-
body_fn, (torch.tensor(0, dtype=torch.int64), *args)
|
|
172
|
-
)
|
|
173
|
-
(
|
|
174
|
-
_1,
|
|
175
|
-
_2,
|
|
176
|
-
_3,
|
|
177
|
-
body_mutated_inputs,
|
|
178
|
-
body_outputs,
|
|
179
|
-
) = check_input_alias_and_mutation_return_outputs(body_gm)
|
|
180
|
-
name, _custom_ops = make_custom_loop_for(
|
|
181
|
-
n_iter,
|
|
182
|
-
body_fn,
|
|
183
|
-
reduction_dim,
|
|
184
|
-
args,
|
|
185
|
-
body_gm=body_gm,
|
|
186
|
-
body_mutated_inputs=body_mutated_inputs,
|
|
187
|
-
body_outputs=body_outputs,
|
|
188
|
-
)
|
|
189
|
-
root.register_module(name, body_gm)
|
|
190
|
-
# body_graph = _maybe_reenter_make_fx(body_fn)(n_iter, *args)
|
|
191
|
-
return torch.ops.onnx_higher_ops.loop_for(name, n_iter, args)
|
|
192
|
-
|
|
193
|
-
return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
"""
|
|
197
|
-
proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
|
|
198
|
-
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
|
|
199
|
-
|
|
200
|
-
args = (cond_graph, body_graph, carried_inputs, additional_inputs)
|
|
201
|
-
|
|
202
|
-
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
|
|
203
|
-
|
|
204
|
-
out_proxy = proxy_mode.tracer.create_proxy(
|
|
205
|
-
"call_function", op, proxy_args, {}, name=op._name
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
out = op(
|
|
209
|
-
cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
|
|
210
|
-
)
|
|
211
|
-
return track_tensor_tree(
|
|
212
|
-
out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
|
213
|
-
)
|
|
214
|
-
"""
|