optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,446 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Optional, Union
16
+
17
+ import torch
18
+ from transformers import (
19
+ PretrainedConfig,
20
+ PreTrainedModel,
21
+ )
22
+ from transformers.modeling_utils import no_init_weights
23
+ from transformers.models.colqwen2.modeling_colqwen2 import ColQwen2ForRetrievalOutput
24
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
25
+ Qwen2_5_VLModel,
26
+ Qwen2_5_VLRotaryEmbedding,
27
+ )
28
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
29
+ Qwen2VLModel,
30
+ Qwen2VLRotaryEmbedding,
31
+ )
32
+
33
+ from optimum.rbln.transformers.models.decoderonly.modeling_decoderonly import (
34
+ RBLNDecoderOnlyModel,
35
+ )
36
+
37
+ from .configuration_colqwen2 import (
38
+ RBLNColQwen2ForRetrievalConfig,
39
+ )
40
+
41
+
42
+ if TYPE_CHECKING:
43
+ from transformers import (
44
+ AutoFeatureExtractor,
45
+ AutoProcessor,
46
+ AutoTokenizer,
47
+ PretrainedConfig,
48
+ )
49
+
50
+ from .colqwen2_architecture import ColQwen2LanguageModelWrapper
51
+
52
+
53
+ class RBLNColQwen2ForRetrieval(RBLNDecoderOnlyModel):
54
+ """
55
+ The ColQwen Model transformer for document retrieval using vision-language models.
56
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
57
+
58
+ A class to convert and run pre-trained transformers based `ColQwen2ForRetrieval` model on RBLN devices.
59
+ It implements the methods to convert a pre-trained transformers `ColQwen2ForRetrieval` model into a RBLN transformer model by:
60
+
61
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
62
+ - compiling the resulting graph using the RBLN compiler.
63
+
64
+ **Configuration:**
65
+ This model uses [`RBLNColQwen2ForRetrievalConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
66
+ the `rbln_config` parameter should be an instance of [`RBLNColQwen2ForRetrievalConfig`] or a dictionary conforming to its structure.
67
+
68
+ See the [`RBLNColQwen2ForRetrievalConfig`] class for all available configuration options.
69
+
70
+ Examples:
71
+ ```python
72
+ from optimum.rbln import RBLNColQwen2ForRetrieval
73
+
74
+ # Using a config dictionary
75
+ rbln_config = {
76
+ "visual": {
77
+ "max_seq_lens": 6400,
78
+ },
79
+ "max_seq_len": 32_768,
80
+ "tensor_parallel_size": 4,
81
+ "device": [0, 1, 2, 3],
82
+ "output_hidden_states": False,
83
+ }
84
+ model = RBLNColQwen2ForRetrieval.from_pretrained(
85
+ "vidore/colqwen2-v1.0-hf",
86
+ export=True,
87
+ rbln_config=rbln_config
88
+ )
89
+
90
+ # Using a RBLNColQwen2ForRetrievalConfig instance (recommended for type checking)
91
+ from optimum.rbln import RBLNColQwen2ForRetrievalConfig
92
+
93
+ config = RBLNColQwen2ForRetrievalConfig(
94
+ visual={
95
+ "max_seq_lens": 6400,
96
+ "device": 0,
97
+ },
98
+ max_seq_len=32_768,
99
+ tensor_parallel_size=4,
100
+ device=[0, 1, 2, 3],
101
+ output_hidden_states=False,
102
+ )
103
+ model = RBLNColQwen2ForRetrieval.from_pretrained(
104
+ "vidore/colqwen2-v1.0-hf",
105
+ export=True,
106
+ rbln_config=config
107
+ )
108
+ ```
109
+ """
110
+
111
+ main_input_name = "inputs_embeds"
112
+ auto_model_class = None
113
+ _rbln_submodules = [
114
+ {"name": "visual"},
115
+ ]
116
+ _decoder_wrapper_cls = ColQwen2LanguageModelWrapper
117
+ _use_rotary_emb = False
118
+
119
+ def __post_init__(self, **kwargs):
120
+ self.config = self.config.vlm_config if hasattr(self.config, "vlm_config") else self.config
121
+
122
+ artifacts = torch.load(
123
+ self.model_save_dir / self.subfolder / "torch_artifacts.pth",
124
+ weights_only=False,
125
+ )
126
+ self.embed_tokens = self._create_embedding_layer()
127
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
128
+ self.visual = self.rbln_submodules[0]
129
+ self.prefill_runtime = self.model[0]
130
+ self.mrope_section = self.config.text_config.rope_scaling["mrope_section"]
131
+ self.is_colqwen2_5 = "qwen2_5_vl" in self.config.model_type
132
+
133
+ if self.is_colqwen2_5:
134
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config.text_config)
135
+ else:
136
+ self.rotary_emb = Qwen2VLRotaryEmbedding(self.config.text_config)
137
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
138
+
139
+ @classmethod
140
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
141
+ if hasattr(model, "vlm"):
142
+ model.visual = model.vlm.visual
143
+ model.language_model = model.vlm.language_model
144
+
145
+ # FIXME: temporary fix for ColQwen2ForRetrieval dtype issue
146
+ return model.to(torch.float32)
147
+
148
+ def _create_embedding_layer(self):
149
+ with no_init_weights():
150
+ embed_tokens = torch.nn.Embedding(
151
+ self.config.text_config.vocab_size,
152
+ self.config.text_config.hidden_size,
153
+ self.config.text_config.pad_token_id,
154
+ )
155
+ return embed_tokens
156
+
157
+ @classmethod
158
+ def get_input_info(
159
+ cls,
160
+ batch_size: int,
161
+ query_length: int,
162
+ rbln_config: RBLNColQwen2ForRetrievalConfig,
163
+ model_config: PretrainedConfig,
164
+ ):
165
+ text_config = model_config.text_config
166
+ input_info = super().get_input_info(
167
+ batch_size,
168
+ query_length,
169
+ rbln_config=rbln_config,
170
+ model_config=text_config,
171
+ )
172
+
173
+ pos_idx = 3
174
+ input_info.insert(
175
+ pos_idx,
176
+ (
177
+ "position_emb",
178
+ [
179
+ 2,
180
+ batch_size,
181
+ 1,
182
+ query_length,
183
+ text_config.hidden_size // text_config.num_attention_heads,
184
+ ],
185
+ rbln_config.torch_dtype,
186
+ ),
187
+ )
188
+
189
+ # remove query postion from input_info
190
+ if "query_position" in input_info:
191
+ query_position = input_info.pop(4)
192
+ assert query_position[0] == "query_position", print(query_position[0], "is deleted.")
193
+ return input_info
194
+
195
+ @classmethod
196
+ def _update_rbln_config(
197
+ cls,
198
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
199
+ model: Optional["PreTrainedModel"] = None,
200
+ model_config: Optional["PretrainedConfig"] = None,
201
+ rbln_config: Optional[RBLNColQwen2ForRetrievalConfig] = None,
202
+ ) -> RBLNColQwen2ForRetrievalConfig:
203
+ model_config = model_config.vlm_config if hasattr(model_config, "vlm_config") else model_config
204
+ if rbln_config.output_hidden_states is None:
205
+ rbln_config.output_hidden_states = getattr(model_config.text_config, "output_hidden_states", False)
206
+
207
+ return super()._update_rbln_config(
208
+ preprocessors=preprocessors, model=model, model_config=model_config, rbln_config=rbln_config
209
+ )
210
+
211
+ def _get_position_embeddings(self, hidden_states, position_ids):
212
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
213
+ mrope_section = self.mrope_section * 2
214
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
215
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
216
+ return torch.stack([cos, sin])
217
+
218
+ def get_rope_index(self, *args, **kwargs):
219
+ if self.is_colqwen2_5:
220
+ return Qwen2_5_VLModel.get_rope_index(self, *args, **kwargs)
221
+ else:
222
+ return Qwen2VLModel.get_rope_index(self, *args, **kwargs)
223
+
224
+ def _preprocess_visual(
225
+ self,
226
+ input_ids: torch.LongTensor = None,
227
+ attention_mask: torch.Tensor = None,
228
+ pixel_values: torch.Tensor = None,
229
+ pixel_values_videos: torch.FloatTensor = None,
230
+ image_grid_thw: torch.LongTensor = None,
231
+ video_grid_thw: torch.LongTensor = None,
232
+ second_per_grid_ts: torch.Tensor = None,
233
+ ):
234
+ batch_size = input_ids.shape[0]
235
+ inputs_embeds = self.embed_tokens(input_ids)
236
+
237
+ if pixel_values is not None:
238
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
239
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
240
+ n_image_features = image_embeds.shape[0]
241
+ if n_image_tokens != n_image_features:
242
+ raise ValueError(
243
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
244
+ )
245
+
246
+ mask = input_ids == self.config.image_token_id
247
+ mask_unsqueezed = mask.unsqueeze(-1)
248
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
249
+
250
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
251
+ inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, image_embeds)
252
+
253
+ if pixel_values_videos is not None:
254
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
255
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
256
+ n_video_features = video_embeds.shape[0]
257
+ if n_video_tokens != n_video_features:
258
+ raise ValueError(
259
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
260
+ )
261
+
262
+ mask = input_ids == self.config.video_token_id
263
+ mask_unsqueezed = mask.unsqueeze(-1)
264
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
265
+ inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, video_embeds)
266
+
267
+ max_inputs_len = input_ids.shape[1]
268
+ head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads
269
+ all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
270
+ all_rope_deltas = []
271
+
272
+ image_token_id = self.config.image_token_id
273
+ video_token_id = self.config.video_token_id
274
+ vision_start_token_id = self.config.vision_start_token_id
275
+ image_idx, video_idx = 0, 0
276
+
277
+ for b_idx in range(batch_size):
278
+ input_id = input_ids[b_idx : b_idx + 1][:, attention_mask[b_idx].bool()]
279
+ vision_start_indices = torch.argwhere(input_id == vision_start_token_id).squeeze(1)
280
+ vision_tokens = input_id[0][vision_start_indices + 1]
281
+ image_nums = (vision_tokens == image_token_id).sum()
282
+ video_nums = (vision_tokens == video_token_id).sum()
283
+ args = [
284
+ input_id,
285
+ image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
286
+ video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
287
+ ]
288
+ if self.config.model_type == "qwen2_5_vl":
289
+ args.append(
290
+ second_per_grid_ts[video_idx : video_idx + video_nums] if second_per_grid_ts is not None else None
291
+ )
292
+ position_ids, rope_deltas = self.get_rope_index(*args)
293
+ image_idx += image_nums
294
+ video_idx += video_nums
295
+
296
+ position_embed = self._get_position_embeddings(inputs_embeds, position_ids)
297
+ mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
298
+ all_position_embeds[:, b_idx : b_idx + 1].index_copy_(dim=-2, index=mask_indices, source=position_embed)
299
+ all_rope_deltas.append(rope_deltas)
300
+
301
+ rope_deltas = torch.stack(all_rope_deltas)
302
+
303
+ return inputs_embeds, all_position_embeds, rope_deltas
304
+
305
+ def _preprocess_chunked_prefill(self, inputs_embeds, attention_mask, position_embed):
306
+ # valid sequence length of inputs_embeds
307
+ query_length = inputs_embeds.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
308
+
309
+ # extract valid inputs
310
+ inputs_embeds = inputs_embeds[:, attention_mask.bool()] if attention_mask is not None else inputs_embeds
311
+ position_embed = (
312
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
313
+ )
314
+
315
+ # add padding for chunked prefill
316
+ padding_size = (
317
+ self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
318
+ ) % self.rbln_config.prefill_chunk_size
319
+ padded_len = query_length + padding_size
320
+
321
+ inputs_embeds = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, padding_size))
322
+ position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
323
+ cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
324
+
325
+ return inputs_embeds, position_embed, cache_position, query_length
326
+
327
+ def _chunked_prefill_forward(
328
+ self,
329
+ inputs_embeds: torch.Tensor,
330
+ attention_mask: Optional[torch.Tensor] = None,
331
+ position_embed: Optional[torch.Tensor] = None,
332
+ output_hidden_states: Optional[bool] = False,
333
+ ):
334
+ padded_inputs_embeds, padded_position_embed, cache_position, query_length = self._preprocess_chunked_prefill(
335
+ inputs_embeds, attention_mask, position_embed
336
+ )
337
+
338
+ # Chunked prefill
339
+ projs = []
340
+ all_hidden_states = [] if output_hidden_states else None
341
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
342
+ # Extract the current chunk of inputs and cache positions
343
+ input_chunk = padded_inputs_embeds[:, step : step + self.rbln_config.prefill_chunk_size]
344
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
345
+ position_embed_chunk = padded_position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
346
+
347
+ # Forward pass for the current chunk
348
+ proj = self.prefill_runtime(
349
+ inputs_embeds=input_chunk,
350
+ cache_position=cache_pos_chunk,
351
+ block_tables=self.block_tables,
352
+ position_emb=position_embed_chunk,
353
+ )
354
+
355
+ if output_hidden_states:
356
+ projs.append(proj[0])
357
+ all_hidden_states.append(proj[1:])
358
+ else:
359
+ projs.append(proj)
360
+
361
+ projs = torch.concat(projs, dim=-2)[:, :query_length]
362
+ if output_hidden_states:
363
+ # Concatenate chunks for each layer
364
+ concatenated_hidden_states = [
365
+ torch.concat(hs_chunks, dim=-2)[:, :query_length] for hs_chunks in list(zip(*all_hidden_states))
366
+ ]
367
+ all_hidden_states = tuple(concatenated_hidden_states)
368
+
369
+ return self._postprocess_chunked_prefill(projs, attention_mask), all_hidden_states
370
+
371
+ def _postprocess_chunked_prefill(self, projs, attention_mask):
372
+ # index copy for attention mask
373
+ if attention_mask is not None:
374
+ embedding = torch.full(
375
+ (1, attention_mask.shape[-1], projs.shape[-1]),
376
+ fill_value=1e-10,
377
+ dtype=projs.dtype,
378
+ )
379
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
380
+ embedding.index_copy_(dim=-2, index=mask_indices, source=projs)
381
+ else:
382
+ embedding = projs
383
+ return embedding
384
+
385
+ def forward(
386
+ self,
387
+ input_ids: Optional[torch.LongTensor] = None,
388
+ inputs_embeds: Optional[torch.FloatTensor] = None,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ pixel_values: Optional[torch.Tensor] = None,
391
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
392
+ image_grid_thw: Optional[torch.LongTensor] = None,
393
+ video_grid_thw: Optional[torch.LongTensor] = None,
394
+ second_per_grid_ts: Optional[torch.Tensor] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ **kwargs,
397
+ ) -> torch.Tensor:
398
+ output_hidden_states = (
399
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
400
+ )
401
+
402
+ if output_hidden_states != self.rbln_config.output_hidden_states:
403
+ raise ValueError(
404
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
405
+ f"Please compile again with the correct argument."
406
+ )
407
+
408
+ # Handle the custom "pixel_values" input obtained with `ColQwen2Processor` through unpadding
409
+ if pixel_values is not None and image_grid_thw is not None:
410
+ offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] # (batch_size,)
411
+ pixel_values = torch.cat(
412
+ [pixel_sequence[:offset] for pixel_sequence, offset in zip(pixel_values, offsets)],
413
+ dim=0,
414
+ )
415
+ # visual preprocessing
416
+ inputs_embeds, position_embed, _ = self._preprocess_visual(
417
+ input_ids,
418
+ attention_mask,
419
+ pixel_values,
420
+ pixel_values_videos,
421
+ image_grid_thw,
422
+ video_grid_thw,
423
+ second_per_grid_ts,
424
+ )
425
+ batch_size = inputs_embeds.shape[0]
426
+
427
+ projs = []
428
+ for b_idx in range(batch_size):
429
+ proj = self._chunked_prefill_forward(
430
+ inputs_embeds[b_idx : b_idx + 1],
431
+ attention_mask[b_idx] if attention_mask is not None else None,
432
+ position_embed[:, b_idx : b_idx + 1],
433
+ output_hidden_states=output_hidden_states,
434
+ )
435
+ projs.append(proj[0])
436
+ all_hidden_states = proj[1] if output_hidden_states else ()
437
+
438
+ # postprocess
439
+ projs = torch.cat(projs, dim=0)
440
+ projs = projs / projs.norm(dim=-1, keepdim=True)
441
+ projs = projs * attention_mask.unsqueeze(-1)
442
+
443
+ return ColQwen2ForRetrievalOutput(
444
+ embeddings=projs,
445
+ hidden_states=all_hidden_states,
446
+ )
@@ -22,5 +22,6 @@ from ....ops import (
22
22
  paged_flash_causal_attn_decode,
23
23
  paged_flash_causal_attn_prefill,
24
24
  )
25
- from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
26
- from .modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
25
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
26
+ from .configuration_lora import RBLNLoRAAdapterConfig, RBLNLoRAConfig
27
+ from .modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
@@ -12,29 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, List, Literal, Optional, Union
16
-
17
- import rebel
15
+ from typing import Any, Dict, List, Literal, Optional, Union, get_args
18
16
 
19
17
  from ....configuration_utils import RBLNModelConfig
20
18
  from ....utils.logging import get_logger
21
19
  from ...utils.rbln_quantization import RBLNQuantizationConfig
20
+ from .configuration_lora import RBLNLoRAConfig
22
21
 
23
22
 
24
23
  logger = get_logger()
25
24
 
26
25
  CacheImplType = Literal["static", "sliding_window", "hybrid"]
26
+ PhaseType = Literal["prefill", "image_prefill", "decode"]
27
27
 
28
28
 
29
- class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
29
+ class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
30
30
  """
31
- Configuration class for RBLN decoder-only models for Causal Language Modeling.
31
+ Configuration class for RBLN decoder-only models.
32
32
 
33
33
  This class extends RBLNModelConfig with parameters specific to decoder-only transformer
34
34
  architectures optimized for RBLN devices. It controls aspects like attention implementation,
35
35
  KV cache management, and batching for inference.
36
36
  """
37
37
 
38
+ _default_phases = ["prefill"]
39
+ _default_logits_to_keep = 0
40
+
38
41
  def __init__(
39
42
  self,
40
43
  batch_size: Optional[int] = None,
@@ -46,12 +49,15 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
46
49
  kvcache_partition_len: Optional[int] = None,
47
50
  kvcache_block_size: Optional[int] = None,
48
51
  quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
52
+ lora_config: Optional[Union[Dict[str, Any], RBLNLoRAConfig]] = None,
49
53
  prefill_chunk_size: Optional[int] = None,
50
54
  kvcache_num_blocks: Optional[int] = None,
51
55
  decoder_batch_sizes: Optional[List[int]] = None,
52
56
  cache_impl: Optional[CacheImplType] = None,
53
57
  sliding_window: Optional[int] = None,
54
58
  sliding_window_layers: Optional[List[int]] = None,
59
+ phases: Optional[List[PhaseType]] = None,
60
+ logits_to_keep: Optional[int] = None,
55
61
  **kwargs,
56
62
  ):
57
63
  """
@@ -78,6 +84,10 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
78
84
  section below for details.
79
85
  quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
80
86
  quantization. Specifies format, etc.
87
+ lora_config (Optional[Union[Dict[str, Any], RBLNLoRAConfig]]): Configuration for LoRA
88
+ (Low-Rank Adaptation) settings when using (multi-)LoRA support. Can be provided as
89
+ a dictionary or an RBLNLoRAConfig instance. When provided, enables LoRA functionality
90
+ for the model compilation. Defaults to None (no LoRA).
81
91
  prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
82
92
  processing input sequences. Defaults to 128. Must be a positive integer
83
93
  divisible by 64. Affects prefill performance and memory usage.
@@ -98,7 +108,11 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
98
108
  you must specify the `sliding_window` size and optionally `sliding_window_layers` for hybrid mode.
99
109
  sliding_window (Optional[int]): The size of the sliding window. Defaults to None.
100
110
  sliding_window_layers (Optional[List[int]]): The layers to use for the sliding window used in the hybrid model. Defaults to None.
101
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
111
+ phases (Optional[List[PhaseType]]): The phases to compile the model for. Defaults to ["prefill"] if DecoderOnlyModel is used,
112
+ ["prefill", "decode"] if DecoderOnlyModelForCausalLM is used.
113
+ logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits.
114
+ Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used.
115
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
102
116
 
103
117
  Raises:
104
118
  ValueError: If `batch_size` is not a positive integer.
@@ -170,54 +184,117 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
170
184
  self.max_seq_len = max_seq_len
171
185
  self.use_inputs_embeds = use_inputs_embeds or False
172
186
  self.use_position_ids = use_position_ids or False
173
- self.use_attention_mask = use_attention_mask
174
-
175
- npu = self.npu or rebel.get_npu_name()
176
- if npu == "RBLN-CA02":
177
- if self.use_attention_mask is False:
178
- logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
179
- self.use_attention_mask = True
180
- else:
181
- self.use_attention_mask = self.use_attention_mask or False
187
+ self.use_attention_mask = use_attention_mask or False
182
188
 
183
189
  if self.use_position_ids and not self.use_attention_mask:
184
190
  raise ValueError("Position IDs should be used with attention mask.")
185
191
 
186
- self.attn_impl = attn_impl
187
- self.kvcache_partition_len = kvcache_partition_len
188
- self.kvcache_block_size = kvcache_block_size
189
192
  self.quantization = quantization or {}
190
193
  if self.quantization and isinstance(self.quantization, dict):
191
194
  self.quantization = RBLNQuantizationConfig(**self.quantization)
192
195
 
193
- self.prefill_chunk_size = prefill_chunk_size or 128
194
- if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
195
- raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
196
+ self.lora_config = lora_config
197
+ if self.lora_config and isinstance(self.lora_config, dict):
198
+ self.lora_config = RBLNLoRAConfig(**self.lora_config)
196
199
 
197
- self.kvcache_num_blocks = kvcache_num_blocks
198
- self.decoder_batch_sizes = decoder_batch_sizes
199
- if self.decoder_batch_sizes is None:
200
- self.decoder_batch_sizes = [self.batch_size]
200
+ # Validate LoRA adapters if LoRA is enabled
201
+ if self.lora_config is not None:
202
+ validation_results = self.lora_config.validate_adapter_weights()
203
+ failed_adapters = [adapter_id for adapter_id, is_valid in validation_results.items() if not is_valid]
201
204
 
202
- if self.use_multiple_decoder:
203
- if max(self.decoder_batch_sizes) > self.batch_size:
205
+ if failed_adapters:
204
206
  raise ValueError(
205
- f"Decoder batch size ({max(self.decoder_batch_sizes)}) must be less than or equal to the runtime batch size ({self.batch_size})."
207
+ f"Some LoRA adapters failed validation and may not be accessible at compile time: {failed_adapters}. "
208
+ "Please ensure all adapter weights are available and properly formatted."
206
209
  )
207
- if max(self.decoder_batch_sizes) < self.batch_size:
208
- logger.warning(
209
- f"Maximum decoder batch size ({max(self.decoder_batch_sizes)}) is less than the model's batch size ({self.batch_size}). "
210
- "Appending the model's batch size to the decoder batch size."
211
- )
212
- self.decoder_batch_sizes.append(self.batch_size)
213
210
 
214
- # Larger batch size should be at the beginning of the list.
215
- self.decoder_batch_sizes.sort(reverse=True)
211
+ logger.info(
212
+ f"LoRA configuration initialized with {self.lora_config.num_adapters} adapters: "
213
+ f"{self.lora_config.adapter_ids}. Max rank: {self.lora_config.max_lora_rank}"
214
+ )
215
+
216
+ self.attn_impl = attn_impl
217
+ self.kvcache_partition_len = kvcache_partition_len
218
+ self.kvcache_block_size = kvcache_block_size
219
+ self.prefill_chunk_size = prefill_chunk_size or 128
220
+ if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
221
+ raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
216
222
 
223
+ self.kvcache_num_blocks = kvcache_num_blocks
217
224
  self.cache_impl = cache_impl or "static"
218
225
  self.sliding_window = sliding_window
219
226
  self.sliding_window_layers = sliding_window_layers or []
220
227
 
228
+ if phases is not None:
229
+ self.validate_phases_type(phases)
230
+ self.phases = phases or self._default_phases
231
+ self.logits_to_keep = logits_to_keep or self._default_logits_to_keep
232
+ if self.logits_to_keep is not None and self.logits_to_keep > 1:
233
+ raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.")
234
+
235
+ self.decoder_batch_sizes = None
236
+ if "decode" in self.phases:
237
+ self.decoder_batch_sizes = decoder_batch_sizes
238
+ if self.decoder_batch_sizes is None:
239
+ self.decoder_batch_sizes = [self.batch_size]
240
+
241
+ if self.use_multiple_decoder:
242
+ if max(self.decoder_batch_sizes) > self.batch_size:
243
+ raise ValueError(
244
+ f"Decoder batch size ({max(self.decoder_batch_sizes)}) must be less than or equal to the runtime batch size ({self.batch_size})."
245
+ )
246
+ if max(self.decoder_batch_sizes) < self.batch_size:
247
+ logger.warning(
248
+ f"Maximum decoder batch size ({max(self.decoder_batch_sizes)}) is less than the model's batch size ({self.batch_size}). "
249
+ "Appending the model's batch size to the decoder batch size."
250
+ )
251
+ self.decoder_batch_sizes.append(self.batch_size)
252
+
253
+ # Larger batch size should be at the beginning of the list.
254
+ self.decoder_batch_sizes.sort(reverse=True)
255
+
256
+ @staticmethod
257
+ def validate_phases_type(phases: List[PhaseType]):
258
+ if not isinstance(phases, list):
259
+ raise ValueError("`phases` must be a list.")
260
+ if not all(phase in get_args(PhaseType) for phase in phases):
261
+ raise ValueError(f"All elements in `phases` must be of type `PhaseType`({get_args(PhaseType)}).")
262
+
221
263
  @property
222
- def use_multiple_decoder(self):
264
+ def use_global_attention(self) -> bool:
265
+ return self.cache_impl in ["static", "hybrid"]
266
+
267
+ @property
268
+ def use_local_attention(self) -> bool:
269
+ return self.cache_impl in ["sliding_window", "hybrid"]
270
+
271
+ @property
272
+ def use_multiple_decoder(self) -> bool:
223
273
  return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
274
+
275
+ @property
276
+ def use_lora(self):
277
+ return self.lora_config is not None
278
+
279
+ @property
280
+ def can_generate(self) -> bool:
281
+ return "decode" in self.phases
282
+
283
+ @property
284
+ def nbits_per_param(self) -> int:
285
+ if self.quantization:
286
+ return self.quantization.nbits_per_param
287
+ return 16
288
+
289
+
290
+ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
291
+ """
292
+ Configuration class for RBLN decoder-only models for Causal Language Modeling.
293
+
294
+ This class extends RBLNModelConfig with parameters specific to decoder-only transformer
295
+ architectures optimized for RBLN devices. It controls aspects like attention implementation,
296
+ KV cache management, and batching for inference.
297
+ """
298
+
299
+ _default_phases = ["prefill", "decode"]
300
+ _default_logits_to_keep = 1