onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,357 @@
1
+ import enum
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+ import numpy as np
5
+ import onnx
6
+ from ..helpers.onnx_helper import onnx_dtype_name
7
+
8
+
9
+ _NOT_SO_FAR_OPS = [
10
+ {"MatMul", "Gemm", "FusedMatMul"},
11
+ {"Conv", "FusedConv"},
12
+ {"MaxPool"},
13
+ ]
14
+
15
+
16
+ def _sum_sets(sets):
17
+ t = set()
18
+ for s in sets:
19
+ t |= s
20
+ return t
21
+
22
+
23
+ _ALL_NOT_SO_FAR_OPS = _sum_sets(_NOT_SO_FAR_OPS)
24
+
25
+
26
+ def _align(res: str, limit: int) -> str:
27
+ if len(res) == limit:
28
+ return res
29
+ if len(res) > limit:
30
+ return res[:limit]
31
+ return res + " " * (limit - len(res))
32
+
33
+
34
+ class ObsType(enum.IntEnum):
35
+ """Observation kind."""
36
+
37
+ RESULT = 1
38
+ INITIALIZER = 2
39
+ SPARSE_INITIALIZER = 4
40
+ INPUT = 8
41
+ OUTPUT = 16
42
+ NODE = 32
43
+
44
+ def __repr__(self):
45
+ return f"{self.__class__.__name__}.{self._name_}"
46
+
47
+
48
+ @dataclass
49
+ class ObsCompare:
50
+ """
51
+ The description of an observation, a node, an input, an output, an initializer.
52
+
53
+ :param position: index of this observation in the original model
54
+ :param kind: node type, see :class:`ObsType`
55
+ :param name_or_outputs: name of an initializer or the outputs of a node
56
+ :param itype: onnx type
57
+ :param index: index of an input or output
58
+ :param shape: shape
59
+ :param op_type: node op_type
60
+ :param comment: comment, unused
61
+ """
62
+
63
+ position: int
64
+ kind: ObsType
65
+ name_or_outputs: Tuple[str]
66
+ itype: int = 0
67
+ index: int = 0
68
+ shape: Optional[Tuple[Tuple[Union[int, str], ...]]] = None
69
+ op_type: str = ""
70
+ comment: str = ""
71
+
72
+ def __str__(self) -> str:
73
+ "usual"
74
+ els = [
75
+ _align(f"{self.position:04d}", 4),
76
+ _align(self.kind._name_, 6),
77
+ _align(onnx_dtype_name(self.itype) if self.itype else "?", 8),
78
+ _align("?" if self.shape is None else "x".join(map(str, self.shape)), 18),
79
+ _align(self.op_type or "", 15),
80
+ _align(", ".join(self.name_or_outputs), 35),
81
+ ]
82
+ return " ".join(els)
83
+
84
+ @classmethod
85
+ def to_str(cls, obs: Optional["ObsCompare"]) -> str:
86
+ assert not obs or isinstance(obs, ObsCompare), f"unexpected type {type(obs)}"
87
+ if obs:
88
+ return str(obs)
89
+ return " " * (4 + 6 + 8 + 18 + 15 + 35 + 5)
90
+
91
+ def distance(self, obs: "ObsCompare") -> float:
92
+ """Computes a cost between two observations."""
93
+ if self.kind != obs.kind:
94
+ return 1e6
95
+ d: float = 0
96
+ if self.itype != obs.itype:
97
+ d += 1e5
98
+ if self.kind == ObsType.NODE:
99
+ cost = 9997
100
+ d = 0
101
+ if self.op_type != obs.op_type:
102
+ if self.op_type in _ALL_NOT_SO_FAR_OPS or obs.op_type in _ALL_NOT_SO_FAR_OPS:
103
+ d += 1e2
104
+ for aset in _NOT_SO_FAR_OPS:
105
+ if self.op_type in aset and obs.op_type in aset:
106
+ cost = 97
107
+ elif self.op_type in aset or obs.op_type in aset:
108
+ d += 5e4
109
+ else:
110
+ d += 9e2
111
+ if len(self.name_or_outputs) == 1 and len(obs.name_or_outputs) == 1:
112
+ if self.name_or_outputs[0] != obs.name_or_outputs[0]:
113
+ n1 = self.name_or_outputs[0]
114
+ n2 = obs.name_or_outputs[0]
115
+ n1 = n1.replace("_", "")
116
+ n2 = n2.replace("_", "")
117
+ if n1 == n2:
118
+ d += 1
119
+ elif (n1.startswith(("val_", "_onx_")) or "::" in n1 or "--" in n1) and (
120
+ n2.startswith(("val_", "_onx_")) or "::" in n2 or "--" in n2
121
+ ):
122
+ # These are name given the exporter
123
+ # and not inspired from the model itself.
124
+ d += cost / 100
125
+ else:
126
+ d += cost
127
+ else:
128
+ a = set(self.name_or_outputs) & set(obs.name_or_outputs)
129
+ b = set(self.name_or_outputs) | set(obs.name_or_outputs)
130
+ d += cost * (len(b) - len(a))
131
+ return d
132
+ if self.kind == ObsType.INPUT:
133
+ return (
134
+ 999.7
135
+ if self.itype != obs.itype
136
+ or self.shape != obs.shape
137
+ or self.index != obs.index
138
+ else 0
139
+ )
140
+ if self.kind == ObsType.INITIALIZER or self.kind == ObsType.SPARSE_INITIALIZER:
141
+ return 1e3 if self.itype != obs.itype or self.shape != obs.shape else 0
142
+ if self.kind == ObsType.OUTPUT:
143
+ return (
144
+ 999.1
145
+ if self.itype != obs.itype
146
+ or self.shape != obs.shape
147
+ or self.index != obs.index
148
+ else 0
149
+ )
150
+ return 1e8
151
+
152
+ @classmethod
153
+ def obs_sequence_from_model(
154
+ cls,
155
+ model: Union[onnx.ModelProto, onnx.GraphProto],
156
+ ) -> List["ObsCompare"]:
157
+ """
158
+ Creates a sequence of observations bases on a model.
159
+
160
+ :param model: model
161
+ :return: sequence of observations
162
+ """
163
+ graph = model if isinstance(model, onnx.GraphProto) else model.graph
164
+
165
+ shapes = {}
166
+ types = {}
167
+ for info in [*graph.value_info, *graph.input, *graph.output]:
168
+ if info.type.tensor_type:
169
+ t = info.type.tensor_type
170
+ shapes[info.name] = tuple((d.dim_param or d.dim_value) for d in t.shape.dim)
171
+ types[info.name] = t.elem_type
172
+
173
+ seq: List[ObsCompare] = []
174
+ for init in graph.initializer:
175
+ obs = ObsCompare(
176
+ position=len(seq),
177
+ kind=ObsType.INITIALIZER,
178
+ itype=init.data_type,
179
+ shape=tuple(init.dims),
180
+ name_or_outputs=(init.name,),
181
+ )
182
+ seq.append(obs)
183
+ for i, inp in enumerate(graph.input):
184
+ obs = ObsCompare(
185
+ position=len(seq),
186
+ kind=ObsType.INPUT,
187
+ itype=inp.type.tensor_type.elem_type,
188
+ index=i,
189
+ shape=tuple(
190
+ (d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim
191
+ ),
192
+ name_or_outputs=(inp.name,),
193
+ )
194
+ seq.append(obs)
195
+ for node in graph.node:
196
+ obs = ObsCompare(
197
+ position=len(seq),
198
+ kind=ObsType.NODE,
199
+ itype=types.get(node.output[0], 0),
200
+ index=i,
201
+ shape=shapes.get(node.output[0], None),
202
+ name_or_outputs=tuple(node.output),
203
+ op_type=node.op_type,
204
+ )
205
+ seq.append(obs)
206
+ for i, inp in enumerate(graph.output):
207
+ obs = ObsCompare(
208
+ position=len(seq),
209
+ kind=ObsType.OUTPUT,
210
+ itype=inp.type.tensor_type.elem_type,
211
+ index=i,
212
+ shape=tuple(
213
+ (d.dim_param or d.dim_value) for d in inp.type.tensor_type.shape.dim
214
+ ),
215
+ name_or_outputs=(inp.name,),
216
+ )
217
+ seq.append(obs)
218
+ return seq
219
+
220
+
221
+ @dataclass
222
+ class ObsComparePair:
223
+ """
224
+ Defines a pair of comparison objects
225
+
226
+ :param side1: object from first side
227
+ :param side2: object from first side
228
+ :param distance: distance
229
+ """
230
+
231
+ side1: Optional[ObsCompare]
232
+ side2: Optional[ObsCompare]
233
+ distance: float
234
+
235
+ def __str__(self) -> str:
236
+ "nice display"
237
+ return (
238
+ f"{self.distance:.4e} | "
239
+ f"{ObsCompare.to_str(self.side1)} | {ObsCompare.to_str(self.side2)}"
240
+ )
241
+
242
+ @classmethod
243
+ def to_str(cls, seq: List["ObsComparePair"]) -> str:
244
+ """Displays every pair in text."""
245
+ return "\n".join([f"{str(pair)}" for pair in seq])
246
+
247
+ @classmethod
248
+ def distance_sequence(cls, s1: List["ObsCompare"], s2: List["ObsCompare"]) -> Tuple[
249
+ float,
250
+ List[Tuple[int, int]],
251
+ List["ObsComparePair"],
252
+ ]:
253
+ """
254
+ Computes the distance between two sequences of results.
255
+
256
+ :param s1: first sequence
257
+ :param s2: second sequence
258
+ :return: distance and alignment
259
+
260
+ An example:
261
+
262
+ .. runpython::
263
+ :showcode:
264
+
265
+ import torch
266
+ from onnx_diagnostic.export.api import to_onnx
267
+ from onnx_diagnostic.torch_onnx.compare import ObsComparePair, ObsCompare
268
+
269
+
270
+ class Model(torch.nn.Module):
271
+ def __init__(self):
272
+ super().__init__()
273
+ self.conv1 = torch.nn.Conv2d(3, 16, 5)
274
+ self.fc1 = torch.nn.Linear(144, 64)
275
+ self.fc2 = torch.nn.Linear(64, 128)
276
+ self.fc3 = torch.nn.Linear(128, 10)
277
+
278
+ def forward(self, x):
279
+ x = torch.nn.functional.max_pool2d(
280
+ torch.nn.functional.relu(self.conv1(x)),
281
+ (4, 4),
282
+ )
283
+ # x = F.max_pool2d(F.relu(self.conv2(x)), 2)
284
+ x = torch.flatten(x, 1)
285
+ x = torch.nn.functional.relu(self.fc1(x))
286
+ x = torch.nn.functional.relu(self.fc2(x))
287
+ y = self.fc3(x)
288
+ return y
289
+
290
+
291
+ model = Model()
292
+ x = torch.randn((2, 3, 16, 17), dtype=torch.float32)
293
+ dynamic_shapes = ({0: "batch", 3: "dim"},)
294
+ onnx_optimized = to_onnx(
295
+ model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=True
296
+ ).model_proto
297
+ onnx_not_optimized = to_onnx(
298
+ model, (x,), dynamic_shapes=dynamic_shapes, exporter="custom", optimize=False
299
+ ).model_proto
300
+ seq1 = ObsCompare.obs_sequence_from_model(onnx_not_optimized)
301
+ seq2 = ObsCompare.obs_sequence_from_model(onnx_optimized)
302
+ _dist, _path, pair_cmp = ObsComparePair.distance_sequence(seq1, seq2)
303
+ text = ObsComparePair.to_str(pair_cmp)
304
+ print(text)
305
+ """
306
+ delay = max(50, abs(len(s2) - len(s1)) + 1)
307
+ distance: Dict[Tuple[int, int], Union[int, float]] = {(-1, -1): 0}
308
+ predecessor: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {(-1, -1): None}
309
+ insert_cost = 1e3
310
+ for i in range(len(s1)):
311
+ for j in range(max(0, i - delay), min(len(s2), i + delay)):
312
+ best = distance.get((i, j), 1e100)
313
+ pred = None
314
+ ki, kj = i - 1, j - 1
315
+ if (ki, kj) in distance:
316
+ d = distance[ki, kj] + s1[i].distance(s2[j])
317
+ if d < best:
318
+ best = d
319
+ pred = (ki, kj)
320
+ ki, kj = i - 1, j
321
+ if (ki, kj) in distance:
322
+ d = distance[ki, kj] + insert_cost + 1
323
+ if d < best:
324
+ best = d
325
+ pred = (ki, kj)
326
+ ki, kj = i, j - 1
327
+ if (ki, kj) in distance:
328
+ d = distance[ki, kj] + insert_cost + 0.1
329
+ if d < best:
330
+ best = d
331
+ pred = (ki, kj)
332
+ distance[i, j] = best
333
+ predecessor[i, j] = pred
334
+
335
+ # reverse
336
+ way = []
337
+ last: Optional[Tuple[int, int]] = len(s1) - 1, len(s2) - 1
338
+ while last is not None:
339
+ way.append(last)
340
+ last = predecessor[last]
341
+ indices = list(reversed(way))[1:]
342
+ obs_path: List[ObsComparePair] = []
343
+ last = -1, -1
344
+ for i, j in indices:
345
+ di = i - last[0]
346
+ dj = j - last[1]
347
+ cost = distance.get((i, j), np.nan)
348
+ if di == dj == 1:
349
+ obs_path.append(ObsComparePair(s1[i], s2[j], distance=cost))
350
+ elif di == 0:
351
+ obs_path.append(ObsComparePair(None, s2[j], distance=cost))
352
+ elif dj == 0:
353
+ obs_path.append(ObsComparePair(s1[i], None, distance=cost))
354
+ else:
355
+ raise RuntimeError(f"issue with di={di}, dj={dj}")
356
+ last = i, j
357
+ return distance[len(s1) - 1, len(s2) - 1], indices, obs_path
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.8.4
3
+ Version: 0.8.6
4
4
  Summary: Tools to help converting pytorch models into ONNX.
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré
@@ -1,16 +1,18 @@
1
- onnx_diagnostic/__init__.py,sha256=BuvSD4fYmz8ZVadwiG4GCHeb46p5sNX7-_GM16OtKW0,173
1
+ onnx_diagnostic/__init__.py,sha256=YQit5D2idhb9-wNQZzvWLT_qwRrKWBoTqMpNlBaWsGw,173
2
2
  onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
3
- onnx_diagnostic/_command_lines_parser.py,sha256=Xkh_7fIDbT5ghTpLqlVhx4cIxAUXqv6zvPdnN3aCdOY,52254
3
+ onnx_diagnostic/_command_lines_parser.py,sha256=AWT6XrphbR0C0w9J846jPcRWkoUtnSSAX7gdR-JavQ4,54258
4
4
  onnx_diagnostic/api.py,sha256=BhCl_yCd78N7TlVtPOHjeYv1QBEy39TjZ647rcHqLh0,345
5
5
  onnx_diagnostic/doc.py,sha256=t3RELgfooYnVMAi0JSpggWkQEgUsREz8NmRvn0TnLI8,2829
6
- onnx_diagnostic/ext_test_case.py,sha256=rVZWqFEfnvwnsD3wF4jeDblh5uj5ckZ8C6DZQ0RGb_E,49599
6
+ onnx_diagnostic/ext_test_case.py,sha256=KxRC6s9107hYvNgU9x2B85rj8_EhAtymPIlMpmkUNu8,50154
7
+ onnx_diagnostic/ci_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ onnx_diagnostic/ci_models/ci_helpers.py,sha256=6CKQ4dVHBHeF6rN_Q3Y_0ZFeLYbYQbGQO3YyW3PQyAc,15341
9
+ onnx_diagnostic/ci_models/export_qwen25_vl.py,sha256=pyxmMIps9aDNkXzDAFrG8Q9DDOsEyKRHjoVvAggjFdU,20050
7
10
  onnx_diagnostic/export/__init__.py,sha256=yEIoWiOeTwBsDhyYt2fTKuhtA0Ya1J9u9ZzMTOTWaWs,101
8
- onnx_diagnostic/export/api.py,sha256=c3ZASq2upUAoiQ4aymm8vOAEySN_Yk6l0o1hWf6Ailo,10065
9
- onnx_diagnostic/export/control_flow.py,sha256=zU5n_QYhNcBllyMsl1_i6ohZt2CshqG2MokJghrvA60,7751
10
- onnx_diagnostic/export/control_flow_onnx.py,sha256=sODOD4v7EJj6LWhrfcdCW68r9nYKsRM4SRnqDw4TrSI,18049
11
- onnx_diagnostic/export/control_flow_research.py,sha256=RuYz9_eM42Bk6TKSiPV6dS68LIMZu-6WBCFCKoSvjrk,5422
11
+ onnx_diagnostic/export/api.py,sha256=BX4c99gMlRYsBWk3P15FMRogArxjP4dXYXP5gILjgIk,10626
12
+ onnx_diagnostic/export/cf_simple_loop_for.py,sha256=0I1tRAwhmmqA-6Qaq8AiUL0Ci-HODuRAVcI9azNcxAQ,13345
13
+ onnx_diagnostic/export/control_flow_onnx.py,sha256=izGlctqQANrHzSxPMbT7hoauNbnIBdx6hb8ry7HtVmM,18263
12
14
  onnx_diagnostic/export/dynamic_shapes.py,sha256=M2hlpHSTbkzZwGKAbrpQXng5HQrwjF5Z6wGGxEgnp74,42061
13
- onnx_diagnostic/export/onnx_plug.py,sha256=WqqdTBk2pV26AplNQysIdhR9y3ZFdQ-5KXu5ogTNcgI,21053
15
+ onnx_diagnostic/export/onnx_plug.py,sha256=U13fL0BjnhMzcDGxaAOqM4TQte5Z4zKDg4ESS0iktjM,22704
14
16
  onnx_diagnostic/export/shape_helper.py,sha256=m628y0oRCQbeZkeh8JDHIfWMsSjoJoeX-IPiPGDHT-w,11273
15
17
  onnx_diagnostic/export/validate.py,sha256=_PGUql2DJhIgGKo0WjTGUc5AgsZUx8fEs00MePy-w98,6043
16
18
  onnx_diagnostic/helpers/__init__.py,sha256=GJ2GT7cgnlIveVUwMZhuvUwidbTJaKv8CsSIOpZDsJg,83
@@ -23,7 +25,7 @@ onnx_diagnostic/helpers/doc_helper.py,sha256=pl5MZd3_FaE8BqQnqoBuSBxoNCFcd2OJd3e
23
25
  onnx_diagnostic/helpers/dot_helper.py,sha256=hwgTJsbsUv0qq7euyPDnc1NsBZDGOwv32JXSZxIHJkE,8118
24
26
  onnx_diagnostic/helpers/fake_tensor_helper.py,sha256=J7wnK3WTuVKnYiMzLVTAPkdJr3hQfIfMC9ZlOu7oGmI,11024
25
27
  onnx_diagnostic/helpers/graph_helper.py,sha256=hevQT5a7_QuriVPQcbT5qe18n99Doyl5h3-qshx1-uk,14093
26
- onnx_diagnostic/helpers/helper.py,sha256=f5w53QR0IO1-zqAMacgeGpZNA8uAo0c2k_ZYXP_BRhE,65840
28
+ onnx_diagnostic/helpers/helper.py,sha256=x8EYQmgrz_G5QS_IsbeFIoDcN_sUs-CslJMHseBj1Fw,65482
27
29
  onnx_diagnostic/helpers/log_helper.py,sha256=0lJiTF87lliI-LmgpUH_V2N8NuzJ0LryH0mSYpkRaL8,93272
28
30
  onnx_diagnostic/helpers/memory_peak.py,sha256=M3m4_thWFIwP5HytbJYEqaijXIv5v5BW_vlcJowIYI4,6434
29
31
  onnx_diagnostic/helpers/mini_onnx_builder.py,sha256=jR2lkRZEQ0N30H0FqeBwaxJd_w_6kyxFagrnulqFjhE,23883
@@ -100,7 +102,7 @@ onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=jJCMWuOqGv5ahCfjr
100
102
  onnx_diagnostic/tasks/data/__init__.py,sha256=uJoemrWgEjI6oA-tMX7r3__x-b3siPmkgqaY7bgIles,401
101
103
  onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx,sha256=UbtvmWMqcZOKJ-I-HXWI1A6YR6QDaFS5u_yXm5C3ZBw,10299
102
104
  onnx_diagnostic/torch_export_patches/__init__.py,sha256=0SaZedwznm1hQUCvXZsGZORV5vby954wEExr5faepGg,720
103
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=WPb8Ku643UIV8kDyt9JUpaJBIVXth9UbteCNctd_yis,41863
105
+ onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=OpZHNWiA0iU-6WCFZcVCj06_MopYiZQ6c6CbAuSQ8Ms,42357
104
106
  onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=0HdubI06EGpxOICqDWZoVmZkVO9gAaFADEmby197EyM,11935
105
107
  onnx_diagnostic/torch_export_patches/patch_details.py,sha256=MSraVo5ngBhihi8ssPMXSY9B4fJ17J-GAADaw3dT-rc,11794
106
108
  onnx_diagnostic/torch_export_patches/patch_expressions.py,sha256=vr4tt61cbDnaaaduzMj4UBZ8OUtr6GfDpIWwOYqjWzs,3213
@@ -117,15 +119,15 @@ onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.p
117
119
  onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py,sha256=nVgYQk0xXpHiictN1wOHVMN2lTH9b0vfIJ4ie-uKopg,1999
118
120
  onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py,sha256=VIZsVHgR8NmAcBQalPl5I6ZzNgcBxjGb6ars31m9gRg,21936
119
121
  onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py,sha256=kTjuTRsfkGGGhspJnMxAMQSchZgGC_IruJzpHh_FmI8,6348
120
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py,sha256=R4YwnN9ktxjjImiJtLRxiKtKLr9LuFlwkPXkTJ6BTIo,6895
122
+ onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py,sha256=HE3fovyvMiYe9EPz1UjdD9AWopX3H188SMwPb8w5mzM,7111
121
123
  onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py,sha256=OxYdlLrwtd_KGHt3E17poduxvWFg-CfGS57-yN1i6gI,3827
122
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py,sha256=icjSrI3LFrShtV_AYQ8F2qiMFHZ74Qg5I2c-V23uEgg,31601
124
+ onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py,sha256=yALbXWi3ysJ6nzQD-rxTdxdNJiBsTbYEBIj4TdksDOA,34598
123
125
  onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py,sha256=cND9Iqo1aKdlX-BXGr9Qlq_Y4EW1L5VWSwZfqYTVazU,4888
124
126
  onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py,sha256=4bJ_z2gizZQla_fcCVt0dmuhzO9Vu-D7CCMWdxMlrKM,16893
125
127
  onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py,sha256=-6TuBm3sLAFEGuW3vRfOTtE5uP6aINFfu7xMnl27Dws,5703
126
128
  onnx_diagnostic/torch_export_patches/patches/patch_helper.py,sha256=kK_CGW643iVXxa-m6pttDBS7HTyMQaPypza7iqIInn4,721
127
129
  onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=FfES0WWiWxmuQbGTlQ7IJS0YBG7km3IQbnMYwk_lPPU,44667
128
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=VAfZj0xu3D4CG71SWv-9sYPUK4ZQTSz2-x4qxP4DxGE,3079
130
+ onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=Mvq8q1Lz3l3GyCD6j8WQjbrPk_V2dnc4iKm3cC_o1OA,3112
129
131
  onnx_diagnostic/torch_export_patches/serialization/__init__.py,sha256=BHLdRPtNAtNPAS-bPKEj3-foGSPvwAbZXrHzGGPDLEw,1876
130
132
  onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py,sha256=drq3EH_yjcSuIWYsVeUWm8Cx6YCZFU6bP_1PLtPfY5I,945
131
133
  onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py,sha256=sIHFvUQoMK8ytXQYB-k7OL62z8A3f5uDaq-S5R5uN-M,10034
@@ -136,18 +138,19 @@ onnx_diagnostic/torch_models/validate.py,sha256=fnbTl5v1n5nM2MpmCgCMaWa6c7DGpb5m
136
138
  onnx_diagnostic/torch_models/hghub/__init__.py,sha256=vi1Q7YHdddj1soiBN42MSvJdFqe2_KUoWafHISjwOu8,58
137
139
  onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=rFbiPNLET-KdBpnv-p0nKgwHX6d7C_Z0s9zZ86_92kQ,14307
138
140
  onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=8V_pAgACPLPsLRYUododg7MSL6str-T3tBEGY4OaeYQ,8724
139
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=GimzkI8W3guATkDx7RQ-w2xNGVaFDVegfTnnmNxf4iE,292068
141
+ onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=Dxa13rsnTQ8eH_BcQvbY2bp1AYFtzuFrJ-J_urrSmeQ,292694
140
142
  onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=tCGqigRyY1omxm2rczRUvCTsweZGbF1MccWI3MmCH20,17423
141
143
  onnx_diagnostic/torch_models/hghub/model_specific.py,sha256=j50Nu7wddJMoqmD4QzMbNdFDUUgUmSBKRzPDH55TlUQ,2498
142
144
  onnx_diagnostic/torch_models/untrained/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
143
145
  onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=y_akbdApi136qHcEQgykwIAYVw0Yfi0lbjb3DNuafaU,3948
144
146
  onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=QXw_Bs2SzfeiQMf-tmtVl83SmVOL4-Um7Qy-f0E48QI,2507
145
147
  onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
148
+ onnx_diagnostic/torch_onnx/compare.py,sha256=O0lws4kzn8WAXr8-x-YMPr7oyBC9DtSIs4OfOr4S5-E,12305
146
149
  onnx_diagnostic/torch_onnx/runtime_info.py,sha256=u1bD6VXqzBCRmqmbzQtDswaPs1PH_ygr1r-CrcfXpNU,8562
147
150
  onnx_diagnostic/torch_onnx/sbs.py,sha256=8okBEIupMgw7TtKc80YFimMtwnY3GchdY05FsA9ooa0,40749
148
151
  onnx_diagnostic/torch_onnx/sbs_dataclasses.py,sha256=UctdBjzoPTQG1LS0tZ8A6E9hpoq5HWUYaJLPOPJc9FI,20299
149
- onnx_diagnostic-0.8.4.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
150
- onnx_diagnostic-0.8.4.dist-info/METADATA,sha256=8S9bFx2lTef7dTeL_dTCtGj1MalIyoUvs5dMzrMffNg,6734
151
- onnx_diagnostic-0.8.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
152
- onnx_diagnostic-0.8.4.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
153
- onnx_diagnostic-0.8.4.dist-info/RECORD,,
152
+ onnx_diagnostic-0.8.6.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
153
+ onnx_diagnostic-0.8.6.dist-info/METADATA,sha256=9xPlJ9UHYSSIyEMqxN14mqZg31Rq9jqzHaczfqWhu-4,6734
154
+ onnx_diagnostic-0.8.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
155
+ onnx_diagnostic-0.8.6.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
156
+ onnx_diagnostic-0.8.6.dist-info/RECORD,,
@@ -1,214 +0,0 @@
1
- import contextlib
2
- from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
3
- import torch
4
- from torch._higher_order_ops.utils import (
5
- materialize_as_graph,
6
- check_input_alias_and_mutation_return_outputs,
7
- # _maybe_reenter_make_fx,
8
- )
9
-
10
- _TEST_EXPORT = False
11
-
12
-
13
- @contextlib.contextmanager
14
- def enable_code_export_control_flow():
15
- """Enables the code meant to be exported."""
16
- global _TEST_EXPORT
17
- old = _TEST_EXPORT
18
- _TEST_EXPORT = True
19
- try:
20
- yield
21
- finally:
22
- _TEST_EXPORT = old
23
-
24
-
25
- def is_exporting() -> bool:
26
- """
27
- Returns :func:`torch.compiler.is_exporting` or
28
- :func:`torch.compiler.is_compiling`.
29
- Changes ``_TEST_EXPORT`` to make it trigger.
30
- """
31
- return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
32
-
33
-
34
- def _loop_for_fn(n_iter, body_fn, reduction_dim, args):
35
- """
36
- Python implementation of the loop.
37
-
38
- :param n_iter: number of iteration
39
- :param body_fn: function implementing the body
40
- :param reduction_dim: dimension used to reduce the list produced by the loop
41
- :param args: arguments to the loop body
42
- :return: results
43
- """
44
- res = []
45
- for i in torch.arange(n_iter, dtype=n_iter.dtype):
46
- r = body_fn(i, *args)
47
- if isinstance(r, tuple):
48
- assert not res or len(r) == len(res[-1]), (
49
- f"Unexpected number of results {len(r)} for function {body_fn}, "
50
- f"expected {len(res[-1])}"
51
- )
52
- res.append(r)
53
- else:
54
- assert isinstance(r, torch.Tensor), (
55
- f"Unexpected type {r} for function {body_fn}, "
56
- f"it must be a tuple or a Tensor."
57
- )
58
- assert not res or len(res[-1]) == 1, (
59
- f"Unexpected number of results {len(r)} for function {body_fn}, "
60
- f"expected {len(res[-1])}"
61
- )
62
- res.append((r,))
63
-
64
- if not res:
65
- return torch.empty(tuple(), dtype=torch.float32, device=args[0].device)
66
- if len(res) == 1:
67
- final = res[0]
68
- else:
69
- n_res = len(res[0])
70
- final = [
71
- torch.cat(
72
- [r[i] for r in res],
73
- dim=(
74
- 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
75
- ),
76
- )
77
- for i in range(n_res)
78
- ]
79
- return tuple(final) if len(final) > 1 else final[0]
80
-
81
-
82
- def make_custom_loop_for(
83
- n_iter: torch.Tensor,
84
- body_fn: Callable,
85
- reduction_dim: Optional[Sequence[int]],
86
- args: Sequence[torch.Tensor],
87
- body_gm: Optional[torch.fx.GraphModule] = None,
88
- body_mutated_inputs: Optional[List[Any]] = None,
89
- body_outputs: Optional[List[Any]] = None,
90
- ) -> Tuple[str, torch.library.CustomOpDef]:
91
- """
92
- Defines a custom operator for a loop in order to avoid
93
- :func:`torch.export.export` digging into it.
94
- It registers the custom op and a custom conversion
95
- to ONNX.
96
-
97
- :param n_iter: number of iterations defined by a tensor of no dimension
98
- :param body_fn: the loop body defined as a function
99
- :param reduction_dim: dimension used to concatenated the results
100
- :param args: list of tensors, input to the body
101
- :param body_gm: torch.fx.GraphModule equivalent to *body_gm*
102
- :param body_mutated_inputs: inputs to *body_gm*
103
- :param body_outputs: outputs to *body_gm*
104
- :return: a name and the custom op definition, the name
105
- is used to cache the custom op
106
- """
107
- assert body_gm is not None, "body_gm cannot be None"
108
- assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
109
- assert body_outputs is not None, "body_outputs cannot be None"
110
-
111
- srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
112
- sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
113
- full_name = (
114
- body_fn.__qualname__.replace("<locals>", "L")
115
- .replace("<lambda>", "l")
116
- .replace(".", "_")
117
- )
118
- name = f"loop_for_onnx_{full_name}_{srank}_{sred}"
119
-
120
- schema = "(str body_fn, Tensor n_iter, Tensor[] body_inputs) -> Tensor"
121
- if len(body_outputs) > 1:
122
- schema += "[]"
123
- custom_def = torch.library.CustomOpDef("onnx_higher_ops", "loop_for", schema, body_fn)
124
- custom_def.register_kernel("cpu")(body_fn)
125
-
126
- custom_def._abstract_fn = lambda _fn_id, *_args, _o=body_outputs: (
127
- tuple([torch.empty_like(s) for s in _o]) if len(_o) > 1 else torch.empty_like(_o[0])
128
- )
129
- return name, custom_def
130
-
131
-
132
- def loop_for(
133
- n_iter: Union[torch.SymInt, torch.Tensor],
134
- body_fn: Callable[..., Tuple[torch.Tensor]],
135
- args: Sequence[torch.Tensor],
136
- reduction_dim: Optional[Sequence[int]] = None,
137
- ) -> Tuple[torch.Tensor, ...]:
138
- """
139
- High operators used to easily export a loop in ONNX.
140
- Does not fully work with :func:`torch.export.export`,
141
- it does replaces a custom op with a loop operator afterwards.
142
- Every iteration produces tensors, all of them are gathered
143
- into lists, all these lists are concatenated into tensors.
144
-
145
- :param n_iter: number of iterations, it can be fixed on
146
- variable, in that case it should a tensor with no dimension
147
- :param body_fn: function body, takes only tensors and returns
148
- only tensors, the first argument is the iteration number
149
- in a tensor with no dimension, all the others
150
- are not changed during the loop
151
- :param args: the available tensors at every loop
152
- :param reduction_dim: the loop aggregated the results into list,
153
- one of each output, each of them is concatenated into one
154
- tensor along one dimension, by default, it is the first
155
- dimension, but it can be defined otherwise
156
- """
157
- assert args, "The function should have at least one arg."
158
- assert (
159
- isinstance(n_iter, torch.Tensor)
160
- and n_iter.dtype == torch.int64
161
- and len(n_iter.shape) == 0
162
- ), f"Only a tensor for one int64 is allowed for n_iter but it equal to {n_iter}."
163
- if is_exporting():
164
- from torch.fx.experimental.proxy_tensor import _CURRENT_MAKE_FX_TRACER
165
-
166
- # tracer = _CURRENT_MAKE_FX_TRACER.fx_tracer
167
- root = _CURRENT_MAKE_FX_TRACER.fx_tracer.root
168
- # graph = _CURRENT_MAKE_FX_TRACER.fx_tracer.graph
169
-
170
- body_gm: torch.fx.GraphModule = materialize_as_graph(
171
- body_fn, (torch.tensor(0, dtype=torch.int64), *args)
172
- )
173
- (
174
- _1,
175
- _2,
176
- _3,
177
- body_mutated_inputs,
178
- body_outputs,
179
- ) = check_input_alias_and_mutation_return_outputs(body_gm)
180
- name, _custom_ops = make_custom_loop_for(
181
- n_iter,
182
- body_fn,
183
- reduction_dim,
184
- args,
185
- body_gm=body_gm,
186
- body_mutated_inputs=body_mutated_inputs,
187
- body_outputs=body_outputs,
188
- )
189
- root.register_module(name, body_gm)
190
- # body_graph = _maybe_reenter_make_fx(body_fn)(n_iter, *args)
191
- return torch.ops.onnx_higher_ops.loop_for(name, n_iter, args)
192
-
193
- return _loop_for_fn(n_iter, body_fn, reduction_dim, args)
194
-
195
-
196
- """
197
- proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
198
- proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
199
-
200
- args = (cond_graph, body_graph, carried_inputs, additional_inputs)
201
-
202
- proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
203
-
204
- out_proxy = proxy_mode.tracer.create_proxy(
205
- "call_function", op, proxy_args, {}, name=op._name
206
- )
207
-
208
- out = op(
209
- cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
210
- )
211
- return track_tensor_tree(
212
- out, out_proxy, constant=None, tracer=proxy_mode.tracer
213
- )
214
- """