optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -99,9 +99,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
99
99
  return True
100
100
 
101
101
  @classmethod
102
- def get_pytorch_model(cls, *args, **kwargs):
103
- model = super().get_pytorch_model(*args, **kwargs)
104
-
102
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
105
103
  with no_init_weights():
106
104
  model_cls_name = model.model.language_model.__class__.__name__
107
105
  causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
@@ -135,7 +133,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
135
133
  return self.language_model.get_input_embeddings()
136
134
 
137
135
  @classmethod
138
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
136
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
139
137
  return model.multi_modal_projector
140
138
 
141
139
  @classmethod
@@ -301,28 +299,60 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
301
299
  generate_idx: Optional[torch.Tensor] = None,
302
300
  padded_cache_lengths: Optional[torch.Tensor] = None,
303
301
  position_ids: Optional[torch.Tensor] = None,
302
+ output_hidden_states: Optional[bool] = None,
304
303
  **lm_kwargs: Dict[str, Any],
305
304
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
305
+ output_hidden_states = (
306
+ output_hidden_states
307
+ if output_hidden_states is not None
308
+ else self.rbln_config.language_model.output_hidden_states
309
+ )
310
+ if output_hidden_states != self.rbln_config.language_model.output_hidden_states:
311
+ raise ValueError(
312
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.language_model.output_hidden_states {self.rbln_config.language_model.output_hidden_states} "
313
+ f"Please compile again with the correct argument."
314
+ )
315
+
306
316
  # prefill
307
317
  if cache_position is None:
308
318
  logits = []
309
319
  inputs_embeds = self._preprocess_prefill(input_ids, inputs_embeds, pixel_values)
310
320
  batch_size = inputs_embeds.shape[0]
311
321
 
322
+ all_hidden_states = (
323
+ tuple(
324
+ torch.zeros(
325
+ batch_size,
326
+ inputs_embeds.shape[1],
327
+ self.config.text_config.hidden_size,
328
+ dtype=self.rbln_config.dtype,
329
+ )
330
+ for _ in range(self.config.text_config.num_hidden_layers + 1)
331
+ )
332
+ if self.rbln_config.language_model.output_hidden_states
333
+ else None
334
+ )
335
+
312
336
  for b_idx in range(batch_size):
313
337
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
314
338
  token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
315
339
  cache_position = self.get_padded_cache_position(cache_position, token_type_id)
316
340
 
317
- output = self.language_model.prefill_decoder(
341
+ outputs = self.language_model.prefill_decoder(
318
342
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
319
343
  attention_mask=attention_mask[b_idx],
320
344
  cache_position=cache_position,
321
345
  batch_idx=b_idx,
322
346
  token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
323
347
  )
324
- padded_cache_lengths[b_idx] += output.padded_cache_lengths
325
- logits.append(output.logits)
348
+ padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
349
+ logits.append(outputs.logits)
350
+ if self.rbln_config.language_model.output_hidden_states:
351
+ for l_idx in range(self.config.text_config.num_hidden_layers + 1):
352
+ mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
353
+ all_hidden_states[l_idx][b_idx].index_copy_(
354
+ dim=0, index=mask_indices, source=outputs.hidden_states[l_idx][0]
355
+ )
326
356
 
327
357
  logits = torch.cat(logits, dim=0)
328
358
  # decoder
@@ -336,15 +366,20 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMix
336
366
  f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
337
367
  )
338
368
 
339
- logits = self.language_model.decoders[batch_size](
369
+ outputs = self.language_model.decoders[batch_size](
340
370
  input_ids=input_ids,
341
371
  inputs_embeds=inputs_embeds,
342
372
  cache_position=cache_position,
343
373
  position_ids=position_ids if self.rbln_config.language_model.use_position_ids else None,
344
- ).logits
374
+ )
375
+ logits = outputs.logits
376
+ all_hidden_states = outputs.hidden_states
345
377
 
346
378
  return RBLNDecoderOnlyOutput(
347
- logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
379
+ logits=logits,
380
+ generate_idx=generate_idx,
381
+ padded_cache_lengths=padded_cache_lengths,
382
+ hidden_states=all_hidden_states,
348
383
  )
349
384
 
350
385
 
@@ -405,26 +440,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
405
440
  )
406
441
  return embed_tokens
407
442
 
408
- @classmethod
409
- def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
410
- sliding_window = getattr(model_config, "sliding_window", None)
411
- sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
412
- if sliding_window_pattern is None:
413
- if hasattr(model_config, "layer_types"):
414
- first_full_attention_index = model_config.layer_types.index("full_attention")
415
- sliding_window_pattern = first_full_attention_index + 1
416
- else:
417
- raise ValueError("Cannot determine sliding_window_pattern from model_config")
418
-
419
- if sliding_window_pattern <= model_config.num_hidden_layers:
420
- rbln_config.cache_impl = "hybrid"
421
- rbln_config.sliding_window = sliding_window
422
- rbln_config.sliding_window_layers = [
423
- i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
424
- ]
425
-
426
- return rbln_config
427
-
428
443
  @classmethod
429
444
  def _update_submodule_config(
430
445
  cls,
@@ -482,7 +497,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
482
497
  @classmethod
483
498
  @torch.inference_mode()
484
499
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
485
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
500
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
486
501
 
487
502
  rbln_compile_configs = rbln_config.compile_cfgs
488
503
  prefill_compile_config = rbln_compile_configs[0]
@@ -20,8 +20,6 @@ import torch.nn as nn
20
20
 
21
21
  from ..decoderonly.decoderonly_architecture import (
22
22
  DecoderOnlyAttention,
23
- DecoderOnlyLayer,
24
- DecoderOnlyModel,
25
23
  DecoderOnlyWrapper,
26
24
  )
27
25
 
@@ -34,12 +32,6 @@ class GPT2Wrapper(DecoderOnlyWrapper):
34
32
  def get_rbln_attn_class(self):
35
33
  return GPT2Attention
36
34
 
37
- def get_rbln_layer_class(self):
38
- return GPT2Layer
39
-
40
- def get_rbln_model_class(self):
41
- return GPT2Model
42
-
43
35
  def get_attn_layer(self, layer: nn.Module):
44
36
  return layer.attn
45
37
 
@@ -50,30 +42,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
50
42
  return model.transformer.h if self.is_causal_lm else model.h
51
43
 
52
44
 
53
- class GPT2Model(DecoderOnlyModel):
54
- def get_last_layernorm(self) -> nn.LayerNorm:
55
- return self._original_mod.ln_f
56
-
57
- def get_embedding(self) -> nn.Embedding:
58
- return self._original_mod.wte
59
-
60
- def get_pos_embedding(self) -> nn.Embedding:
61
- return self._original_mod.wpe
62
-
63
-
64
- class GPT2Layer(DecoderOnlyLayer):
65
- def get_pre_attention_layernorm(self) -> nn.LayerNorm:
66
- return self._original_mod.ln_1
67
-
68
- def get_post_attention_layernorm(self) -> nn.LayerNorm:
69
- return self._original_mod.ln_2
70
-
71
-
72
45
  class GPT2Attention(DecoderOnlyAttention):
73
- def __post_init__(self):
74
- self.c_attn = self._original_mod.c_attn
75
- self.o_proj = self._original_mod.c_proj
76
- self.split_size = self._original_mod.split_size
46
+ def __post_init__(self, self_attn):
47
+ self.c_attn = self_attn.c_attn
48
+ self.o_proj = self_attn.c_proj
49
+ self.split_size = self_attn.split_size
50
+ self.num_key_value_heads = self_attn.num_heads
77
51
 
78
52
  def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
79
53
  if lora_int_id is not None:
@@ -82,12 +56,12 @@ class GPT2Attention(DecoderOnlyAttention):
82
56
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
83
57
  return query_states, key_states, value_states
84
58
 
85
- def get_attn_scale(self):
59
+ def get_attn_scale(self, self_attn):
86
60
  scale = 1.0
87
- if self._original_mod.scale_attn_weights:
61
+ if self_attn.scale_attn_weights:
88
62
  scale /= math.sqrt(self.head_dim)
89
63
 
90
- if self._original_mod.scale_attn_by_inverse_layer_idx:
64
+ if self_attn.scale_attn_by_inverse_layer_idx:
91
65
  scale /= 1 + self.layer_idx
92
66
 
93
67
  return scale
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_gpt_oss import RBLNGptOssForCausalLMConfig
16
+ from .modeling_gpt_oss import RBLNGptOssForCausalLM
@@ -0,0 +1,41 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNGptOssForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN GPT-OSS models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNGptOssForCausalLM, RBLNGptOssForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNGptOssForCausalLMConfig(
30
+ batch_size=1,
31
+ tensor_parallel_size=4
32
+ )
33
+
34
+ # Use the configuration with from_pretrained
35
+ model = RBLNGptOssForCausalLM.from_pretrained(
36
+ "openai/gpt-oss-20b",
37
+ export=True,
38
+ rbln_config=config
39
+ )
40
+ ```
41
+ """
@@ -0,0 +1,122 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
23
+ from ..decoderonly.decoderonly_architecture import (
24
+ DecoderOnlyAttention,
25
+ DecoderOnlyLayer,
26
+ DecoderOnlyWrapper,
27
+ )
28
+
29
+
30
+ class RBLNGptOssWrapper(DecoderOnlyWrapper):
31
+ def get_rbln_layer_class(self):
32
+ return RBLNGptOssLayer
33
+
34
+
35
+ class RBLNGptOssLayer(DecoderOnlyLayer):
36
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
37
+ super().__init__(layer, self_attn, lora_config)
38
+ self.mlp = RBLNGptOssMLP(layer.mlp)
39
+
40
+ def get_mlp(self) -> nn.Module:
41
+ return self.mlp
42
+
43
+
44
+ class RBLNGptOssTopKRouter(nn.Module):
45
+ def __init__(self, model):
46
+ super().__init__()
47
+ self.weight = model.weight
48
+ self.bias = model.bias
49
+
50
+ def forward(self, hidden_states):
51
+ return F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
52
+
53
+
54
+ class RBLNGptOssExperts(nn.Module):
55
+ def __init__(self, model, top_k: Optional[int] = None):
56
+ super().__init__()
57
+ self.intermediate_size = model.intermediate_size
58
+ self.num_experts = model.num_experts
59
+ self.hidden_size = model.hidden_size
60
+
61
+ self.register_buffer(
62
+ "gate_proj_blocks",
63
+ model.gate_up_proj_blocks.data[:, ::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
64
+ )
65
+ self.register_buffer("gate_proj_scales", model.gate_up_proj_scales.data[:, ::2, :])
66
+ self.register_buffer(
67
+ "gate_proj_bias",
68
+ model.gate_up_proj_bias.data[:, ::2].reshape(self.num_experts, self.intermediate_size),
69
+ )
70
+
71
+ self.register_buffer(
72
+ "up_proj_blocks",
73
+ model.gate_up_proj_blocks.data[:, 1::2, :, :].reshape(self.num_experts, self.intermediate_size, -1),
74
+ )
75
+ self.register_buffer("up_proj_scales", model.gate_up_proj_scales.data[:, 1::2, :])
76
+ self.register_buffer(
77
+ "up_proj_bias", model.gate_up_proj_bias.data[:, 1::2].reshape(self.num_experts, self.intermediate_size)
78
+ )
79
+
80
+ self.register_buffer(
81
+ "down_proj_blocks", model.down_proj_blocks.data.reshape(self.num_experts, self.hidden_size, -1)
82
+ )
83
+ self.register_buffer("down_proj_scales", model.down_proj_scales.data)
84
+ self.register_buffer("down_proj_bias", model.down_proj_bias.data)
85
+
86
+ self.alpha = model.alpha # 1.702
87
+ self.limit = model.limit # 7.0
88
+ self.top_k = top_k
89
+
90
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor:
91
+ return torch.ops.rbln_custom_ops.custom_moe_glu_mxfp4(
92
+ hidden_states,
93
+ self.gate_proj_blocks,
94
+ self.gate_proj_scales,
95
+ self.gate_proj_bias,
96
+ self.up_proj_blocks,
97
+ self.up_proj_scales,
98
+ self.up_proj_bias,
99
+ self.down_proj_blocks,
100
+ self.down_proj_scales,
101
+ self.down_proj_bias,
102
+ router_logits,
103
+ torch.tensor(self.alpha, dtype=hidden_states.dtype),
104
+ torch.tensor(self.limit, dtype=hidden_states.dtype),
105
+ k=self.top_k,
106
+ post_norm=True,
107
+ )
108
+
109
+
110
+ class RBLNGptOssMLP(nn.Module):
111
+ def __init__(self, model):
112
+ super().__init__()
113
+ self.router = RBLNGptOssTopKRouter(model.router)
114
+ self.experts = RBLNGptOssExperts(model.experts, top_k=model.router.top_k)
115
+
116
+ def forward(self, hidden_states):
117
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
118
+ hidden_states = hidden_states.view(-1, hidden_dim)
119
+ router_logits = self.router(hidden_states)
120
+ routed_out = self.experts(hidden_states, router_logits=router_logits)
121
+ routed_out = routed_out.reshape(batch_size, sequence_length, hidden_dim)
122
+ return routed_out
@@ -0,0 +1,165 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Optional, Union
16
+
17
+ import torch
18
+ from safetensors.torch import load_file
19
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
20
+ from transformers.integrations.mxfp4 import Mxfp4GptOssExperts
21
+ from transformers.modeling_utils import PreTrainedModel, no_init_weights
22
+
23
+ from ....utils.logging import get_logger
24
+ from ...models.decoderonly import (
25
+ RBLNDecoderOnlyModelConfig,
26
+ RBLNDecoderOnlyModelForCausalLM,
27
+ RBLNDecoderOnlyModelForCausalLMConfig,
28
+ )
29
+ from ...utils.rbln_quantization import load_weight_files
30
+ from .gpt_oss_architecture import RBLNGptOssWrapper
31
+
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ class RBLNGptOssForCausalLM(RBLNDecoderOnlyModelForCausalLM):
40
+ """
41
+ The GPT-OSS Model transformer with a language modeling head (linear layer) on top.
42
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
43
+
44
+ A class to convert and run pre-trained transformers based GPT-OSSForCausalLM model on RBLN devices.
45
+ It implements the methods to convert a pre-trained transformers GPT-OSSForCausalLM model into a RBLN transformer model by:
46
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
47
+ - compiling the resulting graph using the RBLN compiler.
48
+
49
+ **Configuration:**
50
+ This model uses [`RBLNGptOssForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
51
+ the `rbln_config` parameter should be an instance of [`RBLNGptOssForCausalLMConfig`] or a dictionary conforming to its structure.
52
+
53
+ See the [`RBLNGptOssForCausalLMConfig`] class for all available configuration options.
54
+
55
+ Examples:
56
+ ```python
57
+ from optimum.rbln import RBLNGptOssForCausalLM
58
+
59
+ # Simple usage using rbln_* arguments
60
+ # `max_seq_len` is automatically inferred from the model config
61
+ model = RBLNGptOssForCausalLM.from_pretrained(
62
+ "openai/gpt-oss-20b",
63
+ export=True,
64
+ rbln_batch_size=1,
65
+ rbln_tensor_parallel_size=4,
66
+ )
67
+
68
+
69
+ # Using a config dictionary
70
+ rbln_config = {
71
+ "batch_size": 1,
72
+ "tensor_parallel_size": 4,
73
+ }
74
+ model = RBLNGptOssForCausalLM.from_pretrained(
75
+ "openai/gpt-oss-20b",
76
+ export=True,
77
+ rbln_config=rbln_config
78
+ )
79
+
80
+
81
+ # Using a RBLNGptOssForCausalLMConfig instance (recommended for type checking)
82
+ from optimum.rbln import RBLNGptOssForCausalLMConfig
83
+
84
+ config = RBLNGptOssForCausalLMConfig(
85
+ batch_size=1,
86
+ tensor_parallel_size=4
87
+ )
88
+ model = RBLNGptOssForCausalLM.from_pretrained(
89
+ "openai/gpt-oss-20b",
90
+ export=True,
91
+ rbln_config=config
92
+ )
93
+ ```
94
+ """
95
+
96
+ _decoder_wrapper_cls = RBLNGptOssWrapper
97
+
98
+ @staticmethod
99
+ def _get_dtype(dtype: Union[str, torch.dtype] = None, torch_dtype: Union[str, torch.dtype] = None):
100
+ # For BC on torch_dtype argument
101
+ if torch_dtype is not None:
102
+ logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
103
+ # If both kwargs are provided, use `dtype`
104
+ dtype = dtype if dtype is not None else torch_dtype
105
+
106
+ # As mxfp4_quantizer's default dtype
107
+ if dtype is None or dtype == "auto":
108
+ dtype = torch.bfloat16
109
+
110
+ return dtype
111
+
112
+ @classmethod
113
+ def get_pytorch_model(
114
+ cls,
115
+ model_id: str,
116
+ *args,
117
+ rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
118
+ dtype: Union[str, torch.dtype] = None,
119
+ torch_dtype: Union[str, torch.dtype] = None,
120
+ config: Optional[PretrainedConfig] = None,
121
+ **kwargs,
122
+ ) -> PreTrainedModel:
123
+ safetensor_files = load_weight_files(model_id, exception_keywords=["original"])
124
+ state_dict = {k: v for f in safetensor_files for k, v in load_file(f).items()}
125
+
126
+ if config is None:
127
+ config, kwargs = AutoConfig.from_pretrained(model_id, return_unused_kwargs=True)
128
+
129
+ dtype = cls._get_dtype(dtype, torch_dtype)
130
+
131
+ with no_init_weights():
132
+ model = AutoModelForCausalLM.from_config(config, dtype=dtype, **kwargs)
133
+
134
+ _replace_with_mxfp4_linear(model, config)
135
+ model.load_state_dict(state_dict, strict=False)
136
+
137
+ return model
138
+
139
+ @classmethod
140
+ def _update_rbln_config(
141
+ cls,
142
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
143
+ model: Optional["PreTrainedModel"] = None,
144
+ model_config: Optional["PretrainedConfig"] = None,
145
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
146
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
147
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
148
+
149
+ if rbln_config.use_attention_mask:
150
+ raise ValueError(
151
+ "use_attention_mask is not supported for GPT-OSS because custom attention does not support attention sink for masked attention"
152
+ )
153
+
154
+ return rbln_config
155
+
156
+
157
+ def _replace_with_mxfp4_linear(
158
+ model,
159
+ config,
160
+ ):
161
+ for name, module in model.named_children():
162
+ if module.__class__.__name__ == "GptOssExperts":
163
+ model._modules[name] = Mxfp4GptOssExperts(config)
164
+ if len(list(module.children())) > 0:
165
+ _replace_with_mxfp4_linear(module, config)
@@ -50,11 +50,14 @@ class RBLNGroundingDinoForObjectDetectionConfig(RBLNImageModelConfig):
50
50
  Raises:
51
51
  ValueError: If batch_size is not a positive integer.
52
52
  """
53
- super().__init__(**kwargs)
54
- self.encoder = encoder
55
- self.decoder = decoder
56
- self.text_backbone = text_backbone
57
- self.backbone = backbone
53
+
54
+ super().__init__(batch_size=batch_size, **kwargs)
55
+ self.encoder = self.initialize_submodule_config(submodule_config=encoder, batch_size=self.batch_size)
56
+ self.decoder = self.initialize_submodule_config(submodule_config=decoder, batch_size=self.batch_size)
57
+ self.text_backbone = self.initialize_submodule_config(
58
+ submodule_config=text_backbone, batch_size=self.batch_size
59
+ )
60
+ self.backbone = self.initialize_submodule_config(submodule_config=backbone, batch_size=self.batch_size)
58
61
  self.output_attentions = output_attentions if output_attentions is not None else False
59
62
  self.output_hidden_states = output_hidden_states if output_hidden_states is not None else False
60
63
 
@@ -150,7 +150,7 @@ class _GroundingDinoEncoder(torch.nn.Module):
150
150
  all_attn_fused_vision = () if output_attentions else None
151
151
  all_attn_enhanced_text = () if output_attentions else None
152
152
  all_attn_deformable = () if output_attentions else None
153
- for i, encoder_layer in enumerate(self.layers):
153
+ for _, encoder_layer in enumerate(self.layers):
154
154
  if output_hidden_states:
155
155
  encoder_vision_states += (vision_features,)
156
156
  encoder_text_states += (text_features,)
@@ -509,10 +509,12 @@ class _GroundingDinoBiMultiHeadAttention(torch.nn.Module):
509
509
 
510
510
  # mask vision for language
511
511
  if vision_attention_mask is not None:
512
- # RBLN FIX: bool tensor to float tensor
513
- mask = vision_attention_mask * torch.finfo(torch.float16).min
514
- text_attn_weights = text_attn_weights.transpose(1, 2) + mask
515
- text_attn_weights = text_attn_weights.transpose(1, 2)
512
+ # RBLN FIX: bool tensor to float tensor, broadcast across heads and src_len
513
+ mask = vision_attention_mask
514
+ if mask.dim() == 3:
515
+ mask = mask[..., 0]
516
+ mask = mask[:, None, None, :].repeat(1, self.num_heads, 1, 1).flatten(0, 1)
517
+ text_attn_weights = text_attn_weights + mask * torch.finfo(text_attn_weights.dtype).min
516
518
 
517
519
  text_attn_weights = text_attn_weights.softmax(dim=-1)
518
520