optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -11,95 +11,74 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import importlib
|
|
14
15
|
import inspect
|
|
15
|
-
from collections import deque
|
|
16
|
-
from dataclasses import dataclass
|
|
17
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
|
18
17
|
|
|
19
18
|
import rebel
|
|
20
19
|
import torch
|
|
21
20
|
from rebel.compile_context import CompileContext
|
|
22
|
-
from transformers import
|
|
23
|
-
AutoModelForImageTextToText,
|
|
24
|
-
Gemma3ForConditionalGeneration,
|
|
25
|
-
PretrainedConfig,
|
|
26
|
-
PreTrainedModel,
|
|
27
|
-
)
|
|
21
|
+
from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
|
|
28
22
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
29
23
|
from transformers.modeling_utils import no_init_weights
|
|
30
24
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
|
|
31
25
|
|
|
32
26
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
33
27
|
from ....modeling import RBLNModel
|
|
34
|
-
from
|
|
28
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
29
|
+
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
30
|
+
from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
|
|
31
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
32
|
+
from ..decoderonly.modeling_decoderonly import (
|
|
33
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
34
|
+
)
|
|
35
35
|
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
|
36
36
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
|
37
|
+
from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
if TYPE_CHECKING:
|
|
40
41
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class LoopVisionTower:
|
|
49
|
-
def __init__(self, vision_tower: RBLNModel) -> None:
|
|
50
|
-
self.vision_tower = vision_tower
|
|
44
|
+
class LoopVisionTower(LoopProcessor):
|
|
45
|
+
def __init__(self, vision_tower: "RBLNModel"):
|
|
46
|
+
super().__init__(model=vision_tower)
|
|
51
47
|
|
|
52
|
-
def
|
|
53
|
-
|
|
54
|
-
# shape of pixel_values : [batch, num_channel, height, width]
|
|
55
|
-
pixel_values = args[0]
|
|
48
|
+
def _get_batch_size(self, pixel_values, **kwargs):
|
|
49
|
+
return pixel_values.shape[0]
|
|
56
50
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
for
|
|
60
|
-
|
|
51
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
|
|
52
|
+
pixel_values_item = pixel_values[index : index + 1]
|
|
53
|
+
out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
|
|
54
|
+
return ([pixel_values_item], {"out": out_buffer})
|
|
61
55
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
# FIXME:: This can be optimized using out= API of rbln runtime.
|
|
65
|
-
last_hidden_states = torch.cat(last_hidden_states, dim=0)
|
|
56
|
+
def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
|
|
57
|
+
output = kwargs["out"]
|
|
66
58
|
|
|
67
59
|
return BaseModelOutputWithPooling(
|
|
68
|
-
last_hidden_state=
|
|
60
|
+
last_hidden_state=output[0],
|
|
69
61
|
)
|
|
70
62
|
|
|
71
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
72
|
-
return self.forward(*args, **kwds)
|
|
73
|
-
|
|
74
|
-
def __repr__(self) -> str:
|
|
75
|
-
return repr(self.vision_tower)
|
|
76
|
-
|
|
77
63
|
|
|
78
|
-
class LoopProjector:
|
|
79
|
-
def __init__(self, multi_modal_projector)
|
|
80
|
-
|
|
64
|
+
class LoopProjector(LoopProcessor):
|
|
65
|
+
def __init__(self, multi_modal_projector: "RBLNModel"):
|
|
66
|
+
super().__init__(model=multi_modal_projector)
|
|
81
67
|
|
|
82
|
-
def
|
|
83
|
-
|
|
84
|
-
image_feature = args[0]
|
|
68
|
+
def _get_batch_size(self, image_feature, **kwargs):
|
|
69
|
+
return image_feature.shape[0]
|
|
85
70
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
for
|
|
89
|
-
|
|
71
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
|
|
72
|
+
image_feature_item = image_feature[index : index + 1]
|
|
73
|
+
out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
|
|
74
|
+
return ([image_feature_item], {"out": out_buffer})
|
|
90
75
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
return
|
|
76
|
+
def _process_outputs(self, outputs: list, **kwargs):
|
|
77
|
+
output = kwargs["out"]
|
|
78
|
+
return output[0]
|
|
94
79
|
|
|
95
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
96
|
-
return self.forward(*args, **kwds)
|
|
97
80
|
|
|
98
|
-
|
|
99
|
-
return repr(self.multi_modal_projector)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
81
|
+
class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
103
82
|
auto_model_class = AutoModelForImageTextToText
|
|
104
83
|
_rbln_submodules = [
|
|
105
84
|
{"name": "vision_tower"},
|
|
@@ -119,6 +98,23 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
119
98
|
def can_generate(self):
|
|
120
99
|
return True
|
|
121
100
|
|
|
101
|
+
@classmethod
|
|
102
|
+
def get_pytorch_model(cls, *args, **kwargs):
|
|
103
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
|
104
|
+
|
|
105
|
+
with no_init_weights():
|
|
106
|
+
model_cls_name = model.model.language_model.__class__.__name__
|
|
107
|
+
causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
|
|
108
|
+
causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
|
|
109
|
+
new_language_model = causal_model_cls(model.model.language_model.config)
|
|
110
|
+
|
|
111
|
+
new_language_model.lm_head = model.lm_head
|
|
112
|
+
new_language_model.model = model.model.language_model
|
|
113
|
+
model.model.language_model = new_language_model
|
|
114
|
+
model.lm_head = None
|
|
115
|
+
del model.lm_head
|
|
116
|
+
return model
|
|
117
|
+
|
|
122
118
|
def __post_init__(self, **kwargs):
|
|
123
119
|
self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
|
|
124
120
|
self.language_model = self.rbln_submodules[1]
|
|
@@ -208,18 +204,30 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
208
204
|
return model_kwargs
|
|
209
205
|
|
|
210
206
|
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
207
|
+
# Projects the last hidden state from the vision model into language model space.
|
|
208
|
+
|
|
209
|
+
# Args:
|
|
210
|
+
# pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
|
|
211
|
+
# The tensors corresponding to the input images.
|
|
212
|
+
|
|
213
|
+
# Returns:
|
|
214
|
+
# Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
|
215
|
+
|
|
216
|
+
vision_out_buffer = []
|
|
217
|
+
vision_out_size = [
|
|
218
|
+
pixel_values.shape[0],
|
|
219
|
+
(self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
|
|
220
|
+
self.config.vision_config.hidden_size,
|
|
221
|
+
]
|
|
222
|
+
projector_out_size = [
|
|
223
|
+
pixel_values.shape[0],
|
|
224
|
+
self.config.mm_tokens_per_image,
|
|
225
|
+
self.config.text_config.hidden_size,
|
|
226
|
+
]
|
|
227
|
+
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
|
|
228
|
+
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
|
|
229
|
+
vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
|
|
230
|
+
image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
|
|
223
231
|
return image_features
|
|
224
232
|
|
|
225
233
|
def _preprocess_prefill(
|
|
@@ -254,17 +262,45 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
254
262
|
|
|
255
263
|
return inputs_embeds
|
|
256
264
|
|
|
265
|
+
def get_padded_cache_position(
|
|
266
|
+
self,
|
|
267
|
+
cache_position: torch.Tensor, # shape: [1, seq_len]
|
|
268
|
+
token_type_ids: torch.Tensor, # shape: [1, seq_len]
|
|
269
|
+
) -> torch.Tensor:
|
|
270
|
+
seq_len = cache_position[0][-1].item() + 1
|
|
271
|
+
|
|
272
|
+
# Find image start positions
|
|
273
|
+
image_starts = [
|
|
274
|
+
s
|
|
275
|
+
for s in torch.where(token_type_ids == 1)[1]
|
|
276
|
+
if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
# Initialize padded tensors
|
|
280
|
+
padded_input_len = seq_len
|
|
281
|
+
for image_start in image_starts:
|
|
282
|
+
pad_needed = (
|
|
283
|
+
self.rbln_config.image_prefill_chunk_size
|
|
284
|
+
- (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
|
|
285
|
+
) % self.rbln_config.image_prefill_chunk_size
|
|
286
|
+
padded_input_len += pad_needed
|
|
287
|
+
|
|
288
|
+
return torch.cat(
|
|
289
|
+
[cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
|
|
290
|
+
dim=1,
|
|
291
|
+
)
|
|
292
|
+
|
|
257
293
|
def forward(
|
|
258
294
|
self,
|
|
259
295
|
input_ids: torch.LongTensor = None,
|
|
296
|
+
attention_mask: torch.Tensor = None,
|
|
297
|
+
token_type_ids: torch.Tensor = None,
|
|
260
298
|
pixel_values: torch.FloatTensor = None,
|
|
261
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
262
299
|
cache_position: Optional[torch.LongTensor] = None,
|
|
263
300
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
264
301
|
generate_idx: Optional[torch.Tensor] = None,
|
|
265
302
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
266
303
|
position_ids: Optional[torch.Tensor] = None,
|
|
267
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
268
304
|
**lm_kwargs: Dict[str, Any],
|
|
269
305
|
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
|
270
306
|
# prefill
|
|
@@ -275,12 +311,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
275
311
|
|
|
276
312
|
for b_idx in range(batch_size):
|
|
277
313
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
314
|
+
token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
|
315
|
+
cache_position = self.get_padded_cache_position(cache_position, token_type_id)
|
|
316
|
+
|
|
278
317
|
output = self.language_model.prefill_decoder(
|
|
279
318
|
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
|
280
319
|
attention_mask=attention_mask[b_idx],
|
|
281
320
|
cache_position=cache_position,
|
|
282
321
|
batch_idx=b_idx,
|
|
283
|
-
token_type_ids=token_type_ids[b_idx : b_idx + 1]
|
|
322
|
+
token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
|
|
284
323
|
)
|
|
285
324
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
286
325
|
logits.append(output.logits)
|
|
@@ -309,209 +348,6 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
309
348
|
)
|
|
310
349
|
|
|
311
350
|
|
|
312
|
-
class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
313
|
-
def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
|
|
314
|
-
super().__init__(*args, **kwargs)
|
|
315
|
-
self.image_prefill = image_prefill # FIXME(taehoon)
|
|
316
|
-
self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
|
|
317
|
-
self.decode = self.runtime if self.phase == "decode" else None
|
|
318
|
-
|
|
319
|
-
def _prepare_prefill_inputs(self, *args, **kwargs):
|
|
320
|
-
(
|
|
321
|
-
inputs,
|
|
322
|
-
cache_position,
|
|
323
|
-
chunked_attention_mask,
|
|
324
|
-
out_buffers,
|
|
325
|
-
position_ids,
|
|
326
|
-
position_embed,
|
|
327
|
-
padded_cache_lengths,
|
|
328
|
-
query_length,
|
|
329
|
-
token_type_ids,
|
|
330
|
-
) = super()._prepare_prefill_inputs(*args, **kwargs)
|
|
331
|
-
|
|
332
|
-
# chunked_attention_mask shape
|
|
333
|
-
chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
|
|
334
|
-
|
|
335
|
-
# as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
|
|
336
|
-
padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
|
|
337
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
338
|
-
cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
|
|
339
|
-
position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
|
|
340
|
-
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
|
|
341
|
-
|
|
342
|
-
return (
|
|
343
|
-
inputs,
|
|
344
|
-
cache_position,
|
|
345
|
-
chunked_attention_mask,
|
|
346
|
-
out_buffers,
|
|
347
|
-
position_ids,
|
|
348
|
-
position_embed,
|
|
349
|
-
padded_cache_lengths,
|
|
350
|
-
query_length,
|
|
351
|
-
token_type_ids,
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
def prefill_forward(
|
|
355
|
-
self,
|
|
356
|
-
inputs: torch.Tensor,
|
|
357
|
-
cache_position: torch.Tensor = None,
|
|
358
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
359
|
-
batch_idx: int = None,
|
|
360
|
-
block_tables: torch.Tensor = None,
|
|
361
|
-
is_external_block_tables: bool = None,
|
|
362
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
363
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
364
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
365
|
-
) -> torch.FloatTensor:
|
|
366
|
-
"""
|
|
367
|
-
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
368
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
369
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
370
|
-
"""
|
|
371
|
-
(
|
|
372
|
-
inputs,
|
|
373
|
-
cache_position,
|
|
374
|
-
chunked_attention_mask,
|
|
375
|
-
out_buffers,
|
|
376
|
-
position_ids,
|
|
377
|
-
position_embed,
|
|
378
|
-
padded_cache_lengths,
|
|
379
|
-
query_length,
|
|
380
|
-
token_type_ids,
|
|
381
|
-
) = self._prepare_prefill_inputs(
|
|
382
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
step = 0
|
|
386
|
-
while step < query_length:
|
|
387
|
-
# Check if the prefill chunk is an image prefill
|
|
388
|
-
is_image_prefill = torch.all(
|
|
389
|
-
token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
|
|
390
|
-
)
|
|
391
|
-
prefill_chunk_size = (
|
|
392
|
-
self.rbln_config.image_prefill_chunk_size if is_image_prefill else self.rbln_config.prefill_chunk_size
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
# Check if the prefill chunk is a text prefill which have image_tokens in it.
|
|
396
|
-
is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
|
|
397
|
-
token_type_ids[:, step : step + prefill_chunk_size] == 1
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
# Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
|
|
401
|
-
is_cross_block_boundary = (
|
|
402
|
-
step // self.rbln_config.kvcache_block_size
|
|
403
|
-
!= (step + prefill_chunk_size) // self.rbln_config.kvcache_block_size
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
# Check if the prefill chunk is the last chunk
|
|
407
|
-
is_last_chunk = step + prefill_chunk_size >= query_length
|
|
408
|
-
|
|
409
|
-
if is_cross_block_boundary:
|
|
410
|
-
padding_size = prefill_chunk_size - (step + prefill_chunk_size) % self.rbln_config.kvcache_block_size
|
|
411
|
-
padded_cache_lengths += padding_size
|
|
412
|
-
|
|
413
|
-
# if text_prefill end with image_tokens, we only treat the text part.
|
|
414
|
-
num_processed_tokens = prefill_chunk_size
|
|
415
|
-
if is_text_prefill_with_image_tokens:
|
|
416
|
-
first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
|
|
417
|
-
num_processed_tokens = first_image_token_idx
|
|
418
|
-
if is_last_chunk:
|
|
419
|
-
num_processed_tokens = query_length - step
|
|
420
|
-
|
|
421
|
-
input_chunk = inputs[:, step : step + prefill_chunk_size]
|
|
422
|
-
cache_pos_chunk = cache_position[:, step : step + prefill_chunk_size].clone() + padded_cache_lengths
|
|
423
|
-
position_ids_chunk = position_ids[:, step : step + prefill_chunk_size].clone()
|
|
424
|
-
chunked_attention_mask[
|
|
425
|
-
:, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
|
|
426
|
-
] = 1
|
|
427
|
-
query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
|
|
428
|
-
|
|
429
|
-
if is_image_prefill:
|
|
430
|
-
logits = self.image_prefill(
|
|
431
|
-
input_chunk,
|
|
432
|
-
cache_pos_chunk,
|
|
433
|
-
block_tables,
|
|
434
|
-
local_block_tables,
|
|
435
|
-
query_position,
|
|
436
|
-
chunked_attention_mask,
|
|
437
|
-
position_ids_chunk,
|
|
438
|
-
out=out_buffers,
|
|
439
|
-
)
|
|
440
|
-
else:
|
|
441
|
-
logits = self.prefill(
|
|
442
|
-
input_chunk,
|
|
443
|
-
cache_pos_chunk,
|
|
444
|
-
block_tables,
|
|
445
|
-
local_block_tables,
|
|
446
|
-
query_position,
|
|
447
|
-
chunked_attention_mask,
|
|
448
|
-
position_ids_chunk,
|
|
449
|
-
out=out_buffers,
|
|
450
|
-
)
|
|
451
|
-
|
|
452
|
-
step += num_processed_tokens
|
|
453
|
-
|
|
454
|
-
if not is_external_block_tables:
|
|
455
|
-
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
456
|
-
|
|
457
|
-
return RBLNGemma3ForCausalLMOutput(
|
|
458
|
-
logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
|
|
459
|
-
)
|
|
460
|
-
|
|
461
|
-
def decode_forward(
|
|
462
|
-
self,
|
|
463
|
-
inputs: torch.Tensor,
|
|
464
|
-
cache_position: torch.Tensor = None,
|
|
465
|
-
block_tables: torch.Tensor = None,
|
|
466
|
-
is_external_block_tables: bool = None,
|
|
467
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
468
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
469
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
470
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
471
|
-
) -> torch.FloatTensor:
|
|
472
|
-
batch_size = inputs.shape[0]
|
|
473
|
-
if batch_size != self.batch_size:
|
|
474
|
-
raise RuntimeError(
|
|
475
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
476
|
-
)
|
|
477
|
-
|
|
478
|
-
if batch_size != cache_position.shape[0]:
|
|
479
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
480
|
-
|
|
481
|
-
# FIXME(taehoon): how to handle pos_attn_mask with external block tables
|
|
482
|
-
if is_external_block_tables:
|
|
483
|
-
if attention_mask is None:
|
|
484
|
-
raise ValueError("attention_mask should be provided with external block tables.")
|
|
485
|
-
if local_block_tables is None:
|
|
486
|
-
raise ValueError("local_block_tables should be provided with external block tables.")
|
|
487
|
-
else:
|
|
488
|
-
local_block_tables = (
|
|
489
|
-
local_block_tables
|
|
490
|
-
if local_block_tables is not None
|
|
491
|
-
else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
|
492
|
-
)
|
|
493
|
-
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
494
|
-
for b_idx in range(batch_size):
|
|
495
|
-
decoding_step = cache_position[b_idx].item()
|
|
496
|
-
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
497
|
-
raise ValueError(
|
|
498
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
499
|
-
)
|
|
500
|
-
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
501
|
-
|
|
502
|
-
attention_mask = self.dec_attn_mask
|
|
503
|
-
|
|
504
|
-
if self.batch_size < block_tables.shape[0]:
|
|
505
|
-
block_tables = block_tables[: self.batch_size]
|
|
506
|
-
|
|
507
|
-
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
508
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
509
|
-
|
|
510
|
-
logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
|
|
511
|
-
|
|
512
|
-
return RBLNDecoderOnlyOutput(logits=logits)
|
|
513
|
-
|
|
514
|
-
|
|
515
351
|
class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
516
352
|
"""
|
|
517
353
|
The Gemma3 Model transformer with a language modeling head (linear layer) on top.
|
|
@@ -524,52 +360,36 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
524
360
|
"""
|
|
525
361
|
|
|
526
362
|
_decoder_wrapper_cls = Gemma3ForCausalLMWrapper
|
|
363
|
+
_supports_non_fp32 = False
|
|
527
364
|
|
|
528
|
-
def
|
|
529
|
-
main_input_name = self.main_input_name
|
|
530
|
-
|
|
531
|
-
if self.rbln_config.use_inputs_embeds:
|
|
532
|
-
main_input_name = "inputs_embeds"
|
|
533
|
-
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
534
|
-
self.embed_tokens = self._create_embedding_layer()
|
|
535
|
-
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
536
|
-
else:
|
|
537
|
-
self.embed_tokens = None
|
|
538
|
-
|
|
365
|
+
def setup_runtime(self):
|
|
539
366
|
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
|
540
367
|
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
368
|
+
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
369
|
+
|
|
370
|
+
common_kwargs = {
|
|
371
|
+
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
372
|
+
"embed_tokens": self.embed_tokens,
|
|
373
|
+
"dec_attn_mask": dec_attn_mask,
|
|
374
|
+
"page_table_manager": page_table_manager,
|
|
375
|
+
"rbln_config": self.rbln_config,
|
|
376
|
+
}
|
|
377
|
+
|
|
547
378
|
self.prefill_decoder = RBLNGemma3RuntimeModel(
|
|
548
379
|
runtime=self.model[0],
|
|
549
|
-
image_prefill=self.model[1],
|
|
550
|
-
main_input_name=main_input_name,
|
|
551
|
-
embed_tokens=self.embed_tokens,
|
|
380
|
+
image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
|
|
552
381
|
phase="prefill",
|
|
553
382
|
batch_size=self.rbln_config.batch_size,
|
|
554
|
-
|
|
555
|
-
block_tables=block_tables,
|
|
556
|
-
vocab_size=self.config.vocab_size,
|
|
557
|
-
free_block_pool=free_block_pool,
|
|
558
|
-
rbln_config=self.rbln_config,
|
|
383
|
+
**common_kwargs,
|
|
559
384
|
)
|
|
560
385
|
|
|
561
386
|
self.decoders = {}
|
|
562
387
|
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
563
388
|
self.decoders[batch_size] = RBLNGemma3RuntimeModel(
|
|
564
|
-
runtime=self.model[i +
|
|
565
|
-
main_input_name=main_input_name,
|
|
566
|
-
embed_tokens=self.embed_tokens,
|
|
389
|
+
runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
|
|
567
390
|
phase="decode",
|
|
568
391
|
batch_size=batch_size,
|
|
569
|
-
|
|
570
|
-
block_tables=block_tables,
|
|
571
|
-
free_block_pool=free_block_pool,
|
|
572
|
-
rbln_config=self.rbln_config,
|
|
392
|
+
**common_kwargs,
|
|
573
393
|
)
|
|
574
394
|
|
|
575
395
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
@@ -589,6 +409,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
589
409
|
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
590
410
|
sliding_window = getattr(model_config, "sliding_window", None)
|
|
591
411
|
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
|
412
|
+
if sliding_window_pattern is None:
|
|
413
|
+
if hasattr(model_config, "layer_types"):
|
|
414
|
+
first_full_attention_index = model_config.layer_types.index("full_attention")
|
|
415
|
+
sliding_window_pattern = first_full_attention_index + 1
|
|
416
|
+
else:
|
|
417
|
+
raise ValueError("Cannot determine sliding_window_pattern from model_config")
|
|
418
|
+
|
|
592
419
|
if sliding_window_pattern <= model_config.num_hidden_layers:
|
|
593
420
|
rbln_config.cache_impl = "hybrid"
|
|
594
421
|
rbln_config.sliding_window = sliding_window
|
|
@@ -599,7 +426,12 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
599
426
|
return rbln_config
|
|
600
427
|
|
|
601
428
|
@classmethod
|
|
602
|
-
def _update_submodule_config(
|
|
429
|
+
def _update_submodule_config(
|
|
430
|
+
cls,
|
|
431
|
+
model: "PreTrainedModel",
|
|
432
|
+
rbln_config: RBLNModelConfig,
|
|
433
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
434
|
+
):
|
|
603
435
|
if rbln_config.image_prefill_chunk_size is None:
|
|
604
436
|
rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
|
|
605
437
|
|
|
@@ -624,20 +456,26 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
624
456
|
if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
|
|
625
457
|
raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
|
|
626
458
|
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
459
|
+
if rbln_config.use_image_prefill:
|
|
460
|
+
if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
|
|
461
|
+
raise NotImplementedError(
|
|
462
|
+
"Not implemented for different prefill chunk sizes between text and image prefill."
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
# Update image prefill compile config
|
|
466
|
+
img_prefill_input_info = cls.get_input_info(
|
|
467
|
+
batch_size=1,
|
|
468
|
+
query_length=rbln_config.image_prefill_chunk_size,
|
|
469
|
+
rbln_config=rbln_config,
|
|
470
|
+
model_config=model_config,
|
|
471
|
+
)
|
|
472
|
+
image_prefill_compile_config = RBLNCompileConfig(
|
|
473
|
+
compiled_model_name="image_prefill", input_info=img_prefill_input_info
|
|
474
|
+
)
|
|
475
|
+
# Insert image_prefill compile config at index 1
|
|
476
|
+
compile_cfgs = rbln_config.compile_cfgs
|
|
477
|
+
compile_cfgs.insert(1, image_prefill_compile_config)
|
|
478
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
641
479
|
|
|
642
480
|
return rbln_config
|
|
643
481
|
|
|
@@ -690,23 +528,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
690
528
|
context,
|
|
691
529
|
rbln_config.quantization,
|
|
692
530
|
)
|
|
531
|
+
compiled_models = {"prefill": compiled_prefill}
|
|
693
532
|
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
533
|
+
if rbln_config.use_image_prefill:
|
|
534
|
+
image_prefill_compile_config = rbln_compile_configs[1]
|
|
535
|
+
image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
|
|
536
|
+
fill=0, static_tensors=static_tensors
|
|
537
|
+
)
|
|
538
|
+
wrapped_model.phase = "image_prefill"
|
|
539
|
+
compiled_image_prefill = compile_model(
|
|
540
|
+
wrapped_model,
|
|
541
|
+
image_prefill_compile_config,
|
|
542
|
+
image_prefill_example_inputs,
|
|
543
|
+
context,
|
|
544
|
+
rbln_config.quantization,
|
|
545
|
+
)
|
|
546
|
+
compiled_models["image_prefill"] = compiled_image_prefill
|
|
706
547
|
|
|
707
|
-
compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
|
|
708
548
|
wrapped_model.phase = "decode"
|
|
709
|
-
for batch_size, dec_compile_config in zip(
|
|
549
|
+
for batch_size, dec_compile_config in zip(
|
|
550
|
+
rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
|
|
551
|
+
):
|
|
710
552
|
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
711
553
|
compiled_decoder = compile_model(
|
|
712
554
|
wrapped_model,
|
|
@@ -727,35 +569,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
727
569
|
) -> List[rebel.Runtime]:
|
|
728
570
|
expected_model_names = [
|
|
729
571
|
"prefill",
|
|
730
|
-
"image_prefill",
|
|
731
572
|
*[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
|
|
732
573
|
]
|
|
574
|
+
if rbln_config.use_image_prefill:
|
|
575
|
+
expected_model_names.insert(1, "image_prefill")
|
|
576
|
+
|
|
733
577
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
734
578
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
735
579
|
|
|
736
|
-
|
|
580
|
+
ret_val = [
|
|
737
581
|
rebel.Runtime(
|
|
738
582
|
compiled_models[0],
|
|
739
583
|
tensor_type="pt",
|
|
740
584
|
device=rbln_config.device_map["prefill"],
|
|
741
585
|
activate_profiler=rbln_config.activate_profiler,
|
|
742
586
|
timeout=rbln_config.timeout,
|
|
743
|
-
)
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
587
|
+
)
|
|
588
|
+
]
|
|
589
|
+
if rbln_config.use_image_prefill:
|
|
590
|
+
ret_val.append(
|
|
591
|
+
rebel.Runtime(
|
|
592
|
+
compiled_models[1],
|
|
593
|
+
tensor_type="pt",
|
|
594
|
+
device=rbln_config.device_map["image_prefill"],
|
|
595
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
596
|
+
timeout=rbln_config.timeout,
|
|
597
|
+
),
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
ret_val.extend(
|
|
601
|
+
[
|
|
752
602
|
rebel.Runtime(
|
|
753
|
-
compiled_models[i +
|
|
603
|
+
compiled_models[i + rbln_config.decoder_runtime_idx],
|
|
754
604
|
tensor_type="pt",
|
|
755
605
|
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
756
606
|
activate_profiler=rbln_config.activate_profiler,
|
|
757
607
|
timeout=rbln_config.timeout,
|
|
758
608
|
)
|
|
759
609
|
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
760
|
-
]
|
|
761
|
-
|
|
610
|
+
]
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
return ret_val
|