optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,7 +21,6 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
|
21
21
|
|
|
22
22
|
from ....utils import logging
|
|
23
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
|
-
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
25
24
|
from .configuration_lora import RBLNLoRAConfig
|
|
26
25
|
from .lora_architecture import LoRALinear
|
|
27
26
|
|
|
@@ -76,8 +75,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
76
75
|
f" or equal to max_seq_len({rbln_config.max_seq_len})!"
|
|
77
76
|
)
|
|
78
77
|
|
|
79
|
-
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
|
|
80
|
-
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or
|
|
78
|
+
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len, use_rotary_emb)
|
|
79
|
+
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or self.config.n_layer
|
|
81
80
|
self._phase = "prefill"
|
|
82
81
|
|
|
83
82
|
def get_rotary_emb(self, max_seq_len):
|
|
@@ -104,7 +103,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
104
103
|
def get_rbln_causal_lm_class(self):
|
|
105
104
|
return DecoderOnlyForCausalLM
|
|
106
105
|
|
|
107
|
-
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
106
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int, use_rotary_emb: bool):
|
|
108
107
|
new_layers = []
|
|
109
108
|
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
110
109
|
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
@@ -119,6 +118,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
119
118
|
new_layers,
|
|
120
119
|
self.rbln_config,
|
|
121
120
|
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
121
|
+
use_rotary_emb=use_rotary_emb,
|
|
122
122
|
)
|
|
123
123
|
|
|
124
124
|
if self.is_causal_lm:
|
|
@@ -145,8 +145,11 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
145
145
|
local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
|
|
146
146
|
query_position = (
|
|
147
147
|
args.pop(0)
|
|
148
|
-
# query_position usage:
|
|
149
|
-
if (
|
|
148
|
+
# query_position usage: prefill & (logits_to_keep == 1 or use_local_attention)
|
|
149
|
+
if (
|
|
150
|
+
"prefill" in self.phase
|
|
151
|
+
and (self.rbln_config.logits_to_keep == 1 or self.rbln_config.use_local_attention)
|
|
152
|
+
)
|
|
150
153
|
else None
|
|
151
154
|
)
|
|
152
155
|
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
@@ -203,7 +206,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
203
206
|
rotary_emb,
|
|
204
207
|
) = self.prepare_forward_args(*args)
|
|
205
208
|
|
|
206
|
-
|
|
209
|
+
logits, all_hidden_states = self.model(
|
|
207
210
|
input_ids=input_ids,
|
|
208
211
|
inputs_embeds=inputs_embeds,
|
|
209
212
|
attention_mask=attention_mask,
|
|
@@ -215,9 +218,13 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
215
218
|
global_block_tables=global_block_tables,
|
|
216
219
|
local_block_tables=local_block_tables,
|
|
217
220
|
lora_int_id=lora_int_id,
|
|
221
|
+
output_hidden_states=self.rbln_config.output_hidden_states,
|
|
218
222
|
)
|
|
219
223
|
|
|
220
|
-
|
|
224
|
+
if self.rbln_config.output_hidden_states:
|
|
225
|
+
return logits, all_hidden_states
|
|
226
|
+
else:
|
|
227
|
+
return logits
|
|
221
228
|
|
|
222
229
|
|
|
223
230
|
class DecoderOnlyForCausalLM(nn.Module):
|
|
@@ -237,7 +244,6 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
237
244
|
|
|
238
245
|
Attributes:
|
|
239
246
|
config: Configuration from the original causal language model
|
|
240
|
-
_original_mod: Reference to the original model for components like lm_head
|
|
241
247
|
model: RBLN-optimized decoder model instance
|
|
242
248
|
_phase: Current processing phase ("prefill" or "decode")
|
|
243
249
|
"""
|
|
@@ -245,10 +251,9 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
245
251
|
def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
|
|
246
252
|
super().__init__()
|
|
247
253
|
self.config = causal_lm.config
|
|
248
|
-
self._original_mod = causal_lm
|
|
249
254
|
self.model = model
|
|
250
255
|
self._phase = "prefill"
|
|
251
|
-
self.lm_head =
|
|
256
|
+
self.lm_head = causal_lm.lm_head
|
|
252
257
|
|
|
253
258
|
@property
|
|
254
259
|
def phase(self):
|
|
@@ -272,9 +277,10 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
272
277
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
273
278
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
274
279
|
lora_int_id: Optional[torch.Tensor] = None,
|
|
280
|
+
output_hidden_states: Optional[bool] = None,
|
|
275
281
|
):
|
|
276
282
|
# outputs
|
|
277
|
-
hidden_states = self.model(
|
|
283
|
+
hidden_states, all_hidden_states = self.model(
|
|
278
284
|
input_ids=input_ids,
|
|
279
285
|
inputs_embeds=inputs_embeds,
|
|
280
286
|
attention_mask=attention_mask,
|
|
@@ -286,9 +292,10 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
286
292
|
global_block_tables=global_block_tables,
|
|
287
293
|
local_block_tables=local_block_tables,
|
|
288
294
|
lora_int_id=lora_int_id,
|
|
295
|
+
output_hidden_states=output_hidden_states,
|
|
289
296
|
)
|
|
290
297
|
|
|
291
|
-
if "prefill" in self.phase:
|
|
298
|
+
if "prefill" in self.phase and query_position is not None:
|
|
292
299
|
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
|
293
300
|
|
|
294
301
|
logits = self.lm_head(hidden_states)
|
|
@@ -299,7 +306,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
299
306
|
logits = torch.tanh(logits)
|
|
300
307
|
logits = logits * self.config.final_logit_softcapping
|
|
301
308
|
|
|
302
|
-
return logits
|
|
309
|
+
return logits, all_hidden_states
|
|
303
310
|
|
|
304
311
|
|
|
305
312
|
class DecoderOnlyModel(nn.Module):
|
|
@@ -312,20 +319,35 @@ class DecoderOnlyModel(nn.Module):
|
|
|
312
319
|
use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
|
|
313
320
|
|
|
314
321
|
Attributes:
|
|
315
|
-
_original_mod: Reference to original Huggingface model
|
|
316
322
|
layers: ModuleList of RBLN-optimized transformer layers
|
|
317
323
|
_phase: Current processing phase ("prefill" or "decode")
|
|
318
324
|
"""
|
|
319
325
|
|
|
326
|
+
_EMBEDDING_ATTRS = ["embed_tokens", "wte"]
|
|
327
|
+
_POSITION_ATTRS = ["embed_positions", "wpe"]
|
|
328
|
+
_LAYERNORM_ATTRS = ["norm", "final_layer_norm", "final_layernorm", "ln_f", "layer_norm"]
|
|
329
|
+
_PRE_FF_LAYERNORM_ATTRS = None
|
|
330
|
+
_POST_FF_LAYERNORM_ATTRS = None
|
|
331
|
+
|
|
320
332
|
def __init__(
|
|
321
333
|
self,
|
|
322
334
|
model,
|
|
323
335
|
layers: List["DecoderOnlyLayer"],
|
|
324
336
|
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
325
337
|
use_learned_pos_emb=None,
|
|
338
|
+
use_rotary_emb=True,
|
|
326
339
|
):
|
|
327
340
|
super().__init__()
|
|
328
|
-
self.
|
|
341
|
+
self.config = model.config
|
|
342
|
+
# Keep commonly-used original submodules registered on this wrapper so their weights
|
|
343
|
+
# are preserved in state_dict even if the original model object is not kept.
|
|
344
|
+
# Different HF model families use different attribute names; we register what we can
|
|
345
|
+
# and allow subclasses to override getters when needed.
|
|
346
|
+
self.embed_tokens = _get_attr_from_candidates(model, self._EMBEDDING_ATTRS)
|
|
347
|
+
# hasattr(model, "rotary_emb") is workaround for Qwen2VL
|
|
348
|
+
if not (use_rotary_emb or hasattr(model, "rotary_emb")):
|
|
349
|
+
self.embed_positions = _get_attr_from_candidates(model, self._POSITION_ATTRS)
|
|
350
|
+
self.norm = _get_attr_from_candidates(model, self._LAYERNORM_ATTRS)
|
|
329
351
|
self.layers = nn.ModuleList(layers)
|
|
330
352
|
self.rbln_config = rbln_config
|
|
331
353
|
self._phase = "prefill"
|
|
@@ -364,26 +386,28 @@ class DecoderOnlyModel(nn.Module):
|
|
|
364
386
|
cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
|
|
365
387
|
return cache_pos_for_partitions
|
|
366
388
|
|
|
367
|
-
def
|
|
368
|
-
max_cache_len = self.
|
|
389
|
+
def get_swa_custom_op_args(self, position_ids, query_position):
|
|
390
|
+
max_cache_len = self.config.sliding_window
|
|
369
391
|
valid_input_len = 1 if query_position is None else query_position + 1
|
|
370
|
-
cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
|
|
392
|
+
cache_seq_len = torch.clamp(position_ids.to(torch.int32), max=max_cache_len)[:, :1] # past seen tokens
|
|
371
393
|
cache_offset = (
|
|
372
394
|
torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
|
|
373
395
|
) # cache offset for next steps
|
|
374
396
|
|
|
375
|
-
|
|
397
|
+
# Causal mask for sliding window attention
|
|
398
|
+
attn_mask = torch.arange(max_cache_len)[None, :] - cache_seq_len
|
|
399
|
+
attn_mask = torch.where(attn_mask > 0, 0.0, 1.0)[:, None, None, :]
|
|
400
|
+
|
|
401
|
+
return cache_seq_len, cache_offset, attn_mask
|
|
376
402
|
|
|
377
403
|
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
378
|
-
return self.
|
|
404
|
+
return self.norm
|
|
379
405
|
|
|
380
406
|
def get_embedding(self) -> nn.Embedding:
|
|
381
|
-
return self.
|
|
407
|
+
return self.embed_tokens
|
|
382
408
|
|
|
383
409
|
def get_pos_embedding(self) -> nn.Embedding:
|
|
384
|
-
|
|
385
|
-
"The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
|
|
386
|
-
)
|
|
410
|
+
return self.embed_positions
|
|
387
411
|
|
|
388
412
|
def forward(
|
|
389
413
|
self,
|
|
@@ -398,6 +422,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
398
422
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
399
423
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
400
424
|
lora_int_id: Optional[torch.Tensor] = None,
|
|
425
|
+
output_hidden_states: Optional[bool] = None,
|
|
401
426
|
):
|
|
402
427
|
# retrieve input_ids and inputs_embeds
|
|
403
428
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -458,13 +483,19 @@ class DecoderOnlyModel(nn.Module):
|
|
|
458
483
|
|
|
459
484
|
# Get local cache positions for sliding window layers
|
|
460
485
|
if len(self.sliding_window_layers) > 0:
|
|
461
|
-
|
|
486
|
+
cache_seq_len, cache_offset, swa_attn_mask = self.get_swa_custom_op_args(position_ids, query_position)
|
|
487
|
+
sliding_cache_pos = (cache_seq_len, cache_offset)
|
|
462
488
|
|
|
489
|
+
all_hidden_states = () if output_hidden_states else None
|
|
463
490
|
for layer_idx, layer in enumerate(self.layers):
|
|
491
|
+
if output_hidden_states:
|
|
492
|
+
all_hidden_states += (hidden_states,)
|
|
493
|
+
|
|
464
494
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
495
|
+
is_sliding_decode = is_sliding and self.phase == "decode"
|
|
465
496
|
hidden_states = layer(
|
|
466
497
|
hidden_states=hidden_states,
|
|
467
|
-
attention_mask=attention_mask,
|
|
498
|
+
attention_mask=swa_attn_mask if is_sliding_decode else attention_mask,
|
|
468
499
|
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
469
500
|
past_key_values=past_key_values,
|
|
470
501
|
cos=cos,
|
|
@@ -474,7 +505,10 @@ class DecoderOnlyModel(nn.Module):
|
|
|
474
505
|
)
|
|
475
506
|
|
|
476
507
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
477
|
-
|
|
508
|
+
if output_hidden_states:
|
|
509
|
+
all_hidden_states += (hidden_states,)
|
|
510
|
+
|
|
511
|
+
return hidden_states, all_hidden_states
|
|
478
512
|
|
|
479
513
|
|
|
480
514
|
class DecoderOnlyLayer(nn.Module):
|
|
@@ -497,14 +531,23 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
497
531
|
self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
|
|
498
532
|
|
|
499
533
|
Attributes:
|
|
500
|
-
_original_mod: Reference to original layer for accessing components
|
|
501
534
|
self_attn: Modified attention mechanism mapped to RBLN ops at compile time
|
|
502
535
|
phase: Current operation phase ("prefill" or "decode")
|
|
503
536
|
"""
|
|
504
537
|
|
|
538
|
+
_PRE_ATTN_LAYERNORM = ["input_layernorm", "ln_1", "self_attn_layer_norm", "pre_feedforward_layernorm"]
|
|
539
|
+
_POST_ATTN_LAYERNORM = ["post_attention_layernorm", "ln_2", "final_layer_norm", "post_feedforward_layernorm"]
|
|
540
|
+
_PRE_FF_LAYERNORM_ATTRS = None
|
|
541
|
+
_POST_FF_LAYERNORM_ATTRS = None
|
|
542
|
+
|
|
505
543
|
def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
|
|
506
544
|
super().__init__()
|
|
507
|
-
|
|
545
|
+
|
|
546
|
+
self.pre_attention_layernorm = _get_attr_from_candidates(layer, self._PRE_ATTN_LAYERNORM)
|
|
547
|
+
self.post_attention_layernorm = _get_attr_from_candidates(layer, self._POST_ATTN_LAYERNORM)
|
|
548
|
+
self.pre_feedforward_layernorm = _get_attr_from_candidates(layer, self._PRE_FF_LAYERNORM_ATTRS)
|
|
549
|
+
self.post_feedforward_layernorm = _get_attr_from_candidates(layer, self._POST_FF_LAYERNORM_ATTRS)
|
|
550
|
+
self.mlp = layer.mlp
|
|
508
551
|
self.self_attn = self_attn
|
|
509
552
|
self._phase = "prefill"
|
|
510
553
|
self.lora_config = lora_config
|
|
@@ -534,13 +577,19 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
534
577
|
self.self_attn.phase = phase
|
|
535
578
|
|
|
536
579
|
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
537
|
-
return self.
|
|
580
|
+
return self.pre_attention_layernorm
|
|
538
581
|
|
|
539
582
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
540
|
-
return self.
|
|
583
|
+
return self.post_attention_layernorm
|
|
584
|
+
|
|
585
|
+
def get_pre_feedforward_layernorm(self) -> nn.LayerNorm:
|
|
586
|
+
return self.pre_feedforward_layernorm
|
|
587
|
+
|
|
588
|
+
def get_post_feedforward_layernorm(self) -> nn.LayerNorm:
|
|
589
|
+
return self.post_feedforward_layernorm
|
|
541
590
|
|
|
542
591
|
def get_mlp(self) -> nn.Module:
|
|
543
|
-
return self.
|
|
592
|
+
return self.mlp
|
|
544
593
|
|
|
545
594
|
def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
546
595
|
mlp = self.get_mlp()
|
|
@@ -606,6 +655,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
606
655
|
is_sliding: Whether this is sliding window attention
|
|
607
656
|
"""
|
|
608
657
|
|
|
658
|
+
_O_PROJ_ATTRS = ["o_proj", "out_proj", "dense"]
|
|
659
|
+
|
|
609
660
|
def __init__(
|
|
610
661
|
self,
|
|
611
662
|
self_attn,
|
|
@@ -613,39 +664,37 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
613
664
|
is_sliding=False,
|
|
614
665
|
):
|
|
615
666
|
super().__init__()
|
|
616
|
-
self.
|
|
667
|
+
self.config = getattr(self_attn, "config", None)
|
|
617
668
|
self.rbln_config = rbln_config
|
|
618
669
|
self.layer_idx = self_attn.layer_idx
|
|
619
|
-
self.num_heads = getattr(
|
|
620
|
-
|
|
621
|
-
)
|
|
622
|
-
self.head_dim = self._original_mod.head_dim
|
|
670
|
+
self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
|
|
671
|
+
self.head_dim = self_attn.head_dim
|
|
623
672
|
self._phase = "prefill"
|
|
624
|
-
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
|
|
625
|
-
self.quantization = rbln_config.quantization
|
|
673
|
+
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale(self_attn)))
|
|
626
674
|
|
|
627
|
-
if hasattr(
|
|
628
|
-
self.num_key_value_heads =
|
|
629
|
-
elif hasattr(
|
|
630
|
-
self.num_key_value_heads =
|
|
675
|
+
if hasattr(self_attn, "num_key_value_heads"):
|
|
676
|
+
self.num_key_value_heads = self_attn.num_key_value_heads
|
|
677
|
+
elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
|
|
678
|
+
self.num_key_value_heads = self_attn.config.num_key_value_heads
|
|
631
679
|
else:
|
|
632
680
|
self.num_key_value_heads = self.num_heads
|
|
633
681
|
|
|
634
|
-
self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
|
|
635
|
-
self.use_position_ids = rbln_config.use_position_ids
|
|
636
682
|
self.is_sliding = is_sliding
|
|
637
683
|
self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
|
|
638
684
|
self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
|
|
639
685
|
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
640
686
|
self.lora_config = rbln_config.lora_config
|
|
641
687
|
|
|
688
|
+
if hasattr(self_attn, "sinks"):
|
|
689
|
+
self.sinks = self_attn.sinks.data[:, None]
|
|
690
|
+
|
|
642
691
|
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
643
|
-
self.__post_init__()
|
|
692
|
+
self.__post_init__(self_attn)
|
|
644
693
|
|
|
645
694
|
def _init_lora_weights(self):
|
|
646
695
|
"""Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
|
|
647
696
|
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
648
|
-
original_linear = getattr(self
|
|
697
|
+
original_linear = getattr(self, proj_name)
|
|
649
698
|
lora_linear = LoRALinear(
|
|
650
699
|
original_linear=original_linear,
|
|
651
700
|
lora_config=self.lora_config,
|
|
@@ -680,8 +729,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
680
729
|
self.num_heads,
|
|
681
730
|
self.head_dim,
|
|
682
731
|
self.num_key_value_heads,
|
|
683
|
-
self.
|
|
684
|
-
self.use_position_ids,
|
|
732
|
+
rbln_config=self.rbln_config,
|
|
685
733
|
)
|
|
686
734
|
elif self.attn_impl == "flash_attn":
|
|
687
735
|
return FlashAttentionOp(
|
|
@@ -689,32 +737,29 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
689
737
|
self.head_dim,
|
|
690
738
|
self.num_key_value_heads,
|
|
691
739
|
self.kvcache_partition_len,
|
|
692
|
-
self.
|
|
693
|
-
|
|
694
|
-
self.quantization,
|
|
740
|
+
rbln_config=self.rbln_config,
|
|
741
|
+
is_sliding=False,
|
|
695
742
|
)
|
|
696
743
|
elif self.attn_impl == "eager":
|
|
697
744
|
return AttentionOp(
|
|
698
745
|
self.num_heads,
|
|
699
746
|
self.head_dim,
|
|
700
747
|
self.num_key_value_heads,
|
|
701
|
-
self.
|
|
702
|
-
|
|
703
|
-
self.quantization,
|
|
748
|
+
rbln_config=self.rbln_config,
|
|
749
|
+
is_sliding=False,
|
|
704
750
|
)
|
|
705
751
|
else:
|
|
706
752
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
707
753
|
|
|
708
|
-
def __post_init__(self):
|
|
754
|
+
def __post_init__(self, self_attn=None):
|
|
755
|
+
self.q_proj = self_attn.q_proj
|
|
756
|
+
self.k_proj = self_attn.k_proj
|
|
757
|
+
self.v_proj = self_attn.v_proj
|
|
758
|
+
self.o_proj = _get_attr_from_candidates(self_attn, self._O_PROJ_ATTRS)
|
|
759
|
+
|
|
709
760
|
# Initialize LoRA weights if configured, which will replace linear layers
|
|
710
761
|
if self.lora_config:
|
|
711
762
|
self._init_lora_weights()
|
|
712
|
-
else:
|
|
713
|
-
# Use original linear layers if no LoRA
|
|
714
|
-
self.q_proj = self._original_mod.q_proj
|
|
715
|
-
self.k_proj = self._original_mod.k_proj
|
|
716
|
-
self.v_proj = self._original_mod.v_proj
|
|
717
|
-
self.o_proj = self._original_mod.o_proj
|
|
718
763
|
|
|
719
764
|
def projection(
|
|
720
765
|
self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
|
|
@@ -745,8 +790,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
745
790
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
746
791
|
return apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
747
792
|
|
|
748
|
-
def get_attn_scale(self):
|
|
749
|
-
return 1 / math.sqrt(
|
|
793
|
+
def get_attn_scale(self, self_attn):
|
|
794
|
+
return 1 / math.sqrt(self_attn.head_dim)
|
|
750
795
|
|
|
751
796
|
def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
752
797
|
if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
|
|
@@ -803,6 +848,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
803
848
|
block_size=self.kvcache_block_size,
|
|
804
849
|
k_scale=k_scale,
|
|
805
850
|
v_scale=v_scale,
|
|
851
|
+
s_aux=getattr(self, "sinks", None),
|
|
806
852
|
)
|
|
807
853
|
|
|
808
854
|
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
@@ -830,23 +876,27 @@ class AttentionOp(nn.Module):
|
|
|
830
876
|
num_heads: int,
|
|
831
877
|
head_dim: int,
|
|
832
878
|
num_key_value_heads: int,
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
879
|
+
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
|
|
880
|
+
is_sliding: bool = False,
|
|
836
881
|
):
|
|
837
882
|
super().__init__()
|
|
838
883
|
self.num_heads = num_heads
|
|
839
884
|
self.head_dim = head_dim
|
|
840
885
|
self.num_key_value_heads = num_key_value_heads
|
|
841
886
|
self.phase = "prefill"
|
|
842
|
-
self.
|
|
843
|
-
self.
|
|
844
|
-
self.
|
|
887
|
+
self.rbln_config = rbln_config
|
|
888
|
+
self.use_attention_mask = True if is_sliding else rbln_config.use_attention_mask
|
|
889
|
+
self.use_position_ids = rbln_config.use_position_ids
|
|
890
|
+
self.quantization = rbln_config.quantization
|
|
845
891
|
|
|
846
892
|
def get_attn_op_name(self):
|
|
847
893
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
848
|
-
|
|
849
|
-
|
|
894
|
+
|
|
895
|
+
if self.use_attention_mask:
|
|
896
|
+
if self.rbln_config.use_position_ids:
|
|
897
|
+
attn_op_name = "paged_causal_attn_"
|
|
898
|
+
else:
|
|
899
|
+
attn_op_name = "paged_attn_"
|
|
850
900
|
else:
|
|
851
901
|
attn_op_name = "paged_causal_attn_"
|
|
852
902
|
|
|
@@ -871,6 +921,7 @@ class AttentionOp(nn.Module):
|
|
|
871
921
|
block_size: int,
|
|
872
922
|
k_scale: Optional[torch.Tensor] = None,
|
|
873
923
|
v_scale: Optional[torch.Tensor] = None,
|
|
924
|
+
s_aux: Optional[torch.Tensor] = None,
|
|
874
925
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
875
926
|
"""Compute attention with static shapes and explicit cache management.
|
|
876
927
|
|
|
@@ -887,6 +938,7 @@ class AttentionOp(nn.Module):
|
|
|
887
938
|
block_size: Block size for paged attention
|
|
888
939
|
k_scale: Scale applied to key
|
|
889
940
|
v_scale: Scale applied to value
|
|
941
|
+
s_aux: Auxiliary states for attention sinks
|
|
890
942
|
|
|
891
943
|
Returns:
|
|
892
944
|
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
|
@@ -895,7 +947,7 @@ class AttentionOp(nn.Module):
|
|
|
895
947
|
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
|
896
948
|
value_state = value_state.unsqueeze(2)
|
|
897
949
|
|
|
898
|
-
if self.use_attention_mask and not self.use_position_ids:
|
|
950
|
+
if self.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
899
951
|
attn_mask = attn_mask.unsqueeze(2)
|
|
900
952
|
|
|
901
953
|
if self.phase == "decode":
|
|
@@ -927,8 +979,14 @@ class AttentionOp(nn.Module):
|
|
|
927
979
|
op_args["mask"] = attn_mask
|
|
928
980
|
|
|
929
981
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
930
|
-
|
|
931
|
-
|
|
982
|
+
use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
|
|
983
|
+
if use_image_prefill:
|
|
984
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill"
|
|
985
|
+
else:
|
|
986
|
+
if not self.use_attention_mask:
|
|
987
|
+
op_args["is_bidirectional"] = False
|
|
988
|
+
elif self.use_attention_mask and self.rbln_config.use_position_ids:
|
|
989
|
+
op_args["is_bidirectional"] = True
|
|
932
990
|
|
|
933
991
|
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
934
992
|
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
@@ -936,6 +994,9 @@ class AttentionOp(nn.Module):
|
|
|
936
994
|
op_args["k_scale"] = k_scale
|
|
937
995
|
op_args["v_scale"] = v_scale
|
|
938
996
|
|
|
997
|
+
if s_aux is not None:
|
|
998
|
+
op_args["s_aux"] = s_aux
|
|
999
|
+
|
|
939
1000
|
attn_op_name = self.get_attn_op_name()
|
|
940
1001
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
941
1002
|
if attn_op is None:
|
|
@@ -956,24 +1017,26 @@ class FlashAttentionOp(AttentionOp):
|
|
|
956
1017
|
head_dim: int,
|
|
957
1018
|
num_key_value_heads: int,
|
|
958
1019
|
kvcache_partition_len: int,
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
1020
|
+
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
|
|
1021
|
+
is_sliding: bool = False,
|
|
962
1022
|
):
|
|
963
1023
|
super().__init__(
|
|
964
1024
|
num_heads=num_heads,
|
|
965
1025
|
head_dim=head_dim,
|
|
966
1026
|
num_key_value_heads=num_key_value_heads,
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
quantization=quantization,
|
|
1027
|
+
rbln_config=rbln_config,
|
|
1028
|
+
is_sliding=is_sliding,
|
|
970
1029
|
)
|
|
971
1030
|
self.kvcache_partition_size = kvcache_partition_len
|
|
972
1031
|
|
|
973
1032
|
def get_attn_op_name(self):
|
|
974
1033
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
975
|
-
|
|
976
|
-
|
|
1034
|
+
|
|
1035
|
+
if self.use_attention_mask:
|
|
1036
|
+
if self.rbln_config.use_position_ids:
|
|
1037
|
+
attn_op_name = "paged_flash_causal_attn_"
|
|
1038
|
+
else:
|
|
1039
|
+
attn_op_name = "paged_flash_attn_"
|
|
977
1040
|
else:
|
|
978
1041
|
attn_op_name = "paged_flash_causal_attn_"
|
|
979
1042
|
|
|
@@ -998,11 +1061,13 @@ class FlashAttentionOp(AttentionOp):
|
|
|
998
1061
|
block_size,
|
|
999
1062
|
k_scale=None,
|
|
1000
1063
|
v_scale=None,
|
|
1064
|
+
s_aux=None,
|
|
1001
1065
|
):
|
|
1002
1066
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1003
1067
|
key_state = key_state.unsqueeze(2)
|
|
1004
1068
|
value_state = value_state.unsqueeze(2)
|
|
1005
|
-
|
|
1069
|
+
|
|
1070
|
+
if self.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
1006
1071
|
attn_mask = attn_mask.unsqueeze(2)
|
|
1007
1072
|
|
|
1008
1073
|
if self.phase == "decode":
|
|
@@ -1035,8 +1100,14 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1035
1100
|
op_args["mask"] = attn_mask
|
|
1036
1101
|
|
|
1037
1102
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1038
|
-
|
|
1039
|
-
|
|
1103
|
+
use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
|
|
1104
|
+
if use_image_prefill:
|
|
1105
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill"
|
|
1106
|
+
else:
|
|
1107
|
+
if not self.use_attention_mask:
|
|
1108
|
+
op_args["is_bidirectional"] = False
|
|
1109
|
+
elif self.use_attention_mask and self.rbln_config.use_position_ids:
|
|
1110
|
+
op_args["is_bidirectional"] = True
|
|
1040
1111
|
|
|
1041
1112
|
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
1042
1113
|
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
@@ -1044,6 +1115,9 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1044
1115
|
op_args["k_scale"] = k_scale
|
|
1045
1116
|
op_args["v_scale"] = v_scale
|
|
1046
1117
|
|
|
1118
|
+
if s_aux is not None:
|
|
1119
|
+
op_args["s_aux"] = s_aux
|
|
1120
|
+
|
|
1047
1121
|
attn_op_name = self.get_attn_op_name()
|
|
1048
1122
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1049
1123
|
if attn_op is None:
|
|
@@ -1058,6 +1132,22 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1058
1132
|
|
|
1059
1133
|
|
|
1060
1134
|
class SlidingWindowAttentionOp(AttentionOp):
|
|
1135
|
+
def __init__(
|
|
1136
|
+
self,
|
|
1137
|
+
num_heads: int,
|
|
1138
|
+
head_dim: int,
|
|
1139
|
+
num_key_value_heads: int,
|
|
1140
|
+
rbln_config: Optional["RBLNDecoderOnlyModelConfig"] = None,
|
|
1141
|
+
):
|
|
1142
|
+
super().__init__(
|
|
1143
|
+
num_heads=num_heads,
|
|
1144
|
+
head_dim=head_dim,
|
|
1145
|
+
num_key_value_heads=num_key_value_heads,
|
|
1146
|
+
rbln_config=rbln_config,
|
|
1147
|
+
is_sliding=True,
|
|
1148
|
+
)
|
|
1149
|
+
self.quantization = None # Sliding window attention does not support quantization
|
|
1150
|
+
|
|
1061
1151
|
def get_attn_op_name(self):
|
|
1062
1152
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1063
1153
|
if not self.use_attention_mask:
|
|
@@ -1080,6 +1170,7 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1080
1170
|
block_size: int,
|
|
1081
1171
|
k_scale: Optional[torch.Tensor] = None,
|
|
1082
1172
|
v_scale: Optional[torch.Tensor] = None,
|
|
1173
|
+
s_aux: Optional[torch.Tensor] = None,
|
|
1083
1174
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1084
1175
|
assert self.quantization is None, "Sliding window attention does not support quantization"
|
|
1085
1176
|
assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
|
|
@@ -1115,7 +1206,19 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1115
1206
|
}
|
|
1116
1207
|
|
|
1117
1208
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1118
|
-
|
|
1209
|
+
use_image_prefill = getattr(self.rbln_config, "use_image_prefill", False)
|
|
1210
|
+
if use_image_prefill:
|
|
1211
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill"
|
|
1212
|
+
else:
|
|
1213
|
+
if self.use_attention_mask and self.rbln_config.use_position_ids:
|
|
1214
|
+
op_args["is_bidirectional"] = True
|
|
1215
|
+
else:
|
|
1216
|
+
op_args["is_bidirectional"] = False
|
|
1217
|
+
elif self.phase == "decode":
|
|
1218
|
+
op_args["attn_mask"] = attn_mask
|
|
1219
|
+
|
|
1220
|
+
if s_aux is not None:
|
|
1221
|
+
op_args["s_aux"] = s_aux
|
|
1119
1222
|
|
|
1120
1223
|
attn_op_name = self.get_attn_op_name()
|
|
1121
1224
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
@@ -1145,7 +1248,7 @@ class RotaryEmbedding(nn.Module):
|
|
|
1145
1248
|
else:
|
|
1146
1249
|
rope_type = "default"
|
|
1147
1250
|
|
|
1148
|
-
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1251
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, "cpu", max_seq_len_cached)
|
|
1149
1252
|
cache_position = torch.arange(0, max_seq_len_cached)
|
|
1150
1253
|
cache_position_expanded = cache_position[:, None]
|
|
1151
1254
|
|
|
@@ -1222,3 +1325,22 @@ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tu
|
|
|
1222
1325
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1223
1326
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1224
1327
|
return query_states, key_states
|
|
1328
|
+
|
|
1329
|
+
|
|
1330
|
+
def _get_attr_from_candidates(
|
|
1331
|
+
src: object,
|
|
1332
|
+
candidates: Optional[List[str]] = None,
|
|
1333
|
+
):
|
|
1334
|
+
"""
|
|
1335
|
+
Get an attribute from a list of candidate names.
|
|
1336
|
+
|
|
1337
|
+
- If `candidates` is None, this attribute is treated as optional and returns None.
|
|
1338
|
+
- Otherwise, returns `getattr(src, name)` for the first `name` in `candidates` that exists on `src`.
|
|
1339
|
+
- Raises AttributeError if `candidates` is provided but none of the names exist on `src`.
|
|
1340
|
+
"""
|
|
1341
|
+
if candidates is None:
|
|
1342
|
+
return None
|
|
1343
|
+
for name in candidates:
|
|
1344
|
+
if hasattr(src, name):
|
|
1345
|
+
return getattr(src, name)
|
|
1346
|
+
raise AttributeError(f"None of the attributes {candidates} exist in {src}")
|