optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -123,7 +123,10 @@ class MidmAttention(DecoderOnlyAttention):
123
123
  self.split_size = self._original_mod.split_size
124
124
  self.num_key_value_heads = self._original_mod.num_heads
125
125
 
126
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
+ def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
127
+ if lora_int_id is not None:
128
+ raise NotImplementedError("LoRA is not supported for MidmAttention")
129
+
127
130
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
128
131
  return query_states, key_states, value_states
129
132
 
@@ -13,11 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import Any, Callable
16
+ from pathlib import Path
17
+ from typing import Any, Callable, Dict, Optional, Union
17
18
 
18
19
  from transformers import AutoModelForCausalLM
19
20
  from transformers.generation.utils import GenerationMixin
20
21
 
22
+ from ....configuration_utils import RBLNModelConfig
21
23
  from ....utils import logging
22
24
  from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
23
25
  from .midm_architecture import MidmLMHeadModelWrapper
@@ -91,9 +93,45 @@ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
91
93
  _supports_cache_class = True
92
94
 
93
95
  @classmethod
94
- def from_pretrained(cls, *args, **kwargs):
95
- kwargs.setdefault("trust_remote_code", True)
96
- return super().from_pretrained(*args, **kwargs)
96
+ def from_pretrained(
97
+ cls,
98
+ model_id: Union[str, Path],
99
+ *,
100
+ export: Optional[bool] = None,
101
+ rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
102
+ trust_remote_code: Optional[bool] = None,
103
+ **kwargs: Any,
104
+ ):
105
+ """
106
+ The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
107
+ User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
108
+
109
+ Args:
110
+ model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
111
+ It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
112
+ export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
113
+ If None, it will be determined based on the existence of the compiled model files in the model_id.
114
+ rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
115
+ This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNMidmLMHeadModelConfig` for Mi:dm models).
116
+ For detailed configuration options, see the specific model's configuration class documentation.
117
+ trust_remote_code (bool): Whether or not to trust the remote code when loading a model from the Hub.
118
+ kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
119
+
120
+ Returns:
121
+ (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
122
+ """
123
+
124
+ if trust_remote_code is not None:
125
+ kwargs["trust_remote_code"] = trust_remote_code
126
+ elif "trust_remote_code" not in kwargs:
127
+ kwargs["trust_remote_code"] = True
128
+
129
+ return super().from_pretrained(
130
+ model_id=model_id,
131
+ export=export,
132
+ rbln_config=rbln_config,
133
+ **kwargs,
134
+ )
97
135
 
98
136
  def __getattr__(self, __name: str) -> Any:
99
137
  def redirect(func):
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_mistral import RBLNMistralForCausalLMConfig
16
- from .modeling_mistral import RBLNMistralForCausalLM
15
+ from .configuration_mistral import RBLNMistralForCausalLMConfig, RBLNMistralModelConfig
16
+ from .modeling_mistral import RBLNMistralForCausalLM, RBLNMistralModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNMistralModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Mistral models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -15,5 +15,5 @@
15
15
  from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
16
16
 
17
17
 
18
- class MistralForCausalLMWrapper(DecoderOnlyWrapper):
18
+ class MistralWrapper(DecoderOnlyWrapper):
19
19
  pass
@@ -15,8 +15,12 @@
15
15
  from transformers import PretrainedConfig
16
16
 
17
17
  from ....utils import logging
18
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyModelForCausalLMConfig
19
- from .mistral_architecture import MistralForCausalLMWrapper
18
+ from ...models.decoderonly import (
19
+ RBLNDecoderOnlyModel,
20
+ RBLNDecoderOnlyModelForCausalLM,
21
+ RBLNDecoderOnlyModelForCausalLMConfig,
22
+ )
23
+ from .mistral_architecture import MistralWrapper
20
24
 
21
25
 
22
26
  logger = logging.get_logger(__name__)
@@ -79,7 +83,26 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
79
83
  ```
80
84
  """
81
85
 
82
- _decoder_wrapper_cls = MistralForCausalLMWrapper
86
+ _decoder_wrapper_cls = MistralWrapper
87
+
88
+ @classmethod
89
+ def _update_sliding_window_config(
90
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
91
+ ):
92
+ rbln_config.cache_impl = "sliding_window"
93
+ rbln_config.sliding_window = model_config.sliding_window
94
+ rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
95
+
96
+ return rbln_config
97
+
98
+
99
+ class RBLNMistralModel(RBLNDecoderOnlyModel):
100
+ """
101
+ The Mistral Model transformer without a language modeling head.
102
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
103
+ """
104
+
105
+ _decoder_wrapper_cls = MistralWrapper
83
106
 
84
107
  @classmethod
85
108
  def _update_sliding_window_config(
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_opt import RBLNOPTForCausalLMConfig
16
- from .modeling_opt import RBLNOPTForCausalLM
15
+ from .configuration_opt import RBLNOPTForCausalLMConfig, RBLNOPTModelConfig
16
+ from .modeling_opt import RBLNOPTForCausalLM, RBLNOPTModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -20,3 +20,10 @@ class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
20
20
  Configuration class for OPT causal language model.
21
21
  Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
22
22
  """
23
+
24
+
25
+ class RBLNOPTModelConfig(RBLNDecoderOnlyModelConfig):
26
+ """
27
+ Configuration class for OPT model.
28
+ Inherits from RBLNDecoderOnlyModelConfig with no additional parameters.
29
+ """
@@ -16,7 +16,7 @@ import torch.nn as nn
16
16
  from transformers import PreTrainedModel
17
17
 
18
18
  from ....utils import logging
19
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
19
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
20
20
  from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
21
21
  from .opt_architecture import OPTWrapper
22
22
 
@@ -69,22 +69,34 @@ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
69
69
  return layer
70
70
 
71
71
  @classmethod
72
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
73
- wrapper_cfg = {
74
- "max_seq_len": rbln_config.max_seq_len,
75
- "attn_impl": rbln_config.attn_impl,
76
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
77
- "kvcache_block_size": rbln_config.kvcache_block_size,
78
- "use_rotary_emb": cls._use_rotary_emb,
79
- "use_attention_mask": rbln_config.use_attention_mask,
80
- "use_position_ids": rbln_config.use_position_ids,
81
- "use_inputs_embeds": rbln_config.use_inputs_embeds,
82
- "cache_impl": rbln_config.cache_impl,
83
- "sliding_window": rbln_config.sliding_window,
84
- "sliding_window_layers": rbln_config.sliding_window_layers,
85
- }
86
-
72
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
87
73
  for i in range(len(model.model.decoder.layers)):
88
74
  model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
89
75
 
90
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
76
+ return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
77
+
78
+
79
+ class RBLNOPTModel(RBLNDecoderOnlyModel):
80
+ """
81
+ The OPT Model transformer without a language modeling head.
82
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
83
+ """
84
+
85
+ _decoder_wrapper_cls = OPTWrapper
86
+ _use_rotary_emb = False
87
+
88
+ def modify_opt_decoder_layer(layer):
89
+ mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
90
+ layer.mlp = mlp
91
+ del layer.fc1
92
+ del layer.fc2
93
+ del layer.activation_fn
94
+
95
+ return layer
96
+
97
+ @classmethod
98
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
99
+ for i in range(len(model.decoder.layers)):
100
+ model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
101
+
102
+ return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
@@ -40,11 +40,11 @@ class OPTWrapper(DecoderOnlyWrapper):
40
40
  def get_rbln_model_class(self):
41
41
  return OPTModel
42
42
 
43
- def get_model_layer(self, causal_lm: "OPTForCausalLM"):
44
- return causal_lm.model.decoder
43
+ def get_model_layer(self, model: "OPTForCausalLM"):
44
+ return model.model.decoder if self.is_causal_lm else model.decoder
45
45
 
46
- def get_decoder_layers(self, causal_lm: "OPTForCausalLM"):
47
- return causal_lm.model.decoder.layers
46
+ def get_decoder_layers(self, model: "OPTForCausalLM"):
47
+ return model.model.decoder.layers if self.is_causal_lm else model.decoder.layers
48
48
 
49
49
 
50
50
  class OPTAttention(DecoderOnlyAttention):
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
16
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig, RBLNPegasusModelConfig
17
+ from .modeling_pegasus import RBLNPegasusForConditionalGeneration, RBLNPegasusModel
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNPegasusModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ """
21
+ Configuration class for RBLNPegasusModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized PEGASUS models for feature extraction tasks.
25
+ """
26
+
27
+ rbln_model_input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
28
+
29
+
30
+ class RBLNPegasusForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
31
+ """
32
+ Configuration class for RBLNPegasusForConditionalGeneration.
33
+
34
+ This configuration class stores the configuration parameters specific to
35
+ RBLN-optimized PEGASUS models for conditional text generation tasks.
36
+ """
37
+
38
+ support_paged_attention = True
@@ -0,0 +1,71 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import TYPE_CHECKING, Any, Callable
17
+
18
+ from transformers import PegasusForConditionalGeneration, PreTrainedModel
19
+
20
+ from ....utils.logging import get_logger
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
22
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
23
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig
24
+ from .pegasus_architecture import PegasusWrapper
25
+
26
+
27
+ logger = get_logger()
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import PreTrainedModel
32
+
33
+
34
+ class RBLNPegasusModel(RBLNTransformerEncoderForFeatureExtraction):
35
+ """
36
+ RBLN optimized PEGASUS model for feature extraction tasks.
37
+
38
+ This class provides hardware-accelerated inference for PEGASUS encoder models
39
+ on RBLN devices, optimized for feature extraction use cases.
40
+ """
41
+
42
+ rbln_model_input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
43
+
44
+
45
+ class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
46
+ """
47
+ RBLN optimized PEGASUS model for conditional text generation tasks.
48
+
49
+ This class provides hardware-accelerated inference for PEGASUS models
50
+ on RBLN devices, supporting sequence-to-sequence generation tasks
51
+ such as summarization, translation, and text generation.
52
+ """
53
+
54
+ support_causal_attn = True
55
+
56
+ @classmethod
57
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
58
+ return PegasusWrapper(
59
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
60
+ )
61
+
62
+ def __getattr__(self, __name: str) -> Any:
63
+ def redirect(func):
64
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
65
+
66
+ val = getattr(PegasusForConditionalGeneration, __name)
67
+
68
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
69
+ return redirect(val)
70
+
71
+ return val
@@ -0,0 +1,161 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
20
+ from transformers.utils import logging
21
+
22
+ from ..seq2seq.seq2seq_architecture import (
23
+ Seq2SeqCrossAttention,
24
+ Seq2SeqDecoder,
25
+ Seq2SeqDecoderLayer,
26
+ Seq2SeqDecoderWrapper,
27
+ Seq2SeqEncoderWrapper,
28
+ Seq2SeqForConditionalGeneration,
29
+ Seq2SeqSelfAttention,
30
+ )
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class PegasusWrapper:
37
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
38
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
39
+ self.decoder = PegasusDecoderWrapper(model, use_attention_mask=use_attention_mask)
40
+
41
+
42
+ class PegasusDecoderWrapper(Seq2SeqDecoderWrapper):
43
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
44
+ new_layers = []
45
+ for layer in model.get_decoder().layers:
46
+ self_attn = PegasusSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
47
+ cross_attn = PegasusCrossAttention(layer.encoder_attn)
48
+ new_layers.append(PegasusDecoderLayer(layer, self_attn, cross_attn))
49
+
50
+ decoder_model = PegasusDecoder(model.get_decoder(), new_layers)
51
+ new_model = PegasusForConditionalGeneration(model, decoder_model)
52
+
53
+ return new_model
54
+
55
+
56
+ class PegasusForConditionalGeneration(Seq2SeqForConditionalGeneration):
57
+ pass
58
+
59
+
60
+ class PegasusDecoder(Seq2SeqDecoder):
61
+ has_pos_emb = True
62
+
63
+ def __post_init__(self):
64
+ self.embed_positions = self._original_mod.embed_positions
65
+ self.embed_scale = getattr(self._original_mod, "embed_scale", None)
66
+ self.final_layer_norm = getattr(self._original_mod, "layer_norm", None)
67
+
68
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
69
+ if attention_mask is not None:
70
+ attention_mask = attention_mask[:, None, None, :]
71
+ encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
72
+
73
+ return attention_mask, encoder_attention_mask
74
+
75
+ def apply_position_embedding(self, inputs_embeds, cache_position):
76
+ hidden_all = []
77
+ for i in range(inputs_embeds.shape[0]):
78
+ positions_idx = cache_position[i]
79
+ position_weight = self.embed_positions.weight
80
+ position = position_weight[positions_idx]
81
+ batch_hidden = position + inputs_embeds[i]
82
+ hidden_all.append(batch_hidden)
83
+ hidden_states = torch.stack(hidden_all, dim=0)
84
+
85
+ return hidden_states
86
+
87
+ def get_embedding(self):
88
+ if self.embed_scale is not None:
89
+ return lambda x: self.embed_tokens(x) * self.embed_scale
90
+ else:
91
+ return self.embed_tokens
92
+
93
+
94
+ class PegasusLayerFF(nn.Module):
95
+ def __init__(self, decoder_layer):
96
+ super().__init__()
97
+ self.fc1 = decoder_layer.fc1
98
+ self.fc2 = decoder_layer.fc2
99
+ self.activation_fn = decoder_layer.activation_fn
100
+ self.layer_norm = decoder_layer.final_layer_norm
101
+
102
+ def forward(self, hidden_states):
103
+ # Residual Connection
104
+ residual = hidden_states
105
+ hidden_states = self.layer_norm(hidden_states)
106
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
107
+ hidden_states = self.fc2(hidden_states)
108
+ hidden_states = residual + hidden_states
109
+ return hidden_states
110
+
111
+
112
+ class PegasusDecoderLayer(Seq2SeqDecoderLayer):
113
+ def __post_init__(self):
114
+ self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
115
+ self.encoder_attn = self._original_mod.encoder_attn
116
+ self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
117
+ self.ff_layer = PegasusLayerFF(self._original_mod)
118
+
119
+ def pre_self_attn_layer_norm(self, hidden_states):
120
+ return self.self_attn_layer_norm(hidden_states)
121
+
122
+ def post_self_attn_layer_norm(self, hidden_states):
123
+ return hidden_states
124
+
125
+ def pre_cross_attn_layer_norm(self, hidden_states):
126
+ return self.encoder_attn_layer_norm(hidden_states)
127
+
128
+ def post_cross_attn_layer_norm(self, hidden_states):
129
+ return hidden_states
130
+
131
+
132
+ class PegasusSelfAttention(Seq2SeqSelfAttention):
133
+ def __post_init__(self, use_attention_mask: bool = True):
134
+ self.q_proj = self._original_mod.q_proj
135
+ self.k_proj = self._original_mod.k_proj
136
+ self.v_proj = self._original_mod.v_proj
137
+ self.out_proj = self._original_mod.out_proj
138
+ self.num_heads = self._original_mod.num_heads
139
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
140
+ self.scaling = self.head_dim**-0.5
141
+ if use_attention_mask:
142
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
143
+ else:
144
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
145
+
146
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
147
+ query_states = self.q_proj(hidden_states) * self.scaling
148
+ key_states = self.k_proj(hidden_states)
149
+ value_states = self.v_proj(hidden_states)
150
+ return query_states, key_states, value_states
151
+
152
+
153
+ class PegasusCrossAttention(Seq2SeqCrossAttention):
154
+ def __post_init__(self):
155
+ self.q_proj = self._original_mod.q_proj
156
+ self.k_proj = self._original_mod.k_proj
157
+ self.v_proj = self._original_mod.v_proj
158
+ self.out_proj = self._original_mod.out_proj
159
+ self.num_heads = self._original_mod.num_heads
160
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
161
+ self.embed_dim = self._original_mod.embed_dim
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_phi import RBLNPhiForCausalLMConfig
16
- from .modeling_phi import RBLNPhiForCausalLM
15
+ from .configuration_phi import RBLNPhiForCausalLMConfig, RBLNPhiModelConfig
16
+ from .modeling_phi import RBLNPhiForCausalLM, RBLNPhiModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNPhiModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Phi models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .phi_architecture import PhiWrapper
18
18
 
19
19
 
@@ -81,3 +81,12 @@ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = PhiWrapper
84
+
85
+
86
+ class RBLNPhiModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Phi Model transformer without a language modeling head.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+ """
91
+
92
+ _decoder_wrapper_cls = PhiWrapper
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Optional, Tuple
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import PhiForCausalLM
@@ -27,7 +27,7 @@ from ..decoderonly.decoderonly_architecture import (
27
27
 
28
28
 
29
29
  if TYPE_CHECKING:
30
- from transformers import PhiForCausalLM
30
+ from transformers import PhiForCausalLM, PhiModel
31
31
 
32
32
 
33
33
  class PhiWrapper(DecoderOnlyWrapper):
@@ -40,11 +40,11 @@ class PhiWrapper(DecoderOnlyWrapper):
40
40
  def get_rbln_model_class(self):
41
41
  return PhiModel
42
42
 
43
- def get_model_layer(self, causal_lm: "PhiForCausalLM"):
44
- return causal_lm.model
43
+ def get_model_layer(self, model: Union["PhiForCausalLM", "PhiModel"]):
44
+ return model.model if self.is_causal_lm else model
45
45
 
46
- def get_decoder_layers(self, causal_lm: "PhiForCausalLM"):
47
- return causal_lm.model.layers
46
+ def get_decoder_layers(self, model: Union["PhiForCausalLM", "PhiModel"]):
47
+ return model.model.layers if self.is_causal_lm else model.layers
48
48
 
49
49
 
50
50
  class PhiAttention(DecoderOnlyAttention):
@@ -56,7 +56,10 @@ class PhiAttention(DecoderOnlyAttention):
56
56
  self.qk_layernorm = self._original_mod.qk_layernorm
57
57
  self.rotary_ndims = self._original_mod.rotary_ndims
58
58
 
59
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
59
+ def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
+ if lora_int_id is not None:
61
+ raise NotImplementedError("LoRA is not supported for PhiAttention")
62
+
60
63
  query_states = self.q_proj(hidden_states)
61
64
  key_states = self.k_proj(hidden_states)
62
65
  value_states = self.v_proj(hidden_states)
@@ -84,6 +87,7 @@ class PhiLayer(DecoderOnlyLayer):
84
87
  cos: Optional[torch.Tensor] = None,
85
88
  sin: Optional[torch.Tensor] = None,
86
89
  block_tables: Optional[torch.Tensor] = None,
90
+ lora_int_id: Optional[torch.Tensor] = None,
87
91
  ):
88
92
  residual = hidden_states
89
93