optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -11,99 +11,74 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import importlib
|
|
14
15
|
import inspect
|
|
15
|
-
from collections import deque
|
|
16
|
-
from dataclasses import dataclass
|
|
17
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
|
18
17
|
|
|
19
18
|
import rebel
|
|
20
19
|
import torch
|
|
21
20
|
from rebel.compile_context import CompileContext
|
|
22
|
-
from transformers import
|
|
23
|
-
AutoModelForImageTextToText,
|
|
24
|
-
Gemma3ForConditionalGeneration,
|
|
25
|
-
PretrainedConfig,
|
|
26
|
-
PreTrainedModel,
|
|
27
|
-
)
|
|
21
|
+
from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration, PretrainedConfig, PreTrainedModel
|
|
28
22
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
29
23
|
from transformers.modeling_utils import no_init_weights
|
|
30
24
|
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding
|
|
31
25
|
|
|
32
26
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
33
27
|
from ....modeling import RBLNModel
|
|
34
|
-
from
|
|
35
|
-
from
|
|
28
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
29
|
+
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
30
|
+
from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
|
|
31
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
32
|
+
from ..decoderonly.modeling_decoderonly import (
|
|
33
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
34
|
+
)
|
|
36
35
|
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
|
37
36
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
logger = get_logger()
|
|
37
|
+
from .gemma3_runtime_utils import RBLNGemma3RuntimeModel
|
|
41
38
|
|
|
42
39
|
|
|
43
40
|
if TYPE_CHECKING:
|
|
44
41
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
|
|
45
42
|
|
|
46
43
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
44
|
+
class LoopVisionTower(LoopProcessor):
|
|
45
|
+
def __init__(self, vision_tower: "RBLNModel"):
|
|
46
|
+
super().__init__(model=vision_tower)
|
|
51
47
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
self.vision_tower = vision_tower
|
|
48
|
+
def _get_batch_size(self, pixel_values, **kwargs):
|
|
49
|
+
return pixel_values.shape[0]
|
|
55
50
|
|
|
56
|
-
def
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
51
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
|
|
52
|
+
pixel_values_item = pixel_values[index : index + 1]
|
|
53
|
+
out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
|
|
54
|
+
return ([pixel_values_item], {"out": out_buffer})
|
|
60
55
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
for i in range(batch_size):
|
|
64
|
-
outputs.append(self.vision_tower(pixel_values=pixel_values[i : i + 1], return_dict=True))
|
|
65
|
-
|
|
66
|
-
last_hidden_states = [output.last_hidden_state for output in outputs]
|
|
67
|
-
|
|
68
|
-
# FIXME:: This can be optimized using out= API of rbln runtime.
|
|
69
|
-
last_hidden_states = torch.cat(last_hidden_states, dim=0)
|
|
56
|
+
def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
|
|
57
|
+
output = kwargs["out"]
|
|
70
58
|
|
|
71
59
|
return BaseModelOutputWithPooling(
|
|
72
|
-
last_hidden_state=
|
|
60
|
+
last_hidden_state=output[0],
|
|
73
61
|
)
|
|
74
62
|
|
|
75
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
76
|
-
return self.forward(*args, **kwds)
|
|
77
|
-
|
|
78
|
-
def __repr__(self) -> str:
|
|
79
|
-
return repr(self.vision_tower)
|
|
80
|
-
|
|
81
63
|
|
|
82
|
-
class LoopProjector:
|
|
83
|
-
def __init__(self, multi_modal_projector)
|
|
84
|
-
|
|
64
|
+
class LoopProjector(LoopProcessor):
|
|
65
|
+
def __init__(self, multi_modal_projector: "RBLNModel"):
|
|
66
|
+
super().__init__(model=multi_modal_projector)
|
|
85
67
|
|
|
86
|
-
def
|
|
87
|
-
|
|
88
|
-
image_feature = args[0]
|
|
68
|
+
def _get_batch_size(self, image_feature, **kwargs):
|
|
69
|
+
return image_feature.shape[0]
|
|
89
70
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
for
|
|
93
|
-
|
|
71
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
|
|
72
|
+
image_feature_item = image_feature[index : index + 1]
|
|
73
|
+
out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
|
|
74
|
+
return ([image_feature_item], {"out": out_buffer})
|
|
94
75
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
return
|
|
76
|
+
def _process_outputs(self, outputs: list, **kwargs):
|
|
77
|
+
output = kwargs["out"]
|
|
78
|
+
return output[0]
|
|
98
79
|
|
|
99
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
100
|
-
return self.forward(*args, **kwds)
|
|
101
80
|
|
|
102
|
-
|
|
103
|
-
return repr(self.multi_modal_projector)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
81
|
+
class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
107
82
|
auto_model_class = AutoModelForImageTextToText
|
|
108
83
|
_rbln_submodules = [
|
|
109
84
|
{"name": "vision_tower"},
|
|
@@ -123,6 +98,21 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
123
98
|
def can_generate(self):
|
|
124
99
|
return True
|
|
125
100
|
|
|
101
|
+
@classmethod
|
|
102
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
103
|
+
with no_init_weights():
|
|
104
|
+
model_cls_name = model.model.language_model.__class__.__name__
|
|
105
|
+
causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
|
|
106
|
+
causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
|
|
107
|
+
new_language_model = causal_model_cls(model.model.language_model.config)
|
|
108
|
+
|
|
109
|
+
new_language_model.lm_head = model.lm_head
|
|
110
|
+
new_language_model.model = model.model.language_model
|
|
111
|
+
model.model.language_model = new_language_model
|
|
112
|
+
model.lm_head = None
|
|
113
|
+
del model.lm_head
|
|
114
|
+
return model
|
|
115
|
+
|
|
126
116
|
def __post_init__(self, **kwargs):
|
|
127
117
|
self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
|
|
128
118
|
self.language_model = self.rbln_submodules[1]
|
|
@@ -143,7 +133,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
143
133
|
return self.language_model.get_input_embeddings()
|
|
144
134
|
|
|
145
135
|
@classmethod
|
|
146
|
-
def
|
|
136
|
+
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
147
137
|
return model.multi_modal_projector
|
|
148
138
|
|
|
149
139
|
@classmethod
|
|
@@ -212,18 +202,30 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
212
202
|
return model_kwargs
|
|
213
203
|
|
|
214
204
|
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
205
|
+
# Projects the last hidden state from the vision model into language model space.
|
|
206
|
+
|
|
207
|
+
# Args:
|
|
208
|
+
# pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
|
|
209
|
+
# The tensors corresponding to the input images.
|
|
210
|
+
|
|
211
|
+
# Returns:
|
|
212
|
+
# Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
|
213
|
+
|
|
214
|
+
vision_out_buffer = []
|
|
215
|
+
vision_out_size = [
|
|
216
|
+
pixel_values.shape[0],
|
|
217
|
+
(self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
|
|
218
|
+
self.config.vision_config.hidden_size,
|
|
219
|
+
]
|
|
220
|
+
projector_out_size = [
|
|
221
|
+
pixel_values.shape[0],
|
|
222
|
+
self.config.mm_tokens_per_image,
|
|
223
|
+
self.config.text_config.hidden_size,
|
|
224
|
+
]
|
|
225
|
+
vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
|
|
226
|
+
projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
|
|
227
|
+
vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
|
|
228
|
+
image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
|
|
227
229
|
return image_features
|
|
228
230
|
|
|
229
231
|
def _preprocess_prefill(
|
|
@@ -258,17 +260,45 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
258
260
|
|
|
259
261
|
return inputs_embeds
|
|
260
262
|
|
|
263
|
+
def get_padded_cache_position(
|
|
264
|
+
self,
|
|
265
|
+
cache_position: torch.Tensor, # shape: [1, seq_len]
|
|
266
|
+
token_type_ids: torch.Tensor, # shape: [1, seq_len]
|
|
267
|
+
) -> torch.Tensor:
|
|
268
|
+
seq_len = cache_position[0][-1].item() + 1
|
|
269
|
+
|
|
270
|
+
# Find image start positions
|
|
271
|
+
image_starts = [
|
|
272
|
+
s
|
|
273
|
+
for s in torch.where(token_type_ids == 1)[1]
|
|
274
|
+
if torch.all(token_type_ids[:, s : s + self.rbln_config.image_prefill_chunk_size] == 1)
|
|
275
|
+
]
|
|
276
|
+
|
|
277
|
+
# Initialize padded tensors
|
|
278
|
+
padded_input_len = seq_len
|
|
279
|
+
for image_start in image_starts:
|
|
280
|
+
pad_needed = (
|
|
281
|
+
self.rbln_config.image_prefill_chunk_size
|
|
282
|
+
- (image_start + padded_input_len - seq_len) % self.rbln_config.image_prefill_chunk_size
|
|
283
|
+
) % self.rbln_config.image_prefill_chunk_size
|
|
284
|
+
padded_input_len += pad_needed
|
|
285
|
+
|
|
286
|
+
return torch.cat(
|
|
287
|
+
[cache_position, torch.arange(seq_len, padded_input_len, dtype=torch.int32).unsqueeze(0)],
|
|
288
|
+
dim=1,
|
|
289
|
+
)
|
|
290
|
+
|
|
261
291
|
def forward(
|
|
262
292
|
self,
|
|
263
293
|
input_ids: torch.LongTensor = None,
|
|
294
|
+
attention_mask: torch.Tensor = None,
|
|
295
|
+
token_type_ids: torch.Tensor = None,
|
|
264
296
|
pixel_values: torch.FloatTensor = None,
|
|
265
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
266
297
|
cache_position: Optional[torch.LongTensor] = None,
|
|
267
298
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
268
299
|
generate_idx: Optional[torch.Tensor] = None,
|
|
269
300
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
270
301
|
position_ids: Optional[torch.Tensor] = None,
|
|
271
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
272
302
|
**lm_kwargs: Dict[str, Any],
|
|
273
303
|
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
|
274
304
|
# prefill
|
|
@@ -279,12 +309,15 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
279
309
|
|
|
280
310
|
for b_idx in range(batch_size):
|
|
281
311
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
312
|
+
token_type_id = token_type_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
|
313
|
+
cache_position = self.get_padded_cache_position(cache_position, token_type_id)
|
|
314
|
+
|
|
282
315
|
output = self.language_model.prefill_decoder(
|
|
283
316
|
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
|
284
317
|
attention_mask=attention_mask[b_idx],
|
|
285
318
|
cache_position=cache_position,
|
|
286
319
|
batch_idx=b_idx,
|
|
287
|
-
token_type_ids=token_type_ids[b_idx : b_idx + 1]
|
|
320
|
+
token_type_ids=token_type_ids[b_idx : b_idx + 1], # do not pass token_type_id
|
|
288
321
|
)
|
|
289
322
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
290
323
|
logits.append(output.logits)
|
|
@@ -313,362 +346,6 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
|
313
346
|
)
|
|
314
347
|
|
|
315
348
|
|
|
316
|
-
class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
317
|
-
def __init__(self, *args, image_prefill: Optional[rebel.Runtime] = None, **kwargs):
|
|
318
|
-
super().__init__(*args, **kwargs)
|
|
319
|
-
self.image_prefill = image_prefill # FIXME(taehoon)
|
|
320
|
-
self.prefill = self.runtime if self.phase == "prefill" else None # FIXME
|
|
321
|
-
self.decode = self.runtime if self.phase == "decode" else None
|
|
322
|
-
|
|
323
|
-
def pad_for_chunked_images(
|
|
324
|
-
self,
|
|
325
|
-
inputs: torch.Tensor,
|
|
326
|
-
attention_mask: torch.Tensor,
|
|
327
|
-
position_ids: torch.Tensor,
|
|
328
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
329
|
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]:
|
|
330
|
-
"""
|
|
331
|
-
Pads inputs, attention_mask, and position_ids so image token groups (256 tokens with token_type_ids == 1)
|
|
332
|
-
start at multiples of prefill_chunk_size (256). Returns padded tensors and total padded length.
|
|
333
|
-
|
|
334
|
-
Args:
|
|
335
|
-
inputs: (1, seq_len, hidden_size) tensor.
|
|
336
|
-
attention_mask: (1, seq_len) tensor, 1 for valid, 0 for masked.
|
|
337
|
-
position_ids: (1, seq_len) tensor for RoPE.
|
|
338
|
-
token_type_ids: (1, seq_len) tensor, 0 for text, 1 for image.
|
|
339
|
-
|
|
340
|
-
Returns:
|
|
341
|
-
(inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
|
|
342
|
-
"""
|
|
343
|
-
|
|
344
|
-
if token_type_ids is None:
|
|
345
|
-
return inputs, attention_mask, position_ids, 0, torch.zeros(inputs.shape[:2], dtype=torch.long)
|
|
346
|
-
|
|
347
|
-
seq_len = inputs.shape[1]
|
|
348
|
-
|
|
349
|
-
# Find image start positions
|
|
350
|
-
image_starts = [
|
|
351
|
-
s
|
|
352
|
-
for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
|
|
353
|
-
if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
|
|
354
|
-
]
|
|
355
|
-
|
|
356
|
-
# Initialize padded tensors
|
|
357
|
-
padded_input_len = seq_len
|
|
358
|
-
for image_start in image_starts:
|
|
359
|
-
pad_needed = (
|
|
360
|
-
self.rbln_config.prefill_chunk_size
|
|
361
|
-
- (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
|
|
362
|
-
) % self.rbln_config.prefill_chunk_size
|
|
363
|
-
padded_input_len += pad_needed
|
|
364
|
-
total_padding = padded_input_len - seq_len
|
|
365
|
-
|
|
366
|
-
if inputs.dim() == 3:
|
|
367
|
-
inputs_padded = torch.zeros(1, padded_input_len, inputs.shape[2], dtype=inputs.dtype)
|
|
368
|
-
else:
|
|
369
|
-
inputs_padded = torch.zeros(1, padded_input_len, dtype=inputs.dtype)
|
|
370
|
-
attention_mask_padded = torch.zeros(1, padded_input_len, dtype=attention_mask.dtype)
|
|
371
|
-
position_ids_padded = torch.zeros(1, padded_input_len, dtype=position_ids.dtype)
|
|
372
|
-
token_type_ids_padded = torch.zeros(1, padded_input_len, dtype=token_type_ids.dtype)
|
|
373
|
-
|
|
374
|
-
# Fill padded tensors
|
|
375
|
-
dest_pos = 0
|
|
376
|
-
src_pos = 0
|
|
377
|
-
last_pos_id = -1
|
|
378
|
-
for image_start in image_starts + [seq_len]:
|
|
379
|
-
# Text segment
|
|
380
|
-
if src_pos < image_start:
|
|
381
|
-
length = image_start - src_pos
|
|
382
|
-
inputs_padded[:, dest_pos : dest_pos + length] = inputs[:, src_pos:image_start]
|
|
383
|
-
attention_mask_padded[:, dest_pos : dest_pos + length] = attention_mask[:, src_pos:image_start]
|
|
384
|
-
position_ids_padded[:, dest_pos : dest_pos + length] = position_ids[:, src_pos:image_start]
|
|
385
|
-
token_type_ids_padded[:, dest_pos : dest_pos + length] = token_type_ids[:, src_pos:image_start]
|
|
386
|
-
dest_pos += length
|
|
387
|
-
last_pos_id = position_ids[0, image_start - 1].item()
|
|
388
|
-
src_pos = image_start
|
|
389
|
-
|
|
390
|
-
# Padding
|
|
391
|
-
pad_needed = (
|
|
392
|
-
self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
|
|
393
|
-
) % self.rbln_config.prefill_chunk_size
|
|
394
|
-
if pad_needed and dest_pos < padded_input_len:
|
|
395
|
-
position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
|
|
396
|
-
last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
|
|
397
|
-
).unsqueeze(0)
|
|
398
|
-
dest_pos += pad_needed
|
|
399
|
-
|
|
400
|
-
# Image segment
|
|
401
|
-
if src_pos < seq_len and src_pos == image_start:
|
|
402
|
-
inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
|
|
403
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
404
|
-
]
|
|
405
|
-
attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
|
|
406
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
407
|
-
]
|
|
408
|
-
position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
|
|
409
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
410
|
-
]
|
|
411
|
-
token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
|
|
412
|
-
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
|
413
|
-
]
|
|
414
|
-
dest_pos += self.rbln_config.prefill_chunk_size
|
|
415
|
-
src_pos += self.rbln_config.prefill_chunk_size
|
|
416
|
-
last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
|
|
417
|
-
|
|
418
|
-
return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
|
|
419
|
-
|
|
420
|
-
def _prepare_prefill_inputs(
|
|
421
|
-
self,
|
|
422
|
-
inputs: torch.Tensor,
|
|
423
|
-
cache_position: torch.Tensor,
|
|
424
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
425
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
426
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
427
|
-
):
|
|
428
|
-
"""
|
|
429
|
-
Prepare inputs for prefill phase.
|
|
430
|
-
"""
|
|
431
|
-
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
432
|
-
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
433
|
-
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
434
|
-
token_type_ids = (
|
|
435
|
-
token_type_ids[:, attention_mask.bool()]
|
|
436
|
-
if attention_mask is not None and token_type_ids is not None
|
|
437
|
-
else token_type_ids
|
|
438
|
-
)
|
|
439
|
-
|
|
440
|
-
if position_embed is not None:
|
|
441
|
-
position_embed = (
|
|
442
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
seq_len = inputs.shape[1]
|
|
446
|
-
# Initialize attention mask for chunked processing
|
|
447
|
-
if self.rbln_config.use_attention_mask:
|
|
448
|
-
chunked_attention_mask = (
|
|
449
|
-
torch.ones(1, seq_len, dtype=torch.float32)
|
|
450
|
-
if self.rbln_config.use_position_ids
|
|
451
|
-
else torch.zeros(
|
|
452
|
-
1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
453
|
-
)
|
|
454
|
-
)
|
|
455
|
-
else:
|
|
456
|
-
chunked_attention_mask = None
|
|
457
|
-
|
|
458
|
-
# Buffer for storing output logits
|
|
459
|
-
out_buffers = [
|
|
460
|
-
torch.empty(
|
|
461
|
-
size=self.output_size,
|
|
462
|
-
dtype=torch.float32,
|
|
463
|
-
device="cpu",
|
|
464
|
-
)
|
|
465
|
-
]
|
|
466
|
-
|
|
467
|
-
inputs, chunked_attention_mask, position_ids, padded_cache_lengths, token_type_ids_padded = (
|
|
468
|
-
self.pad_for_chunked_images(inputs, chunked_attention_mask, cache_position, token_type_ids)
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
query_length = inputs.shape[1]
|
|
472
|
-
if query_length > self.rbln_config.max_seq_len:
|
|
473
|
-
raise ValueError(
|
|
474
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
# Align attention_mask to compiled shape
|
|
478
|
-
if self.rbln_config.use_position_ids:
|
|
479
|
-
chunked_attention_mask = torch.nn.functional.pad(
|
|
480
|
-
chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
|
|
481
|
-
)
|
|
482
|
-
|
|
483
|
-
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
|
484
|
-
padding_size = 0
|
|
485
|
-
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
|
486
|
-
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
487
|
-
# inputs_embeds
|
|
488
|
-
if inputs.dim() == 3:
|
|
489
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
490
|
-
# inputs_ids
|
|
491
|
-
else:
|
|
492
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
|
493
|
-
|
|
494
|
-
position_ids = torch.cat(
|
|
495
|
-
[
|
|
496
|
-
position_ids,
|
|
497
|
-
torch.arange(
|
|
498
|
-
query_length,
|
|
499
|
-
query_length + padding_size,
|
|
500
|
-
dtype=torch.int32,
|
|
501
|
-
).unsqueeze(0),
|
|
502
|
-
],
|
|
503
|
-
dim=-1,
|
|
504
|
-
)
|
|
505
|
-
token_type_ids_padded = torch.nn.functional.pad(token_type_ids_padded, (0, padding_size))
|
|
506
|
-
|
|
507
|
-
if position_embed is not None:
|
|
508
|
-
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
509
|
-
|
|
510
|
-
cache_position = torch.arange(0, query_length + padding_size, dtype=torch.int32).unsqueeze(0)
|
|
511
|
-
|
|
512
|
-
return (
|
|
513
|
-
inputs,
|
|
514
|
-
cache_position,
|
|
515
|
-
chunked_attention_mask,
|
|
516
|
-
out_buffers,
|
|
517
|
-
position_ids,
|
|
518
|
-
position_embed,
|
|
519
|
-
padded_cache_lengths,
|
|
520
|
-
query_length,
|
|
521
|
-
token_type_ids_padded,
|
|
522
|
-
)
|
|
523
|
-
|
|
524
|
-
def prefill_forward(
|
|
525
|
-
self,
|
|
526
|
-
inputs: torch.Tensor,
|
|
527
|
-
cache_position: torch.Tensor = None,
|
|
528
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
529
|
-
batch_idx: int = None,
|
|
530
|
-
block_tables: torch.Tensor = None,
|
|
531
|
-
is_external_block_tables: bool = None,
|
|
532
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
533
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
534
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
535
|
-
) -> torch.FloatTensor:
|
|
536
|
-
"""
|
|
537
|
-
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
538
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
539
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
540
|
-
"""
|
|
541
|
-
(
|
|
542
|
-
inputs,
|
|
543
|
-
cache_position,
|
|
544
|
-
padded_attention_mask,
|
|
545
|
-
out_buffers,
|
|
546
|
-
position_ids,
|
|
547
|
-
position_embed,
|
|
548
|
-
padded_cache_lengths,
|
|
549
|
-
query_length,
|
|
550
|
-
token_type_ids_padded,
|
|
551
|
-
) = self._prepare_prefill_inputs(
|
|
552
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
553
|
-
)
|
|
554
|
-
if not is_external_block_tables:
|
|
555
|
-
local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
|
|
556
|
-
self.dec_attn_mask[batch_idx : batch_idx + 1] = padded_attention_mask[:1]
|
|
557
|
-
|
|
558
|
-
if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
|
|
559
|
-
chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
560
|
-
|
|
561
|
-
# Process input in chunks of size `prefill_chunk_size`
|
|
562
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
563
|
-
# Extract the current chunk of inputs and cache positions
|
|
564
|
-
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
565
|
-
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
566
|
-
position_ids_chunk = (
|
|
567
|
-
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
568
|
-
if position_ids is not None
|
|
569
|
-
else None
|
|
570
|
-
)
|
|
571
|
-
|
|
572
|
-
if self.rbln_config.use_attention_mask:
|
|
573
|
-
if self.rbln_config.use_position_ids:
|
|
574
|
-
chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = (
|
|
575
|
-
padded_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size]
|
|
576
|
-
)
|
|
577
|
-
|
|
578
|
-
# Define query position
|
|
579
|
-
query_position = (
|
|
580
|
-
torch.sum(
|
|
581
|
-
chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
|
|
582
|
-
dim=-1,
|
|
583
|
-
dtype=torch.int16,
|
|
584
|
-
).squeeze(0)
|
|
585
|
-
- 1
|
|
586
|
-
)
|
|
587
|
-
if token_type_ids_padded[:, step] == 1:
|
|
588
|
-
if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
|
|
589
|
-
raise ValueError("All tokens of image_prefill should be the same image.")
|
|
590
|
-
else:
|
|
591
|
-
logits = self.image_prefill(
|
|
592
|
-
input_chunk,
|
|
593
|
-
cache_pos_chunk,
|
|
594
|
-
block_tables,
|
|
595
|
-
local_block_tables,
|
|
596
|
-
query_position,
|
|
597
|
-
chunked_attention_mask,
|
|
598
|
-
position_ids_chunk,
|
|
599
|
-
out=out_buffers,
|
|
600
|
-
)
|
|
601
|
-
else:
|
|
602
|
-
# Forward pass for the current chunk
|
|
603
|
-
logits = self.prefill(
|
|
604
|
-
input_chunk,
|
|
605
|
-
cache_pos_chunk,
|
|
606
|
-
block_tables,
|
|
607
|
-
local_block_tables,
|
|
608
|
-
query_position,
|
|
609
|
-
chunked_attention_mask,
|
|
610
|
-
position_ids_chunk,
|
|
611
|
-
out=out_buffers,
|
|
612
|
-
)
|
|
613
|
-
|
|
614
|
-
return RBLNGemma3ForCausalLMOutput(
|
|
615
|
-
logits=logits, padded_cache_lengths=padded_cache_lengths, attention_mask=chunked_attention_mask
|
|
616
|
-
)
|
|
617
|
-
|
|
618
|
-
def decode_forward(
|
|
619
|
-
self,
|
|
620
|
-
inputs: torch.Tensor,
|
|
621
|
-
cache_position: torch.Tensor = None,
|
|
622
|
-
block_tables: torch.Tensor = None,
|
|
623
|
-
is_external_block_tables: bool = None,
|
|
624
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
625
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
626
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
627
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
628
|
-
) -> torch.FloatTensor:
|
|
629
|
-
batch_size = inputs.shape[0]
|
|
630
|
-
if batch_size != self.batch_size:
|
|
631
|
-
raise RuntimeError(
|
|
632
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
633
|
-
)
|
|
634
|
-
|
|
635
|
-
if batch_size != cache_position.shape[0]:
|
|
636
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
637
|
-
|
|
638
|
-
# FIXME(taehoon): how to handle pos_attn_mask with external block tables
|
|
639
|
-
if is_external_block_tables:
|
|
640
|
-
if attention_mask is None:
|
|
641
|
-
raise ValueError("attention_mask should be provided with external block tables.")
|
|
642
|
-
if local_block_tables is None:
|
|
643
|
-
raise ValueError("local_block_tables should be provided with external block tables.")
|
|
644
|
-
else:
|
|
645
|
-
local_block_tables = (
|
|
646
|
-
local_block_tables
|
|
647
|
-
if local_block_tables is not None
|
|
648
|
-
else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
|
649
|
-
)
|
|
650
|
-
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
651
|
-
for b_idx in range(batch_size):
|
|
652
|
-
decoding_step = cache_position[b_idx].item()
|
|
653
|
-
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
654
|
-
raise ValueError(
|
|
655
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
656
|
-
)
|
|
657
|
-
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
658
|
-
|
|
659
|
-
attention_mask = self.dec_attn_mask
|
|
660
|
-
|
|
661
|
-
if self.batch_size < block_tables.shape[0]:
|
|
662
|
-
block_tables = block_tables[: self.batch_size]
|
|
663
|
-
|
|
664
|
-
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
665
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
666
|
-
|
|
667
|
-
logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
|
|
668
|
-
|
|
669
|
-
return RBLNDecoderOnlyOutput(logits=logits)
|
|
670
|
-
|
|
671
|
-
|
|
672
349
|
class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
673
350
|
"""
|
|
674
351
|
The Gemma3 Model transformer with a language modeling head (linear layer) on top.
|
|
@@ -681,52 +358,36 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
681
358
|
"""
|
|
682
359
|
|
|
683
360
|
_decoder_wrapper_cls = Gemma3ForCausalLMWrapper
|
|
361
|
+
_supports_non_fp32 = False
|
|
684
362
|
|
|
685
|
-
def
|
|
686
|
-
main_input_name = self.main_input_name
|
|
687
|
-
|
|
688
|
-
if self.rbln_config.use_inputs_embeds:
|
|
689
|
-
main_input_name = "inputs_embeds"
|
|
690
|
-
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
691
|
-
self.embed_tokens = self._create_embedding_layer()
|
|
692
|
-
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
693
|
-
else:
|
|
694
|
-
self.embed_tokens = None
|
|
695
|
-
|
|
363
|
+
def setup_runtime(self):
|
|
696
364
|
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
|
697
365
|
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
366
|
+
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
367
|
+
|
|
368
|
+
common_kwargs = {
|
|
369
|
+
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
370
|
+
"embed_tokens": self.embed_tokens,
|
|
371
|
+
"dec_attn_mask": dec_attn_mask,
|
|
372
|
+
"page_table_manager": page_table_manager,
|
|
373
|
+
"rbln_config": self.rbln_config,
|
|
374
|
+
}
|
|
375
|
+
|
|
704
376
|
self.prefill_decoder = RBLNGemma3RuntimeModel(
|
|
705
377
|
runtime=self.model[0],
|
|
706
|
-
image_prefill=self.model[1],
|
|
707
|
-
main_input_name=main_input_name,
|
|
708
|
-
embed_tokens=self.embed_tokens,
|
|
378
|
+
image_prefill=self.model[1] if self.rbln_config.use_image_prefill else None,
|
|
709
379
|
phase="prefill",
|
|
710
380
|
batch_size=self.rbln_config.batch_size,
|
|
711
|
-
|
|
712
|
-
block_tables=block_tables,
|
|
713
|
-
vocab_size=self.config.vocab_size,
|
|
714
|
-
free_block_pool=free_block_pool,
|
|
715
|
-
rbln_config=self.rbln_config,
|
|
381
|
+
**common_kwargs,
|
|
716
382
|
)
|
|
717
383
|
|
|
718
384
|
self.decoders = {}
|
|
719
385
|
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
720
386
|
self.decoders[batch_size] = RBLNGemma3RuntimeModel(
|
|
721
|
-
runtime=self.model[i +
|
|
722
|
-
main_input_name=main_input_name,
|
|
723
|
-
embed_tokens=self.embed_tokens,
|
|
387
|
+
runtime=self.model[i + self.rbln_config.decoder_runtime_idx],
|
|
724
388
|
phase="decode",
|
|
725
389
|
batch_size=batch_size,
|
|
726
|
-
|
|
727
|
-
block_tables=block_tables,
|
|
728
|
-
free_block_pool=free_block_pool,
|
|
729
|
-
rbln_config=self.rbln_config,
|
|
390
|
+
**common_kwargs,
|
|
730
391
|
)
|
|
731
392
|
|
|
732
393
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
@@ -746,6 +407,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
746
407
|
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
747
408
|
sliding_window = getattr(model_config, "sliding_window", None)
|
|
748
409
|
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
|
410
|
+
if sliding_window_pattern is None:
|
|
411
|
+
if hasattr(model_config, "layer_types"):
|
|
412
|
+
first_full_attention_index = model_config.layer_types.index("full_attention")
|
|
413
|
+
sliding_window_pattern = first_full_attention_index + 1
|
|
414
|
+
else:
|
|
415
|
+
raise ValueError("Cannot determine sliding_window_pattern from model_config")
|
|
416
|
+
|
|
749
417
|
if sliding_window_pattern <= model_config.num_hidden_layers:
|
|
750
418
|
rbln_config.cache_impl = "hybrid"
|
|
751
419
|
rbln_config.sliding_window = sliding_window
|
|
@@ -756,14 +424,20 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
756
424
|
return rbln_config
|
|
757
425
|
|
|
758
426
|
@classmethod
|
|
759
|
-
def _update_submodule_config(
|
|
760
|
-
|
|
761
|
-
|
|
427
|
+
def _update_submodule_config(
|
|
428
|
+
cls,
|
|
429
|
+
model: "PreTrainedModel",
|
|
430
|
+
rbln_config: RBLNModelConfig,
|
|
431
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
432
|
+
):
|
|
433
|
+
if rbln_config.image_prefill_chunk_size is None:
|
|
434
|
+
rbln_config.image_prefill_chunk_size = model.config.mm_tokens_per_image
|
|
762
435
|
|
|
763
|
-
if rbln_config.
|
|
764
|
-
|
|
765
|
-
f"
|
|
436
|
+
if rbln_config.image_prefill_chunk_size != model.config.mm_tokens_per_image:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
f"Image prefill chunk size is different from mm_tokens_per_image: {rbln_config.image_prefill_chunk_size} != {model.config.mm_tokens_per_image}"
|
|
766
439
|
)
|
|
440
|
+
|
|
767
441
|
return rbln_config
|
|
768
442
|
|
|
769
443
|
@classmethod
|
|
@@ -777,22 +451,36 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
777
451
|
# Update rbln_config with super class
|
|
778
452
|
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
779
453
|
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
454
|
+
if not (rbln_config.use_attention_mask and rbln_config.use_position_ids):
|
|
455
|
+
raise ValueError("use_attention_mask and use_position_ids must be True for RBLNGemma3ForCausalLM")
|
|
456
|
+
|
|
457
|
+
if rbln_config.use_image_prefill:
|
|
458
|
+
if rbln_config.prefill_chunk_size != rbln_config.image_prefill_chunk_size:
|
|
459
|
+
raise NotImplementedError(
|
|
460
|
+
"Not implemented for different prefill chunk sizes between text and image prefill."
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Update image prefill compile config
|
|
464
|
+
img_prefill_input_info = cls.get_input_info(
|
|
465
|
+
batch_size=1,
|
|
466
|
+
query_length=rbln_config.image_prefill_chunk_size,
|
|
467
|
+
rbln_config=rbln_config,
|
|
468
|
+
model_config=model_config,
|
|
469
|
+
)
|
|
470
|
+
image_prefill_compile_config = RBLNCompileConfig(
|
|
471
|
+
compiled_model_name="image_prefill", input_info=img_prefill_input_info
|
|
472
|
+
)
|
|
473
|
+
# Insert image_prefill compile config at index 1
|
|
474
|
+
compile_cfgs = rbln_config.compile_cfgs
|
|
475
|
+
compile_cfgs.insert(1, image_prefill_compile_config)
|
|
476
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
789
477
|
|
|
790
478
|
return rbln_config
|
|
791
479
|
|
|
792
480
|
@classmethod
|
|
793
481
|
@torch.inference_mode()
|
|
794
482
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNGemma3ForCausalLMConfig):
|
|
795
|
-
wrapped_model = cls.
|
|
483
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
796
484
|
|
|
797
485
|
rbln_compile_configs = rbln_config.compile_cfgs
|
|
798
486
|
prefill_compile_config = rbln_compile_configs[0]
|
|
@@ -838,20 +526,27 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
838
526
|
context,
|
|
839
527
|
rbln_config.quantization,
|
|
840
528
|
)
|
|
529
|
+
compiled_models = {"prefill": compiled_prefill}
|
|
841
530
|
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
531
|
+
if rbln_config.use_image_prefill:
|
|
532
|
+
image_prefill_compile_config = rbln_compile_configs[1]
|
|
533
|
+
image_prefill_example_inputs = image_prefill_compile_config.get_dummy_inputs(
|
|
534
|
+
fill=0, static_tensors=static_tensors
|
|
535
|
+
)
|
|
536
|
+
wrapped_model.phase = "image_prefill"
|
|
537
|
+
compiled_image_prefill = compile_model(
|
|
538
|
+
wrapped_model,
|
|
539
|
+
image_prefill_compile_config,
|
|
540
|
+
image_prefill_example_inputs,
|
|
541
|
+
context,
|
|
542
|
+
rbln_config.quantization,
|
|
543
|
+
)
|
|
544
|
+
compiled_models["image_prefill"] = compiled_image_prefill
|
|
851
545
|
|
|
852
|
-
compiled_models = {"prefill": compiled_prefill, "image_prefill": compiled_image_prefill}
|
|
853
546
|
wrapped_model.phase = "decode"
|
|
854
|
-
for batch_size, dec_compile_config in zip(
|
|
547
|
+
for batch_size, dec_compile_config in zip(
|
|
548
|
+
rbln_config.decoder_batch_sizes, rbln_compile_configs[rbln_config.decoder_runtime_idx :]
|
|
549
|
+
):
|
|
855
550
|
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
856
551
|
compiled_decoder = compile_model(
|
|
857
552
|
wrapped_model,
|
|
@@ -872,32 +567,45 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
|
872
567
|
) -> List[rebel.Runtime]:
|
|
873
568
|
expected_model_names = [
|
|
874
569
|
"prefill",
|
|
875
|
-
"image_prefill",
|
|
876
570
|
*[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
|
|
877
571
|
]
|
|
572
|
+
if rbln_config.use_image_prefill:
|
|
573
|
+
expected_model_names.insert(1, "image_prefill")
|
|
574
|
+
|
|
878
575
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
879
576
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
880
577
|
|
|
881
|
-
|
|
578
|
+
ret_val = [
|
|
882
579
|
rebel.Runtime(
|
|
883
580
|
compiled_models[0],
|
|
884
581
|
tensor_type="pt",
|
|
885
582
|
device=rbln_config.device_map["prefill"],
|
|
886
583
|
activate_profiler=rbln_config.activate_profiler,
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
893
|
-
),
|
|
894
|
-
*[
|
|
584
|
+
timeout=rbln_config.timeout,
|
|
585
|
+
)
|
|
586
|
+
]
|
|
587
|
+
if rbln_config.use_image_prefill:
|
|
588
|
+
ret_val.append(
|
|
895
589
|
rebel.Runtime(
|
|
896
|
-
compiled_models[
|
|
590
|
+
compiled_models[1],
|
|
591
|
+
tensor_type="pt",
|
|
592
|
+
device=rbln_config.device_map["image_prefill"],
|
|
593
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
594
|
+
timeout=rbln_config.timeout,
|
|
595
|
+
),
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
ret_val.extend(
|
|
599
|
+
[
|
|
600
|
+
rebel.Runtime(
|
|
601
|
+
compiled_models[i + rbln_config.decoder_runtime_idx],
|
|
897
602
|
tensor_type="pt",
|
|
898
603
|
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
899
604
|
activate_profiler=rbln_config.activate_profiler,
|
|
605
|
+
timeout=rbln_config.timeout,
|
|
900
606
|
)
|
|
901
607
|
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
902
|
-
]
|
|
903
|
-
|
|
608
|
+
]
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
return ret_val
|