optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -16,11 +16,9 @@ import copy
16
16
  from typing import TYPE_CHECKING, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from torch import nn
20
19
  from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
21
20
 
22
21
  from ..decoderonly.decoderonly_architecture import (
23
- AttentionOp,
24
22
  DecoderOnlyAttention,
25
23
  DecoderOnlyFlashAttention,
26
24
  DecoderOnlyForCausalLM,
@@ -28,7 +26,6 @@ from ..decoderonly.decoderonly_architecture import (
28
26
  DecoderOnlyModel,
29
27
  DecoderOnlyWrapper,
30
28
  RotaryEmbedding,
31
- SlidingWindowAttentionOp,
32
29
  slice_and_unsqueeze_cos_sin,
33
30
  )
34
31
 
@@ -50,13 +47,14 @@ class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
50
47
 
51
48
  def convert_to_rbln_causal_lm(self, causal_lm: "Gemma3ForCausalLM", max_seq_len: int):
52
49
  new_layers = []
53
- for layer in causal_lm.model.layers:
54
- if layer.is_sliding:
50
+ for layer_idx, layer in enumerate(causal_lm.model.layers):
51
+ if layer_idx in self.sliding_window_layers:
55
52
  new_self_attn = Gemma3Attention(
56
53
  layer.self_attn,
57
54
  use_attention_mask=None, # FIXME: no use in SWA
58
55
  use_position_ids=self.use_position_ids,
59
56
  kvcache_block_size=self.config.sliding_window,
57
+ is_sliding=True,
60
58
  )
61
59
  else:
62
60
  if self.attn_impl == "eager":
@@ -65,6 +63,7 @@ class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
65
63
  use_attention_mask=self.use_attention_mask,
66
64
  use_position_ids=self.use_position_ids,
67
65
  kvcache_block_size=self.kvcache_block_size,
66
+ is_sliding=False,
68
67
  )
69
68
  elif self.attn_impl == "flash_attn":
70
69
  new_self_attn = Gemma3FlashAttention(
@@ -85,131 +84,14 @@ class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
85
84
  new_layers,
86
85
  partition_len=self.kvcache_partition_len,
87
86
  max_seq_len=max_seq_len,
87
+ sliding_window_layers=self.sliding_window_layers,
88
88
  )
89
- new_causal_lm = Gemma3ForCausalLM(causal_lm, new_model)
89
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
90
90
  return new_causal_lm
91
91
 
92
- def forward(self, *args):
93
- if self.phase == "decode":
94
- (
95
- input_ids_or_inputs_embeds,
96
- attention_mask, # used in global layer, 2D attn_mask for padded KVcache.
97
- cache_position,
98
- position_ids,
99
- golbal_block_tables,
100
- local_block_tables,
101
- *past_key_values,
102
- ) = args
103
- query_position = None
104
-
105
- elif "prefill" in self.phase:
106
- (
107
- input_ids_or_inputs_embeds,
108
- attention_mask,
109
- cache_position,
110
- position_ids,
111
- query_position,
112
- golbal_block_tables,
113
- local_block_tables,
114
- *past_key_values,
115
- ) = args
116
-
117
- else:
118
- raise ValueError(f"Unknown phase: {self.phase}")
119
-
120
- if input_ids_or_inputs_embeds.ndim == 2:
121
- input_ids = input_ids_or_inputs_embeds
122
- inputs_embeds = None
123
- elif input_ids_or_inputs_embeds.ndim == 3:
124
- input_ids = None
125
- inputs_embeds = input_ids_or_inputs_embeds
126
- else:
127
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
128
-
129
- if len(past_key_values) != 2 * self.num_hidden_layers:
130
- raise ValueError(
131
- f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
132
- )
133
-
134
- # [key, value] * n_layer -> ( (key, value) ) * n_layer
135
- # cache shape : batch, n_heads, 1, max_seq_len, head_dim
136
- _past_key_values = []
137
- for i in range(self.config.num_hidden_layers):
138
- key_states = past_key_values[i * 2]
139
- value_states = past_key_values[i * 2 + 1]
140
- past_key_value = [key_states, value_states]
141
- _past_key_values.append(past_key_value)
142
- past_key_values = _past_key_values
143
-
144
- logit = self.causal_lm(
145
- input_ids=input_ids,
146
- inputs_embeds=inputs_embeds,
147
- attention_mask=attention_mask,
148
- cache_position=cache_position,
149
- position_ids=position_ids,
150
- query_position=query_position,
151
- past_key_values=past_key_values,
152
- rotary_emb=(self.rotary_emb_global, self.rotary_emb_local),
153
- global_block_tables=golbal_block_tables,
154
- local_block_tables=local_block_tables,
155
- )
156
-
157
- return logit
158
-
159
-
160
- class Gemma3ForCausalLM(DecoderOnlyForCausalLM):
161
- def forward(
162
- self,
163
- input_ids: torch.Tensor = None,
164
- inputs_embeds: torch.Tensor = None,
165
- attention_mask: torch.Tensor = None,
166
- cache_position: torch.Tensor = None,
167
- position_ids: torch.Tensor = None,
168
- query_position: torch.Tensor = None,
169
- past_key_values: Tuple[Tuple[torch.Tensor]] = None,
170
- rotary_emb: nn.Module = None,
171
- global_block_tables: Optional[torch.Tensor] = None,
172
- local_block_tables: Optional[torch.Tensor] = None,
173
- ):
174
- # outputs
175
- hidden_states = self.model(
176
- input_ids=input_ids,
177
- inputs_embeds=inputs_embeds,
178
- attention_mask=attention_mask,
179
- cache_position=cache_position,
180
- position_ids=position_ids,
181
- query_position=query_position,
182
- past_key_values=past_key_values,
183
- rotary_emb=rotary_emb,
184
- global_block_tables=global_block_tables,
185
- local_block_tables=local_block_tables,
186
- )
187
-
188
- if "prefill" in self.phase:
189
- hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
190
-
191
- logits = self.lm_head(hidden_states)
192
-
193
- # Apply final logit softmaxing if configured, e.g. for Gemma2
194
- if getattr(self.config, "final_logit_softcapping", None) is not None:
195
- logits = logits / self.config.final_logit_softcapping
196
- logits = torch.tanh(logits)
197
- logits = logits * self.config.final_logit_softcapping
198
-
199
- return logits
200
-
201
92
 
202
93
  class Gemma3TextModel(DecoderOnlyModel):
203
- def get_local_cache_positions(self, position_ids, query_position):
204
- max_cache_len = self._original_mod.config.sliding_window
205
- valid_input_len = 1 if query_position is None else query_position + 1
206
- cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
207
- cache_offset = (
208
- torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
209
- ) # cache offset for next steps
210
-
211
- return cache_seq_len, cache_offset
212
-
94
+ # Different from DecoderOnlyModel, this model has global and local rotary embeddings.
213
95
  def forward(
214
96
  self,
215
97
  input_ids: torch.Tensor = None,
@@ -254,37 +136,23 @@ class Gemma3TextModel(DecoderOnlyModel):
254
136
 
255
137
  sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
256
138
 
257
- for layer in self.layers:
258
- if layer.is_sliding:
259
- hidden_states = layer(
260
- hidden_states=hidden_states,
261
- attention_mask=attention_mask,
262
- seq_positions=sliding_cache_pos,
263
- past_key_values=past_key_values,
264
- cos=cos_local,
265
- sin=sin_local,
266
- block_tables=local_block_tables,
267
- )
268
- else:
269
- hidden_states = layer(
270
- hidden_states=hidden_states,
271
- attention_mask=attention_mask,
272
- seq_positions=seq_positions,
273
- past_key_values=past_key_values,
274
- cos=cos_global,
275
- sin=sin_global,
276
- block_tables=global_block_tables,
277
- )
139
+ for layer_idx, layer in enumerate(self.layers):
140
+ is_sliding = True if layer_idx in self.sliding_window_layers else False
141
+ hidden_states = layer(
142
+ hidden_states=hidden_states,
143
+ attention_mask=attention_mask,
144
+ seq_positions=sliding_cache_pos if is_sliding else seq_positions,
145
+ past_key_values=past_key_values,
146
+ cos=cos_local if is_sliding else cos_global,
147
+ sin=sin_local if is_sliding else sin_global,
148
+ block_tables=local_block_tables if is_sliding else global_block_tables,
149
+ )
278
150
 
279
151
  hidden_states = self.get_last_layernorm()(hidden_states)
280
152
  return hidden_states
281
153
 
282
154
 
283
155
  class Gemma3DecoderLayer(DecoderOnlyLayer):
284
- def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
285
- super().__init__(layer, self_attn)
286
- self.is_sliding = self._original_mod.is_sliding
287
-
288
156
  def get_pre_feedforward_layernorm(self) -> Gemma3RMSNorm:
289
157
  return self._original_mod.pre_feedforward_layernorm
290
158
 
@@ -328,69 +196,10 @@ class Gemma3Attention(DecoderOnlyAttention):
328
196
  self.o_proj = self._original_mod.o_proj
329
197
  self.q_norm = self._original_mod.q_norm
330
198
  self.k_norm = self._original_mod.k_norm
331
- self.is_sliding = self._original_mod.is_sliding
332
199
 
333
200
  def get_attn_scale(self):
334
201
  return self._original_mod.config.query_pre_attn_scalar**-0.5
335
202
 
336
- def get_attention(self):
337
- if self._original_mod.is_sliding:
338
- return SlidingWindowAttentionOp(
339
- self.num_heads,
340
- self.head_dim,
341
- self.num_key_value_heads,
342
- self.use_attention_mask,
343
- self.use_position_ids,
344
- )
345
- else:
346
- return AttentionOp(
347
- self.num_heads, self.head_dim, self.num_key_value_heads, self.use_attention_mask, self.use_position_ids
348
- )
349
-
350
- def forward(
351
- self,
352
- hidden_states: torch.Tensor,
353
- attention_mask: torch.Tensor,
354
- seq_positions: torch.LongTensor,
355
- past_key_values: Tuple[Tuple[torch.Tensor]],
356
- cos: Optional[torch.Tensor] = None,
357
- sin: Optional[torch.Tensor] = None,
358
- block_tables: Optional[torch.Tensor] = None,
359
- ):
360
- batch_size, query_length, _ = hidden_states.size()
361
-
362
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
363
-
364
- query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
365
- key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
366
- value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
367
- 1, 2
368
- )
369
-
370
- query_states = self.q_norm(query_states)
371
- key_states = self.k_norm(key_states)
372
- query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
373
-
374
- batch_size = query_states.shape[0]
375
- if batch_size > 1 and "prefill" in self.phase:
376
- raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
377
-
378
- attn_output = self.attention(
379
- query_states,
380
- key_states,
381
- value_states,
382
- attention_mask,
383
- past_key_state=past_key_values[self.layer_idx][0],
384
- past_value_state=past_key_values[self.layer_idx][1],
385
- seq_position=seq_positions,
386
- scale=self.scale,
387
- block_tables=block_tables,
388
- block_size=self.kvcache_block_size,
389
- )
390
-
391
- attn_outputs = self.o_proj(attn_output)
392
- return attn_outputs
393
-
394
203
 
395
204
  class Gemma3FlashAttention(DecoderOnlyFlashAttention):
396
205
  def __post_init__(self):
@@ -400,47 +209,6 @@ class Gemma3FlashAttention(DecoderOnlyFlashAttention):
400
209
  self.o_proj = self._original_mod.o_proj
401
210
  self.q_norm = self._original_mod.q_norm
402
211
  self.k_norm = self._original_mod.k_norm
403
- self.is_sliding = self._original_mod.is_sliding
404
212
 
405
213
  def get_attn_scale(self):
406
214
  return self._original_mod.config.query_pre_attn_scalar**-0.5
407
-
408
- def forward(
409
- self,
410
- hidden_states: torch.Tensor,
411
- attention_mask: torch.Tensor,
412
- seq_positions: torch.LongTensor,
413
- past_key_values: Tuple[Tuple[torch.Tensor]],
414
- cos: Optional[torch.Tensor] = None,
415
- sin: Optional[torch.Tensor] = None,
416
- block_tables: Optional[torch.Tensor] = None,
417
- ):
418
- batch_size, query_length, _ = hidden_states.size()
419
-
420
- query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
421
-
422
- query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
423
- key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
424
- value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
425
- 1, 2
426
- )
427
-
428
- query_states = self.q_norm(query_states)
429
- key_states = self.k_norm(key_states)
430
- query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
431
-
432
- attn_output = self.attention(
433
- query_states,
434
- key_states,
435
- value_states,
436
- attention_mask,
437
- past_key_state=past_key_values[self.layer_idx][0],
438
- past_value_state=past_key_values[self.layer_idx][1],
439
- seq_position=seq_positions,
440
- scale=self.scale,
441
- block_tables=block_tables,
442
- kvcache_block_size=self.kvcache_block_size,
443
- )
444
-
445
- attn_outputs = self.o_proj(attn_output)
446
- return attn_outputs