optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -11,99 +11,74 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib
14
15
  import inspect
15
- from collections import deque
16
- from dataclasses import dataclass
17
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
18
17
 
19
18
  import rebel
20
19
  import torch
21
20
  from rebel.compile_context import CompileContext
22
- from transformers import (
23
- AutoModelForImageTextToText,
24
- Gemma3ForConditionalGeneration,
25
- PretrainedConfig,
26
- PreTrainedModel,
27
- )
21
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
22
  from transformers.modeling_outputs import BaseModelOutputWithPooling
29
23
  from transformers.modeling_utils import no_init_weights
30
24
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
31
25
 
32
26
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
27
  from ....modeling import RBLNModel
34
- from ....utils.logging import get_logger
35
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
28
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
30
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
31
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
32
+ from ..decoderonly.modeling_decoderonly import (
33
+ RBLNDecoderOnlyModelForCausalLM,
34
+ )
36
35
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
37
36
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
38
-
39
-
40
- logger = get_logger()
37
+ from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
41
38
 
42
39
 
43
40
  if TYPE_CHECKING:
44
41
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
45
42
 
46
43
 
47
- @dataclass
48
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
49
- attention_mask: Optional[torch.Tensor] = None
50
-
44
+ class LoopVisionTower(LoopProcessor):
45
+ def __init__(self, vision_tower: "RBLNModel"):
46
+ super().__init__(model=vision_tower)
51
47
 
52
- class LoopVisionTower:
53
- def __init__(self, vision_tower: RBLNModel) -> None:
54
- self.vision_tower = vision_tower
48
+ def _get_batch_size(self, pixel_values, **kwargs):
49
+ return pixel_values.shape[0]
55
50
 
56
- def forward(self, *args, **kwargs):
57
- # Loop instead of batch
58
- # shape of pixel_values : [batch, num_channel, height, width]
59
- pixel_values = args[0]
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
52
+ pixel_values_item = pixel_values[index : index + 1]
53
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
54
+ return ([pixel_values_item], {"out": out_buffer})
60
55
 
61
- batch_size = pixel_values.shape[0]
62
- outputs = []
63
- for i in range(batch_size):
64
- outputs.append(self.vision_tower(pixel_values=pixel_values[i : i + 1], return_dict=True))
65
-
66
- last_hidden_states = [output.last_hidden_state for output in outputs]
67
-
68
- # FIXME:: This can be optimized using out= API of rbln runtime.
69
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
56
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
57
+ output = kwargs["out"]
70
58
 
71
59
  return BaseModelOutputWithPooling(
72
- last_hidden_state=last_hidden_states,
60
+ last_hidden_state=output[0],
73
61
  )
74
62
 
75
- def __call__(self, *args: Any, **kwds: Any) -> Any:
76
- return self.forward(*args, **kwds)
77
-
78
- def __repr__(self) -> str:
79
- return repr(self.vision_tower)
80
-
81
63
 
82
- class LoopProjector:
83
- def __init__(self, multi_modal_projector) -> None:
84
- self.multi_modal_projector = multi_modal_projector
64
+ class LoopProjector(LoopProcessor):
65
+ def __init__(self, multi_modal_projector: "RBLNModel"):
66
+ super().__init__(model=multi_modal_projector)
85
67
 
86
- def forward(self, *args, **kwargs):
87
- # Loop instead of batch
88
- image_feature = args[0]
68
+ def _get_batch_size(self, image_feature, **kwargs):
69
+ return image_feature.shape[0]
89
70
 
90
- batch_size = image_feature.shape[0]
91
- outputs = []
92
- for i in range(batch_size):
93
- outputs.append(self.multi_modal_projector(image_feature[i : i + 1]))
71
+ def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
72
+ image_feature_item = image_feature[index : index + 1]
73
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
74
+ return ([image_feature_item], {"out": out_buffer})
94
75
 
95
- # FIXME:: This can be optimized using out= API of rbln runtime.
96
- outputs = torch.cat(outputs, dim=0)
97
- return outputs
76
+ def _process_outputs(self, outputs: list, **kwargs):
77
+ output = kwargs["out"]
78
+ return output[0]
98
79
 
99
- def __call__(self, *args: Any, **kwds: Any) -> Any:
100
- return self.forward(*args, **kwds)
101
80
 
102
- def __repr__(self) -> str:
103
- return repr(self.multi_modal_projector)
104
-
105
-
106
- class RBLNGemma3ForConditionalGeneration(RBLNModel):
81
+ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
107
82
  auto_model_class = AutoModelForImageTextToText
108
83
  _rbln_submodules = [
109
84
  {"name": "vision_tower"},
@@ -123,6 +98,21 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
123
98
  def can_generate(self):
124
99
  return True
125
100
 
101
+ @classmethod
102
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
103
+ with no_init_weights():
104
+ model_cls_name = model.model.language_model.__class__.__name__
105
+ causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
106
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
107
+ new_language_model = causal_model_cls(model.model.language_model.config)
108
+
109
+ new_language_model.lm_head = model.lm_head
110
+ new_language_model.model = model.model.language_model
111
+ model.model.language_model = new_language_model
112
+ model.lm_head = None
113
+ del model.lm_head
114
+ return model
115
+
126
116
  def __post_init__(self, **kwargs):
127
117
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
128
118
  self.language_model = self.rbln_submodules[1]
@@ -143,7 +133,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
143
133
  return self.language_model.get_input_embeddings()
144
134
 
145
135
  @classmethod
146
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
136
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
147
137
  return model.multi_modal_projector
148
138
 
149
139
  @classmethod
@@ -212,18 +202,30 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
212
202
  return model_kwargs
213
203
 
214
204
  def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
215
- """
216
- Projects the last hidden state from the vision model into language model space.
217
-
218
- Args:
219
- pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
220
- The tensors corresponding to the input images.
221
-
222
- Returns:
223
- Image feature tensor of shape `(num_images, image_length, embed_dim)`.
224
- """
225
- vision_outputs = self.vision_tower(pixel_values).last_hidden_state
226
- image_features = self.multi_modal_projector(vision_outputs)
205
+ # Projects the last hidden state from the vision model into language model space.
206
+
207
+ # Args:
208
+ # pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
209
+ # The tensors corresponding to the input images.
210
+
211
+ # Returns:
212
+ # Image feature tensor of shape `(num_images, image_length, embed_dim)`.
213
+
214
+ vision_out_buffer = []
215
+ vision_out_size = [
216
+ pixel_values.shape[0],
217
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
218
+ self.config.vision_config.hidden_size,
219
+ ]
220
+ projector_out_size = [
221
+ pixel_values.shape[0],
222
+ self.config.mm_tokens_per_image,
223
+ self.config.text_config.hidden_size,
224
+ ]
225
+ vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
226
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
227
+ vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
228
+ image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
227
229
  return image_features
228
230
 
229
231
  def _preprocess_prefill(
@@ -258,17 +260,45 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
258
260
 
259
261
  return inputs_embeds
260
262
 
263
+ def get_padded_cache_position(
264
+ self,
265
+ cache_position: torch.Tensor, # shape: [1, seq_len]
266
+ token_type_ids: torch.Tensor, # shape: [1, seq_len]
267
+ ) -> torch.Tensor:
268
+ seq_len = cache_position[0][-1].item() + 1
269
+
270
+ # Find image start positions
271
+ image_starts = [
272
+ s
273
+ for s in torch.where(token_type_ids == 1)[1]
274
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
275
+ ]
276
+
277
+ # Initialize padded tensors
278
+ padded_input_len = seq_len
279
+ for image_start in image_starts:
280
+ pad_needed = (
281
+ self.rbln_config.image_prefill_chunk_size
282
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
283
+ ) % self.rbln_config.image_prefill_chunk_size
284
+ padded_input_len += pad_needed
285
+
286
+ return torch.cat(
287
+ [cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
288
+ dim=1,
289
+ )
290
+
261
291
  def forward(
262
292
  self,
263
293
  input_ids: torch.LongTensor = None,
294
+ attention_mask: torch.Tensor = None,
295
+ token_type_ids: torch.Tensor = None,
264
296
  pixel_values: torch.FloatTensor = None,
265
- attention_mask: Optional[torch.Tensor] = None,
266
297
  cache_position: Optional[torch.LongTensor] = None,
267
298
  inputs_embeds: Optional[torch.FloatTensor] = None,
268
299
  generate_idx: Optional[torch.Tensor] = None,
269
300
  padded_cache_lengths: Optional[torch.Tensor] = None,
270
301
  position_ids: Optional[torch.Tensor] = None,
271
- token_type_ids: Optional[torch.Tensor] = None,
272
302
  **lm_kwargs: Dict[str, Any],
273
303
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
274
304
  # prefill
@@ -279,12 +309,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
279
309
 
280
310
  for b_idx in range(batch_size):
281
311
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
312
+ token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
313
+ cache_position = self.get_padded_cache_position(cache_position, token_type_id)
314
+
282
315
  output = self.language_model.prefill_decoder(
283
316
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
284
317
  attention_mask=attention_mask[b_idx],
285
318
  cache_position=cache_position,
286
319
  batch_idx=b_idx,
287
- token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
320
+ token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
288
321
  )
289
322
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
290
323
  logits.append(output.logits)
@@ -313,362 +346,6 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
313
346
  )
314
347
 
315
348
 
316
- class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
317
- def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
318
- super().__init__(*args, **kwargs)
319
- self.image_prefill = image_prefill # FIXME(taehoon)
320
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
321
- self.decode = self.runtime if self.phase == "decode" else None
322
-
323
- def pad_for_chunked_images(
324
- self,
325
- inputs: torch.Tensor,
326
- attention_mask: torch.Tensor,
327
- position_ids: torch.Tensor,
328
- token_type_ids: Optional[torch.Tensor] = None,
329
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]:
330
- """
331
- Pads inputs, attention_mask, and position_ids so image token groups (256 tokens with token_type_ids == 1)
332
- start at multiples of prefill_chunk_size (256). Returns padded tensors and total padded length.
333
-
334
- Args:
335
- inputs: (1, seq_len, hidden_size) tensor.
336
- attention_mask: (1, seq_len) tensor, 1 for valid, 0 for masked.
337
- position_ids: (1, seq_len) tensor for RoPE.
338
- token_type_ids: (1, seq_len) tensor, 0 for text, 1 for image.
339
-
340
- Returns:
341
- (inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
342
- """
343
-
344
- if token_type_ids is None:
345
- return inputs, attention_mask, position_ids, 0, torch.zeros(inputs.shape[:2], dtype=torch.long)
346
-
347
- seq_len = inputs.shape[1]
348
-
349
- # Find image start positions
350
- image_starts = [
351
- s
352
- for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
353
- if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
354
- ]
355
-
356
- # Initialize padded tensors
357
- padded_input_len = seq_len
358
- for image_start in image_starts:
359
- pad_needed = (
360
- self.rbln_config.prefill_chunk_size
361
- - (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
362
- ) % self.rbln_config.prefill_chunk_size
363
- padded_input_len += pad_needed
364
- total_padding = padded_input_len - seq_len
365
-
366
- if inputs.dim() == 3:
367
- inputs_padded = torch.zeros(1, padded_input_len, inputs.shape[2], dtype=inputs.dtype)
368
- else:
369
- inputs_padded = torch.zeros(1, padded_input_len, dtype=inputs.dtype)
370
- attention_mask_padded = torch.zeros(1, padded_input_len, dtype=attention_mask.dtype)
371
- position_ids_padded = torch.zeros(1, padded_input_len, dtype=position_ids.dtype)
372
- token_type_ids_padded = torch.zeros(1, padded_input_len, dtype=token_type_ids.dtype)
373
-
374
- # Fill padded tensors
375
- dest_pos = 0
376
- src_pos = 0
377
- last_pos_id = -1
378
- for image_start in image_starts + [seq_len]:
379
- # Text segment
380
- if src_pos < image_start:
381
- length = image_start - src_pos
382
- inputs_padded[:, dest_pos : dest_pos + length] = inputs[:, src_pos:image_start]
383
- attention_mask_padded[:, dest_pos : dest_pos + length] = attention_mask[:, src_pos:image_start]
384
- position_ids_padded[:, dest_pos : dest_pos + length] = position_ids[:, src_pos:image_start]
385
- token_type_ids_padded[:, dest_pos : dest_pos + length] = token_type_ids[:, src_pos:image_start]
386
- dest_pos += length
387
- last_pos_id = position_ids[0, image_start - 1].item()
388
- src_pos = image_start
389
-
390
- # Padding
391
- pad_needed = (
392
- self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
393
- ) % self.rbln_config.prefill_chunk_size
394
- if pad_needed and dest_pos < padded_input_len:
395
- position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
396
- last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
397
- ).unsqueeze(0)
398
- dest_pos += pad_needed
399
-
400
- # Image segment
401
- if src_pos < seq_len and src_pos == image_start:
402
- inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
403
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
404
- ]
405
- attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
406
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
407
- ]
408
- position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
409
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
410
- ]
411
- token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
412
- :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
413
- ]
414
- dest_pos += self.rbln_config.prefill_chunk_size
415
- src_pos += self.rbln_config.prefill_chunk_size
416
- last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
417
-
418
- return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
419
-
420
- def _prepare_prefill_inputs(
421
- self,
422
- inputs: torch.Tensor,
423
- cache_position: torch.Tensor,
424
- attention_mask: Optional[torch.Tensor] = None,
425
- position_embed: Optional[torch.Tensor] = None,
426
- token_type_ids: Optional[torch.Tensor] = None,
427
- ):
428
- """
429
- Prepare inputs for prefill phase.
430
- """
431
- # Handle continuous batching in a compiled graph by extracting valid inputs
432
- # If an attention mask is provided, select only the valid (non-masked) inputs
433
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
434
- token_type_ids = (
435
- token_type_ids[:, attention_mask.bool()]
436
- if attention_mask is not None and token_type_ids is not None
437
- else token_type_ids
438
- )
439
-
440
- if position_embed is not None:
441
- position_embed = (
442
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
443
- )
444
-
445
- seq_len = inputs.shape[1]
446
- # Initialize attention mask for chunked processing
447
- if self.rbln_config.use_attention_mask:
448
- chunked_attention_mask = (
449
- torch.ones(1, seq_len, dtype=torch.float32)
450
- if self.rbln_config.use_position_ids
451
- else torch.zeros(
452
- 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
453
- )
454
- )
455
- else:
456
- chunked_attention_mask = None
457
-
458
- # Buffer for storing output logits
459
- out_buffers = [
460
- torch.empty(
461
- size=self.output_size,
462
- dtype=torch.float32,
463
- device="cpu",
464
- )
465
- ]
466
-
467
- inputs, chunked_attention_mask, position_ids, padded_cache_lengths, token_type_ids_padded = (
468
- self.pad_for_chunked_images(inputs, chunked_attention_mask, cache_position, token_type_ids)
469
- )
470
-
471
- query_length = inputs.shape[1]
472
- if query_length > self.rbln_config.max_seq_len:
473
- raise ValueError(
474
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
475
- )
476
-
477
- # Align attention_mask to compiled shape
478
- if self.rbln_config.use_position_ids:
479
- chunked_attention_mask = torch.nn.functional.pad(
480
- chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
481
- )
482
-
483
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
484
- padding_size = 0
485
- if query_length % self.rbln_config.prefill_chunk_size != 0:
486
- padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
487
- # inputs_embeds
488
- if inputs.dim() == 3:
489
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
490
- # inputs_ids
491
- else:
492
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
493
-
494
- position_ids = torch.cat(
495
- [
496
- position_ids,
497
- torch.arange(
498
- query_length,
499
- query_length + padding_size,
500
- dtype=torch.int32,
501
- ).unsqueeze(0),
502
- ],
503
- dim=-1,
504
- )
505
- token_type_ids_padded = torch.nn.functional.pad(token_type_ids_padded, (0, padding_size))
506
-
507
- if position_embed is not None:
508
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
509
-
510
- cache_position = torch.arange(0, query_length + padding_size, dtype=torch.int32).unsqueeze(0)
511
-
512
- return (
513
- inputs,
514
- cache_position,
515
- chunked_attention_mask,
516
- out_buffers,
517
- position_ids,
518
- position_embed,
519
- padded_cache_lengths,
520
- query_length,
521
- token_type_ids_padded,
522
- )
523
-
524
- def prefill_forward(
525
- self,
526
- inputs: torch.Tensor,
527
- cache_position: torch.Tensor = None,
528
- attention_mask: Optional[torch.Tensor] = None,
529
- batch_idx: int = None,
530
- block_tables: torch.Tensor = None,
531
- is_external_block_tables: bool = None,
532
- position_embed: Optional[torch.Tensor] = None,
533
- token_type_ids: Optional[torch.Tensor] = None,
534
- local_block_tables: Optional[torch.Tensor] = None,
535
- ) -> torch.FloatTensor:
536
- """
537
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
538
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
539
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
540
- """
541
- (
542
- inputs,
543
- cache_position,
544
- padded_attention_mask,
545
- out_buffers,
546
- position_ids,
547
- position_embed,
548
- padded_cache_lengths,
549
- query_length,
550
- token_type_ids_padded,
551
- ) = self._prepare_prefill_inputs(
552
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
553
- )
554
- if not is_external_block_tables:
555
- local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
556
- self.dec_attn_mask[batch_idx : batch_idx + 1] = padded_attention_mask[:1]
557
-
558
- if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
559
- chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
560
-
561
- # Process input in chunks of size `prefill_chunk_size`
562
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
563
- # Extract the current chunk of inputs and cache positions
564
- input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
565
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
566
- position_ids_chunk = (
567
- position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
568
- if position_ids is not None
569
- else None
570
- )
571
-
572
- if self.rbln_config.use_attention_mask:
573
- if self.rbln_config.use_position_ids:
574
- chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = (
575
- padded_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size]
576
- )
577
-
578
- # Define query position
579
- query_position = (
580
- torch.sum(
581
- chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
582
- dim=-1,
583
- dtype=torch.int16,
584
- ).squeeze(0)
585
- - 1
586
- )
587
- if token_type_ids_padded[:, step] == 1:
588
- if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
589
- raise ValueError("All tokens of image_prefill should be the same image.")
590
- else:
591
- logits = self.image_prefill(
592
- input_chunk,
593
- cache_pos_chunk,
594
- block_tables,
595
- local_block_tables,
596
- query_position,
597
- chunked_attention_mask,
598
- position_ids_chunk,
599
- out=out_buffers,
600
- )
601
- else:
602
- # Forward pass for the current chunk
603
- logits = self.prefill(
604
- input_chunk,
605
- cache_pos_chunk,
606
- block_tables,
607
- local_block_tables,
608
- query_position,
609
- chunked_attention_mask,
610
- position_ids_chunk,
611
- out=out_buffers,
612
- )
613
-
614
- return RBLNGemma3ForCausalLMOutput(
615
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
616
- )
617
-
618
- def decode_forward(
619
- self,
620
- inputs: torch.Tensor,
621
- cache_position: torch.Tensor = None,
622
- block_tables: torch.Tensor = None,
623
- is_external_block_tables: bool = None,
624
- attention_mask: Optional[torch.Tensor] = None,
625
- position_embed: Optional[torch.Tensor] = None,
626
- position_ids: Optional[torch.Tensor] = None,
627
- local_block_tables: Optional[torch.Tensor] = None,
628
- ) -> torch.FloatTensor:
629
- batch_size = inputs.shape[0]
630
- if batch_size != self.batch_size:
631
- raise RuntimeError(
632
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
633
- )
634
-
635
- if batch_size != cache_position.shape[0]:
636
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
637
-
638
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
639
- if is_external_block_tables:
640
- if attention_mask is None:
641
- raise ValueError("attention_mask should be provided with external block tables.")
642
- if local_block_tables is None:
643
- raise ValueError("local_block_tables should be provided with external block tables.")
644
- else:
645
- local_block_tables = (
646
- local_block_tables
647
- if local_block_tables is not None
648
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
649
- )
650
- if self.rbln_config.use_attention_mask and attention_mask is None:
651
- for b_idx in range(batch_size):
652
- decoding_step = cache_position[b_idx].item()
653
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
654
- raise ValueError(
655
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
656
- )
657
- self.dec_attn_mask[b_idx, decoding_step] = 1
658
-
659
- attention_mask = self.dec_attn_mask
660
-
661
- if self.batch_size < block_tables.shape[0]:
662
- block_tables = block_tables[: self.batch_size]
663
-
664
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
665
- attention_mask = attention_mask[: self.batch_size]
666
-
667
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
668
-
669
- return RBLNDecoderOnlyOutput(logits=logits)
670
-
671
-
672
349
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
673
350
  """
674
351
  The Gemma3 Model transformer with a language modeling head (linear layer) on top.
@@ -681,52 +358,36 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
681
358
  """
682
359
 
683
360
  _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
361
+ _supports_non_fp32 = False
684
362
 
685
- def __post_init__(self, **kwargs):
686
- main_input_name = self.main_input_name
687
-
688
- if self.rbln_config.use_inputs_embeds:
689
- main_input_name = "inputs_embeds"
690
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
691
- self.embed_tokens = self._create_embedding_layer()
692
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
693
- else:
694
- self.embed_tokens = None
695
-
363
+ def setup_runtime(self):
696
364
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
697
365
  dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
698
- block_tables = torch.zeros(
699
- self.rbln_config.batch_size,
700
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
701
- dtype=torch.int16,
702
- ).fill_(-1)
703
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
366
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
367
+
368
+ common_kwargs = {
369
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
370
+ "embed_tokens": self.embed_tokens,
371
+ "dec_attn_mask": dec_attn_mask,
372
+ "page_table_manager": page_table_manager,
373
+ "rbln_config": self.rbln_config,
374
+ }
375
+
704
376
  self.prefill_decoder = RBLNGemma3RuntimeModel(
705
377
  runtime=self.model[0],
706
- image_prefill=self.model[1],
707
- main_input_name=main_input_name,
708
- embed_tokens=self.embed_tokens,
378
+ image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
709
379
  phase="prefill",
710
380
  batch_size=self.rbln_config.batch_size,
711
- dec_attn_mask=dec_attn_mask,
712
- block_tables=block_tables,
713
- vocab_size=self.config.vocab_size,
714
- free_block_pool=free_block_pool,
715
- rbln_config=self.rbln_config,
381
+ **common_kwargs,
716
382
  )
717
383
 
718
384
  self.decoders = {}
719
385
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
720
386
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
721
- runtime=self.model[i + 2],
722
- main_input_name=main_input_name,
723
- embed_tokens=self.embed_tokens,
387
+ runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
724
388
  phase="decode",
725
389
  batch_size=batch_size,
726
- dec_attn_mask=dec_attn_mask,
727
- block_tables=block_tables,
728
- free_block_pool=free_block_pool,
729
- rbln_config=self.rbln_config,
390
+ **common_kwargs,
730
391
  )
731
392
 
732
393
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -746,6 +407,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
746
407
  def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
747
408
  sliding_window = getattr(model_config, "sliding_window", None)
748
409
  sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
410
+ if sliding_window_pattern is None:
411
+ if hasattr(model_config, "layer_types"):
412
+ first_full_attention_index = model_config.layer_types.index("full_attention")
413
+ sliding_window_pattern = first_full_attention_index + 1
414
+ else:
415
+ raise ValueError("Cannot determine sliding_window_pattern from model_config")
416
+
749
417
  if sliding_window_pattern <= model_config.num_hidden_layers:
750
418
  rbln_config.cache_impl = "hybrid"
751
419
  rbln_config.sliding_window = sliding_window
@@ -756,14 +424,20 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
756
424
  return rbln_config
757
425
 
758
426
  @classmethod
759
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
760
- if rbln_config.prefill_chunk_size is None:
761
- rbln_config.prefill_chunk_size = model.config.mm_tokens_per_image
427
+ def _update_submodule_config(
428
+ cls,
429
+ model: "PreTrainedModel",
430
+ rbln_config: RBLNModelConfig,
431
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
432
+ ):
433
+ if rbln_config.image_prefill_chunk_size is None:
434
+ rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
762
435
 
763
- if rbln_config.prefill_chunk_size != model.config.mm_tokens_per_image:
764
- logger.warning(
765
- f"Prefill chunk size is different from mm_tokens_per_image: {rbln_config.prefill_chunk_size} != {model.config.mm_tokens_per_image}"
436
+ if rbln_config.image_prefill_chunk_size != model.config.mm_tokens_per_image:
437
+ raise ValueError(
438
+ f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
766
439
  )
440
+
767
441
  return rbln_config
768
442
 
769
443
  @classmethod
@@ -777,22 +451,36 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
777
451
  # Update rbln_config with super class
778
452
  rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
779
453
 
780
- # Assume that prefill compile config is at index 0
781
- compile_cfgs = rbln_config.compile_cfgs
782
- image_prefill_compile_config = RBLNCompileConfig(
783
- compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
784
- )
785
- # Insert image_prefill compile config at index 1
786
- image_idx = 1
787
- compile_cfgs.insert(image_idx, image_prefill_compile_config)
788
- rbln_config.set_compile_cfgs(compile_cfgs)
454
+ if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
455
+ raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
456
+
457
+ if rbln_config.use_image_prefill:
458
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
459
+ raise NotImplementedError(
460
+ "Not implemented for different prefill chunk sizes between text and image prefill."
461
+ )
462
+
463
+ # Update image prefill compile config
464
+ img_prefill_input_info = cls.get_input_info(
465
+ batch_size=1,
466
+ query_length=rbln_config.image_prefill_chunk_size,
467
+ rbln_config=rbln_config,
468
+ model_config=model_config,
469
+ )
470
+ image_prefill_compile_config = RBLNCompileConfig(
471
+ compiled_model_name="image_prefill", input_info=img_prefill_input_info
472
+ )
473
+ # Insert image_prefill compile config at index 1
474
+ compile_cfgs = rbln_config.compile_cfgs
475
+ compile_cfgs.insert(1, image_prefill_compile_config)
476
+ rbln_config.set_compile_cfgs(compile_cfgs)
789
477
 
790
478
  return rbln_config
791
479
 
792
480
  @classmethod
793
481
  @torch.inference_mode()
794
482
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
795
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
483
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
796
484
 
797
485
  rbln_compile_configs = rbln_config.compile_cfgs
798
486
  prefill_compile_config = rbln_compile_configs[0]
@@ -838,20 +526,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
838
526
  context,
839
527
  rbln_config.quantization,
840
528
  )
529
+ compiled_models = {"prefill": compiled_prefill}
841
530
 
842
- image_prefill_compile_config = rbln_compile_configs[1]
843
- wrapped_model.phase = "image_prefill"
844
- compiled_image_prefill = compile_model(
845
- wrapped_model,
846
- image_prefill_compile_config,
847
- prefill_example_inputs,
848
- context,
849
- rbln_config.quantization,
850
- )
531
+ if rbln_config.use_image_prefill:
532
+ image_prefill_compile_config = rbln_compile_configs[1]
533
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
534
+ fill=0, static_tensors=static_tensors
535
+ )
536
+ wrapped_model.phase = "image_prefill"
537
+ compiled_image_prefill = compile_model(
538
+ wrapped_model,
539
+ image_prefill_compile_config,
540
+ image_prefill_example_inputs,
541
+ context,
542
+ rbln_config.quantization,
543
+ )
544
+ compiled_models["image_prefill"] = compiled_image_prefill
851
545
 
852
- compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
853
546
  wrapped_model.phase = "decode"
854
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[2:]):
547
+ for batch_size, dec_compile_config in zip(
548
+ rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
549
+ ):
855
550
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
856
551
  compiled_decoder = compile_model(
857
552
  wrapped_model,
@@ -872,32 +567,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
872
567
  ) -> List[rebel.Runtime]:
873
568
  expected_model_names = [
874
569
  "prefill",
875
- "image_prefill",
876
570
  *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
877
571
  ]
572
+ if rbln_config.use_image_prefill:
573
+ expected_model_names.insert(1, "image_prefill")
574
+
878
575
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
879
576
  cls._raise_missing_compiled_file_error(expected_model_names)
880
577
 
881
- return [
578
+ ret_val = [
882
579
  rebel.Runtime(
883
580
  compiled_models[0],
884
581
  tensor_type="pt",
885
582
  device=rbln_config.device_map["prefill"],
886
583
  activate_profiler=rbln_config.activate_profiler,
887
- ),
888
- rebel.Runtime(
889
- compiled_models[1],
890
- tensor_type="pt",
891
- device=rbln_config.device_map["image_prefill"],
892
- activate_profiler=rbln_config.activate_profiler,
893
- ),
894
- *[
584
+ timeout=rbln_config.timeout,
585
+ )
586
+ ]
587
+ if rbln_config.use_image_prefill:
588
+ ret_val.append(
895
589
  rebel.Runtime(
896
- compiled_models[i + 2],
590
+ compiled_models[1],
591
+ tensor_type="pt",
592
+ device=rbln_config.device_map["image_prefill"],
593
+ activate_profiler=rbln_config.activate_profiler,
594
+ timeout=rbln_config.timeout,
595
+ ),
596
+ )
597
+
598
+ ret_val.extend(
599
+ [
600
+ rebel.Runtime(
601
+ compiled_models[i + rbln_config.decoder_runtime_idx],
897
602
  tensor_type="pt",
898
603
  device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
899
604
  activate_profiler=rbln_config.activate_profiler,
605
+ timeout=rbln_config.timeout,
900
606
  )
901
607
  for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
902
- ],
903
- ]
608
+ ]
609
+ )
610
+
611
+ return ret_val