onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +154 -3
  3. onnx_diagnostic/ci_models/__init__.py +0 -0
  4. onnx_diagnostic/ci_models/ci_helpers.py +435 -0
  5. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  6. onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
  7. onnx_diagnostic/export/api.py +1 -0
  8. onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
  9. onnx_diagnostic/export/control_flow_onnx.py +23 -17
  10. onnx_diagnostic/ext_test_case.py +23 -2
  11. onnx_diagnostic/helpers/bench_run.py +1 -1
  12. onnx_diagnostic/helpers/log_helper.py +1 -3
  13. onnx_diagnostic/helpers/optim_helper.py +116 -0
  14. onnx_diagnostic/tasks/image_text_to_text.py +15 -5
  15. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  16. onnx_diagnostic/tasks/text_generation.py +3 -0
  17. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
  18. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  19. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  20. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
  22. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
  24. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  25. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  26. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
  27. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  28. onnx_diagnostic/torch_onnx/compare.py +357 -0
  29. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
  30. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
  31. onnx_diagnostic/export/control_flow.py +0 -214
  32. onnx_diagnostic/export/control_flow_research.py +0 -140
  33. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
  34. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
  35. {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
@@ -700,6 +700,19 @@ def requires_onnx(version: str, msg: str = "") -> Callable:
700
700
  return lambda x: x
701
701
 
702
702
 
703
+ def requires_experimental_experiment(version: str, msg: str = "") -> Callable:
704
+ """Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
705
+ import packaging.version as pv
706
+ import experimental_experiment
707
+
708
+ if pv.Version(experimental_experiment.__version__) < pv.Version(version):
709
+ msg = (
710
+ f"onnx-array-api version {experimental_experiment.__version__} < {version}: {msg}"
711
+ )
712
+ return unittest.skip(msg)
713
+ return lambda x: x
714
+
715
+
703
716
  def requires_onnx_array_api(version: str, msg: str = "") -> Callable:
704
717
  """Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
705
718
  import packaging.version as pv
@@ -774,6 +787,7 @@ class ExtTestCase(unittest.TestCase):
774
787
  def setUpClass(cls):
775
788
  logger = logging.getLogger("onnxscript.optimizer.constant_folding")
776
789
  logger.setLevel(logging.ERROR)
790
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
777
791
  unittest.TestCase.setUpClass()
778
792
 
779
793
  @classmethod
@@ -1253,6 +1267,7 @@ class ExtTestCase(unittest.TestCase):
1253
1267
  :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
1254
1268
  """
1255
1269
  from .helpers import string_type, string_diff, max_diff
1270
+ from .helpers.torch_helper import torch_deepcopy
1256
1271
  from .helpers.rt_helper import make_feeds
1257
1272
  from .helpers.ort_session import InferenceSessionForTorch
1258
1273
 
@@ -1269,6 +1284,12 @@ class ExtTestCase(unittest.TestCase):
1269
1284
  model_file = proto
1270
1285
  name = proto
1271
1286
  proto = onnx.load(name)
1287
+ elif hasattr(proto, "save"):
1288
+ name = f"{test_name}.onnx"
1289
+ proto.save(name)
1290
+ proto = onnx.load(name)
1291
+ elif hasattr(proto, "model_proto"):
1292
+ proto = proto.model_proto
1272
1293
  elif not self.unit_test_going():
1273
1294
  assert isinstance(
1274
1295
  proto, onnx.ModelProto
@@ -1327,9 +1348,9 @@ class ExtTestCase(unittest.TestCase):
1327
1348
  if copy_inputs:
1328
1349
  expected = [
1329
1350
  (
1330
- model(*copy.deepcopy(inp))
1351
+ model(*torch_deepcopy(inp))
1331
1352
  if isinstance(inp, tuple)
1332
- else model(**copy.deepcopy(inp))
1353
+ else model(**torch_deepcopy(inp))
1333
1354
  )
1334
1355
  for inp in inputs
1335
1356
  ]
@@ -20,7 +20,7 @@ class BenchmarkError(RuntimeError):
20
20
 
21
21
 
22
22
  def _clean_string(s: str) -> str:
23
- cleaned = [c for c in s if 32 <= ord(c) < 127 and c not in {","}]
23
+ cleaned = [c for c in s if 32 <= ord(c) < 127 and c not in {",", ":"}]
24
24
  return "".join(cleaned)
25
25
 
26
26
 
@@ -1921,9 +1921,7 @@ class CubeLogsPerformance(CubeLogs):
1921
1921
  return lambdas[formula]
1922
1922
 
1923
1923
  if formula == "onnx_n_nodes_no_cst":
1924
- return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(
1925
- df, "op_onnx__Constant", 0
1926
- ).fillna(0)
1924
+ return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(df, "op_onnx__Constant", 0)
1927
1925
  if formula == "peak_gpu_torch":
1928
1926
  return lambda df: gdf(df, "mema_gpu_5_after_export") - gdf(df, "mema_gpu_4_reset")
1929
1927
  if formula == "peak_gpu_nvidia":
@@ -0,0 +1,116 @@
1
+ from typing import Optional, Union
2
+ import pprint
3
+ import onnx
4
+
5
+
6
+ def optimize_model(
7
+ algorithm: str,
8
+ model: Union[onnx.ModelProto, str],
9
+ output: Optional[str] = None,
10
+ processor: Optional[str] = None,
11
+ infer_shapes: bool = True,
12
+ remove_shape_info: bool = False,
13
+ verbose: int = 1,
14
+ ):
15
+ """
16
+ Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
17
+ and replaces them by the corresponding nodes. It also does basic optimization
18
+ such as removing identity nodes or unused nodes.
19
+
20
+ :param algorithm: algorithm to choose
21
+ :param model: model to optimize as a proto or a filename
22
+ :param output: if not empty, the optimized model is saved
23
+ :param processor: optimization are done for the processor
24
+ :param infer_shapes: infer shapes before optimizing, this might not be
25
+ available for all algorithm
26
+ :param remove_shape_info: remove shape information before saving the model
27
+ :param verbose: verbosity level
28
+ :return: optimized model
29
+
30
+ The goal is to make the model faster.
31
+ Argument patterns defines the patterns to apply or the set of patterns.
32
+ It is possible to show statistics or to remove a particular pattern.
33
+ Here are some environment variables which can be used to trigger
34
+ these displays.
35
+
36
+ Available options algorithms, default and default+runtime:
37
+
38
+ - ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply
39
+ those patterns when optimizing a model
40
+ - ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
41
+ - ``PATTERN=<pattern1,pattern2,...>``: increase verbosity
42
+ for specific patterns to understand why one pattern was not applied,
43
+ this shows which line is rejecting a pattern if it seems one pattern was missed
44
+ """
45
+ if isinstance(model, str):
46
+ if verbose:
47
+ print(f"[optimize_model] load {model!r}")
48
+ proto = onnx.load(model)
49
+ if verbose:
50
+ print("[optimize_model] done loading.")
51
+ else:
52
+ proto = model
53
+
54
+ if verbose:
55
+ print(f"[optimize_model] optimize with {algorithm!r}")
56
+ if algorithm in {"default", "default+onnxruntime"}:
57
+ from experimental_experiment.xoptim import get_pattern_list
58
+ from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
59
+
60
+ pats = get_pattern_list(algorithm)
61
+
62
+ gr = GraphBuilder(
63
+ proto,
64
+ infer_shapes_options=infer_shapes,
65
+ optimization_options=OptimizationOptions(
66
+ patterns=pats,
67
+ verbose=verbose,
68
+ remove_unused=True,
69
+ constant_folding=True,
70
+ remove_identity=True,
71
+ max_iter=max(100, len(proto.graph.node) // 2),
72
+ processor=processor or "CPU",
73
+ ),
74
+ )
75
+ if verbose:
76
+ print(f"[optimize_model] starts optimizing with {len(pats)} patterns")
77
+ print(f"[optimize_model] model has {len(proto.graph.node)} nodes")
78
+ opt_onx, report = gr.to_onnx(optimize=True, return_optimize_report=True)
79
+ if verbose:
80
+ print("[optimize_model] optimization report")
81
+ pprint.pprint(report)
82
+ print("[optimize_model] done")
83
+
84
+ elif algorithm == "slim":
85
+ import onnxslim
86
+
87
+ opt_onx = onnxslim.slim(proto, no_shape_infer=not infer_shapes)
88
+ elif algorithm in {"ir", "os_ort"}:
89
+ import onnx_ir
90
+ import onnxscript.optimizer
91
+ from onnxscript.rewriter.ort_fusions import optimize_for_ort
92
+
93
+ model_ir = onnx_ir.from_proto(proto)
94
+ if algorithm == "ir":
95
+ onnxscript.optimizer.optimize(model_ir)
96
+ else:
97
+ optimize_for_ort(model_ir)
98
+ opt_onx = onnx_ir.serde.serialize_model(model_ir)
99
+
100
+ del proto
101
+ if verbose:
102
+ print(f"[optimize_model] done optimizing, model has {len(opt_onx.graph.node)} nodes")
103
+ if remove_shape_info:
104
+ if verbose:
105
+ print(f"[optimize_model] remove shape information {len(opt_onx.graph.value_info)}")
106
+ del opt_onx.graph.value_info[:]
107
+ if verbose:
108
+ print("[optimize_model] done removing shape info")
109
+
110
+ if output:
111
+ if verbose:
112
+ print(f"[optimize_model] save file into {output!r}")
113
+ onnx.save(opt_onx, output, save_as_external_data=True)
114
+ if verbose:
115
+ print("[optimize_model] done saving")
116
+ return opt_onx
@@ -13,6 +13,10 @@ from .data import get_data
13
13
  __TASK__ = "image-text-to-text"
14
14
 
15
15
 
16
+ def should_have_vision_config(config):
17
+ return config.architectures != ["FuyuForCausalLM"]
18
+
19
+
16
20
  def reduce_model_config(config: Any) -> Dict[str, Any]:
17
21
  """Reduces a model size."""
18
22
  kwargs: Dict[str, Any] = {}
@@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
477
481
  "hidden_size",
478
482
  "pad_token_id",
479
483
  )
480
- check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
484
+ if should_have_vision_config(config):
485
+ check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
481
486
  text_config = True
482
487
  else:
483
488
  check_hasattr(
@@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
491
496
  "vision_config",
492
497
  )
493
498
  text_config = False
494
- check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
499
+ if should_have_vision_config(config):
500
+ check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
495
501
  kwargs = dict(
496
502
  head_dim=(
497
503
  16
@@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
552
558
  ),
553
559
  width=(
554
560
  224
555
- if config is None or not hasattr(config.vision_config, "image_size")
561
+ if config is None
562
+ or not should_have_vision_config(config)
563
+ or not hasattr(config.vision_config, "image_size")
556
564
  else config.vision_config.image_size
557
565
  ),
558
566
  height=(
559
567
  224
560
- if config is None or not hasattr(config.vision_config, "image_size")
568
+ if config is None
569
+ or not should_have_vision_config(config)
570
+ or not hasattr(config.vision_config, "image_size")
561
571
  else config.vision_config.image_size
562
572
  ),
563
573
  num_channels=(
564
574
  3
565
- if config is None
575
+ if config is None or not should_have_vision_config(config)
566
576
  else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
567
577
  ),
568
578
  pad_token_id=(
@@ -18,6 +18,22 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
18
18
  config.num_decoder_layers = min(config.num_decoder_layers, 2)
19
19
  if hasattr(config, "num_hidden_layers"):
20
20
  config.num_hidden_layers = min(config.num_hidden_layers, nhl())
21
+ if hasattr(config, "encoder") and hasattr(config.encoder, "layer_types"):
22
+ default_layer_types = [
23
+ "sliding_attention",
24
+ "full_attention",
25
+ "sliding_attention",
26
+ "full_attention",
27
+ ]
28
+ config.encoder.num_hidden_layers = 4
29
+ config.encoder.layer_types = (
30
+ default_layer_types if config is None else config.encoder.layer_types[:4]
31
+ )
32
+ config.decoder.num_hidden_layers = 4
33
+ config.decoder.layer_types = (
34
+ default_layer_types if config is None else config.decoder.layer_types[:4]
35
+ )
36
+
21
37
  update_config(config, kwargs)
22
38
  return kwargs
23
39
 
@@ -177,55 +193,75 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
177
193
 
178
194
  If the configuration is None, the function selects typical dimensions.
179
195
  """
196
+ path = 1
180
197
  if config is not None:
181
- check_hasattr(
182
- config,
183
- "vocab_size",
184
- "hidden_size",
185
- "num_attention_heads",
186
- ("num_hidden_layers", "num_layers"),
187
- ("n_positions", "d_model"),
188
- (
189
- "num_key_value_heads",
190
- "num_heads",
191
- ("decoder_attention_heads", "encoder_attention_heads"),
192
- ),
193
- )
194
- # exceptions = {
195
- # "PLBartForConditionalGeneration": (
196
- # lambda c: c.encoder_attention_heads + c.decoder_attention_heads
197
- # )
198
- # }
199
- kwargs = dict(
200
- batch_size=2,
201
- sequence_length=30,
202
- sequence_length2=3,
203
- head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
204
- head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
205
- dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
206
- num_hidden_layers=(
207
- 8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
208
- ),
209
- num_key_value_heads_encoder=(
210
- 16
211
- if config is None
212
- else _pick(
198
+ if hasattr(config, "num_attention_heads"):
199
+ check_hasattr(
213
200
  config,
214
- "encoder_attention_heads",
215
- "num_key_value_heads",
216
- "num_heads",
201
+ "vocab_size",
202
+ "hidden_size",
203
+ "num_attention_heads",
204
+ ("num_hidden_layers", "num_layers"),
205
+ ("n_positions", "d_model"),
206
+ (
207
+ "num_key_value_heads",
208
+ "num_heads",
209
+ ("decoder_attention_heads", "encoder_attention_heads"),
210
+ ),
217
211
  )
218
- ),
219
- num_key_value_heads_decoder=(
220
- 16
221
- if config is None
222
- else _pick(
223
- config,
224
- "decoder_attention_heads",
225
- "num_key_value_heads",
226
- "num_heads",
227
- )
228
- ),
229
- encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
230
- )
212
+ else:
213
+ check_hasattr(config, "encoder", "decoder")
214
+ path = 2
215
+
216
+ if path == 1:
217
+ kwargs = dict(
218
+ batch_size=2,
219
+ sequence_length=30,
220
+ sequence_length2=3,
221
+ head_dim_encoder=(
222
+ 16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim")
223
+ ),
224
+ head_dim_decoder=(
225
+ 16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim")
226
+ ),
227
+ dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
228
+ num_hidden_layers=(
229
+ 8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
230
+ ),
231
+ num_key_value_heads_encoder=(
232
+ 16
233
+ if config is None
234
+ else _pick(
235
+ config,
236
+ "encoder_attention_heads",
237
+ "num_key_value_heads",
238
+ "num_heads",
239
+ )
240
+ ),
241
+ num_key_value_heads_decoder=(
242
+ 16
243
+ if config is None
244
+ else _pick(
245
+ config,
246
+ "decoder_attention_heads",
247
+ "num_key_value_heads",
248
+ "num_heads",
249
+ )
250
+ ),
251
+ encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
252
+ )
253
+ else:
254
+ kwargs = dict(
255
+ batch_size=2,
256
+ sequence_length=30,
257
+ sequence_length2=3,
258
+ dummy_max_token_id=config.encoder.vocab_size - 1,
259
+ num_key_value_heads_encoder=config.encoder.num_key_value_heads,
260
+ num_key_value_heads_decoder=config.decoder.num_key_value_heads,
261
+ num_hidden_layers=len(config.encoder.layer_types),
262
+ head_dim_encoder=config.encoder.head_dim,
263
+ head_dim_decoder=config.decoder.head_dim,
264
+ encoder_dim=256,
265
+ )
266
+
231
267
  return kwargs, get_inputs
@@ -40,6 +40,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
40
40
  state_size=8 if config is None else getattr(config, "state_size", None),
41
41
  conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
42
42
  )
43
+ elif config.__class__.__name__ == "FunnelConfig":
44
+ # does not support num_hidden_layers
45
+ kwargs = dict()
43
46
  else:
44
47
  kwargs = dict(
45
48
  head_dim=getattr(
@@ -221,6 +221,7 @@ def _patch_torch(
221
221
  catch_constraints: bool,
222
222
  stop_if_static: int,
223
223
  ) -> Tuple[Optional[Callable], ...]:
224
+ import packaging.version as pv
224
225
  import torch
225
226
  import torch.jit
226
227
  import torch._export.non_strict_utils # produce_guards_and_solve_constraints
@@ -238,6 +239,11 @@ def _patch_torch(
238
239
  patched_ShapeEnv,
239
240
  )
240
241
 
242
+ if pv.Version(torch.__version__) >= pv.Version("2.9.99"):
243
+ from .patches.patch_torch import patched_DynamicDimConstraintPrinter
244
+ else:
245
+ patched_DynamicDimConstraintPrinter = None
246
+
241
247
  f___constrain_user_specified_dimhint_range = None
242
248
  f__broadcast_in_dim_meta = None
243
249
  f__broadcast_shapes = None
@@ -259,6 +265,17 @@ def _patch_torch(
259
265
  print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
260
266
  print("[torch_export_patches] patch pytorch")
261
267
 
268
+ # torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
269
+ if patched_DynamicDimConstraintPrinter is not None:
270
+ f__print_symbol = (
271
+ torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
272
+ )
273
+ torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
274
+ patched_DynamicDimConstraintPrinter._print_Symbol
275
+ )
276
+ else:
277
+ f__print_symbol = None
278
+
262
279
  # torch.vmap
263
280
  f_vmap = torch.vmap
264
281
  torch.vmap = patched_vmap
@@ -392,6 +409,7 @@ def _patch_torch(
392
409
  f_shape_env__log_guard,
393
410
  f_shape_env__set_replacement,
394
411
  f_vmap,
412
+ f__print_symbol,
395
413
  )
396
414
 
397
415
 
@@ -416,6 +434,7 @@ def _unpatch_torch(
416
434
  f_shape_env__log_guard: Optional[Callable],
417
435
  f_shape_env__set_replacement: Optional[Callable],
418
436
  f_vmap: Optional[Callable],
437
+ f__print_symbol: Optional[Callable],
419
438
  ):
420
439
  import torch
421
440
  import torch.jit
@@ -423,6 +442,10 @@ def _unpatch_torch(
423
442
  from torch.fx.experimental.symbolic_shapes import ShapeEnv
424
443
 
425
444
  # this should disappear when torch.jit is removed
445
+ if f__print_symbol is not None:
446
+ torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
447
+ f__print_symbol
448
+ )
426
449
  torch.vmap = f_vmap
427
450
  torch.jit.isinstance = f_jit_isinstance
428
451
  torch._dynamo.mark_static_address = f_mark_static_address
@@ -818,6 +841,7 @@ def torch_export_patches(
818
841
  rewrite: Optional[List[Callable]] = None,
819
842
  dump_rewriting: Optional[str] = None,
820
843
  patch_details: Optional[PatchDetails] = None,
844
+ profile: Optional[str] = None,
821
845
  ) -> Callable:
822
846
  """
823
847
  Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -847,9 +871,12 @@ def torch_export_patches(
847
871
  this is done by function :func:`transform_method
848
872
  <onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
849
873
  its documentation provides possible values
850
- :param dump_rewriting: dumps rewriting information in file beginning with that prefix
851
- :param patch_details: if specified, this class is used to stored every rewritten done.
874
+ :param dump_rewriting: dumps rewriting information in file beginning with that prefix,
875
+ this only applied on the automated rewritings
876
+ :param patch_details: if specified, this class is used to stored every applied rewriting.
852
877
  :param verbose: to show which patches is applied
878
+ :param profile: starts profiling whatever is called inside the context manager,
879
+ output the profiling into a text file
853
880
 
854
881
  The list of available patches.
855
882
 
@@ -989,6 +1016,7 @@ def torch_export_patches(
989
1016
  f_shape_env__log_guard,
990
1017
  f_shape_env__set_replacement,
991
1018
  f_vmap,
1019
+ f__print_Symbol,
992
1020
  ) = _patch_torch(
993
1021
  verbose, patch_details, patch_torch, catch_constraints, stop_if_static
994
1022
  )
@@ -1017,10 +1045,23 @@ def torch_export_patches(
1017
1045
  if verbose:
1018
1046
  print("[torch_export_patches] done patching")
1019
1047
 
1048
+ if profile:
1049
+ from pyinstrument import Profiler
1050
+
1051
+ profiler = Profiler()
1052
+ profiler.start()
1053
+ else:
1054
+ profiler = None
1055
+
1020
1056
  try:
1021
1057
  yield fct_callable
1022
1058
  finally:
1023
1059
 
1060
+ if profiler:
1061
+ profiler.stop()
1062
+ with open(profile, "w") as f:
1063
+ f.write(profiler.output_html())
1064
+
1024
1065
  # unpatch
1025
1066
 
1026
1067
  if verbose:
@@ -1051,6 +1092,7 @@ def torch_export_patches(
1051
1092
  f_shape_env__log_guard,
1052
1093
  f_shape_env__set_replacement,
1053
1094
  f_vmap,
1095
+ f__print_Symbol,
1054
1096
  )
1055
1097
 
1056
1098
  if patch_transformers:
@@ -101,7 +101,10 @@ def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:
101
101
 
102
102
 
103
103
  def patched_float_arange(start, end, step):
104
- """Patched arange when start, end, step are floats."""
104
+ """
105
+ Patched arange when start, end, step are floats.
106
+ This patch should not be needed after 2.10.
107
+ """
105
108
  if is_torchdynamo_exporting():
106
109
  return torch.ops.patched.float_arange(start, end, step)
107
110
  else:
@@ -596,33 +596,41 @@ class RewriteControlFlow(ast.NodeTransformer):
596
596
  elts=[
597
597
  *[
598
598
  ast.Call(
599
- ast.Attribute(
600
- value=ast.Name(id="torch", ctx=ast.Load()),
601
- attr="arange",
602
- ctx=ast.Load(),
603
- ),
604
- args=[
605
- ast.Subscript(
606
- value=ast.Attribute(
607
- value=ast.Name(id=v, ctx=ast.Load()),
608
- attr="shape",
599
+ func=ast.Attribute(
600
+ value=ast.Call(
601
+ ast.Attribute(
602
+ value=ast.Name(id="torch", ctx=ast.Load()),
603
+ attr="arange",
609
604
  ctx=ast.Load(),
610
605
  ),
611
- slice=ast.Constant(value=0, ctx=ast.Load()),
606
+ args=[
607
+ ast.Subscript(
608
+ value=ast.Attribute(
609
+ value=ast.Name(id=v, ctx=ast.Load()),
610
+ attr="shape",
611
+ ctx=ast.Load(),
612
+ ),
613
+ slice=ast.Constant(value=0, ctx=ast.Load()),
614
+ ctx=ast.Load(),
615
+ ),
616
+ ],
617
+ keywords=[
618
+ ast.keyword(
619
+ arg="dtype",
620
+ value=ast.Attribute(
621
+ value=ast.Name(id="torch", ctx=ast.Load()),
622
+ attr="int64",
623
+ ctx=ast.Load(),
624
+ ),
625
+ )
626
+ ],
612
627
  ctx=ast.Load(),
613
628
  ),
614
- ],
615
- keywords=[
616
- ast.keyword(
617
- arg="dtype",
618
- value=ast.Attribute(
619
- value=ast.Name(id="torch", ctx=ast.Load()),
620
- attr="int64",
621
- ctx=ast.Load(),
622
- ),
623
- )
624
- ],
625
- ctx=ast.Load(),
629
+ attr="unsqueeze",
630
+ ctx=ast.Load(),
631
+ ),
632
+ args=[ast.Constant(value=1)],
633
+ keywords=[],
626
634
  )
627
635
  for v in scan_shape_vars
628
636
  ],