onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +18 -0
  3. onnx_diagnostic/api.py +15 -0
  4. onnx_diagnostic/ext_test_case.py +3 -1
  5. onnx_diagnostic/helpers/args_helper.py +1 -1
  6. onnx_diagnostic/helpers/helper.py +6 -5
  7. onnx_diagnostic/helpers/model_builder_helper.py +24 -8
  8. onnx_diagnostic/helpers/rt_helper.py +5 -1
  9. onnx_diagnostic/helpers/torch_helper.py +2 -0
  10. onnx_diagnostic/reference/__init__.py +1 -0
  11. onnx_diagnostic/reference/torch_evaluator.py +518 -0
  12. onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
  13. onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
  14. onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
  15. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  16. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
  17. onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
  18. onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
  19. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  20. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  21. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  22. onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
  23. onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
  24. onnx_diagnostic/tasks/__init__.py +22 -1
  25. onnx_diagnostic/tasks/image_classification.py +2 -2
  26. onnx_diagnostic/tasks/text_generation.py +3 -3
  27. onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
  28. onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
  29. onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
  31. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
  32. onnx_diagnostic/torch_models/test_helper.py +115 -15
  33. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  34. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +38 -23
  36. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
  37. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  from dataclasses import dataclass
3
+ from functools import wraps
3
4
  from typing import Any, Callable, Dict, List, Optional, Tuple
4
5
  import torch
5
6
  import transformers
@@ -531,3 +532,90 @@ class patched_GenerationMixin:
531
532
  # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
532
533
  model_inputs.pop("labels", None)
533
534
  return model_inputs
535
+
536
+
537
+ def patched_dynamic_rope_update(rope_forward):
538
+ """
539
+ patch:transformers.modeling_rope_utils.dynamic_rope_update
540
+ """
541
+
542
+ def longrope_frequency_update(self, position_ids, device):
543
+ seq_len = torch.max(position_ids) + 1
544
+ if hasattr(self.config, "original_max_position_embeddings"):
545
+ original_max_position_embeddings = self.config.original_max_position_embeddings
546
+ else:
547
+ original_max_position_embeddings = self.config.max_position_embeddings
548
+ # At export time, seq_len is unknown.
549
+ long_inv_freq, _ = self.rope_init_fn(
550
+ self.config, device, seq_len=original_max_position_embeddings + 1
551
+ )
552
+ original_inv_freq = self.original_inv_freq.to(device)
553
+
554
+ cond = (seq_len > original_max_position_embeddings).item()
555
+ inv_freq = torch.cond(
556
+ cond,
557
+ (lambda x, y: x.clone()),
558
+ (lambda x, y: y.clone()),
559
+ [long_inv_freq, original_inv_freq],
560
+ )
561
+ self.inv_freq = inv_freq
562
+ # if seq_len > original_max_position_embeddings:
563
+ # self.inv_freq = self.long_inv_freq
564
+ # else:
565
+ # self.inv_freq = self.original_inv_freq
566
+
567
+ def dynamic_frequency_update(self, position_ids, device):
568
+ seq_len = torch.max(position_ids) + 1
569
+ if seq_len > self.max_seq_len_cached: # growth
570
+ inv_freq, self.attention_scaling = self.rope_init_fn(
571
+ self.config, device, seq_len=seq_len
572
+ )
573
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
574
+ self.max_seq_len_cached = seq_len
575
+
576
+ if (
577
+ seq_len < self.original_max_seq_len
578
+ and self.max_seq_len_cached > self.original_max_seq_len
579
+ ):
580
+ self.original_inv_freq = self.original_inv_freq.to(device)
581
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
582
+ self.max_seq_len_cached = self.original_max_seq_len
583
+
584
+ @wraps(rope_forward)
585
+ def wrapper(self, x, position_ids):
586
+ if "dynamic" in self.rope_type:
587
+ dynamic_frequency_update(self, position_ids, device=x.device)
588
+ elif self.rope_type == "longrope":
589
+ longrope_frequency_update(self, position_ids, device=x.device)
590
+ return rope_forward(self, x, position_ids)
591
+
592
+ return wrapper
593
+
594
+
595
+ class patched_Phi3RotaryEmbedding(torch.nn.Module):
596
+ _PATCHES_ = ["forward"]
597
+ _PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
598
+
599
+ @torch.no_grad()
600
+ @patched_dynamic_rope_update
601
+ def forward(self, x, position_ids):
602
+ inv_freq_expanded = (
603
+ self.inv_freq[None, :, None]
604
+ .float()
605
+ .expand(position_ids.shape[0], -1, 1)
606
+ .to(x.device)
607
+ )
608
+ position_ids_expanded = position_ids[:, None, :].float()
609
+
610
+ device_type = (
611
+ x.device.type
612
+ if isinstance(x.device.type, str) and x.device.type != "mps"
613
+ else "cpu"
614
+ )
615
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
616
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
617
+ emb = torch.cat((freqs, freqs), dim=-1)
618
+ cos = emb.cos() * self.attention_scaling
619
+ sin = emb.sin() * self.attention_scaling
620
+
621
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
@@ -3951,3 +3951,145 @@ def _ccached_facebook_bart_large_cnn():
3951
3951
  "vocab_size": 50264,
3952
3952
  }
3953
3953
  )
3954
+
3955
+
3956
+ def _ccached_microsoft_phi4_reasoning():
3957
+ "microsoft/Phi-4-mini-reasoning"
3958
+ return transformers.Phi3Config(
3959
+ **{
3960
+ "architectures": ["Phi3ForCausalLM"],
3961
+ "attention_bias": false,
3962
+ "attention_dropout": 0.0,
3963
+ "bos_token_id": 199999,
3964
+ "embd_pdrop": 0.0,
3965
+ "eos_token_id": 199999,
3966
+ "full_attn_mod": 1,
3967
+ "hidden_act": "silu",
3968
+ "hidden_size": 3072,
3969
+ "initializer_range": 0.02,
3970
+ "intermediate_size": 8192,
3971
+ "interpolate_factor": 1,
3972
+ "lm_head_bias": false,
3973
+ "max_position_embeddings": 131072,
3974
+ "mlp_bias": false,
3975
+ "model_type": "phi3",
3976
+ "num_attention_heads": 24,
3977
+ "num_hidden_layers": 32,
3978
+ "num_key_value_heads": 8,
3979
+ "original_max_position_embeddings": 4096,
3980
+ "pad_token_id": 199999,
3981
+ "partial_rotary_factor": 0.75,
3982
+ "resid_pdrop": 0.0,
3983
+ "rms_norm_eps": 1e-05,
3984
+ "rope_scaling": {
3985
+ "long_factor": [
3986
+ 1,
3987
+ 1.118320672,
3988
+ 1.250641126,
3989
+ 1.398617824,
3990
+ 1.564103225,
3991
+ 1.74916897,
3992
+ 1.956131817,
3993
+ 2.187582649,
3994
+ 2.446418898,
3995
+ 2.735880826,
3996
+ 3.059592084,
3997
+ 3.421605075,
3998
+ 3.826451687,
3999
+ 4.279200023,
4000
+ 4.785517845,
4001
+ 5.351743533,
4002
+ 5.984965424,
4003
+ 6.693110555,
4004
+ 7.485043894,
4005
+ 8.370679318,
4006
+ 9.36110372,
4007
+ 10.4687158,
4008
+ 11.70738129,
4009
+ 13.09260651,
4010
+ 14.64173252,
4011
+ 16.37415215,
4012
+ 18.31155283,
4013
+ 20.47818807,
4014
+ 22.90118105,
4015
+ 25.61086418,
4016
+ 28.64115884,
4017
+ 32.03,
4018
+ 32.1,
4019
+ 32.13,
4020
+ 32.23,
4021
+ 32.6,
4022
+ 32.61,
4023
+ 32.64,
4024
+ 32.66,
4025
+ 32.7,
4026
+ 32.71,
4027
+ 32.93,
4028
+ 32.97,
4029
+ 33.28,
4030
+ 33.49,
4031
+ 33.5,
4032
+ 44.16,
4033
+ 47.77,
4034
+ ],
4035
+ "short_factor": [
4036
+ 1,
4037
+ 1.118320672,
4038
+ 1.250641126,
4039
+ 1.398617824,
4040
+ 1.564103225,
4041
+ 1.74916897,
4042
+ 1.956131817,
4043
+ 2.187582649,
4044
+ 2.446418898,
4045
+ 2.735880826,
4046
+ 3.059592084,
4047
+ 3.421605075,
4048
+ 3.826451687,
4049
+ 4.279200023,
4050
+ 4.785517845,
4051
+ 5.351743533,
4052
+ 5.984965424,
4053
+ 6.693110555,
4054
+ 7.485043894,
4055
+ 8.370679318,
4056
+ 9.36110372,
4057
+ 10.4687158,
4058
+ 11.70738129,
4059
+ 13.09260651,
4060
+ 14.64173252,
4061
+ 16.37415215,
4062
+ 18.31155283,
4063
+ 20.47818807,
4064
+ 22.90118105,
4065
+ 25.61086418,
4066
+ 28.64115884,
4067
+ 32.03,
4068
+ 32.1,
4069
+ 32.13,
4070
+ 32.23,
4071
+ 32.6,
4072
+ 32.61,
4073
+ 32.64,
4074
+ 32.66,
4075
+ 32.7,
4076
+ 32.71,
4077
+ 32.93,
4078
+ 32.97,
4079
+ 33.28,
4080
+ 33.49,
4081
+ 33.5,
4082
+ 44.16,
4083
+ 47.77,
4084
+ ],
4085
+ "type": "longrope",
4086
+ },
4087
+ "rope_theta": 10000.0,
4088
+ "sliding_window": 262144,
4089
+ "tie_word_embeddings": true,
4090
+ "torch_dtype": "bfloat16",
4091
+ "transformers_version": "4.50.0",
4092
+ "use_cache": true,
4093
+ "vocab_size": 200064,
4094
+ }
4095
+ )
@@ -4,6 +4,7 @@ import os
4
4
  import sys
5
5
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
6
  import time
7
+ import numpy as np
7
8
  import onnx
8
9
  import onnxscript
9
10
  import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -17,6 +18,7 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
17
18
  from ..tasks import random_input_kwargs
18
19
  from ..torch_export_patches import torch_export_patches
19
20
  from ..torch_export_patches.patch_inputs import use_dyn_not_str
21
+ from ..reference import TorchOnnxEvaluator
20
22
  from .hghub import get_untrained_model_with_inputs
21
23
 
22
24
 
@@ -192,11 +194,16 @@ def _quiet_or_not_quiet(
192
194
  summary: Dict[str, Any],
193
195
  data: Optional[Dict[str, Any]],
194
196
  fct: Callable,
197
+ repeat: int = 1,
198
+ warmup: int = 0,
195
199
  ) -> Any:
196
200
  begin = time.perf_counter()
197
201
  if quiet:
198
202
  try:
199
- return fct()
203
+ res = fct()
204
+ summary[f"time_{suffix}"] = time.perf_counter() - begin
205
+ if warmup + repeat == 1:
206
+ return res
200
207
  except Exception as e:
201
208
  summary[f"ERR_{suffix}"] = str(e)
202
209
  summary[f"time_{suffix}"] = time.perf_counter() - begin
@@ -204,11 +211,45 @@ def _quiet_or_not_quiet(
204
211
  return {f"ERR_{suffix}": e}
205
212
  data[f"ERR_{suffix}"] = e
206
213
  return None
207
- res = fct()
214
+ else:
215
+ res = fct()
208
216
  summary[f"time_{suffix}"] = time.perf_counter() - begin
217
+ if warmup + repeat > 1:
218
+ if suffix == "run":
219
+ res = torch_deepcopy(res)
220
+ summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
221
+ summary[f"{suffix}_warmup"] = warmup
222
+ summary[f"{suffix}_repeat"] = repeat
223
+ for _w in range(max(0, warmup - 1)):
224
+ t = fct()
225
+ summary[f"io_{suffix}_{_w+1}"] = string_type(t, with_shape=True, with_min_max=True)
226
+ summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
227
+ times = []
228
+ for _r in range(repeat):
229
+ begin = time.perf_counter()
230
+ t = fct()
231
+ times.append(time.perf_counter() - begin)
232
+ a = np.array(times)
233
+ summary[f"time_{suffix}_latency"] = a.mean()
234
+ summary[f"time_{suffix}_latency_std"] = a.std()
235
+ summary[f"time_{suffix}_latency_min"] = a.min()
236
+ summary[f"time_{suffix}_latency_min"] = a.max()
209
237
  return res
210
238
 
211
239
 
240
+ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
241
+ """Shrinks the configuration before it gets added to the information to log."""
242
+ new_cfg = {}
243
+ for k, v in cfg.items():
244
+
245
+ new_cfg[k] = (
246
+ v
247
+ if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50)
248
+ else (v.__class__("...") if isinstance(v, (list, tuple)) else "...")
249
+ )
250
+ return new_cfg
251
+
252
+
212
253
  def validate_model(
213
254
  model_id: str,
214
255
  task: Optional[str] = None,
@@ -231,6 +272,9 @@ def validate_model(
231
272
  model_options: Optional[Dict[str, Any]] = None,
232
273
  subfolder: Optional[str] = None,
233
274
  opset: Optional[int] = None,
275
+ runtime: str = "onnxruntime",
276
+ repeat: int = 1,
277
+ warmup: int = 0,
234
278
  ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
235
279
  """
236
280
  Validates a model.
@@ -267,6 +311,10 @@ def validate_model(
267
311
  ``num_hidden_layers`` or ``attn_implementation``
268
312
  :param subfolder: version or subfolders to uses when retrieving a model id
269
313
  :param opset: onnx opset to use for the conversion
314
+ :param runtime: onnx runtime to use to check about discrepancies,
315
+ only if `do_run` is true
316
+ :param repeat: number of time to measure the model
317
+ :param warmup: warmup the model first
270
318
  :return: two dictionaries, one with some metrics,
271
319
  another one with whatever the function produces
272
320
 
@@ -295,6 +343,7 @@ def validate_model(
295
343
  version_ortfusiontype=ortfusiontype or "",
296
344
  version_stop_if_static=str(stop_if_static),
297
345
  version_exporter=exporter or "",
346
+ version_runtime=runtime,
298
347
  )
299
348
  )
300
349
  if opset:
@@ -436,7 +485,9 @@ def validate_model(
436
485
  if summary["model_module"] in sys.modules:
437
486
  summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
438
487
  summary["model_config_class"] = data["configuration"].__class__.__name__
439
- summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "")
488
+ summary["model_config"] = str(shrink_config(data["configuration"].to_dict())).replace(
489
+ " ", ""
490
+ )
440
491
  summary["model_id"] = model_id
441
492
 
442
493
  if verbose:
@@ -460,7 +511,13 @@ def validate_model(
460
511
  model = data["model"]
461
512
 
462
513
  expected = _quiet_or_not_quiet(
463
- quiet, "run", summary, data, (lambda m=model, inp=inputs: m(**inp))
514
+ quiet,
515
+ "run",
516
+ summary,
517
+ data,
518
+ (lambda m=model, inp=inputs: m(**torch_deepcopy(inp))),
519
+ repeat=repeat,
520
+ warmup=warmup,
464
521
  )
465
522
  if "ERR_run" in summary:
466
523
  return summary, data
@@ -522,7 +579,7 @@ def validate_model(
522
579
 
523
580
  disc = max_diff(data["expected"], expected)
524
581
  for k, v in disc.items():
525
- summary[f"disc_patched_{k}"] = v
582
+ summary[f"disc_patched_{k}"] = str(v)
526
583
  if verbose:
527
584
  print("[validate_model] done (patched run)")
528
585
  print(f"[validate_model] patched discrepancies={string_diff(disc)}")
@@ -618,7 +675,14 @@ def validate_model(
618
675
  return summary, data
619
676
 
620
677
  if do_run:
621
- summary_valid, data = validate_onnx_model(data=data, quiet=quiet, verbose=verbose)
678
+ summary_valid, data = validate_onnx_model(
679
+ data=data,
680
+ quiet=quiet,
681
+ verbose=verbose,
682
+ runtime=runtime,
683
+ repeat=repeat,
684
+ warmup=warmup,
685
+ )
622
686
  summary.update(summary_valid)
623
687
 
624
688
  if ortfusiontype and "onnx_filename" in data:
@@ -671,7 +735,13 @@ def validate_model(
671
735
 
672
736
  if do_run:
673
737
  summary_valid, data = validate_onnx_model(
674
- data=data, quiet=quiet, verbose=verbose, flavour=flavour
738
+ data=data,
739
+ quiet=quiet,
740
+ verbose=verbose,
741
+ flavour=flavour,
742
+ runtime=runtime,
743
+ repeat=repeat,
744
+ warmup=warmup,
675
745
  )
676
746
  summary.update(summary_valid)
677
747
 
@@ -883,6 +953,9 @@ def validate_onnx_model(
883
953
  quiet: bool = False,
884
954
  verbose: int = 0,
885
955
  flavour: Optional[str] = None,
956
+ runtime: str = "onnxruntime",
957
+ repeat: int = 1,
958
+ warmup: int = 0,
886
959
  ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
887
960
  """
888
961
  Verifies that an onnx model produces the same
@@ -895,6 +968,9 @@ def validate_onnx_model(
895
968
  :param quiet: catch exception or not
896
969
  :param verbose: verbosity
897
970
  :param flavour: use a different version of the inputs
971
+ :param runtime: onnx runtime to use, onnxruntime or torch
972
+ :param repeat: run that number of times the model
973
+ :param warmup: warmup the model
898
974
  :return: two dictionaries, one with some metrics,
899
975
  another one with whatever the function produces
900
976
  """
@@ -936,18 +1012,28 @@ def validate_onnx_model(
936
1012
  f"{providers}..., flavour={flavour!r}"
937
1013
  )
938
1014
 
1015
+ cls_runtime = (
1016
+ (
1017
+ lambda model, providers: onnxruntime.InferenceSession(
1018
+ (model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
1019
+ providers=providers,
1020
+ )
1021
+ )
1022
+ if runtime == "onnxruntime"
1023
+ else (
1024
+ lambda model, providers: TorchOnnxEvaluator(
1025
+ model, providers=providers, verbose=max(verbose - 1, 0)
1026
+ )
1027
+ )
1028
+ )
939
1029
  sess = _quiet_or_not_quiet(
940
1030
  quiet,
941
- _mk("time_onnx_ort_create"),
1031
+ _mk("onnx_ort_create"),
942
1032
  summary,
943
1033
  data,
944
- (
945
- lambda source=source, providers=providers: onnxruntime.InferenceSession(
946
- source, providers=providers
947
- )
948
- ),
1034
+ (lambda source=source, providers=providers: cls_runtime(source, providers)),
949
1035
  )
950
- if f"ERR_{_mk('time_onnx_ort_create')}" in summary:
1036
+ if f"ERR_{_mk('onnx_ort_create')}" in summary:
951
1037
  return summary, data
952
1038
 
953
1039
  data[_mk("onnx_ort_sess")] = sess
@@ -975,6 +1061,8 @@ def validate_onnx_model(
975
1061
  summary,
976
1062
  data,
977
1063
  (lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
1064
+ repeat=repeat,
1065
+ warmup=warmup,
978
1066
  )
979
1067
  if f"ERR_{_mk('time_onnx_ort_run')}" in summary:
980
1068
  return summary, data
@@ -1051,7 +1139,7 @@ def call_torch_export_onnx(
1051
1139
  dynamo=False,
1052
1140
  dynamic_axes={
1053
1141
  k: v
1054
- for k, v in CoupleInputsDynamicShapes(args, kwargs, ds)
1142
+ for k, v in CoupleInputsDynamicShapes(args, kwargs, ds) # type: ignore[arg-type]
1055
1143
  .replace_by_string()
1056
1144
  .items()
1057
1145
  if isinstance(v, dict)
@@ -1229,6 +1317,13 @@ def call_torch_export_custom(
1229
1317
  "custom-nostrict",
1230
1318
  "custom-nostrict-default",
1231
1319
  "custom-nostrict-all",
1320
+ "custom-inline",
1321
+ "custom-strict-inline",
1322
+ "custom-strict-default-inline",
1323
+ "custom-strict-all-inline",
1324
+ "custom-nostrict-inline",
1325
+ "custom-nostrict-default-inline",
1326
+ "custom-nostrict-all-inline",
1232
1327
  }
1233
1328
  assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
1234
1329
  assert "model" in data, f"model is missing from data: {sorted(data)}"
@@ -1269,6 +1364,10 @@ def call_torch_export_custom(
1269
1364
  ),
1270
1365
  save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
1271
1366
  )
1367
+ inline = "-inline" in exporter
1368
+ if inline:
1369
+ export_options.aten_as_function = set()
1370
+
1272
1371
  options = OptimizationOptions(patterns=optimization) if optimization else None
1273
1372
  model = data["model"]
1274
1373
  kws = dict(
@@ -1279,6 +1378,7 @@ def call_torch_export_custom(
1279
1378
  large_model=True,
1280
1379
  return_optimize_report=True,
1281
1380
  verbose=max(verbose - 2, 0),
1381
+ inline=inline,
1282
1382
  )
1283
1383
  if opset:
1284
1384
  kws["target_opset"] = opset