onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -3
  3. onnx_diagnostic/ci_models/ci_helpers.py +12 -7
  4. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  5. onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
  6. onnx_diagnostic/export/api.py +1 -0
  7. onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
  8. onnx_diagnostic/ext_test_case.py +9 -2
  9. onnx_diagnostic/helpers/bench_run.py +1 -1
  10. onnx_diagnostic/helpers/log_helper.py +1 -3
  11. onnx_diagnostic/helpers/optim_helper.py +116 -0
  12. onnx_diagnostic/tasks/image_text_to_text.py +15 -5
  13. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  14. onnx_diagnostic/tasks/text_generation.py +3 -0
  15. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
  16. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  17. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  18. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  19. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
  20. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  21. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
  22. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  23. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  25. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +29 -26
  27. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,10 @@ Requirements
8
8
  ::
9
9
 
10
10
  git+https://github.com/sdpython/experimental-experiment.git # optional
11
- huggingface_hub>=1.2.1
11
+ huggingface_hub
12
12
  onnx-diagnostic>=0.8.6
13
13
  onnxruntime>=1.23
14
- torch>=2.9 # weekly is better
14
+ torch>=2.10 # weekly is better
15
15
  tqdm
16
16
  transformers>=4.57
17
17
 
@@ -59,6 +59,7 @@ It is possible to overwrite this by by setting environment variable
59
59
  import os
60
60
  import sys
61
61
  import time
62
+ import warnings
62
63
  from typing import Any, Dict, List, Tuple
63
64
  from .ci_helpers import (
64
65
  check_for_discrepancies_and_log_everything_into_a_json_file,
@@ -97,7 +98,6 @@ def get_untrained_model(model_id: str, second_input: bool, verbose: int) -> Dict
97
98
  },
98
99
  # "_attn_implementation": "flash_attention_2",
99
100
  "_attn_implementation": "sdpa",
100
- "dtype": "float16",
101
101
  }
102
102
 
103
103
  config_reduction = _config_reduction
@@ -281,6 +281,10 @@ def main(
281
281
  ).eval()
282
282
  data = dict(model=model)
283
283
  config = model.config
284
+ if not hasattr(config, "bos_token_id") or not config.bos_token_id:
285
+ config.bos_token_id = 151643
286
+ if not hasattr(config, "eos_token_id") or not config.eos_token_id:
287
+ config.eos_token_id = 151645
284
288
  else:
285
289
  print("-- random model")
286
290
  data = get_untrained_model(model_id, second_input=second_input, verbose=1)
@@ -298,7 +302,11 @@ def main(
298
302
  print(f"-- config._attn_implementation={model.config._attn_implementation}")
299
303
  print(f"-- model.dtype={model.dtype}")
300
304
  print(f"-- model.device={model.device}")
301
- processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
305
+ try:
306
+ processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
307
+ except OSError as e:
308
+ warnings.warn(f"Unable to access internet due to {e!r}", ResourceWarning, stacklevel=0)
309
+ return
302
310
  print(f"-- processor={type(processor)}")
303
311
 
304
312
  export_inputs, other_inputs = None, None
@@ -154,6 +154,7 @@ def to_onnx(
154
154
  options=options,
155
155
  inline=inline,
156
156
  dispatcher=main_dispatcher,
157
+ optimize=optimize,
157
158
  **(exporter_kwargs or {}),
158
159
  )
159
160
 
@@ -11,6 +11,7 @@ from torch._higher_order_ops.utils import (
11
11
  unique_graph_id,
12
12
  validate_subgraph_args_types,
13
13
  )
14
+ import torch._dynamo.variables.higher_order_ops as hop
14
15
  from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
15
16
  from torch.utils._python_dispatch import _get_current_dispatch_mode
16
17
 
@@ -97,14 +98,18 @@ def _simple_loop_for_fn(
97
98
  f"Unexpected number of results {len(r)} for function {body_fn}, "
98
99
  f"expected {len(res[-1])}"
99
100
  )
101
+ assert all(isinstance(t, torch.Tensor) for t in r), (
102
+ f"Unexpected type {[type(_) for _ in r]} for returned by function {body_fn}, "
103
+ f"it must be a tuple of Tensor or a Tensor."
104
+ )
100
105
  res.append(r)
101
106
  else:
102
107
  assert isinstance(r, torch.Tensor), (
103
- f"Unexpected type {r} for function {body_fn}, "
104
- f"it must be a tuple or a Tensor."
108
+ f"Unexpected type {type(r)} coming from function {body_fn}, "
109
+ f"it must be a tuple of Tensor or a Tensor."
105
110
  )
106
111
  assert not res or len(res[-1]) == 1, (
107
- f"Unexpected number of results {len(r)} for function {body_fn}, "
112
+ f"Unexpected number of results {len(r)} coming from function {body_fn}, "
108
113
  f"expected {len(res[-1])}"
109
114
  )
110
115
  res.append((r,))
@@ -126,8 +131,6 @@ def _simple_loop_for_fn(
126
131
  )
127
132
 
128
133
 
129
- # from torch._functorch.utils import exposed_in
130
- # @exposed_in("torch")
131
134
  def _simple_loop_for(
132
135
  n_iter: Union[int, torch.Tensor],
133
136
  body_fn: Callable,
@@ -159,7 +162,7 @@ def _simple_loop_for(
159
162
 
160
163
  if torch.compiler.is_dynamo_compiling():
161
164
  return simple_loop_for_op(
162
- n_iter, body_fn, (n_iter, *operands), concatenation_dims=concatenation_dims
165
+ n_iter, body_fn, operands, concatenation_dims=concatenation_dims
163
166
  )
164
167
 
165
168
  if isinstance(n_iter, (bool, int, float)):
@@ -181,8 +184,10 @@ def _simple_loop_for(
181
184
 
182
185
  with setup_compilation_env() as _backend:
183
186
  return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims)
184
- # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
185
- # n_iter, body_fn, operands, concatenation_dims)
187
+ # This is needed to support function body using module weights or function body
188
+ # defined as a class method. This is yet to be implemented.
189
+ # cpl = torch.compile(_loop_for_op_wrapper, backend=_backend, fullgraph=True)
190
+ # return cpl(n_iter, body_fn, operands, concatenation_dims)
186
191
 
187
192
 
188
193
  def trace_simple_loop_for(
@@ -236,9 +241,15 @@ def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None):
236
241
  )
237
242
  mode = _get_current_dispatch_mode()
238
243
  assert mode is None, "Mode should never be enabled for CPU/CUDA key"
239
- return _simple_loop_for_fn(
240
- n_iter, body_fn, operands, concatenation_dims=concatenation_dims
244
+ is_fake = isinstance(n_iter, torch._subclasses.fake_tensor.FakeTensor)
245
+ res = _simple_loop_for_fn(n_iter, body_fn, operands, concatenation_dims=concatenation_dims)
246
+ assert is_fake or not any(
247
+ isinstance(r, torch._subclasses.fake_tensor.FakeTensor) for r in res
248
+ ), (
249
+ f"One result is a fake tensor but the inputs were not, type(n_iter)={type(n_iter)}, "
250
+ f"operands: {[type(_) for _ in operands]}, res: {[type(_) for _ in res]}"
241
251
  )
252
+ return res
242
253
 
243
254
 
244
255
  @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
@@ -267,6 +278,180 @@ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
267
278
  simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
268
279
 
269
280
 
281
+ class SimpleLoopForHigherOrderVariable(hop.TorchHigherOrderOperatorVariable):
282
+ """
283
+ Replicates the same pattern found for other higher order operators.
284
+ This enables recursive compilation and the use of modules inside a function.
285
+ """
286
+
287
+ _HOP_NAME = "simple_loop_for"
288
+ _ALLOW_FALLBACK_TO_EAGER = False
289
+ supports_input_mutation = False
290
+ supports_aliasing = False
291
+
292
+ def _call_function(
293
+ self,
294
+ tx: torch._dynamo.symbolic_convert.InstructionTranslator,
295
+ args: list[hop.VariableTracker],
296
+ kwargs: dict[str, hop.VariableTracker],
297
+ ) -> hop.VariableTracker:
298
+ """Main function."""
299
+ args, kwargs = hop.LazyVariableTracker.realize_all((args, kwargs))
300
+
301
+ for i, k in enumerate(["n_iter", "body_fn", "operands", "concatenated_dims"]):
302
+ if v := kwargs.pop(k, None):
303
+ assert i == len(args), "did not provide the right number of non-keyword args"
304
+ args.append(v)
305
+
306
+ if len(args) != 4 or kwargs:
307
+ hop.unimplemented(
308
+ gb_type="simple_loop_for: improper args/kwargs",
309
+ context=f"args: {args}, kwargs: {kwargs}",
310
+ explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) "
311
+ f"and no keyword arguments (got {len(kwargs)})",
312
+ hints=[*hop.graph_break_hints.USER_ERROR],
313
+ )
314
+
315
+ # Specialize into one of the branches since pred is constant
316
+ n_iter, body_fn, operands, _concatenated_dims = args
317
+ assert type(n_iter) is not hop.ConstantVariable, (
318
+ f"n_iter is a {type(n_iter)}. When used simple_loop_for, "
319
+ f"it unrolls the loop. A SymInt should be used."
320
+ )
321
+
322
+ # predicate
323
+ if type(n_iter.realize()) not in (
324
+ hop.ConstantVariable,
325
+ hop.TensorVariable,
326
+ hop.SymNodeVariable,
327
+ ):
328
+ hop.unimplemented(
329
+ gb_type="simple_loop_for: improper predicate",
330
+ context=str(n_iter),
331
+ explanation=(
332
+ f"Expected `n_iter` to be an int or a integer "
333
+ f"tensor with a single item "
334
+ f"but got {str(type(n_iter))} with original python type "
335
+ f"{str(n_iter.python_type())}."
336
+ ),
337
+ hints=[*hop.graph_break_hints.USER_ERROR],
338
+ )
339
+
340
+ # operands
341
+ if not isinstance(operands, (hop.ListVariable, hop.TupleVariable)):
342
+ hop.unimplemented(
343
+ gb_type="simple_loop_for: improper operands",
344
+ context=str(operands),
345
+ explanation="Expected `operands` to be a list/tuple "
346
+ f"but got {operands.python_type()}.",
347
+ hints=[*hop.graph_break_hints.USER_ERROR],
348
+ )
349
+
350
+ operands_seq = operands.unpack_var_sequence(tx)
351
+ if not hop.only_consist_of(
352
+ operands, (hop.TensorVariable, hop.ConstantVariable, hop.SymNodeVariable)
353
+ ):
354
+ hop.unimplemented(
355
+ gb_type="simple_loop_for: improper operands contents",
356
+ context=str(operands),
357
+ explanation=(
358
+ "Expected `operands` to be a list/tuple of pytrees "
359
+ "that only consists of tensor leaves."
360
+ ),
361
+ hints=[*hop.graph_break_hints.USER_ERROR],
362
+ )
363
+
364
+ # branches
365
+ hop._check_supported_callable_arg(tx, body_fn, "body_fn")
366
+
367
+ def speculate_body():
368
+ (
369
+ (ret_val, ret_spec),
370
+ ret_graph,
371
+ ret_lifted_freevars,
372
+ ) = hop.speculate_subgraph(
373
+ tx,
374
+ args[1],
375
+ (args[0], *operands_seq),
376
+ {},
377
+ self._HOP_NAME,
378
+ source_target=self.value,
379
+ should_flatten_outputs=True,
380
+ # TODO - removing consts from control flow ops need more work
381
+ remove_consts_from_outputs=False,
382
+ supports_input_mutation=self.supports_input_mutation,
383
+ supports_aliasing=self.supports_aliasing,
384
+ )
385
+
386
+ # need to ensure we increase epoch so we don't memoize unbacked bindings
387
+ # across different subgraphs which can interfere with runtime assertion
388
+ # generation.
389
+ tx.fake_mode.epoch += 1
390
+
391
+ if not hop.only_consist_of(ret_val, (hop.TensorVariable, hop.ConstantVariable)):
392
+ hop.unimplemented(
393
+ gb_type="simple_loop_for: unsupported branch return type",
394
+ context=str(ret_val),
395
+ explanation=(
396
+ "Expected branches to return a possibly nested "
397
+ "pytree of tensors or constant ints."
398
+ ),
399
+ hints=[*hop.graph_break_hints.USER_ERROR],
400
+ )
401
+ for ret in ret_val.unpack_var_sequence(tx):
402
+ if ret.is_python_constant() and not isinstance(ret.as_python_constant(), int):
403
+ hop.unimplemented(
404
+ gb_type=(
405
+ "simple_loop_for: unsupported branch return type "
406
+ "(constant non-int)"
407
+ ),
408
+ context=str(ret_val),
409
+ explanation="Constants returned from branches must be ints.",
410
+ hints=[*hop.graph_break_hints.USER_ERROR],
411
+ )
412
+ return ret_val, ret_spec, ret_graph, ret_lifted_freevars
413
+
414
+ body_r, body_spec, body_graph, body_lifted_freevars = speculate_body()
415
+ body_nn_modules = dict(tx.output.nn_modules)
416
+
417
+ same_spec = body_spec.treespec.as_python_constant()
418
+ if same_spec is not NotImplemented and not same_spec:
419
+ hop.unimplemented(
420
+ gb_type="simple_loop_for: differing branch outputs",
421
+ context=(
422
+ f"body_spec: {body_spec.treespec}, false_spec: "
423
+ f"{body_spec.treespec}, same_spec: {same_spec}"
424
+ ),
425
+ explanation="Expected branches to return the same pytree structure.",
426
+ hints=[*hop.graph_break_hints.USER_ERROR],
427
+ )
428
+
429
+ body_name = tx.output.install_subgraph(
430
+ "loop_body", torch.fx.GraphModule(body_nn_modules, body_graph)
431
+ )
432
+ body_node = hop.make_attr(tx, body_name)
433
+ p_args = (
434
+ n_iter.as_proxy(),
435
+ body_node,
436
+ # We pick true_shared but it shouldn't matter
437
+ operands.as_proxy() + tuple(body_lifted_freevars.keys()),
438
+ )
439
+
440
+ return hop._call_function_and_unflatten_output(
441
+ tx,
442
+ simple_loop_for,
443
+ p_args,
444
+ {},
445
+ None,
446
+ body_spec,
447
+ body_r,
448
+ )
449
+
450
+
451
+ hop._hop_name_to_variable_class["simple_loop_for"] = SimpleLoopForHigherOrderVariable
452
+
453
+
454
+ # @torch._functorch.utils.exposed_in("torch")
270
455
  def simple_loop_for(
271
456
  n_iter: Union[int, torch.Tensor],
272
457
  body_fn: Callable,
@@ -1267,6 +1267,7 @@ class ExtTestCase(unittest.TestCase):
1267
1267
  :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
1268
1268
  """
1269
1269
  from .helpers import string_type, string_diff, max_diff
1270
+ from .helpers.torch_helper import torch_deepcopy
1270
1271
  from .helpers.rt_helper import make_feeds
1271
1272
  from .helpers.ort_session import InferenceSessionForTorch
1272
1273
 
@@ -1283,6 +1284,12 @@ class ExtTestCase(unittest.TestCase):
1283
1284
  model_file = proto
1284
1285
  name = proto
1285
1286
  proto = onnx.load(name)
1287
+ elif hasattr(proto, "save"):
1288
+ name = f"{test_name}.onnx"
1289
+ proto.save(name)
1290
+ proto = onnx.load(name)
1291
+ elif hasattr(proto, "model_proto"):
1292
+ proto = proto.model_proto
1286
1293
  elif not self.unit_test_going():
1287
1294
  assert isinstance(
1288
1295
  proto, onnx.ModelProto
@@ -1341,9 +1348,9 @@ class ExtTestCase(unittest.TestCase):
1341
1348
  if copy_inputs:
1342
1349
  expected = [
1343
1350
  (
1344
- model(*copy.deepcopy(inp))
1351
+ model(*torch_deepcopy(inp))
1345
1352
  if isinstance(inp, tuple)
1346
- else model(**copy.deepcopy(inp))
1353
+ else model(**torch_deepcopy(inp))
1347
1354
  )
1348
1355
  for inp in inputs
1349
1356
  ]
@@ -20,7 +20,7 @@ class BenchmarkError(RuntimeError):
20
20
 
21
21
 
22
22
  def _clean_string(s: str) -> str:
23
- cleaned = [c for c in s if 32 <= ord(c) < 127 and c not in {","}]
23
+ cleaned = [c for c in s if 32 <= ord(c) < 127 and c not in {",", ":"}]
24
24
  return "".join(cleaned)
25
25
 
26
26
 
@@ -1921,9 +1921,7 @@ class CubeLogsPerformance(CubeLogs):
1921
1921
  return lambdas[formula]
1922
1922
 
1923
1923
  if formula == "onnx_n_nodes_no_cst":
1924
- return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(
1925
- df, "op_onnx__Constant", 0
1926
- ).fillna(0)
1924
+ return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(df, "op_onnx__Constant", 0)
1927
1925
  if formula == "peak_gpu_torch":
1928
1926
  return lambda df: gdf(df, "mema_gpu_5_after_export") - gdf(df, "mema_gpu_4_reset")
1929
1927
  if formula == "peak_gpu_nvidia":
@@ -0,0 +1,116 @@
1
+ from typing import Optional, Union
2
+ import pprint
3
+ import onnx
4
+
5
+
6
+ def optimize_model(
7
+ algorithm: str,
8
+ model: Union[onnx.ModelProto, str],
9
+ output: Optional[str] = None,
10
+ processor: Optional[str] = None,
11
+ infer_shapes: bool = True,
12
+ remove_shape_info: bool = False,
13
+ verbose: int = 1,
14
+ ):
15
+ """
16
+ Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
17
+ and replaces them by the corresponding nodes. It also does basic optimization
18
+ such as removing identity nodes or unused nodes.
19
+
20
+ :param algorithm: algorithm to choose
21
+ :param model: model to optimize as a proto or a filename
22
+ :param output: if not empty, the optimized model is saved
23
+ :param processor: optimization are done for the processor
24
+ :param infer_shapes: infer shapes before optimizing, this might not be
25
+ available for all algorithm
26
+ :param remove_shape_info: remove shape information before saving the model
27
+ :param verbose: verbosity level
28
+ :return: optimized model
29
+
30
+ The goal is to make the model faster.
31
+ Argument patterns defines the patterns to apply or the set of patterns.
32
+ It is possible to show statistics or to remove a particular pattern.
33
+ Here are some environment variables which can be used to trigger
34
+ these displays.
35
+
36
+ Available options algorithms, default and default+runtime:
37
+
38
+ - ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply
39
+ those patterns when optimizing a model
40
+ - ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
41
+ - ``PATTERN=<pattern1,pattern2,...>``: increase verbosity
42
+ for specific patterns to understand why one pattern was not applied,
43
+ this shows which line is rejecting a pattern if it seems one pattern was missed
44
+ """
45
+ if isinstance(model, str):
46
+ if verbose:
47
+ print(f"[optimize_model] load {model!r}")
48
+ proto = onnx.load(model)
49
+ if verbose:
50
+ print("[optimize_model] done loading.")
51
+ else:
52
+ proto = model
53
+
54
+ if verbose:
55
+ print(f"[optimize_model] optimize with {algorithm!r}")
56
+ if algorithm in {"default", "default+onnxruntime"}:
57
+ from experimental_experiment.xoptim import get_pattern_list
58
+ from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
59
+
60
+ pats = get_pattern_list(algorithm)
61
+
62
+ gr = GraphBuilder(
63
+ proto,
64
+ infer_shapes_options=infer_shapes,
65
+ optimization_options=OptimizationOptions(
66
+ patterns=pats,
67
+ verbose=verbose,
68
+ remove_unused=True,
69
+ constant_folding=True,
70
+ remove_identity=True,
71
+ max_iter=max(100, len(proto.graph.node) // 2),
72
+ processor=processor or "CPU",
73
+ ),
74
+ )
75
+ if verbose:
76
+ print(f"[optimize_model] starts optimizing with {len(pats)} patterns")
77
+ print(f"[optimize_model] model has {len(proto.graph.node)} nodes")
78
+ opt_onx, report = gr.to_onnx(optimize=True, return_optimize_report=True)
79
+ if verbose:
80
+ print("[optimize_model] optimization report")
81
+ pprint.pprint(report)
82
+ print("[optimize_model] done")
83
+
84
+ elif algorithm == "slim":
85
+ import onnxslim
86
+
87
+ opt_onx = onnxslim.slim(proto, no_shape_infer=not infer_shapes)
88
+ elif algorithm in {"ir", "os_ort"}:
89
+ import onnx_ir
90
+ import onnxscript.optimizer
91
+ from onnxscript.rewriter.ort_fusions import optimize_for_ort
92
+
93
+ model_ir = onnx_ir.from_proto(proto)
94
+ if algorithm == "ir":
95
+ onnxscript.optimizer.optimize(model_ir)
96
+ else:
97
+ optimize_for_ort(model_ir)
98
+ opt_onx = onnx_ir.serde.serialize_model(model_ir)
99
+
100
+ del proto
101
+ if verbose:
102
+ print(f"[optimize_model] done optimizing, model has {len(opt_onx.graph.node)} nodes")
103
+ if remove_shape_info:
104
+ if verbose:
105
+ print(f"[optimize_model] remove shape information {len(opt_onx.graph.value_info)}")
106
+ del opt_onx.graph.value_info[:]
107
+ if verbose:
108
+ print("[optimize_model] done removing shape info")
109
+
110
+ if output:
111
+ if verbose:
112
+ print(f"[optimize_model] save file into {output!r}")
113
+ onnx.save(opt_onx, output, save_as_external_data=True)
114
+ if verbose:
115
+ print("[optimize_model] done saving")
116
+ return opt_onx
@@ -13,6 +13,10 @@ from .data import get_data
13
13
  __TASK__ = "image-text-to-text"
14
14
 
15
15
 
16
+ def should_have_vision_config(config):
17
+ return config.architectures != ["FuyuForCausalLM"]
18
+
19
+
16
20
  def reduce_model_config(config: Any) -> Dict[str, Any]:
17
21
  """Reduces a model size."""
18
22
  kwargs: Dict[str, Any] = {}
@@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
477
481
  "hidden_size",
478
482
  "pad_token_id",
479
483
  )
480
- check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
484
+ if should_have_vision_config(config):
485
+ check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
481
486
  text_config = True
482
487
  else:
483
488
  check_hasattr(
@@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
491
496
  "vision_config",
492
497
  )
493
498
  text_config = False
494
- check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
499
+ if should_have_vision_config(config):
500
+ check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
495
501
  kwargs = dict(
496
502
  head_dim=(
497
503
  16
@@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
552
558
  ),
553
559
  width=(
554
560
  224
555
- if config is None or not hasattr(config.vision_config, "image_size")
561
+ if config is None
562
+ or not should_have_vision_config(config)
563
+ or not hasattr(config.vision_config, "image_size")
556
564
  else config.vision_config.image_size
557
565
  ),
558
566
  height=(
559
567
  224
560
- if config is None or not hasattr(config.vision_config, "image_size")
568
+ if config is None
569
+ or not should_have_vision_config(config)
570
+ or not hasattr(config.vision_config, "image_size")
561
571
  else config.vision_config.image_size
562
572
  ),
563
573
  num_channels=(
564
574
  3
565
- if config is None
575
+ if config is None or not should_have_vision_config(config)
566
576
  else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
567
577
  ),
568
578
  pad_token_id=(