onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +18 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/ext_test_case.py +3 -1
- onnx_diagnostic/helpers/args_helper.py +1 -1
- onnx_diagnostic/helpers/doc_helper.py +143 -0
- onnx_diagnostic/helpers/helper.py +6 -5
- onnx_diagnostic/helpers/model_builder_helper.py +24 -8
- onnx_diagnostic/helpers/rt_helper.py +5 -1
- onnx_diagnostic/helpers/torch_helper.py +2 -0
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/torch_evaluator.py +648 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
- onnx_diagnostic/tasks/__init__.py +22 -1
- onnx_diagnostic/tasks/image_classification.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +3 -3
- onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/test_helper.py +133 -16
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/RECORD +39 -23
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from . import OpRunKernel, OpRunTensor
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Abs_1(OpRunKernel):
|
|
6
|
+
"""Abs"""
|
|
7
|
+
|
|
8
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
9
|
+
return OpRunTensor(torch.abs(x.tensor))
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Cos_1(OpRunKernel):
|
|
13
|
+
"""Cos"""
|
|
14
|
+
|
|
15
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
16
|
+
return OpRunTensor(x.tensor.cos())
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Erf_9(OpRunKernel):
|
|
20
|
+
"""Erf"""
|
|
21
|
+
|
|
22
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
23
|
+
return OpRunTensor(x.tensor.erf())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Exp_1(OpRunKernel):
|
|
27
|
+
"""Exp"""
|
|
28
|
+
|
|
29
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
30
|
+
return OpRunTensor(x.tensor.exp())
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Identity_1(OpRunKernel):
|
|
34
|
+
"Identity"
|
|
35
|
+
|
|
36
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
37
|
+
return OpRunTensor(x.tensor)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Log_1(OpRunKernel):
|
|
41
|
+
"""Log"""
|
|
42
|
+
|
|
43
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
44
|
+
return OpRunTensor(x.tensor.log())
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Neg_1(OpRunKernel):
|
|
48
|
+
"""Neg"""
|
|
49
|
+
|
|
50
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
51
|
+
return OpRunTensor(-x.tensor)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Not_1(OpRunKernel):
|
|
55
|
+
"""Not"""
|
|
56
|
+
|
|
57
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
58
|
+
return OpRunTensor(~x.tensor)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Reciprocal_1(OpRunKernel):
|
|
62
|
+
"""REciprocal"""
|
|
63
|
+
|
|
64
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
65
|
+
return OpRunTensor(1 / x.tensor)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Sigmoid_6(OpRunKernel):
|
|
69
|
+
"""Sqrt"""
|
|
70
|
+
|
|
71
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
72
|
+
return OpRunTensor(torch.sigmoid(x.tensor))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Sin_1(OpRunKernel):
|
|
76
|
+
"""Sin"""
|
|
77
|
+
|
|
78
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
79
|
+
return OpRunTensor(x.tensor.sin())
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Sqrt_1(OpRunKernel):
|
|
83
|
+
"""Sqrt"""
|
|
84
|
+
|
|
85
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
86
|
+
return OpRunTensor(x.tensor.sqrt())
|
|
@@ -39,9 +39,30 @@ def supported_tasks() -> List[str]:
|
|
|
39
39
|
|
|
40
40
|
def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
|
|
41
41
|
"""Reduces a model size."""
|
|
42
|
+
head_size0 = (
|
|
43
|
+
config.head_dim
|
|
44
|
+
if hasattr(config, "head_dim") and config.head_dim
|
|
45
|
+
else (
|
|
46
|
+
config.hidden_size // config.num_attention_heads
|
|
47
|
+
if hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads")
|
|
48
|
+
else None
|
|
49
|
+
)
|
|
50
|
+
)
|
|
42
51
|
tasks = {mod.__TASK__: mod.reduce_model_config for mod in __TASKS__}
|
|
43
52
|
assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
|
|
44
|
-
|
|
53
|
+
res = tasks[task](config)
|
|
54
|
+
if head_size0 and "head_dim" in res:
|
|
55
|
+
head_size = (
|
|
56
|
+
config.head_dim
|
|
57
|
+
if hasattr(config, "head_dim") and config.head_dim
|
|
58
|
+
else config.hidden_size // config.num_attention_heads
|
|
59
|
+
)
|
|
60
|
+
assert head_size0 == head_size or head_size % 16 == 0, (
|
|
61
|
+
f"head_size should be a multiple of 16 "
|
|
62
|
+
f"(head_size0={head_size0}), res={res}, "
|
|
63
|
+
f"config=\n{config}"
|
|
64
|
+
)
|
|
65
|
+
return res
|
|
45
66
|
|
|
46
67
|
|
|
47
68
|
def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
|
|
@@ -58,8 +58,8 @@ def get_inputs(
|
|
|
58
58
|
shapes = {
|
|
59
59
|
"pixel_values": {
|
|
60
60
|
0: torch.export.Dim("batch", min=1, max=1024),
|
|
61
|
-
2:
|
|
62
|
-
3:
|
|
61
|
+
2: "width",
|
|
62
|
+
3: "height",
|
|
63
63
|
},
|
|
64
64
|
}
|
|
65
65
|
inputs = dict(
|
|
@@ -27,7 +27,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
27
27
|
kwargs = dict(
|
|
28
28
|
num_hidden_layers=min(config.num_hidden_layers, 2),
|
|
29
29
|
intermediate_size=256 if config is None else min(512, config.intermediate_size),
|
|
30
|
-
hidden_size=
|
|
30
|
+
hidden_size=512 if config is None else min(512, config.hidden_size),
|
|
31
31
|
cls_cache="MambaCache",
|
|
32
32
|
state_size=8 if config is None else getattr(config, "state_size", None),
|
|
33
33
|
conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
|
|
@@ -44,8 +44,8 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
44
44
|
else config.num_attention_heads
|
|
45
45
|
),
|
|
46
46
|
hidden_size=(
|
|
47
|
-
min(config.hidden_size,
|
|
48
|
-
if config.hidden_size %
|
|
47
|
+
min(config.hidden_size, 4096 // 4)
|
|
48
|
+
if config.hidden_size % 64 == 0
|
|
49
49
|
else config.hidden_size
|
|
50
50
|
),
|
|
51
51
|
)
|
|
@@ -185,9 +185,17 @@ def _make_exporter_export(
|
|
|
185
185
|
|
|
186
186
|
if exporter == "export-strict":
|
|
187
187
|
try:
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
188
|
+
if verbose >= 2:
|
|
189
|
+
exported = torch.export.export(
|
|
190
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
194
|
+
io.StringIO()
|
|
195
|
+
):
|
|
196
|
+
exported = torch.export.export(
|
|
197
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
198
|
+
)
|
|
191
199
|
except Exception as e:
|
|
192
200
|
if not quiet:
|
|
193
201
|
raise
|
|
@@ -198,17 +206,33 @@ def _make_exporter_export(
|
|
|
198
206
|
return exported.module()
|
|
199
207
|
if exporter in ("export-strict-dec", "export-strict-decall"):
|
|
200
208
|
try:
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
exported
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
209
|
+
if verbose >= 2:
|
|
210
|
+
exported = torch.export.export(
|
|
211
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
212
|
+
)
|
|
213
|
+
if verbose >= 9:
|
|
214
|
+
print("-- graph before decomposition")
|
|
215
|
+
print(exported.graph)
|
|
216
|
+
exported = (
|
|
217
|
+
exported.run_decompositions()
|
|
218
|
+
if "decall" in exporter
|
|
219
|
+
else exported.run_decompositions({})
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
223
|
+
io.StringIO()
|
|
224
|
+
):
|
|
225
|
+
exported = torch.export.export(
|
|
226
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
227
|
+
)
|
|
228
|
+
if verbose >= 9:
|
|
229
|
+
print("-- graph before decomposition")
|
|
230
|
+
print(exported.graph)
|
|
231
|
+
exported = (
|
|
232
|
+
exported.run_decompositions()
|
|
233
|
+
if "decall" in exporter
|
|
234
|
+
else exported.run_decompositions({})
|
|
235
|
+
)
|
|
212
236
|
except Exception as e:
|
|
213
237
|
if not quiet:
|
|
214
238
|
raise
|
|
@@ -219,9 +243,17 @@ def _make_exporter_export(
|
|
|
219
243
|
return exported.module()
|
|
220
244
|
if exporter == "export-nostrict":
|
|
221
245
|
try:
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
246
|
+
if verbose >= 2:
|
|
247
|
+
exported = torch.export.export(
|
|
248
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
252
|
+
io.StringIO()
|
|
253
|
+
):
|
|
254
|
+
exported = torch.export.export(
|
|
255
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
|
|
256
|
+
)
|
|
225
257
|
except Exception as e:
|
|
226
258
|
if not quiet:
|
|
227
259
|
raise
|
|
@@ -232,17 +264,33 @@ def _make_exporter_export(
|
|
|
232
264
|
return exported.module()
|
|
233
265
|
if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
|
|
234
266
|
try:
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
exported
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
267
|
+
if verbose >= 2:
|
|
268
|
+
exported = torch.export.export(
|
|
269
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
|
|
270
|
+
)
|
|
271
|
+
if verbose >= 9:
|
|
272
|
+
print("-- graph before decomposition")
|
|
273
|
+
print(exported.graph)
|
|
274
|
+
exported = (
|
|
275
|
+
exported.run_decompositions()
|
|
276
|
+
if "decall" in exporter
|
|
277
|
+
else exported.run_decompositions({})
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
281
|
+
io.StringIO()
|
|
282
|
+
):
|
|
283
|
+
exported = torch.export.export(
|
|
284
|
+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
|
|
285
|
+
)
|
|
286
|
+
if verbose >= 9:
|
|
287
|
+
print("-- graph before decomposition")
|
|
288
|
+
print(exported.graph)
|
|
289
|
+
exported = (
|
|
290
|
+
exported.run_decompositions()
|
|
291
|
+
if "decall" in exporter
|
|
292
|
+
else exported.run_decompositions({})
|
|
293
|
+
)
|
|
246
294
|
except Exception as e:
|
|
247
295
|
if not quiet:
|
|
248
296
|
raise
|
|
@@ -255,8 +303,15 @@ def _make_exporter_export(
|
|
|
255
303
|
from experimental_experiment.torch_interpreter.tracing import CustomTracer
|
|
256
304
|
|
|
257
305
|
try:
|
|
258
|
-
|
|
259
|
-
|
|
306
|
+
if verbose >= 2:
|
|
307
|
+
graph = CustomTracer().trace(model)
|
|
308
|
+
mod = torch.fx.GraphModule(model, graph)
|
|
309
|
+
else:
|
|
310
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
311
|
+
io.StringIO()
|
|
312
|
+
):
|
|
313
|
+
graph = CustomTracer().trace(model)
|
|
314
|
+
mod = torch.fx.GraphModule(model, graph)
|
|
260
315
|
except Exception as e:
|
|
261
316
|
if not quiet:
|
|
262
317
|
raise
|
|
@@ -289,13 +344,25 @@ def _make_exporter_onnx(
|
|
|
289
344
|
if "-dec" in exporter:
|
|
290
345
|
opts["decomposition_table"] = "all" if "-decall" in exporter else "default"
|
|
291
346
|
try:
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
347
|
+
if verbose >= 2:
|
|
348
|
+
onx, builder = to_onnx(
|
|
349
|
+
model,
|
|
350
|
+
inputs,
|
|
351
|
+
dynamic_shapes=dynamic_shapes,
|
|
352
|
+
export_options=ExportOptions(**opts),
|
|
353
|
+
return_builder=True,
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
|
|
357
|
+
io.StringIO()
|
|
358
|
+
):
|
|
359
|
+
onx, builder = to_onnx(
|
|
360
|
+
model,
|
|
361
|
+
inputs,
|
|
362
|
+
dynamic_shapes=dynamic_shapes,
|
|
363
|
+
export_options=ExportOptions(**opts),
|
|
364
|
+
return_builder=True,
|
|
365
|
+
)
|
|
299
366
|
except Exception as e:
|
|
300
367
|
if not quiet:
|
|
301
368
|
raise RuntimeError(
|
|
@@ -306,6 +373,7 @@ def _make_exporter_onnx(
|
|
|
306
373
|
) from e
|
|
307
374
|
return dict(error=str(e), success=0, error_step="export")
|
|
308
375
|
return onx, builder
|
|
376
|
+
|
|
309
377
|
if exporter == "dynamo":
|
|
310
378
|
import torch
|
|
311
379
|
|
|
@@ -338,6 +406,7 @@ def _make_exporter_onnx(
|
|
|
338
406
|
) from e
|
|
339
407
|
return dict(error=str(e), success=0, error_step="export")
|
|
340
408
|
return onx, None
|
|
409
|
+
|
|
341
410
|
if exporter == "dynamo-ir":
|
|
342
411
|
import torch
|
|
343
412
|
|
|
@@ -23,7 +23,6 @@ class AtenRollPos(torch.nn.Module):
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class InplaceAdd(torch.nn.Module):
|
|
26
|
-
|
|
27
26
|
def __init__(self):
|
|
28
27
|
super().__init__()
|
|
29
28
|
self.bias = torch.ones((1, 4), dtype=torch.float32)
|
|
@@ -37,7 +36,6 @@ class InplaceAdd(torch.nn.Module):
|
|
|
37
36
|
|
|
38
37
|
|
|
39
38
|
class InplaceAdd2(torch.nn.Module):
|
|
40
|
-
|
|
41
39
|
def __init__(self):
|
|
42
40
|
super().__init__()
|
|
43
41
|
self.bias = torch.ones((1, 4), dtype=torch.float32)
|
|
@@ -51,7 +49,6 @@ class InplaceAdd2(torch.nn.Module):
|
|
|
51
49
|
|
|
52
50
|
|
|
53
51
|
class InplaceAdd_Mul(torch.nn.Module):
|
|
54
|
-
|
|
55
52
|
def __init__(self):
|
|
56
53
|
super().__init__()
|
|
57
54
|
self.bias = torch.ones((1, 4), dtype=torch.float32)
|
|
@@ -65,7 +62,6 @@ class InplaceAdd_Mul(torch.nn.Module):
|
|
|
65
62
|
|
|
66
63
|
|
|
67
64
|
class InplaceCloneAdd_(torch.nn.Module):
|
|
68
|
-
|
|
69
65
|
def __init__(self):
|
|
70
66
|
super().__init__()
|
|
71
67
|
self.bias = torch.ones((1, 4), dtype=torch.float32)
|
|
@@ -80,7 +76,6 @@ class InplaceCloneAdd_(torch.nn.Module):
|
|
|
80
76
|
|
|
81
77
|
|
|
82
78
|
class InplaceSetItemSquare(torch.nn.Module):
|
|
83
|
-
|
|
84
79
|
def forward(self, x):
|
|
85
80
|
x[:2, :3] = 1
|
|
86
81
|
return x
|
|
@@ -90,7 +85,6 @@ class InplaceSetItemSquare(torch.nn.Module):
|
|
|
90
85
|
|
|
91
86
|
|
|
92
87
|
class InplaceSetItemSquareAdd(torch.nn.Module):
|
|
93
|
-
|
|
94
88
|
def forward(self, x):
|
|
95
89
|
x[:2, :3] = 1
|
|
96
90
|
return x + 2
|
|
@@ -100,7 +94,6 @@ class InplaceSetItemSquareAdd(torch.nn.Module):
|
|
|
100
94
|
|
|
101
95
|
|
|
102
96
|
class InplaceSetItemSquareAdd2(torch.nn.Module):
|
|
103
|
-
|
|
104
97
|
def forward(self, x):
|
|
105
98
|
x[:2, :3] = 1
|
|
106
99
|
return x + 2, x + 3
|
|
@@ -110,7 +103,6 @@ class InplaceSetItemSquareAdd2(torch.nn.Module):
|
|
|
110
103
|
|
|
111
104
|
|
|
112
105
|
class InplaceSetItemEllipsis_1(torch.nn.Module):
|
|
113
|
-
|
|
114
106
|
def __init__(self):
|
|
115
107
|
super().__init__()
|
|
116
108
|
self.params = torch.zeros((1, 8192, 4), dtype=torch.float32)
|
|
@@ -124,11 +116,10 @@ class InplaceSetItemEllipsis_1(torch.nn.Module):
|
|
|
124
116
|
(torch.from_numpy(np.array([0, 3, 2, 1])).to(torch.int64)),
|
|
125
117
|
(torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
|
|
126
118
|
)
|
|
127
|
-
_dynamic = {"
|
|
119
|
+
_dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
|
|
128
120
|
|
|
129
121
|
|
|
130
122
|
class InplaceSetItemEllipsis_2(torch.nn.Module):
|
|
131
|
-
|
|
132
123
|
def __init__(self):
|
|
133
124
|
super().__init__()
|
|
134
125
|
self.params = torch.zeros((1, 8192, 6), dtype=torch.float32)
|
|
@@ -142,7 +133,7 @@ class InplaceSetItemEllipsis_2(torch.nn.Module):
|
|
|
142
133
|
torch.from_numpy(np.array([0, 3, 2, 5])).to(torch.int64),
|
|
143
134
|
(torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
|
|
144
135
|
)
|
|
145
|
-
_dynamic = {"
|
|
136
|
+
_dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
|
|
146
137
|
|
|
147
138
|
|
|
148
139
|
class InplaceSetItemMask(torch.nn.Module):
|
|
@@ -156,7 +147,6 @@ class InplaceSetItemMask(torch.nn.Module):
|
|
|
156
147
|
|
|
157
148
|
|
|
158
149
|
class AtenInterpolate(torch.nn.Module):
|
|
159
|
-
|
|
160
150
|
def forward(self, x):
|
|
161
151
|
y = torch.nn.functional.interpolate(
|
|
162
152
|
x,
|
|
@@ -171,7 +161,6 @@ class AtenInterpolate(torch.nn.Module):
|
|
|
171
161
|
|
|
172
162
|
|
|
173
163
|
class AtenNonZero(torch.nn.Module):
|
|
174
|
-
|
|
175
164
|
def forward(self, x):
|
|
176
165
|
y = torch.nonzero(x)
|
|
177
166
|
return y
|
|
@@ -181,7 +170,6 @@ class AtenNonZero(torch.nn.Module):
|
|
|
181
170
|
|
|
182
171
|
|
|
183
172
|
class AtenNonZeroTuple(torch.nn.Module):
|
|
184
|
-
|
|
185
173
|
def forward(self, x):
|
|
186
174
|
y = torch.nonzero(x, as_tuple=True)
|
|
187
175
|
return y[0], y[1]
|
|
@@ -191,7 +179,6 @@ class AtenNonZeroTuple(torch.nn.Module):
|
|
|
191
179
|
|
|
192
180
|
|
|
193
181
|
class AtenAsStrided(torch.nn.Module):
|
|
194
|
-
|
|
195
182
|
def __init__(self):
|
|
196
183
|
super().__init__()
|
|
197
184
|
|
|
@@ -288,7 +275,6 @@ class ControlFlowCondConstant(torch.nn.Module):
|
|
|
288
275
|
|
|
289
276
|
|
|
290
277
|
class ControlFlowCondNestedModule(torch.nn.Module):
|
|
291
|
-
|
|
292
278
|
class Submodule(torch.nn.Module):
|
|
293
279
|
def __init__(self):
|
|
294
280
|
super().__init__()
|
|
@@ -367,7 +353,9 @@ class ControlFlowCondNonZero(torch.nn.Module):
|
|
|
367
353
|
|
|
368
354
|
|
|
369
355
|
class ControlFlowCondIdentity_153832(torch.nn.Module):
|
|
370
|
-
"""
|
|
356
|
+
"""
|
|
357
|
+
`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
|
|
358
|
+
"""
|
|
371
359
|
|
|
372
360
|
def forward(self, x, y):
|
|
373
361
|
|
|
@@ -389,7 +377,6 @@ class ControlFlowCondIdentity_153832(torch.nn.Module):
|
|
|
389
377
|
|
|
390
378
|
|
|
391
379
|
class ControlFlowScan(torch.nn.Module):
|
|
392
|
-
|
|
393
380
|
@staticmethod
|
|
394
381
|
def add(carry: torch.Tensor, y: torch.Tensor):
|
|
395
382
|
next_carry = carry + y
|
|
@@ -486,7 +473,6 @@ class ControlFlowScanCDist2(torch.nn.Module):
|
|
|
486
473
|
|
|
487
474
|
|
|
488
475
|
class ControlFlowScanCDistXY(torch.nn.Module):
|
|
489
|
-
|
|
490
476
|
@staticmethod
|
|
491
477
|
def dist(y: torch.Tensor, scanned_x: torch.Tensor):
|
|
492
478
|
sub = y - scanned_x.reshape((1, -1))
|
|
@@ -517,7 +503,9 @@ class ControlFlowScanCDistXY(torch.nn.Module):
|
|
|
517
503
|
|
|
518
504
|
|
|
519
505
|
class ControlFlowScanInplace_153705(torch.nn.Module):
|
|
520
|
-
"""
|
|
506
|
+
"""
|
|
507
|
+
`#153705 <https://github.com/pytorch/pytorch/issues/153705>`_
|
|
508
|
+
"""
|
|
521
509
|
|
|
522
510
|
def forward(self, x, y):
|
|
523
511
|
def loop_body_1(z, iv, x, y):
|
|
@@ -540,7 +528,9 @@ class ControlFlowScanInplace_153705(torch.nn.Module):
|
|
|
540
528
|
|
|
541
529
|
|
|
542
530
|
class ControlFlowScanDecomposition_151564(torch.nn.Module):
|
|
543
|
-
"""
|
|
531
|
+
"""
|
|
532
|
+
`#151564 <https://github.com/pytorch/pytorch/issues/151564>`_
|
|
533
|
+
"""
|
|
544
534
|
|
|
545
535
|
@classmethod
|
|
546
536
|
def dummy_loop(cls, padded: torch.Tensor, pos: torch.Tensor):
|
|
@@ -790,7 +780,6 @@ class SignatureShapeAsIndex(torch.nn.Module):
|
|
|
790
780
|
|
|
791
781
|
|
|
792
782
|
class TypeBFloat16(torch.nn.Module):
|
|
793
|
-
|
|
794
783
|
def forward(self, x):
|
|
795
784
|
xb = x.to(torch.bfloat16)
|
|
796
785
|
return (xb + xb).to(torch.float32)
|
|
@@ -825,7 +814,6 @@ class CropLastDimensionWithTensorShape(torch.nn.Module):
|
|
|
825
814
|
|
|
826
815
|
|
|
827
816
|
class CropLastDimensionWithTensorContent(torch.nn.Module):
|
|
828
|
-
|
|
829
817
|
def forward(self, x, shape):
|
|
830
818
|
return x[..., : shape[0]]
|
|
831
819
|
|
|
@@ -833,11 +821,10 @@ class CropLastDimensionWithTensorContent(torch.nn.Module):
|
|
|
833
821
|
(torch.rand(3, 4, 4).to(torch.float32), torch.tensor([2], dtype=torch.int64)),
|
|
834
822
|
(torch.rand(6, 4, 4).to(torch.float32), torch.tensor([3], dtype=torch.int64)),
|
|
835
823
|
]
|
|
836
|
-
_dynamic = {"x": {0: DIM("batch")}}
|
|
824
|
+
_dynamic = {"x": {0: DIM("batch")}, "shape": {}}
|
|
837
825
|
|
|
838
826
|
|
|
839
827
|
class SignatureListFixedWithNone(torch.nn.Module):
|
|
840
|
-
|
|
841
828
|
def forward(self, lx):
|
|
842
829
|
x = lx[0]
|
|
843
830
|
if lx[1] is not None:
|