onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +108 -3
- onnx_diagnostic/ci_models/ci_helpers.py +12 -7
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
- onnx_diagnostic/export/api.py +295 -5
- onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
- onnx_diagnostic/export/dynamic_shapes.py +45 -3
- onnx_diagnostic/export/shape_helper.py +1 -0
- onnx_diagnostic/ext_test_case.py +9 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +0 -8
- onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
- onnx_diagnostic/helpers/helper.py +30 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/helpers/ort_session.py +5 -0
- onnx_diagnostic/tasks/image_text_to_text.py +19 -9
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
- onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- onnx_diagnostic/torch_models/validate.py +48 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/METADATA +3 -1
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/RECORD +39 -36
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/top_level.txt +0 -0
|
@@ -221,6 +221,7 @@ def _patch_torch(
|
|
|
221
221
|
catch_constraints: bool,
|
|
222
222
|
stop_if_static: int,
|
|
223
223
|
) -> Tuple[Optional[Callable], ...]:
|
|
224
|
+
import packaging.version as pv
|
|
224
225
|
import torch
|
|
225
226
|
import torch.jit
|
|
226
227
|
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
|
|
@@ -238,6 +239,11 @@ def _patch_torch(
|
|
|
238
239
|
patched_ShapeEnv,
|
|
239
240
|
)
|
|
240
241
|
|
|
242
|
+
if pv.Version(torch.__version__) >= pv.Version("2.9.99"):
|
|
243
|
+
from .patches.patch_torch import patched_DynamicDimConstraintPrinter
|
|
244
|
+
else:
|
|
245
|
+
patched_DynamicDimConstraintPrinter = None
|
|
246
|
+
|
|
241
247
|
f___constrain_user_specified_dimhint_range = None
|
|
242
248
|
f__broadcast_in_dim_meta = None
|
|
243
249
|
f__broadcast_shapes = None
|
|
@@ -259,6 +265,17 @@ def _patch_torch(
|
|
|
259
265
|
print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
|
|
260
266
|
print("[torch_export_patches] patch pytorch")
|
|
261
267
|
|
|
268
|
+
# torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
|
|
269
|
+
if patched_DynamicDimConstraintPrinter is not None:
|
|
270
|
+
f__print_symbol = (
|
|
271
|
+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
|
|
272
|
+
)
|
|
273
|
+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
|
|
274
|
+
patched_DynamicDimConstraintPrinter._print_Symbol
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
f__print_symbol = None
|
|
278
|
+
|
|
262
279
|
# torch.vmap
|
|
263
280
|
f_vmap = torch.vmap
|
|
264
281
|
torch.vmap = patched_vmap
|
|
@@ -392,6 +409,7 @@ def _patch_torch(
|
|
|
392
409
|
f_shape_env__log_guard,
|
|
393
410
|
f_shape_env__set_replacement,
|
|
394
411
|
f_vmap,
|
|
412
|
+
f__print_symbol,
|
|
395
413
|
)
|
|
396
414
|
|
|
397
415
|
|
|
@@ -416,6 +434,7 @@ def _unpatch_torch(
|
|
|
416
434
|
f_shape_env__log_guard: Optional[Callable],
|
|
417
435
|
f_shape_env__set_replacement: Optional[Callable],
|
|
418
436
|
f_vmap: Optional[Callable],
|
|
437
|
+
f__print_symbol: Optional[Callable],
|
|
419
438
|
):
|
|
420
439
|
import torch
|
|
421
440
|
import torch.jit
|
|
@@ -423,6 +442,10 @@ def _unpatch_torch(
|
|
|
423
442
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
424
443
|
|
|
425
444
|
# this should disappear when torch.jit is removed
|
|
445
|
+
if f__print_symbol is not None:
|
|
446
|
+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
|
|
447
|
+
f__print_symbol
|
|
448
|
+
)
|
|
426
449
|
torch.vmap = f_vmap
|
|
427
450
|
torch.jit.isinstance = f_jit_isinstance
|
|
428
451
|
torch._dynamo.mark_static_address = f_mark_static_address
|
|
@@ -848,8 +871,9 @@ def torch_export_patches(
|
|
|
848
871
|
this is done by function :func:`transform_method
|
|
849
872
|
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
|
|
850
873
|
its documentation provides possible values
|
|
851
|
-
:param dump_rewriting: dumps rewriting information in file beginning with that prefix
|
|
852
|
-
|
|
874
|
+
:param dump_rewriting: dumps rewriting information in file beginning with that prefix,
|
|
875
|
+
this only applied on the automated rewritings
|
|
876
|
+
:param patch_details: if specified, this class is used to stored every applied rewriting.
|
|
853
877
|
:param verbose: to show which patches is applied
|
|
854
878
|
:param profile: starts profiling whatever is called inside the context manager,
|
|
855
879
|
output the profiling into a text file
|
|
@@ -992,6 +1016,7 @@ def torch_export_patches(
|
|
|
992
1016
|
f_shape_env__log_guard,
|
|
993
1017
|
f_shape_env__set_replacement,
|
|
994
1018
|
f_vmap,
|
|
1019
|
+
f__print_Symbol,
|
|
995
1020
|
) = _patch_torch(
|
|
996
1021
|
verbose, patch_details, patch_torch, catch_constraints, stop_if_static
|
|
997
1022
|
)
|
|
@@ -1067,6 +1092,7 @@ def torch_export_patches(
|
|
|
1067
1092
|
f_shape_env__log_guard,
|
|
1068
1093
|
f_shape_env__set_replacement,
|
|
1069
1094
|
f_vmap,
|
|
1095
|
+
f__print_Symbol,
|
|
1070
1096
|
)
|
|
1071
1097
|
|
|
1072
1098
|
if patch_transformers:
|
|
@@ -191,7 +191,7 @@ class PatchDetails:
|
|
|
191
191
|
ep = torch.export.export(
|
|
192
192
|
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
|
|
193
193
|
)
|
|
194
|
-
patches = details.
|
|
194
|
+
patches = details.patches_involved_in_graph(ep.graph)
|
|
195
195
|
report = details.make_report(patches, format="rst")
|
|
196
196
|
print(report)
|
|
197
197
|
"""
|
|
@@ -235,7 +235,7 @@ class PatchDetails:
|
|
|
235
235
|
"""Returns the data for a dataframe."""
|
|
236
236
|
return [p.to_dict() for p in self.patched]
|
|
237
237
|
|
|
238
|
-
def
|
|
238
|
+
def patches_involved_in_graph(
|
|
239
239
|
self, graph: "torch.fx.Graph" # noqa: F821
|
|
240
240
|
) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
|
|
241
241
|
"""
|
|
@@ -322,7 +322,7 @@ class PatchDetails:
|
|
|
322
322
|
"""
|
|
323
323
|
Creates a report based on the involved patches.
|
|
324
324
|
|
|
325
|
-
:param patches: from method :meth:`
|
|
325
|
+
:param patches: from method :meth:`patches_involved_in_graph`
|
|
326
326
|
:param format: format of the report
|
|
327
327
|
:return: report
|
|
328
328
|
"""
|
|
@@ -101,7 +101,10 @@ def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:
|
|
|
101
101
|
|
|
102
102
|
|
|
103
103
|
def patched_float_arange(start, end, step):
|
|
104
|
-
"""
|
|
104
|
+
"""
|
|
105
|
+
Patched arange when start, end, step are floats.
|
|
106
|
+
This patch should not be needed after 2.10.
|
|
107
|
+
"""
|
|
105
108
|
if is_torchdynamo_exporting():
|
|
106
109
|
return torch.ops.patched.float_arange(start, end, step)
|
|
107
110
|
else:
|
|
@@ -596,33 +596,41 @@ class RewriteControlFlow(ast.NodeTransformer):
|
|
|
596
596
|
elts=[
|
|
597
597
|
*[
|
|
598
598
|
ast.Call(
|
|
599
|
-
ast.Attribute(
|
|
600
|
-
value=ast.
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
args=[
|
|
605
|
-
ast.Subscript(
|
|
606
|
-
value=ast.Attribute(
|
|
607
|
-
value=ast.Name(id=v, ctx=ast.Load()),
|
|
608
|
-
attr="shape",
|
|
599
|
+
func=ast.Attribute(
|
|
600
|
+
value=ast.Call(
|
|
601
|
+
ast.Attribute(
|
|
602
|
+
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
603
|
+
attr="arange",
|
|
609
604
|
ctx=ast.Load(),
|
|
610
605
|
),
|
|
611
|
-
|
|
606
|
+
args=[
|
|
607
|
+
ast.Subscript(
|
|
608
|
+
value=ast.Attribute(
|
|
609
|
+
value=ast.Name(id=v, ctx=ast.Load()),
|
|
610
|
+
attr="shape",
|
|
611
|
+
ctx=ast.Load(),
|
|
612
|
+
),
|
|
613
|
+
slice=ast.Constant(value=0, ctx=ast.Load()),
|
|
614
|
+
ctx=ast.Load(),
|
|
615
|
+
),
|
|
616
|
+
],
|
|
617
|
+
keywords=[
|
|
618
|
+
ast.keyword(
|
|
619
|
+
arg="dtype",
|
|
620
|
+
value=ast.Attribute(
|
|
621
|
+
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
622
|
+
attr="int64",
|
|
623
|
+
ctx=ast.Load(),
|
|
624
|
+
),
|
|
625
|
+
)
|
|
626
|
+
],
|
|
612
627
|
ctx=ast.Load(),
|
|
613
628
|
),
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
620
|
-
attr="int64",
|
|
621
|
-
ctx=ast.Load(),
|
|
622
|
-
),
|
|
623
|
-
)
|
|
624
|
-
],
|
|
625
|
-
ctx=ast.Load(),
|
|
629
|
+
attr="unsqueeze",
|
|
630
|
+
ctx=ast.Load(),
|
|
631
|
+
),
|
|
632
|
+
args=[ast.Constant(value=1)],
|
|
633
|
+
keywords=[],
|
|
626
634
|
)
|
|
627
635
|
for v in scan_shape_vars
|
|
628
636
|
],
|
|
@@ -22,13 +22,22 @@ if patch_DynamicLayer:
|
|
|
22
22
|
_PATCHES_ = ["lazy_initialization"]
|
|
23
23
|
_PATCHED_CLASS_ = DynamicLayer
|
|
24
24
|
|
|
25
|
-
def lazy_initialization(
|
|
25
|
+
def lazy_initialization(
|
|
26
|
+
self, key_states: torch.Tensor, value_states: torch.Tensor = None
|
|
27
|
+
):
|
|
26
28
|
self.dtype, self.device = key_states.dtype, key_states.device
|
|
27
|
-
|
|
28
|
-
|
|
29
|
+
assert (
|
|
30
|
+
hasattr(key_states, "shape") and key_states is not None
|
|
31
|
+
), f"Attribute 'shape' is wrong for type {type(key_states)}"
|
|
32
|
+
like = torch.narrow(key_states, dim=-2, start=0, length=0)
|
|
29
33
|
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
|
|
30
|
-
|
|
31
|
-
|
|
34
|
+
if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor):
|
|
35
|
+
with key_states.fake_mode:
|
|
36
|
+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
|
|
37
|
+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
|
|
38
|
+
else:
|
|
39
|
+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
|
|
40
|
+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
|
|
32
41
|
if patch_is_initialized:
|
|
33
42
|
self.is_initialized = True
|
|
34
43
|
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import transformers.models.funnel.modeling_funnel
|
|
5
|
+
|
|
6
|
+
patch_funnel = True
|
|
7
|
+
except ImportError:
|
|
8
|
+
patch_funnel = False
|
|
9
|
+
|
|
10
|
+
if patch_funnel:
|
|
11
|
+
from transformers.models.funnel.modeling_funnel import _relative_shift_gather
|
|
12
|
+
|
|
13
|
+
class patched_FunnelAttentionStructure(torch.nn.Module):
|
|
14
|
+
_PATCHES_ = ["relative_pos"]
|
|
15
|
+
_PATCHED_CLASS_ = transformers.models.funnel.modeling_funnel.FunnelAttentionStructure
|
|
16
|
+
|
|
17
|
+
def relative_pos(
|
|
18
|
+
self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
if pooled_pos is None:
|
|
21
|
+
pooled_pos = pos
|
|
22
|
+
ref_point = pooled_pos[0] - pos[0]
|
|
23
|
+
# PATCHED
|
|
24
|
+
num_remove = shift * pooled_pos.shape[0]
|
|
25
|
+
max_dist = ref_point + num_remove * stride
|
|
26
|
+
min_dist = pooled_pos[0] - pos[-1]
|
|
27
|
+
return torch.arange(
|
|
28
|
+
max_dist.to(torch.long),
|
|
29
|
+
(min_dist - 1).to(torch.long),
|
|
30
|
+
torch.tensor(-stride, dtype=torch.long),
|
|
31
|
+
dtype=torch.long,
|
|
32
|
+
device=pos.device,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
class patched_FunnelRelMultiheadAttention(torch.nn.Module):
|
|
36
|
+
_PATCHES_ = ["relative_positional_attention"]
|
|
37
|
+
_PATCHED_CLASS_ = (
|
|
38
|
+
transformers.models.funnel.modeling_funnel.FunnelRelMultiheadAttention
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def relative_positional_attention(
|
|
42
|
+
self, position_embeds, q_head, context_len, cls_mask=None
|
|
43
|
+
):
|
|
44
|
+
"""Relative attention score for the positional encodings"""
|
|
45
|
+
# q_head has shape batch_size x sea_len x n_head x d_head
|
|
46
|
+
if self.config.attention_type == "factorized":
|
|
47
|
+
phi, pi, psi, omega = position_embeds
|
|
48
|
+
# Shape n_head x d_head
|
|
49
|
+
u = self.r_r_bias * self.scale
|
|
50
|
+
# Shape d_model x n_head x d_head
|
|
51
|
+
w_r = self.r_kernel
|
|
52
|
+
|
|
53
|
+
# Shape batch_size x sea_len x n_head x d_model
|
|
54
|
+
q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
|
|
55
|
+
q_r_attention_1 = q_r_attention * phi[:, None]
|
|
56
|
+
q_r_attention_2 = q_r_attention * pi[:, None]
|
|
57
|
+
|
|
58
|
+
# Shape batch_size x n_head x seq_len x context_len
|
|
59
|
+
positional_attn = torch.einsum(
|
|
60
|
+
"bind,jd->bnij", q_r_attention_1, psi
|
|
61
|
+
) + torch.einsum("bind,jd->bnij", q_r_attention_2, omega)
|
|
62
|
+
else:
|
|
63
|
+
shift = 2 if q_head.shape[1] != context_len else 1
|
|
64
|
+
r = position_embeds[self.block_index][shift - 1]
|
|
65
|
+
# Shape n_head x d_head
|
|
66
|
+
v = self.r_r_bias * self.scale
|
|
67
|
+
# Shape d_model x n_head x d_head
|
|
68
|
+
w_r = self.r_kernel
|
|
69
|
+
|
|
70
|
+
# Shape max_rel_len x n_head x d_model
|
|
71
|
+
r_head = torch.einsum("td,dnh->tnh", r, w_r)
|
|
72
|
+
# Shape batch_size x n_head x seq_len x max_rel_len
|
|
73
|
+
positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
|
|
74
|
+
# Shape batch_size x n_head x seq_len x context_len
|
|
75
|
+
positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
|
|
76
|
+
|
|
77
|
+
if cls_mask is not None:
|
|
78
|
+
# PATCHED
|
|
79
|
+
positional_attn = positional_attn * cls_mask
|
|
80
|
+
return positional_attn
|
|
@@ -256,10 +256,21 @@ if patch_qwen2_5:
|
|
|
256
256
|
return attn_output
|
|
257
257
|
|
|
258
258
|
def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
|
|
259
|
+
import onnx_ir
|
|
260
|
+
|
|
259
261
|
first_float_tensor = next(
|
|
260
262
|
a
|
|
261
263
|
for a in args
|
|
262
|
-
if a is not None
|
|
264
|
+
if a is not None
|
|
265
|
+
and a.dtype
|
|
266
|
+
in {
|
|
267
|
+
torch.float16,
|
|
268
|
+
torch.float32,
|
|
269
|
+
torch.bfloat16,
|
|
270
|
+
onnx_ir.DataType.BFLOAT16,
|
|
271
|
+
onnx_ir.DataType.FLOAT16,
|
|
272
|
+
onnx_ir.DataType.FLOAT,
|
|
273
|
+
}
|
|
263
274
|
)
|
|
264
275
|
dtype = first_float_tensor.dtype
|
|
265
276
|
strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
|
|
@@ -214,7 +214,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
214
214
|
cond,
|
|
215
215
|
(lambda x, y: x.clone()),
|
|
216
216
|
(lambda x, y: y.clone()),
|
|
217
|
-
[long_inv_freq, original_inv_freq],
|
|
217
|
+
[long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
|
|
218
218
|
)
|
|
219
219
|
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
220
220
|
# if seq_len > original_max_position_embeddings:
|
|
@@ -293,7 +293,7 @@ def patched_dynamic_rope_update(rope_forward):
|
|
|
293
293
|
cond,
|
|
294
294
|
(lambda x, y: x.clone()),
|
|
295
295
|
(lambda x, y: y.clone()),
|
|
296
|
-
[long_inv_freq, original_inv_freq],
|
|
296
|
+
[long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
|
|
297
297
|
)
|
|
298
298
|
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
299
299
|
|
|
@@ -5,6 +5,7 @@ import os
|
|
|
5
5
|
import traceback
|
|
6
6
|
from functools import reduce
|
|
7
7
|
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
|
|
8
|
+
import sympy
|
|
8
9
|
import torch
|
|
9
10
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
10
11
|
|
|
@@ -1091,3 +1092,17 @@ def patched__broadcast_in_dim_meta_level_2(
|
|
|
1091
1092
|
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
|
|
1092
1093
|
|
|
1093
1094
|
return a.as_strided(shape, new_strides, a.storage_offset())
|
|
1095
|
+
|
|
1096
|
+
|
|
1097
|
+
class patched_DynamicDimConstraintPrinter:
|
|
1098
|
+
"""
|
|
1099
|
+
Patches
|
|
1100
|
+
``torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol``.
|
|
1101
|
+
Valid for ``torch>=2.10``.
|
|
1102
|
+
"""
|
|
1103
|
+
|
|
1104
|
+
def _print_Symbol(self, expr: sympy.Symbol) -> str:
|
|
1105
|
+
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
|
1106
|
+
if self.symbol_to_source.get(expr):
|
|
1107
|
+
return self.symbol_to_source[expr][0].name
|
|
1108
|
+
return str(expr)
|
|
@@ -1,29 +1,37 @@
|
|
|
1
1
|
# transformers
|
|
2
2
|
from typing import List
|
|
3
3
|
from .patch_helper import _has_transformers
|
|
4
|
-
|
|
5
4
|
from ._patch_transformers_attention import (
|
|
6
5
|
patched_sdpa_attention_forward,
|
|
7
6
|
patched_model_bart_eager_attention_forward,
|
|
8
7
|
patched_modeling_marian_eager_attention_forward,
|
|
9
8
|
)
|
|
9
|
+
from ._patch_transformers_generation_mixin import patched_GenerationMixin
|
|
10
|
+
from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
|
|
11
|
+
from ._patch_transformers_rotary_embedding import (
|
|
12
|
+
patched__compute_dynamic_ntk_parameters,
|
|
13
|
+
patched_dynamic_rope_update,
|
|
14
|
+
patched_GemmaRotaryEmbedding,
|
|
15
|
+
patched_LlamaRotaryEmbedding,
|
|
16
|
+
patched_MistralRotaryEmbedding,
|
|
17
|
+
patched_MixtralRotaryEmbedding,
|
|
18
|
+
patched_PhiRotaryEmbedding,
|
|
19
|
+
)
|
|
20
|
+
from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
|
|
21
|
+
from ._patch_transformers_sam_mask_decoder import patched_SamMaskDecoder
|
|
22
|
+
|
|
23
|
+
# transformers dependent patches
|
|
10
24
|
|
|
11
25
|
from ._patch_transformers_cache_utils import patch_parse_processor_args
|
|
12
26
|
|
|
13
27
|
if patch_parse_processor_args:
|
|
14
28
|
from ._patch_transformers_cache_utils import patched_parse_processor_args
|
|
15
|
-
|
|
16
|
-
from ._patch_transformers_causal_mask import patched_AttentionMaskConverter
|
|
17
|
-
|
|
18
29
|
from ._patch_transformers_dynamic_cache import patch_DynamicLayer, patch_DynamicCache
|
|
19
30
|
|
|
20
31
|
if patch_DynamicLayer:
|
|
21
32
|
from ._patch_transformers_dynamic_cache import patched_DynamicLayer
|
|
22
33
|
if patch_DynamicCache:
|
|
23
34
|
from ._patch_transformers_dynamic_cache import patched_DynamicCache
|
|
24
|
-
|
|
25
|
-
from ._patch_transformers_generation_mixin import patched_GenerationMixin
|
|
26
|
-
|
|
27
35
|
from ._patch_transformers_masking_utils import patch_masking_utils
|
|
28
36
|
|
|
29
37
|
if patch_masking_utils:
|
|
@@ -33,15 +41,7 @@ if patch_masking_utils:
|
|
|
33
41
|
patched_sdpa_mask_recent_torch,
|
|
34
42
|
)
|
|
35
43
|
|
|
36
|
-
|
|
37
|
-
patched__compute_dynamic_ntk_parameters,
|
|
38
|
-
patched_dynamic_rope_update,
|
|
39
|
-
patched_GemmaRotaryEmbedding,
|
|
40
|
-
patched_LlamaRotaryEmbedding,
|
|
41
|
-
patched_MistralRotaryEmbedding,
|
|
42
|
-
patched_MixtralRotaryEmbedding,
|
|
43
|
-
patched_PhiRotaryEmbedding,
|
|
44
|
-
)
|
|
44
|
+
# transformers models dependent patches
|
|
45
45
|
|
|
46
46
|
if _has_transformers("4.51"):
|
|
47
47
|
from ._patch_transformers_rotary_embedding import patched_Phi3RotaryEmbedding
|
|
@@ -54,16 +54,11 @@ if _has_transformers("4.52"):
|
|
|
54
54
|
if _has_transformers("4.53"):
|
|
55
55
|
from ._patch_transformers_rotary_embedding import patched_SmolLM3RotaryEmbedding
|
|
56
56
|
|
|
57
|
-
# Models
|
|
58
|
-
|
|
59
57
|
from ._patch_transformers_gemma3 import patch_gemma3
|
|
60
58
|
|
|
61
59
|
if patch_gemma3:
|
|
62
60
|
from ._patch_transformers_gemma3 import patched_Gemma3Model
|
|
63
61
|
|
|
64
|
-
from ._patch_transformers_idefics import patched_IdeficsEmbedding, patched_IdeficsAttention
|
|
65
|
-
|
|
66
|
-
|
|
67
62
|
from ._patch_transformers_qwen2 import patch_qwen2
|
|
68
63
|
|
|
69
64
|
if patch_qwen2:
|
|
@@ -80,14 +75,17 @@ if patch_qwen2_5:
|
|
|
80
75
|
patched_Qwen2_5_VLModel,
|
|
81
76
|
PLUGS as PLUGS_Qwen25,
|
|
82
77
|
)
|
|
83
|
-
|
|
84
78
|
from ._patch_transformers_qwen3 import patch_qwen3
|
|
85
79
|
|
|
86
80
|
if patch_qwen3:
|
|
87
81
|
from ._patch_transformers_qwen3 import patched_Qwen3MoeSparseMoeBlock
|
|
82
|
+
from ._patch_transformers_funnel import patch_funnel
|
|
88
83
|
|
|
89
|
-
|
|
90
|
-
from .
|
|
84
|
+
if patch_funnel:
|
|
85
|
+
from ._patch_transformers_funnel import (
|
|
86
|
+
patched_FunnelAttentionStructure,
|
|
87
|
+
patched_FunnelRelMultiheadAttention,
|
|
88
|
+
)
|
|
91
89
|
|
|
92
90
|
|
|
93
91
|
def get_transformers_plugs() -> List["EagerDirectReplacementWithOnnx"]: # noqa: F821
|
|
@@ -184,7 +184,18 @@ def _trygetattr(config, attname):
|
|
|
184
184
|
return None
|
|
185
185
|
|
|
186
186
|
|
|
187
|
+
def rewrite_architecture_name(name: Optional[str]) -> Optional[str]:
|
|
188
|
+
if name == "ConditionalDETRForObjectDetection":
|
|
189
|
+
return "ConditionalDetrForObjectDetection"
|
|
190
|
+
return name
|
|
191
|
+
|
|
192
|
+
|
|
187
193
|
def architecture_from_config(config) -> Optional[str]:
|
|
194
|
+
"""Guesses the architecture (class) of the model described by this config."""
|
|
195
|
+
return rewrite_architecture_name(_architecture_from_config(config))
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _architecture_from_config(config) -> Optional[str]:
|
|
188
199
|
"""Guesses the architecture (class) of the model described by this config."""
|
|
189
200
|
if isinstance(config, dict):
|
|
190
201
|
if "_class_name" in config:
|
|
@@ -5,7 +5,10 @@ from typing import Dict, List
|
|
|
5
5
|
|
|
6
6
|
__date__ = "2025-06-21"
|
|
7
7
|
|
|
8
|
-
__data_arch_values__ = {
|
|
8
|
+
__data_arch_values__ = {
|
|
9
|
+
"ConditionalDETRForObjectDetection": dict(image_size=224),
|
|
10
|
+
"ResNetForImageClassification": dict(image_size=224),
|
|
11
|
+
}
|
|
9
12
|
|
|
10
13
|
__data_arch__ = textwrap.dedent(
|
|
11
14
|
"""
|
|
@@ -32,6 +35,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
32
35
|
ConvNextV2Model,image-feature-extraction
|
|
33
36
|
CosmosTransformer3DModel,image-to-video
|
|
34
37
|
CvtModel,feature-extraction
|
|
38
|
+
ClvpModelForConditionalGeneration,audio-feature-extraction
|
|
35
39
|
DPTModel,image-feature-extraction
|
|
36
40
|
Data2VecAudioModel,feature-extraction
|
|
37
41
|
Data2VecTextModel,feature-extraction
|
|
@@ -49,6 +53,8 @@ __data_arch__ = textwrap.dedent(
|
|
|
49
53
|
ElectraModel,feature-extraction
|
|
50
54
|
EsmModel,feature-extraction
|
|
51
55
|
FalconMambaForCausalLM,text-generation
|
|
56
|
+
FunnelBaseModel,feature-extraction
|
|
57
|
+
FuyuForCausalLM,image-text-to-text
|
|
52
58
|
GLPNModel,image-feature-extraction
|
|
53
59
|
GPT2LMHeadModel,text-generation
|
|
54
60
|
GPTBigCodeModel,feature-extraction
|
|
@@ -63,6 +69,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
63
69
|
Glm4vMoeForConditionalGeneration,image-text-to-text
|
|
64
70
|
GraniteForCausalLM,text-generation
|
|
65
71
|
GroupViTModel,feature-extraction
|
|
72
|
+
HeliumForCausalLM,text-generation
|
|
66
73
|
HieraForImageClassification,image-classification
|
|
67
74
|
HubertModel,feature-extraction
|
|
68
75
|
IBertModel,feature-extraction
|
|
@@ -136,6 +143,7 @@ __data_arch__ = textwrap.dedent(
|
|
|
136
143
|
SwinModel,image-feature-extraction
|
|
137
144
|
Swinv2Model,image-feature-extraction
|
|
138
145
|
T5ForConditionalGeneration,text2text-generation
|
|
146
|
+
T5GemmaForConditionalGeneration,text2text-generation
|
|
139
147
|
TableTransformerModel,image-feature-extraction
|
|
140
148
|
TableTransformerForObjectDetection,object-detection
|
|
141
149
|
UNet2DConditionModel,text-to-image
|
|
@@ -64,6 +64,7 @@ def get_untrained_model_with_inputs(
|
|
|
64
64
|
use_only_preinstalled: bool = False,
|
|
65
65
|
config_reduction: Optional[Callable[[Any, str], Dict]] = None,
|
|
66
66
|
submodule: Optional[str] = None,
|
|
67
|
+
skip_inputs: bool = False,
|
|
67
68
|
) -> Dict[str, Any]:
|
|
68
69
|
"""
|
|
69
70
|
Gets a non initialized model similar to the original model
|
|
@@ -93,6 +94,7 @@ def get_untrained_model_with_inputs(
|
|
|
93
94
|
this function takes a configuration and a task (string)
|
|
94
95
|
as arguments
|
|
95
96
|
:param submodule: use a submodule instead of the main model
|
|
97
|
+
:param skip_inputs: do not generate the inputs
|
|
96
98
|
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
|
|
97
99
|
some necessary rewriting as well
|
|
98
100
|
|
|
@@ -332,13 +334,12 @@ def get_untrained_model_with_inputs(
|
|
|
332
334
|
f"[get_untrained_model_with_inputs] "
|
|
333
335
|
f"instantiate_specific_model(2) {cls_model}"
|
|
334
336
|
)
|
|
335
|
-
|
|
336
337
|
try:
|
|
337
338
|
if type(config) is dict:
|
|
338
339
|
model = cls_model(**config)
|
|
339
340
|
else:
|
|
340
341
|
model = cls_model(config)
|
|
341
|
-
except RuntimeError as e:
|
|
342
|
+
except (RuntimeError, AttributeError, ValueError) as e:
|
|
342
343
|
raise RuntimeError(
|
|
343
344
|
f"Unable to instantiate class {cls_model.__name__} with\n{config}"
|
|
344
345
|
) from e
|
|
@@ -350,23 +351,27 @@ def get_untrained_model_with_inputs(
|
|
|
350
351
|
)
|
|
351
352
|
|
|
352
353
|
# input kwargs
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
354
|
+
if not skip_inputs:
|
|
355
|
+
seed = int(os.environ.get("SEED", "17")) + 1
|
|
356
|
+
torch.manual_seed(seed)
|
|
357
|
+
kwargs, fct = random_input_kwargs(config, task) # type: ignore[arg-type]
|
|
358
|
+
if verbose:
|
|
359
|
+
print(f"[get_untrained_model_with_inputs] use fct={fct}")
|
|
360
|
+
if os.environ.get("PRINT_CONFIG") in (1, "1"):
|
|
361
|
+
print(f"-- input kwargs for task {task!r}")
|
|
362
|
+
pprint.pprint(kwargs)
|
|
363
|
+
if inputs_kwargs:
|
|
364
|
+
kwargs.update(inputs_kwargs)
|
|
365
|
+
|
|
366
|
+
# This line is important. Some models may produce different
|
|
367
|
+
# outputs even with the same inputs in training mode.
|
|
368
|
+
model.eval() # type: ignore[union-attr]
|
|
369
|
+
res = fct(model, config, add_second_input=add_second_input, **kwargs)
|
|
370
|
+
|
|
371
|
+
res["input_kwargs"] = kwargs
|
|
372
|
+
else:
|
|
373
|
+
res = {}
|
|
374
|
+
|
|
370
375
|
res["model_kwargs"] = mkwargs
|
|
371
376
|
if diff_config is not None:
|
|
372
377
|
res["dump_info"] = dict(config_diff=diff_config)
|