optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (120) hide show
  1. optimum/rbln/__init__.py +58 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +24 -5
  4. optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  9. optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +7 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  12. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  19. optimum/rbln/diffusers/modeling_diffusers.py +4 -5
  20. optimum/rbln/diffusers/models/__init__.py +3 -13
  21. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  23. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  24. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  25. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  26. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
  27. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -26
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +2 -2
  29. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +2 -2
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  31. optimum/rbln/modeling.py +4 -5
  32. optimum/rbln/modeling_base.py +18 -14
  33. optimum/rbln/ops/kv_cache_update.py +5 -0
  34. optimum/rbln/ops/linear.py +7 -0
  35. optimum/rbln/transformers/__init__.py +60 -0
  36. optimum/rbln/transformers/configuration_generic.py +4 -4
  37. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  38. optimum/rbln/transformers/modeling_generic.py +1 -4
  39. optimum/rbln/transformers/models/__init__.py +45 -30
  40. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  41. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  42. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  43. optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
  44. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  45. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  46. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  47. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  48. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  49. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  50. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
  51. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
  52. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  53. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  54. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  55. optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
  56. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  57. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
  58. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  59. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
  60. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  61. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  62. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
  63. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  64. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  65. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  66. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  67. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  68. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  69. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  70. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  71. optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
  72. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
  73. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  74. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  75. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  76. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  77. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  78. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  79. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  80. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  81. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  82. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  83. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  84. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  85. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  86. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  87. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
  91. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  92. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  93. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
  94. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  95. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  96. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  97. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  98. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  99. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
  100. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
  101. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  102. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  103. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  104. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  105. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
  106. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  107. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  108. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  109. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  110. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
  111. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
  112. optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
  113. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  114. optimum/rbln/utils/depreacate_utils.py +16 -0
  115. optimum/rbln/utils/hub.py +8 -47
  116. optimum/rbln/utils/runtime_utils.py +31 -5
  117. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
  118. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +120 -103
  119. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
  120. {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,7 @@ import torch.nn as nn
16
16
  from transformers import PreTrainedModel
17
17
 
18
18
  from ....utils import logging
19
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
19
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
20
20
  from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
21
21
  from .opt_architecture import OPTWrapper
22
22
 
@@ -88,3 +88,43 @@ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
88
88
  model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
89
89
 
90
90
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
91
+
92
+
93
+ class RBLNOPTModel(RBLNDecoderOnlyModel):
94
+ """
95
+ The OPT Model transformer without a language modeling head.
96
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
97
+ """
98
+
99
+ _decoder_wrapper_cls = OPTWrapper
100
+ _use_rotary_emb = False
101
+
102
+ def modify_opt_decoder_layer(layer):
103
+ mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
104
+ layer.mlp = mlp
105
+ del layer.fc1
106
+ del layer.fc2
107
+ del layer.activation_fn
108
+
109
+ return layer
110
+
111
+ @classmethod
112
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
113
+ wrapper_cfg = {
114
+ "max_seq_len": rbln_config.max_seq_len,
115
+ "attn_impl": rbln_config.attn_impl,
116
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
117
+ "kvcache_block_size": rbln_config.kvcache_block_size,
118
+ "use_rotary_emb": cls._use_rotary_emb,
119
+ "use_attention_mask": rbln_config.use_attention_mask,
120
+ "use_position_ids": rbln_config.use_position_ids,
121
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
122
+ "cache_impl": rbln_config.cache_impl,
123
+ "sliding_window": rbln_config.sliding_window,
124
+ "sliding_window_layers": rbln_config.sliding_window_layers,
125
+ }
126
+
127
+ for i in range(len(model.decoder.layers)):
128
+ model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
129
+
130
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
@@ -18,7 +18,6 @@ import torch.nn as nn
18
18
 
19
19
  from ...models.decoderonly.decoderonly_architecture import (
20
20
  DecoderOnlyAttention,
21
- DecoderOnlyForCausalLM,
22
21
  DecoderOnlyLayer,
23
22
  DecoderOnlyModel,
24
23
  DecoderOnlyWrapper,
@@ -30,30 +29,22 @@ if TYPE_CHECKING:
30
29
 
31
30
 
32
31
  class OPTWrapper(DecoderOnlyWrapper):
33
- def convert_to_rbln_causal_lm(self, causal_lm: "OPTForCausalLM", max_seq_len: int):
34
- if self.attn_impl != "eager":
35
- raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
36
-
37
- new_layers = []
38
-
39
- for layer in causal_lm.model.decoder.layers:
40
- new_self_attn = OPTAttention(
41
- layer.self_attn,
42
- self.use_attention_mask,
43
- kvcache_block_size=self.kvcache_block_size,
44
- use_position_ids=self.use_position_ids,
45
- )
46
- new_layer = OPTDecoderLayer(layer, new_self_attn)
47
- new_layers.append(new_layer)
48
- new_model = OPTModel(
49
- causal_lm.model.decoder,
50
- new_layers,
51
- max_seq_len=max_seq_len,
52
- use_learned_pos_emb=True,
53
- sliding_window_layers=self.sliding_window_layers,
54
- )
55
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
56
- return new_causal_lm
32
+ _use_learned_pos_emb = True
33
+
34
+ def get_rbln_attn_class(self):
35
+ return OPTAttention
36
+
37
+ def get_rbln_layer_class(self):
38
+ return OPTDecoderLayer
39
+
40
+ def get_rbln_model_class(self):
41
+ return OPTModel
42
+
43
+ def get_model_layer(self, model: "OPTForCausalLM"):
44
+ return model.model.decoder if self.is_causal_lm else model.decoder
45
+
46
+ def get_decoder_layers(self, model: "OPTForCausalLM"):
47
+ return model.model.decoder.layers if self.is_causal_lm else model.decoder.layers
57
48
 
58
49
 
59
50
  class OPTAttention(DecoderOnlyAttention):
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
16
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig, RBLNPegasusModelConfig
17
+ from .modeling_pegasus import RBLNPegasusForConditionalGeneration, RBLNPegasusModel
@@ -0,0 +1,34 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNPegasusModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ """
21
+ Configuration class for RBLNPegasusModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized PEGASUS models for feature extraction tasks.
25
+ """
26
+
27
+
28
+ class RBLNPegasusForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
29
+ """
30
+ Configuration class for RBLNPegasusForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized PEGASUS models for conditional text generation tasks.
34
+ """
@@ -0,0 +1,69 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import TYPE_CHECKING, Any, Callable
17
+
18
+ from transformers import PegasusForConditionalGeneration, PreTrainedModel
19
+
20
+ from ....utils.logging import get_logger
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
22
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
23
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig
24
+ from .pegasus_architecture import PegasusWrapper
25
+
26
+
27
+ logger = get_logger()
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import PreTrainedModel
32
+
33
+
34
+ class RBLNPegasusModel(RBLNTransformerEncoderForFeatureExtraction):
35
+ """
36
+ RBLN optimized PEGASUS model for feature extraction tasks.
37
+
38
+ This class provides hardware-accelerated inference for PEGASUS encoder models
39
+ on RBLN devices, optimized for feature extraction use cases.
40
+ """
41
+
42
+
43
+ class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
44
+ """
45
+ RBLN optimized PEGASUS model for conditional text generation tasks.
46
+
47
+ This class provides hardware-accelerated inference for PEGASUS models
48
+ on RBLN devices, supporting sequence-to-sequence generation tasks
49
+ such as summarization, translation, and text generation.
50
+ """
51
+
52
+ support_causal_attn = True
53
+
54
+ @classmethod
55
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
56
+ return PegasusWrapper(
57
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
58
+ )
59
+
60
+ def __getattr__(self, __name: str) -> Any:
61
+ def redirect(func):
62
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
63
+
64
+ val = getattr(PegasusForConditionalGeneration, __name)
65
+
66
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
67
+ return redirect(val)
68
+
69
+ return val
@@ -0,0 +1,161 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
20
+ from transformers.utils import logging
21
+
22
+ from ..seq2seq.seq2seq_architecture import (
23
+ Seq2SeqCrossAttention,
24
+ Seq2SeqDecoder,
25
+ Seq2SeqDecoderLayer,
26
+ Seq2SeqDecoderWrapper,
27
+ Seq2SeqEncoderWrapper,
28
+ Seq2SeqForConditionalGeneration,
29
+ Seq2SeqSelfAttention,
30
+ )
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class PegasusWrapper:
37
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
38
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
39
+ self.decoder = PegasusDecoderWrapper(model, use_attention_mask=use_attention_mask)
40
+
41
+
42
+ class PegasusDecoderWrapper(Seq2SeqDecoderWrapper):
43
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
44
+ new_layers = []
45
+ for layer in model.get_decoder().layers:
46
+ self_attn = PegasusSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
47
+ cross_attn = PegasusCrossAttention(layer.encoder_attn)
48
+ new_layers.append(PegasusDecoderLayer(layer, self_attn, cross_attn))
49
+
50
+ decoder_model = PegasusDecoder(model.get_decoder(), new_layers)
51
+ new_model = PegasusForConditionalGeneration(model, decoder_model)
52
+
53
+ return new_model
54
+
55
+
56
+ class PegasusForConditionalGeneration(Seq2SeqForConditionalGeneration):
57
+ pass
58
+
59
+
60
+ class PegasusDecoder(Seq2SeqDecoder):
61
+ has_pos_emb = True
62
+
63
+ def __post_init__(self):
64
+ self.embed_positions = self._original_mod.embed_positions
65
+ self.embed_scale = getattr(self._original_mod, "embed_scale", None)
66
+ self.final_layer_norm = getattr(self._original_mod, "layer_norm", None)
67
+
68
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
69
+ if attention_mask is not None:
70
+ attention_mask = attention_mask[:, None, None, :]
71
+ encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
72
+
73
+ return attention_mask, encoder_attention_mask
74
+
75
+ def apply_position_embedding(self, inputs_embeds, cache_position):
76
+ hidden_all = []
77
+ for i in range(inputs_embeds.shape[0]):
78
+ positions_idx = cache_position[i]
79
+ position_weight = self.embed_positions.weight
80
+ position = position_weight[positions_idx]
81
+ batch_hidden = position + inputs_embeds[i]
82
+ hidden_all.append(batch_hidden)
83
+ hidden_states = torch.stack(hidden_all, dim=0)
84
+
85
+ return hidden_states
86
+
87
+ def get_embedding(self):
88
+ if self.embed_scale is not None:
89
+ return lambda x: self.embed_tokens(x) * self.embed_scale
90
+ else:
91
+ return self.embed_tokens
92
+
93
+
94
+ class PegasusLayerFF(nn.Module):
95
+ def __init__(self, decoder_layer):
96
+ super().__init__()
97
+ self.fc1 = decoder_layer.fc1
98
+ self.fc2 = decoder_layer.fc2
99
+ self.activation_fn = decoder_layer.activation_fn
100
+ self.layer_norm = decoder_layer.final_layer_norm
101
+
102
+ def forward(self, hidden_states):
103
+ # Residual Connection
104
+ residual = hidden_states
105
+ hidden_states = self.layer_norm(hidden_states)
106
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
107
+ hidden_states = self.fc2(hidden_states)
108
+ hidden_states = residual + hidden_states
109
+ return hidden_states
110
+
111
+
112
+ class PegasusDecoderLayer(Seq2SeqDecoderLayer):
113
+ def __post_init__(self):
114
+ self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
115
+ self.encoder_attn = self._original_mod.encoder_attn
116
+ self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
117
+ self.ff_layer = PegasusLayerFF(self._original_mod)
118
+
119
+ def pre_self_attn_layer_norm(self, hidden_states):
120
+ return self.self_attn_layer_norm(hidden_states)
121
+
122
+ def post_self_attn_layer_norm(self, hidden_states):
123
+ return hidden_states
124
+
125
+ def pre_cross_attn_layer_norm(self, hidden_states):
126
+ return self.encoder_attn_layer_norm(hidden_states)
127
+
128
+ def post_cross_attn_layer_norm(self, hidden_states):
129
+ return hidden_states
130
+
131
+
132
+ class PegasusSelfAttention(Seq2SeqSelfAttention):
133
+ def __post_init__(self, use_attention_mask: bool = True):
134
+ self.q_proj = self._original_mod.q_proj
135
+ self.k_proj = self._original_mod.k_proj
136
+ self.v_proj = self._original_mod.v_proj
137
+ self.out_proj = self._original_mod.out_proj
138
+ self.num_heads = self._original_mod.num_heads
139
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
140
+ self.scaling = self.head_dim**-0.5
141
+ if use_attention_mask:
142
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
143
+ else:
144
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
145
+
146
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
147
+ query_states = self.q_proj(hidden_states) * self.scaling
148
+ key_states = self.k_proj(hidden_states)
149
+ value_states = self.v_proj(hidden_states)
150
+ return query_states, key_states, value_states
151
+
152
+
153
+ class PegasusCrossAttention(Seq2SeqCrossAttention):
154
+ def __post_init__(self):
155
+ self.q_proj = self._original_mod.q_proj
156
+ self.k_proj = self._original_mod.k_proj
157
+ self.v_proj = self._original_mod.v_proj
158
+ self.out_proj = self._original_mod.out_proj
159
+ self.num_heads = self._original_mod.num_heads
160
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
161
+ self.embed_dim = self._original_mod.embed_dim
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_phi import RBLNPhiForCausalLMConfig
16
- from .modeling_phi import RBLNPhiForCausalLM
15
+ from .configuration_phi import RBLNPhiForCausalLMConfig, RBLNPhiModelConfig
16
+ from .modeling_phi import RBLNPhiForCausalLM, RBLNPhiModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNPhiModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Phi models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .phi_architecture import PhiWrapper
18
18
 
19
19
 
@@ -81,3 +81,12 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = PhiWrapper
84
+
85
+
86
+ class RBLNPhiModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Phi Model transformer without a language modeling head.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+ """
91
+
92
+ _decoder_wrapper_cls = PhiWrapper
@@ -12,14 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Optional, Tuple
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import PhiForCausalLM
19
19
 
20
20
  from ..decoderonly.decoderonly_architecture import (
21
21
  DecoderOnlyAttention,
22
- DecoderOnlyForCausalLM,
23
22
  DecoderOnlyLayer,
24
23
  DecoderOnlyModel,
25
24
  DecoderOnlyWrapper,
@@ -28,29 +27,24 @@ from ..decoderonly.decoderonly_architecture import (
28
27
 
29
28
 
30
29
  if TYPE_CHECKING:
31
- from transformers import PhiForCausalLM
30
+ from transformers import PhiForCausalLM, PhiModel
32
31
 
33
32
 
34
33
  class PhiWrapper(DecoderOnlyWrapper):
35
- def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM", max_seq_len: int):
36
- new_layers = []
37
- for layer in causal_lm.model.layers:
38
- if self.attn_impl == "eager":
39
- new_self_attn = PhiAttention(
40
- layer.self_attn,
41
- self.use_attention_mask,
42
- kvcache_block_size=self.kvcache_block_size,
43
- use_position_ids=self.use_position_ids,
44
- )
45
- elif self.attn_impl == "flash_attn":
46
- raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
47
- else:
48
- raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
49
- new_layer = PhiLayer(layer, new_self_attn)
50
- new_layers.append(new_layer)
51
- new_model = PhiModel(causal_lm.model, new_layers, sliding_window_layers=self.sliding_window_layers)
52
- new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
53
- return new_causal_lm
34
+ def get_rbln_attn_class(self):
35
+ return PhiAttention
36
+
37
+ def get_rbln_layer_class(self):
38
+ return PhiLayer
39
+
40
+ def get_rbln_model_class(self):
41
+ return PhiModel
42
+
43
+ def get_model_layer(self, model: Union["PhiForCausalLM", "PhiModel"]):
44
+ return model.model if self.is_causal_lm else model
45
+
46
+ def get_decoder_layers(self, model: Union["PhiForCausalLM", "PhiModel"]):
47
+ return model.model.layers if self.is_causal_lm else model.layers
54
48
 
55
49
 
56
50
  class PhiAttention(DecoderOnlyAttention):
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_pixtral import RBLNPixtralVisionModelConfig
16
+ from .modeling_pixtral import RBLNPixtralVisionModel
@@ -0,0 +1,43 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNPixtralVisionModelConfig(RBLNModelConfig):
21
+ def __init__(
22
+ self,
23
+ max_image_size: Tuple = None,
24
+ batch_size: Optional[int] = None,
25
+ output_hidden_states: Optional[bool] = None,
26
+ **kwargs: Any,
27
+ ):
28
+ """
29
+ Args:
30
+ max_image_size (Tuple): The size of max input images. A tuple (max_height, max_width)
31
+ batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
32
+ **kwargs: Additional arguments passed to the parent RBLNModelConfig.
33
+
34
+ Raises:
35
+ ValueError: If batch_size is not a positive integer.
36
+ """
37
+ super().__init__(**kwargs)
38
+ self.batch_size = batch_size or 1
39
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
40
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
41
+
42
+ self.max_image_size = max_image_size
43
+ self.output_hidden_states = output_hidden_states