optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -11,95 +11,74 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib
14
15
  import inspect
15
- from collections import deque
16
- from dataclasses import dataclass
17
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
18
17
 
19
18
  import rebel
20
19
  import torch
21
20
  from rebel.compile_context import CompileContext
22
- from transformers import (
23
- AutoModelForImageTextToText,
24
- Gemma3ForConditionalGeneration,
25
- PretrainedConfig,
26
- PreTrainedModel,
27
- )
21
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
22
  from transformers.modeling_outputs import BaseModelOutputWithPooling
29
23
  from transformers.modeling_utils import no_init_weights
30
24
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
31
25
 
32
26
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
27
  from ....modeling import RBLNModel
34
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
28
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
30
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
31
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
32
+ from ..decoderonly.modeling_decoderonly import (
33
+ RBLNDecoderOnlyModelForCausalLM,
34
+ )
35
35
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
36
36
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
37
+ from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
37
38
 
38
39
 
39
40
  if TYPE_CHECKING:
40
41
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
41
42
 
42
43
 
43
- @dataclass
44
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
45
- attention_mask: Optional[torch.Tensor] = None
46
-
47
-
48
- class LoopVisionTower:
49
- def __init__(self, vision_tower: RBLNModel) -> None:
50
- self.vision_tower = vision_tower
44
+ class LoopVisionTower(LoopProcessor):
45
+ def __init__(self, vision_tower: "RBLNModel"):
46
+ super().__init__(model=vision_tower)
51
47
 
52
- def forward(self, *args, **kwargs):
53
- # Loop instead of batch
54
- # shape of pixel_values : [batch, num_channel, height, width]
55
- pixel_values = args[0]
48
+ def _get_batch_size(self, pixel_values, **kwargs):
49
+ return pixel_values.shape[0]
56
50
 
57
- batch_size = pixel_values.shape[0]
58
- outputs = []
59
- for i in range(batch_size):
60
- outputs.append(self.vision_tower(pixel_values=pixel_values[i : i + 1], return_dict=True))
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
52
+ pixel_values_item = pixel_values[index : index + 1]
53
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
54
+ return ([pixel_values_item], {"out": out_buffer})
61
55
 
62
- last_hidden_states = [output.last_hidden_state for output in outputs]
63
-
64
- # FIXME:: This can be optimized using out= API of rbln runtime.
65
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
56
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
57
+ output = kwargs["out"]
66
58
 
67
59
  return BaseModelOutputWithPooling(
68
- last_hidden_state=last_hidden_states,
60
+ last_hidden_state=output[0],
69
61
  )
70
62
 
71
- def __call__(self, *args: Any, **kwds: Any) -> Any:
72
- return self.forward(*args, **kwds)
73
-
74
- def __repr__(self) -> str:
75
- return repr(self.vision_tower)
76
-
77
63
 
78
- class LoopProjector:
79
- def __init__(self, multi_modal_projector) -> None:
80
- self.multi_modal_projector = multi_modal_projector
64
+ class LoopProjector(LoopProcessor):
65
+ def __init__(self, multi_modal_projector: "RBLNModel"):
66
+ super().__init__(model=multi_modal_projector)
81
67
 
82
- def forward(self, *args, **kwargs):
83
- # Loop instead of batch
84
- image_feature = args[0]
68
+ def _get_batch_size(self, image_feature, **kwargs):
69
+ return image_feature.shape[0]
85
70
 
86
- batch_size = image_feature.shape[0]
87
- outputs = []
88
- for i in range(batch_size):
89
- outputs.append(self.multi_modal_projector(image_feature[i : i + 1]))
71
+ def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
72
+ image_feature_item = image_feature[index : index + 1]
73
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
74
+ return ([image_feature_item], {"out": out_buffer})
90
75
 
91
- # FIXME:: This can be optimized using out= API of rbln runtime.
92
- outputs = torch.cat(outputs, dim=0)
93
- return outputs
76
+ def _process_outputs(self, outputs: list, **kwargs):
77
+ output = kwargs["out"]
78
+ return output[0]
94
79
 
95
- def __call__(self, *args: Any, **kwds: Any) -> Any:
96
- return self.forward(*args, **kwds)
97
80
 
98
- def __repr__(self) -> str:
99
- return repr(self.multi_modal_projector)
100
-
101
-
102
- class RBLNGemma3ForConditionalGeneration(RBLNModel):
81
+ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
103
82
  auto_model_class = AutoModelForImageTextToText
104
83
  _rbln_submodules = [
105
84
  {"name": "vision_tower"},
@@ -119,6 +98,21 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
119
98
  def can_generate(self):
120
99
  return True
121
100
 
101
+ @classmethod
102
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
103
+ with no_init_weights():
104
+ model_cls_name = model.model.language_model.__class__.__name__
105
+ causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
106
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
107
+ new_language_model = causal_model_cls(model.model.language_model.config)
108
+
109
+ new_language_model.lm_head = model.lm_head
110
+ new_language_model.model = model.model.language_model
111
+ model.model.language_model = new_language_model
112
+ model.lm_head = None
113
+ del model.lm_head
114
+ return model
115
+
122
116
  def __post_init__(self, **kwargs):
123
117
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
124
118
  self.language_model = self.rbln_submodules[1]
@@ -139,7 +133,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
139
133
  return self.language_model.get_input_embeddings()
140
134
 
141
135
  @classmethod
142
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
136
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
143
137
  return model.multi_modal_projector
144
138
 
145
139
  @classmethod
@@ -208,18 +202,30 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
208
202
  return model_kwargs
209
203
 
210
204
  def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
211
- """
212
- Projects the last hidden state from the vision model into language model space.
213
-
214
- Args:
215
- pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
216
- The tensors corresponding to the input images.
217
-
218
- Returns:
219
- Image feature tensor of shape `(num_images, image_length, embed_dim)`.
220
- """
221
- vision_outputs = self.vision_tower(pixel_values).last_hidden_state
222
- image_features = self.multi_modal_projector(vision_outputs)
205
+ # Projects the last hidden state from the vision model into language model space.
206
+
207
+ # Args:
208
+ # pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
209
+ # The tensors corresponding to the input images.
210
+
211
+ # Returns:
212
+ # Image feature tensor of shape `(num_images, image_length, embed_dim)`.
213
+
214
+ vision_out_buffer = []
215
+ vision_out_size = [
216
+ pixel_values.shape[0],
217
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
218
+ self.config.vision_config.hidden_size,
219
+ ]
220
+ projector_out_size = [
221
+ pixel_values.shape[0],
222
+ self.config.mm_tokens_per_image,
223
+ self.config.text_config.hidden_size,
224
+ ]
225
+ vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
226
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
227
+ vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
228
+ image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
223
229
  return image_features
224
230
 
225
231
  def _preprocess_prefill(
@@ -254,17 +260,45 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
254
260
 
255
261
  return inputs_embeds
256
262
 
263
+ def get_padded_cache_position(
264
+ self,
265
+ cache_position: torch.Tensor, # shape: [1, seq_len]
266
+ token_type_ids: torch.Tensor, # shape: [1, seq_len]
267
+ ) -> torch.Tensor:
268
+ seq_len = cache_position[0][-1].item() + 1
269
+
270
+ # Find image start positions
271
+ image_starts = [
272
+ s
273
+ for s in torch.where(token_type_ids == 1)[1]
274
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
275
+ ]
276
+
277
+ # Initialize padded tensors
278
+ padded_input_len = seq_len
279
+ for image_start in image_starts:
280
+ pad_needed = (
281
+ self.rbln_config.image_prefill_chunk_size
282
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
283
+ ) % self.rbln_config.image_prefill_chunk_size
284
+ padded_input_len += pad_needed
285
+
286
+ return torch.cat(
287
+ [cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
288
+ dim=1,
289
+ )
290
+
257
291
  def forward(
258
292
  self,
259
293
  input_ids: torch.LongTensor = None,
294
+ attention_mask: torch.Tensor = None,
295
+ token_type_ids: torch.Tensor = None,
260
296
  pixel_values: torch.FloatTensor = None,
261
- attention_mask: Optional[torch.Tensor] = None,
262
297
  cache_position: Optional[torch.LongTensor] = None,
263
298
  inputs_embeds: Optional[torch.FloatTensor] = None,
264
299
  generate_idx: Optional[torch.Tensor] = None,
265
300
  padded_cache_lengths: Optional[torch.Tensor] = None,
266
301
  position_ids: Optional[torch.Tensor] = None,
267
- token_type_ids: Optional[torch.Tensor] = None,
268
302
  **lm_kwargs: Dict[str, Any],
269
303
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
270
304
  # prefill
@@ -275,12 +309,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
275
309
 
276
310
  for b_idx in range(batch_size):
277
311
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
312
+ token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
313
+ cache_position = self.get_padded_cache_position(cache_position, token_type_id)
314
+
278
315
  output = self.language_model.prefill_decoder(
279
316
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
280
317
  attention_mask=attention_mask[b_idx],
281
318
  cache_position=cache_position,
282
319
  batch_idx=b_idx,
283
- token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
320
+ token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
284
321
  )
285
322
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
286
323
  logits.append(output.logits)
@@ -309,209 +346,6 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
309
346
  )
310
347
 
311
348
 
312
- class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
313
- def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
314
- super().__init__(*args, **kwargs)
315
- self.image_prefill = image_prefill # FIXME(taehoon)
316
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
317
- self.decode = self.runtime if self.phase == "decode" else None
318
-
319
- def _prepare_prefill_inputs(self, *args, **kwargs):
320
- (
321
- inputs,
322
- cache_position,
323
- chunked_attention_mask,
324
- out_buffers,
325
- position_ids,
326
- position_embed,
327
- padded_cache_lengths,
328
- query_length,
329
- token_type_ids,
330
- ) = super()._prepare_prefill_inputs(*args, **kwargs)
331
-
332
- # chunked_attention_mask shape
333
- chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
334
-
335
- # as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
336
- padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
337
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
338
- cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
339
- position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
340
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
341
-
342
- return (
343
- inputs,
344
- cache_position,
345
- chunked_attention_mask,
346
- out_buffers,
347
- position_ids,
348
- position_embed,
349
- padded_cache_lengths,
350
- query_length,
351
- token_type_ids,
352
- )
353
-
354
- def prefill_forward(
355
- self,
356
- inputs: torch.Tensor,
357
- cache_position: torch.Tensor = None,
358
- attention_mask: Optional[torch.Tensor] = None,
359
- batch_idx: int = None,
360
- block_tables: torch.Tensor = None,
361
- is_external_block_tables: bool = None,
362
- position_embed: Optional[torch.Tensor] = None,
363
- token_type_ids: Optional[torch.Tensor] = None,
364
- local_block_tables: Optional[torch.Tensor] = None,
365
- ) -> torch.FloatTensor:
366
- """
367
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
368
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
369
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
370
- """
371
- (
372
- inputs,
373
- cache_position,
374
- chunked_attention_mask,
375
- out_buffers,
376
- position_ids,
377
- position_embed,
378
- padded_cache_lengths,
379
- query_length,
380
- token_type_ids,
381
- ) = self._prepare_prefill_inputs(
382
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
383
- )
384
-
385
- step = 0
386
- while step < query_length:
387
- # Check if the prefill chunk is an image prefill
388
- is_image_prefill = torch.all(
389
- token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
390
- )
391
- prefill_chunk_size = (
392
- self.rbln_config.image_prefill_chunk_size if is_image_prefill else self.rbln_config.prefill_chunk_size
393
- )
394
-
395
- # Check if the prefill chunk is a text prefill which have image_tokens in it.
396
- is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
397
- token_type_ids[:, step : step + prefill_chunk_size] == 1
398
- )
399
-
400
- # Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
401
- is_cross_block_boundary = (
402
- step // self.rbln_config.kvcache_block_size
403
- != (step + prefill_chunk_size) // self.rbln_config.kvcache_block_size
404
- )
405
-
406
- # Check if the prefill chunk is the last chunk
407
- is_last_chunk = step + prefill_chunk_size >= query_length
408
-
409
- if is_cross_block_boundary:
410
- padding_size = prefill_chunk_size - (step + prefill_chunk_size) % self.rbln_config.kvcache_block_size
411
- padded_cache_lengths += padding_size
412
-
413
- # if text_prefill end with image_tokens, we only treat the text part.
414
- num_processed_tokens = prefill_chunk_size
415
- if is_text_prefill_with_image_tokens:
416
- first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
417
- num_processed_tokens = first_image_token_idx
418
- if is_last_chunk:
419
- num_processed_tokens = query_length - step
420
-
421
- input_chunk = inputs[:, step : step + prefill_chunk_size]
422
- cache_pos_chunk = cache_position[:, step : step + prefill_chunk_size].clone() + padded_cache_lengths
423
- position_ids_chunk = position_ids[:, step : step + prefill_chunk_size].clone()
424
- chunked_attention_mask[
425
- :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
426
- ] = 1
427
- query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
428
-
429
- if is_image_prefill:
430
- logits = self.image_prefill(
431
- input_chunk,
432
- cache_pos_chunk,
433
- block_tables,
434
- local_block_tables,
435
- query_position,
436
- chunked_attention_mask,
437
- position_ids_chunk,
438
- out=out_buffers,
439
- )
440
- else:
441
- logits = self.prefill(
442
- input_chunk,
443
- cache_pos_chunk,
444
- block_tables,
445
- local_block_tables,
446
- query_position,
447
- chunked_attention_mask,
448
- position_ids_chunk,
449
- out=out_buffers,
450
- )
451
-
452
- step += num_processed_tokens
453
-
454
- if not is_external_block_tables:
455
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
456
-
457
- return RBLNGemma3ForCausalLMOutput(
458
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
459
- )
460
-
461
- def decode_forward(
462
- self,
463
- inputs: torch.Tensor,
464
- cache_position: torch.Tensor = None,
465
- block_tables: torch.Tensor = None,
466
- is_external_block_tables: bool = None,
467
- attention_mask: Optional[torch.Tensor] = None,
468
- position_embed: Optional[torch.Tensor] = None,
469
- position_ids: Optional[torch.Tensor] = None,
470
- local_block_tables: Optional[torch.Tensor] = None,
471
- ) -> torch.FloatTensor:
472
- batch_size = inputs.shape[0]
473
- if batch_size != self.batch_size:
474
- raise RuntimeError(
475
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
476
- )
477
-
478
- if batch_size != cache_position.shape[0]:
479
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
480
-
481
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
482
- if is_external_block_tables:
483
- if attention_mask is None:
484
- raise ValueError("attention_mask should be provided with external block tables.")
485
- if local_block_tables is None:
486
- raise ValueError("local_block_tables should be provided with external block tables.")
487
- else:
488
- local_block_tables = (
489
- local_block_tables
490
- if local_block_tables is not None
491
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
492
- )
493
- if self.rbln_config.use_attention_mask and attention_mask is None:
494
- for b_idx in range(batch_size):
495
- decoding_step = cache_position[b_idx].item()
496
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
497
- raise ValueError(
498
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
499
- )
500
- self.dec_attn_mask[b_idx, decoding_step] = 1
501
-
502
- attention_mask = self.dec_attn_mask
503
-
504
- if self.batch_size < block_tables.shape[0]:
505
- block_tables = block_tables[: self.batch_size]
506
-
507
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
508
- attention_mask = attention_mask[: self.batch_size]
509
-
510
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
511
-
512
- return RBLNDecoderOnlyOutput(logits=logits)
513
-
514
-
515
349
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
516
350
  """
517
351
  The Gemma3 Model transformer with a language modeling head (linear layer) on top.
@@ -524,52 +358,36 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
524
358
  """
525
359
 
526
360
  _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
361
+ _supports_non_fp32 = False
527
362
 
528
- def __post_init__(self, **kwargs):
529
- main_input_name = self.main_input_name
530
-
531
- if self.rbln_config.use_inputs_embeds:
532
- main_input_name = "inputs_embeds"
533
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
534
- self.embed_tokens = self._create_embedding_layer()
535
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
536
- else:
537
- self.embed_tokens = None
538
-
363
+ def setup_runtime(self):
539
364
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
540
365
  dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
541
- block_tables = torch.zeros(
542
- self.rbln_config.batch_size,
543
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
544
- dtype=torch.int16,
545
- ).fill_(-1)
546
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
366
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
367
+
368
+ common_kwargs = {
369
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
370
+ "embed_tokens": self.embed_tokens,
371
+ "dec_attn_mask": dec_attn_mask,
372
+ "page_table_manager": page_table_manager,
373
+ "rbln_config": self.rbln_config,
374
+ }
375
+
547
376
  self.prefill_decoder = RBLNGemma3RuntimeModel(
548
377
  runtime=self.model[0],
549
- image_prefill=self.model[1],
550
- main_input_name=main_input_name,
551
- embed_tokens=self.embed_tokens,
378
+ image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
552
379
  phase="prefill",
553
380
  batch_size=self.rbln_config.batch_size,
554
- dec_attn_mask=dec_attn_mask,
555
- block_tables=block_tables,
556
- vocab_size=self.config.vocab_size,
557
- free_block_pool=free_block_pool,
558
- rbln_config=self.rbln_config,
381
+ **common_kwargs,
559
382
  )
560
383
 
561
384
  self.decoders = {}
562
385
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
563
386
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
564
- runtime=self.model[i + 2],
565
- main_input_name=main_input_name,
566
- embed_tokens=self.embed_tokens,
387
+ runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
567
388
  phase="decode",
568
389
  batch_size=batch_size,
569
- dec_attn_mask=dec_attn_mask,
570
- block_tables=block_tables,
571
- free_block_pool=free_block_pool,
572
- rbln_config=self.rbln_config,
390
+ **common_kwargs,
573
391
  )
574
392
 
575
393
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -589,6 +407,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
589
407
  def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
590
408
  sliding_window = getattr(model_config, "sliding_window", None)
591
409
  sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
410
+ if sliding_window_pattern is None:
411
+ if hasattr(model_config, "layer_types"):
412
+ first_full_attention_index = model_config.layer_types.index("full_attention")
413
+ sliding_window_pattern = first_full_attention_index + 1
414
+ else:
415
+ raise ValueError("Cannot determine sliding_window_pattern from model_config")
416
+
592
417
  if sliding_window_pattern <= model_config.num_hidden_layers:
593
418
  rbln_config.cache_impl = "hybrid"
594
419
  rbln_config.sliding_window = sliding_window
@@ -599,7 +424,12 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
599
424
  return rbln_config
600
425
 
601
426
  @classmethod
602
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
427
+ def _update_submodule_config(
428
+ cls,
429
+ model: "PreTrainedModel",
430
+ rbln_config: RBLNModelConfig,
431
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
432
+ ):
603
433
  if rbln_config.image_prefill_chunk_size is None:
604
434
  rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
605
435
 
@@ -624,27 +454,33 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
624
454
  if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
625
455
  raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
626
456
 
627
- # Update image prefill compile config
628
- img_prefill_input_info = cls.get_input_info(
629
- batch_size=1,
630
- query_length=rbln_config.image_prefill_chunk_size,
631
- rbln_config=rbln_config,
632
- model_config=model_config,
633
- )
634
- image_prefill_compile_config = RBLNCompileConfig(
635
- compiled_model_name="image_prefill", input_info=img_prefill_input_info
636
- )
637
- # Insert image_prefill compile config at index 1
638
- compile_cfgs = rbln_config.compile_cfgs
639
- compile_cfgs.insert(1, image_prefill_compile_config)
640
- rbln_config.set_compile_cfgs(compile_cfgs)
457
+ if rbln_config.use_image_prefill:
458
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
459
+ raise NotImplementedError(
460
+ "Not implemented for different prefill chunk sizes between text and image prefill."
461
+ )
462
+
463
+ # Update image prefill compile config
464
+ img_prefill_input_info = cls.get_input_info(
465
+ batch_size=1,
466
+ query_length=rbln_config.image_prefill_chunk_size,
467
+ rbln_config=rbln_config,
468
+ model_config=model_config,
469
+ )
470
+ image_prefill_compile_config = RBLNCompileConfig(
471
+ compiled_model_name="image_prefill", input_info=img_prefill_input_info
472
+ )
473
+ # Insert image_prefill compile config at index 1
474
+ compile_cfgs = rbln_config.compile_cfgs
475
+ compile_cfgs.insert(1, image_prefill_compile_config)
476
+ rbln_config.set_compile_cfgs(compile_cfgs)
641
477
 
642
478
  return rbln_config
643
479
 
644
480
  @classmethod
645
481
  @torch.inference_mode()
646
482
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
647
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
483
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
648
484
 
649
485
  rbln_compile_configs = rbln_config.compile_cfgs
650
486
  prefill_compile_config = rbln_compile_configs[0]
@@ -690,23 +526,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
690
526
  context,
691
527
  rbln_config.quantization,
692
528
  )
529
+ compiled_models = {"prefill": compiled_prefill}
693
530
 
694
- image_prefill_compile_config = rbln_compile_configs[1]
695
- image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
696
- fill=0, static_tensors=static_tensors
697
- )
698
- wrapped_model.phase = "image_prefill"
699
- compiled_image_prefill = compile_model(
700
- wrapped_model,
701
- image_prefill_compile_config,
702
- image_prefill_example_inputs,
703
- context,
704
- rbln_config.quantization,
705
- )
531
+ if rbln_config.use_image_prefill:
532
+ image_prefill_compile_config = rbln_compile_configs[1]
533
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
534
+ fill=0, static_tensors=static_tensors
535
+ )
536
+ wrapped_model.phase = "image_prefill"
537
+ compiled_image_prefill = compile_model(
538
+ wrapped_model,
539
+ image_prefill_compile_config,
540
+ image_prefill_example_inputs,
541
+ context,
542
+ rbln_config.quantization,
543
+ )
544
+ compiled_models["image_prefill"] = compiled_image_prefill
706
545
 
707
- compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
708
546
  wrapped_model.phase = "decode"
709
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[2:]):
547
+ for batch_size, dec_compile_config in zip(
548
+ rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
549
+ ):
710
550
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
711
551
  compiled_decoder = compile_model(
712
552
  wrapped_model,
@@ -727,35 +567,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
727
567
  ) -> List[rebel.Runtime]:
728
568
  expected_model_names = [
729
569
  "prefill",
730
- "image_prefill",
731
570
  *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
732
571
  ]
572
+ if rbln_config.use_image_prefill:
573
+ expected_model_names.insert(1, "image_prefill")
574
+
733
575
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
734
576
  cls._raise_missing_compiled_file_error(expected_model_names)
735
577
 
736
- return [
578
+ ret_val = [
737
579
  rebel.Runtime(
738
580
  compiled_models[0],
739
581
  tensor_type="pt",
740
582
  device=rbln_config.device_map["prefill"],
741
583
  activate_profiler=rbln_config.activate_profiler,
742
584
  timeout=rbln_config.timeout,
743
- ),
744
- rebel.Runtime(
745
- compiled_models[1],
746
- tensor_type="pt",
747
- device=rbln_config.device_map["image_prefill"],
748
- activate_profiler=rbln_config.activate_profiler,
749
- timeout=rbln_config.timeout,
750
- ),
751
- *[
585
+ )
586
+ ]
587
+ if rbln_config.use_image_prefill:
588
+ ret_val.append(
589
+ rebel.Runtime(
590
+ compiled_models[1],
591
+ tensor_type="pt",
592
+ device=rbln_config.device_map["image_prefill"],
593
+ activate_profiler=rbln_config.activate_profiler,
594
+ timeout=rbln_config.timeout,
595
+ ),
596
+ )
597
+
598
+ ret_val.extend(
599
+ [
752
600
  rebel.Runtime(
753
- compiled_models[i + 2],
601
+ compiled_models[i + rbln_config.decoder_runtime_idx],
754
602
  tensor_type="pt",
755
603
  device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
756
604
  activate_profiler=rbln_config.activate_profiler,
757
605
  timeout=rbln_config.timeout,
758
606
  )
759
607
  for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
760
- ],
761
- ]
608
+ ]
609
+ )
610
+
611
+ return ret_val