optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +108 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +156 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +48 -21
- optimum/rbln/modeling_base.py +99 -22
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/runtime_utils.py +60 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -11,95 +11,74 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import importlib
|
|
14
15
|
import inspect
|
|
15
|
-
from collections import deque
|
|
16
|
-
from dataclasses import dataclass
|
|
17
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
|
18
17
|
|
|
19
18
|
import rebel
|
|
20
19
|
import torch
|
|
21
20
|
from rebel.compile_context import CompileContext
|
|
22
|
-
from transformers import
|
|
23
|
-
AutoModelForImageTextToText,
|
|
24
|
-
Gemma3ForConditionalGeneration,
|
|
25
|
-
PretrainedConfig,
|
|
26
|
-
PreTrainedModel,
|
|
27
|
-
)
|
|
21
|
+
from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
|
|
28
22
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
29
23
|
from transformers.modeling_utils import no_init_weights
|
|
30
24
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
|
|
31
25
|
|
|
32
26
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
33
27
|
from ....modeling import RBLNModel
|
|
34
|
-
from
|
|
28
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
29
|
+
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
30
|
+
from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
|
|
31
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
32
|
+
from ..decoderonly.modeling_decoderonly import (
|
|
33
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
34
|
+
)
|
|
35
35
|
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
|
36
36
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
|
37
|
+
from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
if TYPE_CHECKING:
|
|
40
41
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
|
|
41
42
|
|
|
42
43
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class LoopVisionTower:
|
|
49
|
-
def __init__(self, vision_tower: RBLNModel) -> None:
|
|
50
|
-
self.vision_tower = vision_tower
|
|
44
|
+
class LoopVisionTower(LoopProcessor):
|
|
45
|
+
def __init__(self, vision_tower: "RBLNModel"):
|
|
46
|
+
super().__init__(model=vision_tower)
|
|
51
47
|
|
|
52
|
-
def
|
|
53
|
-
|
|
54
|
-
# shape of pixel_values : [batch, num_channel, height, width]
|
|
55
|
-
pixel_values = args[0]
|
|
48
|
+
def _get_batch_size(self, pixel_values, **kwargs):
|
|
49
|
+
return pixel_values.shape[0]
|
|
56
50
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
for
|
|
60
|
-
|
|
51
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
|
|
52
|
+
pixel_values_item = pixel_values[index : index + 1]
|
|
53
|
+
out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
|
|
54
|
+
return ([pixel_values_item], {"out": out_buffer})
|
|
61
55
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
# FIXME:: This can be optimized using out= API of rbln runtime.
|
|
65
|
-
last_hidden_states = torch.cat(last_hidden_states, dim=0)
|
|
56
|
+
def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
|
|
57
|
+
output = kwargs["out"]
|
|
66
58
|
|
|
67
59
|
return BaseModelOutputWithPooling(
|
|
68
|
-
last_hidden_state=
|
|
60
|
+
last_hidden_state=output[0],
|
|
69
61
|
)
|
|
70
62
|
|
|
71
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
72
|
-
return self.forward(*args, **kwds)
|
|
73
|
-
|
|
74
|
-
def __repr__(self) -> str:
|
|
75
|
-
return repr(self.vision_tower)
|
|
76
|
-
|
|
77
63
|
|
|
78
|
-
class LoopProjector:
|
|
79
|
-
def __init__(self, multi_modal_projector)
|
|
80
|
-
|
|
64
|
+
class LoopProjector(LoopProcessor):
|
|
65
|
+
def __init__(self, multi_modal_projector: "RBLNModel"):
|
|
66
|
+
super().__init__(model=multi_modal_projector)
|
|
81
67
|
|
|
82
|
-
def
|
|
83
|
-
|
|
84
|
-
image_feature = args[0]
|
|
68
|
+
def _get_batch_size(self, image_feature, **kwargs):
|
|
69
|
+
return image_feature.shape[0]
|
|
85
70
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
for
|
|
89
|
-
|
|
71
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
|
|
72
|
+
image_feature_item = image_feature[index : index + 1]
|
|
73
|
+
out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
|
|
74
|
+
return ([image_feature_item], {"out": out_buffer})
|
|
90
75
|
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
return
|
|
76
|
+
def _process_outputs(self, outputs: list, **kwargs):
|
|
77
|
+
output = kwargs["out"]
|
|
78
|
+
return output[0]
|
|
94
79
|
|
|
95
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
96
|
-
return self.forward(*args, **kwds)
|
|
97
80
|
|
|
98
|
-
|
|
99
|
-
return repr(self.multi_modal_projector)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
81
|
+
class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
103
82
|
auto_model_class = AutoModelForImageTextToText
|
|
104
83
|
_rbln_submodules = [
|
|
105
84
|
{"name": "vision_tower"},
|
|
@@ -119,6 +98,21 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
119
98
|
def can_generate(self):
|
|
120
99
|
return True
|
|
121
100
|
|
|
101
|
+
@classmethod
|
|
102
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
103
|
+
with no_init_weights():
|
|
104
|
+
model_cls_name = model.model.language_model.__class__.__name__
|
|
105
|
+
causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
|
|
106
|
+
causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
|
|
107
|
+
new_language_model = causal_model_cls(model.model.language_model.config)
|
|
108
|
+
|
|
109
|
+
new_language_model.lm_head = model.lm_head
|
|
110
|
+
new_language_model.model = model.model.language_model
|
|
111
|
+
model.model.language_model = new_language_model
|
|
112
|
+
model.lm_head = None
|
|
113
|
+
del model.lm_head
|
|
114
|
+
return model
|
|
115
|
+
|
|
122
116
|
def __post_init__(self, **kwargs):
|
|
123
117
|
self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
|
|
124
118
|
self.language_model = self.rbln_submodules[1]
|
|
@@ -139,7 +133,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
139
133
|
return self.language_model.get_input_embeddings()
|
|
140
134
|
|
|
141
135
|
@classmethod
|
|
142
|
-
def
|
|
136
|
+
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
143
137
|
return model.multi_modal_projector
|
|
144
138
|
|
|
145
139
|
@classmethod
|
|
@@ -208,18 +202,30 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
208
202
|
return model_kwargs
|
|
209
203
|
|
|
210
204
|
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
205
|
+
# Projects the last hidden state from the vision model into language model space.
|
|
206
|
+
|
|
207
|
+
# Args:
|
|
208
|
+
# pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
|
|
209
|
+
# The tensors corresponding to the input images.
|
|
210
|
+
|
|
211
|
+
# Returns:
|
|
212
|
+
# Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
|
213
|
+
|
|
214
|
+
vision_out_buffer = []
|
|
215
|
+
vision_out_size = [
|
|
216
|
+
pixel_values.shape[0],
|
|
217
|
+
(self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
|
|
218
|
+
self.config.vision_config.hidden_size,
|
|
219
|
+
]
|
|
220
|
+
projector_out_size = [
|
|
221
|
+
pixel_values.shape[0],
|
|
222
|
+
self.config.mm_tokens_per_image,
|
|
223
|
+
self.config.text_config.hidden_size,
|
|
224
|
+
]
|
|
225
|
+
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
|
|
226
|
+
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
|
|
227
|
+
vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
|
|
228
|
+
image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
|
|
223
229
|
return image_features
|
|
224
230
|
|
|
225
231
|
def _preprocess_prefill(
|
|
@@ -254,17 +260,45 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
254
260
|
|
|
255
261
|
return inputs_embeds
|
|
256
262
|
|
|
263
|
+
def get_padded_cache_position(
|
|
264
|
+
self,
|
|
265
|
+
cache_position: torch.Tensor, # shape: [1, seq_len]
|
|
266
|
+
token_type_ids: torch.Tensor, # shape: [1, seq_len]
|
|
267
|
+
) -> torch.Tensor:
|
|
268
|
+
seq_len = cache_position[0][-1].item() + 1
|
|
269
|
+
|
|
270
|
+
# Find image start positions
|
|
271
|
+
image_starts = [
|
|
272
|
+
s
|
|
273
|
+
for s in torch.where(token_type_ids == 1)[1]
|
|
274
|
+
if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
# Initialize padded tensors
|
|
278
|
+
padded_input_len = seq_len
|
|
279
|
+
for image_start in image_starts:
|
|
280
|
+
pad_needed = (
|
|
281
|
+
self.rbln_config.image_prefill_chunk_size
|
|
282
|
+
- (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
|
|
283
|
+
) % self.rbln_config.image_prefill_chunk_size
|
|
284
|
+
padded_input_len += pad_needed
|
|
285
|
+
|
|
286
|
+
return torch.cat(
|
|
287
|
+
[cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
|
|
288
|
+
dim=1,
|
|
289
|
+
)
|
|
290
|
+
|
|
257
291
|
def forward(
|
|
258
292
|
self,
|
|
259
293
|
input_ids: torch.LongTensor = None,
|
|
294
|
+
attention_mask: torch.Tensor = None,
|
|
295
|
+
token_type_ids: torch.Tensor = None,
|
|
260
296
|
pixel_values: torch.FloatTensor = None,
|
|
261
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
262
297
|
cache_position: Optional[torch.LongTensor] = None,
|
|
263
298
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
264
299
|
generate_idx: Optional[torch.Tensor] = None,
|
|
265
300
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
266
301
|
position_ids: Optional[torch.Tensor] = None,
|
|
267
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
268
302
|
**lm_kwargs: Dict[str, Any],
|
|
269
303
|
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
|
270
304
|
# prefill
|
|
@@ -275,12 +309,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
275
309
|
|
|
276
310
|
for b_idx in range(batch_size):
|
|
277
311
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
312
|
+
token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
|
313
|
+
cache_position = self.get_padded_cache_position(cache_position, token_type_id)
|
|
314
|
+
|
|
278
315
|
output = self.language_model.prefill_decoder(
|
|
279
316
|
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
|
280
317
|
attention_mask=attention_mask[b_idx],
|
|
281
318
|
cache_position=cache_position,
|
|
282
319
|
batch_idx=b_idx,
|
|
283
|
-
token_type_ids=token_type_ids[b_idx : b_idx + 1]
|
|
320
|
+
token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
|
|
284
321
|
)
|
|
285
322
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
286
323
|
logits.append(output.logits)
|
|
@@ -309,209 +346,6 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
309
346
|
)
|
|
310
347
|
|
|
311
348
|
|
|
312
|
-
class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
313
|
-
def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
|
|
314
|
-
super().__init__(*args, **kwargs)
|
|
315
|
-
self.image_prefill = image_prefill # FIXME(taehoon)
|
|
316
|
-
self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
|
|
317
|
-
self.decode = self.runtime if self.phase == "decode" else None
|
|
318
|
-
|
|
319
|
-
def _prepare_prefill_inputs(self, *args, **kwargs):
|
|
320
|
-
(
|
|
321
|
-
inputs,
|
|
322
|
-
cache_position,
|
|
323
|
-
chunked_attention_mask,
|
|
324
|
-
out_buffers,
|
|
325
|
-
position_ids,
|
|
326
|
-
position_embed,
|
|
327
|
-
padded_cache_lengths,
|
|
328
|
-
query_length,
|
|
329
|
-
token_type_ids,
|
|
330
|
-
) = super()._prepare_prefill_inputs(*args, **kwargs)
|
|
331
|
-
|
|
332
|
-
# chunked_attention_mask shape
|
|
333
|
-
chunked_attention_mask = torch.zeros(1, chunked_attention_mask.shape[-1], dtype=torch.float32)
|
|
334
|
-
|
|
335
|
-
# as gemma3 has different prefill chunk size for image and text, we need to pad the inputs to the max of the two.
|
|
336
|
-
padding_size = max(self.rbln_config.prefill_chunk_size, self.rbln_config.image_prefill_chunk_size)
|
|
337
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
338
|
-
cache_position = torch.nn.functional.pad(cache_position, (0, padding_size))
|
|
339
|
-
position_ids = torch.nn.functional.pad(position_ids, (0, padding_size))
|
|
340
|
-
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
|
|
341
|
-
|
|
342
|
-
return (
|
|
343
|
-
inputs,
|
|
344
|
-
cache_position,
|
|
345
|
-
chunked_attention_mask,
|
|
346
|
-
out_buffers,
|
|
347
|
-
position_ids,
|
|
348
|
-
position_embed,
|
|
349
|
-
padded_cache_lengths,
|
|
350
|
-
query_length,
|
|
351
|
-
token_type_ids,
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
def prefill_forward(
|
|
355
|
-
self,
|
|
356
|
-
inputs: torch.Tensor,
|
|
357
|
-
cache_position: torch.Tensor = None,
|
|
358
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
359
|
-
batch_idx: int = None,
|
|
360
|
-
block_tables: torch.Tensor = None,
|
|
361
|
-
is_external_block_tables: bool = None,
|
|
362
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
363
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
364
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
365
|
-
) -> torch.FloatTensor:
|
|
366
|
-
"""
|
|
367
|
-
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
368
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
369
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
370
|
-
"""
|
|
371
|
-
(
|
|
372
|
-
inputs,
|
|
373
|
-
cache_position,
|
|
374
|
-
chunked_attention_mask,
|
|
375
|
-
out_buffers,
|
|
376
|
-
position_ids,
|
|
377
|
-
position_embed,
|
|
378
|
-
padded_cache_lengths,
|
|
379
|
-
query_length,
|
|
380
|
-
token_type_ids,
|
|
381
|
-
) = self._prepare_prefill_inputs(
|
|
382
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
step = 0
|
|
386
|
-
while step < query_length:
|
|
387
|
-
# Check if the prefill chunk is an image prefill
|
|
388
|
-
is_image_prefill = torch.all(
|
|
389
|
-
token_type_ids[:, step : step + self.rbln_config.image_prefill_chunk_size] == 1
|
|
390
|
-
)
|
|
391
|
-
prefill_chunk_size = (
|
|
392
|
-
self.rbln_config.image_prefill_chunk_size if is_image_prefill else self.rbln_config.prefill_chunk_size
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
# Check if the prefill chunk is a text prefill which have image_tokens in it.
|
|
396
|
-
is_text_prefill_with_image_tokens = not is_image_prefill and torch.any(
|
|
397
|
-
token_type_ids[:, step : step + prefill_chunk_size] == 1
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
# Check if the prefill chunk crosses a block boundary, requiring padding to align with block boundaries
|
|
401
|
-
is_cross_block_boundary = (
|
|
402
|
-
step // self.rbln_config.kvcache_block_size
|
|
403
|
-
!= (step + prefill_chunk_size) // self.rbln_config.kvcache_block_size
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
# Check if the prefill chunk is the last chunk
|
|
407
|
-
is_last_chunk = step + prefill_chunk_size >= query_length
|
|
408
|
-
|
|
409
|
-
if is_cross_block_boundary:
|
|
410
|
-
padding_size = prefill_chunk_size - (step + prefill_chunk_size) % self.rbln_config.kvcache_block_size
|
|
411
|
-
padded_cache_lengths += padding_size
|
|
412
|
-
|
|
413
|
-
# if text_prefill end with image_tokens, we only treat the text part.
|
|
414
|
-
num_processed_tokens = prefill_chunk_size
|
|
415
|
-
if is_text_prefill_with_image_tokens:
|
|
416
|
-
first_image_token_idx = torch.where(token_type_ids[:, step : step + prefill_chunk_size] == 1)[1][0]
|
|
417
|
-
num_processed_tokens = first_image_token_idx
|
|
418
|
-
if is_last_chunk:
|
|
419
|
-
num_processed_tokens = query_length - step
|
|
420
|
-
|
|
421
|
-
input_chunk = inputs[:, step : step + prefill_chunk_size]
|
|
422
|
-
cache_pos_chunk = cache_position[:, step : step + prefill_chunk_size].clone() + padded_cache_lengths
|
|
423
|
-
position_ids_chunk = position_ids[:, step : step + prefill_chunk_size].clone()
|
|
424
|
-
chunked_attention_mask[
|
|
425
|
-
:, step + padded_cache_lengths : step + num_processed_tokens + padded_cache_lengths
|
|
426
|
-
] = 1
|
|
427
|
-
query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
|
|
428
|
-
|
|
429
|
-
if is_image_prefill:
|
|
430
|
-
logits = self.image_prefill(
|
|
431
|
-
input_chunk,
|
|
432
|
-
cache_pos_chunk,
|
|
433
|
-
block_tables,
|
|
434
|
-
local_block_tables,
|
|
435
|
-
query_position,
|
|
436
|
-
chunked_attention_mask,
|
|
437
|
-
position_ids_chunk,
|
|
438
|
-
out=out_buffers,
|
|
439
|
-
)
|
|
440
|
-
else:
|
|
441
|
-
logits = self.prefill(
|
|
442
|
-
input_chunk,
|
|
443
|
-
cache_pos_chunk,
|
|
444
|
-
block_tables,
|
|
445
|
-
local_block_tables,
|
|
446
|
-
query_position,
|
|
447
|
-
chunked_attention_mask,
|
|
448
|
-
position_ids_chunk,
|
|
449
|
-
out=out_buffers,
|
|
450
|
-
)
|
|
451
|
-
|
|
452
|
-
step += num_processed_tokens
|
|
453
|
-
|
|
454
|
-
if not is_external_block_tables:
|
|
455
|
-
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
456
|
-
|
|
457
|
-
return RBLNGemma3ForCausalLMOutput(
|
|
458
|
-
logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
|
|
459
|
-
)
|
|
460
|
-
|
|
461
|
-
def decode_forward(
|
|
462
|
-
self,
|
|
463
|
-
inputs: torch.Tensor,
|
|
464
|
-
cache_position: torch.Tensor = None,
|
|
465
|
-
block_tables: torch.Tensor = None,
|
|
466
|
-
is_external_block_tables: bool = None,
|
|
467
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
468
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
469
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
470
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
471
|
-
) -> torch.FloatTensor:
|
|
472
|
-
batch_size = inputs.shape[0]
|
|
473
|
-
if batch_size != self.batch_size:
|
|
474
|
-
raise RuntimeError(
|
|
475
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
476
|
-
)
|
|
477
|
-
|
|
478
|
-
if batch_size != cache_position.shape[0]:
|
|
479
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
480
|
-
|
|
481
|
-
# FIXME(taehoon): how to handle pos_attn_mask with external block tables
|
|
482
|
-
if is_external_block_tables:
|
|
483
|
-
if attention_mask is None:
|
|
484
|
-
raise ValueError("attention_mask should be provided with external block tables.")
|
|
485
|
-
if local_block_tables is None:
|
|
486
|
-
raise ValueError("local_block_tables should be provided with external block tables.")
|
|
487
|
-
else:
|
|
488
|
-
local_block_tables = (
|
|
489
|
-
local_block_tables
|
|
490
|
-
if local_block_tables is not None
|
|
491
|
-
else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
|
492
|
-
)
|
|
493
|
-
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
494
|
-
for b_idx in range(batch_size):
|
|
495
|
-
decoding_step = cache_position[b_idx].item()
|
|
496
|
-
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
497
|
-
raise ValueError(
|
|
498
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
499
|
-
)
|
|
500
|
-
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
501
|
-
|
|
502
|
-
attention_mask = self.dec_attn_mask
|
|
503
|
-
|
|
504
|
-
if self.batch_size < block_tables.shape[0]:
|
|
505
|
-
block_tables = block_tables[: self.batch_size]
|
|
506
|
-
|
|
507
|
-
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
508
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
509
|
-
|
|
510
|
-
logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
|
|
511
|
-
|
|
512
|
-
return RBLNDecoderOnlyOutput(logits=logits)
|
|
513
|
-
|
|
514
|
-
|
|
515
349
|
class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
516
350
|
"""
|
|
517
351
|
The Gemma3 Model transformer with a language modeling head (linear layer) on top.
|
|
@@ -524,52 +358,36 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
524
358
|
"""
|
|
525
359
|
|
|
526
360
|
_decoder_wrapper_cls = Gemma3ForCausalLMWrapper
|
|
361
|
+
_supports_non_fp32 = False
|
|
527
362
|
|
|
528
|
-
def
|
|
529
|
-
main_input_name = self.main_input_name
|
|
530
|
-
|
|
531
|
-
if self.rbln_config.use_inputs_embeds:
|
|
532
|
-
main_input_name = "inputs_embeds"
|
|
533
|
-
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
534
|
-
self.embed_tokens = self._create_embedding_layer()
|
|
535
|
-
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
536
|
-
else:
|
|
537
|
-
self.embed_tokens = None
|
|
538
|
-
|
|
363
|
+
def setup_runtime(self):
|
|
539
364
|
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
|
540
365
|
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
366
|
+
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
367
|
+
|
|
368
|
+
common_kwargs = {
|
|
369
|
+
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
370
|
+
"embed_tokens": self.embed_tokens,
|
|
371
|
+
"dec_attn_mask": dec_attn_mask,
|
|
372
|
+
"page_table_manager": page_table_manager,
|
|
373
|
+
"rbln_config": self.rbln_config,
|
|
374
|
+
}
|
|
375
|
+
|
|
547
376
|
self.prefill_decoder = RBLNGemma3RuntimeModel(
|
|
548
377
|
runtime=self.model[0],
|
|
549
|
-
image_prefill=self.model[1],
|
|
550
|
-
main_input_name=main_input_name,
|
|
551
|
-
embed_tokens=self.embed_tokens,
|
|
378
|
+
image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
|
|
552
379
|
phase="prefill",
|
|
553
380
|
batch_size=self.rbln_config.batch_size,
|
|
554
|
-
|
|
555
|
-
block_tables=block_tables,
|
|
556
|
-
vocab_size=self.config.vocab_size,
|
|
557
|
-
free_block_pool=free_block_pool,
|
|
558
|
-
rbln_config=self.rbln_config,
|
|
381
|
+
**common_kwargs,
|
|
559
382
|
)
|
|
560
383
|
|
|
561
384
|
self.decoders = {}
|
|
562
385
|
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
563
386
|
self.decoders[batch_size] = RBLNGemma3RuntimeModel(
|
|
564
|
-
runtime=self.model[i +
|
|
565
|
-
main_input_name=main_input_name,
|
|
566
|
-
embed_tokens=self.embed_tokens,
|
|
387
|
+
runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
|
|
567
388
|
phase="decode",
|
|
568
389
|
batch_size=batch_size,
|
|
569
|
-
|
|
570
|
-
block_tables=block_tables,
|
|
571
|
-
free_block_pool=free_block_pool,
|
|
572
|
-
rbln_config=self.rbln_config,
|
|
390
|
+
**common_kwargs,
|
|
573
391
|
)
|
|
574
392
|
|
|
575
393
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
@@ -589,6 +407,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
589
407
|
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
590
408
|
sliding_window = getattr(model_config, "sliding_window", None)
|
|
591
409
|
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
|
410
|
+
if sliding_window_pattern is None:
|
|
411
|
+
if hasattr(model_config, "layer_types"):
|
|
412
|
+
first_full_attention_index = model_config.layer_types.index("full_attention")
|
|
413
|
+
sliding_window_pattern = first_full_attention_index + 1
|
|
414
|
+
else:
|
|
415
|
+
raise ValueError("Cannot determine sliding_window_pattern from model_config")
|
|
416
|
+
|
|
592
417
|
if sliding_window_pattern <= model_config.num_hidden_layers:
|
|
593
418
|
rbln_config.cache_impl = "hybrid"
|
|
594
419
|
rbln_config.sliding_window = sliding_window
|
|
@@ -599,7 +424,12 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
599
424
|
return rbln_config
|
|
600
425
|
|
|
601
426
|
@classmethod
|
|
602
|
-
def _update_submodule_config(
|
|
427
|
+
def _update_submodule_config(
|
|
428
|
+
cls,
|
|
429
|
+
model: "PreTrainedModel",
|
|
430
|
+
rbln_config: RBLNModelConfig,
|
|
431
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
432
|
+
):
|
|
603
433
|
if rbln_config.image_prefill_chunk_size is None:
|
|
604
434
|
rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
|
|
605
435
|
|
|
@@ -624,27 +454,33 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
624
454
|
if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
|
|
625
455
|
raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
|
|
626
456
|
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
457
|
+
if rbln_config.use_image_prefill:
|
|
458
|
+
if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
|
|
459
|
+
raise NotImplementedError(
|
|
460
|
+
"Not implemented for different prefill chunk sizes between text and image prefill."
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Update image prefill compile config
|
|
464
|
+
img_prefill_input_info = cls.get_input_info(
|
|
465
|
+
batch_size=1,
|
|
466
|
+
query_length=rbln_config.image_prefill_chunk_size,
|
|
467
|
+
rbln_config=rbln_config,
|
|
468
|
+
model_config=model_config,
|
|
469
|
+
)
|
|
470
|
+
image_prefill_compile_config = RBLNCompileConfig(
|
|
471
|
+
compiled_model_name="image_prefill", input_info=img_prefill_input_info
|
|
472
|
+
)
|
|
473
|
+
# Insert image_prefill compile config at index 1
|
|
474
|
+
compile_cfgs = rbln_config.compile_cfgs
|
|
475
|
+
compile_cfgs.insert(1, image_prefill_compile_config)
|
|
476
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
641
477
|
|
|
642
478
|
return rbln_config
|
|
643
479
|
|
|
644
480
|
@classmethod
|
|
645
481
|
@torch.inference_mode()
|
|
646
482
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
647
|
-
wrapped_model = cls.
|
|
483
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
648
484
|
|
|
649
485
|
rbln_compile_configs = rbln_config.compile_cfgs
|
|
650
486
|
prefill_compile_config = rbln_compile_configs[0]
|
|
@@ -690,23 +526,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
690
526
|
context,
|
|
691
527
|
rbln_config.quantization,
|
|
692
528
|
)
|
|
529
|
+
compiled_models = {"prefill": compiled_prefill}
|
|
693
530
|
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
531
|
+
if rbln_config.use_image_prefill:
|
|
532
|
+
image_prefill_compile_config = rbln_compile_configs[1]
|
|
533
|
+
image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
|
|
534
|
+
fill=0, static_tensors=static_tensors
|
|
535
|
+
)
|
|
536
|
+
wrapped_model.phase = "image_prefill"
|
|
537
|
+
compiled_image_prefill = compile_model(
|
|
538
|
+
wrapped_model,
|
|
539
|
+
image_prefill_compile_config,
|
|
540
|
+
image_prefill_example_inputs,
|
|
541
|
+
context,
|
|
542
|
+
rbln_config.quantization,
|
|
543
|
+
)
|
|
544
|
+
compiled_models["image_prefill"] = compiled_image_prefill
|
|
706
545
|
|
|
707
|
-
compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
|
|
708
546
|
wrapped_model.phase = "decode"
|
|
709
|
-
for batch_size, dec_compile_config in zip(
|
|
547
|
+
for batch_size, dec_compile_config in zip(
|
|
548
|
+
rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
|
|
549
|
+
):
|
|
710
550
|
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
711
551
|
compiled_decoder = compile_model(
|
|
712
552
|
wrapped_model,
|
|
@@ -727,35 +567,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
727
567
|
) -> List[rebel.Runtime]:
|
|
728
568
|
expected_model_names = [
|
|
729
569
|
"prefill",
|
|
730
|
-
"image_prefill",
|
|
731
570
|
*[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
|
|
732
571
|
]
|
|
572
|
+
if rbln_config.use_image_prefill:
|
|
573
|
+
expected_model_names.insert(1, "image_prefill")
|
|
574
|
+
|
|
733
575
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
734
576
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
735
577
|
|
|
736
|
-
|
|
578
|
+
ret_val = [
|
|
737
579
|
rebel.Runtime(
|
|
738
580
|
compiled_models[0],
|
|
739
581
|
tensor_type="pt",
|
|
740
582
|
device=rbln_config.device_map["prefill"],
|
|
741
583
|
activate_profiler=rbln_config.activate_profiler,
|
|
742
584
|
timeout=rbln_config.timeout,
|
|
743
|
-
)
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
585
|
+
)
|
|
586
|
+
]
|
|
587
|
+
if rbln_config.use_image_prefill:
|
|
588
|
+
ret_val.append(
|
|
589
|
+
rebel.Runtime(
|
|
590
|
+
compiled_models[1],
|
|
591
|
+
tensor_type="pt",
|
|
592
|
+
device=rbln_config.device_map["image_prefill"],
|
|
593
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
594
|
+
timeout=rbln_config.timeout,
|
|
595
|
+
),
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
ret_val.extend(
|
|
599
|
+
[
|
|
752
600
|
rebel.Runtime(
|
|
753
|
-
compiled_models[i +
|
|
601
|
+
compiled_models[i + rbln_config.decoder_runtime_idx],
|
|
754
602
|
tensor_type="pt",
|
|
755
603
|
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
756
604
|
activate_profiler=rbln_config.activate_profiler,
|
|
757
605
|
timeout=rbln_config.timeout,
|
|
758
606
|
)
|
|
759
607
|
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
760
|
-
]
|
|
761
|
-
|
|
608
|
+
]
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
return ret_val
|