optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,165 @@
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import PreTrainedModel
7
+
8
+ from ..decoderonly.decoderonly_architecture import (
9
+ DecoderOnlyWrapper,
10
+ apply_rotary_pos_emb,
11
+ )
12
+
13
+
14
+ class Qwen2VisionTransformerWrapper(nn.Module):
15
+ def __init__(self, model: torch.nn.Module):
16
+ super().__init__()
17
+ self._original_mod = model
18
+ self.merger = model.merger
19
+ self.blocks = self.wrap_vision_blocks(model.blocks)
20
+
21
+ def wrap_vision_blocks(self, blocks: torch.nn.ModuleList):
22
+ wrapped_blocks = []
23
+ for i, block in enumerate(blocks):
24
+ wrapped_blocks.append(Qwen2VLVisionBlock(block))
25
+ return nn.ModuleList(wrapped_blocks)
26
+
27
+ def forward(
28
+ self,
29
+ hidden_states: torch.Tensor,
30
+ full_attn_masks: torch.Tensor,
31
+ cos: torch.Tensor,
32
+ sin: torch.Tensor,
33
+ ):
34
+ full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
35
+
36
+ for block in self.blocks:
37
+ hidden_states = block(hidden_states, full_attn_masks, [cos, sin])
38
+
39
+ return self.merger(hidden_states)
40
+
41
+
42
+ class Qwen2VLVisionBlock(torch.nn.Module):
43
+ def __init__(self, model: torch.nn.Module):
44
+ super().__init__()
45
+ self._origin_model = model
46
+ self.norm1 = model.norm1
47
+ self.norm2 = model.norm2
48
+
49
+ self.attn = VisionAttention(model.attn)
50
+ self.mlp = model.mlp
51
+
52
+ def forward(
53
+ self,
54
+ hidden_states: torch.Tensor,
55
+ attn_masks: torch.Tensor,
56
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
57
+ ) -> torch.Tensor:
58
+ hidden_states = hidden_states + self.attn(
59
+ self.norm1(hidden_states),
60
+ attn_masks,
61
+ position_embeddings,
62
+ )
63
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
64
+ return hidden_states
65
+
66
+
67
+ class VisionAttention(nn.Module):
68
+ def __init__(self, model: nn.Module) -> None:
69
+ super().__init__()
70
+ self._origin_model = model
71
+ self.num_heads = model.num_heads
72
+ self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
73
+ self.qkv = model.qkv
74
+ self.proj = model.proj
75
+
76
+ def forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ attn_masks: torch.Tensor,
80
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
81
+ ) -> torch.Tensor:
82
+ seq_length = hidden_states.shape[0]
83
+ hidden_states = hidden_states.unsqueeze(0)
84
+ q, k, v = (
85
+ self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
86
+ )
87
+
88
+ cos, sin = position_embeddings
89
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
90
+
91
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
92
+ attn_weights = attn_weights + attn_masks
93
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
94
+ attn_output = torch.matmul(attn_weights, v)
95
+ attn_output = attn_output.transpose(1, 2)
96
+ attn_output = attn_output.reshape(1, seq_length, -1)
97
+ attn_output = self.proj(attn_output).squeeze(0)
98
+
99
+ return attn_output
100
+
101
+
102
+ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
103
+ def prepare_forward_args(self, *args):
104
+ args = list(args)
105
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
106
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
107
+ cache_position = args.pop(0)
108
+ global_block_tables = args.pop(0)
109
+ local_block_tables = None
110
+ position_embeds = args.pop(0)
111
+ query_position = args.pop(0) if self.phase == "prefill" else None
112
+ position_ids = None
113
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
114
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
115
+ past_key_values = args
116
+
117
+ if len(past_key_values) != 2 * self.num_hidden_layers:
118
+ raise ValueError(
119
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
120
+ )
121
+
122
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
123
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
124
+ _past_key_values = []
125
+ for i in range(self.config.num_hidden_layers):
126
+ key_states = past_key_values[i * 2]
127
+ value_states = past_key_values[i * 2 + 1]
128
+ past_key_value = [key_states, value_states]
129
+ _past_key_values.append(past_key_value)
130
+ past_key_values = _past_key_values
131
+
132
+ return (
133
+ input_ids,
134
+ inputs_embeds,
135
+ cache_position,
136
+ global_block_tables,
137
+ local_block_tables,
138
+ query_position,
139
+ attention_mask,
140
+ position_ids,
141
+ lora_int_id,
142
+ past_key_values,
143
+ position_embeds,
144
+ )
145
+
146
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
147
+ new_layers = []
148
+
149
+ for layer_idx, layer in enumerate(model.model.language_model.layers):
150
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
151
+ new_self_attn = self.get_rbln_attn_class()(
152
+ self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
153
+ )
154
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
155
+ new_layers.append(new_layer)
156
+
157
+ new_model = self.get_rbln_model_class()(
158
+ model.model.language_model,
159
+ new_layers,
160
+ self.rbln_config,
161
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
162
+ )
163
+
164
+ new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
165
+ return new_model
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_qwen3 import RBLNQwen3ForCausalLMConfig, RBLNQwen3ModelConfig
16
+ from .modeling_qwen3 import RBLNQwen3ForCausalLM, RBLNQwen3Model
@@ -0,0 +1,71 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNQwen3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Qwen3 models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNQwen3ForCausalLM, RBLNQwen3ForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNQwen3ForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=40960,
32
+ tensor_parallel_size=4,
33
+ kvcache_partition_len=16384
34
+ )
35
+
36
+ # Use the configuration with from_pretrained
37
+ model = RBLNQwen3ForCausalLM.from_pretrained(
38
+ "Qwen/Qwen3-4B",
39
+ export=True,
40
+ rbln_config=config
41
+ )
42
+ ```
43
+ """
44
+
45
+
46
+ class RBLNQwen3ModelConfig(RBLNDecoderOnlyModelConfig):
47
+ """
48
+ Configuration class for RBLN Qwen3 models.
49
+
50
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
51
+
52
+ Example usage:
53
+ ```python
54
+ from optimum.rbln import RBLNQwen3Model, RBLNQwen3ModelConfig
55
+
56
+ # Create a configuration object
57
+ config = RBLNQwen3ModelConfig(
58
+ batch_size=1,
59
+ max_seq_len=40960,
60
+ tensor_parallel_size=4,
61
+ kvcache_partition_len=16384
62
+ )
63
+
64
+ # Use the configuration with from_pretrained
65
+ model = RBLNQwen3Model.from_pretrained(
66
+ "Qwen/Qwen3-Embedding-4B",
67
+ export=True,
68
+ rbln_config=config
69
+ )
70
+ ```
71
+ """
@@ -0,0 +1,133 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ from transformers import PretrainedConfig
18
+
19
+ from ....utils import logging
20
+ from ...models.decoderonly import (
21
+ RBLNDecoderOnlyModel,
22
+ RBLNDecoderOnlyModelForCausalLM,
23
+ RBLNDecoderOnlyModelForCausalLMConfig,
24
+ )
25
+ from .qwen3_architecture import Qwen3Wrapper
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import PretrainedConfig
32
+
33
+
34
+ class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
35
+ """
36
+ The Qwen3 Model transformer with a language modeling head (linear layer) on top.
37
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
38
+ A class to convert and run pre-trained transformers based Qwen3ForCausalLM model on RBLN devices.
39
+ It implements the methods to convert a pre-trained transformers Qwen3ForCausalLM model into a RBLN transformer model by:
40
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
41
+ - compiling the resulting graph using the RBLN compiler.
42
+ **Configuration:**
43
+ This model uses [`RBLNQwen3ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
44
+ the `rbln_config` parameter should be an instance of [`RBLNQwen3ForCausalLMConfig`] or a dictionary conforming to its structure.
45
+ See the [`RBLNQwen3ForCausalLMConfig`] class for all available configuration options.
46
+ Examples:
47
+ ```python
48
+ from optimum.rbln import RBLNQwen3ForCausalLM
49
+ # Simple usage using rbln_* arguments
50
+ # `max_seq_len` is automatically inferred from the model config
51
+ model = RBLNQwen3ForCausalLM.from_pretrained(
52
+ "Qwen/Qwen3-4B",
53
+ export=True,
54
+ rbln_batch_size=1,
55
+ rbln_tensor_parallel_size=4,
56
+ )
57
+ # Using a config dictionary
58
+ rbln_config = {
59
+ "batch_size": 1,
60
+ "max_seq_len": 40_960,
61
+ "tensor_parallel_size": 4,
62
+ "kvcache_partition_len": 8192,
63
+ }
64
+ model = RBLNQwen3ForCausalLM.from_pretrained(
65
+ "Qwen/Qwen3-4B",
66
+ export=True,
67
+ rbln_config=rbln_config
68
+ )
69
+ # Using a RBLNQwen3ForCausalLMConfig instance (recommended for type checking)
70
+ from optimum.rbln import RBLNQwen3ForCausalLMConfig
71
+ config = RBLNQwen3ForCausalLMConfig(
72
+ batch_size=1,
73
+ max_seq_len=40_960,
74
+ tensor_parallel_size=4,
75
+ kvcache_partition_len=8192,
76
+ )
77
+ model = RBLNQwen3ForCausalLM.from_pretrained(
78
+ "Qwen/Qwen3-4B",
79
+ export=True,
80
+ rbln_config=config
81
+ )
82
+ ```
83
+ """
84
+
85
+ _decoder_wrapper_cls = Qwen3Wrapper
86
+
87
+ @classmethod
88
+ def _update_sliding_window_config(
89
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
90
+ ):
91
+ # https://github.com/huggingface/transformers/issues/35896
92
+ # There seems to be a bug in transformers(v4.52.4). Therefore, similar to when attn_implementation is eager,
93
+ # we set all layers to use sliding window in this version. This should be updated once the bug is fixed.
94
+
95
+ rbln_config.cache_impl = "sliding_window"
96
+ rbln_config.sliding_window = model_config.sliding_window
97
+ rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
98
+ return rbln_config
99
+
100
+ def forward(self, *args, **kwargs):
101
+ kwargs["return_dict"] = True
102
+ return super().forward(*args, **kwargs)
103
+
104
+
105
+ class RBLNQwen3Model(RBLNDecoderOnlyModel):
106
+ """
107
+ The bare Qwen3 Model outputting raw hidden-states without any specific head on top.
108
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
109
+ A class to convert and run pre-trained transformers based Qwen3Model on RBLN devices.
110
+ It implements the methods to convert a pre-trained transformers Qwen3Model into a RBLN transformer model by:
111
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
112
+ - compiling the resulting graph using the RBLN compiler.
113
+ **Configuration:**
114
+ This model uses [`RBLNQwen3ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
115
+ the `rbln_config` parameter should be an instance of [`RBLNQwen3ModelConfig`] or a dictionary conforming to its structure.
116
+ See the [`RBLNQwen3ModelConfig`] class for all available configuration options.
117
+ Examples:
118
+ ```python
119
+ from optimum.rbln import RBLNQwen3Model
120
+ # Simple usage using rbln_* arguments
121
+ # `max_seq_len` is automatically inferred from the model config
122
+ model = RBLNQwen3Model.from_pretrained(
123
+ "Qwen/Qwen3-Embedding-4B",
124
+ export=True,
125
+ rbln_batch_size=1,
126
+ rbln_max_seq_len=40_960,
127
+ rbln_tensor_parallel_size=4,
128
+ rbln_kvcache_partition_len=8192,
129
+ )
130
+ """
131
+
132
+ _decoder_wrapper_cls = Qwen3Wrapper
133
+ _use_rotary_emb = True
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyWrapper
17
+
18
+
19
+ class Qwen3Wrapper(DecoderOnlyWrapper):
20
+ def get_rbln_attn_class(self):
21
+ return Qwen3Attention
22
+
23
+
24
+ class Qwen3Attention(DecoderOnlyAttention):
25
+ def __post_init__(self):
26
+ self.k_proj = self._original_mod.k_proj
27
+ self.v_proj = self._original_mod.v_proj
28
+ self.q_proj = self._original_mod.q_proj
29
+ self.o_proj = self._original_mod.o_proj
30
+ self.q_norm = self._original_mod.q_norm
31
+ self.k_norm = self._original_mod.k_norm
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import Optional
17
+
16
18
  from ...configuration_generic import RBLNModelForImageClassificationConfig
17
19
 
18
20
 
@@ -23,3 +25,18 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
23
25
  This configuration class stores the configuration parameters specific to
24
26
  RBLN-optimized ResNet models for image classification tasks.
25
27
  """
28
+
29
+ def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
30
+ """
31
+ Args:
32
+ image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
33
+ Can be an integer for square images or a tuple (height, width).
34
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
35
+ output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers.
36
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
+
38
+ Raises:
39
+ ValueError: If batch_size is not a positive integer.
40
+ """
41
+ super().__init__(**kwargs)
42
+ self.output_hidden_states = output_hidden_states
@@ -13,7 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
20
+
16
21
  from ...modeling_generic import RBLNModelForImageClassification
22
+ from .configuration_resnet import RBLNResNetForImageClassificationConfig
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
17
27
 
18
28
 
19
29
  class RBLNResNetForImageClassification(RBLNModelForImageClassification):
@@ -24,3 +34,66 @@ class RBLNResNetForImageClassification(RBLNModelForImageClassification):
24
34
  on RBLN devices, supporting image classification with convolutional neural networks
25
35
  designed for computer vision tasks.
26
36
  """
37
+
38
+ @classmethod
39
+ def _update_rbln_config(
40
+ cls,
41
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
42
+ model: Optional["PreTrainedModel"] = None,
43
+ model_config: Optional["PretrainedConfig"] = None,
44
+ rbln_config: Optional["RBLNResNetForImageClassificationConfig"] = None,
45
+ ) -> "RBLNResNetForImageClassificationConfig":
46
+ if rbln_config.output_hidden_states is None:
47
+ rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
48
+
49
+ rbln_config = super()._update_rbln_config(
50
+ preprocessors=preprocessors,
51
+ model=model,
52
+ model_config=model_config,
53
+ rbln_config=rbln_config,
54
+ )
55
+
56
+ return rbln_config
57
+
58
+ @classmethod
59
+ def _wrap_model_if_needed(
60
+ cls, model: torch.nn.Module, rbln_config: "RBLNResNetForImageClassificationConfig"
61
+ ) -> torch.nn.Module:
62
+ class _ResNetForImageClassification(torch.nn.Module):
63
+ def __init__(self, model: torch.nn.Module, output_hidden_states: bool):
64
+ super().__init__()
65
+ self.model = model
66
+ self.output_hidden_states = output_hidden_states
67
+
68
+ def forward(self, *args, **kwargs):
69
+ output = self.model(*args, output_hidden_states=self.output_hidden_states, **kwargs)
70
+ return output
71
+
72
+ return _ResNetForImageClassification(model, rbln_config.output_hidden_states)
73
+
74
+ def forward(
75
+ self, pixel_values: torch.Tensor, output_hidden_states: bool = None, return_dict: bool = None, **kwargs
76
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
77
+ """
78
+ Foward pass for the RBLN-optimized ResNet model for image classification.
79
+
80
+ Args:
81
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
82
+ output_hidden_states (bool, *optional*, defaults to False): Whether or not to return the hidden states of all layers.
83
+ See hidden_states under returned tensors for more details.
84
+ return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
85
+
86
+ Returns:
87
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
88
+ """
89
+ output_hidden_states = (
90
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
91
+ )
92
+
93
+ if output_hidden_states != self.rbln_config.output_hidden_states:
94
+ raise ValueError(
95
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
96
+ f"Please compile again with the correct argument."
97
+ )
98
+
99
+ return super().forward(pixel_values=pixel_values, return_dict=return_dict, **kwargs)
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
19
+
15
20
  from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassification
16
21
 
17
22
 
@@ -26,6 +31,19 @@ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
26
31
 
27
32
  rbln_model_input_names = ["input_ids", "attention_mask"]
28
33
 
34
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Union[Tuple, MaskedLMOutput]:
35
+ """
36
+ Forward pass for the RBLN-optimized RoBERTa model for masked language modeling tasks.
37
+
38
+ Args:
39
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
40
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
41
+
42
+ Returns:
43
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
44
+ """
45
+ return super().forward(input_ids, attention_mask, **kwargs)
46
+
29
47
 
30
48
  class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
31
49
  """
@@ -37,3 +55,18 @@ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
37
55
  """
38
56
 
39
57
  rbln_model_input_names = ["input_ids", "attention_mask"]
58
+
59
+ def forward(
60
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
61
+ ) -> Union[Tuple, SequenceClassifierOutput]:
62
+ """
63
+ Forward pass for the RBLN-optimized RoBERTa model for sequence classification tasks.
64
+
65
+ Args:
66
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
67
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
68
+
69
+ Returns:
70
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
71
+ """
72
+ return super().forward(input_ids, attention_mask, **kwargs)
@@ -12,11 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
16
-
17
- import rebel
15
+ from typing import Any, Optional
18
16
 
19
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.deprecation import deprecate_kwarg
20
19
  from ....utils.logging import get_logger
21
20
 
22
21
 
@@ -24,14 +23,18 @@ logger = get_logger()
24
23
 
25
24
 
26
25
  class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
26
+ support_paged_attention = None
27
+
28
+ @deprecate_kwarg(old_name="pad_token_id", version="0.10.0")
27
29
  def __init__(
28
30
  self,
29
31
  batch_size: Optional[int] = None,
30
32
  enc_max_seq_len: Optional[int] = None,
31
33
  dec_max_seq_len: Optional[int] = None,
32
34
  use_attention_mask: Optional[bool] = None,
33
- pad_token_id: Optional[int] = None,
34
- **kwargs: Dict[str, Any],
35
+ kvcache_num_blocks: Optional[int] = None,
36
+ kvcache_block_size: Optional[int] = None,
37
+ **kwargs: Any,
35
38
  ):
36
39
  """
37
40
  Args:
@@ -39,9 +42,11 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
39
42
  enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
40
43
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
41
44
  use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
42
- This is automatically set to True for RBLN-CA02 devices.
43
- pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
44
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+ kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
46
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
47
+ kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
48
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
49
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
50
 
46
51
  Raises:
47
52
  ValueError: If batch_size is not a positive integer.
@@ -55,12 +60,12 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
55
60
  self.dec_max_seq_len = dec_max_seq_len
56
61
 
57
62
  self.use_attention_mask = use_attention_mask
58
- npu = self.npu or rebel.get_npu_name()
59
- if npu == "RBLN-CA02":
60
- if self.use_attention_mask is False:
61
- logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
- self.use_attention_mask = True
63
- else:
64
- self.use_attention_mask = self.use_attention_mask or False
65
63
 
66
- self.pad_token_id = pad_token_id
64
+ if self.support_paged_attention:
65
+ self.kvcache_num_blocks = kvcache_num_blocks
66
+ self.kvcache_block_size = kvcache_block_size
67
+ else:
68
+ if kvcache_num_blocks is not None or kvcache_block_size is not None:
69
+ raise ValueError(
70
+ "You cannot set kvcache_num_blocks or kvcache_block_size as paged attention is not supported for the model."
71
+ )