optimum-rbln 0.9.4a2__py3-none-any.whl → 0.10.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +230 -67
- optimum/rbln/diffusers/models/controlnet.py +2 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -2
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +3 -12
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +2 -2
- optimum/rbln/modeling_base.py +11 -10
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +44 -0
- optimum/rbln/transformers/modeling_attention_utils.py +124 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +38 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +7 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -1
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +40 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +144 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +120 -128
- optimum/rbln/transformers/models/detr/__init__.py +23 -0
- optimum/rbln/transformers/models/detr/configuration_detr.py +38 -0
- optimum/rbln/transformers/models/detr/modeling_detr.py +53 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +2 -7
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +5 -177
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +42 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +168 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/mixtral/configuration_mixtral.py +38 -0
- optimum/rbln/transformers/models/mixtral/mixtral_architecture.py +76 -0
- optimum/rbln/transformers/models/mixtral/modeling_mixtral.py +68 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +9 -5
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +13 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +10 -4
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/generation_whisper.py +8 -8
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/deprecation.py +78 -1
- optimum/rbln/utils/hub.py +93 -2
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +12 -8
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/RECORD +107 -81
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.10.0.post1.dist-info}/licenses/LICENSE +0 -0
|
@@ -27,7 +27,7 @@
|
|
|
27
27
|
# limitations under the License.
|
|
28
28
|
|
|
29
29
|
import math
|
|
30
|
-
from typing import Optional
|
|
30
|
+
from typing import Optional
|
|
31
31
|
|
|
32
32
|
import torch
|
|
33
33
|
from transformers import PretrainedConfig
|
|
@@ -35,13 +35,16 @@ from transformers import PretrainedConfig
|
|
|
35
35
|
|
|
36
36
|
def _compute_default_rope_parameters(
|
|
37
37
|
config: Optional[PretrainedConfig] = None,
|
|
38
|
+
device: Optional["torch.device"] = None,
|
|
38
39
|
seq_len: Optional[int] = None,
|
|
39
|
-
) ->
|
|
40
|
+
) -> tuple["torch.Tensor", float]:
|
|
40
41
|
"""
|
|
41
42
|
Computes the inverse frequencies according to the original RoPE implementation
|
|
42
43
|
Args:
|
|
43
44
|
config ([`~transformers.PretrainedConfig`]):
|
|
44
45
|
The model configuration.
|
|
46
|
+
device (`torch.device`):
|
|
47
|
+
The device to use for initialization of the inverse frequencies.
|
|
45
48
|
seq_len (`int`, *optional*):
|
|
46
49
|
The current sequence length. Unused for this type of RoPE.
|
|
47
50
|
Returns:
|
|
@@ -50,40 +53,38 @@ def _compute_default_rope_parameters(
|
|
|
50
53
|
"""
|
|
51
54
|
base = config.rope_theta
|
|
52
55
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
53
|
-
head_dim = (
|
|
54
|
-
config.head_dim
|
|
55
|
-
if hasattr(config, "head_dim") and config.head_dim is not None
|
|
56
|
-
else config.hidden_size // config.num_attention_heads
|
|
57
|
-
)
|
|
56
|
+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
|
58
57
|
dim = int(head_dim * partial_rotary_factor)
|
|
59
58
|
|
|
60
59
|
attention_factor = 1.0 # Unused in this type of RoPE
|
|
61
60
|
|
|
62
61
|
# Compute the inverse frequencies
|
|
63
|
-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float
|
|
62
|
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
|
|
64
63
|
return inv_freq, attention_factor
|
|
65
64
|
|
|
66
65
|
|
|
67
66
|
def _compute_linear_scaling_rope_parameters(
|
|
68
67
|
config: Optional[PretrainedConfig] = None,
|
|
68
|
+
device: Optional["torch.device"] = None,
|
|
69
69
|
seq_len: Optional[int] = None,
|
|
70
|
-
) ->
|
|
70
|
+
) -> tuple["torch.Tensor", float]:
|
|
71
71
|
"""
|
|
72
72
|
Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
|
|
73
73
|
Args:
|
|
74
74
|
config ([`~transformers.PretrainedConfig`]):
|
|
75
75
|
The model configuration.
|
|
76
|
+
device (`torch.device`):
|
|
77
|
+
The device to use for initialization of the inverse frequencies.
|
|
76
78
|
seq_len (`int`, *optional*):
|
|
77
79
|
The current sequence length. Unused for this type of RoPE.
|
|
78
80
|
Returns:
|
|
79
81
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
80
82
|
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
|
81
83
|
"""
|
|
82
|
-
|
|
83
84
|
factor = config.rope_scaling["factor"]
|
|
84
85
|
|
|
85
86
|
# Gets the default RoPE parameters
|
|
86
|
-
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
|
87
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
|
|
87
88
|
|
|
88
89
|
# Then applies linear scaling to the frequencies.
|
|
89
90
|
# NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
|
|
@@ -94,20 +95,23 @@ def _compute_linear_scaling_rope_parameters(
|
|
|
94
95
|
|
|
95
96
|
def _compute_dynamic_ntk_parameters(
|
|
96
97
|
config: Optional[PretrainedConfig] = None,
|
|
98
|
+
device: Optional["torch.device"] = None,
|
|
97
99
|
seq_len: Optional[int] = None,
|
|
98
|
-
) ->
|
|
100
|
+
) -> tuple["torch.Tensor", float]:
|
|
99
101
|
"""
|
|
100
102
|
Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
|
|
101
103
|
Args:
|
|
102
104
|
config ([`~transformers.PretrainedConfig`]):
|
|
103
105
|
The model configuration.
|
|
106
|
+
device (`torch.device`):
|
|
107
|
+
The device to use for initialization of the inverse frequencies.
|
|
104
108
|
seq_len (`int`, *optional*):
|
|
105
109
|
The current sequence length, used to update the dynamic RoPE at inference time.
|
|
106
110
|
Returns:
|
|
107
111
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
108
112
|
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
|
109
113
|
"""
|
|
110
|
-
|
|
114
|
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
|
111
115
|
base = config.rope_theta
|
|
112
116
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
113
117
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
@@ -117,6 +121,17 @@ def _compute_dynamic_ntk_parameters(
|
|
|
117
121
|
|
|
118
122
|
attention_factor = 1.0 # Unused in this type of RoPE
|
|
119
123
|
|
|
124
|
+
# seq_len: default to max_position_embeddings, e.g. at init time
|
|
125
|
+
if seq_len is None:
|
|
126
|
+
seq_len = max_position_embeddings
|
|
127
|
+
elif isinstance(seq_len, torch.Tensor):
|
|
128
|
+
seq_len = torch.maximum(
|
|
129
|
+
seq_len,
|
|
130
|
+
torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
seq_len = max(seq_len, max_position_embeddings)
|
|
134
|
+
|
|
120
135
|
# Process with chunk_size to reduce precesion error
|
|
121
136
|
chunk_size = 4096
|
|
122
137
|
chunks = (seq_len + chunk_size - 1) // chunk_size
|
|
@@ -140,13 +155,17 @@ def _compute_dynamic_ntk_parameters(
|
|
|
140
155
|
return final_inv_freq, attention_factor
|
|
141
156
|
|
|
142
157
|
|
|
143
|
-
def _compute_yarn_parameters(
|
|
158
|
+
def _compute_yarn_parameters(
|
|
159
|
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
|
|
160
|
+
) -> tuple["torch.Tensor", float]:
|
|
144
161
|
"""
|
|
145
162
|
Computes the inverse frequencies with NTK scaling. Please refer to the
|
|
146
|
-
[original paper](https://
|
|
163
|
+
[original paper](https://huggingface.co/papers/2309.00071)
|
|
147
164
|
Args:
|
|
148
165
|
config ([`~transformers.PretrainedConfig`]):
|
|
149
166
|
The model configuration.
|
|
167
|
+
device (`torch.device`):
|
|
168
|
+
The device to use for initialization of the inverse frequencies.
|
|
150
169
|
seq_len (`int`, *optional*):
|
|
151
170
|
The current sequence length. Unused for this type of RoPE.
|
|
152
171
|
Returns:
|
|
@@ -158,13 +177,25 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
|
|
|
158
177
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
159
178
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
160
179
|
dim = int(head_dim * partial_rotary_factor)
|
|
161
|
-
max_position_embeddings = config.max_position_embeddings
|
|
162
180
|
factor = config.rope_scaling["factor"]
|
|
181
|
+
attention_factor = config.rope_scaling.get("attention_factor")
|
|
182
|
+
mscale = config.rope_scaling.get("mscale")
|
|
183
|
+
mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
|
|
184
|
+
original_max_position_embeddings = (
|
|
185
|
+
config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def get_mscale(scale, mscale=1):
|
|
189
|
+
if scale <= 1:
|
|
190
|
+
return 1.0
|
|
191
|
+
return 0.1 * mscale * math.log(scale) + 1.0
|
|
163
192
|
|
|
164
193
|
# Sets the attention factor as suggested in the paper
|
|
165
|
-
attention_factor = config.rope_scaling.get("attention_factor")
|
|
166
194
|
if attention_factor is None:
|
|
167
|
-
|
|
195
|
+
if mscale and mscale_all_dim:
|
|
196
|
+
attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
|
|
197
|
+
else:
|
|
198
|
+
attention_factor = get_mscale(factor)
|
|
168
199
|
|
|
169
200
|
# Optional config options
|
|
170
201
|
# beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
|
|
@@ -176,10 +207,13 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
|
|
|
176
207
|
"""Inverse dimension formula to find the dimension based on the number of rotations"""
|
|
177
208
|
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
|
|
178
209
|
|
|
179
|
-
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
|
|
210
|
+
def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
|
|
180
211
|
"""Find dimension range bounds based on rotations"""
|
|
181
|
-
low =
|
|
182
|
-
high =
|
|
212
|
+
low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
|
|
213
|
+
high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
|
|
214
|
+
if truncate:
|
|
215
|
+
low = math.floor(low)
|
|
216
|
+
high = math.ceil(high)
|
|
183
217
|
return max(low, 0), min(high, dim - 1)
|
|
184
218
|
|
|
185
219
|
def linear_ramp_factor(min, max, dim):
|
|
@@ -192,38 +226,40 @@ def _compute_yarn_parameters(config: PretrainedConfig, seq_len: Optional[int] =
|
|
|
192
226
|
|
|
193
227
|
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
|
|
194
228
|
# to expand the possible context length. In other words, interpolation = apply scaling factor.
|
|
195
|
-
pos_freqs = base ** (torch.arange(0, dim, 2).float
|
|
229
|
+
pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
|
|
196
230
|
inv_freq_extrapolation = 1.0 / pos_freqs
|
|
197
231
|
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
|
198
232
|
|
|
199
|
-
|
|
233
|
+
truncate = config.rope_scaling.get("truncate", True)
|
|
234
|
+
low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
|
|
200
235
|
|
|
201
236
|
# Get n-dimensional rotational scaling corrected for extrapolation
|
|
202
|
-
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float
|
|
237
|
+
inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
|
|
203
238
|
inv_freq = (
|
|
204
239
|
inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
|
|
205
240
|
+ inv_freq_extrapolation * inv_freq_extrapolation_factor
|
|
206
241
|
)
|
|
207
|
-
|
|
208
242
|
return inv_freq, attention_factor
|
|
209
243
|
|
|
210
244
|
|
|
211
245
|
def _compute_longrope_parameters(
|
|
212
|
-
config: PretrainedConfig, seq_len: Optional[int] = None
|
|
213
|
-
) ->
|
|
246
|
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
|
|
247
|
+
) -> tuple["torch.Tensor", float]:
|
|
214
248
|
"""
|
|
215
249
|
Computes the inverse frequencies with LongRoPE scaling. Please refer to the
|
|
216
250
|
[original implementation](https://github.com/microsoft/LongRoPE)
|
|
217
251
|
Args:
|
|
218
252
|
config ([`~transformers.PretrainedConfig`]):
|
|
219
253
|
The model configuration.
|
|
254
|
+
device (`torch.device`):
|
|
255
|
+
The device to use for initialization of the inverse frequencies.
|
|
220
256
|
seq_len (`int`, *optional*):
|
|
221
|
-
The current sequence length.
|
|
257
|
+
The current sequence length.
|
|
222
258
|
Returns:
|
|
223
259
|
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
|
224
260
|
post-processing scaling factor applied to the computed cos/sin.
|
|
225
261
|
"""
|
|
226
|
-
|
|
262
|
+
# TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
|
|
227
263
|
base = config.rope_theta
|
|
228
264
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
|
229
265
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
@@ -237,40 +273,40 @@ def _compute_longrope_parameters(
|
|
|
237
273
|
# `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
|
|
238
274
|
# values to compute the default attention scaling factor, instead of using `factor`.
|
|
239
275
|
if hasattr(config, "original_max_position_embeddings"):
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
factor = expanded_max_position_embeddings / max_position_embeddings
|
|
276
|
+
original_max_position_embeddings = config.original_max_position_embeddings
|
|
277
|
+
factor = config.max_position_embeddings / config.original_max_position_embeddings
|
|
243
278
|
else:
|
|
244
|
-
|
|
245
|
-
expanded_max_position_embeddings = max_position_embeddings * factor
|
|
279
|
+
original_max_position_embeddings = config.max_position_embeddings
|
|
246
280
|
|
|
247
281
|
# Sets the attention factor as suggested in the paper
|
|
248
282
|
if attention_factor is None:
|
|
249
283
|
if factor <= 1.0:
|
|
250
284
|
attention_factor = 1.0
|
|
251
285
|
else:
|
|
252
|
-
attention_factor = math.sqrt(1 + math.log(factor) / math.log(
|
|
286
|
+
attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
|
|
253
287
|
|
|
254
288
|
# Compute the inverse frequencies -- scaled based on the target sequence length
|
|
255
|
-
if
|
|
256
|
-
ext_factors = torch.tensor(long_factor, dtype=torch.float32)
|
|
289
|
+
if seq_len and seq_len > original_max_position_embeddings:
|
|
290
|
+
ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
|
|
257
291
|
else:
|
|
258
|
-
ext_factors = torch.tensor(short_factor, dtype=torch.float32)
|
|
259
|
-
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
|
|
292
|
+
ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
|
|
293
|
+
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
|
|
260
294
|
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
|
|
261
295
|
|
|
262
296
|
return inv_freq, attention_factor
|
|
263
297
|
|
|
264
298
|
|
|
265
299
|
def _compute_llama3_parameters(
|
|
266
|
-
config: PretrainedConfig, seq_len: Optional[int] = None
|
|
267
|
-
) ->
|
|
300
|
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
|
|
301
|
+
) -> tuple["torch.Tensor", float]:
|
|
268
302
|
"""
|
|
269
303
|
Computes the inverse frequencies for llama 3.1.
|
|
270
304
|
|
|
271
305
|
Args:
|
|
272
306
|
config ([`~transformers.PretrainedConfig`]):
|
|
273
307
|
The model configuration.
|
|
308
|
+
device (`torch.device`):
|
|
309
|
+
The device to use for initialization of the inverse frequencies.
|
|
274
310
|
seq_len (`int`, *optional*):
|
|
275
311
|
The current sequence length. Unused for this type of RoPE.
|
|
276
312
|
Returns:
|
|
@@ -278,7 +314,7 @@ def _compute_llama3_parameters(
|
|
|
278
314
|
post-processing scaling factor applied to the computed cos/sin.
|
|
279
315
|
"""
|
|
280
316
|
# Gets the default RoPE parameters
|
|
281
|
-
inv_freq, attention_factor = _compute_default_rope_parameters(config, seq_len)
|
|
317
|
+
inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
|
|
282
318
|
|
|
283
319
|
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
|
284
320
|
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
|
@@ -79,6 +79,10 @@ _import_structure = {
|
|
|
79
79
|
"RBLNColQwen2ForRetrieval",
|
|
80
80
|
"RBLNColQwen2ForRetrievalConfig",
|
|
81
81
|
],
|
|
82
|
+
"detr": [
|
|
83
|
+
"RBLNDetrForObjectDetection",
|
|
84
|
+
"RBLNDetrForObjectDetectionConfig",
|
|
85
|
+
],
|
|
82
86
|
"distilbert": [
|
|
83
87
|
"RBLNDistilBertForQuestionAnswering",
|
|
84
88
|
"RBLNDistilBertForQuestionAnsweringConfig",
|
|
@@ -88,12 +92,16 @@ _import_structure = {
|
|
|
88
92
|
"RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
|
|
89
93
|
"RBLNQwen2_5_VLForConditionalGeneration",
|
|
90
94
|
"RBLNQwen2_5_VLForConditionalGenerationConfig",
|
|
95
|
+
"RBLNQwen2_5_VLModel",
|
|
96
|
+
"RBLNQwen2_5_VLModelConfig",
|
|
91
97
|
],
|
|
92
98
|
"qwen2_vl": [
|
|
93
99
|
"RBLNQwen2VisionTransformerPretrainedModel",
|
|
94
100
|
"RBLNQwen2VisionTransformerPretrainedModelConfig",
|
|
95
101
|
"RBLNQwen2VLForConditionalGeneration",
|
|
96
102
|
"RBLNQwen2VLForConditionalGenerationConfig",
|
|
103
|
+
"RBLNQwen2VLModel",
|
|
104
|
+
"RBLNQwen2VLModelConfig",
|
|
97
105
|
],
|
|
98
106
|
"decoderonly": [
|
|
99
107
|
"RBLNDecoderOnlyModelConfig",
|
|
@@ -110,12 +118,14 @@ _import_structure = {
|
|
|
110
118
|
],
|
|
111
119
|
"exaone": ["RBLNExaoneForCausalLM", "RBLNExaoneForCausalLMConfig"],
|
|
112
120
|
"gemma": ["RBLNGemmaForCausalLM", "RBLNGemmaForCausalLMConfig", "RBLNGemmaModel", "RBLNGemmaModelConfig"],
|
|
121
|
+
"gemma2": ["RBLNGemma2ForCausalLM", "RBLNGemma2ForCausalLMConfig", "RBLNGemma2Model", "RBLNGemma2ModelConfig"],
|
|
113
122
|
"gemma3": [
|
|
114
123
|
"RBLNGemma3ForCausalLM",
|
|
115
124
|
"RBLNGemma3ForCausalLMConfig",
|
|
116
125
|
"RBLNGemma3ForConditionalGeneration",
|
|
117
126
|
"RBLNGemma3ForConditionalGenerationConfig",
|
|
118
127
|
],
|
|
128
|
+
"gpt_oss": ["RBLNGptOssForCausalLM", "RBLNGptOssForCausalLMConfig"],
|
|
119
129
|
"gpt2": ["RBLNGPT2LMHeadModel", "RBLNGPT2LMHeadModelConfig", "RBLNGPT2Model", "RBLNGPT2ModelConfig"],
|
|
120
130
|
"idefics3": [
|
|
121
131
|
"RBLNIdefics3VisionTransformer",
|
|
@@ -132,6 +142,12 @@ _import_structure = {
|
|
|
132
142
|
"RBLNPegasusForConditionalGenerationConfig",
|
|
133
143
|
"RBLNPegasusModelConfig",
|
|
134
144
|
],
|
|
145
|
+
"paligemma": [
|
|
146
|
+
"RBLNPaliGemmaForConditionalGeneration",
|
|
147
|
+
"RBLNPaliGemmaForConditionalGenerationConfig",
|
|
148
|
+
"RBLNPaliGemmaModel",
|
|
149
|
+
"RBLNPaliGemmaModelConfig",
|
|
150
|
+
],
|
|
135
151
|
"llava_next": ["RBLNLlavaNextForConditionalGeneration", "RBLNLlavaNextForConditionalGenerationConfig"],
|
|
136
152
|
"midm": ["RBLNMidmLMHeadModel", "RBLNMidmLMHeadModelConfig"],
|
|
137
153
|
"pixtral": ["RBLNPixtralVisionModel", "RBLNPixtralVisionModelConfig"],
|
|
@@ -143,7 +159,9 @@ _import_structure = {
|
|
|
143
159
|
],
|
|
144
160
|
"phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig", "RBLNPhiModel", "RBLNPhiModelConfig"],
|
|
145
161
|
"qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig", "RBLNQwen2Model", "RBLNQwen2ModelConfig"],
|
|
162
|
+
"qwen2_moe": ["RBLNQwen2MoeForCausalLM", "RBLNQwen2MoeForCausalLMConfig"],
|
|
146
163
|
"qwen3": ["RBLNQwen3ForCausalLM", "RBLNQwen3ForCausalLMConfig", "RBLNQwen3Model", "RBLNQwen3ModelConfig"],
|
|
164
|
+
"qwen3_moe": ["RBLNQwen3MoeForCausalLM", "RBLNQwen3MoeForCausalLMConfig"],
|
|
147
165
|
"resnet": ["RBLNResNetForImageClassification", "RBLNResNetForImageClassificationConfig"],
|
|
148
166
|
"roberta": [
|
|
149
167
|
"RBLNRobertaForMaskedLM",
|
|
@@ -155,6 +173,10 @@ _import_structure = {
|
|
|
155
173
|
"RBLNSiglipVisionModel",
|
|
156
174
|
"RBLNSiglipVisionModelConfig",
|
|
157
175
|
],
|
|
176
|
+
"mixtral": [
|
|
177
|
+
"RBLNMixtralForCausalLM",
|
|
178
|
+
"RBLNMixtralForCausalLMConfig",
|
|
179
|
+
],
|
|
158
180
|
"swin": [
|
|
159
181
|
"RBLNSwinBackbone",
|
|
160
182
|
"RBLNSwinBackboneConfig",
|
|
@@ -250,10 +272,12 @@ if TYPE_CHECKING:
|
|
|
250
272
|
RBLNLoRAConfig,
|
|
251
273
|
)
|
|
252
274
|
from .depth_anything import RBLNDepthAnythingForDepthEstimation, RBLNDepthAnythingForDepthEstimationConfig
|
|
275
|
+
from .detr import RBLNDetrForObjectDetection, RBLNDetrForObjectDetectionConfig
|
|
253
276
|
from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
|
|
254
277
|
from .dpt import RBLNDPTForDepthEstimation, RBLNDPTForDepthEstimationConfig
|
|
255
278
|
from .exaone import RBLNExaoneForCausalLM, RBLNExaoneForCausalLMConfig
|
|
256
279
|
from .gemma import RBLNGemmaForCausalLM, RBLNGemmaForCausalLMConfig, RBLNGemmaModel, RBLNGemmaModelConfig
|
|
280
|
+
from .gemma2 import RBLNGemma2ForCausalLM, RBLNGemma2ForCausalLMConfig, RBLNGemma2Model, RBLNGemma2ModelConfig
|
|
257
281
|
from .gemma3 import (
|
|
258
282
|
RBLNGemma3ForCausalLM,
|
|
259
283
|
RBLNGemma3ForCausalLMConfig,
|
|
@@ -261,6 +285,7 @@ if TYPE_CHECKING:
|
|
|
261
285
|
RBLNGemma3ForConditionalGenerationConfig,
|
|
262
286
|
)
|
|
263
287
|
from .gpt2 import RBLNGPT2LMHeadModel, RBLNGPT2LMHeadModelConfig, RBLNGPT2Model, RBLNGPT2ModelConfig
|
|
288
|
+
from .gpt_oss import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
|
|
264
289
|
from .grounding_dino import (
|
|
265
290
|
RBLNGroundingDinoDecoder,
|
|
266
291
|
RBLNGroundingDinoDecoderConfig,
|
|
@@ -280,7 +305,14 @@ if TYPE_CHECKING:
|
|
|
280
305
|
from .llava_next import RBLNLlavaNextForConditionalGeneration, RBLNLlavaNextForConditionalGenerationConfig
|
|
281
306
|
from .midm import RBLNMidmLMHeadModel, RBLNMidmLMHeadModelConfig
|
|
282
307
|
from .mistral import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig, RBLNMistralModel, RBLNMistralModelConfig
|
|
308
|
+
from .mixtral import RBLNMixtralForCausalLM, RBLNMixtralForCausalLMConfig
|
|
283
309
|
from .opt import RBLNOPTForCausalLM, RBLNOPTForCausalLMConfig, RBLNOPTModel, RBLNOPTModelConfig
|
|
310
|
+
from .paligemma import (
|
|
311
|
+
RBLNPaliGemmaForConditionalGeneration,
|
|
312
|
+
RBLNPaliGemmaForConditionalGenerationConfig,
|
|
313
|
+
RBLNPaliGemmaModel,
|
|
314
|
+
RBLNPaliGemmaModelConfig,
|
|
315
|
+
)
|
|
284
316
|
from .pegasus import (
|
|
285
317
|
RBLNPegasusForConditionalGeneration,
|
|
286
318
|
RBLNPegasusForConditionalGenerationConfig,
|
|
@@ -295,14 +327,20 @@ if TYPE_CHECKING:
|
|
|
295
327
|
RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
|
|
296
328
|
RBLNQwen2_5_VLForConditionalGeneration,
|
|
297
329
|
RBLNQwen2_5_VLForConditionalGenerationConfig,
|
|
330
|
+
RBLNQwen2_5_VLModel,
|
|
331
|
+
RBLNQwen2_5_VLModelConfig,
|
|
298
332
|
)
|
|
333
|
+
from .qwen2_moe import RBLNQwen2MoeForCausalLM, RBLNQwen2MoeForCausalLMConfig
|
|
299
334
|
from .qwen2_vl import (
|
|
300
335
|
RBLNQwen2VisionTransformerPretrainedModel,
|
|
301
336
|
RBLNQwen2VisionTransformerPretrainedModelConfig,
|
|
302
337
|
RBLNQwen2VLForConditionalGeneration,
|
|
303
338
|
RBLNQwen2VLForConditionalGenerationConfig,
|
|
339
|
+
RBLNQwen2VLModel,
|
|
340
|
+
RBLNQwen2VLModelConfig,
|
|
304
341
|
)
|
|
305
342
|
from .qwen3 import RBLNQwen3ForCausalLM, RBLNQwen3ForCausalLMConfig, RBLNQwen3Model, RBLNQwen3ModelConfig
|
|
343
|
+
from .qwen3_moe import RBLNQwen3MoeForCausalLM, RBLNQwen3MoeForCausalLMConfig
|
|
306
344
|
from .resnet import RBLNResNetForImageClassification, RBLNResNetForImageClassificationConfig
|
|
307
345
|
from .roberta import (
|
|
308
346
|
RBLNRobertaForMaskedLM,
|
|
@@ -184,8 +184,8 @@ class _BaseAutoModelClass:
|
|
|
184
184
|
model_id: Union[str, Path],
|
|
185
185
|
export: bool = None,
|
|
186
186
|
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
187
|
-
**kwargs,
|
|
188
|
-
):
|
|
187
|
+
**kwargs: Optional[Dict[str, Any]],
|
|
188
|
+
) -> RBLNBaseModel:
|
|
189
189
|
"""
|
|
190
190
|
Load an RBLN-accelerated model from a pretrained checkpoint or a compiled RBLN artifact.
|
|
191
191
|
|
|
@@ -213,7 +213,7 @@ class _BaseAutoModelClass:
|
|
|
213
213
|
`token`, `trust_remote_code`, `cache_dir`, `subfolder`, `local_files_only`).
|
|
214
214
|
|
|
215
215
|
Returns:
|
|
216
|
-
An instantiated RBLN model ready for inference on RBLN NPUs.
|
|
216
|
+
RBLNBaseModel: An instantiated RBLN model ready for inference on RBLN NPUs.
|
|
217
217
|
"""
|
|
218
218
|
rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
|
|
219
219
|
return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
|
|
@@ -60,10 +60,10 @@ class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
|
|
60
60
|
class BartDecoder(Seq2SeqDecoder):
|
|
61
61
|
has_pos_emb = True
|
|
62
62
|
|
|
63
|
-
def __post_init__(self):
|
|
64
|
-
self.embed_positions =
|
|
65
|
-
self.layernorm_embedding =
|
|
66
|
-
self.embed_scale = getattr(
|
|
63
|
+
def __post_init__(self, model: nn.Module):
|
|
64
|
+
self.embed_positions = model.embed_positions
|
|
65
|
+
self.layernorm_embedding = model.layernorm_embedding
|
|
66
|
+
self.embed_scale = getattr(model, "embed_scale", None)
|
|
67
67
|
|
|
68
68
|
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
|
|
69
69
|
if attention_mask is not None:
|
|
@@ -112,11 +112,11 @@ class BartLayerFF(nn.Module):
|
|
|
112
112
|
|
|
113
113
|
|
|
114
114
|
class BartDecoderLayer(Seq2SeqDecoderLayer):
|
|
115
|
-
def __post_init__(self):
|
|
116
|
-
self.self_attn_layer_norm =
|
|
117
|
-
self.encoder_attn =
|
|
118
|
-
self.encoder_attn_layer_norm =
|
|
119
|
-
self.ff_layer = BartLayerFF(
|
|
115
|
+
def __post_init__(self, decoder_layer: nn.Module):
|
|
116
|
+
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
|
117
|
+
self.encoder_attn = decoder_layer.encoder_attn
|
|
118
|
+
self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
|
|
119
|
+
self.ff_layer = BartLayerFF(decoder_layer)
|
|
120
120
|
|
|
121
121
|
def pre_self_attn_layer_norm(self, hidden_states):
|
|
122
122
|
return hidden_states
|
|
@@ -132,13 +132,13 @@ class BartDecoderLayer(Seq2SeqDecoderLayer):
|
|
|
132
132
|
|
|
133
133
|
|
|
134
134
|
class BartSelfAttention(Seq2SeqSelfAttention):
|
|
135
|
-
def __post_init__(self, use_attention_mask: bool = True):
|
|
136
|
-
self.q_proj =
|
|
137
|
-
self.k_proj =
|
|
138
|
-
self.v_proj =
|
|
139
|
-
self.out_proj =
|
|
140
|
-
self.num_heads =
|
|
141
|
-
self.head_dim =
|
|
135
|
+
def __post_init__(self, attn: nn.Module, use_attention_mask: bool = True):
|
|
136
|
+
self.q_proj = attn.q_proj
|
|
137
|
+
self.k_proj = attn.k_proj
|
|
138
|
+
self.v_proj = attn.v_proj
|
|
139
|
+
self.out_proj = attn.out_proj
|
|
140
|
+
self.num_heads = attn.num_heads
|
|
141
|
+
self.head_dim = attn.embed_dim // attn.num_heads
|
|
142
142
|
self.scaling = self.head_dim**-0.5
|
|
143
143
|
if use_attention_mask:
|
|
144
144
|
self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
|
|
@@ -153,11 +153,11 @@ class BartSelfAttention(Seq2SeqSelfAttention):
|
|
|
153
153
|
|
|
154
154
|
|
|
155
155
|
class BartCrossAttention(Seq2SeqCrossAttention):
|
|
156
|
-
def __post_init__(self):
|
|
157
|
-
self.q_proj =
|
|
158
|
-
self.k_proj =
|
|
159
|
-
self.v_proj =
|
|
160
|
-
self.out_proj =
|
|
161
|
-
self.num_heads =
|
|
162
|
-
self.head_dim =
|
|
163
|
-
self.embed_dim =
|
|
156
|
+
def __post_init__(self, attn: nn.Module):
|
|
157
|
+
self.q_proj = attn.q_proj
|
|
158
|
+
self.k_proj = attn.k_proj
|
|
159
|
+
self.v_proj = attn.v_proj
|
|
160
|
+
self.out_proj = attn.out_proj
|
|
161
|
+
self.num_heads = attn.num_heads
|
|
162
|
+
self.head_dim = attn.embed_dim // attn.num_heads
|
|
163
|
+
self.embed_dim = attn.embed_dim
|
|
@@ -32,8 +32,13 @@ class RBLNBlip2VisionModelConfig(RBLNModelConfig):
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
34
|
batch_size: Optional[int] = None,
|
|
35
|
-
**kwargs,
|
|
35
|
+
**kwargs: Any,
|
|
36
36
|
):
|
|
37
|
+
"""
|
|
38
|
+
Args:
|
|
39
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
40
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
41
|
+
"""
|
|
37
42
|
super().__init__(**kwargs)
|
|
38
43
|
self.batch_size = batch_size or 1
|
|
39
44
|
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
@@ -53,7 +58,7 @@ class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
|
|
|
53
58
|
batch_size: Optional[int] = None,
|
|
54
59
|
num_query_tokens: Optional[int] = None,
|
|
55
60
|
image_text_hidden_size: Optional[int] = None,
|
|
56
|
-
**kwargs,
|
|
61
|
+
**kwargs: Any,
|
|
57
62
|
):
|
|
58
63
|
"""
|
|
59
64
|
Args:
|
|
@@ -468,7 +468,7 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
|
|
|
468
468
|
input_ids (torch.LongTensor, optional): The sequence used as a prompt for the generation.
|
|
469
469
|
attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices
|
|
470
470
|
inputs_embeds (torch.FloatTensor, optional): Embedded representation of the inputs. Should be float, not int tokens.
|
|
471
|
-
interpolate_pos_encoding (bool, optional, defaults to False)
|
|
471
|
+
interpolate_pos_encoding (bool, optional, defaults to False): Whether to interpolate the positional encoding of the image embeddings.
|
|
472
472
|
Returns:
|
|
473
473
|
A list of strings of length batch_size * num_captions.
|
|
474
474
|
"""
|
|
@@ -77,11 +77,11 @@ class ColPaliModel(nn.Module):
|
|
|
77
77
|
self, model, layers: List["ColPaliLayer"], output_hidden_states: bool = False, max_seq_len: int = 2048
|
|
78
78
|
):
|
|
79
79
|
super().__init__()
|
|
80
|
-
self._original_mod = model
|
|
81
80
|
self.layers = nn.ModuleList(layers)
|
|
82
81
|
self.output_hidden_states = output_hidden_states
|
|
83
|
-
self.
|
|
84
|
-
self.
|
|
82
|
+
self.config = model.config
|
|
83
|
+
self.norm = model.norm
|
|
84
|
+
self.hidden_size = self.config.hidden_size
|
|
85
85
|
self.max_seq_len = max_seq_len
|
|
86
86
|
|
|
87
87
|
def forward(
|
|
@@ -118,7 +118,6 @@ class ColPaliModel(nn.Module):
|
|
|
118
118
|
class ColPaliLayer(nn.Module):
|
|
119
119
|
def __init__(self, layer, self_attn: "ColPaliAttention"):
|
|
120
120
|
super().__init__()
|
|
121
|
-
self._original_mod = layer
|
|
122
121
|
self.self_attn = self_attn
|
|
123
122
|
self.mlp = layer.mlp
|
|
124
123
|
self.input_layernorm = layer.input_layernorm
|
|
@@ -155,27 +154,22 @@ class ColPaliLayer(nn.Module):
|
|
|
155
154
|
class ColPaliAttention(nn.Module):
|
|
156
155
|
def __init__(self, self_attn):
|
|
157
156
|
super().__init__()
|
|
158
|
-
self.
|
|
159
|
-
self.num_heads = (
|
|
160
|
-
|
|
161
|
-
)
|
|
162
|
-
self.head_dim = self._original_mod.head_dim
|
|
157
|
+
self.config = self_attn.config
|
|
158
|
+
self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
|
|
159
|
+
self.head_dim = self_attn.head_dim
|
|
163
160
|
self.scaling = self.head_dim**-0.5
|
|
164
161
|
|
|
165
|
-
if hasattr(
|
|
166
|
-
self.num_key_value_heads =
|
|
167
|
-
elif hasattr(
|
|
168
|
-
self.num_key_value_heads =
|
|
162
|
+
if hasattr(self_attn, "num_key_value_heads"):
|
|
163
|
+
self.num_key_value_heads = self_attn.num_key_value_heads
|
|
164
|
+
elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
|
|
165
|
+
self.num_key_value_heads = self_attn.config.num_key_value_heads
|
|
169
166
|
else:
|
|
170
167
|
self.num_key_value_heads = self.num_heads
|
|
171
168
|
|
|
172
|
-
self.
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
self.
|
|
176
|
-
self.k_proj = self._original_mod.k_proj
|
|
177
|
-
self.v_proj = self._original_mod.v_proj
|
|
178
|
-
self.o_proj = self._original_mod.o_proj
|
|
169
|
+
self.q_proj = self_attn.q_proj
|
|
170
|
+
self.k_proj = self_attn.k_proj
|
|
171
|
+
self.v_proj = self_attn.v_proj
|
|
172
|
+
self.o_proj = self_attn.o_proj
|
|
179
173
|
|
|
180
174
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
181
175
|
query_states = self.q_proj(hidden_states)
|