onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -3
  3. onnx_diagnostic/ci_models/ci_helpers.py +12 -7
  4. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  5. onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
  6. onnx_diagnostic/export/api.py +295 -5
  7. onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
  8. onnx_diagnostic/export/dynamic_shapes.py +45 -3
  9. onnx_diagnostic/export/shape_helper.py +1 -0
  10. onnx_diagnostic/ext_test_case.py +9 -2
  11. onnx_diagnostic/helpers/bench_run.py +1 -1
  12. onnx_diagnostic/helpers/cache_helper.py +0 -8
  13. onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
  14. onnx_diagnostic/helpers/helper.py +30 -1
  15. onnx_diagnostic/helpers/log_helper.py +1 -3
  16. onnx_diagnostic/helpers/optim_helper.py +116 -0
  17. onnx_diagnostic/helpers/ort_session.py +5 -0
  18. onnx_diagnostic/tasks/image_text_to_text.py +19 -9
  19. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  20. onnx_diagnostic/tasks/text_generation.py +3 -0
  21. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
  22. onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
  23. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  24. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
  31. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  32. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  33. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  34. onnx_diagnostic/torch_models/validate.py +48 -0
  35. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/METADATA +3 -1
  36. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/RECORD +39 -36
  37. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/WHEEL +0 -0
  38. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/licenses/LICENSE.txt +0 -0
  39. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,10 @@ Requirements
8
8
  ::
9
9
 
10
10
  git+https://github.com/sdpython/experimental-experiment.git # optional
11
- huggingface_hub>=1.2.1
11
+ huggingface_hub
12
12
  onnx-diagnostic>=0.8.6
13
13
  onnxruntime>=1.23
14
- torch>=2.9 # weekly is better
14
+ torch>=2.10 # weekly is better
15
15
  tqdm
16
16
  transformers>=4.57
17
17
 
@@ -59,6 +59,7 @@ It is possible to overwrite this by by setting environment variable
59
59
  import os
60
60
  import sys
61
61
  import time
62
+ import warnings
62
63
  from typing import Any, Dict, List, Tuple
63
64
  from .ci_helpers import (
64
65
  check_for_discrepancies_and_log_everything_into_a_json_file,
@@ -97,7 +98,6 @@ def get_untrained_model(model_id: str, second_input: bool, verbose: int) -> Dict
97
98
  },
98
99
  # "_attn_implementation": "flash_attention_2",
99
100
  "_attn_implementation": "sdpa",
100
- "dtype": "float16",
101
101
  }
102
102
 
103
103
  config_reduction = _config_reduction
@@ -281,6 +281,10 @@ def main(
281
281
  ).eval()
282
282
  data = dict(model=model)
283
283
  config = model.config
284
+ if not hasattr(config, "bos_token_id") or not config.bos_token_id:
285
+ config.bos_token_id = 151643
286
+ if not hasattr(config, "eos_token_id") or not config.eos_token_id:
287
+ config.eos_token_id = 151645
284
288
  else:
285
289
  print("-- random model")
286
290
  data = get_untrained_model(model_id, second_input=second_input, verbose=1)
@@ -298,7 +302,11 @@ def main(
298
302
  print(f"-- config._attn_implementation={model.config._attn_implementation}")
299
303
  print(f"-- model.dtype={model.dtype}")
300
304
  print(f"-- model.device={model.device}")
301
- processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
305
+ try:
306
+ processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
307
+ except OSError as e:
308
+ warnings.warn(f"Unable to access internet due to {e!r}", ResourceWarning, stacklevel=0)
309
+ return
302
310
  print(f"-- processor={type(processor)}")
303
311
 
304
312
  export_inputs, other_inputs = None, None
@@ -1,6 +1,11 @@
1
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
1
+ import inspect
2
+ import os
3
+ import textwrap
4
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
2
5
  import torch
6
+ from .dynamic_shapes import ModelInputs
3
7
  from .onnx_plug import EagerDirectReplacementWithOnnx
8
+ from ..helpers import string_type
4
9
 
5
10
 
6
11
  def get_main_dispatcher(
@@ -70,6 +75,7 @@ def to_onnx(
70
75
  inline: bool = True,
71
76
  ) -> Any:
72
77
  """
78
+ Exports one model into ONNX.
73
79
  Common API for exporters. By default, the models are optimized to use the
74
80
  most efficient kernels implemented in :epkg:`onnxruntime`.
75
81
 
@@ -126,8 +132,12 @@ def to_onnx(
126
132
  from experimental_experiment.xbuilder import OptimizationOptions
127
133
 
128
134
  options = None
135
+ export_options = None
129
136
  if exporter_kwargs is not None:
130
137
  options = exporter_kwargs.pop("options", None)
138
+ export_options = exporter_kwargs.pop("export_options", None)
139
+ if export_options is None:
140
+ export_options = ExportOptions(save_ep=save_ep)
131
141
  if options is None and optimize:
132
142
  options = OptimizationOptions(
133
143
  patterns="default+onnxruntime" if optimizer_for_ort else "default"
@@ -138,7 +148,7 @@ def to_onnx(
138
148
  else None
139
149
  )
140
150
 
141
- return _to_onnx(
151
+ proto, opt_stats = _to_onnx(
142
152
  mod,
143
153
  args=args,
144
154
  kwargs=kwargs,
@@ -150,15 +160,52 @@ def to_onnx(
150
160
  dynamic_shapes=dynamic_shapes,
151
161
  large_model=True,
152
162
  output_dynamic_shapes=output_dynamic_shapes,
153
- export_options=ExportOptions(save_ep=save_ep),
163
+ export_options=export_options,
154
164
  options=options,
155
165
  inline=inline,
156
166
  dispatcher=main_dispatcher,
167
+ optimize=optimize,
168
+ return_optimize_report=True,
157
169
  **(exporter_kwargs or {}),
158
170
  )
171
+ if opt_stats and filename and os.path.exists(filename):
172
+ import pandas
173
+
174
+ stat_filename = f"{os.path.splitext(filename)[0]}.opt.xlsx"
175
+ pattern_stats = []
176
+ for k, v in opt_stats.items():
177
+ if "time" in k:
178
+ pattern_stats.append(dict(level="main", pattern=k, time_in=v))
179
+ pattern_stats.extend(
180
+ [{**obs, "level": "detailed"} for obs in opt_stats["optimization"]]
181
+ )
182
+ df = pandas.DataFrame(pattern_stats)
183
+ df.to_excel(stat_filename, index=False)
184
+ cols = [
185
+ c
186
+ for c in [
187
+ "level",
188
+ "pattern",
189
+ "time_in",
190
+ "iteration",
191
+ "inlined",
192
+ "removed",
193
+ "added",
194
+ "instances",
195
+ "changed",
196
+ "scale",
197
+ ]
198
+ if c in df.columns
199
+ ]
200
+ agg = {k: "sum" for k in cols if k not in ("level", "pattern")}
201
+ agg.update(dict(iteration="max", instances="mean"))
202
+ agg = {k: v for k, v in agg.items() if k in df.columns}
203
+ stat_filename = f"{os.path.splitext(filename)[0]}.opt.agg.xlsx"
204
+ df[cols].groupby(["level", "pattern"]).agg(agg).to_excel(stat_filename)
205
+
206
+ return proto
159
207
 
160
208
  if exporter in ("dynamo", "onnx-dynamo"):
161
- import os
162
209
  from ..helpers import flatten_object
163
210
  import onnxscript.rewriter.ort_fusions as ort_fusions
164
211
 
@@ -225,7 +272,6 @@ def to_onnx(
225
272
  return epo
226
273
 
227
274
  if exporter == "modelbuilder":
228
- import os
229
275
  from ..helpers import flatten_object, string_type
230
276
  from ..helpers.model_builder_helper import create_model_builder, save_model_builder
231
277
 
@@ -266,3 +312,247 @@ def to_onnx(
266
312
  return onx
267
313
 
268
314
  raise ValueError(f"Unknown exporter={exporter!r}")
315
+
316
+
317
+ class _WrapperToExportMethodToOnnx(torch.nn.Module):
318
+ """
319
+ Wraps an existing models in order to spy on inputs.
320
+ This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`.
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ mod: "torch.nn.Module",
326
+ method_name: str = "forward",
327
+ input_names: Optional[Sequence[str]] = None,
328
+ target_opset: Optional[Union[int, Dict[str, int]]] = None,
329
+ verbose: int = 0,
330
+ filename: Optional[str] = None,
331
+ output_names: Optional[List[str]] = None,
332
+ output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
333
+ exporter: str = "onnx-dynamo",
334
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
335
+ save_ep: Optional[str] = None,
336
+ optimize: bool = True,
337
+ optimizer_for_ort: bool = True,
338
+ use_control_flow_dispatcher: bool = False,
339
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
340
+ inline: bool = True,
341
+ convert_after_n_calls: int = 2,
342
+ patch_kwargs: Optional[Dict[str, Any]] = None,
343
+ skip_kwargs_names: Optional[Set[str]] = None,
344
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
345
+ ):
346
+ super().__init__()
347
+ self._model_to_call = mod
348
+ self._method_name = method_name
349
+ self._method_call = (
350
+ self._model_to_call.forward
351
+ if method_name == "forward"
352
+ else getattr(mod, method_name)
353
+ )
354
+ self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
355
+ self._convert_after_n_calls = convert_after_n_calls
356
+ self._patch_kwargs = patch_kwargs
357
+ self._method_src = None
358
+ self.verbose = verbose
359
+ self.skip_kwargs_names = skip_kwargs_names
360
+ self.dynamic_shapes = dynamic_shapes
361
+ self._to_onnx_kwargs = dict(
362
+ input_names=input_names,
363
+ target_opset=target_opset,
364
+ verbose=verbose,
365
+ filename=filename,
366
+ output_names=output_names,
367
+ output_dynamic_shapes=output_dynamic_shapes,
368
+ exporter=exporter,
369
+ exporter_kwargs=exporter_kwargs,
370
+ save_ep=save_ep,
371
+ optimize=optimize,
372
+ optimizer_for_ort=optimizer_for_ort,
373
+ use_control_flow_dispatcher=use_control_flow_dispatcher,
374
+ onnx_plugs=onnx_plugs,
375
+ inline=inline,
376
+ )
377
+ self._export_done = False
378
+
379
+ def __str__(self) -> str:
380
+ return self.__repr__()
381
+
382
+ def __repr__(self) -> str:
383
+ return (
384
+ f"{self.__class__.__name__}({self._model_to_call.__class__.__name__}."
385
+ f"{self._method_name})"
386
+ )
387
+
388
+ def forward(self, *args, **kwargs):
389
+ if not self._export_done:
390
+ self._inputs.append(
391
+ (
392
+ args,
393
+ (
394
+ kwargs
395
+ if not kwargs or not self.skip_kwargs_names
396
+ else {
397
+ k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names
398
+ }
399
+ ),
400
+ )
401
+ )
402
+ if self.verbose:
403
+ print(
404
+ f"[method_to_onnx] input[{len(self._inputs)-1}]: "
405
+ f"{string_type(self._inputs[-1], with_shape=True)}"
406
+ )
407
+ if len(self._inputs) >= self._convert_after_n_calls:
408
+ self._convert_method_to_onnx()
409
+ del self._inputs[:]
410
+ self._export_done = True
411
+ return self._method_call(*args, **kwargs)
412
+
413
+ def _convert_method_to_onnx(self):
414
+
415
+ def make_method(self):
416
+ inner_sig = inspect.signature(self._method_call)
417
+ params = [
418
+ p.replace(annotation=inspect._empty) for p in inner_sig.parameters.values()
419
+ ]
420
+ simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
421
+ args = str(simple_sig)[1:-1]
422
+ calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
423
+ src = textwrap.dedent(
424
+ f"""
425
+ def f(self, {args}):
426
+ return self._method_call({calls_args})
427
+ """
428
+ )
429
+ self._method_src = src
430
+ ns = {}
431
+ try:
432
+ exec(src, ns)
433
+ except NameError as e:
434
+ raise NameError(f"Unable to compile due to {e}\n{src}") from e
435
+ return ns["f"]
436
+
437
+ class WrapWithExactSignature(torch.nn.Module):
438
+ def __init__(self, parent):
439
+ super().__init__()
440
+ self._model_to_call = parent._model_to_call
441
+ self._method_call = parent._method_call
442
+
443
+ forward = make_method(self)
444
+
445
+ compiled_model = WrapWithExactSignature(self)
446
+
447
+ if self.dynamic_shapes is None:
448
+ mi = ModelInputs(compiled_model, self._inputs)
449
+ ds = mi.guess_dynamic_shapes()
450
+ if self.verbose:
451
+ print(f"[method_to_onnx] guess_dynamic_shapes={string_type(ds)}")
452
+ a, kw, nds = mi.move_to_kwargs(*self._inputs[-1], ds)
453
+ else:
454
+ a, kw = self._inputs[-1]
455
+ nds = [self.dynamic_shapes]
456
+ if self.verbose:
457
+ print(f"[method_to_onnx] export args={string_type(a, with_shape=True)}")
458
+ print(f"[method_to_onnx] export kwargs={string_type(kw, with_shape=True)}")
459
+ print(f"[method_to_onnx] dynamic_shapes={string_type(nds)}")
460
+ if self._patch_kwargs is None:
461
+ to_onnx(
462
+ compiled_model,
463
+ args=a,
464
+ kwargs=kw,
465
+ dynamic_shapes=nds[-1],
466
+ **self._to_onnx_kwargs,
467
+ )
468
+ return
469
+ from ..torch_export_patches import torch_export_patches
470
+
471
+ with torch_export_patches(**self._patch_kwargs):
472
+ to_onnx(
473
+ compiled_model,
474
+ args=a,
475
+ kwargs=kw,
476
+ dynamic_shapes=nds[-1],
477
+ **self._to_onnx_kwargs,
478
+ )
479
+
480
+
481
+ def method_to_onnx(
482
+ mod: "torch.nn.Module",
483
+ method_name: str = "forward",
484
+ input_names: Optional[Sequence[str]] = None,
485
+ target_opset: Optional[Union[int, Dict[str, int]]] = None,
486
+ verbose: int = 0,
487
+ filename: Optional[str] = None,
488
+ output_names: Optional[List[str]] = None,
489
+ output_dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
490
+ exporter: str = "onnx-dynamo",
491
+ exporter_kwargs: Optional[Dict[str, Any]] = None,
492
+ save_ep: Optional[str] = None,
493
+ optimize: bool = True,
494
+ optimizer_for_ort: bool = True,
495
+ use_control_flow_dispatcher: bool = False,
496
+ onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
497
+ inline: bool = True,
498
+ convert_after_n_calls: int = 2,
499
+ patch_kwargs: Optional[Dict[str, Any]] = None,
500
+ skip_kwargs_names: Optional[Set[str]] = None,
501
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
502
+ ) -> Callable:
503
+ """
504
+ Exports one method into ONNX for a module into ONNX.
505
+ It returns a new method which must be called by the user
506
+ at least twice with different values for the dynamic dimension
507
+ between triggering the conversion into ONNX.
508
+
509
+ :param mod_meth: function to export into ONNX
510
+ :param input_names: input names for the onnx model (optional)
511
+ :param target_opset: opset to target, if not specified, each converter
512
+ keeps its default value
513
+ :param verbose: verbosity level
514
+ :param filename: output filename, mandatory, the onnx model is saved on disk
515
+ :param output_names: to change the output of the onnx model
516
+ :param output_dynamic_shapes: to overwrite the dynamic shapes names
517
+ :param exporter: exporter to use (``onnx-dynamo``, ``modelbuilder``, ``custom``)
518
+ :param exporter_kwargs: additional parameters sent to the exporter
519
+ :param save_ep: saves the exported program
520
+ :param optimize: optimizes the model
521
+ :param optimizer_for_ort: optimizes the model for onnxruntime
522
+ :param use_control_flow_dispatcher: use the dispatcher created to supported
523
+ custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
524
+ :param onnx_plugs: the code was modified to replace some parts with onnx translation
525
+ :param inline: inline local functions
526
+ :param convert_after_n_calls: converts the model after this number of calls.
527
+ :param patch_kwargs: patch arguments
528
+ :param skip_kwargs_names: use default values for these parameters part of
529
+ the signature of the method to export
530
+ :param dynamic_shapes: dynamic shapes to use if the guessed ones are not right
531
+ :return: the output of the selected exporter, usually a structure including
532
+ an onnx model
533
+
534
+ See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
535
+ """
536
+ wrapped_model = _WrapperToExportMethodToOnnx(
537
+ mod=mod,
538
+ method_name=method_name,
539
+ input_names=input_names,
540
+ target_opset=target_opset,
541
+ verbose=verbose,
542
+ filename=filename,
543
+ output_names=output_names,
544
+ output_dynamic_shapes=output_dynamic_shapes,
545
+ exporter=exporter,
546
+ exporter_kwargs=exporter_kwargs,
547
+ save_ep=save_ep,
548
+ optimize=optimize,
549
+ optimizer_for_ort=optimizer_for_ort,
550
+ use_control_flow_dispatcher=use_control_flow_dispatcher,
551
+ onnx_plugs=onnx_plugs,
552
+ inline=inline,
553
+ convert_after_n_calls=convert_after_n_calls,
554
+ patch_kwargs=patch_kwargs,
555
+ skip_kwargs_names=skip_kwargs_names,
556
+ dynamic_shapes=dynamic_shapes,
557
+ )
558
+ return wrapped_model
@@ -11,6 +11,7 @@ from torch._higher_order_ops.utils import (
11
11
  unique_graph_id,
12
12
  validate_subgraph_args_types,
13
13
  )
14
+ import torch._dynamo.variables.higher_order_ops as hop
14
15
  from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
15
16
  from torch.utils._python_dispatch import _get_current_dispatch_mode
16
17
 
@@ -97,14 +98,18 @@ def _simple_loop_for_fn(
97
98
  f"Unexpected number of results {len(r)} for function {body_fn}, "
98
99
  f"expected {len(res[-1])}"
99
100
  )
101
+ assert all(isinstance(t, torch.Tensor) for t in r), (
102
+ f"Unexpected type {[type(_) for _ in r]} for returned by function {body_fn}, "
103
+ f"it must be a tuple of Tensor or a Tensor."
104
+ )
100
105
  res.append(r)
101
106
  else:
102
107
  assert isinstance(r, torch.Tensor), (
103
- f"Unexpected type {r} for function {body_fn}, "
104
- f"it must be a tuple or a Tensor."
108
+ f"Unexpected type {type(r)} coming from function {body_fn}, "
109
+ f"it must be a tuple of Tensor or a Tensor."
105
110
  )
106
111
  assert not res or len(res[-1]) == 1, (
107
- f"Unexpected number of results {len(r)} for function {body_fn}, "
112
+ f"Unexpected number of results {len(r)} coming from function {body_fn}, "
108
113
  f"expected {len(res[-1])}"
109
114
  )
110
115
  res.append((r,))
@@ -126,8 +131,6 @@ def _simple_loop_for_fn(
126
131
  )
127
132
 
128
133
 
129
- # from torch._functorch.utils import exposed_in
130
- # @exposed_in("torch")
131
134
  def _simple_loop_for(
132
135
  n_iter: Union[int, torch.Tensor],
133
136
  body_fn: Callable,
@@ -159,7 +162,7 @@ def _simple_loop_for(
159
162
 
160
163
  if torch.compiler.is_dynamo_compiling():
161
164
  return simple_loop_for_op(
162
- n_iter, body_fn, (n_iter, *operands), concatenation_dims=concatenation_dims
165
+ n_iter, body_fn, operands, concatenation_dims=concatenation_dims
163
166
  )
164
167
 
165
168
  if isinstance(n_iter, (bool, int, float)):
@@ -181,8 +184,10 @@ def _simple_loop_for(
181
184
 
182
185
  with setup_compilation_env() as _backend:
183
186
  return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims)
184
- # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
185
- # n_iter, body_fn, operands, concatenation_dims)
187
+ # This is needed to support function body using module weights or function body
188
+ # defined as a class method. This is yet to be implemented.
189
+ # cpl = torch.compile(_loop_for_op_wrapper, backend=_backend, fullgraph=True)
190
+ # return cpl(n_iter, body_fn, operands, concatenation_dims)
186
191
 
187
192
 
188
193
  def trace_simple_loop_for(
@@ -236,9 +241,15 @@ def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None):
236
241
  )
237
242
  mode = _get_current_dispatch_mode()
238
243
  assert mode is None, "Mode should never be enabled for CPU/CUDA key"
239
- return _simple_loop_for_fn(
240
- n_iter, body_fn, operands, concatenation_dims=concatenation_dims
244
+ is_fake = isinstance(n_iter, torch._subclasses.fake_tensor.FakeTensor)
245
+ res = _simple_loop_for_fn(n_iter, body_fn, operands, concatenation_dims=concatenation_dims)
246
+ assert is_fake or not any(
247
+ isinstance(r, torch._subclasses.fake_tensor.FakeTensor) for r in res
248
+ ), (
249
+ f"One result is a fake tensor but the inputs were not, type(n_iter)={type(n_iter)}, "
250
+ f"operands: {[type(_) for _ in operands]}, res: {[type(_) for _ in res]}"
241
251
  )
252
+ return res
242
253
 
243
254
 
244
255
  @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
@@ -267,6 +278,180 @@ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
267
278
  simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
268
279
 
269
280
 
281
+ class SimpleLoopForHigherOrderVariable(hop.TorchHigherOrderOperatorVariable):
282
+ """
283
+ Replicates the same pattern found for other higher order operators.
284
+ This enables recursive compilation and the use of modules inside a function.
285
+ """
286
+
287
+ _HOP_NAME = "simple_loop_for"
288
+ _ALLOW_FALLBACK_TO_EAGER = False
289
+ supports_input_mutation = False
290
+ supports_aliasing = False
291
+
292
+ def _call_function(
293
+ self,
294
+ tx: torch._dynamo.symbolic_convert.InstructionTranslator,
295
+ args: list[hop.VariableTracker],
296
+ kwargs: dict[str, hop.VariableTracker],
297
+ ) -> hop.VariableTracker:
298
+ """Main function."""
299
+ args, kwargs = hop.LazyVariableTracker.realize_all((args, kwargs))
300
+
301
+ for i, k in enumerate(["n_iter", "body_fn", "operands", "concatenated_dims"]):
302
+ if v := kwargs.pop(k, None):
303
+ assert i == len(args), "did not provide the right number of non-keyword args"
304
+ args.append(v)
305
+
306
+ if len(args) != 4 or kwargs:
307
+ hop.unimplemented(
308
+ gb_type="simple_loop_for: improper args/kwargs",
309
+ context=f"args: {args}, kwargs: {kwargs}",
310
+ explanation=f"torch.cond expects 4 positional arguments (got {len(args)}) "
311
+ f"and no keyword arguments (got {len(kwargs)})",
312
+ hints=[*hop.graph_break_hints.USER_ERROR],
313
+ )
314
+
315
+ # Specialize into one of the branches since pred is constant
316
+ n_iter, body_fn, operands, _concatenated_dims = args
317
+ assert type(n_iter) is not hop.ConstantVariable, (
318
+ f"n_iter is a {type(n_iter)}. When used simple_loop_for, "
319
+ f"it unrolls the loop. A SymInt should be used."
320
+ )
321
+
322
+ # predicate
323
+ if type(n_iter.realize()) not in (
324
+ hop.ConstantVariable,
325
+ hop.TensorVariable,
326
+ hop.SymNodeVariable,
327
+ ):
328
+ hop.unimplemented(
329
+ gb_type="simple_loop_for: improper predicate",
330
+ context=str(n_iter),
331
+ explanation=(
332
+ f"Expected `n_iter` to be an int or a integer "
333
+ f"tensor with a single item "
334
+ f"but got {str(type(n_iter))} with original python type "
335
+ f"{str(n_iter.python_type())}."
336
+ ),
337
+ hints=[*hop.graph_break_hints.USER_ERROR],
338
+ )
339
+
340
+ # operands
341
+ if not isinstance(operands, (hop.ListVariable, hop.TupleVariable)):
342
+ hop.unimplemented(
343
+ gb_type="simple_loop_for: improper operands",
344
+ context=str(operands),
345
+ explanation="Expected `operands` to be a list/tuple "
346
+ f"but got {operands.python_type()}.",
347
+ hints=[*hop.graph_break_hints.USER_ERROR],
348
+ )
349
+
350
+ operands_seq = operands.unpack_var_sequence(tx)
351
+ if not hop.only_consist_of(
352
+ operands, (hop.TensorVariable, hop.ConstantVariable, hop.SymNodeVariable)
353
+ ):
354
+ hop.unimplemented(
355
+ gb_type="simple_loop_for: improper operands contents",
356
+ context=str(operands),
357
+ explanation=(
358
+ "Expected `operands` to be a list/tuple of pytrees "
359
+ "that only consists of tensor leaves."
360
+ ),
361
+ hints=[*hop.graph_break_hints.USER_ERROR],
362
+ )
363
+
364
+ # branches
365
+ hop._check_supported_callable_arg(tx, body_fn, "body_fn")
366
+
367
+ def speculate_body():
368
+ (
369
+ (ret_val, ret_spec),
370
+ ret_graph,
371
+ ret_lifted_freevars,
372
+ ) = hop.speculate_subgraph(
373
+ tx,
374
+ args[1],
375
+ (args[0], *operands_seq),
376
+ {},
377
+ self._HOP_NAME,
378
+ source_target=self.value,
379
+ should_flatten_outputs=True,
380
+ # TODO - removing consts from control flow ops need more work
381
+ remove_consts_from_outputs=False,
382
+ supports_input_mutation=self.supports_input_mutation,
383
+ supports_aliasing=self.supports_aliasing,
384
+ )
385
+
386
+ # need to ensure we increase epoch so we don't memoize unbacked bindings
387
+ # across different subgraphs which can interfere with runtime assertion
388
+ # generation.
389
+ tx.fake_mode.epoch += 1
390
+
391
+ if not hop.only_consist_of(ret_val, (hop.TensorVariable, hop.ConstantVariable)):
392
+ hop.unimplemented(
393
+ gb_type="simple_loop_for: unsupported branch return type",
394
+ context=str(ret_val),
395
+ explanation=(
396
+ "Expected branches to return a possibly nested "
397
+ "pytree of tensors or constant ints."
398
+ ),
399
+ hints=[*hop.graph_break_hints.USER_ERROR],
400
+ )
401
+ for ret in ret_val.unpack_var_sequence(tx):
402
+ if ret.is_python_constant() and not isinstance(ret.as_python_constant(), int):
403
+ hop.unimplemented(
404
+ gb_type=(
405
+ "simple_loop_for: unsupported branch return type "
406
+ "(constant non-int)"
407
+ ),
408
+ context=str(ret_val),
409
+ explanation="Constants returned from branches must be ints.",
410
+ hints=[*hop.graph_break_hints.USER_ERROR],
411
+ )
412
+ return ret_val, ret_spec, ret_graph, ret_lifted_freevars
413
+
414
+ body_r, body_spec, body_graph, body_lifted_freevars = speculate_body()
415
+ body_nn_modules = dict(tx.output.nn_modules)
416
+
417
+ same_spec = body_spec.treespec.as_python_constant()
418
+ if same_spec is not NotImplemented and not same_spec:
419
+ hop.unimplemented(
420
+ gb_type="simple_loop_for: differing branch outputs",
421
+ context=(
422
+ f"body_spec: {body_spec.treespec}, false_spec: "
423
+ f"{body_spec.treespec}, same_spec: {same_spec}"
424
+ ),
425
+ explanation="Expected branches to return the same pytree structure.",
426
+ hints=[*hop.graph_break_hints.USER_ERROR],
427
+ )
428
+
429
+ body_name = tx.output.install_subgraph(
430
+ "loop_body", torch.fx.GraphModule(body_nn_modules, body_graph)
431
+ )
432
+ body_node = hop.make_attr(tx, body_name)
433
+ p_args = (
434
+ n_iter.as_proxy(),
435
+ body_node,
436
+ # We pick true_shared but it shouldn't matter
437
+ operands.as_proxy() + tuple(body_lifted_freevars.keys()),
438
+ )
439
+
440
+ return hop._call_function_and_unflatten_output(
441
+ tx,
442
+ simple_loop_for,
443
+ p_args,
444
+ {},
445
+ None,
446
+ body_spec,
447
+ body_r,
448
+ )
449
+
450
+
451
+ hop._hop_name_to_variable_class["simple_loop_for"] = SimpleLoopForHigherOrderVariable
452
+
453
+
454
+ # @torch._functorch.utils.exposed_in("torch")
270
455
  def simple_loop_for(
271
456
  n_iter: Union[int, torch.Tensor],
272
457
  body_fn: Callable,