onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +118 -5
- onnx_diagnostic/export/control_flow.py +214 -0
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +135 -0
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +118 -25
- onnx_diagnostic/helpers/cache_helper.py +218 -204
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +92 -26
- onnx_diagnostic/helpers/log_helper.py +26 -4
- onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +115 -16
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/rt_helper.py +547 -0
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +108 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/image_text_to_text.py +5 -1
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import transformers.models.qwen3_moe
|
|
5
|
+
|
|
6
|
+
patch_qwen3 = True
|
|
7
|
+
except ImportError:
|
|
8
|
+
patch_qwen3 = False
|
|
9
|
+
|
|
10
|
+
if patch_qwen3:
|
|
11
|
+
|
|
12
|
+
class patched_Qwen3MoeSparseMoeBlock(torch.nn.Module):
|
|
13
|
+
_PATCHES_ = ["forward", "_forward_expert_loop"]
|
|
14
|
+
_PATCHED_CLASS_ = (
|
|
15
|
+
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
def _forward_expert_loop(
|
|
19
|
+
self,
|
|
20
|
+
final_hidden_states,
|
|
21
|
+
expert_mask_idx,
|
|
22
|
+
hidden_states,
|
|
23
|
+
routing_weights,
|
|
24
|
+
expert_idx: int,
|
|
25
|
+
):
|
|
26
|
+
# idx, top_x = torch.where(expert_mask_idx.squeeze(0))
|
|
27
|
+
idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
|
|
28
|
+
hidden_dim = hidden_states.shape[-1]
|
|
29
|
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
|
30
|
+
expert_current_state = self.experts[expert_idx](current_state)
|
|
31
|
+
current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
|
|
32
|
+
return final_hidden_states.index_add(
|
|
33
|
+
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
""" """
|
|
38
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
39
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
40
|
+
# router_logits: (batch * sequence_length, n_experts)
|
|
41
|
+
router_logits = self.gate(hidden_states)
|
|
42
|
+
|
|
43
|
+
routing_weights = torch.nn.functional.softmax(
|
|
44
|
+
router_logits, dim=1, dtype=torch.float
|
|
45
|
+
)
|
|
46
|
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
|
47
|
+
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
|
|
48
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
49
|
+
# we cast back to the input dtype
|
|
50
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
|
51
|
+
|
|
52
|
+
final_hidden_states = torch.zeros(
|
|
53
|
+
(batch_size * sequence_length, hidden_dim),
|
|
54
|
+
dtype=hidden_states.dtype,
|
|
55
|
+
device=hidden_states.device,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# One hot encode the selected experts to create an expert mask
|
|
59
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
60
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
61
|
+
selected_experts, num_classes=self.num_experts
|
|
62
|
+
).permute(2, 1, 0)
|
|
63
|
+
|
|
64
|
+
# Loop over all available experts in the model
|
|
65
|
+
# and perform the computation on each expert
|
|
66
|
+
expert_sum = expert_mask.sum(dim=(-1, -2))
|
|
67
|
+
# expert_hit = torch.greater(expert_sum, 0).nonzero()
|
|
68
|
+
# for expert_idx in expert_hit:
|
|
69
|
+
for expert_idx in range(self.num_experts):
|
|
70
|
+
# initial code has a squeeze but it is not possible to do that.
|
|
71
|
+
# expert_mask_idx = expert_mask[expert_idx].squeeze(0)
|
|
72
|
+
expert_mask_idx = expert_mask[expert_idx]
|
|
73
|
+
final_hidden_states = torch.cond(
|
|
74
|
+
(expert_sum[expert_idx] > 0).item(),
|
|
75
|
+
lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
|
|
76
|
+
final_hidden_states,
|
|
77
|
+
expert_mask,
|
|
78
|
+
hidden_states,
|
|
79
|
+
routing_weights,
|
|
80
|
+
expert_idx=_i,
|
|
81
|
+
),
|
|
82
|
+
lambda final_hidden_states, *args: final_hidden_states.clone(),
|
|
83
|
+
[final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# if expert_sum[expert_idx] > 0:
|
|
87
|
+
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
88
|
+
|
|
89
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
90
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
91
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
92
|
+
# current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
|
93
|
+
# current_hidden_states = (
|
|
94
|
+
# expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
95
|
+
# )
|
|
96
|
+
|
|
97
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
98
|
+
# the `top_x` tensor here.
|
|
99
|
+
# final_hidden_states.index_add_(
|
|
100
|
+
# 0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
101
|
+
# )
|
|
102
|
+
|
|
103
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
104
|
+
batch_size, sequence_length, hidden_dim
|
|
105
|
+
)
|
|
106
|
+
return final_hidden_states, router_logits
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
from functools import wraps
|
|
2
|
+
from typing import Callable, Optional, Tuple
|
|
3
|
+
import packaging.version as pv
|
|
4
|
+
import torch
|
|
5
|
+
import transformers
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def patched__compute_dynamic_ntk_parameters(
|
|
9
|
+
config: Optional[transformers.PretrainedConfig] = None,
|
|
10
|
+
device: Optional["torch.device"] = None,
|
|
11
|
+
seq_len: Optional[int] = None,
|
|
12
|
+
**rope_kwargs,
|
|
13
|
+
) -> Tuple["torch.Tensor", float]:
|
|
14
|
+
"""
|
|
15
|
+
manual patch:
|
|
16
|
+
``[patch:transformers.modeling_rope_utils._compute_dynamic_ntk_parameters]``
|
|
17
|
+
|
|
18
|
+
Computes the inverse frequencies with NTK scaling.
|
|
19
|
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
config ([`~transformers.PretrainedConfig`]):
|
|
23
|
+
The model configuration.
|
|
24
|
+
device (`torch.device`):
|
|
25
|
+
The device to use for initialization of the inverse frequencies.
|
|
26
|
+
seq_len (`int`, *optional*):
|
|
27
|
+
The current sequence length,
|
|
28
|
+
used to update the dynamic RoPE at inference time.
|
|
29
|
+
rope_kwargs (`Dict`, *optional*):
|
|
30
|
+
BC compatibility with the previous
|
|
31
|
+
RoPE class instantiation, will be removed in v4.45.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Tuple of (`torch.Tensor`, `float`),
|
|
35
|
+
containing the inverse frequencies for the RoPE embeddings and the
|
|
36
|
+
post-processing scaling factor applied to the
|
|
37
|
+
omputed cos/sin (unused in this type of RoPE).
|
|
38
|
+
"""
|
|
39
|
+
if config is not None and len(rope_kwargs) > 0:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
|
42
|
+
f"`_compute_dynamic_ntk_parameters`, got "
|
|
43
|
+
f"`rope_kwargs`={rope_kwargs} and `config`={config}"
|
|
44
|
+
)
|
|
45
|
+
if len(rope_kwargs) > 0:
|
|
46
|
+
base = rope_kwargs["base"]
|
|
47
|
+
dim = rope_kwargs["dim"]
|
|
48
|
+
max_position_embeddings = rope_kwargs["max_position_embeddings"]
|
|
49
|
+
factor = rope_kwargs["factor"]
|
|
50
|
+
elif config is not None:
|
|
51
|
+
if hasattr(config, "rope_theta"):
|
|
52
|
+
# transformers<5
|
|
53
|
+
base = config.rope_theta
|
|
54
|
+
partial_rotary_factor = (
|
|
55
|
+
config.partial_rotary_factor
|
|
56
|
+
if hasattr(config, "partial_rotary_factor")
|
|
57
|
+
else 1.0
|
|
58
|
+
)
|
|
59
|
+
factor = config.rope_scaling["factor"]
|
|
60
|
+
else:
|
|
61
|
+
print("-----")
|
|
62
|
+
print(config)
|
|
63
|
+
base = config.rope_parameters["rope_theta"]
|
|
64
|
+
partial_rotary_factor = config.rope_parameters["partial_rotary_factor"]
|
|
65
|
+
factor = config.rope_parameters["factor"]
|
|
66
|
+
head_dim = getattr(
|
|
67
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
68
|
+
)
|
|
69
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
70
|
+
max_position_embeddings = config.max_position_embeddings
|
|
71
|
+
|
|
72
|
+
attention_factor = 1.0 # Unused in this type of RoPE
|
|
73
|
+
|
|
74
|
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
|
75
|
+
# seq_len = seq_len if seq_len is not None and
|
|
76
|
+
# seq_len > max_position_embeddings else max_position_embeddings
|
|
77
|
+
if seq_len is None:
|
|
78
|
+
seq_len = max_position_embeddings
|
|
79
|
+
else:
|
|
80
|
+
# PATCHED: remove the line using max
|
|
81
|
+
torch._check(isinstance(seq_len, torch.Tensor))
|
|
82
|
+
seq_len = torch.maximum(
|
|
83
|
+
seq_len,
|
|
84
|
+
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Compute the inverse frequencies
|
|
88
|
+
base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (
|
|
89
|
+
dim / (dim - 2)
|
|
90
|
+
)
|
|
91
|
+
inv_freq = 1.0 / (
|
|
92
|
+
base
|
|
93
|
+
** (
|
|
94
|
+
torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
|
|
95
|
+
/ dim
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
return inv_freq, attention_factor
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _get_rope_init_fn(self, layer_type=None) -> Callable:
|
|
102
|
+
if hasattr(self, "rope_init_fn"):
|
|
103
|
+
# transformers<=5.0
|
|
104
|
+
rope_init_fn = (
|
|
105
|
+
patched__compute_dynamic_ntk_parameters
|
|
106
|
+
if self.rope_init_fn
|
|
107
|
+
is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
|
|
108
|
+
else self.rope_init_fn
|
|
109
|
+
)
|
|
110
|
+
return rope_init_fn
|
|
111
|
+
|
|
112
|
+
rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
|
|
113
|
+
rope_init_fn = self.compute_default_rope_parameters
|
|
114
|
+
if rope_type != "default":
|
|
115
|
+
rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
116
|
+
if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
|
|
117
|
+
return patched__compute_dynamic_ntk_parameters
|
|
118
|
+
return rope_init_fn
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def patched_dynamic_rope_update(rope_forward):
|
|
122
|
+
"""manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
|
|
123
|
+
|
|
124
|
+
``rope_type`` is determined in the constructor of class
|
|
125
|
+
:class:`transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding`.
|
|
126
|
+
|
|
127
|
+
.. code-block:: python
|
|
128
|
+
|
|
129
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
130
|
+
self.rope_type = config.rope_scaling.get(
|
|
131
|
+
"rope_type", config.rope_scaling.get("type"))
|
|
132
|
+
else:
|
|
133
|
+
self.rope_type = "default"
|
|
134
|
+
|
|
135
|
+
The original code of the patched function:
|
|
136
|
+
|
|
137
|
+
.. code-block:: python
|
|
138
|
+
|
|
139
|
+
def dynamic_rope_update(rope_forward):
|
|
140
|
+
def longrope_frequency_update(self, position_ids, device):
|
|
141
|
+
seq_len = torch.max(position_ids) + 1
|
|
142
|
+
if hasattr(self.config, "original_max_position_embeddings"):
|
|
143
|
+
original_max_position_embeddings =
|
|
144
|
+
self.config.original_max_position_embeddings
|
|
145
|
+
else:
|
|
146
|
+
original_max_position_embeddings =
|
|
147
|
+
self.config.max_position_embeddings
|
|
148
|
+
if seq_len > original_max_position_embeddings:
|
|
149
|
+
if not hasattr(self, "long_inv_freq"):
|
|
150
|
+
self.long_inv_freq, _ = self.rope_init_fn(
|
|
151
|
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
152
|
+
)
|
|
153
|
+
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
|
|
154
|
+
else:
|
|
155
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
156
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
157
|
+
|
|
158
|
+
def dynamic_frequency_update(self, position_ids, device):
|
|
159
|
+
seq_len = torch.max(position_ids) + 1
|
|
160
|
+
if seq_len > self.max_seq_len_cached: # growth
|
|
161
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
162
|
+
self.config, device, seq_len=seq_len)
|
|
163
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
164
|
+
self.max_seq_len_cached = seq_len
|
|
165
|
+
|
|
166
|
+
if seq_len < self.original_max_seq_len and
|
|
167
|
+
self.max_seq_len_cached > self.original_max_seq_len:
|
|
168
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
169
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
170
|
+
self.max_seq_len_cached = self.original_max_seq_len
|
|
171
|
+
|
|
172
|
+
@wraps(rope_forward)
|
|
173
|
+
def wrapper(self, x, position_ids):
|
|
174
|
+
if "dynamic" in self.rope_type:
|
|
175
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
176
|
+
elif self.rope_type == "longrope":
|
|
177
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
178
|
+
return rope_forward(self, x, position_ids)
|
|
179
|
+
|
|
180
|
+
return wrapper
|
|
181
|
+
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def longrope_frequency_update(self, position_ids, device, layer_type=None):
|
|
185
|
+
# It is no use to patch the function after the model is created
|
|
186
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
187
|
+
# is created and when no patch is applied yet.
|
|
188
|
+
# So we select the patched version here.
|
|
189
|
+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
|
|
190
|
+
seq_len = torch.max(position_ids) + 1
|
|
191
|
+
if hasattr(self.config, "original_max_position_embeddings"):
|
|
192
|
+
original_max_position_embeddings = self.config.original_max_position_embeddings
|
|
193
|
+
else:
|
|
194
|
+
original_max_position_embeddings = self.config.max_position_embeddings
|
|
195
|
+
|
|
196
|
+
if layer_type is None:
|
|
197
|
+
# rope_type = self.rope_type
|
|
198
|
+
original_inv_freq = self.original_inv_freq
|
|
199
|
+
prefix = ""
|
|
200
|
+
else:
|
|
201
|
+
# rope_type = self.rope_type[layer_type]
|
|
202
|
+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
|
|
203
|
+
prefix = f"{layer_type}_"
|
|
204
|
+
|
|
205
|
+
# At export time, seq_len is unknown.
|
|
206
|
+
long_inv_freq, _ = rope_init_fn(
|
|
207
|
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
208
|
+
)
|
|
209
|
+
original_inv_freq = self.original_inv_freq.to(device)
|
|
210
|
+
|
|
211
|
+
# PATCHED: uses torch.cond instead of a test
|
|
212
|
+
cond = (seq_len > original_max_position_embeddings).item()
|
|
213
|
+
inv_freq = torch.cond(
|
|
214
|
+
cond,
|
|
215
|
+
(lambda x, y: x.clone()),
|
|
216
|
+
(lambda x, y: y.clone()),
|
|
217
|
+
[long_inv_freq, original_inv_freq],
|
|
218
|
+
)
|
|
219
|
+
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
220
|
+
# if seq_len > original_max_position_embeddings:
|
|
221
|
+
# self.inv_freq = self.long_inv_freq
|
|
222
|
+
# else:
|
|
223
|
+
# self.inv_freq = self.original_inv_freq
|
|
224
|
+
|
|
225
|
+
def dynamic_frequency_update(self, position_ids, device, layer_type=None):
|
|
226
|
+
# constructor:
|
|
227
|
+
# - self.max_seq_len_cached = config.max_position_embeddings
|
|
228
|
+
# - self.original_max_seq_len = config.max_position_embeddings
|
|
229
|
+
# - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
230
|
+
|
|
231
|
+
# It is no use to patch the function after the model is created
|
|
232
|
+
# as rope_init_fn is an attribute set to one function when the model
|
|
233
|
+
# is created and when no patch is applied yet.
|
|
234
|
+
# So we select the patched version here.
|
|
235
|
+
rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
|
|
236
|
+
|
|
237
|
+
# This behaviour is difficult to translate.
|
|
238
|
+
# The sequence always grows.
|
|
239
|
+
# The test should always True.
|
|
240
|
+
# So: self.max_seq_len_cached = max(self.max_seq_len_cached, seq_len) --> seq_len
|
|
241
|
+
#
|
|
242
|
+
# if seq_len > self.max_seq_len_cached: # growth
|
|
243
|
+
# inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
244
|
+
# self.config, device, seq_len=seq_len
|
|
245
|
+
# )
|
|
246
|
+
# self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
247
|
+
# self.max_seq_len_cached = seq_len
|
|
248
|
+
#
|
|
249
|
+
# So we should not need what follows.
|
|
250
|
+
#
|
|
251
|
+
# cond = (seq_len > self.max_seq_len_cached).item()
|
|
252
|
+
# self.attention_scaling = torch.cond(
|
|
253
|
+
# cond,
|
|
254
|
+
# (lambda x, y: x.clone()),
|
|
255
|
+
# (lambda x, y: y.clone()),
|
|
256
|
+
# [attention_scaling, self.attention_scaling],
|
|
257
|
+
# )
|
|
258
|
+
|
|
259
|
+
seq_len = torch.max(position_ids) + 1
|
|
260
|
+
long_inv_freq, self.attention_scaling = rope_init_fn(
|
|
261
|
+
self.config, device, seq_len=seq_len
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if layer_type is None:
|
|
265
|
+
# rope_type = self.rope_type
|
|
266
|
+
# max_seq_len_cached = self.max_seq_len_cached
|
|
267
|
+
original_inv_freq = self.original_inv_freq
|
|
268
|
+
prefix = ""
|
|
269
|
+
else:
|
|
270
|
+
# rope_type = self.rope_type[layer_type]
|
|
271
|
+
# max_seq_len_cached = getattr(
|
|
272
|
+
# self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
|
|
273
|
+
# )
|
|
274
|
+
original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
|
|
275
|
+
prefix = f"{layer_type}_"
|
|
276
|
+
|
|
277
|
+
# Second test to translate.
|
|
278
|
+
# Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
|
|
279
|
+
# But in that case the following condition is a way to restore the original cache.
|
|
280
|
+
|
|
281
|
+
# if (
|
|
282
|
+
# seq_len < self.original_max_seq_len
|
|
283
|
+
# and self.max_seq_len_cached > self.original_max_seq_len
|
|
284
|
+
# ):
|
|
285
|
+
# self.original_inv_freq = self.original_inv_freq.to(device)
|
|
286
|
+
# self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
287
|
+
# self.max_seq_len_cached = self.original_max_seq_len
|
|
288
|
+
|
|
289
|
+
original_inv_freq = self.original_inv_freq.to(device)
|
|
290
|
+
cond = (seq_len >= self.original_max_seq_len).item()
|
|
291
|
+
# PATCHED: uses torch.cond instead of a test
|
|
292
|
+
inv_freq = torch.cond(
|
|
293
|
+
cond,
|
|
294
|
+
(lambda x, y: x.clone()),
|
|
295
|
+
(lambda x, y: y.clone()),
|
|
296
|
+
[long_inv_freq, original_inv_freq],
|
|
297
|
+
)
|
|
298
|
+
setattr(self, f"{prefix}inv_freq", inv_freq)
|
|
299
|
+
|
|
300
|
+
@wraps(rope_forward)
|
|
301
|
+
def wrapper(self, x, position_ids, layer_type=None):
|
|
302
|
+
if layer_type is None:
|
|
303
|
+
if "dynamic" in self.rope_type:
|
|
304
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
305
|
+
elif self.rope_type == "longrope":
|
|
306
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
307
|
+
return rope_forward(self, x, position_ids)
|
|
308
|
+
|
|
309
|
+
if "dynamic" in self.rope_type:
|
|
310
|
+
dynamic_frequency_update(
|
|
311
|
+
self, position_ids, device=x.device, layer_type=layer_type
|
|
312
|
+
)
|
|
313
|
+
elif self.rope_type == "longrope":
|
|
314
|
+
longrope_frequency_update(
|
|
315
|
+
self, position_ids, device=x.device, layer_type=layer_type
|
|
316
|
+
)
|
|
317
|
+
return rope_forward(self, x, position_ids, layer_type=layer_type)
|
|
318
|
+
|
|
319
|
+
return wrapper
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class common_RotaryEmbedding(torch.nn.Module):
|
|
323
|
+
# This may cause some issues.
|
|
324
|
+
# @torch.no_grad()
|
|
325
|
+
# PATCHED: the decorator
|
|
326
|
+
@patched_dynamic_rope_update
|
|
327
|
+
def forward(self, x, position_ids, layer_type=None):
|
|
328
|
+
if layer_type is not None:
|
|
329
|
+
# transformers>=5.0
|
|
330
|
+
inv_freq = getattr(self, f"{layer_type}_inv_freq")
|
|
331
|
+
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
|
|
332
|
+
else:
|
|
333
|
+
# transformers<5.0
|
|
334
|
+
inv_freq = self.inv_freq
|
|
335
|
+
attention_scaling = self.attention_scaling
|
|
336
|
+
|
|
337
|
+
inv_freq_expanded = (
|
|
338
|
+
inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
339
|
+
)
|
|
340
|
+
position_ids_expanded = position_ids[:, None, :].float()
|
|
341
|
+
|
|
342
|
+
device_type = (
|
|
343
|
+
x.device.type
|
|
344
|
+
if isinstance(x.device.type, str) and x.device.type != "mps"
|
|
345
|
+
else "cpu"
|
|
346
|
+
)
|
|
347
|
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
348
|
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
349
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
350
|
+
cos = emb.cos() * attention_scaling
|
|
351
|
+
sin = emb.sin() * attention_scaling
|
|
352
|
+
|
|
353
|
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class patched_GemmaRotaryEmbedding(common_RotaryEmbedding):
|
|
357
|
+
_PATCHES_ = ["forward"]
|
|
358
|
+
_PATCHED_CLASS_ = transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
|
|
362
|
+
|
|
363
|
+
class patched_Gemma2RotaryEmbedding(common_RotaryEmbedding):
|
|
364
|
+
_PATCHES_ = ["forward"]
|
|
365
|
+
_PATCHED_CLASS_ = transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding
|
|
366
|
+
|
|
367
|
+
class patched_Gemma3RotaryEmbedding(common_RotaryEmbedding):
|
|
368
|
+
_PATCHES_ = ["forward"]
|
|
369
|
+
_PATCHED_CLASS_ = transformers.models.gemma3.modeling_gemma3.Gemma3RotaryEmbedding
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class patched_LlamaRotaryEmbedding(common_RotaryEmbedding):
|
|
373
|
+
_PATCHES_ = ["forward"]
|
|
374
|
+
_PATCHED_CLASS_ = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class patched_MistralRotaryEmbedding(common_RotaryEmbedding):
|
|
378
|
+
_PATCHES_ = ["forward"]
|
|
379
|
+
_PATCHED_CLASS_ = transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
class patched_MixtralRotaryEmbedding(common_RotaryEmbedding):
|
|
383
|
+
_PATCHES_ = ["forward"]
|
|
384
|
+
_PATCHED_CLASS_ = transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
class patched_PhiRotaryEmbedding(common_RotaryEmbedding):
|
|
388
|
+
_PATCHES_ = ["forward"]
|
|
389
|
+
_PATCHED_CLASS_ = transformers.models.phi.modeling_phi.PhiRotaryEmbedding
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.51"):
|
|
393
|
+
|
|
394
|
+
class patched_Phi3RotaryEmbedding(common_RotaryEmbedding):
|
|
395
|
+
_PATCHES_ = ["forward"]
|
|
396
|
+
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.52"):
|
|
400
|
+
|
|
401
|
+
class patched_Phi4MultimodalRotaryEmbedding(common_RotaryEmbedding):
|
|
402
|
+
_PATCHES_ = ["forward"]
|
|
403
|
+
_PATCHED_CLASS_ = (
|
|
404
|
+
transformers.models.phi4_multimodal.modeling_phi4_multimodal.Phi4MultimodalRotaryEmbedding
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
if pv.Version(transformers.__version__) >= pv.Version("4.53"):
|
|
409
|
+
|
|
410
|
+
class patched_SmolLM3RotaryEmbedding(common_RotaryEmbedding):
|
|
411
|
+
_PATCHES_ = ["forward"]
|
|
412
|
+
_PATCHED_CLASS_ = transformers.models.smollm3.modeling_smollm3.SmolLM3RotaryEmbedding
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import torch
|
|
3
|
+
import transformers
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class patched_SamMaskDecoder(torch.nn.Module):
|
|
7
|
+
_PATCHES_ = ["forward"]
|
|
8
|
+
_PATCHED_CLASS_ = transformers.models.sam.modeling_sam.SamMaskDecoder
|
|
9
|
+
|
|
10
|
+
def forward(
|
|
11
|
+
self,
|
|
12
|
+
image_embeddings: torch.Tensor,
|
|
13
|
+
image_positional_embeddings: torch.Tensor,
|
|
14
|
+
sparse_prompt_embeddings: torch.Tensor,
|
|
15
|
+
dense_prompt_embeddings: torch.Tensor,
|
|
16
|
+
multimask_output: bool,
|
|
17
|
+
output_attentions: Optional[bool] = None,
|
|
18
|
+
attention_similarity: Optional[torch.Tensor] = None,
|
|
19
|
+
target_embedding: Optional[torch.Tensor] = None,
|
|
20
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
21
|
+
"""
|
|
22
|
+
Predict masks given image and prompt embeddings.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
image_embeddings (`torch.Tensor`):
|
|
26
|
+
the embeddings from the image encoder
|
|
27
|
+
image_positional_embedding (`torch.Tensor`):
|
|
28
|
+
positional encoding with the shape of image_embeddings
|
|
29
|
+
sparse_prompt_embeddings (`torch.Tensor`):
|
|
30
|
+
The embeddings of the points and boxes
|
|
31
|
+
dense_prompt_embeddings (`torch.Tensor`):
|
|
32
|
+
the embeddings of the mask inputs
|
|
33
|
+
multimask_output (bool):
|
|
34
|
+
Whether to return multiple masks or a single mask.
|
|
35
|
+
output_attentions (bool, *optional*):
|
|
36
|
+
Whether or not to return the attentions tensors of all attention layers.
|
|
37
|
+
"""
|
|
38
|
+
batch_size, num_channels, height, width = image_embeddings.shape
|
|
39
|
+
point_batch_size = sparse_prompt_embeddings.shape[1]
|
|
40
|
+
# Concatenate output tokens
|
|
41
|
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
42
|
+
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
|
|
43
|
+
|
|
44
|
+
# torch.cond rewrites the if-else logic to handle empty sparse_prompt_embeddings
|
|
45
|
+
# torch.any is needed to avoid data-dependent control flow
|
|
46
|
+
# with sparse_prompt_embeddings.sum().item() != 0
|
|
47
|
+
def sparse_prompt_embeddings_is_not_empty(output_tokens, sparse_prompt_embeddings):
|
|
48
|
+
return torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
|
|
49
|
+
|
|
50
|
+
def sparse_prompt_embeddings_is_empty(output_tokens, sparse_prompt_embeddings):
|
|
51
|
+
return output_tokens.clone()
|
|
52
|
+
|
|
53
|
+
tokens = torch.cond(
|
|
54
|
+
torch.any(sparse_prompt_embeddings != 0),
|
|
55
|
+
sparse_prompt_embeddings_is_not_empty,
|
|
56
|
+
sparse_prompt_embeddings_is_empty,
|
|
57
|
+
[output_tokens, sparse_prompt_embeddings],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
point_embeddings = tokens.to(self.iou_token.weight.dtype)
|
|
61
|
+
|
|
62
|
+
# Expand per-image data in batch direction to be per-point
|
|
63
|
+
image_embeddings = image_embeddings + dense_prompt_embeddings
|
|
64
|
+
image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
|
|
65
|
+
image_positional_embeddings = image_positional_embeddings.repeat_interleave(
|
|
66
|
+
point_batch_size, 0
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Run the transformer, image_positional_embedding are consumed
|
|
70
|
+
torch._check(point_embeddings.shape[0] != 0)
|
|
71
|
+
torch._check(point_embeddings.shape[1] != 0)
|
|
72
|
+
torch._check(point_embeddings.shape[2] != 0)
|
|
73
|
+
torch._check(point_embeddings.shape[3] != 0)
|
|
74
|
+
embeddings_attentions = self.transformer(
|
|
75
|
+
point_embeddings=point_embeddings,
|
|
76
|
+
image_embeddings=image_embeddings,
|
|
77
|
+
image_positional_embeddings=image_positional_embeddings,
|
|
78
|
+
attention_similarity=attention_similarity,
|
|
79
|
+
target_embedding=target_embedding,
|
|
80
|
+
output_attentions=output_attentions,
|
|
81
|
+
)
|
|
82
|
+
point_embedding, image_embeddings = embeddings_attentions[:2]
|
|
83
|
+
iou_token_out = torch.select(point_embedding, dim=2, index=0)
|
|
84
|
+
mask_tokens_out = torch.narrow(
|
|
85
|
+
point_embedding, dim=2, start=1, length=self.num_mask_tokens
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Upscale mask embeddings and predict masks using the mask tokens
|
|
89
|
+
image_embeddings = image_embeddings.transpose(2, 3).reshape(
|
|
90
|
+
batch_size * point_batch_size, num_channels, height, width
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
upscaled_embedding = self.upscale_conv1(image_embeddings)
|
|
94
|
+
upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
|
|
95
|
+
upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
|
|
96
|
+
|
|
97
|
+
hyper_in_list = []
|
|
98
|
+
for i in range(self.num_mask_tokens):
|
|
99
|
+
current_mlp = self.output_hypernetworks_mlps[i]
|
|
100
|
+
hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
|
|
101
|
+
hyper_in = torch.stack(hyper_in_list, dim=2)
|
|
102
|
+
|
|
103
|
+
_, num_channels, height, width = upscaled_embedding.shape
|
|
104
|
+
upscaled_embedding = upscaled_embedding.reshape(
|
|
105
|
+
batch_size, point_batch_size, num_channels, height * width
|
|
106
|
+
)
|
|
107
|
+
masks = (hyper_in @ upscaled_embedding).reshape(
|
|
108
|
+
batch_size, point_batch_size, -1, height, width
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Generate mask quality predictions
|
|
112
|
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
|
113
|
+
|
|
114
|
+
# Select the correct mask or masks for output
|
|
115
|
+
if multimask_output:
|
|
116
|
+
mask_slice = slice(1, None)
|
|
117
|
+
else:
|
|
118
|
+
mask_slice = slice(0, 1)
|
|
119
|
+
masks = masks[:, :, mask_slice, :, :]
|
|
120
|
+
iou_pred = iou_pred[:, :, mask_slice]
|
|
121
|
+
|
|
122
|
+
outputs = (masks, iou_pred)
|
|
123
|
+
|
|
124
|
+
if len(embeddings_attentions) == 2:
|
|
125
|
+
# transformers==4.54
|
|
126
|
+
return outputs
|
|
127
|
+
|
|
128
|
+
if output_attentions and len(embeddings_attentions) > 2:
|
|
129
|
+
outputs = outputs + (embeddings_attentions[2],) # noqa: RUF005
|
|
130
|
+
else:
|
|
131
|
+
outputs = outputs + (None,) # noqa: RUF005
|
|
132
|
+
return outputs
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _has_transformers(version: str) -> bool:
|
|
5
|
+
import packaging.version as pv
|
|
6
|
+
import transformers
|
|
7
|
+
|
|
8
|
+
return pv.Version(transformers.__version__) >= pv.Version(version)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _is_torchdynamo_exporting() -> bool:
|
|
12
|
+
"""
|
|
13
|
+
Tells if :epkg:`torch` is exporting a model.
|
|
14
|
+
Relies on ``torch.compiler.is_exporting()``.
|
|
15
|
+
"""
|
|
16
|
+
if not hasattr(torch.compiler, "is_exporting"):
|
|
17
|
+
# torch.compiler.is_exporting requires torch>=2.7
|
|
18
|
+
return False
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
return torch.compiler.is_exporting()
|
|
22
|
+
except Exception:
|
|
23
|
+
try:
|
|
24
|
+
import torch._dynamo as dynamo
|
|
25
|
+
|
|
26
|
+
return dynamo.is_exporting() # type: ignore
|
|
27
|
+
except Exception:
|
|
28
|
+
return False
|