optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +53 -33
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
- optimum/rbln/diffusers/modeling_diffusers.py +16 -26
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
- optimum/rbln/diffusers/models/controlnet.py +13 -7
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +33 -35
- optimum/rbln/modeling_base.py +45 -107
- optimum/rbln/transformers/__init__.py +39 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +18 -19
- optimum/rbln/transformers/modeling_rope_utils.py +5 -2
- optimum/rbln/transformers/models/__init__.py +46 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
- optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +106 -236
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +58 -27
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +47 -2
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
- optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
- optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/METADATA +2 -2
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/RECORD +130 -117
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/licenses/LICENSE +0 -0
@@ -11,6 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
import importlib
|
14
15
|
import inspect
|
15
16
|
from collections import deque
|
16
17
|
from dataclasses import dataclass
|
@@ -32,10 +33,6 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
|
|
32
33
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
33
34
|
from ....modeling import RBLNModel
|
34
35
|
from ....utils.logging import get_logger
|
35
|
-
from ..decoderonly.decoderonly_architecture import (
|
36
|
-
set_default_values,
|
37
|
-
validate_attention_method,
|
38
|
-
)
|
39
36
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
|
40
37
|
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
41
38
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
@@ -127,6 +124,23 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
127
124
|
def can_generate(self):
|
128
125
|
return True
|
129
126
|
|
127
|
+
@classmethod
|
128
|
+
def get_pytorch_model(cls, *args, **kwargs):
|
129
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
130
|
+
|
131
|
+
with no_init_weights():
|
132
|
+
model_cls_name = model.model.language_model.__class__.__name__
|
133
|
+
causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
|
134
|
+
causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
|
135
|
+
new_language_model = causal_model_cls(model.model.language_model.config)
|
136
|
+
|
137
|
+
new_language_model.lm_head = model.lm_head
|
138
|
+
new_language_model.model = model.model.language_model
|
139
|
+
model.model.language_model = new_language_model
|
140
|
+
model.lm_head = None
|
141
|
+
del model.lm_head
|
142
|
+
return model
|
143
|
+
|
130
144
|
def __post_init__(self, **kwargs):
|
131
145
|
self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
|
132
146
|
self.language_model = self.rbln_submodules[1]
|
@@ -215,15 +229,16 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
215
229
|
|
216
230
|
return model_kwargs
|
217
231
|
|
218
|
-
def get_image_features(self, pixel_values: torch.Tensor):
|
232
|
+
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
219
233
|
"""
|
220
234
|
Projects the last hidden state from the vision model into language model space.
|
221
235
|
|
222
236
|
Args:
|
223
|
-
pixel_values (`torch.FloatTensor
|
224
|
-
|
237
|
+
pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
|
238
|
+
The tensors corresponding to the input images.
|
239
|
+
|
225
240
|
Returns:
|
226
|
-
|
241
|
+
Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
227
242
|
"""
|
228
243
|
vision_outputs = self.vision_tower(pixel_values).last_hidden_state
|
229
244
|
image_features = self.multi_modal_projector(vision_outputs)
|
@@ -272,7 +287,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
272
287
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
273
288
|
position_ids: Optional[torch.Tensor] = None,
|
274
289
|
token_type_ids: Optional[torch.Tensor] = None,
|
275
|
-
**lm_kwargs,
|
290
|
+
**lm_kwargs: Dict[str, Any],
|
276
291
|
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
277
292
|
# prefill
|
278
293
|
if cache_position is None:
|
@@ -352,16 +367,17 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
352
367
|
# Find image start positions
|
353
368
|
image_starts = [
|
354
369
|
s
|
355
|
-
for s in range(seq_len - self.prefill_chunk_size + 1)
|
356
|
-
if torch.all(token_type_ids[:, s : s + self.prefill_chunk_size] == 1)
|
370
|
+
for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
|
371
|
+
if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
|
357
372
|
]
|
358
373
|
|
359
374
|
# Initialize padded tensors
|
360
375
|
padded_input_len = seq_len
|
361
376
|
for image_start in image_starts:
|
362
377
|
pad_needed = (
|
363
|
-
self.
|
364
|
-
|
378
|
+
self.rbln_config.prefill_chunk_size
|
379
|
+
- (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
|
380
|
+
) % self.rbln_config.prefill_chunk_size
|
365
381
|
padded_input_len += pad_needed
|
366
382
|
total_padding = padded_input_len - seq_len
|
367
383
|
|
@@ -390,7 +406,9 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
390
406
|
src_pos = image_start
|
391
407
|
|
392
408
|
# Padding
|
393
|
-
pad_needed = (
|
409
|
+
pad_needed = (
|
410
|
+
self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
|
411
|
+
) % self.rbln_config.prefill_chunk_size
|
394
412
|
if pad_needed and dest_pos < padded_input_len:
|
395
413
|
position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
|
396
414
|
last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
|
@@ -399,21 +417,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
399
417
|
|
400
418
|
# Image segment
|
401
419
|
if src_pos < seq_len and src_pos == image_start:
|
402
|
-
inputs_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = inputs[
|
403
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
420
|
+
inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
|
421
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
404
422
|
]
|
405
|
-
attention_mask_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = attention_mask[
|
406
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
423
|
+
attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
|
424
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
407
425
|
]
|
408
|
-
position_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = position_ids[
|
409
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
426
|
+
position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
|
427
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
410
428
|
]
|
411
|
-
token_type_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = token_type_ids[
|
412
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
429
|
+
token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
|
430
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
413
431
|
]
|
414
|
-
dest_pos += self.prefill_chunk_size
|
415
|
-
src_pos += self.prefill_chunk_size
|
416
|
-
last_pos_id = position_ids[0, image_start + self.prefill_chunk_size - 1].item()
|
432
|
+
dest_pos += self.rbln_config.prefill_chunk_size
|
433
|
+
src_pos += self.rbln_config.prefill_chunk_size
|
434
|
+
last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
|
417
435
|
|
418
436
|
return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
|
419
437
|
|
@@ -444,11 +462,13 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
444
462
|
|
445
463
|
seq_len = inputs.shape[1]
|
446
464
|
# Initialize attention mask for chunked processing
|
447
|
-
if self.use_attention_mask:
|
465
|
+
if self.rbln_config.use_attention_mask:
|
448
466
|
chunked_attention_mask = (
|
449
467
|
torch.ones(1, seq_len, dtype=torch.float32)
|
450
|
-
if self.use_position_ids
|
451
|
-
else torch.zeros(
|
468
|
+
if self.rbln_config.use_position_ids
|
469
|
+
else torch.zeros(
|
470
|
+
1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
|
471
|
+
)
|
452
472
|
)
|
453
473
|
else:
|
454
474
|
chunked_attention_mask = None
|
@@ -467,20 +487,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
467
487
|
)
|
468
488
|
|
469
489
|
query_length = inputs.shape[1]
|
470
|
-
if query_length > self.max_seq_len:
|
490
|
+
if query_length > self.rbln_config.max_seq_len:
|
471
491
|
raise ValueError(
|
472
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
492
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
473
493
|
)
|
474
494
|
|
475
495
|
# Align attention_mask to compiled shape
|
476
|
-
if self.use_position_ids:
|
496
|
+
if self.rbln_config.use_position_ids:
|
477
497
|
chunked_attention_mask = torch.nn.functional.pad(
|
478
|
-
chunked_attention_mask, (0, self.max_seq_len - query_length)
|
498
|
+
chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
|
479
499
|
)
|
480
500
|
|
481
501
|
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
482
|
-
|
483
|
-
|
502
|
+
padding_size = 0
|
503
|
+
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
504
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
484
505
|
# inputs_embeds
|
485
506
|
if inputs.dim() == 3:
|
486
507
|
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
@@ -548,65 +569,71 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
548
569
|
) = self._prepare_prefill_inputs(
|
549
570
|
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
550
571
|
)
|
551
|
-
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
|
552
572
|
if not is_external_block_tables:
|
553
573
|
local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
|
574
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
|
554
575
|
|
555
|
-
if self.use_attention_mask and self.use_position_ids:
|
556
|
-
chunked_attention_mask = torch.zeros(1, self.max_seq_len, dtype=torch.float32)
|
576
|
+
if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
|
577
|
+
chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
|
557
578
|
|
558
579
|
# Process input in chunks of size `prefill_chunk_size`
|
559
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
580
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
560
581
|
# Extract the current chunk of inputs and cache positions
|
561
|
-
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
562
|
-
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
582
|
+
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
583
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
563
584
|
position_ids_chunk = (
|
564
|
-
position_ids[:, step : step + self.prefill_chunk_size]
|
585
|
+
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
586
|
+
if position_ids is not None
|
587
|
+
else None
|
565
588
|
)
|
566
589
|
|
567
590
|
# Not used in Gemma3 yet.
|
568
|
-
if self.use_attention_mask:
|
569
|
-
if self.use_position_ids:
|
570
|
-
chunked_attention_mask[0, step : step + self.prefill_chunk_size] = self.dec_attn_mask[
|
571
|
-
batch_idx, step : step + self.prefill_chunk_size
|
591
|
+
if self.rbln_config.use_attention_mask:
|
592
|
+
if self.rbln_config.use_position_ids:
|
593
|
+
chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = self.dec_attn_mask[
|
594
|
+
batch_idx, step : step + self.rbln_config.prefill_chunk_size
|
572
595
|
]
|
573
596
|
else:
|
574
597
|
# Update attention mask to ensure proper causal behavior
|
575
|
-
if step >= self.prefill_chunk_size:
|
576
|
-
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
577
|
-
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] =
|
598
|
+
if step >= self.rbln_config.prefill_chunk_size:
|
599
|
+
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
600
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = (
|
601
|
+
self.causal_mask
|
602
|
+
)
|
578
603
|
|
579
604
|
# Define query position
|
580
605
|
query_position = (
|
581
606
|
torch.sum(
|
582
|
-
chunked_attention_mask[0][step : step + self.prefill_chunk_size],
|
607
|
+
chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
|
608
|
+
dim=-1,
|
609
|
+
dtype=torch.int16,
|
583
610
|
).squeeze(0)
|
584
611
|
- 1
|
585
612
|
)
|
586
613
|
if token_type_ids_padded[:, step] == 1:
|
587
|
-
if torch.any(token_type_ids_padded[:, step : step + self.prefill_chunk_size] == 0):
|
614
|
+
if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
|
588
615
|
raise ValueError("All tokens of image_prefill should be the same image.")
|
589
616
|
else:
|
590
617
|
logits = self.image_prefill(
|
591
618
|
input_chunk,
|
592
|
-
chunked_attention_mask,
|
593
619
|
cache_pos_chunk,
|
594
|
-
position_ids_chunk,
|
595
|
-
query_position,
|
596
620
|
block_tables,
|
597
621
|
local_block_tables,
|
622
|
+
query_position,
|
623
|
+
chunked_attention_mask,
|
624
|
+
position_ids_chunk,
|
598
625
|
out=out_buffers,
|
599
626
|
)
|
600
627
|
else:
|
601
628
|
# Forward pass for the current chunk
|
602
629
|
logits = self.prefill(
|
603
630
|
input_chunk,
|
604
|
-
chunked_attention_mask,
|
605
631
|
cache_pos_chunk,
|
606
|
-
position_ids_chunk,
|
607
|
-
query_position,
|
608
632
|
block_tables,
|
609
633
|
local_block_tables,
|
634
|
+
query_position,
|
635
|
+
chunked_attention_mask,
|
636
|
+
position_ids_chunk,
|
610
637
|
out=out_buffers,
|
611
638
|
)
|
612
639
|
|
@@ -646,7 +673,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
646
673
|
if local_block_tables is not None
|
647
674
|
else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
648
675
|
)
|
649
|
-
if self.use_attention_mask and attention_mask is None:
|
676
|
+
if self.rbln_config.use_attention_mask and attention_mask is None:
|
650
677
|
for b_idx in range(batch_size):
|
651
678
|
decoding_step = cache_position[b_idx].item()
|
652
679
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
@@ -663,7 +690,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
663
690
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
664
691
|
attention_mask = attention_mask[: self.batch_size]
|
665
692
|
|
666
|
-
logits = self.decode(inputs,
|
693
|
+
logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
|
667
694
|
|
668
695
|
return RBLNDecoderOnlyOutput(logits=logits)
|
669
696
|
|
@@ -700,7 +727,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
700
727
|
dtype=torch.int16,
|
701
728
|
).fill_(-1)
|
702
729
|
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
703
|
-
|
704
730
|
self.prefill_decoder = RBLNGemma3RuntimeModel(
|
705
731
|
runtime=self.model[0],
|
706
732
|
image_prefill=self.model[1],
|
@@ -710,14 +736,9 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
710
736
|
batch_size=self.rbln_config.batch_size,
|
711
737
|
dec_attn_mask=dec_attn_mask,
|
712
738
|
block_tables=block_tables,
|
713
|
-
free_block_pool=free_block_pool,
|
714
|
-
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
715
739
|
vocab_size=self.config.vocab_size,
|
716
|
-
|
717
|
-
|
718
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
719
|
-
attn_impl=self.rbln_config.attn_impl,
|
720
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
740
|
+
free_block_pool=free_block_pool,
|
741
|
+
rbln_config=self.rbln_config,
|
721
742
|
)
|
722
743
|
|
723
744
|
self.decoders = {}
|
@@ -731,10 +752,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
731
752
|
dec_attn_mask=dec_attn_mask,
|
732
753
|
block_tables=block_tables,
|
733
754
|
free_block_pool=free_block_pool,
|
734
|
-
|
735
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
736
|
-
attn_impl=self.rbln_config.attn_impl,
|
737
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
755
|
+
rbln_config=self.rbln_config,
|
738
756
|
)
|
739
757
|
|
740
758
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
@@ -751,81 +769,17 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
751
769
|
return embed_tokens
|
752
770
|
|
753
771
|
@classmethod
|
754
|
-
def
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
kvcache_block_size: int,
|
763
|
-
kvcache_num_blocks: int,
|
764
|
-
num_key_value_heads: int,
|
765
|
-
num_hidden_layers: int,
|
766
|
-
hidden_size: int,
|
767
|
-
head_dim: int,
|
768
|
-
sliding_window: int,
|
769
|
-
sliding_window_pattern: int,
|
770
|
-
dec_batch_size: int,
|
771
|
-
):
|
772
|
-
if use_inputs_embeds:
|
773
|
-
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
774
|
-
else:
|
775
|
-
main_input = ("input_ids", [batch_size, query_length], "int64")
|
776
|
-
|
777
|
-
input_info = [
|
778
|
-
main_input,
|
779
|
-
(
|
780
|
-
"attention_mask",
|
781
|
-
[batch_size, 1, query_length, max_seq_len] if not use_position_ids else [batch_size, max_seq_len],
|
782
|
-
"float32",
|
783
|
-
),
|
784
|
-
(
|
785
|
-
"cache_position",
|
786
|
-
[batch_size, query_length],
|
787
|
-
"int32",
|
788
|
-
),
|
789
|
-
(
|
790
|
-
"position_ids",
|
791
|
-
[batch_size, query_length],
|
792
|
-
"int32",
|
793
|
-
),
|
794
|
-
]
|
795
|
-
|
796
|
-
if query_length > 1:
|
797
|
-
input_info.extend(
|
798
|
-
[
|
799
|
-
("query_position", [], "int16"),
|
800
|
-
]
|
801
|
-
)
|
802
|
-
|
803
|
-
max_block_cnt = max_seq_len // kvcache_block_size
|
804
|
-
|
805
|
-
if query_length > 1:
|
806
|
-
input_info.extend([("global_block_tables", [max_block_cnt], "int16")])
|
807
|
-
input_info.extend([("local_block_tables", [1], "int16")])
|
808
|
-
else:
|
809
|
-
input_info.extend([("global_block_tables", [batch_size, max_block_cnt], "int16")])
|
810
|
-
input_info.extend([("local_block_tables", [batch_size, 1], "int16")])
|
811
|
-
|
812
|
-
def is_sliding(layer_idx: int) -> bool:
|
813
|
-
return bool((layer_idx + 1) % sliding_window_pattern)
|
814
|
-
|
815
|
-
local_kvcache_shape = [dec_batch_size, num_key_value_heads, sliding_window, head_dim]
|
816
|
-
global_kvcache_shape = [kvcache_num_blocks, num_key_value_heads, kvcache_block_size, head_dim]
|
817
|
-
input_info.extend(
|
818
|
-
[
|
819
|
-
(
|
820
|
-
f"past_key_values_{i}",
|
821
|
-
local_kvcache_shape if is_sliding(i // 2) else global_kvcache_shape,
|
822
|
-
"float32",
|
823
|
-
)
|
824
|
-
for i in range(num_hidden_layers * 2)
|
772
|
+
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
773
|
+
sliding_window = getattr(model_config, "sliding_window", None)
|
774
|
+
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
775
|
+
if sliding_window_pattern <= model_config.num_hidden_layers:
|
776
|
+
rbln_config.cache_impl = "hybrid"
|
777
|
+
rbln_config.sliding_window = sliding_window
|
778
|
+
rbln_config.sliding_window_layers = [
|
779
|
+
i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
|
825
780
|
]
|
826
|
-
)
|
827
781
|
|
828
|
-
return
|
782
|
+
return rbln_config
|
829
783
|
|
830
784
|
@classmethod
|
831
785
|
def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
@@ -846,102 +800,18 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
846
800
|
model_config: Optional["PretrainedConfig"] = None,
|
847
801
|
rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
|
848
802
|
) -> RBLNGemma3ForCausalLMConfig:
|
849
|
-
|
850
|
-
|
851
|
-
if rbln_config.max_seq_len is None:
|
852
|
-
raise ValueError("`max_seq_len` should be specified.")
|
853
|
-
|
854
|
-
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
855
|
-
attn_impl=rbln_config.attn_impl,
|
856
|
-
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
857
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
858
|
-
max_seq_len=rbln_config.max_seq_len,
|
859
|
-
)
|
860
|
-
|
861
|
-
validate_attention_method(
|
862
|
-
attn_impl=rbln_config.attn_impl,
|
863
|
-
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
864
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
865
|
-
max_seq_len=rbln_config.max_seq_len,
|
866
|
-
)
|
803
|
+
# Update rbln_config with super class
|
804
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
867
805
|
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
if rbln_config.attn_impl == "flash_attn":
|
872
|
-
flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
|
873
|
-
if max_num_blocks < flash_min_blocks:
|
874
|
-
max_num_blocks = flash_min_blocks
|
875
|
-
|
876
|
-
if max_num_blocks < rbln_config.batch_size:
|
877
|
-
raise RuntimeError(
|
878
|
-
f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
|
879
|
-
"Ensure the number of blocks is at least equal to the batch size."
|
880
|
-
)
|
881
|
-
|
882
|
-
if rbln_config.kvcache_num_blocks is None:
|
883
|
-
rbln_config.kvcache_num_blocks = max_num_blocks
|
884
|
-
elif rbln_config.kvcache_num_blocks > max_num_blocks:
|
885
|
-
logger.warning(
|
886
|
-
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
887
|
-
f" than the estimated maximum number of blocks ({max_num_blocks})."
|
888
|
-
"This can cause a failure during model compilation."
|
889
|
-
)
|
890
|
-
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
891
|
-
|
892
|
-
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
893
|
-
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
894
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
895
|
-
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
896
|
-
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
897
|
-
sliding_window = getattr(model_config, "sliding_window", None)
|
898
|
-
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
899
|
-
|
900
|
-
prefill_input_info = cls.get_input_info(
|
901
|
-
batch_size=1,
|
902
|
-
query_length=rbln_config.prefill_chunk_size,
|
903
|
-
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
904
|
-
use_attention_mask=rbln_config.use_attention_mask,
|
905
|
-
use_position_ids=rbln_config.use_position_ids,
|
906
|
-
max_seq_len=rbln_config.max_seq_len,
|
907
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
908
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
909
|
-
num_key_value_heads=num_key_value_heads,
|
910
|
-
num_hidden_layers=num_hidden_layers,
|
911
|
-
hidden_size=hidden_size,
|
912
|
-
head_dim=head_dim,
|
913
|
-
sliding_window=sliding_window,
|
914
|
-
sliding_window_pattern=sliding_window_pattern,
|
915
|
-
dec_batch_size=max(rbln_config.decoder_batch_sizes),
|
916
|
-
)
|
917
|
-
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
806
|
+
# Assume that prefill compile config is at index 0
|
807
|
+
compile_cfgs = rbln_config.compile_cfgs
|
918
808
|
image_prefill_compile_config = RBLNCompileConfig(
|
919
|
-
compiled_model_name="image_prefill", input_info=
|
809
|
+
compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
|
920
810
|
)
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
batch_size=batch_size,
|
926
|
-
query_length=1,
|
927
|
-
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
928
|
-
use_attention_mask=rbln_config.use_attention_mask,
|
929
|
-
use_position_ids=rbln_config.use_position_ids,
|
930
|
-
max_seq_len=rbln_config.max_seq_len,
|
931
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
932
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
933
|
-
num_key_value_heads=num_key_value_heads,
|
934
|
-
num_hidden_layers=num_hidden_layers,
|
935
|
-
hidden_size=hidden_size,
|
936
|
-
head_dim=head_dim,
|
937
|
-
sliding_window=sliding_window,
|
938
|
-
sliding_window_pattern=sliding_window_pattern,
|
939
|
-
dec_batch_size=batch_size,
|
940
|
-
)
|
941
|
-
dec_compile_configs.append(
|
942
|
-
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
943
|
-
)
|
944
|
-
rbln_config.set_compile_cfgs([prefill_compile_config, image_prefill_compile_config, *dec_compile_configs])
|
811
|
+
# Insert image_prefill compile config at index 1
|
812
|
+
image_idx = 1
|
813
|
+
compile_cfgs.insert(image_idx, image_prefill_compile_config)
|
814
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
945
815
|
|
946
816
|
return rbln_config
|
947
817
|
|
@@ -16,4 +16,7 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
|
|
16
16
|
|
17
17
|
|
18
18
|
class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
-
|
19
|
+
"""
|
20
|
+
Configuration class for GPT-2 causal language model.
|
21
|
+
Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
|
22
|
+
"""
|
@@ -45,7 +45,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
45
45
|
)
|
46
46
|
new_layer = GPT2Layer(layer, new_self_attn)
|
47
47
|
new_layers.append(new_layer)
|
48
|
-
new_model = GPT2Model(
|
48
|
+
new_model = GPT2Model(
|
49
|
+
causal_lm.transformer,
|
50
|
+
new_layers,
|
51
|
+
max_seq_len=max_seq_len,
|
52
|
+
sliding_window_layers=self.sliding_window_layers,
|
53
|
+
)
|
49
54
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
50
55
|
return new_causal_lm
|
51
56
|
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional
|
15
|
+
from typing import Any, Dict, Optional
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
|
@@ -22,6 +22,16 @@ class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
|
|
22
22
|
|
23
23
|
|
24
24
|
class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
|
25
|
+
"""
|
26
|
+
Configuration class for RBLNIdefics3ForConditionalGeneration models.
|
27
|
+
|
28
|
+
This class extends `RBLNModelConfig` to include settings specific to the Idefics3 vision-language model optimized for RBLN devices.
|
29
|
+
It allows configuration of the batch size and separate configurations for the vision and text submodules.
|
30
|
+
|
31
|
+
Attributes:
|
32
|
+
submodules (List[str]): List of submodules included in the model. Defaults to `["vision_model", "text_model"]`.
|
33
|
+
"""
|
34
|
+
|
25
35
|
submodules = ["vision_model", "text_model"]
|
26
36
|
|
27
37
|
def __init__(
|
@@ -29,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
|
|
29
39
|
batch_size: Optional[int] = None,
|
30
40
|
vision_model: Optional[RBLNModelConfig] = None,
|
31
41
|
text_model: Optional[RBLNModelConfig] = None,
|
32
|
-
**kwargs,
|
42
|
+
**kwargs: Dict[str, Any],
|
33
43
|
):
|
34
44
|
"""
|
35
45
|
Args:
|
@@ -102,10 +102,9 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
|
|
102
102
|
subfolder: str,
|
103
103
|
rbln_config: RBLNModelConfig,
|
104
104
|
):
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
"""
|
105
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
106
|
+
# store the torch tensor, weight, etc. in this function.
|
107
|
+
|
109
108
|
save_dict = {}
|
110
109
|
save_dict["embeddings"] = model.get_input_embeddings().state_dict()
|
111
110
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
@@ -190,6 +189,44 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
|
|
190
189
|
|
191
190
|
|
192
191
|
class RBLNIdefics3ForConditionalGeneration(RBLNModel):
|
192
|
+
"""
|
193
|
+
RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
194
|
+
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
195
|
+
|
196
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
197
|
+
|
198
|
+
Important Note:
|
199
|
+
This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
|
200
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
201
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNIdefics3ForConditionalGenerationConfig class for details.
|
202
|
+
|
203
|
+
Examples:
|
204
|
+
```python
|
205
|
+
from optimum.rbln import RBLNIdefics3ForConditionalGeneration
|
206
|
+
|
207
|
+
model = RBLNIdefics3ForConditionalGeneration.from_pretrained(
|
208
|
+
"HuggingFaceM4/idefics3-8b",
|
209
|
+
export=True,
|
210
|
+
rbln_config={
|
211
|
+
"vision_model": {
|
212
|
+
"device": 0,
|
213
|
+
},
|
214
|
+
"text_model": {
|
215
|
+
"batch_size": 1,
|
216
|
+
"max_seq_len": 131_072,
|
217
|
+
"tensor_parallel_size": 8,
|
218
|
+
"use_inputs_embeds": True,
|
219
|
+
"attn_impl": "flash_attn",
|
220
|
+
"kvcache_partition_len": 16_384,
|
221
|
+
"device": [0, 1, 2, 3, 4, 5, 6, 7],
|
222
|
+
},
|
223
|
+
},
|
224
|
+
)
|
225
|
+
|
226
|
+
model.save_pretrained("compiled-idefics3-8b")
|
227
|
+
```
|
228
|
+
"""
|
229
|
+
|
193
230
|
auto_model_class = AutoModelForVision2Seq
|
194
231
|
_rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}]
|
195
232
|
_rbln_submodule_prefix = "model"
|