optimum-rbln 0.7.5a1__py3-none-any.whl → 0.7.5rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. optimum/rbln/__init__.py +10 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/transformers/__init__.py +10 -0
  4. optimum/rbln/transformers/models/__init__.py +14 -0
  5. optimum/rbln/transformers/models/auto/__init__.py +1 -0
  6. optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
  7. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +114 -19
  8. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +29 -10
  9. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  10. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  11. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  12. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +69 -0
  13. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +446 -0
  14. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1057 -0
  15. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  16. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  17. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -0
  18. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -1
  19. optimum/rbln/transformers/models/phi/phi_architecture.py +4 -1
  20. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -2
  21. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/METADATA +1 -1
  22. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/RECORD +24 -20
  23. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/WHEEL +0 -0
  24. {optimum_rbln-0.7.5a1.dist-info → optimum_rbln-0.7.5rc0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,69 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
19
+ from ..siglip.configuration_siglip import RBLNSiglipVisionModelConfig
20
+
21
+
22
+ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
23
+ def __init__(
24
+ self,
25
+ prefill_chunk_size: Optional[int] = None,
26
+ use_position_ids: Optional[bool] = None,
27
+ use_attention_mask: Optional[bool] = None,
28
+ **kwargs,
29
+ ):
30
+ # use_attention_mask and use_position_ids are always True for Gemma3
31
+ use_attention_mask = use_attention_mask or True
32
+ use_position_ids = use_position_ids or True
33
+ prefill_chunk_size = prefill_chunk_size or 256
34
+
35
+ super().__init__(
36
+ prefill_chunk_size=prefill_chunk_size,
37
+ use_attention_mask=use_attention_mask,
38
+ use_position_ids=use_position_ids,
39
+ **kwargs,
40
+ )
41
+
42
+
43
+ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
44
+ submodules = ["vision_tower", "language_model"]
45
+
46
+ def __init__(
47
+ self,
48
+ batch_size: Optional[int] = None,
49
+ vision_tower: Optional[RBLNModelConfig] = None,
50
+ language_model: Optional[RBLNModelConfig] = None,
51
+ **kwargs,
52
+ ):
53
+ """
54
+ Args:
55
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
56
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
57
+ language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
58
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
59
+
60
+ Raises:
61
+ ValueError: If batch_size is not a positive integer.
62
+ """
63
+ super().__init__(**kwargs)
64
+ self.batch_size = batch_size or 1
65
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
66
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
67
+
68
+ self.vision_tower = self.init_submodule_config(RBLNSiglipVisionModelConfig, vision_tower)
69
+ self.language_model = self.init_submodule_config(RBLNGemma3ForCausalLMConfig, language_model)
@@ -0,0 +1,446 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+ from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
21
+
22
+ from ..decoderonly.decoderonly_architecture import (
23
+ AttentionOp,
24
+ DecoderOnlyAttention,
25
+ DecoderOnlyFlashAttention,
26
+ DecoderOnlyForCausalLM,
27
+ DecoderOnlyLayer,
28
+ DecoderOnlyModel,
29
+ DecoderOnlyWrapper,
30
+ RotaryEmbedding,
31
+ SlidingWindowAttentionOp,
32
+ slice_and_unsqueeze_cos_sin,
33
+ )
34
+
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import Gemma3ForCausalLM
38
+
39
+
40
+ class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
41
+ def get_rotary_emb(self, max_seq_len):
42
+ rotary_emb_global = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
43
+
44
+ config = copy.deepcopy(self.config)
45
+ config.rope_theta = config.rope_local_base_freq
46
+ config.rope_scaling = {"rope_type": "default"}
47
+ rotary_emb_local = RotaryEmbedding(config=config, max_seq_len_cached=max_seq_len)
48
+
49
+ return (rotary_emb_global, rotary_emb_local)
50
+
51
+ def convert_to_rbln_causal_lm(self, causal_lm: "Gemma3ForCausalLM", max_seq_len: int):
52
+ new_layers = []
53
+ for layer in causal_lm.model.layers:
54
+ if layer.is_sliding:
55
+ new_self_attn = Gemma3Attention(
56
+ layer.self_attn,
57
+ use_attention_mask=None, # FIXME: no use in SWA
58
+ use_position_ids=self.use_position_ids,
59
+ kvcache_block_size=self.config.sliding_window,
60
+ )
61
+ else:
62
+ if self.attn_impl == "eager":
63
+ new_self_attn = Gemma3Attention(
64
+ layer.self_attn,
65
+ use_attention_mask=self.use_attention_mask,
66
+ use_position_ids=self.use_position_ids,
67
+ kvcache_block_size=self.kvcache_block_size,
68
+ )
69
+ elif self.attn_impl == "flash_attn":
70
+ new_self_attn = Gemma3FlashAttention(
71
+ layer.self_attn,
72
+ kvcache_partition_len=self.kvcache_partition_len,
73
+ use_attention_mask=self.use_attention_mask,
74
+ kvcache_block_size=self.kvcache_block_size,
75
+ use_position_ids=self.use_position_ids,
76
+ )
77
+ else:
78
+ raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
79
+
80
+ new_layer = Gemma3DecoderLayer(layer, new_self_attn)
81
+ new_layers.append(new_layer)
82
+
83
+ new_model = Gemma3TextModel(
84
+ causal_lm.model,
85
+ new_layers,
86
+ partition_len=self.kvcache_partition_len,
87
+ max_seq_len=max_seq_len,
88
+ )
89
+ new_causal_lm = Gemma3ForCausalLM(causal_lm, new_model)
90
+ return new_causal_lm
91
+
92
+ def forward(self, *args):
93
+ if self.phase == "decode":
94
+ (
95
+ input_ids_or_inputs_embeds,
96
+ attention_mask, # used in global layer, 2D attn_mask for padded KVcache.
97
+ cache_position,
98
+ position_ids,
99
+ golbal_block_tables,
100
+ local_block_tables,
101
+ *past_key_values,
102
+ ) = args
103
+ query_position = None
104
+
105
+ elif "prefill" in self.phase:
106
+ (
107
+ input_ids_or_inputs_embeds,
108
+ attention_mask,
109
+ cache_position,
110
+ position_ids,
111
+ query_position,
112
+ golbal_block_tables,
113
+ local_block_tables,
114
+ *past_key_values,
115
+ ) = args
116
+
117
+ else:
118
+ raise ValueError(f"Unknown phase: {self.phase}")
119
+
120
+ if input_ids_or_inputs_embeds.ndim == 2:
121
+ input_ids = input_ids_or_inputs_embeds
122
+ inputs_embeds = None
123
+ elif input_ids_or_inputs_embeds.ndim == 3:
124
+ input_ids = None
125
+ inputs_embeds = input_ids_or_inputs_embeds
126
+ else:
127
+ raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
128
+
129
+ if len(past_key_values) != 2 * self.num_hidden_layers:
130
+ raise ValueError(
131
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
132
+ )
133
+
134
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
135
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
136
+ _past_key_values = []
137
+ for i in range(self.config.num_hidden_layers):
138
+ key_states = past_key_values[i * 2]
139
+ value_states = past_key_values[i * 2 + 1]
140
+ past_key_value = [key_states, value_states]
141
+ _past_key_values.append(past_key_value)
142
+ past_key_values = _past_key_values
143
+
144
+ logit = self.causal_lm(
145
+ input_ids=input_ids,
146
+ inputs_embeds=inputs_embeds,
147
+ attention_mask=attention_mask,
148
+ cache_position=cache_position,
149
+ position_ids=position_ids,
150
+ query_position=query_position,
151
+ past_key_values=past_key_values,
152
+ rotary_emb=(self.rotary_emb_global, self.rotary_emb_local),
153
+ global_block_tables=golbal_block_tables,
154
+ local_block_tables=local_block_tables,
155
+ )
156
+
157
+ return logit
158
+
159
+
160
+ class Gemma3ForCausalLM(DecoderOnlyForCausalLM):
161
+ def forward(
162
+ self,
163
+ input_ids: torch.Tensor = None,
164
+ inputs_embeds: torch.Tensor = None,
165
+ attention_mask: torch.Tensor = None,
166
+ cache_position: torch.Tensor = None,
167
+ position_ids: torch.Tensor = None,
168
+ query_position: torch.Tensor = None,
169
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
170
+ rotary_emb: nn.Module = None,
171
+ global_block_tables: Optional[torch.Tensor] = None,
172
+ local_block_tables: Optional[torch.Tensor] = None,
173
+ ):
174
+ # outputs
175
+ hidden_states = self.model(
176
+ input_ids=input_ids,
177
+ inputs_embeds=inputs_embeds,
178
+ attention_mask=attention_mask,
179
+ cache_position=cache_position,
180
+ position_ids=position_ids,
181
+ query_position=query_position,
182
+ past_key_values=past_key_values,
183
+ rotary_emb=rotary_emb,
184
+ global_block_tables=global_block_tables,
185
+ local_block_tables=local_block_tables,
186
+ )
187
+
188
+ if "prefill" in self.phase:
189
+ hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
190
+
191
+ logits = self.lm_head(hidden_states)
192
+
193
+ # Apply final logit softmaxing if configured, e.g. for Gemma2
194
+ if getattr(self.config, "final_logit_softcapping", None) is not None:
195
+ logits = logits / self.config.final_logit_softcapping
196
+ logits = torch.tanh(logits)
197
+ logits = logits * self.config.final_logit_softcapping
198
+
199
+ return logits
200
+
201
+
202
+ class Gemma3TextModel(DecoderOnlyModel):
203
+ def get_local_cache_positions(self, position_ids, query_position):
204
+ max_cache_len = self._original_mod.config.sliding_window
205
+ valid_input_len = 1 if query_position is None else query_position + 1
206
+ cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
207
+ cache_offset = (
208
+ torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
209
+ ) # cache offset for next steps
210
+
211
+ return cache_seq_len, cache_offset
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: torch.Tensor = None,
216
+ inputs_embeds: torch.Tensor = None,
217
+ attention_mask: torch.Tensor = None,
218
+ cache_position: torch.Tensor = None,
219
+ position_ids: torch.Tensor = None,
220
+ query_position: torch.Tensor = None,
221
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
222
+ rotary_emb: torch.nn.Module = None,
223
+ global_block_tables: Optional[torch.Tensor] = None,
224
+ local_block_tables: Optional[torch.Tensor] = None,
225
+ ):
226
+ # retrieve input_ids and inputs_embeds
227
+ if (input_ids is None) ^ (inputs_embeds is not None):
228
+ raise ValueError(
229
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
230
+ )
231
+
232
+ # embed positions
233
+ if inputs_embeds is None:
234
+ inputs_embeds = self.get_embedding()(input_ids)
235
+
236
+ hidden_states = inputs_embeds
237
+
238
+ # Global Position Embeddings
239
+ cos_global, sin_global = rotary_emb[0](hidden_states, self.max_seq_len)
240
+ cos_global, sin_global = slice_and_unsqueeze_cos_sin(cos_global, sin_global, position_ids)
241
+
242
+ # Local Position Embeddings
243
+ cos_local, sin_local = rotary_emb[1](hidden_states, self.max_seq_len)
244
+ cos_local, sin_local = slice_and_unsqueeze_cos_sin(cos_local, sin_local, position_ids)
245
+
246
+ # (batch, seq_len) -> (batch,)
247
+ if self.attn_impl == "flash_attn":
248
+ seq_positions = cache_position[:, 0]
249
+ seq_positions = self.convert_sequence_positions_for_flash_attn(
250
+ seq_positions=seq_positions, max_seq_len=self.max_seq_len
251
+ )
252
+ else:
253
+ seq_positions = cache_position[:, :1]
254
+
255
+ sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
256
+
257
+ for layer in self.layers:
258
+ if layer.is_sliding:
259
+ hidden_states = layer(
260
+ hidden_states=hidden_states,
261
+ attention_mask=attention_mask,
262
+ seq_positions=sliding_cache_pos,
263
+ past_key_values=past_key_values,
264
+ cos=cos_local,
265
+ sin=sin_local,
266
+ block_tables=local_block_tables,
267
+ )
268
+ else:
269
+ hidden_states = layer(
270
+ hidden_states=hidden_states,
271
+ attention_mask=attention_mask,
272
+ seq_positions=seq_positions,
273
+ past_key_values=past_key_values,
274
+ cos=cos_global,
275
+ sin=sin_global,
276
+ block_tables=global_block_tables,
277
+ )
278
+
279
+ hidden_states = self.get_last_layernorm()(hidden_states)
280
+ return hidden_states
281
+
282
+
283
+ class Gemma3DecoderLayer(DecoderOnlyLayer):
284
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
285
+ super().__init__(layer, self_attn)
286
+ self.is_sliding = self._original_mod.is_sliding
287
+
288
+ def get_pre_feedforward_layernorm(self) -> Gemma3RMSNorm:
289
+ return self._original_mod.pre_feedforward_layernorm
290
+
291
+ def get_post_feedforward_layernorm(self) -> Gemma3RMSNorm:
292
+ return self._original_mod.post_feedforward_layernorm
293
+
294
+ def forward(
295
+ self,
296
+ hidden_states: torch.Tensor,
297
+ attention_mask: torch.Tensor,
298
+ seq_positions: Union[torch.LongTensor, Tuple[torch.LongTensor]],
299
+ past_key_values: Tuple[Tuple[torch.Tensor]],
300
+ cos: Optional[torch.Tensor] = None,
301
+ sin: Optional[torch.Tensor] = None,
302
+ block_tables: Optional[torch.Tensor] = None,
303
+ ):
304
+ residual = hidden_states
305
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
306
+
307
+ hidden_states = self.self_attn(
308
+ hidden_states, attention_mask, seq_positions, past_key_values, cos, sin, block_tables
309
+ )
310
+ hidden_states = self.get_post_attention_layernorm()(hidden_states)
311
+ hidden_states = residual + hidden_states
312
+
313
+ # Fully Connected
314
+ residual = hidden_states
315
+ hidden_states = self.get_pre_feedforward_layernorm()(hidden_states)
316
+ hidden_states = self._original_mod.mlp(hidden_states)
317
+ hidden_states = self.get_post_feedforward_layernorm()(hidden_states)
318
+ hidden_states = residual + hidden_states
319
+
320
+ return hidden_states
321
+
322
+
323
+ class Gemma3Attention(DecoderOnlyAttention):
324
+ def __post_init__(self):
325
+ self.q_proj = self._original_mod.q_proj
326
+ self.k_proj = self._original_mod.k_proj
327
+ self.v_proj = self._original_mod.v_proj
328
+ self.o_proj = self._original_mod.o_proj
329
+ self.q_norm = self._original_mod.q_norm
330
+ self.k_norm = self._original_mod.k_norm
331
+ self.is_sliding = self._original_mod.is_sliding
332
+
333
+ def get_attn_scale(self):
334
+ return self._original_mod.config.query_pre_attn_scalar**-0.5
335
+
336
+ def get_attention(self):
337
+ if self._original_mod.is_sliding:
338
+ return SlidingWindowAttentionOp(
339
+ self.num_heads,
340
+ self.head_dim,
341
+ self.num_key_value_heads,
342
+ self.use_attention_mask,
343
+ self.use_position_ids,
344
+ )
345
+ else:
346
+ return AttentionOp(
347
+ self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
348
+ )
349
+
350
+ def forward(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ attention_mask: torch.Tensor,
354
+ seq_positions: torch.LongTensor,
355
+ past_key_values: Tuple[Tuple[torch.Tensor]],
356
+ cos: Optional[torch.Tensor] = None,
357
+ sin: Optional[torch.Tensor] = None,
358
+ block_tables: Optional[torch.Tensor] = None,
359
+ ):
360
+ batch_size, query_length, _ = hidden_states.size()
361
+
362
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
363
+
364
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
365
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
366
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
367
+ 1, 2
368
+ )
369
+
370
+ query_states = self.q_norm(query_states)
371
+ key_states = self.k_norm(key_states)
372
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
373
+
374
+ batch_size = query_states.shape[0]
375
+ if batch_size > 1 and "prefill" in self.phase:
376
+ raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
377
+
378
+ attn_output = self.attention(
379
+ query_states,
380
+ key_states,
381
+ value_states,
382
+ attention_mask,
383
+ past_key_state=past_key_values[self.layer_idx][0],
384
+ past_value_state=past_key_values[self.layer_idx][1],
385
+ seq_position=seq_positions,
386
+ scale=self.scale,
387
+ block_tables=block_tables,
388
+ block_size=self.kvcache_block_size,
389
+ )
390
+
391
+ attn_outputs = self.o_proj(attn_output)
392
+ return attn_outputs
393
+
394
+
395
+ class Gemma3FlashAttention(DecoderOnlyFlashAttention):
396
+ def __post_init__(self):
397
+ self.q_proj = self._original_mod.q_proj
398
+ self.k_proj = self._original_mod.k_proj
399
+ self.v_proj = self._original_mod.v_proj
400
+ self.o_proj = self._original_mod.o_proj
401
+ self.q_norm = self._original_mod.q_norm
402
+ self.k_norm = self._original_mod.k_norm
403
+ self.is_sliding = self._original_mod.is_sliding
404
+
405
+ def get_attn_scale(self):
406
+ return self._original_mod.config.query_pre_attn_scalar**-0.5
407
+
408
+ def forward(
409
+ self,
410
+ hidden_states: torch.Tensor,
411
+ attention_mask: torch.Tensor,
412
+ seq_positions: torch.LongTensor,
413
+ past_key_values: Tuple[Tuple[torch.Tensor]],
414
+ cos: Optional[torch.Tensor] = None,
415
+ sin: Optional[torch.Tensor] = None,
416
+ block_tables: Optional[torch.Tensor] = None,
417
+ ):
418
+ batch_size, query_length, _ = hidden_states.size()
419
+
420
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
421
+
422
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
423
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
424
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
425
+ 1, 2
426
+ )
427
+
428
+ query_states = self.q_norm(query_states)
429
+ key_states = self.k_norm(key_states)
430
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
431
+
432
+ attn_output = self.attention(
433
+ query_states,
434
+ key_states,
435
+ value_states,
436
+ attention_mask,
437
+ past_key_state=past_key_values[self.layer_idx][0],
438
+ past_value_state=past_key_values[self.layer_idx][1],
439
+ seq_position=seq_positions,
440
+ scale=self.scale,
441
+ block_tables=block_tables,
442
+ kvcache_block_size=self.kvcache_block_size,
443
+ )
444
+
445
+ attn_outputs = self.o_proj(attn_output)
446
+ return attn_outputs