optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -11,95 +11,74 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib
14
15
  import inspect
15
- from collections import deque
16
- from dataclasses import dataclass
17
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
18
17
 
19
18
  import rebel
20
19
  import torch
21
20
  from rebel.compile_context import CompileContext
22
- from transformers import (
23
- AutoModelForImageTextToText,
24
- Gemma3ForConditionalGeneration,
25
- PretrainedConfig,
26
- PreTrainedModel,
27
- )
21
+ from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
28
22
  from transformers.modeling_outputs import BaseModelOutputWithPooling
29
23
  from transformers.modeling_utils import no_init_weights
30
24
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
31
25
 
32
26
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
27
  from ....modeling import RBLNModel
34
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
28
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
30
+ from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
31
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
32
+ from ..decoderonly.modeling_decoderonly import (
33
+ RBLNDecoderOnlyModelForCausalLM,
34
+ )
35
35
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
36
36
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
37
+ from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
37
38
 
38
39
 
39
40
  if TYPE_CHECKING:
40
41
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
41
42
 
42
43
 
43
- @dataclass
44
- class RBLNGemma3ForCausalLMOutput(RBLNDecoderOnlyOutput):
45
- attention_mask: Optional[torch.Tensor] = None
46
-
47
-
48
- class LoopVisionTower:
49
- def __init__(self, vision_tower: RBLNModel) -> None:
50
- self.vision_tower = vision_tower
44
+ class LoopVisionTower(LoopProcessor):
45
+ def __init__(self, vision_tower: "RBLNModel"):
46
+ super().__init__(model=vision_tower)
51
47
 
52
- def forward(self, *args, **kwargs):
53
- # Loop instead of batch
54
- # shape of pixel_values : [batch, num_channel, height, width]
55
- pixel_values = args[0]
48
+ def _get_batch_size(self, pixel_values, **kwargs):
49
+ return pixel_values.shape[0]
56
50
 
57
- batch_size = pixel_values.shape[0]
58
- outputs = []
59
- for i in range(batch_size):
60
- outputs.append(self.vision_tower(pixel_values=pixel_values[i : i + 1], return_dict=True))
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
52
+ pixel_values_item = pixel_values[index : index + 1]
53
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
54
+ return ([pixel_values_item], {"out": out_buffer})
61
55
 
62
- last_hidden_states = [output.last_hidden_state for output in outputs]
63
-
64
- # FIXME:: This can be optimized using out= API of rbln runtime.
65
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
56
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
57
+ output = kwargs["out"]
66
58
 
67
59
  return BaseModelOutputWithPooling(
68
- last_hidden_state=last_hidden_states,
60
+ last_hidden_state=output[0],
69
61
  )
70
62
 
71
- def __call__(self, *args: Any, **kwds: Any) -> Any:
72
- return self.forward(*args, **kwds)
73
-
74
- def __repr__(self) -> str:
75
- return repr(self.vision_tower)
76
-
77
63
 
78
- class LoopProjector:
79
- def __init__(self, multi_modal_projector) -> None:
80
- self.multi_modal_projector = multi_modal_projector
64
+ class LoopProjector(LoopProcessor):
65
+ def __init__(self, multi_modal_projector: "RBLNModel"):
66
+ super().__init__(model=multi_modal_projector)
81
67
 
82
- def forward(self, *args, **kwargs):
83
- # Loop instead of batch
84
- image_feature = args[0]
68
+ def _get_batch_size(self, image_feature, **kwargs):
69
+ return image_feature.shape[0]
85
70
 
86
- batch_size = image_feature.shape[0]
87
- outputs = []
88
- for i in range(batch_size):
89
- outputs.append(self.multi_modal_projector(image_feature[i : i + 1]))
71
+ def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
72
+ image_feature_item = image_feature[index : index + 1]
73
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
74
+ return ([image_feature_item], {"out": out_buffer})
90
75
 
91
- # FIXME:: This can be optimized using out= API of rbln runtime.
92
- outputs = torch.cat(outputs, dim=0)
93
- return outputs
76
+ def _process_outputs(self, outputs: list, **kwargs):
77
+ output = kwargs["out"]
78
+ return output[0]
94
79
 
95
- def __call__(self, *args: Any, **kwds: Any) -> Any:
96
- return self.forward(*args, **kwds)
97
80
 
98
- def __repr__(self) -> str:
99
- return repr(self.multi_modal_projector)
100
-
101
-
102
- class RBLNGemma3ForConditionalGeneration(RBLNModel):
81
+ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
103
82
  auto_model_class = AutoModelForImageTextToText
104
83
  _rbln_submodules = [
105
84
  {"name": "vision_tower"},
@@ -119,6 +98,23 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
119
98
  def can_generate(self):
120
99
  return True
121
100
 
101
+ @classmethod
102
+ def get_pytorch_model(cls, *args, **kwargs):
103
+ model = super().get_pytorch_model(*args, **kwargs)
104
+
105
+ with no_init_weights():
106
+ model_cls_name = model.model.language_model.__class__.__name__
107
+ causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
108
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
109
+ new_language_model = causal_model_cls(model.model.language_model.config)
110
+
111
+ new_language_model.lm_head = model.lm_head
112
+ new_language_model.model = model.model.language_model
113
+ model.model.language_model = new_language_model
114
+ model.lm_head = None
115
+ del model.lm_head
116
+ return model
117
+
122
118
  def __post_init__(self, **kwargs):
123
119
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
124
120
  self.language_model = self.rbln_submodules[1]
@@ -208,18 +204,30 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
208
204
  return model_kwargs
209
205
 
210
206
  def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
211
- """
212
- Projects the last hidden state from the vision model into language model space.
213
-
214
- Args:
215
- pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
216
- The tensors corresponding to the input images.
217
-
218
- Returns:
219
- Image feature tensor of shape `(num_images, image_length, embed_dim)`.
220
- """
221
- vision_outputs = self.vision_tower(pixel_values).last_hidden_state
222
- image_features = self.multi_modal_projector(vision_outputs)
207
+ # Projects the last hidden state from the vision model into language model space.
208
+
209
+ # Args:
210
+ # pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
211
+ # The tensors corresponding to the input images.
212
+
213
+ # Returns:
214
+ # Image feature tensor of shape `(num_images, image_length, embed_dim)`.
215
+
216
+ vision_out_buffer = []
217
+ vision_out_size = [
218
+ pixel_values.shape[0],
219
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
220
+ self.config.vision_config.hidden_size,
221
+ ]
222
+ projector_out_size = [
223
+ pixel_values.shape[0],
224
+ self.config.mm_tokens_per_image,
225
+ self.config.text_config.hidden_size,
226
+ ]
227
+ vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
228
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
229
+ vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
230
+ image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
223
231
  return image_features
224
232
 
225
233
  def _preprocess_prefill(
@@ -254,17 +262,45 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
254
262
 
255
263
  return inputs_embeds
256
264
 
265
+ def get_padded_cache_position(
266
+ self,
267
+ cache_position: torch.Tensor, # shape: [1, seq_len]
268
+ token_type_ids: torch.Tensor, # shape: [1, seq_len]
269
+ ) -> torch.Tensor:
270
+ seq_len = cache_position[0][-1].item() + 1
271
+
272
+ # Find image start positions
273
+ image_starts = [
274
+ s
275
+ for s in torch.where(token_type_ids == 1)[1]
276
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
277
+ ]
278
+
279
+ # Initialize padded tensors
280
+ padded_input_len = seq_len
281
+ for image_start in image_starts:
282
+ pad_needed = (
283
+ self.rbln_config.image_prefill_chunk_size
284
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
285
+ ) % self.rbln_config.image_prefill_chunk_size
286
+ padded_input_len += pad_needed
287
+
288
+ return torch.cat(
289
+ [cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
290
+ dim=1,
291
+ )
292
+
257
293
  def forward(
258
294
  self,
259
295
  input_ids: torch.LongTensor = None,
296
+ attention_mask: torch.Tensor = None,
297
+ token_type_ids: torch.Tensor = None,
260
298
  pixel_values: torch.FloatTensor = None,
261
- attention_mask: Optional[torch.Tensor] = None,
262
299
  cache_position: Optional[torch.LongTensor] = None,
263
300
  inputs_embeds: Optional[torch.FloatTensor] = None,
264
301
  generate_idx: Optional[torch.Tensor] = None,
265
302
  padded_cache_lengths: Optional[torch.Tensor] = None,
266
303
  position_ids: Optional[torch.Tensor] = None,
267
- token_type_ids: Optional[torch.Tensor] = None,
268
304
  **lm_kwargs: Dict[str, Any],
269
305
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
270
306
  # prefill
@@ -275,12 +311,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
275
311
 
276
312
  for b_idx in range(batch_size):
277
313
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
314
+ token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
315
+ cache_position = self.get_padded_cache_position(cache_position, token_type_id)
316
+
278
317
  output = self.language_model.prefill_decoder(
279
318
  inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
280
319
  attention_mask=attention_mask[b_idx],
281
320
  cache_position=cache_position,
282
321
  batch_idx=b_idx,
283
- token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
322
+ token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
284
323
  )
285
324
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
286
325
  logits.append(output.logits)
@@ -309,209 +348,6 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
309
348
  )
310
349
 
311
350
 
312
- class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
313
- def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
314
- super().__init__(*args, **kwargs)
315
- self.image_prefill = image_prefill # FIXME(taehoon)
316
- self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
317
- self.decode = self.runtime if self.phase == "decode" else None
318
-
319
- def _prepare_prefill_inputs(self, *args, **kwargs):
320
- (
321
- inputs,
322
- cache_position,
323
- chunked_attention_mask,
324
- out_buffers,
325
- position_ids,
326
- position_embed,
327
- padded_cache_lengths,
328
- query_length,
329
- token_type_ids,
330
- ) = super()._prepare_prefill_inputs(*args, **kwargs)
331
-
332
- # chunked_attention_mask shape
333
- chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
334
-
335
- # as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
336
- padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
337
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
338
- cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
339
- position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
340
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
341
-
342
- return (
343
- inputs,
344
- cache_position,
345
- chunked_attention_mask,
346
- out_buffers,
347
- position_ids,
348
- position_embed,
349
- padded_cache_lengths,
350
- query_length,
351
- token_type_ids,
352
- )
353
-
354
- def prefill_forward(
355
- self,
356
- inputs: torch.Tensor,
357
- cache_position: torch.Tensor = None,
358
- attention_mask: Optional[torch.Tensor] = None,
359
- batch_idx: int = None,
360
- block_tables: torch.Tensor = None,
361
- is_external_block_tables: bool = None,
362
- position_embed: Optional[torch.Tensor] = None,
363
- token_type_ids: Optional[torch.Tensor] = None,
364
- local_block_tables: Optional[torch.Tensor] = None,
365
- ) -> torch.FloatTensor:
366
- """
367
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
368
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
369
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
370
- """
371
- (
372
- inputs,
373
- cache_position,
374
- chunked_attention_mask,
375
- out_buffers,
376
- position_ids,
377
- position_embed,
378
- padded_cache_lengths,
379
- query_length,
380
- token_type_ids,
381
- ) = self._prepare_prefill_inputs(
382
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
383
- )
384
-
385
- step = 0
386
- while step < query_length:
387
- # Check if the prefill chunk is an image prefill
388
- is_image_prefill = torch.all(
389
- token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
390
- )
391
- prefill_chunk_size = (
392
- self.rbln_config.image_prefill_chunk_size if is_image_prefill else self.rbln_config.prefill_chunk_size
393
- )
394
-
395
- # Check if the prefill chunk is a text prefill which have image_tokens in it.
396
- is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
397
- token_type_ids[:, step : step + prefill_chunk_size] == 1
398
- )
399
-
400
- # Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
401
- is_cross_block_boundary = (
402
- step // self.rbln_config.kvcache_block_size
403
- != (step + prefill_chunk_size) // self.rbln_config.kvcache_block_size
404
- )
405
-
406
- # Check if the prefill chunk is the last chunk
407
- is_last_chunk = step + prefill_chunk_size >= query_length
408
-
409
- if is_cross_block_boundary:
410
- padding_size = prefill_chunk_size - (step + prefill_chunk_size) % self.rbln_config.kvcache_block_size
411
- padded_cache_lengths += padding_size
412
-
413
- # if text_prefill end with image_tokens, we only treat the text part.
414
- num_processed_tokens = prefill_chunk_size
415
- if is_text_prefill_with_image_tokens:
416
- first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
417
- num_processed_tokens = first_image_token_idx
418
- if is_last_chunk:
419
- num_processed_tokens = query_length - step
420
-
421
- input_chunk = inputs[:, step : step + prefill_chunk_size]
422
- cache_pos_chunk = cache_position[:, step : step + prefill_chunk_size].clone() + padded_cache_lengths
423
- position_ids_chunk = position_ids[:, step : step + prefill_chunk_size].clone()
424
- chunked_attention_mask[
425
- :, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
426
- ] = 1
427
- query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
428
-
429
- if is_image_prefill:
430
- logits = self.image_prefill(
431
- input_chunk,
432
- cache_pos_chunk,
433
- block_tables,
434
- local_block_tables,
435
- query_position,
436
- chunked_attention_mask,
437
- position_ids_chunk,
438
- out=out_buffers,
439
- )
440
- else:
441
- logits = self.prefill(
442
- input_chunk,
443
- cache_pos_chunk,
444
- block_tables,
445
- local_block_tables,
446
- query_position,
447
- chunked_attention_mask,
448
- position_ids_chunk,
449
- out=out_buffers,
450
- )
451
-
452
- step += num_processed_tokens
453
-
454
- if not is_external_block_tables:
455
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
456
-
457
- return RBLNGemma3ForCausalLMOutput(
458
- logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
459
- )
460
-
461
- def decode_forward(
462
- self,
463
- inputs: torch.Tensor,
464
- cache_position: torch.Tensor = None,
465
- block_tables: torch.Tensor = None,
466
- is_external_block_tables: bool = None,
467
- attention_mask: Optional[torch.Tensor] = None,
468
- position_embed: Optional[torch.Tensor] = None,
469
- position_ids: Optional[torch.Tensor] = None,
470
- local_block_tables: Optional[torch.Tensor] = None,
471
- ) -> torch.FloatTensor:
472
- batch_size = inputs.shape[0]
473
- if batch_size != self.batch_size:
474
- raise RuntimeError(
475
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
476
- )
477
-
478
- if batch_size != cache_position.shape[0]:
479
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
480
-
481
- # FIXME(taehoon): how to handle pos_attn_mask with external block tables
482
- if is_external_block_tables:
483
- if attention_mask is None:
484
- raise ValueError("attention_mask should be provided with external block tables.")
485
- if local_block_tables is None:
486
- raise ValueError("local_block_tables should be provided with external block tables.")
487
- else:
488
- local_block_tables = (
489
- local_block_tables
490
- if local_block_tables is not None
491
- else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
492
- )
493
- if self.rbln_config.use_attention_mask and attention_mask is None:
494
- for b_idx in range(batch_size):
495
- decoding_step = cache_position[b_idx].item()
496
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
497
- raise ValueError(
498
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
499
- )
500
- self.dec_attn_mask[b_idx, decoding_step] = 1
501
-
502
- attention_mask = self.dec_attn_mask
503
-
504
- if self.batch_size < block_tables.shape[0]:
505
- block_tables = block_tables[: self.batch_size]
506
-
507
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
508
- attention_mask = attention_mask[: self.batch_size]
509
-
510
- logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
511
-
512
- return RBLNDecoderOnlyOutput(logits=logits)
513
-
514
-
515
351
  class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
516
352
  """
517
353
  The Gemma3 Model transformer with a language modeling head (linear layer) on top.
@@ -524,52 +360,36 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
524
360
  """
525
361
 
526
362
  _decoder_wrapper_cls = Gemma3ForCausalLMWrapper
363
+ _supports_non_fp32 = False
527
364
 
528
- def __post_init__(self, **kwargs):
529
- main_input_name = self.main_input_name
530
-
531
- if self.rbln_config.use_inputs_embeds:
532
- main_input_name = "inputs_embeds"
533
- artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
534
- self.embed_tokens = self._create_embedding_layer()
535
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
536
- else:
537
- self.embed_tokens = None
538
-
365
+ def setup_runtime(self):
539
366
  # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
540
367
  dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
541
- block_tables = torch.zeros(
542
- self.rbln_config.batch_size,
543
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
544
- dtype=torch.int16,
545
- ).fill_(-1)
546
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
368
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
369
+
370
+ common_kwargs = {
371
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
372
+ "embed_tokens": self.embed_tokens,
373
+ "dec_attn_mask": dec_attn_mask,
374
+ "page_table_manager": page_table_manager,
375
+ "rbln_config": self.rbln_config,
376
+ }
377
+
547
378
  self.prefill_decoder = RBLNGemma3RuntimeModel(
548
379
  runtime=self.model[0],
549
- image_prefill=self.model[1],
550
- main_input_name=main_input_name,
551
- embed_tokens=self.embed_tokens,
380
+ image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
552
381
  phase="prefill",
553
382
  batch_size=self.rbln_config.batch_size,
554
- dec_attn_mask=dec_attn_mask,
555
- block_tables=block_tables,
556
- vocab_size=self.config.vocab_size,
557
- free_block_pool=free_block_pool,
558
- rbln_config=self.rbln_config,
383
+ **common_kwargs,
559
384
  )
560
385
 
561
386
  self.decoders = {}
562
387
  for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
563
388
  self.decoders[batch_size] = RBLNGemma3RuntimeModel(
564
- runtime=self.model[i + 2],
565
- main_input_name=main_input_name,
566
- embed_tokens=self.embed_tokens,
389
+ runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
567
390
  phase="decode",
568
391
  batch_size=batch_size,
569
- dec_attn_mask=dec_attn_mask,
570
- block_tables=block_tables,
571
- free_block_pool=free_block_pool,
572
- rbln_config=self.rbln_config,
392
+ **common_kwargs,
573
393
  )
574
394
 
575
395
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -589,6 +409,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
589
409
  def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
590
410
  sliding_window = getattr(model_config, "sliding_window", None)
591
411
  sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
412
+ if sliding_window_pattern is None:
413
+ if hasattr(model_config, "layer_types"):
414
+ first_full_attention_index = model_config.layer_types.index("full_attention")
415
+ sliding_window_pattern = first_full_attention_index + 1
416
+ else:
417
+ raise ValueError("Cannot determine sliding_window_pattern from model_config")
418
+
592
419
  if sliding_window_pattern <= model_config.num_hidden_layers:
593
420
  rbln_config.cache_impl = "hybrid"
594
421
  rbln_config.sliding_window = sliding_window
@@ -599,7 +426,12 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
599
426
  return rbln_config
600
427
 
601
428
  @classmethod
602
- def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
429
+ def _update_submodule_config(
430
+ cls,
431
+ model: "PreTrainedModel",
432
+ rbln_config: RBLNModelConfig,
433
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
434
+ ):
603
435
  if rbln_config.image_prefill_chunk_size is None:
604
436
  rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
605
437
 
@@ -624,20 +456,26 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
624
456
  if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
625
457
  raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
626
458
 
627
- # Update image prefill compile config
628
- img_prefill_input_info = cls.get_input_info(
629
- batch_size=1,
630
- query_length=rbln_config.image_prefill_chunk_size,
631
- rbln_config=rbln_config,
632
- model_config=model_config,
633
- )
634
- image_prefill_compile_config = RBLNCompileConfig(
635
- compiled_model_name="image_prefill", input_info=img_prefill_input_info
636
- )
637
- # Insert image_prefill compile config at index 1
638
- compile_cfgs = rbln_config.compile_cfgs
639
- compile_cfgs.insert(1, image_prefill_compile_config)
640
- rbln_config.set_compile_cfgs(compile_cfgs)
459
+ if rbln_config.use_image_prefill:
460
+ if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
461
+ raise NotImplementedError(
462
+ "Not implemented for different prefill chunk sizes between text and image prefill."
463
+ )
464
+
465
+ # Update image prefill compile config
466
+ img_prefill_input_info = cls.get_input_info(
467
+ batch_size=1,
468
+ query_length=rbln_config.image_prefill_chunk_size,
469
+ rbln_config=rbln_config,
470
+ model_config=model_config,
471
+ )
472
+ image_prefill_compile_config = RBLNCompileConfig(
473
+ compiled_model_name="image_prefill", input_info=img_prefill_input_info
474
+ )
475
+ # Insert image_prefill compile config at index 1
476
+ compile_cfgs = rbln_config.compile_cfgs
477
+ compile_cfgs.insert(1, image_prefill_compile_config)
478
+ rbln_config.set_compile_cfgs(compile_cfgs)
641
479
 
642
480
  return rbln_config
643
481
 
@@ -690,23 +528,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
690
528
  context,
691
529
  rbln_config.quantization,
692
530
  )
531
+ compiled_models = {"prefill": compiled_prefill}
693
532
 
694
- image_prefill_compile_config = rbln_compile_configs[1]
695
- image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
696
- fill=0, static_tensors=static_tensors
697
- )
698
- wrapped_model.phase = "image_prefill"
699
- compiled_image_prefill = compile_model(
700
- wrapped_model,
701
- image_prefill_compile_config,
702
- image_prefill_example_inputs,
703
- context,
704
- rbln_config.quantization,
705
- )
533
+ if rbln_config.use_image_prefill:
534
+ image_prefill_compile_config = rbln_compile_configs[1]
535
+ image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
536
+ fill=0, static_tensors=static_tensors
537
+ )
538
+ wrapped_model.phase = "image_prefill"
539
+ compiled_image_prefill = compile_model(
540
+ wrapped_model,
541
+ image_prefill_compile_config,
542
+ image_prefill_example_inputs,
543
+ context,
544
+ rbln_config.quantization,
545
+ )
546
+ compiled_models["image_prefill"] = compiled_image_prefill
706
547
 
707
- compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
708
548
  wrapped_model.phase = "decode"
709
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[2:]):
549
+ for batch_size, dec_compile_config in zip(
550
+ rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
551
+ ):
710
552
  dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
711
553
  compiled_decoder = compile_model(
712
554
  wrapped_model,
@@ -727,35 +569,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
727
569
  ) -> List[rebel.Runtime]:
728
570
  expected_model_names = [
729
571
  "prefill",
730
- "image_prefill",
731
572
  *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
732
573
  ]
574
+ if rbln_config.use_image_prefill:
575
+ expected_model_names.insert(1, "image_prefill")
576
+
733
577
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
734
578
  cls._raise_missing_compiled_file_error(expected_model_names)
735
579
 
736
- return [
580
+ ret_val = [
737
581
  rebel.Runtime(
738
582
  compiled_models[0],
739
583
  tensor_type="pt",
740
584
  device=rbln_config.device_map["prefill"],
741
585
  activate_profiler=rbln_config.activate_profiler,
742
586
  timeout=rbln_config.timeout,
743
- ),
744
- rebel.Runtime(
745
- compiled_models[1],
746
- tensor_type="pt",
747
- device=rbln_config.device_map["image_prefill"],
748
- activate_profiler=rbln_config.activate_profiler,
749
- timeout=rbln_config.timeout,
750
- ),
751
- *[
587
+ )
588
+ ]
589
+ if rbln_config.use_image_prefill:
590
+ ret_val.append(
591
+ rebel.Runtime(
592
+ compiled_models[1],
593
+ tensor_type="pt",
594
+ device=rbln_config.device_map["image_prefill"],
595
+ activate_profiler=rbln_config.activate_profiler,
596
+ timeout=rbln_config.timeout,
597
+ ),
598
+ )
599
+
600
+ ret_val.extend(
601
+ [
752
602
  rebel.Runtime(
753
- compiled_models[i + 2],
603
+ compiled_models[i + rbln_config.decoder_runtime_idx],
754
604
  tensor_type="pt",
755
605
  device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
756
606
  activate_profiler=rbln_config.activate_profiler,
757
607
  timeout=rbln_config.timeout,
758
608
  )
759
609
  for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
760
- ],
761
- ]
610
+ ]
611
+ )
612
+
613
+ return ret_val