optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -12,16 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import torch
16
- import torch.nn as nn
17
- from transformers import PreTrainedModel
18
15
 
19
- from ..decoderonly.decoderonly_architecture import (
20
- DecoderOnlyAttention,
21
- DecoderOnlyLayer,
22
- DecoderOnlyWrapper,
23
- RotaryEmbedding,
24
- )
16
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyWrapper
25
17
 
26
18
 
27
19
  class Qwen3Wrapper(DecoderOnlyWrapper):
@@ -37,239 +29,3 @@ class Qwen3Attention(DecoderOnlyAttention):
37
29
  self.o_proj = self._original_mod.o_proj
38
30
  self.q_norm = self._original_mod.q_norm
39
31
  self.k_norm = self._original_mod.k_norm
40
-
41
-
42
- class Qwen3ModelWrapper(nn.Module):
43
- def __init__(
44
- self,
45
- model,
46
- attn_impl=None,
47
- use_inputs_embeds=None,
48
- use_attention_mask=None,
49
- use_rotary_emb=None,
50
- cache_impl=None,
51
- kvcache_partition_len=None,
52
- max_seq_len=None,
53
- kvcache_block_size=None,
54
- sliding_window=None,
55
- sliding_window_layers=None,
56
- ):
57
- super().__init__()
58
- self.config = model.config
59
-
60
- if use_rotary_emb:
61
- rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
62
- if isinstance(rotary_embs, tuple):
63
- self.rotary_emb_global, self.rotary_emb_local = rotary_embs
64
- else:
65
- self.rotary_emb = rotary_embs
66
- else:
67
- self.rotary_emb = None
68
-
69
- self._original_mod = model
70
- self.use_inputs_embeds = use_inputs_embeds
71
- self.attn_impl = attn_impl
72
- self.cache_impl = cache_impl
73
- self.use_attention_mask = use_attention_mask
74
- self.kvcache_partition_len = kvcache_partition_len
75
- self.kvcache_block_size = kvcache_block_size
76
- self.max_seq_len = max_seq_len
77
- self.sliding_window = sliding_window
78
- self.sliding_window_layers = sliding_window_layers
79
- self.model = self.convert_to_rbln_model(model)
80
-
81
- def get_rotary_emb(self, max_seq_len):
82
- return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
83
-
84
- def convert_to_rbln_model(self, base_model: PreTrainedModel):
85
- for layer_idx, layer in enumerate(base_model.layers):
86
- is_sliding = layer_idx in self.sliding_window_layers
87
- new_self_attn = Qwen3Attention(
88
- layer.self_attn,
89
- self.use_attention_mask if not is_sliding else True,
90
- use_position_ids=None,
91
- kvcache_block_size=self.sliding_window
92
- if layer_idx in self.sliding_window_layers
93
- else self.kvcache_block_size,
94
- is_sliding=is_sliding,
95
- attn_impl=self.attn_impl if not is_sliding else "eager",
96
- kvcache_partition_len=self.kvcache_partition_len,
97
- )
98
- base_model.layers[layer_idx] = DecoderOnlyLayer(layer, new_self_attn)
99
-
100
- return base_model
101
-
102
- @property
103
- def hidden_multiplier(self):
104
- return 1
105
-
106
- def get_last_layernorm(self) -> nn.LayerNorm:
107
- return self._original_mod.norm
108
-
109
- def get_embedding(self) -> nn.Embedding:
110
- return self._original_mod.embed_tokens
111
-
112
- def get_pos_embedding(self) -> nn.Embedding:
113
- raise NotImplementedError(
114
- "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
115
- )
116
-
117
- def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
118
- if self.attn_impl not in ["flash_attn"]:
119
- raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
120
- partition_len = self.kvcache_partition_len
121
- num_partition = max_seq_len // partition_len
122
-
123
- cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
124
- pidx = torch.arange(num_partition)
125
- cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
126
- return cache_pos_for_partitions
127
-
128
- def get_local_cache_positions(self, position_ids, query_position):
129
- max_cache_len = self.model.config.sliding_window
130
- valid_input_len = 1 if query_position is None else query_position + 1
131
- cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
132
- cache_offset = (
133
- torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
134
- ) # cache offset for next steps
135
-
136
- return cache_seq_len, cache_offset
137
-
138
- def prepare_forward_args(self, *args):
139
- args = list(args)
140
- input_ids = None if self.use_inputs_embeds else args.pop(0)
141
- inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
142
- cache_position = args.pop(0)
143
- global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
144
- local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
145
- query_position = args.pop(0) if self.sliding_window else None
146
- attention_mask = args.pop(0) if self.use_attention_mask else None
147
- position_ids = None
148
- past_key_values = args
149
-
150
- if len(past_key_values) != 2 * self.config.num_hidden_layers:
151
- raise ValueError(
152
- f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.config.num_hidden_layers}"
153
- )
154
-
155
- # [key, value] * n_layer -> ( (key, value) ) * n_layer
156
- # cache shape : batch, n_heads, 1, max_seq_len, head_dim
157
- _past_key_values = []
158
- for i in range(self.config.num_hidden_layers):
159
- key_states = past_key_values[i * 2]
160
- value_states = past_key_values[i * 2 + 1]
161
- past_key_value = [key_states, value_states]
162
- _past_key_values.append(past_key_value)
163
- past_key_values = _past_key_values
164
-
165
- if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
166
- rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
167
- else:
168
- rotary_emb = self.rotary_emb
169
-
170
- return (
171
- input_ids,
172
- inputs_embeds,
173
- cache_position,
174
- global_block_tables,
175
- local_block_tables,
176
- attention_mask,
177
- position_ids,
178
- query_position,
179
- past_key_values,
180
- rotary_emb,
181
- )
182
-
183
- def forward(self, *args):
184
- (
185
- input_ids,
186
- inputs_embeds,
187
- cache_position,
188
- global_block_tables,
189
- local_block_tables,
190
- attention_mask,
191
- position_ids,
192
- query_position,
193
- past_key_values,
194
- rotary_emb,
195
- ) = self.prepare_forward_args(*args)
196
-
197
- # retrieve input_ids and inputs_embeds
198
- if (input_ids is None) ^ (inputs_embeds is not None):
199
- raise ValueError(
200
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
201
- )
202
-
203
- # embed positions
204
- if inputs_embeds is None:
205
- inputs_embeds = self.get_embedding()(input_ids)
206
-
207
- hidden_states = inputs_embeds * self.hidden_multiplier
208
-
209
- # get cos,sin vector if needed
210
- position_ids = position_ids if position_ids is not None else cache_position
211
- if rotary_emb is not None:
212
- if isinstance(rotary_emb, torch.Tensor):
213
- cos = rotary_emb[0]
214
- sin = rotary_emb[1]
215
- else:
216
- cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
217
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
218
- else:
219
- batch_size = inputs_embeds.shape[0]
220
- if position_ids.shape[0] > 1:
221
- position_embeds = []
222
- for b_idx in range(batch_size):
223
- position_embed = self.get_pos_embedding()(position_ids[b_idx])
224
- position_embeds.append(position_embed)
225
-
226
- position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
227
- else:
228
- position_embeds = self.get_pos_embedding()(position_ids)
229
- hidden_states = hidden_states + position_embeds
230
- cos, sin = None, None
231
-
232
- # Get sequence positions for flash attention
233
- if self.attn_impl == "flash_attn":
234
- seq_positions = cache_position[:, 0]
235
- seq_positions = self.convert_sequence_positions_for_flash_attn(
236
- seq_positions=seq_positions, max_seq_len=self.max_seq_len
237
- )
238
- else:
239
- seq_positions = cache_position[:, :1]
240
-
241
- # Get local cache positions for sliding window layers
242
- if len(self.sliding_window_layers) > 0:
243
- sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
244
-
245
- for layer_idx, layer in enumerate(self.model.layers):
246
- is_sliding = True if layer_idx in self.sliding_window_layers else False
247
- hidden_states = layer(
248
- hidden_states=hidden_states,
249
- attention_mask=attention_mask,
250
- seq_positions=sliding_cache_pos if is_sliding else seq_positions,
251
- past_key_values=past_key_values,
252
- cos=cos,
253
- sin=sin,
254
- block_tables=local_block_tables if is_sliding else global_block_tables,
255
- )
256
-
257
- hidden_states = self.get_last_layernorm()(hidden_states)
258
- return hidden_states
259
-
260
-
261
- def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
262
- """Slice cos[cache_position], sin[cache_position] vector for the query."""
263
- if cache_position.shape[0] > 1:
264
- cos_all = []
265
- sin_all = []
266
- for i in range(cache_position.shape[0]):
267
- cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
268
- sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
269
- cos = torch.cat(cos_all, dim=0)
270
- sin = torch.cat(sin_all, dim=0)
271
- else:
272
- cos = cos[cache_position].unsqueeze(unsqueeze_dim)
273
- sin = sin[cache_position].unsqueeze(unsqueeze_dim)
274
-
275
- return cos, sin
@@ -12,9 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
16
-
17
- import rebel
15
+ from typing import Any, Optional
18
16
 
19
17
  from ....configuration_utils import RBLNModelConfig
20
18
  from ....utils.logging import get_logger
@@ -24,6 +22,8 @@ logger = get_logger()
24
22
 
25
23
 
26
24
  class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
25
+ support_paged_attention = None
26
+
27
27
  def __init__(
28
28
  self,
29
29
  batch_size: Optional[int] = None,
@@ -31,7 +31,9 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
31
31
  dec_max_seq_len: Optional[int] = None,
32
32
  use_attention_mask: Optional[bool] = None,
33
33
  pad_token_id: Optional[int] = None,
34
- **kwargs: Dict[str, Any],
34
+ kvcache_num_blocks: Optional[int] = None,
35
+ kvcache_block_size: Optional[int] = None,
36
+ **kwargs: Any,
35
37
  ):
36
38
  """
37
39
  Args:
@@ -39,9 +41,12 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
39
41
  enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
40
42
  dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
41
43
  use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
42
- This is automatically set to True for RBLN-CA02 devices.
43
44
  pad_token_id (Optional[int]): The ID of the padding token in the vocabulary.
44
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+ kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
46
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
47
+ kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
48
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
49
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
50
 
46
51
  Raises:
47
52
  ValueError: If batch_size is not a positive integer.
@@ -55,12 +60,14 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
55
60
  self.dec_max_seq_len = dec_max_seq_len
56
61
 
57
62
  self.use_attention_mask = use_attention_mask
58
- npu = self.npu or rebel.get_npu_name()
59
- if npu == "RBLN-CA02":
60
- if self.use_attention_mask is False:
61
- logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
62
- self.use_attention_mask = True
63
- else:
64
- self.use_attention_mask = self.use_attention_mask or False
65
63
 
66
64
  self.pad_token_id = pad_token_id
65
+
66
+ if self.support_paged_attention:
67
+ self.kvcache_num_blocks = kvcache_num_blocks
68
+ self.kvcache_block_size = kvcache_block_size
69
+ else:
70
+ if kvcache_num_blocks is not None or kvcache_block_size is not None:
71
+ raise ValueError(
72
+ "You cannot set kvcache_num_blocks or kvcache_block_size as paged attention is not supported for the model."
73
+ )
@@ -20,6 +20,7 @@ import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
+ from transformers.generation.utils import GenerationMixin
23
24
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
24
25
 
25
26
  from ....configuration_utils import RBLNCompileConfig
@@ -38,7 +39,7 @@ if TYPE_CHECKING:
38
39
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
39
40
  mandatory_members = ["main_input_name"]
40
41
 
41
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
42
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
42
43
  output = super().forward(*args, **kwargs)
43
44
  return BaseModelOutput(last_hidden_state=output)
44
45
 
@@ -83,7 +84,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
83
84
  decoding_step = cache_position[b_idx].item()
84
85
  if not (0 <= decoding_step < self.dec_max_seq_len):
85
86
  raise ValueError(
86
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
87
+ f"Decoding step {decoding_step} out of bounds for decoder_max_seq_len ({self.dec_max_seq_len})."
87
88
  )
88
89
  decoder_attention_mask[b_idx, : decoding_step + 1] = 1
89
90
 
@@ -101,7 +102,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
101
102
  return Seq2SeqLMOutput(logits=lm_logits)
102
103
 
103
104
 
104
- class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
105
+ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
105
106
  """
106
107
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
107
108
  This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -117,6 +118,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
117
118
  main_input_name = "input_ids"
118
119
  auto_model_class = AutoModelForSeq2SeqLM
119
120
  support_causal_attn = None
121
+ _is_stateful = False
120
122
 
121
123
  def __post_init__(self, **kwargs):
122
124
  batch_size = self.rbln_config.batch_size
@@ -181,6 +183,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
181
183
 
182
184
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
183
185
 
186
+ @classmethod
187
+ def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
188
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
189
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
190
+
191
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
192
+ raise NotImplementedError(
193
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
194
+ )
195
+
196
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
197
+ raise NotImplementedError(
198
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
199
+ )
200
+
184
201
  @classmethod
185
202
  def _update_rbln_config(
186
203
  cls,
@@ -238,6 +255,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
238
255
  if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
239
256
  raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
240
257
 
258
+ if rbln_config.support_paged_attention:
259
+ cls._update_paged_attention_config(model_config, rbln_config)
260
+
241
261
  # model input info
242
262
  enc_input_info = [
243
263
  ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
@@ -310,6 +330,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
310
330
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
311
331
 
312
332
  rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
333
+
313
334
  return rbln_config
314
335
 
315
336
  @classmethod
@@ -31,7 +31,7 @@ class Seq2SeqWrapper:
31
31
  Args:
32
32
  model (nn.Module): The Seq2Seq model to wrap.
33
33
  enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
34
- **kwargs: Additional arguments to pass to the decoder wrapper.
34
+ kwargs: Additional arguments to pass to the decoder wrapper.
35
35
  """
36
36
 
37
37
  def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
@@ -125,7 +125,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
125
125
 
126
126
  Args:
127
127
  model (nn.Module): The Seq2Seq model containing the decoder.
128
- **kwargs: Additional arguments for decoder configuration.
128
+ kwargs: Additional arguments for decoder configuration.
129
129
  """
130
130
 
131
131
  def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
@@ -12,9 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_siglip import (
16
- RBLNSiglipVisionModelConfig,
17
- )
18
- from .modeling_siglip import (
19
- RBLNSiglipVisionModel,
20
- )
15
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
16
+ from .modeling_siglip import RBLNSiglipVisionModel
@@ -42,7 +42,7 @@ class RBLNSiglipVisionModelConfig(RBLNModelConfig):
42
42
  interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
43
43
  output_hidden_states: (Optional[bool]): Whether to return hidden states.
44
44
  output_attentions: (Optional[bool]): Whether to return attentions.
45
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
46
46
 
47
47
  Raises:
48
48
  ValueError: If batch_size is not a positive integer.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import SiglipVisionConfig, SiglipVisionModel
@@ -29,8 +29,6 @@ logger = get_logger(__name__)
29
29
  if TYPE_CHECKING:
30
30
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
31
31
 
32
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
33
-
34
32
 
35
33
  class _SiglipVisionModel(torch.nn.Module):
36
34
  def __init__(
@@ -65,6 +63,8 @@ class RBLNSiglipVisionModel(RBLNModel):
65
63
  on RBLN devices, supporting image encoding for multimodal vision-language tasks.
66
64
  """
67
65
 
66
+ _tp_support = False
67
+
68
68
  @classmethod
69
69
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
70
70
  wrapper_cfg = {
@@ -74,12 +74,6 @@ class RBLNSiglipVisionModel(RBLNModel):
74
74
  }
75
75
  return _SiglipVisionModel(model, **wrapper_cfg).eval()
76
76
 
77
- @classmethod
78
- def update_rbln_config_using_pipe(
79
- cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
80
- ) -> "RBLNDiffusionMixinConfig":
81
- return rbln_config
82
-
83
77
  @classmethod
84
78
  def _update_rbln_config(
85
79
  cls,
@@ -126,13 +120,8 @@ class RBLNSiglipVisionModel(RBLNModel):
126
120
  output_attentions: bool = None,
127
121
  output_hidden_states: bool = None,
128
122
  interpolate_pos_encoding: bool = False,
129
- **kwargs: Dict[str, Any],
123
+ **kwargs: Any,
130
124
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
131
- if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
132
- logger.warning(
133
- f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
134
- )
135
-
136
125
  output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
137
126
  output_hidden_states = (
138
127
  output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
@@ -156,7 +145,7 @@ class RBLNSiglipVisionModel(RBLNModel):
156
145
  f"Please compile again with the correct argument."
157
146
  )
158
147
 
159
- output = super().forward(pixel_values, return_dict=return_dict)
148
+ output = super().forward(pixel_values, return_dict=return_dict, **kwargs)
160
149
  return output
161
150
 
162
151
  def _prepare_output(self, output, return_dict):
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_swin import RBLNSwinBackboneConfig
16
+ from .modeling_swin import RBLNSwinBackbone
@@ -0,0 +1,42 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at:
4
+
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from typing import Any, Optional, Tuple, Union
14
+
15
+ from ...configuration_generic import RBLNModelForImageClassificationConfig
16
+
17
+
18
+ class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
19
+ def __init__(
20
+ self,
21
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
22
+ batch_size: Optional[int] = None,
23
+ output_hidden_states: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ **kwargs: Any,
26
+ ):
27
+ """
28
+ Args:
29
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
30
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
31
+
32
+ Raises:
33
+ ValueError: If batch_size is not a positive integer.
34
+ """
35
+ super().__init__(**kwargs)
36
+ self.batch_size = batch_size or 1
37
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
39
+
40
+ self.image_size = image_size
41
+ self.output_hidden_states = output_hidden_states
42
+ self.output_attentions = output_attentions