optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (120) hide show
  1. optimum/rbln/__init__.py +58 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +24 -5
  4. optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +7 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +4 -5
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  23. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  24. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  25. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -26
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +2 -2
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +2 -2
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  31. optimum/rbln/modeling.py +4 -5
  32. optimum/rbln/modeling_base.py +18 -14
  33. optimum/rbln/ops/kv_cache_update.py +5 -0
  34. optimum/rbln/ops/linear.py +7 -0
  35. optimum/rbln/transformers/__init__.py +60 -0
  36. optimum/rbln/transformers/configuration_generic.py +4 -4
  37. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  38. optimum/rbln/transformers/modeling_generic.py +1 -4
  39. optimum/rbln/transformers/models/__init__.py +45 -30
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  41. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  42. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  43. optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
  44. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  45. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  46. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  47. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  48. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
  51. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
  52. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  53. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  54. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  55. optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
  56. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  57. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
  58. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  59. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
  60. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  61. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  62. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
  63. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  64. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  65. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  66. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  67. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  68. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  69. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  75. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  76. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  77. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  78. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  79. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  80. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  81. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  82. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  83. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  84. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  85. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  86. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  87. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
  91. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  92. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  93. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
  94. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  97. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
  101. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  102. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  103. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  104. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  105. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
  106. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  107. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  108. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  109. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  110. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
  111. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
  112. optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
  113. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  114. optimum/rbln/utils/depreacate_utils.py +16 -0
  115. optimum/rbln/utils/hub.py +8 -47
  116. optimum/rbln/utils/runtime_utils.py +31 -5
  117. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
  118. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +120 -103
  119. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
  120. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
@@ -19,8 +19,6 @@ import torch.nn as nn
19
19
  from ....utils import logging
20
20
  from ...models.decoderonly.decoderonly_architecture import (
21
21
  DecoderOnlyAttention,
22
- DecoderOnlyFlashAttention,
23
- DecoderOnlyForCausalLM,
24
22
  DecoderOnlyLayer,
25
23
  DecoderOnlyModel,
26
24
  DecoderOnlyWrapper,
@@ -36,38 +34,23 @@ logger = logging.get_logger(__name__)
36
34
  class ExaoneForCausalLMWrapper(DecoderOnlyWrapper):
37
35
  """A wrapper class for the Exaone model with a language modeling head."""
38
36
 
39
- def convert_to_rbln_causal_lm(self, causal_lm: "ExaoneForCausalLM", max_seq_len: int):
40
- new_layers = []
41
- for layer in causal_lm.transformer.h:
42
- if self.attn_impl == "eager":
43
- new_self_attn = ExaoneAttention(
44
- layer.attn.attention,
45
- self.use_attention_mask,
46
- kvcache_block_size=self.kvcache_block_size,
47
- use_position_ids=self.use_position_ids,
48
- )
49
- elif self.attn_impl == "flash_attn":
50
- new_self_attn = ExaoneFlashAttention(
51
- layer.attn.attention,
52
- kvcache_partition_len=self.kvcache_partition_len,
53
- use_attention_mask=self.use_attention_mask,
54
- kvcache_block_size=self.kvcache_block_size,
55
- use_position_ids=self.use_position_ids,
56
- )
57
- else:
58
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
59
-
60
- new_layer = ExaoneLayer(layer, new_self_attn)
61
- new_layers.append(new_layer)
62
- new_model = ExaoneModel(
63
- causal_lm.transformer,
64
- new_layers,
65
- partition_len=self.kvcache_partition_len,
66
- max_seq_len=max_seq_len,
67
- sliding_window_layers=self.sliding_window_layers,
68
- )
69
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
70
- return new_causal_lm
37
+ def get_decoder_layers(self, causal_lm: "ExaoneForCausalLM"):
38
+ return causal_lm.transformer.h
39
+
40
+ def get_attn_layer(self, layer: nn.Module):
41
+ return layer.attn.attention
42
+
43
+ def get_model_layer(self, causal_lm: "ExaoneForCausalLM"):
44
+ return causal_lm.transformer
45
+
46
+ def get_rbln_attn_class(self):
47
+ return ExaoneAttention
48
+
49
+ def get_rbln_layer_class(self):
50
+ return ExaoneLayer
51
+
52
+ def get_rbln_model_class(self):
53
+ return ExaoneModel
71
54
 
72
55
 
73
56
  class ExaoneModel(DecoderOnlyModel):
@@ -92,11 +75,3 @@ class ExaoneAttention(DecoderOnlyAttention):
92
75
  self.k_proj = self._original_mod.k_proj
93
76
  self.v_proj = self._original_mod.v_proj
94
77
  self.o_proj = self._original_mod.out_proj
95
-
96
-
97
- class ExaoneFlashAttention(DecoderOnlyFlashAttention):
98
- def __post_init__(self):
99
- self.q_proj = self._original_mod.q_proj
100
- self.k_proj = self._original_mod.k_proj
101
- self.v_proj = self._original_mod.v_proj
102
- self.o_proj = self._original_mod.out_proj
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_gemma import RBLNGemmaForCausalLMConfig
16
- from .modeling_gemma import RBLNGemmaForCausalLM
15
+ from .configuration_gemma import RBLNGemmaForCausalLMConfig, RBLNGemmaModelConfig
16
+ from .modeling_gemma import RBLNGemmaForCausalLM, RBLNGemmaModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNGemmaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNGemmaModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Gemma models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -12,54 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING
16
15
 
17
- from ...models.decoderonly.decoderonly_architecture import (
18
- DecoderOnlyAttention,
19
- DecoderOnlyFlashAttention,
20
- DecoderOnlyForCausalLM,
21
- DecoderOnlyLayer,
22
- DecoderOnlyModel,
23
- DecoderOnlyWrapper,
24
- )
25
-
26
-
27
- if TYPE_CHECKING:
28
- from transformers import GemmaForCausalLM
16
+ from ...models.decoderonly.decoderonly_architecture import DecoderOnlyModel, DecoderOnlyWrapper
29
17
 
30
18
 
31
19
  class GemmaWrapper(DecoderOnlyWrapper):
32
- def convert_to_rbln_causal_lm(self, causal_lm: "GemmaForCausalLM", max_seq_len: int):
33
- new_layers = []
34
- for layer in causal_lm.model.layers:
35
- if self.attn_impl == "eager":
36
- new_self_attn = DecoderOnlyAttention(
37
- layer.self_attn,
38
- self.use_attention_mask,
39
- kvcache_block_size=self.kvcache_block_size,
40
- use_position_ids=self.use_position_ids,
41
- )
42
- elif self.attn_impl == "flash_attn":
43
- new_self_attn = DecoderOnlyFlashAttention(
44
- layer.self_attn,
45
- kvcache_partition_len=self.kvcache_partition_len,
46
- use_attention_mask=self.use_attention_mask,
47
- kvcache_block_size=self.kvcache_block_size,
48
- use_position_ids=self.use_position_ids,
49
- )
50
- else:
51
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
52
- new_layer = DecoderOnlyLayer(layer, new_self_attn)
53
- new_layers.append(new_layer)
54
- new_model = GemmaModel(
55
- causal_lm.model,
56
- new_layers,
57
- partition_len=self.kvcache_partition_len,
58
- max_seq_len=max_seq_len,
59
- sliding_window_layers=self.sliding_window_layers,
60
- )
61
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
62
- return new_causal_lm
20
+ def get_rbln_model_class(self):
21
+ return GemmaModel
63
22
 
64
23
 
65
24
  class GemmaModel(DecoderOnlyModel):
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .gemma_architecture import GemmaWrapper
18
18
 
19
19
 
@@ -81,3 +81,24 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = GemmaWrapper
84
+
85
+
86
+ class RBLNGemmaModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Gemma Model transformer without a language modeling head.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+
91
+ A class to convert and run pre-trained transformers based GemmaModel model on RBLN devices.
92
+ It implements the methods to convert a pre-trained transformers GemmaModel model into a RBLN transformer model by:
93
+
94
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
95
+ - compiling the resulting graph using the RBLN compiler.
96
+
97
+ **Configuration:**
98
+ This model uses [`RBLNGemmaModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
99
+ the `rbln_config` parameter should be an instance of [`RBLNGemmaModelConfig`] or a dictionary conforming to its structure.
100
+
101
+ See the [`RBLNGemmaModelConfig`] class for all available configuration options.
102
+ """
103
+
104
+ _decoder_wrapper_cls = GemmaWrapper
@@ -11,9 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict, Optional
15
-
16
- import rebel
14
+ from typing import Any, Optional
17
15
 
18
16
  from ....configuration_utils import RBLNModelConfig
19
17
  from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
@@ -23,10 +21,11 @@ from ..siglip.configuration_siglip import RBLNSiglipVisionModelConfig
23
21
  class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
24
22
  def __init__(
25
23
  self,
26
- prefill_chunk_size: Optional[int] = None,
27
24
  use_position_ids: Optional[bool] = None,
28
25
  use_attention_mask: Optional[bool] = None,
29
- **kwargs: Dict[str, Any],
26
+ prefill_chunk_size: Optional[int] = None,
27
+ image_prefill_chunk_size: Optional[int] = None,
28
+ **kwargs: Any,
30
29
  ):
31
30
  # use_attention_mask and use_position_ids are always True for Gemma3
32
31
  use_attention_mask = use_attention_mask or True
@@ -39,10 +38,15 @@ class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
39
38
  use_position_ids=use_position_ids,
40
39
  **kwargs,
41
40
  )
41
+ self.image_prefill_chunk_size = image_prefill_chunk_size
42
42
 
43
- npu = self.npu or rebel.get_npu_name()
44
- if npu == "RBLN-CA02":
45
- raise NotImplementedError("Gemma3 is currently not supported on RBLN-CA02")
43
+ @property
44
+ def use_image_prefill(self):
45
+ return self.image_prefill_chunk_size is not None
46
+
47
+ @property
48
+ def decoder_runtime_idx(self):
49
+ return 2 if self.use_image_prefill else 1
46
50
 
47
51
 
48
52
  class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
@@ -53,7 +57,7 @@ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
53
57
  batch_size: Optional[int] = None,
54
58
  vision_tower: Optional[RBLNModelConfig] = None,
55
59
  language_model: Optional[RBLNModelConfig] = None,
56
- **kwargs: Dict[str, Any],
60
+ **kwargs: Any,
57
61
  ):
58
62
  """
59
63
  Args:
@@ -72,3 +76,11 @@ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
72
76
 
73
77
  self.vision_tower = self.init_submodule_config(RBLNSiglipVisionModelConfig, vision_tower)
74
78
  self.language_model = self.init_submodule_config(RBLNGemma3ForCausalLMConfig, language_model)
79
+
80
+ @property
81
+ def image_prefill_chunk_size(self):
82
+ return self.language_model.image_prefill_chunk_size
83
+
84
+ @property
85
+ def prefill_chunk_size(self):
86
+ return self.language_model.prefill_chunk_size
@@ -13,15 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import copy
16
- from typing import TYPE_CHECKING, Optional, Tuple, Union
16
+ from typing import Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from transformers.models.gemma3.modeling_gemma3 import Gemma3RMSNorm
20
20
 
21
21
  from ..decoderonly.decoderonly_architecture import (
22
22
  DecoderOnlyAttention,
23
- DecoderOnlyFlashAttention,
24
- DecoderOnlyForCausalLM,
25
23
  DecoderOnlyLayer,
26
24
  DecoderOnlyModel,
27
25
  DecoderOnlyWrapper,
@@ -30,10 +28,6 @@ from ..decoderonly.decoderonly_architecture import (
30
28
  )
31
29
 
32
30
 
33
- if TYPE_CHECKING:
34
- from transformers import Gemma3ForCausalLM
35
-
36
-
37
31
  class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
38
32
  def get_rotary_emb(self, max_seq_len):
39
33
  rotary_emb_global = RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
@@ -45,49 +39,14 @@ class Gemma3ForCausalLMWrapper(DecoderOnlyWrapper):
45
39
 
46
40
  return (rotary_emb_global, rotary_emb_local)
47
41
 
48
- def convert_to_rbln_causal_lm(self, causal_lm: "Gemma3ForCausalLM", max_seq_len: int):
49
- new_layers = []
50
- for layer_idx, layer in enumerate(causal_lm.model.layers):
51
- if layer_idx in self.sliding_window_layers:
52
- new_self_attn = Gemma3Attention(
53
- layer.self_attn,
54
- use_attention_mask=None, # FIXME: no use in SWA
55
- use_position_ids=self.use_position_ids,
56
- kvcache_block_size=self.config.sliding_window,
57
- is_sliding=True,
58
- )
59
- else:
60
- if self.attn_impl == "eager":
61
- new_self_attn = Gemma3Attention(
62
- layer.self_attn,
63
- use_attention_mask=self.use_attention_mask,
64
- use_position_ids=self.use_position_ids,
65
- kvcache_block_size=self.kvcache_block_size,
66
- is_sliding=False,
67
- )
68
- elif self.attn_impl == "flash_attn":
69
- new_self_attn = Gemma3FlashAttention(
70
- layer.self_attn,
71
- kvcache_partition_len=self.kvcache_partition_len,
72
- use_attention_mask=self.use_attention_mask,
73
- kvcache_block_size=self.kvcache_block_size,
74
- use_position_ids=self.use_position_ids,
75
- )
76
- else:
77
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
78
-
79
- new_layer = Gemma3DecoderLayer(layer, new_self_attn)
80
- new_layers.append(new_layer)
81
-
82
- new_model = Gemma3TextModel(
83
- causal_lm.model,
84
- new_layers,
85
- partition_len=self.kvcache_partition_len,
86
- max_seq_len=max_seq_len,
87
- sliding_window_layers=self.sliding_window_layers,
88
- )
89
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
90
- return new_causal_lm
42
+ def get_rbln_attn_class(self):
43
+ return Gemma3Attention
44
+
45
+ def get_rbln_layer_class(self):
46
+ return Gemma3DecoderLayer
47
+
48
+ def get_rbln_model_class(self):
49
+ return Gemma3TextModel
91
50
 
92
51
 
93
52
  class Gemma3TextModel(DecoderOnlyModel):
@@ -199,16 +158,3 @@ class Gemma3Attention(DecoderOnlyAttention):
199
158
 
200
159
  def get_attn_scale(self):
201
160
  return self._original_mod.config.query_pre_attn_scalar**-0.5
202
-
203
-
204
- class Gemma3FlashAttention(DecoderOnlyFlashAttention):
205
- def __post_init__(self):
206
- self.q_proj = self._original_mod.q_proj
207
- self.k_proj = self._original_mod.k_proj
208
- self.v_proj = self._original_mod.v_proj
209
- self.o_proj = self._original_mod.o_proj
210
- self.q_norm = self._original_mod.q_norm
211
- self.k_norm = self._original_mod.k_norm
212
-
213
- def get_attn_scale(self):
214
- return self._original_mod.config.query_pre_attn_scalar**-0.5