optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -6,22 +6,28 @@ import torch.nn as nn
6
6
  from transformers import PreTrainedModel
7
7
 
8
8
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper, apply_rotary_pos_emb
9
+ from .configuration_qwen2_5_vl import RBLNQwen2_5_VisionTransformerPretrainedModelConfig
9
10
 
10
11
 
11
12
  class Qwen2_5_VisionTransformerWrapper(nn.Module):
12
- def __init__(self, model: torch.nn.Module):
13
+ def __init__(self, model: torch.nn.Module, rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig):
13
14
  super().__init__()
14
- self._original_mod = model
15
15
  self.fullatt_block_indexes = model.fullatt_block_indexes
16
16
  self.merger = model.merger
17
+ self.rbln_config = rbln_config
17
18
  window_seq_len = (model.window_size // model.patch_size) ** 2
18
- self.blocks = self.wrap_vision_blocks(model.blocks, window_seq_len)
19
+ self.blocks = self.wrap_vision_blocks(model.blocks, window_seq_len, rbln_config)
19
20
 
20
- def wrap_vision_blocks(self, blocks: torch.nn.ModuleList, window_seq_len: int):
21
+ def wrap_vision_blocks(
22
+ self,
23
+ blocks: torch.nn.ModuleList,
24
+ window_seq_len: int,
25
+ rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
26
+ ):
21
27
  wrapped_blocks = []
22
28
  for i, block in enumerate(blocks):
23
29
  is_full_attn = True if i in self.fullatt_block_indexes else False
24
- wrapped_blocks.append(Qwen2_5_VLVisionBlock(block, is_full_attn, window_seq_len))
30
+ wrapped_blocks.append(Qwen2_5_VLVisionBlock(block, is_full_attn, window_seq_len, rbln_config))
25
31
  return nn.ModuleList(wrapped_blocks)
26
32
 
27
33
  def forward(
@@ -32,8 +38,8 @@ class Qwen2_5_VisionTransformerWrapper(nn.Module):
32
38
  cos: torch.Tensor,
33
39
  sin: torch.Tensor,
34
40
  ):
35
- full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
36
- window_attn_masks = (1 - window_attn_masks) * torch.finfo(torch.float32).min
41
+ full_attn_masks = (1.0 - full_attn_masks) * torch.finfo(hidden_states.dtype).min
42
+ window_attn_masks = (1.0 - window_attn_masks) * torch.finfo(hidden_states.dtype).min
37
43
 
38
44
  for i, block in enumerate(self.blocks):
39
45
  attn_masks = full_attn_masks if i in self.fullatt_block_indexes else window_attn_masks
@@ -45,16 +51,23 @@ class Qwen2_5_VisionTransformerWrapper(nn.Module):
45
51
 
46
52
 
47
53
  class Qwen2_5_VLVisionBlock(torch.nn.Module):
48
- def __init__(self, model: torch.nn.Module, is_full_attn: bool, window_seq_len: int):
54
+ def __init__(
55
+ self,
56
+ model: torch.nn.Module,
57
+ is_full_attn: bool,
58
+ window_seq_len: int,
59
+ rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
60
+ ):
49
61
  super().__init__()
50
62
  self._origin_model = model
63
+ self.rbln_config = rbln_config
51
64
  self.norm1 = model.norm1
52
65
  self.norm2 = model.norm2
53
66
 
54
67
  if is_full_attn:
55
- self.attn = Qwen2_5_VLVisionFullAttention(model.attn)
68
+ self.attn = Qwen2_5_VLVisionFullAttention(model.attn, rbln_config)
56
69
  else:
57
- self.attn = Qwen2_5_VLVisionWindowAttention(model.attn, window_seq_len)
70
+ self.attn = Qwen2_5_VLVisionWindowAttention(model.attn, window_seq_len, rbln_config)
58
71
  self.mlp = model.mlp
59
72
 
60
73
  def forward(
@@ -73,13 +86,15 @@ class Qwen2_5_VLVisionBlock(torch.nn.Module):
73
86
 
74
87
 
75
88
  class Qwen2_5_VLVisionFullAttention(nn.Module):
76
- def __init__(self, model: nn.Module) -> None:
89
+ def __init__(self, model: nn.Module, rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig) -> None:
77
90
  super().__init__()
78
91
  self._origin_model = model
92
+ self.rbln_config = rbln_config
79
93
  self.num_heads = model.num_heads
80
94
  self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
81
95
  self.qkv = model.qkv
82
96
  self.proj = model.proj
97
+ self.scale = torch.tensor(1 / math.sqrt(self.head_dim), dtype=rbln_config.dtype)
83
98
 
84
99
  def forward(
85
100
  self,
@@ -96,9 +111,9 @@ class Qwen2_5_VLVisionFullAttention(nn.Module):
96
111
  cos, sin = position_embeddings
97
112
  q, k = apply_rotary_pos_emb(q, k, cos, sin)
98
113
 
99
- attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
114
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
100
115
  attn_weights = attn_weights + attn_masks
101
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
116
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=hidden_states.dtype)
102
117
  attn_output = torch.matmul(attn_weights, v)
103
118
  attn_output = attn_output.transpose(1, 2)
104
119
  attn_output = attn_output.reshape(1, seq_length, -1)
@@ -108,14 +123,18 @@ class Qwen2_5_VLVisionFullAttention(nn.Module):
108
123
 
109
124
 
110
125
  class Qwen2_5_VLVisionWindowAttention(nn.Module):
111
- def __init__(self, model: nn.Module, window_seq_len: int) -> None:
126
+ def __init__(
127
+ self, model: nn.Module, window_seq_len: int, rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig
128
+ ) -> None:
112
129
  super().__init__()
113
130
  self._origin_model = model
131
+ self.rbln_config = rbln_config
114
132
  self.num_heads = model.num_heads
115
133
  self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
116
134
  self.qkv = model.qkv
117
135
  self.proj = model.proj
118
136
  self.window_seq_len = window_seq_len
137
+ self.scale = torch.tensor(1 / math.sqrt(self.head_dim), dtype=rbln_config.dtype)
119
138
 
120
139
  def forward(
121
140
  self,
@@ -142,10 +161,10 @@ class Qwen2_5_VLVisionWindowAttention(nn.Module):
142
161
  sin = sin.reshape(num_windows, 1, seq_length // num_windows, -1)
143
162
  q, k = apply_rotary_pos_emb(q, k, cos, sin)
144
163
 
145
- attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
164
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
146
165
 
147
166
  attn_weights = attn_weights + attn_masks
148
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
167
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
149
168
  attn_output = torch.matmul(attn_weights, v)
150
169
  attn_output = attn_output.transpose(1, 2)
151
170
  attn_output = attn_output.reshape(1, seq_length, -1)
@@ -155,6 +174,12 @@ class Qwen2_5_VLVisionWindowAttention(nn.Module):
155
174
 
156
175
 
157
176
  class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
177
+ def get_decoder_layers(self, model: PreTrainedModel):
178
+ return model.model.language_model.layers if hasattr(model, "model") else model.language_model.layers
179
+
180
+ def get_model_layer(self, model: PreTrainedModel):
181
+ return model.model.language_model if hasattr(model, "model") else model.language_model
182
+
158
183
  def prepare_forward_args(self, *args):
159
184
  args = list(args)
160
185
  input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
@@ -163,10 +188,10 @@ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
163
188
  global_block_tables = args.pop(0)
164
189
  local_block_tables = None
165
190
  position_embeds = args.pop(0)
166
- query_position = args.pop(0) if self.phase == "prefill" else None
191
+ query_position = args.pop(0) if self.phase == "prefill" and self.rbln_config.logits_to_keep > 0 else None
167
192
  position_ids = None
168
- lora_int_id = None
169
193
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
194
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
170
195
  past_key_values = args
171
196
 
172
197
  if len(past_key_values) != 2 * self.num_hidden_layers:
@@ -197,24 +222,3 @@ class Qwen2_5_VL_LanguageModelWrapper(DecoderOnlyWrapper):
197
222
  past_key_values,
198
223
  position_embeds,
199
224
  )
200
-
201
- def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
202
- new_layers = []
203
-
204
- for layer_idx, layer in enumerate(model.model.language_model.layers):
205
- is_sliding = layer_idx in self.rbln_config.sliding_window_layers
206
- new_self_attn = self.get_rbln_attn_class()(
207
- self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
208
- )
209
- new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
210
- new_layers.append(new_layer)
211
-
212
- new_model = self.get_rbln_model_class()(
213
- model.model.language_model,
214
- new_layers,
215
- self.rbln_config,
216
- use_learned_pos_emb=self.__class__._use_learned_pos_emb,
217
- )
218
-
219
- new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
220
- return new_model
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_qwen2_moe import RBLNQwen2MoeForCausalLMConfig
16
+ from .modeling_qwen2_moe import RBLNQwen2MoeForCausalLM
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNQwen2MoeForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Qwen2 Moe models.
21
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
22
+ Example usage:
23
+ ```python
24
+ from optimum.rbln import RBLNQwen2MoeForCausalLM, RBLNQwen2MoeForCausalLMConfig
25
+ # Create a configuration object
26
+ config = RBLNQwen2MoeForCausalLMConfig(
27
+ batch_size=1,
28
+ max_seq_len=8192,
29
+ tensor_parallel_size=4
30
+ )
31
+ # Use the configuration with from_pretrained
32
+ model = RBLNQwen2MoeForCausalLM.from_pretrained(
33
+ "Qwen/Qwen1.5-MoE-A2.7B",
34
+ export=True,
35
+ rbln_config=config
36
+ )
37
+ ```
38
+ """
@@ -0,0 +1,68 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from .qwen2_moe_architecture import Qwen2MoeWrapper
17
+
18
+
19
+ class RBLNQwen2MoeForCausalLM(RBLNDecoderOnlyModelForCausalLM):
20
+ """
21
+ The Qwen2MoE is a Mixture-of-Experts (MoE) variant of Qwen2, available as a base model and an aligned chat model.
22
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
23
+ A class to convert and run pre-trained transformers based Qwen2MoeForCausalLM model on RBLN devices.
24
+ It implements the methods to convert a pre-trained transformers Qwen2MoeForCausalLM model into a RBLN transformer model by:
25
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
26
+ - compiling the resulting graph using the RBLN compiler.
27
+ **Configuration:**
28
+ This model uses [`RBLNQwen2MoeForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
29
+ the `rbln_config` parameter should be an instance of [`RBLNQwen2MoeForCausalLMConfig`] or a dictionary conforming to its structure.
30
+ See the [`RBLNQwen2MoeForCausalLMConfig`] class for all available configuration options.
31
+ Examples:
32
+ ```python
33
+ from optimum.rbln import RBLNQwen2MoeForCausalLM
34
+ # Simple usage using rbln_* arguments
35
+ # `max_seq_len` is automatically inferred from the model config
36
+ model = RBLNQwen2MoeForCausalLM.from_pretrained(
37
+ "Qwen/Qwen1.5-MoE-A2.7B",
38
+ export=True,
39
+ rbln_batch_size=1,
40
+ rbln_tensor_parallel_size=4,
41
+ )
42
+ # Using a config dictionary
43
+ rbln_config = {
44
+ "batch_size": 1,
45
+ "max_seq_len": 8192,
46
+ "tensor_parallel_size": 4,
47
+ }
48
+ model = RBLNQwen2MoeForCausalLM.from_pretrained(
49
+ "Qwen/Qwen1.5-MoE-A2.7B",
50
+ export=True,
51
+ rbln_config=rbln_config
52
+ )
53
+ # Using a RBLNQwen2ForCausalLMConfig instance (recommended for type checking)
54
+ from optimum.rbln import RBLNQwen2MoeForCausalLMConfig
55
+ config = RBLNQwen2MoeForCausalLMConfig(
56
+ batch_size=1,
57
+ max_seq_len=8192,
58
+ tensor_parallel_size=4
59
+ )
60
+ model = RBLNQwen2MoeForCausalLM.from_pretrained(
61
+ "Qwen/Qwen1.5-MoE-A2.7B",
62
+ export=True,
63
+ rbln_config=config
64
+ )
65
+ ```
66
+ """
67
+
68
+ _decoder_wrapper_cls = Qwen2MoeWrapper
@@ -0,0 +1,94 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..decoderonly.configuration_decoderonly import RBLNLoRAConfig
21
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyLayer, DecoderOnlyWrapper
22
+
23
+
24
+ class Qwen2MoeWrapper(DecoderOnlyWrapper):
25
+ def get_rbln_layer_class(self):
26
+ return Qwen2MoeLayer
27
+
28
+
29
+ class Qwen2MoeLayer(DecoderOnlyLayer):
30
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config: Optional[RBLNLoRAConfig] = None):
31
+ super().__init__(layer, self_attn, lora_config)
32
+ self.mlp = (
33
+ Qwen2MoeSparseMoeBlock(layer.mlp)
34
+ if layer.mlp.__class__.__name__ == "Qwen2MoeSparseMoeBlock"
35
+ else layer.mlp
36
+ )
37
+
38
+ def get_mlp(self) -> nn.Module:
39
+ return self.mlp
40
+
41
+
42
+ class Qwen2MoeSparseMoeBlock(nn.Module):
43
+ def __init__(self, model: nn.Module):
44
+ super().__init__()
45
+ self.num_experts = model.num_experts
46
+ self.top_k = model.top_k
47
+ self.norm_topk_prob = model.norm_topk_prob
48
+ self.gate = model.gate
49
+ self.shared_expert = model.shared_expert
50
+ self.shared_expert_gate = model.shared_expert_gate
51
+ self.experts = Qwen2MoeMLP(model.experts, self.top_k, self.norm_topk_prob)
52
+
53
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
54
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
55
+ hidden_states = hidden_states.view(-1, hidden_dim)
56
+
57
+ # router_logits: (batch * sequence_length, n_experts)
58
+ router_logits = self.gate(hidden_states)
59
+ final_hidden_states = self.experts(hidden_states, router_logits)
60
+ shared_expert_output = self.shared_expert(hidden_states)
61
+ shared_expert_output = (
62
+ torch.nn.functional.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
63
+ )
64
+ final_hidden_states = final_hidden_states + shared_expert_output
65
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
66
+ return final_hidden_states
67
+
68
+
69
+ class Qwen2MoeMLP(nn.Module):
70
+ def __init__(self, expert_list, top_k, norm_topk_prob):
71
+ super().__init__()
72
+ self.hidden_size = expert_list[0].hidden_size
73
+ self.intermediate_size = expert_list[0].intermediate_size
74
+ self.top_k = top_k
75
+ self.norm_topk_prob = norm_topk_prob
76
+
77
+ self.num_experts = len(expert_list)
78
+ self.gate_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
79
+ self.up_proj = nn.Linear(self.hidden_size, self.num_experts * self.intermediate_size, bias=False)
80
+ self.down_proj = nn.Linear(self.num_experts * self.intermediate_size, self.hidden_size, bias=False)
81
+ self.gate_proj.weight.data = torch.stack([expert.gate_proj.weight.data for expert in expert_list], dim=0)
82
+ self.up_proj.weight.data = torch.stack([expert.up_proj.weight.data for expert in expert_list], dim=0)
83
+ self.down_proj.weight.data = torch.stack([expert.down_proj.weight.data for expert in expert_list], dim=0)
84
+
85
+ def forward(self, x, router_logits):
86
+ return torch.ops.rbln_custom_ops.custom_moe_glu(
87
+ x,
88
+ self.gate_proj.weight,
89
+ self.up_proj.weight,
90
+ self.down_proj.weight,
91
+ router_logits,
92
+ self.top_k,
93
+ self.norm_topk_prob,
94
+ )
@@ -15,5 +15,10 @@
15
15
  from .configuration_qwen2_vl import (
16
16
  RBLNQwen2VisionTransformerPretrainedModelConfig,
17
17
  RBLNQwen2VLForConditionalGenerationConfig,
18
+ RBLNQwen2VLModelConfig,
19
+ )
20
+ from .modeling_qwen2_vl import (
21
+ RBLNQwen2VisionTransformerPretrainedModel,
22
+ RBLNQwen2VLForConditionalGeneration,
23
+ RBLNQwen2VLModel,
18
24
  )
19
- from .modeling_qwen2_vl import RBLNQwen2VisionTransformerPretrainedModel, RBLNQwen2VLForConditionalGeneration
@@ -15,7 +15,7 @@
15
15
  from typing import Any, Dict, List, Optional, Union
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
18
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
19
19
 
20
20
 
21
21
  class RBLNQwen2VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -48,6 +48,16 @@ class RBLNQwen2VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausalLMC
48
48
  self.visual = visual
49
49
 
50
50
 
51
+ class RBLNQwen2VLModelConfig(RBLNDecoderOnlyModelConfig):
52
+ """
53
+ Configuration class for RBLNQwen2VLModel.
54
+ """
55
+
56
+ def __init__(self, visual: Optional[RBLNModelConfig] = None, **kwargs: Dict[str, Any]):
57
+ super().__init__(**kwargs)
58
+ self.visual = self.initialize_submodule_config(submodule_config=visual)
59
+
60
+
51
61
  class RBLNQwen2VisionTransformerPretrainedModelConfig(RBLNModelConfig):
52
62
  def __init__(self, max_seq_lens: Union[int, List[int]] = None, **kwargs: Dict[str, Any]):
53
63
  """