optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -99,9 +99,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
|
|
|
99
99
|
return True
|
|
100
100
|
|
|
101
101
|
@classmethod
|
|
102
|
-
def
|
|
103
|
-
model = super().get_pytorch_model(*args, **kwargs)
|
|
104
|
-
|
|
102
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
105
103
|
with no_init_weights():
|
|
106
104
|
model_cls_name = model.model.language_model.__class__.__name__
|
|
107
105
|
causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
|
|
@@ -135,7 +133,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
|
|
|
135
133
|
return self.language_model.get_input_embeddings()
|
|
136
134
|
|
|
137
135
|
@classmethod
|
|
138
|
-
def
|
|
136
|
+
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
139
137
|
return model.multi_modal_projector
|
|
140
138
|
|
|
141
139
|
@classmethod
|
|
@@ -301,28 +299,60 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
|
|
|
301
299
|
generate_idx: Optional[torch.Tensor] = None,
|
|
302
300
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
303
301
|
position_ids: Optional[torch.Tensor] = None,
|
|
302
|
+
output_hidden_states: Optional[bool] = None,
|
|
304
303
|
**lm_kwargs: Dict[str, Any],
|
|
305
304
|
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
|
305
|
+
output_hidden_states = (
|
|
306
|
+
output_hidden_states
|
|
307
|
+
if output_hidden_states is not None
|
|
308
|
+
else self.rbln_config.language_model.output_hidden_states
|
|
309
|
+
)
|
|
310
|
+
if output_hidden_states != self.rbln_config.language_model.output_hidden_states:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.language_model.output_hidden_states {self.rbln_config.language_model.output_hidden_states} "
|
|
313
|
+
f"Please compile again with the correct argument."
|
|
314
|
+
)
|
|
315
|
+
|
|
306
316
|
# prefill
|
|
307
317
|
if cache_position is None:
|
|
308
318
|
logits = []
|
|
309
319
|
inputs_embeds = self._preprocess_prefill(input_ids, inputs_embeds, pixel_values)
|
|
310
320
|
batch_size = inputs_embeds.shape[0]
|
|
311
321
|
|
|
322
|
+
all_hidden_states = (
|
|
323
|
+
tuple(
|
|
324
|
+
torch.zeros(
|
|
325
|
+
batch_size,
|
|
326
|
+
inputs_embeds.shape[1],
|
|
327
|
+
self.config.text_config.hidden_size,
|
|
328
|
+
dtype=self.rbln_config.dtype,
|
|
329
|
+
)
|
|
330
|
+
for _ in range(self.config.text_config.num_hidden_layers + 1)
|
|
331
|
+
)
|
|
332
|
+
if self.rbln_config.language_model.output_hidden_states
|
|
333
|
+
else None
|
|
334
|
+
)
|
|
335
|
+
|
|
312
336
|
for b_idx in range(batch_size):
|
|
313
337
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
314
338
|
token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
|
315
339
|
cache_position = self.get_padded_cache_position(cache_position, token_type_id)
|
|
316
340
|
|
|
317
|
-
|
|
341
|
+
outputs = self.language_model.prefill_decoder(
|
|
318
342
|
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
|
319
343
|
attention_mask=attention_mask[b_idx],
|
|
320
344
|
cache_position=cache_position,
|
|
321
345
|
batch_idx=b_idx,
|
|
322
346
|
token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
|
|
323
347
|
)
|
|
324
|
-
padded_cache_lengths[b_idx] +=
|
|
325
|
-
logits.append(
|
|
348
|
+
padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
|
|
349
|
+
logits.append(outputs.logits)
|
|
350
|
+
if self.rbln_config.language_model.output_hidden_states:
|
|
351
|
+
for l_idx in range(self.config.text_config.num_hidden_layers + 1):
|
|
352
|
+
mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
|
|
353
|
+
all_hidden_states[l_idx][b_idx].index_copy_(
|
|
354
|
+
dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0]
|
|
355
|
+
)
|
|
326
356
|
|
|
327
357
|
logits = torch.cat(logits, dim=0)
|
|
328
358
|
# decoder
|
|
@@ -336,15 +366,20 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
|
|
|
336
366
|
f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
|
|
337
367
|
)
|
|
338
368
|
|
|
339
|
-
|
|
369
|
+
outputs = self.language_model.decoders[batch_size](
|
|
340
370
|
input_ids=input_ids,
|
|
341
371
|
inputs_embeds=inputs_embeds,
|
|
342
372
|
cache_position=cache_position,
|
|
343
373
|
position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
|
|
344
|
-
)
|
|
374
|
+
)
|
|
375
|
+
logits = outputs.logits
|
|
376
|
+
all_hidden_states = outputs.hidden_states
|
|
345
377
|
|
|
346
378
|
return RBLNDecoderOnlyOutput(
|
|
347
|
-
logits=logits,
|
|
379
|
+
logits=logits,
|
|
380
|
+
generate_idx=generate_idx,
|
|
381
|
+
padded_cache_lengths=padded_cache_lengths,
|
|
382
|
+
hidden_states=all_hidden_states,
|
|
348
383
|
)
|
|
349
384
|
|
|
350
385
|
|
|
@@ -405,26 +440,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
405
440
|
)
|
|
406
441
|
return embed_tokens
|
|
407
442
|
|
|
408
|
-
@classmethod
|
|
409
|
-
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
410
|
-
sliding_window = getattr(model_config, "sliding_window", None)
|
|
411
|
-
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
|
412
|
-
if sliding_window_pattern is None:
|
|
413
|
-
if hasattr(model_config, "layer_types"):
|
|
414
|
-
first_full_attention_index = model_config.layer_types.index("full_attention")
|
|
415
|
-
sliding_window_pattern = first_full_attention_index + 1
|
|
416
|
-
else:
|
|
417
|
-
raise ValueError("Cannot determine sliding_window_pattern from model_config")
|
|
418
|
-
|
|
419
|
-
if sliding_window_pattern <= model_config.num_hidden_layers:
|
|
420
|
-
rbln_config.cache_impl = "hybrid"
|
|
421
|
-
rbln_config.sliding_window = sliding_window
|
|
422
|
-
rbln_config.sliding_window_layers = [
|
|
423
|
-
i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
|
|
424
|
-
]
|
|
425
|
-
|
|
426
|
-
return rbln_config
|
|
427
|
-
|
|
428
443
|
@classmethod
|
|
429
444
|
def _update_submodule_config(
|
|
430
445
|
cls,
|
|
@@ -482,7 +497,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
482
497
|
@classmethod
|
|
483
498
|
@torch.inference_mode()
|
|
484
499
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
485
|
-
wrapped_model = cls.
|
|
500
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
486
501
|
|
|
487
502
|
rbln_compile_configs = rbln_config.compile_cfgs
|
|
488
503
|
prefill_compile_config = rbln_compile_configs[0]
|
|
@@ -20,8 +20,6 @@ import torch.nn as nn
|
|
|
20
20
|
|
|
21
21
|
from ..decoderonly.decoderonly_architecture import (
|
|
22
22
|
DecoderOnlyAttention,
|
|
23
|
-
DecoderOnlyLayer,
|
|
24
|
-
DecoderOnlyModel,
|
|
25
23
|
DecoderOnlyWrapper,
|
|
26
24
|
)
|
|
27
25
|
|
|
@@ -34,12 +32,6 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
|
34
32
|
def get_rbln_attn_class(self):
|
|
35
33
|
return GPT2Attention
|
|
36
34
|
|
|
37
|
-
def get_rbln_layer_class(self):
|
|
38
|
-
return GPT2Layer
|
|
39
|
-
|
|
40
|
-
def get_rbln_model_class(self):
|
|
41
|
-
return GPT2Model
|
|
42
|
-
|
|
43
35
|
def get_attn_layer(self, layer: nn.Module):
|
|
44
36
|
return layer.attn
|
|
45
37
|
|
|
@@ -50,30 +42,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
|
50
42
|
return model.transformer.h if self.is_causal_lm else model.h
|
|
51
43
|
|
|
52
44
|
|
|
53
|
-
class GPT2Model(DecoderOnlyModel):
|
|
54
|
-
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
55
|
-
return self._original_mod.ln_f
|
|
56
|
-
|
|
57
|
-
def get_embedding(self) -> nn.Embedding:
|
|
58
|
-
return self._original_mod.wte
|
|
59
|
-
|
|
60
|
-
def get_pos_embedding(self) -> nn.Embedding:
|
|
61
|
-
return self._original_mod.wpe
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class GPT2Layer(DecoderOnlyLayer):
|
|
65
|
-
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
66
|
-
return self._original_mod.ln_1
|
|
67
|
-
|
|
68
|
-
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
69
|
-
return self._original_mod.ln_2
|
|
70
|
-
|
|
71
|
-
|
|
72
45
|
class GPT2Attention(DecoderOnlyAttention):
|
|
73
|
-
def __post_init__(self):
|
|
74
|
-
self.c_attn =
|
|
75
|
-
self.o_proj =
|
|
76
|
-
self.split_size =
|
|
46
|
+
def __post_init__(self, self_attn):
|
|
47
|
+
self.c_attn = self_attn.c_attn
|
|
48
|
+
self.o_proj = self_attn.c_proj
|
|
49
|
+
self.split_size = self_attn.split_size
|
|
50
|
+
self.num_key_value_heads = self_attn.num_heads
|
|
77
51
|
|
|
78
52
|
def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
79
53
|
if lora_int_id is not None:
|
|
@@ -82,12 +56,12 @@ class GPT2Attention(DecoderOnlyAttention):
|
|
|
82
56
|
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
|
83
57
|
return query_states, key_states, value_states
|
|
84
58
|
|
|
85
|
-
def get_attn_scale(self):
|
|
59
|
+
def get_attn_scale(self, self_attn):
|
|
86
60
|
scale = 1.0
|
|
87
|
-
if
|
|
61
|
+
if self_attn.scale_attn_weights:
|
|
88
62
|
scale /= math.sqrt(self.head_dim)
|
|
89
63
|
|
|
90
|
-
if
|
|
64
|
+
if self_attn.scale_attn_by_inverse_layer_idx:
|
|
91
65
|
scale /= 1 + self.layer_idx
|
|
92
66
|
|
|
93
67
|
return scale
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_gpt_oss import RBLNGptOssForCausalLMConfig
|
|
16
|
+
from .modeling_gpt_oss import RBLNGptOssForCausalLM
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNGptOssForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
|
19
|
+
"""
|
|
20
|
+
Configuration class for RBLN GPT-OSS models.
|
|
21
|
+
|
|
22
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
|
23
|
+
|
|
24
|
+
Example usage:
|
|
25
|
+
```python
|
|
26
|
+
from optimum.rbln import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
|
|
27
|
+
|
|
28
|
+
# Create a configuration object
|
|
29
|
+
config = RBLNGptOssForCausalLMConfig(
|
|
30
|
+
batch_size=1,
|
|
31
|
+
tensor_parallel_size=4
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Use the configuration with from_pretrained
|
|
35
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
36
|
+
"openai/gpt-oss-20b",
|
|
37
|
+
export=True,
|
|
38
|
+
rbln_config=config
|
|
39
|
+
)
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
from torch import nn
|
|
21
|
+
|
|
22
|
+
from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
|
|
23
|
+
from ..decoderonly.decoderonly_architecture import (
|
|
24
|
+
DecoderOnlyAttention,
|
|
25
|
+
DecoderOnlyLayer,
|
|
26
|
+
DecoderOnlyWrapper,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RBLNGptOssWrapper(DecoderOnlyWrapper):
|
|
31
|
+
def get_rbln_layer_class(self):
|
|
32
|
+
return RBLNGptOssLayer
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RBLNGptOssLayer(DecoderOnlyLayer):
|
|
36
|
+
def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
|
|
37
|
+
super().__init__(layer, self_attn, lora_config)
|
|
38
|
+
self.mlp = RBLNGptOssMLP(layer.mlp)
|
|
39
|
+
|
|
40
|
+
def get_mlp(self) -> nn.Module:
|
|
41
|
+
return self.mlp
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RBLNGptOssTopKRouter(nn.Module):
|
|
45
|
+
def __init__(self, model):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.weight = model.weight
|
|
48
|
+
self.bias = model.bias
|
|
49
|
+
|
|
50
|
+
def forward(self, hidden_states):
|
|
51
|
+
return F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RBLNGptOssExperts(nn.Module):
|
|
55
|
+
def __init__(self, model, top_k: Optional[int] = None):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.intermediate_size = model.intermediate_size
|
|
58
|
+
self.num_experts = model.num_experts
|
|
59
|
+
self.hidden_size = model.hidden_size
|
|
60
|
+
|
|
61
|
+
self.register_buffer(
|
|
62
|
+
"gate_proj_blocks",
|
|
63
|
+
model.gate_up_proj_blocks.data[:, ::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
|
|
64
|
+
)
|
|
65
|
+
self.register_buffer("gate_proj_scales", model.gate_up_proj_scales.data[:, ::2, :])
|
|
66
|
+
self.register_buffer(
|
|
67
|
+
"gate_proj_bias",
|
|
68
|
+
model.gate_up_proj_bias.data[:, ::2].reshape(self.num_experts, self.intermediate_size),
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self.register_buffer(
|
|
72
|
+
"up_proj_blocks",
|
|
73
|
+
model.gate_up_proj_blocks.data[:, 1::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
|
|
74
|
+
)
|
|
75
|
+
self.register_buffer("up_proj_scales", model.gate_up_proj_scales.data[:, 1::2, :])
|
|
76
|
+
self.register_buffer(
|
|
77
|
+
"up_proj_bias", model.gate_up_proj_bias.data[:, 1::2].reshape(self.num_experts, self.intermediate_size)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
self.register_buffer(
|
|
81
|
+
"down_proj_blocks", model.down_proj_blocks.data.reshape(self.num_experts, self.hidden_size, -1)
|
|
82
|
+
)
|
|
83
|
+
self.register_buffer("down_proj_scales", model.down_proj_scales.data)
|
|
84
|
+
self.register_buffer("down_proj_bias", model.down_proj_bias.data)
|
|
85
|
+
|
|
86
|
+
self.alpha = model.alpha # 1.702
|
|
87
|
+
self.limit = model.limit # 7.0
|
|
88
|
+
self.top_k = top_k
|
|
89
|
+
|
|
90
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
|
|
91
|
+
return torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4(
|
|
92
|
+
hidden_states,
|
|
93
|
+
self.gate_proj_blocks,
|
|
94
|
+
self.gate_proj_scales,
|
|
95
|
+
self.gate_proj_bias,
|
|
96
|
+
self.up_proj_blocks,
|
|
97
|
+
self.up_proj_scales,
|
|
98
|
+
self.up_proj_bias,
|
|
99
|
+
self.down_proj_blocks,
|
|
100
|
+
self.down_proj_scales,
|
|
101
|
+
self.down_proj_bias,
|
|
102
|
+
router_logits,
|
|
103
|
+
torch.tensor(self.alpha, dtype=hidden_states.dtype),
|
|
104
|
+
torch.tensor(self.limit, dtype=hidden_states.dtype),
|
|
105
|
+
k=self.top_k,
|
|
106
|
+
post_norm=True,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class RBLNGptOssMLP(nn.Module):
|
|
111
|
+
def __init__(self, model):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.router = RBLNGptOssTopKRouter(model.router)
|
|
114
|
+
self.experts = RBLNGptOssExperts(model.experts, top_k=model.router.top_k)
|
|
115
|
+
|
|
116
|
+
def forward(self, hidden_states):
|
|
117
|
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
118
|
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
119
|
+
router_logits = self.router(hidden_states)
|
|
120
|
+
routed_out = self.experts(hidden_states, router_logits=router_logits)
|
|
121
|
+
routed_out = routed_out.reshape(batch_size, sequence_length, hidden_dim)
|
|
122
|
+
return routed_out
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from safetensors.torch import load_file
|
|
19
|
+
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
|
|
20
|
+
from transformers.integrations.mxfp4 import Mxfp4GptOssExperts
|
|
21
|
+
from transformers.modeling_utils import PreTrainedModel, no_init_weights
|
|
22
|
+
|
|
23
|
+
from ....utils.logging import get_logger
|
|
24
|
+
from ...models.decoderonly import (
|
|
25
|
+
RBLNDecoderOnlyModelConfig,
|
|
26
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
27
|
+
RBLNDecoderOnlyModelForCausalLMConfig,
|
|
28
|
+
)
|
|
29
|
+
from ...utils.rbln_quantization import load_weight_files
|
|
30
|
+
from .gpt_oss_architecture import RBLNGptOssWrapper
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
35
|
+
|
|
36
|
+
logger = get_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
40
|
+
"""
|
|
41
|
+
The GPT-OSS Model transformer with a language modeling head (linear layer) on top.
|
|
42
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
43
|
+
|
|
44
|
+
A class to convert and run pre-trained transformers based GPT-OSSForCausalLM model on RBLN devices.
|
|
45
|
+
It implements the methods to convert a pre-trained transformers GPT-OSSForCausalLM model into a RBLN transformer model by:
|
|
46
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
47
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
48
|
+
|
|
49
|
+
**Configuration:**
|
|
50
|
+
This model uses [`RBLNGptOssForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
|
|
51
|
+
the `rbln_config` parameter should be an instance of [`RBLNGptOssForCausalLMConfig`] or a dictionary conforming to its structure.
|
|
52
|
+
|
|
53
|
+
See the [`RBLNGptOssForCausalLMConfig`] class for all available configuration options.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
```python
|
|
57
|
+
from optimum.rbln import RBLNGptOssForCausalLM
|
|
58
|
+
|
|
59
|
+
# Simple usage using rbln_* arguments
|
|
60
|
+
# `max_seq_len` is automatically inferred from the model config
|
|
61
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
62
|
+
"openai/gpt-oss-20b",
|
|
63
|
+
export=True,
|
|
64
|
+
rbln_batch_size=1,
|
|
65
|
+
rbln_tensor_parallel_size=4,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Using a config dictionary
|
|
70
|
+
rbln_config = {
|
|
71
|
+
"batch_size": 1,
|
|
72
|
+
"tensor_parallel_size": 4,
|
|
73
|
+
}
|
|
74
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
75
|
+
"openai/gpt-oss-20b",
|
|
76
|
+
export=True,
|
|
77
|
+
rbln_config=rbln_config
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Using a RBLNGptOssForCausalLMConfig instance (recommended for type checking)
|
|
82
|
+
from optimum.rbln import RBLNGptOssForCausalLMConfig
|
|
83
|
+
|
|
84
|
+
config = RBLNGptOssForCausalLMConfig(
|
|
85
|
+
batch_size=1,
|
|
86
|
+
tensor_parallel_size=4
|
|
87
|
+
)
|
|
88
|
+
model = RBLNGptOssForCausalLM.from_pretrained(
|
|
89
|
+
"openai/gpt-oss-20b",
|
|
90
|
+
export=True,
|
|
91
|
+
rbln_config=config
|
|
92
|
+
)
|
|
93
|
+
```
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
_decoder_wrapper_cls = RBLNGptOssWrapper
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _get_dtype(dtype: Union[str, torch.dtype] = None, torch_dtype: Union[str, torch.dtype] = None):
|
|
100
|
+
# For BC on torch_dtype argument
|
|
101
|
+
if torch_dtype is not None:
|
|
102
|
+
logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
|
|
103
|
+
# If both kwargs are provided, use `dtype`
|
|
104
|
+
dtype = dtype if dtype is not None else torch_dtype
|
|
105
|
+
|
|
106
|
+
# As mxfp4_quantizer's default dtype
|
|
107
|
+
if dtype is None or dtype == "auto":
|
|
108
|
+
dtype = torch.bfloat16
|
|
109
|
+
|
|
110
|
+
return dtype
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def get_pytorch_model(
|
|
114
|
+
cls,
|
|
115
|
+
model_id: str,
|
|
116
|
+
*args,
|
|
117
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
118
|
+
dtype: Union[str, torch.dtype] = None,
|
|
119
|
+
torch_dtype: Union[str, torch.dtype] = None,
|
|
120
|
+
config: Optional[PretrainedConfig] = None,
|
|
121
|
+
**kwargs,
|
|
122
|
+
) -> PreTrainedModel:
|
|
123
|
+
safetensor_files = load_weight_files(model_id, exception_keywords=["original"])
|
|
124
|
+
state_dict = {k: v for f in safetensor_files for k, v in load_file(f).items()}
|
|
125
|
+
|
|
126
|
+
if config is None:
|
|
127
|
+
config, kwargs = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True)
|
|
128
|
+
|
|
129
|
+
dtype = cls._get_dtype(dtype, torch_dtype)
|
|
130
|
+
|
|
131
|
+
with no_init_weights():
|
|
132
|
+
model = AutoModelForCausalLM.from_config(config, dtype=dtype, **kwargs)
|
|
133
|
+
|
|
134
|
+
_replace_with_mxfp4_linear(model, config)
|
|
135
|
+
model.load_state_dict(state_dict, strict=False)
|
|
136
|
+
|
|
137
|
+
return model
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def _update_rbln_config(
|
|
141
|
+
cls,
|
|
142
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
143
|
+
model: Optional["PreTrainedModel"] = None,
|
|
144
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
145
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
146
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
147
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
148
|
+
|
|
149
|
+
if rbln_config.use_attention_mask:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"use_attention_mask is not supported for GPT-OSS because custom attention does not support attention sink for masked attention"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return rbln_config
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _replace_with_mxfp4_linear(
|
|
158
|
+
model,
|
|
159
|
+
config,
|
|
160
|
+
):
|
|
161
|
+
for name, module in model.named_children():
|
|
162
|
+
if module.__class__.__name__ == "GptOssExperts":
|
|
163
|
+
model._modules[name] = Mxfp4GptOssExperts(config)
|
|
164
|
+
if len(list(module.children())) > 0:
|
|
165
|
+
_replace_with_mxfp4_linear(module, config)
|
|
@@ -50,11 +50,14 @@ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
|
|
|
50
50
|
Raises:
|
|
51
51
|
ValueError: If batch_size is not a positive integer.
|
|
52
52
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
self.
|
|
56
|
-
self.
|
|
57
|
-
self.
|
|
53
|
+
|
|
54
|
+
super().__init__(batch_size=batch_size, **kwargs)
|
|
55
|
+
self.encoder = self.initialize_submodule_config(submodule_config=encoder, batch_size=self.batch_size)
|
|
56
|
+
self.decoder = self.initialize_submodule_config(submodule_config=decoder, batch_size=self.batch_size)
|
|
57
|
+
self.text_backbone = self.initialize_submodule_config(
|
|
58
|
+
submodule_config=text_backbone, batch_size=self.batch_size
|
|
59
|
+
)
|
|
60
|
+
self.backbone = self.initialize_submodule_config(submodule_config=backbone, batch_size=self.batch_size)
|
|
58
61
|
self.output_attentions = output_attentions if output_attentions is not None else False
|
|
59
62
|
self.output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
|
60
63
|
|
|
@@ -150,7 +150,7 @@ class _GroundingDinoEncoder(torch.nn.Module):
|
|
|
150
150
|
all_attn_fused_vision = () if output_attentions else None
|
|
151
151
|
all_attn_enhanced_text = () if output_attentions else None
|
|
152
152
|
all_attn_deformable = () if output_attentions else None
|
|
153
|
-
for
|
|
153
|
+
for _, encoder_layer in enumerate(self.layers):
|
|
154
154
|
if output_hidden_states:
|
|
155
155
|
encoder_vision_states += (vision_features,)
|
|
156
156
|
encoder_text_states += (text_features,)
|
|
@@ -509,10 +509,12 @@ class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
|
|
|
509
509
|
|
|
510
510
|
# mask vision for language
|
|
511
511
|
if vision_attention_mask is not None:
|
|
512
|
-
# RBLN FIX: bool tensor to float tensor
|
|
513
|
-
mask = vision_attention_mask
|
|
514
|
-
|
|
515
|
-
|
|
512
|
+
# RBLN FIX: bool tensor to float tensor, broadcast across heads and src_len
|
|
513
|
+
mask = vision_attention_mask
|
|
514
|
+
if mask.dim() == 3:
|
|
515
|
+
mask = mask[..., 0]
|
|
516
|
+
mask = mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
|
|
517
|
+
text_attn_weights = text_attn_weights + mask * torch.finfo(text_attn_weights.dtype).min
|
|
516
518
|
|
|
517
519
|
text_attn_weights = text_attn_weights.softmax(dim=-1)
|
|
518
520
|
|