onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,352 @@
1
+ import contextlib
2
+ from typing import Callable, List, Optional, Sequence, Tuple, Union
3
+ import torch
4
+ from torch._C import DispatchKey
5
+ from torch._ops import HigherOrderOperator
6
+ from torch._subclasses.fake_tensor import FakeTensorMode
7
+ import torch.utils._pytree as pytree
8
+ from torch._higher_order_ops.utils import (
9
+ check_input_alias_and_mutation_return_outputs,
10
+ reenter_make_fx,
11
+ unique_graph_id,
12
+ validate_subgraph_args_types,
13
+ )
14
+ from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
15
+ from torch.utils._python_dispatch import _get_current_dispatch_mode
16
+
17
+
18
+ class SimpleLoopForOp(HigherOrderOperator):
19
+ """Higher order op for :func:`simple_loop_for`."""
20
+
21
+ def __init__(self):
22
+ super().__init__("simple_loop_for")
23
+
24
+ def __call__(self, n_iter, body_fn, operands, concatenation_dims=None):
25
+ validate_subgraph_args_types(operands)
26
+ return super().__call__(n_iter, body_fn, operands, concatenation_dims)
27
+
28
+ def gen_schema(self, n_iter, body_fn, operands, concatenation_dims):
29
+ from torch._higher_order_ops.schema import HopSchemaGenerator
30
+ from torch._higher_order_ops.utils import materialize_as_graph
31
+
32
+ body_gm: torch.fx.GraphModule = materialize_as_graph( # type: ignore[annotation-unchecked]
33
+ body_fn, (torch.tensor(0, dtype=torch.int64), *operands)
34
+ )
35
+ (
36
+ _,
37
+ _,
38
+ _,
39
+ body_mutated_inputs,
40
+ body_outputs,
41
+ ) = check_input_alias_and_mutation_return_outputs(body_gm)
42
+ mutated_inputs = body_mutated_inputs
43
+
44
+ schema_gen = HopSchemaGenerator(self)
45
+ schema_gen.add_arg("n_iter", n_iter)
46
+ schema_gen.add_arg("body_fn", body_gm)
47
+ for idx, arg in enumerate(operands):
48
+ schema_gen.add_arg(f"operand{idx}", arg, is_mutated=idx in mutated_inputs)
49
+
50
+ for out in body_outputs:
51
+ schema_gen.add_output(out)
52
+ assert concatenation_dims is None or len(concatenation_dims) == len(body_outputs), (
53
+ f"concatenation_dims={concatenation_dims} but its length should be equal to "
54
+ f"the number of outputs ({len(body_outputs)})"
55
+ )
56
+ schema_gen.add_schema_tree_spec(n_iter, body_fn, operands, concatenation_dims)
57
+ return schema_gen.gen_schema()
58
+
59
+
60
+ simple_loop_for_op = SimpleLoopForOp()
61
+
62
+
63
+ def _simple_loop_for_fn(
64
+ n_iter: torch.Tensor,
65
+ body_fn: Callable,
66
+ operands: Tuple[torch.Tensor, ...] = (),
67
+ concatenation_dims: Optional[Sequence[int]] = None,
68
+ ) -> Tuple[torch.Tensor, ...]:
69
+ """
70
+ Python implementation of the loop.
71
+
72
+ :param n_iter: number of iteration
73
+ :param body_fn: function implementing the body
74
+ :param concatenation_dims: dimension used to reduce the list produced by the loop
75
+ :param operands: arguments to the loop body
76
+ :return: results
77
+ """
78
+ torch._check(
79
+ isinstance(n_iter, (int, torch.Tensor)),
80
+ lambda: f"Unexpected type {type(n_iter)} for n_iter",
81
+ )
82
+ torch._check(callable(body_fn), lambda: f"Unexpected type {type(body_fn)} for body_fn")
83
+ torch._check(
84
+ concatenation_dims is None or isinstance(concatenation_dims, (list, tuple)),
85
+ lambda: f"Unexpected type {type(concatenation_dims)} for concatenation_dims",
86
+ )
87
+ torch._check(
88
+ isinstance(operands, tuple), lambda: f"Unexpected type {type(operands)} for operands"
89
+ )
90
+ res: List[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = []
91
+ for i in torch.arange(
92
+ n_iter, dtype=torch.int64 if isinstance(n_iter, int) else n_iter.dtype
93
+ ):
94
+ r = body_fn(i, *operands)
95
+ if isinstance(r, tuple):
96
+ assert not res or len(r) == len(res[-1]), (
97
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
98
+ f"expected {len(res[-1])}"
99
+ )
100
+ res.append(r)
101
+ else:
102
+ assert isinstance(r, torch.Tensor), (
103
+ f"Unexpected type {r} for function {body_fn}, "
104
+ f"it must be a tuple or a Tensor."
105
+ )
106
+ assert not res or len(res[-1]) == 1, (
107
+ f"Unexpected number of results {len(r)} for function {body_fn}, "
108
+ f"expected {len(res[-1])}"
109
+ )
110
+ res.append((r,))
111
+
112
+ if not res:
113
+ return torch.empty(tuple(), dtype=torch.float32, device=operands[0].device)
114
+
115
+ n_res = len(res[0])
116
+ return tuple(
117
+ torch.cat(
118
+ [r[i] for r in res],
119
+ dim=(
120
+ 0
121
+ if concatenation_dims is None or i >= len(concatenation_dims)
122
+ else concatenation_dims[i]
123
+ ),
124
+ )
125
+ for i in range(n_res)
126
+ )
127
+
128
+
129
+ # from torch._functorch.utils import exposed_in
130
+ # @exposed_in("torch")
131
+ def _simple_loop_for(
132
+ n_iter: Union[int, torch.Tensor],
133
+ body_fn: Callable,
134
+ operands: Tuple[torch.Tensor, ...] = (),
135
+ concatenation_dims: Optional[Sequence[int]] = None,
136
+ ) -> Tuple[torch.Tensor, ...]:
137
+ def _validate_input(n_iter, body_fn, operands, concatenation_dims):
138
+ assert isinstance(
139
+ n_iter, (int, torch.Tensor, torch.SymInt)
140
+ ), f"Expected pred to be bool or tensor, but got {n_iter}."
141
+ assert (
142
+ not isinstance(n_iter, torch.Tensor) or n_iter.numel() == 1
143
+ ), f"Expected pred to be bool or single-element tensor, but got {n_iter}."
144
+ assert callable(body_fn), "Expect both branches to be callable."
145
+ assert isinstance(operands, (tuple, list)) and pytree.tree_all(
146
+ lambda t: isinstance(t, torch.Tensor), operands
147
+ ), (
148
+ "Expect operands to be a tuple of possibly nested dict/list/tuple that only "
149
+ f"consists of tensor leaves, but got {operands}."
150
+ )
151
+ assert concatenation_dims is None or (
152
+ isinstance(concatenation_dims, (list, tuple))
153
+ and all(isinstance(i, int) for i in concatenation_dims)
154
+ ), (
155
+ f"concatenation_dims should be None or a list of integers but it is "
156
+ f"{concatenation_dims}. Its length should be equal to the number of outputs."
157
+ )
158
+ assert torch._dynamo.is_dynamo_supported(), "simple_loop_for requires dynamo support."
159
+
160
+ if torch.compiler.is_dynamo_compiling():
161
+ return simple_loop_for_op(
162
+ n_iter, body_fn, (n_iter, *operands), concatenation_dims=concatenation_dims
163
+ )
164
+
165
+ if isinstance(n_iter, (bool, int, float)):
166
+ torch._check(
167
+ isinstance(n_iter, int),
168
+ lambda: f"n_iter must be an integer or a tensor not {type(n_iter)}",
169
+ )
170
+ return _simple_loop_for_fn(
171
+ n_iter, body_fn, operands, concatenation_dims=concatenation_dims
172
+ )
173
+
174
+ def _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims):
175
+ return simple_loop_for_op(n_iter, body_fn, operands, concatenation_dims)
176
+
177
+ _validate_input(n_iter, body_fn, operands, concatenation_dims)
178
+
179
+ # This requires torch>=2.10.
180
+ from torch._higher_order_ops.utils import setup_compilation_env
181
+
182
+ with setup_compilation_env() as _backend:
183
+ return _loop_for_op_wrapper(n_iter, body_fn, operands, concatenation_dims)
184
+ # return torch.compile(_loop_for_op_wrapper, backend=backend, fullgraph=True)(
185
+ # n_iter, body_fn, operands, concatenation_dims)
186
+
187
+
188
+ def trace_simple_loop_for(
189
+ proxy_mode, func_overload, n_iter, body_fn, operands, concatenation_dims
190
+ ):
191
+ """See function ``simple_loop_for``."""
192
+ assert isinstance(operands, (list, tuple)) and (
193
+ concatenation_dims is None
194
+ or (
195
+ isinstance(concatenation_dims, (list, tuple))
196
+ and all(isinstance(i, int) for i in concatenation_dims)
197
+ )
198
+ ), (
199
+ f"simple_loop_for operands must be a list or tuple of tensors and SymInts and "
200
+ f"concatenation_dims must be None or a list of integer, "
201
+ f"operands={[type(o) for o in operands]}, "
202
+ f"concatenation_dims={concatenation_dims}"
203
+ )
204
+
205
+ body_graph = reenter_make_fx(body_fn)(n_iter, *operands)
206
+
207
+ body_outs = []
208
+ for node in body_graph.graph.nodes:
209
+ if node.op == "output":
210
+ body_outs.extend(node.args)
211
+
212
+ # flat_body_outs = pytree.arg_tree_leaves(*body_outs)
213
+ _i, body_name = unique_graph_id(proxy_mode, prefix="body_graph")
214
+ proxy_mode.tracer.root.register_module(body_name, body_graph)
215
+ args = (n_iter, body_graph, operands, concatenation_dims)
216
+ proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
217
+ out_proxy = proxy_mode.tracer.create_proxy("call_function", func_overload, proxy_args, {})
218
+ out = func_overload(n_iter, body_graph, operands, concatenation_dims)
219
+ return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
220
+
221
+
222
+ @simple_loop_for_op.py_impl(DispatchKey.CompositeExplicitAutograd)
223
+ def loop_for_op_dense(n_iter, body_fn, operands, concatenation_dims=None):
224
+ """Registered eager mode implementation."""
225
+ assert all(isinstance(o, torch.Tensor) for o in operands) and (
226
+ concatenation_dims is None
227
+ or (
228
+ isinstance(concatenation_dims, (list, tuple))
229
+ and all(isinstance(i, int) for i in concatenation_dims)
230
+ )
231
+ ), (
232
+ f"simple_loop_for operands must be a list or tuple of tensors and SymInts and "
233
+ f"concatenation_dims must be None or a list of integer, "
234
+ f"operands={[type(o) for o in operands]}, "
235
+ f"concatenation_dims={concatenation_dims}"
236
+ )
237
+ mode = _get_current_dispatch_mode()
238
+ assert mode is None, "Mode should never be enabled for CPU/CUDA key"
239
+ return _simple_loop_for_fn(
240
+ n_iter, body_fn, operands, concatenation_dims=concatenation_dims
241
+ )
242
+
243
+
244
+ @simple_loop_for_op.py_impl(ProxyTorchDispatchMode)
245
+ def inner(mode, n_iter, body_fn, operands, concatenation_dims=None):
246
+ """Registered tracing implementation."""
247
+ return trace_simple_loop_for(
248
+ mode, simple_loop_for_op, n_iter, body_fn, operands, concatenation_dims
249
+ )
250
+
251
+
252
+ @simple_loop_for_op.py_impl(FakeTensorMode)
253
+ def simple_loop_for_fake_tensor_mode(mode, n_iter, body_fn, operands, concatenation_dims=None):
254
+ """Registered FakeMode implementation."""
255
+ ignore_fresh_unbacked = contextlib.nullcontext()
256
+ if mode.shape_env:
257
+ ignore_fresh_unbacked = mode.shape_env.ignore_fresh_unbacked_symbols()
258
+
259
+ with mode, ignore_fresh_unbacked:
260
+ flat_body_outs, true_body_spec = pytree.tree_flatten(body_fn(n_iter, *operands))
261
+
262
+ return pytree.tree_unflatten(flat_body_outs, true_body_spec)
263
+
264
+
265
+ # Registration for autograd.
266
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCPU)
267
+ simple_loop_for_op.fallthrough(torch._C.DispatchKey.AutogradCUDA)
268
+
269
+
270
+ def simple_loop_for(
271
+ n_iter: Union[int, torch.Tensor],
272
+ body_fn: Callable,
273
+ operands: Tuple[torch.Tensor, ...] = (),
274
+ concatenation_dims: Optional[Union[int, Sequence[int]]] = None,
275
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
276
+ """
277
+ Implements a simple loop for, the body is defined by a function which takes the
278
+ iteration number stored in a tensor, and other tensors.
279
+ It results one or several tensors in a tuple. All of them
280
+ are finally concatenated along the first dimension.
281
+
282
+ :param n_iter: iteration number
283
+ :param body: function
284
+ :param operands: bidy arguments
285
+ :param concatenation_dims: dimension or dimensions used to concatenate the output sequences
286
+ :return: contenated outputs, the output is a Tensor
287
+
288
+ An example with one output:
289
+
290
+ .. runpython::
291
+ :showcode:
292
+
293
+ import torch
294
+ from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
295
+
296
+
297
+ class Model(torch.nn.Module):
298
+ def forward(self, n_iter, x):
299
+ def body(i, x):
300
+ return (x[: i.item() + 1].unsqueeze(1),)
301
+
302
+ return simple_loop_for(n_iter, body, (x,))
303
+
304
+
305
+ model = Model()
306
+ n_iter = torch.tensor(4, dtype=torch.int64)
307
+ x = torch.arange(10, dtype=torch.float32)
308
+ ep = torch.export.export(
309
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
310
+ )
311
+ print(ep)
312
+
313
+ Another example with two outputs and a final concatenation on different axes.
314
+
315
+ .. runpython::
316
+ :showcode:
317
+
318
+ import torch
319
+ from onnx_diagnostic.export.cf_simple_loop_for import simple_loop_for
320
+
321
+
322
+ class Model(torch.nn.Module):
323
+ def forward(self, n_iter, x):
324
+ def body(i, x):
325
+ return (x[: i.item() + 1].unsqueeze(1), x[i.item() + 1 :].unsqueeze(0))
326
+
327
+ return simple_loop_for(n_iter, body, (x,), (0, 1))
328
+
329
+
330
+ model = Model()
331
+ n_iter = torch.tensor(4, dtype=torch.int64)
332
+ x = torch.arange(10, dtype=torch.float32)
333
+ ep = torch.export.export(
334
+ model, (n_iter, x), dynamic_shapes=({}, ({0: torch.export.Dim.DYNAMIC}))
335
+ )
336
+ print(ep)
337
+ """
338
+ res = _simple_loop_for(
339
+ n_iter,
340
+ body_fn,
341
+ operands,
342
+ concatenation_dims=(
343
+ (concatenation_dims,)
344
+ if isinstance(concatenation_dims, int)
345
+ else concatenation_dims
346
+ ),
347
+ )
348
+ torch._check(
349
+ isinstance(res, tuple),
350
+ lambda: f"Output of the loop should be a tuple not {type(res)}.",
351
+ )
352
+ return res[0] if len(res) == 1 else res
@@ -55,13 +55,13 @@ def is_exporting() -> bool:
55
55
  return _TEST_EXPORT or torch.compiler.is_exporting() or torch.compiler.is_compiling()
56
56
 
57
57
 
58
- def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
58
+ def _loop_for_onnx_fn(n_iter, body_fn, concatenation_dims, args):
59
59
  """
60
60
  Python implementation of the loop.
61
61
 
62
62
  :param n_iter: number of iteration
63
63
  :param body_fn: function implementing the body
64
- :param reduction_dim: dimension used to reduce the list produced by the loop
64
+ :param concatenation_dims: dimension used to reduce the list produced by the loop
65
65
  :param args: arguments to the loop body
66
66
  :return: results
67
67
  """
@@ -95,7 +95,9 @@ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
95
95
  torch.cat(
96
96
  [r[i] for r in res],
97
97
  dim=(
98
- 0 if reduction_dim is None or i >= len(reduction_dim) else reduction_dim[i]
98
+ 0
99
+ if concatenation_dims is None or i >= len(concatenation_dims)
100
+ else concatenation_dims[i]
99
101
  ),
100
102
  )
101
103
  for i in range(n_res)
@@ -106,7 +108,7 @@ def _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args):
106
108
  def make_custom_loop_for_onnx(
107
109
  n_iter: torch.Tensor,
108
110
  body_fn: Callable,
109
- reduction_dim: Optional[Sequence[int]],
111
+ concatenation_dims: Optional[Sequence[int]],
110
112
  args: Sequence[torch.Tensor],
111
113
  body_gm: Optional[torch.fx.GraphModule] = None,
112
114
  body_mutated_inputs: Optional[List[Any]] = None,
@@ -120,7 +122,7 @@ def make_custom_loop_for_onnx(
120
122
 
121
123
  :param n_iter: number of iterations defined by a tensor of no dimension
122
124
  :param body_fn: the loop body defined as a function
123
- :param reduction_dim: dimension used to concatenated the results
125
+ :param concatenation_dims: dimension used to concatenated the results
124
126
  :param args: list of tensors, input to the body
125
127
  :param body_gm: torch.fx.GraphModule equivalent to *body_gm*
126
128
  :param body_mutated_inputs: inputs to *body_gm*
@@ -133,7 +135,7 @@ def make_custom_loop_for_onnx(
133
135
  assert body_mutated_inputs is not None, "body_mutated_inputs cannot be None"
134
136
  assert body_outputs is not None, "body_outputs cannot be None"
135
137
  srank = "_".join("x".join(map(str, s.shape)) for s in body_outputs)
136
- sred = "x".join(map(str, reduction_dim)) if reduction_dim else ""
138
+ sred = "x".join(map(str, concatenation_dims)) if concatenation_dims else ""
137
139
  full_name = (
138
140
  body_fn.__qualname__.replace("<locals>", "L")
139
141
  .replace("<lambda>", "l")
@@ -169,14 +171,14 @@ def make_custom_loop_for_onnx(
169
171
  custom_def,
170
172
  _make_onx,
171
173
  (
172
- lambda g, sts, outputs, *args, bc=_make_onx, rd=reduction_dim, name=name: (
174
+ lambda g, sts, outputs, *args, bc=_make_onx, rd=concatenation_dims, name=name: (
173
175
  convert_custom_loop_into_onnx(
174
176
  g,
175
177
  sts,
176
178
  outputs,
177
179
  *args,
178
180
  body_callable=bc,
179
- reduction_dim=rd,
181
+ concatenation_dims=rd,
180
182
  name=name,
181
183
  )
182
184
  )
@@ -196,7 +198,7 @@ def convert_custom_loop_into_onnx(
196
198
  outputs: List[str],
197
199
  *args: str,
198
200
  body_callable: Callable[..., onnx.ModelProto],
199
- reduction_dim: Optional[Sequence[int]] = None,
201
+ concatenation_dims: Optional[Sequence[int]] = None,
200
202
  name: str = "loop_for_onnx",
201
203
  ) -> Union[str, List[str]]:
202
204
  """
@@ -207,7 +209,7 @@ def convert_custom_loop_into_onnx(
207
209
  :param outputs: output names
208
210
  :param args: input argument known at export time
209
211
  :param body: GraphProto, the loop body
210
- :param reduction_dim: the dimension to follow when aggregating the
212
+ :param concatenation_dims: the dimension to follow when aggregating the
211
213
  list of tensors after the loop ran
212
214
  :param name: to give the onnx nodes a name
213
215
  :return: output names
@@ -289,7 +291,11 @@ def convert_custom_loop_into_onnx(
289
291
  out,
290
292
  outputs=[o],
291
293
  name=name,
292
- axis=0 if not reduction_dim or i >= len(reduction_dim) else reduction_dim[i],
294
+ axis=(
295
+ 0
296
+ if not concatenation_dims or i >= len(concatenation_dims)
297
+ else concatenation_dims[i]
298
+ ),
293
299
  )
294
300
  for i, (out, o) in enumerate(zip(outloop, outputs))
295
301
  ]
@@ -337,7 +343,7 @@ def loop_for_onnx(
337
343
  n_iter: Union[torch.SymInt, torch.Tensor],
338
344
  body_fn: Callable[..., Tuple[torch.Tensor]],
339
345
  args: Sequence[torch.Tensor],
340
- reduction_dim: Optional[Sequence[int]] = None,
346
+ concatenation_dims: Optional[Sequence[int]] = None,
341
347
  ) -> Tuple[torch.Tensor, ...]:
342
348
  """
343
349
  High operators used to easily export a loop in ONNX.
@@ -353,7 +359,7 @@ def loop_for_onnx(
353
359
  in a tensor with no dimension, all the others
354
360
  are not changed during the loop
355
361
  :param args: the available tensors at every loop
356
- :param reduction_dim: the loop aggregated the results into list,
362
+ :param concatenation_dims: the loop aggregated the results into list,
357
363
  one of each output, each of them is concatenated into one
358
364
  tensor along one dimension, by default, it is the first
359
365
  dimension, but it can be defined otherwise
@@ -449,7 +455,7 @@ def loop_for_onnx(
449
455
  )
450
456
  print(ep)
451
457
 
452
- A last example with ``reduction_dim``:
458
+ A last example with ``concatenation_dims``:
453
459
 
454
460
  .. runpython::
455
461
  :showcode:
@@ -465,7 +471,7 @@ def loop_for_onnx(
465
471
  def body(i, x):
466
472
  return x[: i.item() + 1].unsqueeze(1), x[: i.item() + 1].unsqueeze(0) + 1
467
473
 
468
- two = loop_for_onnx(n_iter, body, (x,), reduction_dim=[0, 1])
474
+ two = loop_for_onnx(n_iter, body, (x,), concatenation_dims=[0, 1])
469
475
  return two[0] + two[1].T
470
476
 
471
477
 
@@ -516,7 +522,7 @@ def loop_for_onnx(
516
522
  name, _custom_ops = make_custom_loop_for_onnx(
517
523
  n_iter,
518
524
  body_fn,
519
- reduction_dim,
525
+ concatenation_dims,
520
526
  args,
521
527
  body_gm=body_gm,
522
528
  body_mutated_inputs=body_mutated_inputs,
@@ -525,4 +531,4 @@ def loop_for_onnx(
525
531
  fct = getattr(torch.ops.onnx_higher_ops, name)
526
532
  return fct(n_iter, *args)
527
533
 
528
- return _loop_for_onnx_fn(n_iter, body_fn, reduction_dim, args)
534
+ return _loop_for_onnx_fn(n_iter, body_fn, concatenation_dims, args)
@@ -700,6 +700,19 @@ def requires_onnx(version: str, msg: str = "") -> Callable:
700
700
  return lambda x: x
701
701
 
702
702
 
703
+ def requires_experimental_experiment(version: str, msg: str = "") -> Callable:
704
+ """Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
705
+ import packaging.version as pv
706
+ import experimental_experiment
707
+
708
+ if pv.Version(experimental_experiment.__version__) < pv.Version(version):
709
+ msg = (
710
+ f"onnx-array-api version {experimental_experiment.__version__} < {version}: {msg}"
711
+ )
712
+ return unittest.skip(msg)
713
+ return lambda x: x
714
+
715
+
703
716
  def requires_onnx_array_api(version: str, msg: str = "") -> Callable:
704
717
  """Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
705
718
  import packaging.version as pv
@@ -774,6 +787,7 @@ class ExtTestCase(unittest.TestCase):
774
787
  def setUpClass(cls):
775
788
  logger = logging.getLogger("onnxscript.optimizer.constant_folding")
776
789
  logger.setLevel(logging.ERROR)
790
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
777
791
  unittest.TestCase.setUpClass()
778
792
 
779
793
  @classmethod
@@ -818,6 +818,7 @@ def torch_export_patches(
818
818
  rewrite: Optional[List[Callable]] = None,
819
819
  dump_rewriting: Optional[str] = None,
820
820
  patch_details: Optional[PatchDetails] = None,
821
+ profile: Optional[str] = None,
821
822
  ) -> Callable:
822
823
  """
823
824
  Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -850,6 +851,8 @@ def torch_export_patches(
850
851
  :param dump_rewriting: dumps rewriting information in file beginning with that prefix
851
852
  :param patch_details: if specified, this class is used to stored every rewritten done.
852
853
  :param verbose: to show which patches is applied
854
+ :param profile: starts profiling whatever is called inside the context manager,
855
+ output the profiling into a text file
853
856
 
854
857
  The list of available patches.
855
858
 
@@ -1017,10 +1020,23 @@ def torch_export_patches(
1017
1020
  if verbose:
1018
1021
  print("[torch_export_patches] done patching")
1019
1022
 
1023
+ if profile:
1024
+ from pyinstrument import Profiler
1025
+
1026
+ profiler = Profiler()
1027
+ profiler.start()
1028
+ else:
1029
+ profiler = None
1030
+
1020
1031
  try:
1021
1032
  yield fct_callable
1022
1033
  finally:
1023
1034
 
1035
+ if profiler:
1036
+ profiler.stop()
1037
+ with open(profile, "w") as f:
1038
+ f.write(profiler.output_html())
1039
+
1024
1040
  # unpatch
1025
1041
 
1026
1042
  if verbose:
@@ -256,8 +256,12 @@ if patch_qwen2_5:
256
256
  return attn_output
257
257
 
258
258
  def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259
- first_tensor = next(a for a in args if a is not None)
260
- dtype = first_tensor.dtype
259
+ first_float_tensor = next(
260
+ a
261
+ for a in args
262
+ if a is not None and a.dtype in {torch.float16, torch.float32, torch.bfloat16}
263
+ )
264
+ dtype = first_float_tensor.dtype
261
265
  strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
262
266
  itype = torch_dtype_to_onnx_dtype(dtype)
263
267
  if strategy is not None:
@@ -269,7 +273,7 @@ if patch_qwen2_5:
269
273
  if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
270
274
  # first_tensor may be a SymbolicTensor (onnx).
271
275
  # is_cuda is not available.
272
- if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
276
+ if hasattr(first_float_tensor, "is_cuda") and first_float_tensor.is_cuda:
273
277
  return "PACKED", itype
274
278
  return "LOOPMHA", itype
275
279
  raise AssertionError(
@@ -733,3 +737,71 @@ if patch_qwen2_5:
733
737
  attn_output = attn_output.reshape(seq_length, -1).contiguous()
734
738
  attn_output = self.proj(attn_output)
735
739
  return attn_output
740
+
741
+ class patched_Qwen2_5_VLModel:
742
+ _PATCHES_ = ["get_placeholder_mask"]
743
+ _PATCHED_CLASS_ = transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLModel
744
+
745
+ def get_placeholder_mask(
746
+ self,
747
+ input_ids: torch.LongTensor,
748
+ inputs_embeds: torch.FloatTensor,
749
+ image_features: Optional[torch.FloatTensor] = None,
750
+ video_features: Optional[torch.FloatTensor] = None,
751
+ ):
752
+ if input_ids is None:
753
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
754
+ torch.tensor(
755
+ self.config.image_token_id,
756
+ dtype=torch.long,
757
+ device=inputs_embeds.device,
758
+ )
759
+ )
760
+ special_image_mask = special_image_mask.all(-1)
761
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
762
+ torch.tensor(
763
+ self.config.video_token_id,
764
+ dtype=torch.long,
765
+ device=inputs_embeds.device,
766
+ )
767
+ )
768
+ special_video_mask = special_video_mask.all(-1)
769
+ else:
770
+ special_image_mask = input_ids == self.config.image_token_id
771
+ special_video_mask = input_ids == self.config.video_token_id
772
+
773
+ special_image_mask = (
774
+ special_image_mask.unsqueeze(-1)
775
+ .expand_as(inputs_embeds)
776
+ .to(inputs_embeds.device)
777
+ )
778
+
779
+ # PATCHED: we should use torch._check
780
+ # but this fails for compilation. It cannot be verified with FakeTensors
781
+ # torch._check(
782
+ # image_features is None
783
+ # or inputs_embeds[special_image_mask].numel() == image_features.numel(),
784
+ # lambda: (
785
+ # f"Image features and image tokens do not match: tokens: "
786
+ # f"{special_image_mask.sum()}, features {image_features.shape[0]}"
787
+ # ),
788
+ # )
789
+
790
+ special_video_mask = (
791
+ special_video_mask.unsqueeze(-1)
792
+ .expand_as(inputs_embeds)
793
+ .to(inputs_embeds.device)
794
+ )
795
+
796
+ # PATCHED: we should use torch._check
797
+ # but this fails for compilation. It cannot be verified with FakeTensors
798
+ # torch._check(
799
+ # video_features is None
800
+ # or inputs_embeds[special_video_mask].numel() == video_features.numel(),
801
+ # lambda: (
802
+ # f"Videos features and video tokens do not match: tokens: "
803
+ # f"{special_video_mask.sum()}, features {video_features.shape[0]}"
804
+ # ),
805
+ # )
806
+
807
+ return special_image_mask, special_video_mask
@@ -77,6 +77,7 @@ if patch_qwen2_5:
77
77
  patched_Qwen2_5_VisionTransformerPretrainedModel,
78
78
  patched_Qwen2_5_VLVisionAttentionOneIteration,
79
79
  patched_Qwen2_5_VLVisionAttention,
80
+ patched_Qwen2_5_VLModel,
80
81
  PLUGS as PLUGS_Qwen25,
81
82
  )
82
83