onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +18 -0
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/ext_test_case.py +3 -1
  5. onnx_diagnostic/helpers/args_helper.py +1 -1
  6. onnx_diagnostic/helpers/doc_helper.py +143 -0
  7. onnx_diagnostic/helpers/helper.py +6 -5
  8. onnx_diagnostic/helpers/model_builder_helper.py +24 -8
  9. onnx_diagnostic/helpers/rt_helper.py +5 -1
  10. onnx_diagnostic/helpers/torch_helper.py +2 -0
  11. onnx_diagnostic/reference/__init__.py +1 -0
  12. onnx_diagnostic/reference/torch_evaluator.py +648 -0
  13. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  14. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  15. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  16. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  17. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  18. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  19. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  20. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  21. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  22. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  23. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  24. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  25. onnx_diagnostic/tasks/__init__.py +22 -1
  26. onnx_diagnostic/tasks/image_classification.py +2 -2
  27. onnx_diagnostic/tasks/text_generation.py +3 -3
  28. onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
  29. onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
  30. onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
  31. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
  32. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  33. onnx_diagnostic/torch_models/test_helper.py +133 -16
  34. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  35. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/METADATA +1 -1
  36. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/RECORD +39 -23
  37. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/WHEEL +1 -1
  38. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/licenses/LICENSE.txt +0 -0
  39. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,86 @@
1
+ import torch
2
+ from . import OpRunKernel, OpRunTensor
3
+
4
+
5
+ class Abs_1(OpRunKernel):
6
+ """Abs"""
7
+
8
+ def run(self, x: OpRunTensor) -> OpRunTensor:
9
+ return OpRunTensor(torch.abs(x.tensor))
10
+
11
+
12
+ class Cos_1(OpRunKernel):
13
+ """Cos"""
14
+
15
+ def run(self, x: OpRunTensor) -> OpRunTensor:
16
+ return OpRunTensor(x.tensor.cos())
17
+
18
+
19
+ class Erf_9(OpRunKernel):
20
+ """Erf"""
21
+
22
+ def run(self, x: OpRunTensor) -> OpRunTensor:
23
+ return OpRunTensor(x.tensor.erf())
24
+
25
+
26
+ class Exp_1(OpRunKernel):
27
+ """Exp"""
28
+
29
+ def run(self, x: OpRunTensor) -> OpRunTensor:
30
+ return OpRunTensor(x.tensor.exp())
31
+
32
+
33
+ class Identity_1(OpRunKernel):
34
+ "Identity"
35
+
36
+ def run(self, x: OpRunTensor) -> OpRunTensor:
37
+ return OpRunTensor(x.tensor)
38
+
39
+
40
+ class Log_1(OpRunKernel):
41
+ """Log"""
42
+
43
+ def run(self, x: OpRunTensor) -> OpRunTensor:
44
+ return OpRunTensor(x.tensor.log())
45
+
46
+
47
+ class Neg_1(OpRunKernel):
48
+ """Neg"""
49
+
50
+ def run(self, x: OpRunTensor) -> OpRunTensor:
51
+ return OpRunTensor(-x.tensor)
52
+
53
+
54
+ class Not_1(OpRunKernel):
55
+ """Not"""
56
+
57
+ def run(self, x: OpRunTensor) -> OpRunTensor:
58
+ return OpRunTensor(~x.tensor)
59
+
60
+
61
+ class Reciprocal_1(OpRunKernel):
62
+ """REciprocal"""
63
+
64
+ def run(self, x: OpRunTensor) -> OpRunTensor:
65
+ return OpRunTensor(1 / x.tensor)
66
+
67
+
68
+ class Sigmoid_6(OpRunKernel):
69
+ """Sqrt"""
70
+
71
+ def run(self, x: OpRunTensor) -> OpRunTensor:
72
+ return OpRunTensor(torch.sigmoid(x.tensor))
73
+
74
+
75
+ class Sin_1(OpRunKernel):
76
+ """Sin"""
77
+
78
+ def run(self, x: OpRunTensor) -> OpRunTensor:
79
+ return OpRunTensor(x.tensor.sin())
80
+
81
+
82
+ class Sqrt_1(OpRunKernel):
83
+ """Sqrt"""
84
+
85
+ def run(self, x: OpRunTensor) -> OpRunTensor:
86
+ return OpRunTensor(x.tensor.sqrt())
@@ -39,9 +39,30 @@ def supported_tasks() -> List[str]:
39
39
 
40
40
  def reduce_model_config(config: Any, task: str) -> Dict[str, Any]:
41
41
  """Reduces a model size."""
42
+ head_size0 = (
43
+ config.head_dim
44
+ if hasattr(config, "head_dim") and config.head_dim
45
+ else (
46
+ config.hidden_size // config.num_attention_heads
47
+ if hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads")
48
+ else None
49
+ )
50
+ )
42
51
  tasks = {mod.__TASK__: mod.reduce_model_config for mod in __TASKS__}
43
52
  assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
44
- return tasks[task](config)
53
+ res = tasks[task](config)
54
+ if head_size0 and "head_dim" in res:
55
+ head_size = (
56
+ config.head_dim
57
+ if hasattr(config, "head_dim") and config.head_dim
58
+ else config.hidden_size // config.num_attention_heads
59
+ )
60
+ assert head_size0 == head_size or head_size % 16 == 0, (
61
+ f"head_size should be a multiple of 16 "
62
+ f"(head_size0={head_size0}), res={res}, "
63
+ f"config=\n{config}"
64
+ )
65
+ return res
45
66
 
46
67
 
47
68
  def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callable]:
@@ -58,8 +58,8 @@ def get_inputs(
58
58
  shapes = {
59
59
  "pixel_values": {
60
60
  0: torch.export.Dim("batch", min=1, max=1024),
61
- 2: torch.export.Dim("width", min=1, max=4096),
62
- 3: torch.export.Dim("height", min=1, max=4096),
61
+ 2: "width",
62
+ 3: "height",
63
63
  },
64
64
  }
65
65
  inputs = dict(
@@ -27,7 +27,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
27
27
  kwargs = dict(
28
28
  num_hidden_layers=min(config.num_hidden_layers, 2),
29
29
  intermediate_size=256 if config is None else min(512, config.intermediate_size),
30
- hidden_size=256 if config is None else min(256, config.hidden_size),
30
+ hidden_size=512 if config is None else min(512, config.hidden_size),
31
31
  cls_cache="MambaCache",
32
32
  state_size=8 if config is None else getattr(config, "state_size", None),
33
33
  conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
@@ -44,8 +44,8 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
44
44
  else config.num_attention_heads
45
45
  ),
46
46
  hidden_size=(
47
- min(config.hidden_size, 3072 // 4)
48
- if config.hidden_size % 4 == 0
47
+ min(config.hidden_size, 4096 // 4)
48
+ if config.hidden_size % 64 == 0
49
49
  else config.hidden_size
50
50
  ),
51
51
  )
@@ -185,9 +185,17 @@ def _make_exporter_export(
185
185
 
186
186
  if exporter == "export-strict":
187
187
  try:
188
- exported = torch.export.export(
189
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
190
- )
188
+ if verbose >= 2:
189
+ exported = torch.export.export(
190
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
191
+ )
192
+ else:
193
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
194
+ io.StringIO()
195
+ ):
196
+ exported = torch.export.export(
197
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
198
+ )
191
199
  except Exception as e:
192
200
  if not quiet:
193
201
  raise
@@ -198,17 +206,33 @@ def _make_exporter_export(
198
206
  return exported.module()
199
207
  if exporter in ("export-strict-dec", "export-strict-decall"):
200
208
  try:
201
- exported = torch.export.export(
202
- model, inputs, dynamic_shapes=dynamic_shapes, strict=True
203
- )
204
- if verbose >= 9:
205
- print("-- graph before decomposition")
206
- print(exported.graph)
207
- exported = (
208
- exported.run_decompositions()
209
- if "decall" in exporter
210
- else exported.run_decompositions({})
211
- )
209
+ if verbose >= 2:
210
+ exported = torch.export.export(
211
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
212
+ )
213
+ if verbose >= 9:
214
+ print("-- graph before decomposition")
215
+ print(exported.graph)
216
+ exported = (
217
+ exported.run_decompositions()
218
+ if "decall" in exporter
219
+ else exported.run_decompositions({})
220
+ )
221
+ else:
222
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
223
+ io.StringIO()
224
+ ):
225
+ exported = torch.export.export(
226
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=True
227
+ )
228
+ if verbose >= 9:
229
+ print("-- graph before decomposition")
230
+ print(exported.graph)
231
+ exported = (
232
+ exported.run_decompositions()
233
+ if "decall" in exporter
234
+ else exported.run_decompositions({})
235
+ )
212
236
  except Exception as e:
213
237
  if not quiet:
214
238
  raise
@@ -219,9 +243,17 @@ def _make_exporter_export(
219
243
  return exported.module()
220
244
  if exporter == "export-nostrict":
221
245
  try:
222
- exported = torch.export.export(
223
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
224
- )
246
+ if verbose >= 2:
247
+ exported = torch.export.export(
248
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
249
+ )
250
+ else:
251
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
252
+ io.StringIO()
253
+ ):
254
+ exported = torch.export.export(
255
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
256
+ )
225
257
  except Exception as e:
226
258
  if not quiet:
227
259
  raise
@@ -232,17 +264,33 @@ def _make_exporter_export(
232
264
  return exported.module()
233
265
  if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
234
266
  try:
235
- exported = torch.export.export(
236
- model, inputs, dynamic_shapes=dynamic_shapes, strict=False
237
- )
238
- if verbose >= 9:
239
- print("-- graph before decomposition")
240
- print(exported.graph)
241
- exported = (
242
- exported.run_decompositions()
243
- if "decall" in exporter
244
- else exported.run_decompositions({})
245
- )
267
+ if verbose >= 2:
268
+ exported = torch.export.export(
269
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
270
+ )
271
+ if verbose >= 9:
272
+ print("-- graph before decomposition")
273
+ print(exported.graph)
274
+ exported = (
275
+ exported.run_decompositions()
276
+ if "decall" in exporter
277
+ else exported.run_decompositions({})
278
+ )
279
+ else:
280
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
281
+ io.StringIO()
282
+ ):
283
+ exported = torch.export.export(
284
+ model, inputs, dynamic_shapes=dynamic_shapes, strict=False
285
+ )
286
+ if verbose >= 9:
287
+ print("-- graph before decomposition")
288
+ print(exported.graph)
289
+ exported = (
290
+ exported.run_decompositions()
291
+ if "decall" in exporter
292
+ else exported.run_decompositions({})
293
+ )
246
294
  except Exception as e:
247
295
  if not quiet:
248
296
  raise
@@ -255,8 +303,15 @@ def _make_exporter_export(
255
303
  from experimental_experiment.torch_interpreter.tracing import CustomTracer
256
304
 
257
305
  try:
258
- graph = CustomTracer().trace(model)
259
- mod = torch.fx.GraphModule(model, graph)
306
+ if verbose >= 2:
307
+ graph = CustomTracer().trace(model)
308
+ mod = torch.fx.GraphModule(model, graph)
309
+ else:
310
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
311
+ io.StringIO()
312
+ ):
313
+ graph = CustomTracer().trace(model)
314
+ mod = torch.fx.GraphModule(model, graph)
260
315
  except Exception as e:
261
316
  if not quiet:
262
317
  raise
@@ -289,13 +344,25 @@ def _make_exporter_onnx(
289
344
  if "-dec" in exporter:
290
345
  opts["decomposition_table"] = "all" if "-decall" in exporter else "default"
291
346
  try:
292
- onx, builder = to_onnx(
293
- model,
294
- inputs,
295
- dynamic_shapes=dynamic_shapes,
296
- export_options=ExportOptions(**opts),
297
- return_builder=True,
298
- )
347
+ if verbose >= 2:
348
+ onx, builder = to_onnx(
349
+ model,
350
+ inputs,
351
+ dynamic_shapes=dynamic_shapes,
352
+ export_options=ExportOptions(**opts),
353
+ return_builder=True,
354
+ )
355
+ else:
356
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
357
+ io.StringIO()
358
+ ):
359
+ onx, builder = to_onnx(
360
+ model,
361
+ inputs,
362
+ dynamic_shapes=dynamic_shapes,
363
+ export_options=ExportOptions(**opts),
364
+ return_builder=True,
365
+ )
299
366
  except Exception as e:
300
367
  if not quiet:
301
368
  raise RuntimeError(
@@ -306,6 +373,7 @@ def _make_exporter_onnx(
306
373
  ) from e
307
374
  return dict(error=str(e), success=0, error_step="export")
308
375
  return onx, builder
376
+
309
377
  if exporter == "dynamo":
310
378
  import torch
311
379
 
@@ -338,6 +406,7 @@ def _make_exporter_onnx(
338
406
  ) from e
339
407
  return dict(error=str(e), success=0, error_step="export")
340
408
  return onx, None
409
+
341
410
  if exporter == "dynamo-ir":
342
411
  import torch
343
412
 
@@ -23,7 +23,6 @@ class AtenRollPos(torch.nn.Module):
23
23
 
24
24
 
25
25
  class InplaceAdd(torch.nn.Module):
26
-
27
26
  def __init__(self):
28
27
  super().__init__()
29
28
  self.bias = torch.ones((1, 4), dtype=torch.float32)
@@ -37,7 +36,6 @@ class InplaceAdd(torch.nn.Module):
37
36
 
38
37
 
39
38
  class InplaceAdd2(torch.nn.Module):
40
-
41
39
  def __init__(self):
42
40
  super().__init__()
43
41
  self.bias = torch.ones((1, 4), dtype=torch.float32)
@@ -51,7 +49,6 @@ class InplaceAdd2(torch.nn.Module):
51
49
 
52
50
 
53
51
  class InplaceAdd_Mul(torch.nn.Module):
54
-
55
52
  def __init__(self):
56
53
  super().__init__()
57
54
  self.bias = torch.ones((1, 4), dtype=torch.float32)
@@ -65,7 +62,6 @@ class InplaceAdd_Mul(torch.nn.Module):
65
62
 
66
63
 
67
64
  class InplaceCloneAdd_(torch.nn.Module):
68
-
69
65
  def __init__(self):
70
66
  super().__init__()
71
67
  self.bias = torch.ones((1, 4), dtype=torch.float32)
@@ -80,7 +76,6 @@ class InplaceCloneAdd_(torch.nn.Module):
80
76
 
81
77
 
82
78
  class InplaceSetItemSquare(torch.nn.Module):
83
-
84
79
  def forward(self, x):
85
80
  x[:2, :3] = 1
86
81
  return x
@@ -90,7 +85,6 @@ class InplaceSetItemSquare(torch.nn.Module):
90
85
 
91
86
 
92
87
  class InplaceSetItemSquareAdd(torch.nn.Module):
93
-
94
88
  def forward(self, x):
95
89
  x[:2, :3] = 1
96
90
  return x + 2
@@ -100,7 +94,6 @@ class InplaceSetItemSquareAdd(torch.nn.Module):
100
94
 
101
95
 
102
96
  class InplaceSetItemSquareAdd2(torch.nn.Module):
103
-
104
97
  def forward(self, x):
105
98
  x[:2, :3] = 1
106
99
  return x + 2, x + 3
@@ -110,7 +103,6 @@ class InplaceSetItemSquareAdd2(torch.nn.Module):
110
103
 
111
104
 
112
105
  class InplaceSetItemEllipsis_1(torch.nn.Module):
113
-
114
106
  def __init__(self):
115
107
  super().__init__()
116
108
  self.params = torch.zeros((1, 8192, 4), dtype=torch.float32)
@@ -124,11 +116,10 @@ class InplaceSetItemEllipsis_1(torch.nn.Module):
124
116
  (torch.from_numpy(np.array([0, 3, 2, 1])).to(torch.int64)),
125
117
  (torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
126
118
  )
127
- _dynamic = {"x": {0: DIM("batch")}}
119
+ _dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
128
120
 
129
121
 
130
122
  class InplaceSetItemEllipsis_2(torch.nn.Module):
131
-
132
123
  def __init__(self):
133
124
  super().__init__()
134
125
  self.params = torch.zeros((1, 8192, 6), dtype=torch.float32)
@@ -142,7 +133,7 @@ class InplaceSetItemEllipsis_2(torch.nn.Module):
142
133
  torch.from_numpy(np.array([0, 3, 2, 5])).to(torch.int64),
143
134
  (torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
144
135
  )
145
- _dynamic = {"x": {0: DIM("batch")}}
136
+ _dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
146
137
 
147
138
 
148
139
  class InplaceSetItemMask(torch.nn.Module):
@@ -156,7 +147,6 @@ class InplaceSetItemMask(torch.nn.Module):
156
147
 
157
148
 
158
149
  class AtenInterpolate(torch.nn.Module):
159
-
160
150
  def forward(self, x):
161
151
  y = torch.nn.functional.interpolate(
162
152
  x,
@@ -171,7 +161,6 @@ class AtenInterpolate(torch.nn.Module):
171
161
 
172
162
 
173
163
  class AtenNonZero(torch.nn.Module):
174
-
175
164
  def forward(self, x):
176
165
  y = torch.nonzero(x)
177
166
  return y
@@ -181,7 +170,6 @@ class AtenNonZero(torch.nn.Module):
181
170
 
182
171
 
183
172
  class AtenNonZeroTuple(torch.nn.Module):
184
-
185
173
  def forward(self, x):
186
174
  y = torch.nonzero(x, as_tuple=True)
187
175
  return y[0], y[1]
@@ -191,7 +179,6 @@ class AtenNonZeroTuple(torch.nn.Module):
191
179
 
192
180
 
193
181
  class AtenAsStrided(torch.nn.Module):
194
-
195
182
  def __init__(self):
196
183
  super().__init__()
197
184
 
@@ -288,7 +275,6 @@ class ControlFlowCondConstant(torch.nn.Module):
288
275
 
289
276
 
290
277
  class ControlFlowCondNestedModule(torch.nn.Module):
291
-
292
278
  class Submodule(torch.nn.Module):
293
279
  def __init__(self):
294
280
  super().__init__()
@@ -367,7 +353,9 @@ class ControlFlowCondNonZero(torch.nn.Module):
367
353
 
368
354
 
369
355
  class ControlFlowCondIdentity_153832(torch.nn.Module):
370
- """`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
356
+ """
357
+ `#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
358
+ """
371
359
 
372
360
  def forward(self, x, y):
373
361
 
@@ -389,7 +377,6 @@ class ControlFlowCondIdentity_153832(torch.nn.Module):
389
377
 
390
378
 
391
379
  class ControlFlowScan(torch.nn.Module):
392
-
393
380
  @staticmethod
394
381
  def add(carry: torch.Tensor, y: torch.Tensor):
395
382
  next_carry = carry + y
@@ -486,7 +473,6 @@ class ControlFlowScanCDist2(torch.nn.Module):
486
473
 
487
474
 
488
475
  class ControlFlowScanCDistXY(torch.nn.Module):
489
-
490
476
  @staticmethod
491
477
  def dist(y: torch.Tensor, scanned_x: torch.Tensor):
492
478
  sub = y - scanned_x.reshape((1, -1))
@@ -517,7 +503,9 @@ class ControlFlowScanCDistXY(torch.nn.Module):
517
503
 
518
504
 
519
505
  class ControlFlowScanInplace_153705(torch.nn.Module):
520
- """`#153705 <https://github.com/pytorch/pytorch/issues/153705>`_"""
506
+ """
507
+ `#153705 <https://github.com/pytorch/pytorch/issues/153705>`_
508
+ """
521
509
 
522
510
  def forward(self, x, y):
523
511
  def loop_body_1(z, iv, x, y):
@@ -540,7 +528,9 @@ class ControlFlowScanInplace_153705(torch.nn.Module):
540
528
 
541
529
 
542
530
  class ControlFlowScanDecomposition_151564(torch.nn.Module):
543
- """`#151564 <https://github.com/pytorch/pytorch/issues/151564>`_"""
531
+ """
532
+ `#151564 <https://github.com/pytorch/pytorch/issues/151564>`_
533
+ """
544
534
 
545
535
  @classmethod
546
536
  def dummy_loop(cls, padded: torch.Tensor, pos: torch.Tensor):
@@ -790,7 +780,6 @@ class SignatureShapeAsIndex(torch.nn.Module):
790
780
 
791
781
 
792
782
  class TypeBFloat16(torch.nn.Module):
793
-
794
783
  def forward(self, x):
795
784
  xb = x.to(torch.bfloat16)
796
785
  return (xb + xb).to(torch.float32)
@@ -825,7 +814,6 @@ class CropLastDimensionWithTensorShape(torch.nn.Module):
825
814
 
826
815
 
827
816
  class CropLastDimensionWithTensorContent(torch.nn.Module):
828
-
829
817
  def forward(self, x, shape):
830
818
  return x[..., : shape[0]]
831
819
 
@@ -833,11 +821,10 @@ class CropLastDimensionWithTensorContent(torch.nn.Module):
833
821
  (torch.rand(3, 4, 4).to(torch.float32), torch.tensor([2], dtype=torch.int64)),
834
822
  (torch.rand(6, 4, 4).to(torch.float32), torch.tensor([3], dtype=torch.int64)),
835
823
  ]
836
- _dynamic = {"x": {0: DIM("batch")}}
824
+ _dynamic = {"x": {0: DIM("batch")}, "shape": {}}
837
825
 
838
826
 
839
827
  class SignatureListFixedWithNone(torch.nn.Module):
840
-
841
828
  def forward(self, lx):
842
829
  x = lx[0]
843
830
  if lx[1] is not None: