optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,165 @@
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import PreTrainedModel
7
+
8
+ from ..decoderonly.decoderonly_architecture import (
9
+ DecoderOnlyWrapper,
10
+ apply_rotary_pos_emb,
11
+ )
12
+
13
+
14
+ class Qwen2VisionTransformerWrapper(nn.Module):
15
+ def __init__(self, model: torch.nn.Module):
16
+ super().__init__()
17
+ self._original_mod = model
18
+ self.merger = model.merger
19
+ self.blocks = self.wrap_vision_blocks(model.blocks)
20
+
21
+ def wrap_vision_blocks(self, blocks: torch.nn.ModuleList):
22
+ wrapped_blocks = []
23
+ for i, block in enumerate(blocks):
24
+ wrapped_blocks.append(Qwen2VLVisionBlock(block))
25
+ return nn.ModuleList(wrapped_blocks)
26
+
27
+ def forward(
28
+ self,
29
+ hidden_states: torch.Tensor,
30
+ full_attn_masks: torch.Tensor,
31
+ cos: torch.Tensor,
32
+ sin: torch.Tensor,
33
+ ):
34
+ full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
35
+
36
+ for block in self.blocks:
37
+ hidden_states = block(hidden_states, full_attn_masks, [cos, sin])
38
+
39
+ return self.merger(hidden_states)
40
+
41
+
42
+ class Qwen2VLVisionBlock(torch.nn.Module):
43
+ def __init__(self, model: torch.nn.Module):
44
+ super().__init__()
45
+ self._origin_model = model
46
+ self.norm1 = model.norm1
47
+ self.norm2 = model.norm2
48
+
49
+ self.attn = VisionAttention(model.attn)
50
+ self.mlp = model.mlp
51
+
52
+ def forward(
53
+ self,
54
+ hidden_states: torch.Tensor,
55
+ attn_masks: torch.Tensor,
56
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
57
+ ) -> torch.Tensor:
58
+ hidden_states = hidden_states + self.attn(
59
+ self.norm1(hidden_states),
60
+ attn_masks,
61
+ position_embeddings,
62
+ )
63
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
64
+ return hidden_states
65
+
66
+
67
+ class VisionAttention(nn.Module):
68
+ def __init__(self, model: nn.Module) -> None:
69
+ super().__init__()
70
+ self._origin_model = model
71
+ self.num_heads = model.num_heads
72
+ self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
73
+ self.qkv = model.qkv
74
+ self.proj = model.proj
75
+
76
+ def forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ attn_masks: torch.Tensor,
80
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
81
+ ) -> torch.Tensor:
82
+ seq_length = hidden_states.shape[0]
83
+ hidden_states = hidden_states.unsqueeze(0)
84
+ q, k, v = (
85
+ self.qkv(hidden_states).reshape(1, seq_length, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).unbind(0)
86
+ )
87
+
88
+ cos, sin = position_embeddings
89
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
90
+
91
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
92
+ attn_weights = attn_weights + attn_masks
93
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
94
+ attn_output = torch.matmul(attn_weights, v)
95
+ attn_output = attn_output.transpose(1, 2)
96
+ attn_output = attn_output.reshape(1, seq_length, -1)
97
+ attn_output = self.proj(attn_output).squeeze(0)
98
+
99
+ return attn_output
100
+
101
+
102
+ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
103
+ def prepare_forward_args(self, *args):
104
+ args = list(args)
105
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
106
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
107
+ cache_position = args.pop(0)
108
+ global_block_tables = args.pop(0)
109
+ local_block_tables = None
110
+ position_embeds = args.pop(0)
111
+ query_position = args.pop(0) if self.phase == "prefill" else None
112
+ position_ids = None
113
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
114
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
115
+ past_key_values = args
116
+
117
+ if len(past_key_values) != 2 * self.num_hidden_layers:
118
+ raise ValueError(
119
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
120
+ )
121
+
122
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
123
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
124
+ _past_key_values = []
125
+ for i in range(self.config.num_hidden_layers):
126
+ key_states = past_key_values[i * 2]
127
+ value_states = past_key_values[i * 2 + 1]
128
+ past_key_value = [key_states, value_states]
129
+ _past_key_values.append(past_key_value)
130
+ past_key_values = _past_key_values
131
+
132
+ return (
133
+ input_ids,
134
+ inputs_embeds,
135
+ cache_position,
136
+ global_block_tables,
137
+ local_block_tables,
138
+ query_position,
139
+ attention_mask,
140
+ position_ids,
141
+ lora_int_id,
142
+ past_key_values,
143
+ position_embeds,
144
+ )
145
+
146
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
147
+ new_layers = []
148
+
149
+ for layer_idx, layer in enumerate(model.model.language_model.layers):
150
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
151
+ new_self_attn = self.get_rbln_attn_class()(
152
+ self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
153
+ )
154
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
155
+ new_layers.append(new_layer)
156
+
157
+ new_model = self.get_rbln_model_class()(
158
+ model.model.language_model,
159
+ new_layers,
160
+ self.rbln_config,
161
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
162
+ )
163
+
164
+ new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
165
+ return new_model
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNQwen3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -43,7 +43,7 @@ class RBLNQwen3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
43
43
  """
44
44
 
45
45
 
46
- class RBLNQwen3ModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
46
+ class RBLNQwen3ModelConfig(RBLNDecoderOnlyModelConfig):
47
47
  """
48
48
  Configuration class for RBLN Qwen3 models.
49
49
 
@@ -12,37 +12,76 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from pathlib import Path
16
- from typing import TYPE_CHECKING, List, Optional, Union
15
+ from typing import TYPE_CHECKING
17
16
 
18
- import rebel
19
- import torch
20
- from rebel.compile_context import CompileContext
21
- from transformers import PretrainedConfig, PreTrainedModel
22
- from transformers.modeling_outputs import BaseModelOutputWithPast
23
- from transformers.modeling_utils import no_init_weights
17
+ from transformers import PretrainedConfig
24
18
 
25
- from ....configuration_utils import RBLNCompileConfig
26
- from ....modeling import RBLNModel
27
19
  from ....utils import logging
28
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
29
- from ..decoderonly.modeling_decoderonly import set_default_values, validate_attention_method
30
- from .configuration_qwen3 import RBLNQwen3ModelConfig
31
- from .qwen3_architecture import Qwen3ModelWrapper, Qwen3Wrapper
20
+ from ...models.decoderonly import (
21
+ RBLNDecoderOnlyModel,
22
+ RBLNDecoderOnlyModelForCausalLM,
23
+ RBLNDecoderOnlyModelForCausalLMConfig,
24
+ )
25
+ from .qwen3_architecture import Qwen3Wrapper
32
26
 
33
27
 
34
28
  logger = logging.get_logger(__name__)
35
29
 
36
30
  if TYPE_CHECKING:
37
- from transformers import (
38
- AutoFeatureExtractor,
39
- AutoProcessor,
40
- AutoTokenizer,
41
- PretrainedConfig,
42
- )
31
+ from transformers import PretrainedConfig
43
32
 
44
33
 
45
34
  class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
35
+ """
36
+ The Qwen3 Model transformer with a language modeling head (linear layer) on top.
37
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
38
+ A class to convert and run pre-trained transformers based Qwen3ForCausalLM model on RBLN devices.
39
+ It implements the methods to convert a pre-trained transformers Qwen3ForCausalLM model into a RBLN transformer model by:
40
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
41
+ - compiling the resulting graph using the RBLN compiler.
42
+ **Configuration:**
43
+ This model uses [`RBLNQwen3ForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
44
+ the `rbln_config` parameter should be an instance of [`RBLNQwen3ForCausalLMConfig`] or a dictionary conforming to its structure.
45
+ See the [`RBLNQwen3ForCausalLMConfig`] class for all available configuration options.
46
+ Examples:
47
+ ```python
48
+ from optimum.rbln import RBLNQwen3ForCausalLM
49
+ # Simple usage using rbln_* arguments
50
+ # `max_seq_len` is automatically inferred from the model config
51
+ model = RBLNQwen3ForCausalLM.from_pretrained(
52
+ "Qwen/Qwen3-4B",
53
+ export=True,
54
+ rbln_batch_size=1,
55
+ rbln_tensor_parallel_size=4,
56
+ )
57
+ # Using a config dictionary
58
+ rbln_config = {
59
+ "batch_size": 1,
60
+ "max_seq_len": 40_960,
61
+ "tensor_parallel_size": 4,
62
+ "kvcache_partition_len": 8192,
63
+ }
64
+ model = RBLNQwen3ForCausalLM.from_pretrained(
65
+ "Qwen/Qwen3-4B",
66
+ export=True,
67
+ rbln_config=rbln_config
68
+ )
69
+ # Using a RBLNQwen3ForCausalLMConfig instance (recommended for type checking)
70
+ from optimum.rbln import RBLNQwen3ForCausalLMConfig
71
+ config = RBLNQwen3ForCausalLMConfig(
72
+ batch_size=1,
73
+ max_seq_len=40_960,
74
+ tensor_parallel_size=4,
75
+ kvcache_partition_len=8192,
76
+ )
77
+ model = RBLNQwen3ForCausalLM.from_pretrained(
78
+ "Qwen/Qwen3-4B",
79
+ export=True,
80
+ rbln_config=config
81
+ )
82
+ ```
83
+ """
84
+
46
85
  _decoder_wrapper_cls = Qwen3Wrapper
47
86
 
48
87
  @classmethod
@@ -63,315 +102,32 @@ class RBLNQwen3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
63
102
  return super().forward(*args, **kwargs)
64
103
 
65
104
 
66
- class RBLNQwen3Model(RBLNModel):
67
- _decoder_wrapper_cls = Qwen3ModelWrapper
68
- _use_rotary_emb = True
69
-
70
- def __post_init__(self, **kwargs):
71
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
72
- self.embed_tokens = self._create_embedding_layer()
73
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
74
- self.block_tables = torch.arange(
75
- self.rbln_config.max_seq_len / self.rbln_config.kvcache_block_size, dtype=torch.int16
76
- )
77
- self.causal_mask = 1 - torch.triu(
78
- torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
79
- )
80
-
81
- @classmethod
82
- def save_torch_artifacts(
83
- cls,
84
- model: PreTrainedModel,
85
- save_dir_path: Path,
86
- subfolder: str,
87
- rbln_config: RBLNQwen3ModelConfig,
88
- ):
89
- save_dict = {}
90
- save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
91
- torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
92
-
93
- def _create_embedding_layer(self):
94
- with no_init_weights():
95
- embed_tokens = torch.nn.Embedding(
96
- self.config.vocab_size,
97
- self.config.hidden_size,
98
- self.config.pad_token_id,
99
- )
100
- return embed_tokens
101
-
102
- def get_input_embeddings(self):
103
- return self.embed_tokens
104
-
105
- @classmethod
106
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNQwen3ModelConfig"):
107
- wrapper_cfg = {
108
- "max_seq_len": rbln_config.max_seq_len,
109
- "attn_impl": rbln_config.attn_impl,
110
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
111
- "kvcache_block_size": rbln_config.kvcache_block_size,
112
- "use_rotary_emb": cls._use_rotary_emb,
113
- "use_attention_mask": rbln_config.use_attention_mask,
114
- "cache_impl": rbln_config.cache_impl,
115
- "sliding_window": rbln_config.sliding_window,
116
- "sliding_window_layers": rbln_config.sliding_window_layers,
117
- }
118
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
119
-
120
- @classmethod
121
- @torch.inference_mode()
122
- def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNQwen3ModelConfig):
123
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
124
-
125
- rbln_compile_configs = rbln_config.compile_cfgs
126
- prefill_compile_config = rbln_compile_configs[0]
127
-
128
- context = CompileContext(use_weight_sharing=False)
129
-
130
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
131
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
132
-
133
- static_tensors = {}
134
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
135
- if "past_key_values" in name:
136
- static_tensors[name] = tensor
137
- context.mark_static_address(tensor)
138
-
139
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context):
140
- try:
141
- original_linear = torch.nn.functional.linear
142
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
143
- compiled_model = RBLNModel.compile(
144
- wrapped_model,
145
- compile_config,
146
- example_inputs=example_inputs,
147
- compile_context=compile_context,
148
- create_runtimes=rbln_config.create_runtimes,
149
- device=rbln_config.device,
150
- )
151
- return compiled_model
152
- finally:
153
- torch.nn.functional.linear = original_linear
154
-
155
- wrapped_model.phase = "prefill"
156
- compiled_prefill = compile_model(wrapped_model, prefill_compile_config, prefill_example_inputs, context)
157
-
158
- compiled_models = {"prefill": compiled_prefill}
159
- return compiled_models
160
-
161
- @classmethod
162
- def get_input_info(
163
- cls,
164
- batch_size: int,
165
- query_length: int,
166
- rbln_config: RBLNQwen3ModelConfig,
167
- model_config: PretrainedConfig,
168
- ):
169
- input_info = RBLNDecoderOnlyModelForCausalLM.get_input_info(
170
- batch_size,
171
- query_length,
172
- rbln_config=rbln_config,
173
- model_config=model_config,
174
- )
175
-
176
- if rbln_config.sliding_window is None:
177
- # remove query position
178
- input_info.pop(3)
179
-
180
- return input_info
181
-
182
- @classmethod
183
- def _update_rbln_config(
184
- cls,
185
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
186
- model: Optional["PreTrainedModel"] = None,
187
- model_config: Optional["PretrainedConfig"] = None,
188
- rbln_config: Optional[RBLNQwen3ModelConfig] = None,
189
- ) -> RBLNQwen3ModelConfig:
190
- if rbln_config.max_seq_len is None:
191
- rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
192
- model_config, "n_positions", None
193
- )
194
- if rbln_config.max_seq_len is None:
195
- raise ValueError("`max_seq_len` should be specified.")
196
-
197
- rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
198
- attn_impl=rbln_config.attn_impl,
199
- kvcache_partition_len=rbln_config.kvcache_partition_len,
200
- kvcache_block_size=rbln_config.kvcache_block_size,
201
- max_seq_len=rbln_config.max_seq_len,
202
- )
203
-
204
- validate_attention_method(
205
- attn_impl=rbln_config.attn_impl,
206
- kvcache_partition_len=rbln_config.kvcache_partition_len,
207
- kvcache_block_size=rbln_config.kvcache_block_size,
208
- max_seq_len=rbln_config.max_seq_len,
209
- )
210
-
211
- # only compile prefill cb -> always batch_size 1
212
- required_num_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size
213
- max_num_blocks = required_num_blocks
214
-
215
- if rbln_config.attn_impl == "flash_attn":
216
- estimated_max_num_blocks = RBLNDecoderOnlyModelForCausalLM.get_maximum_num_blocks(
217
- config=model_config,
218
- tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
219
- kvcache_block_size=rbln_config.kvcache_block_size,
220
- nbits_per_param=16 if not rbln_config.quantization else 4,
221
- n_model_params=sum(p.numel() for p in model.parameters()),
222
- num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
223
- )
224
-
225
- max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
226
-
227
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
228
- if max_num_blocks < flash_min_blocks:
229
- max_num_blocks = flash_min_blocks
230
-
231
- if rbln_config.kvcache_num_blocks is None:
232
- rbln_config.kvcache_num_blocks = max_num_blocks
233
-
234
- prefill_input_info = cls.get_input_info(
235
- batch_size=1,
236
- query_length=rbln_config.prefill_chunk_size,
237
- rbln_config=rbln_config,
238
- model_config=model_config,
239
- )
240
-
241
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
242
- rbln_config.set_compile_cfgs([prefill_compile_config])
243
-
244
- return rbln_config
245
-
246
- @classmethod
247
- def _create_runtimes(
248
- cls,
249
- compiled_models: List[rebel.RBLNCompiledModel],
250
- rbln_config: RBLNQwen3ModelConfig,
251
- ) -> List[rebel.Runtime]:
252
- expected_model_names = ["prefill"]
253
- if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
254
- cls._raise_missing_compiled_file_error(expected_model_names)
255
-
256
- return [
257
- rebel.Runtime(
258
- compiled_models[0],
259
- tensor_type="pt",
260
- device=rbln_config.device_map["prefill"],
261
- activate_profiler=rbln_config.activate_profiler,
262
- ),
263
- ]
264
-
265
- def _preprocess_chunked_prefill(
266
- self,
267
- inputs: torch.Tensor,
268
- attention_mask: Optional[torch.Tensor] = None,
269
- position_embed: Optional[torch.Tensor] = None,
270
- ):
271
- # valid sequence length of inputs_embeds
272
- query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
273
-
274
- # extract valid inputs
275
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
276
- if position_embed is not None:
277
- position_embed = (
278
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
279
- )
280
-
281
- if self.rbln_config.use_attention_mask:
282
- chunked_attention_mask = (
283
- torch.zeros(
284
- 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
285
- )
286
- if self.rbln_config.use_attention_mask
287
- else None
288
- )
289
- else:
290
- chunked_attention_mask = None
291
-
292
- # padding for chunked prefill
293
- padding_size = (
294
- self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
295
- ) % self.rbln_config.prefill_chunk_size
296
- padded_len = query_length + padding_size
297
-
298
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
299
- position_embed = (
300
- None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
301
- )
302
- cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
303
-
304
- return inputs, chunked_attention_mask, position_embed, cache_position, query_length
305
-
306
- def _chunked_prefill_forward(
307
- self,
308
- inputs_embeds: torch.Tensor,
309
- attention_mask: Optional[torch.Tensor] = None,
310
- position_embed: Optional[torch.Tensor] = None,
311
- ):
312
- padded_input, chunked_attention_mask, padded_position_embed, cache_position, query_length = (
313
- self._preprocess_chunked_prefill(inputs_embeds, attention_mask, position_embed)
105
+ class RBLNQwen3Model(RBLNDecoderOnlyModel):
106
+ """
107
+ The bare Qwen3 Model outputting raw hidden-states without any specific head on top.
108
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
109
+ A class to convert and run pre-trained transformers based Qwen3Model on RBLN devices.
110
+ It implements the methods to convert a pre-trained transformers Qwen3Model into a RBLN transformer model by:
111
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
112
+ - compiling the resulting graph using the RBLN compiler.
113
+ **Configuration:**
114
+ This model uses [`RBLNQwen3ModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
115
+ the `rbln_config` parameter should be an instance of [`RBLNQwen3ModelConfig`] or a dictionary conforming to its structure.
116
+ See the [`RBLNQwen3ModelConfig`] class for all available configuration options.
117
+ Examples:
118
+ ```python
119
+ from optimum.rbln import RBLNQwen3Model
120
+ # Simple usage using rbln_* arguments
121
+ # `max_seq_len` is automatically inferred from the model config
122
+ model = RBLNQwen3Model.from_pretrained(
123
+ "Qwen/Qwen3-Embedding-4B",
124
+ export=True,
125
+ rbln_batch_size=1,
126
+ rbln_max_seq_len=40_960,
127
+ rbln_tensor_parallel_size=4,
128
+ rbln_kvcache_partition_len=8192,
314
129
  )
130
+ """
315
131
 
316
- # chunked prefill
317
- last_hidden_states = []
318
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
319
- # Extract the current chunk of inputs and cache positions
320
- input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
321
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
322
-
323
- model_args = {
324
- "input_ids": input_chunk,
325
- "cache_position": cache_pos_chunk,
326
- "block_tables": self.block_tables,
327
- }
328
-
329
- if chunked_attention_mask is not None:
330
- if step >= self.rbln_config.prefill_chunk_size:
331
- chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
332
- chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
333
- model_args["attention_mask"] = chunked_attention_mask
334
-
335
- last_hidden_states_chunk = self.model[0](**model_args)
336
- last_hidden_states.append(last_hidden_states_chunk)
337
-
338
- last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
339
-
340
- return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
341
-
342
- def _postprocess_chunked_prefill(
343
- self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
344
- ):
345
- # index copy for attention mask
346
- if attention_mask is not None:
347
- new_last_hidden_states = torch.full(
348
- (1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
349
- fill_value=1e-10,
350
- dtype=last_hidden_states.dtype,
351
- )
352
- mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
353
- new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
354
- else:
355
- new_last_hidden_states = last_hidden_states
356
- return new_last_hidden_states
357
-
358
- def forward(
359
- self,
360
- input_ids: Optional[torch.LongTensor] = None,
361
- inputs_embeds: Optional[torch.Tensor] = None,
362
- attention_mask: Optional[torch.LongTensor] = None,
363
- position_embed: Optional[torch.Tensor] = None,
364
- **kwargs,
365
- ):
366
- inputs = inputs_embeds if inputs_embeds is not None else input_ids
367
- batch_size = inputs.shape[0]
368
- all_last_hidden_states = []
369
- for b_idx in range(batch_size):
370
- last_hidden_states = self._chunked_prefill_forward(
371
- inputs[b_idx : b_idx + 1],
372
- attention_mask[b_idx] if attention_mask is not None else None,
373
- position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
374
- )
375
- all_last_hidden_states.append(last_hidden_states)
376
-
377
- return BaseModelOutputWithPast(last_hidden_state=torch.concat(all_last_hidden_states, dim=0))
132
+ _decoder_wrapper_cls = Qwen3Wrapper
133
+ _use_rotary_emb = True