optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +53 -33
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
- optimum/rbln/diffusers/modeling_diffusers.py +16 -26
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
- optimum/rbln/diffusers/models/controlnet.py +13 -7
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +33 -35
- optimum/rbln/modeling_base.py +45 -107
- optimum/rbln/transformers/__init__.py +39 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +18 -19
- optimum/rbln/transformers/modeling_rope_utils.py +1 -1
- optimum/rbln/transformers/models/__init__.py +46 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +30 -12
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +51 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +88 -236
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +33 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +51 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +47 -26
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -2
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
- optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a1.dist-info}/RECORD +127 -114
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -32,10 +32,6 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
|
|
32
32
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
33
33
|
from ....modeling import RBLNModel
|
34
34
|
from ....utils.logging import get_logger
|
35
|
-
from ..decoderonly.decoderonly_architecture import (
|
36
|
-
set_default_values,
|
37
|
-
validate_attention_method,
|
38
|
-
)
|
39
35
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
|
40
36
|
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
41
37
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
@@ -215,15 +211,16 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
215
211
|
|
216
212
|
return model_kwargs
|
217
213
|
|
218
|
-
def get_image_features(self, pixel_values: torch.Tensor):
|
214
|
+
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
219
215
|
"""
|
220
216
|
Projects the last hidden state from the vision model into language model space.
|
221
217
|
|
222
218
|
Args:
|
223
|
-
pixel_values (`torch.FloatTensor
|
224
|
-
|
219
|
+
pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
|
220
|
+
The tensors corresponding to the input images.
|
221
|
+
|
225
222
|
Returns:
|
226
|
-
|
223
|
+
Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
227
224
|
"""
|
228
225
|
vision_outputs = self.vision_tower(pixel_values).last_hidden_state
|
229
226
|
image_features = self.multi_modal_projector(vision_outputs)
|
@@ -272,7 +269,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
272
269
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
273
270
|
position_ids: Optional[torch.Tensor] = None,
|
274
271
|
token_type_ids: Optional[torch.Tensor] = None,
|
275
|
-
**lm_kwargs,
|
272
|
+
**lm_kwargs: Dict[str, Any],
|
276
273
|
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
277
274
|
# prefill
|
278
275
|
if cache_position is None:
|
@@ -352,16 +349,17 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
352
349
|
# Find image start positions
|
353
350
|
image_starts = [
|
354
351
|
s
|
355
|
-
for s in range(seq_len - self.prefill_chunk_size + 1)
|
356
|
-
if torch.all(token_type_ids[:, s : s + self.prefill_chunk_size] == 1)
|
352
|
+
for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
|
353
|
+
if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
|
357
354
|
]
|
358
355
|
|
359
356
|
# Initialize padded tensors
|
360
357
|
padded_input_len = seq_len
|
361
358
|
for image_start in image_starts:
|
362
359
|
pad_needed = (
|
363
|
-
self.
|
364
|
-
|
360
|
+
self.rbln_config.prefill_chunk_size
|
361
|
+
- (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
|
362
|
+
) % self.rbln_config.prefill_chunk_size
|
365
363
|
padded_input_len += pad_needed
|
366
364
|
total_padding = padded_input_len - seq_len
|
367
365
|
|
@@ -390,7 +388,9 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
390
388
|
src_pos = image_start
|
391
389
|
|
392
390
|
# Padding
|
393
|
-
pad_needed = (
|
391
|
+
pad_needed = (
|
392
|
+
self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
|
393
|
+
) % self.rbln_config.prefill_chunk_size
|
394
394
|
if pad_needed and dest_pos < padded_input_len:
|
395
395
|
position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
|
396
396
|
last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
|
@@ -399,21 +399,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
399
399
|
|
400
400
|
# Image segment
|
401
401
|
if src_pos < seq_len and src_pos == image_start:
|
402
|
-
inputs_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = inputs[
|
403
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
402
|
+
inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
|
403
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
404
404
|
]
|
405
|
-
attention_mask_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = attention_mask[
|
406
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
405
|
+
attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
|
406
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
407
407
|
]
|
408
|
-
position_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = position_ids[
|
409
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
408
|
+
position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
|
409
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
410
410
|
]
|
411
|
-
token_type_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = token_type_ids[
|
412
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
411
|
+
token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
|
412
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
413
413
|
]
|
414
|
-
dest_pos += self.prefill_chunk_size
|
415
|
-
src_pos += self.prefill_chunk_size
|
416
|
-
last_pos_id = position_ids[0, image_start + self.prefill_chunk_size - 1].item()
|
414
|
+
dest_pos += self.rbln_config.prefill_chunk_size
|
415
|
+
src_pos += self.rbln_config.prefill_chunk_size
|
416
|
+
last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
|
417
417
|
|
418
418
|
return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
|
419
419
|
|
@@ -444,11 +444,13 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
444
444
|
|
445
445
|
seq_len = inputs.shape[1]
|
446
446
|
# Initialize attention mask for chunked processing
|
447
|
-
if self.use_attention_mask:
|
447
|
+
if self.rbln_config.use_attention_mask:
|
448
448
|
chunked_attention_mask = (
|
449
449
|
torch.ones(1, seq_len, dtype=torch.float32)
|
450
|
-
if self.use_position_ids
|
451
|
-
else torch.zeros(
|
450
|
+
if self.rbln_config.use_position_ids
|
451
|
+
else torch.zeros(
|
452
|
+
1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
|
453
|
+
)
|
452
454
|
)
|
453
455
|
else:
|
454
456
|
chunked_attention_mask = None
|
@@ -467,20 +469,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
467
469
|
)
|
468
470
|
|
469
471
|
query_length = inputs.shape[1]
|
470
|
-
if query_length > self.max_seq_len:
|
472
|
+
if query_length > self.rbln_config.max_seq_len:
|
471
473
|
raise ValueError(
|
472
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
474
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
473
475
|
)
|
474
476
|
|
475
477
|
# Align attention_mask to compiled shape
|
476
|
-
if self.use_position_ids:
|
478
|
+
if self.rbln_config.use_position_ids:
|
477
479
|
chunked_attention_mask = torch.nn.functional.pad(
|
478
|
-
chunked_attention_mask, (0, self.max_seq_len - query_length)
|
480
|
+
chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
|
479
481
|
)
|
480
482
|
|
481
483
|
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
482
|
-
|
483
|
-
|
484
|
+
padding_size = 0
|
485
|
+
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
486
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
484
487
|
# inputs_embeds
|
485
488
|
if inputs.dim() == 3:
|
486
489
|
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
@@ -548,65 +551,71 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
548
551
|
) = self._prepare_prefill_inputs(
|
549
552
|
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
550
553
|
)
|
551
|
-
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
|
552
554
|
if not is_external_block_tables:
|
553
555
|
local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
|
556
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
|
554
557
|
|
555
|
-
if self.use_attention_mask and self.use_position_ids:
|
556
|
-
chunked_attention_mask = torch.zeros(1, self.max_seq_len, dtype=torch.float32)
|
558
|
+
if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
|
559
|
+
chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
|
557
560
|
|
558
561
|
# Process input in chunks of size `prefill_chunk_size`
|
559
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
562
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
560
563
|
# Extract the current chunk of inputs and cache positions
|
561
|
-
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
562
|
-
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
564
|
+
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
565
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
563
566
|
position_ids_chunk = (
|
564
|
-
position_ids[:, step : step + self.prefill_chunk_size]
|
567
|
+
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
568
|
+
if position_ids is not None
|
569
|
+
else None
|
565
570
|
)
|
566
571
|
|
567
572
|
# Not used in Gemma3 yet.
|
568
|
-
if self.use_attention_mask:
|
569
|
-
if self.use_position_ids:
|
570
|
-
chunked_attention_mask[0, step : step + self.prefill_chunk_size] = self.dec_attn_mask[
|
571
|
-
batch_idx, step : step + self.prefill_chunk_size
|
573
|
+
if self.rbln_config.use_attention_mask:
|
574
|
+
if self.rbln_config.use_position_ids:
|
575
|
+
chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = self.dec_attn_mask[
|
576
|
+
batch_idx, step : step + self.rbln_config.prefill_chunk_size
|
572
577
|
]
|
573
578
|
else:
|
574
579
|
# Update attention mask to ensure proper causal behavior
|
575
|
-
if step >= self.prefill_chunk_size:
|
576
|
-
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
577
|
-
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] =
|
580
|
+
if step >= self.rbln_config.prefill_chunk_size:
|
581
|
+
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
582
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = (
|
583
|
+
self.causal_mask
|
584
|
+
)
|
578
585
|
|
579
586
|
# Define query position
|
580
587
|
query_position = (
|
581
588
|
torch.sum(
|
582
|
-
chunked_attention_mask[0][step : step + self.prefill_chunk_size],
|
589
|
+
chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
|
590
|
+
dim=-1,
|
591
|
+
dtype=torch.int16,
|
583
592
|
).squeeze(0)
|
584
593
|
- 1
|
585
594
|
)
|
586
595
|
if token_type_ids_padded[:, step] == 1:
|
587
|
-
if torch.any(token_type_ids_padded[:, step : step + self.prefill_chunk_size] == 0):
|
596
|
+
if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
|
588
597
|
raise ValueError("All tokens of image_prefill should be the same image.")
|
589
598
|
else:
|
590
599
|
logits = self.image_prefill(
|
591
600
|
input_chunk,
|
592
|
-
chunked_attention_mask,
|
593
601
|
cache_pos_chunk,
|
594
|
-
position_ids_chunk,
|
595
|
-
query_position,
|
596
602
|
block_tables,
|
597
603
|
local_block_tables,
|
604
|
+
query_position,
|
605
|
+
chunked_attention_mask,
|
606
|
+
position_ids_chunk,
|
598
607
|
out=out_buffers,
|
599
608
|
)
|
600
609
|
else:
|
601
610
|
# Forward pass for the current chunk
|
602
611
|
logits = self.prefill(
|
603
612
|
input_chunk,
|
604
|
-
chunked_attention_mask,
|
605
613
|
cache_pos_chunk,
|
606
|
-
position_ids_chunk,
|
607
|
-
query_position,
|
608
614
|
block_tables,
|
609
615
|
local_block_tables,
|
616
|
+
query_position,
|
617
|
+
chunked_attention_mask,
|
618
|
+
position_ids_chunk,
|
610
619
|
out=out_buffers,
|
611
620
|
)
|
612
621
|
|
@@ -646,7 +655,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
646
655
|
if local_block_tables is not None
|
647
656
|
else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
648
657
|
)
|
649
|
-
if self.use_attention_mask and attention_mask is None:
|
658
|
+
if self.rbln_config.use_attention_mask and attention_mask is None:
|
650
659
|
for b_idx in range(batch_size):
|
651
660
|
decoding_step = cache_position[b_idx].item()
|
652
661
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
@@ -663,7 +672,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
663
672
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
664
673
|
attention_mask = attention_mask[: self.batch_size]
|
665
674
|
|
666
|
-
logits = self.decode(inputs,
|
675
|
+
logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
|
667
676
|
|
668
677
|
return RBLNDecoderOnlyOutput(logits=logits)
|
669
678
|
|
@@ -700,7 +709,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
700
709
|
dtype=torch.int16,
|
701
710
|
).fill_(-1)
|
702
711
|
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
703
|
-
|
704
712
|
self.prefill_decoder = RBLNGemma3RuntimeModel(
|
705
713
|
runtime=self.model[0],
|
706
714
|
image_prefill=self.model[1],
|
@@ -710,14 +718,9 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
710
718
|
batch_size=self.rbln_config.batch_size,
|
711
719
|
dec_attn_mask=dec_attn_mask,
|
712
720
|
block_tables=block_tables,
|
713
|
-
free_block_pool=free_block_pool,
|
714
|
-
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
715
721
|
vocab_size=self.config.vocab_size,
|
716
|
-
|
717
|
-
|
718
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
719
|
-
attn_impl=self.rbln_config.attn_impl,
|
720
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
722
|
+
free_block_pool=free_block_pool,
|
723
|
+
rbln_config=self.rbln_config,
|
721
724
|
)
|
722
725
|
|
723
726
|
self.decoders = {}
|
@@ -731,10 +734,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
731
734
|
dec_attn_mask=dec_attn_mask,
|
732
735
|
block_tables=block_tables,
|
733
736
|
free_block_pool=free_block_pool,
|
734
|
-
|
735
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
736
|
-
attn_impl=self.rbln_config.attn_impl,
|
737
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
737
|
+
rbln_config=self.rbln_config,
|
738
738
|
)
|
739
739
|
|
740
740
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
@@ -751,81 +751,17 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
751
751
|
return embed_tokens
|
752
752
|
|
753
753
|
@classmethod
|
754
|
-
def
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
kvcache_block_size: int,
|
763
|
-
kvcache_num_blocks: int,
|
764
|
-
num_key_value_heads: int,
|
765
|
-
num_hidden_layers: int,
|
766
|
-
hidden_size: int,
|
767
|
-
head_dim: int,
|
768
|
-
sliding_window: int,
|
769
|
-
sliding_window_pattern: int,
|
770
|
-
dec_batch_size: int,
|
771
|
-
):
|
772
|
-
if use_inputs_embeds:
|
773
|
-
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
774
|
-
else:
|
775
|
-
main_input = ("input_ids", [batch_size, query_length], "int64")
|
776
|
-
|
777
|
-
input_info = [
|
778
|
-
main_input,
|
779
|
-
(
|
780
|
-
"attention_mask",
|
781
|
-
[batch_size, 1, query_length, max_seq_len] if not use_position_ids else [batch_size, max_seq_len],
|
782
|
-
"float32",
|
783
|
-
),
|
784
|
-
(
|
785
|
-
"cache_position",
|
786
|
-
[batch_size, query_length],
|
787
|
-
"int32",
|
788
|
-
),
|
789
|
-
(
|
790
|
-
"position_ids",
|
791
|
-
[batch_size, query_length],
|
792
|
-
"int32",
|
793
|
-
),
|
794
|
-
]
|
795
|
-
|
796
|
-
if query_length > 1:
|
797
|
-
input_info.extend(
|
798
|
-
[
|
799
|
-
("query_position", [], "int16"),
|
800
|
-
]
|
801
|
-
)
|
802
|
-
|
803
|
-
max_block_cnt = max_seq_len // kvcache_block_size
|
804
|
-
|
805
|
-
if query_length > 1:
|
806
|
-
input_info.extend([("global_block_tables", [max_block_cnt], "int16")])
|
807
|
-
input_info.extend([("local_block_tables", [1], "int16")])
|
808
|
-
else:
|
809
|
-
input_info.extend([("global_block_tables", [batch_size, max_block_cnt], "int16")])
|
810
|
-
input_info.extend([("local_block_tables", [batch_size, 1], "int16")])
|
811
|
-
|
812
|
-
def is_sliding(layer_idx: int) -> bool:
|
813
|
-
return bool((layer_idx + 1) % sliding_window_pattern)
|
814
|
-
|
815
|
-
local_kvcache_shape = [dec_batch_size, num_key_value_heads, sliding_window, head_dim]
|
816
|
-
global_kvcache_shape = [kvcache_num_blocks, num_key_value_heads, kvcache_block_size, head_dim]
|
817
|
-
input_info.extend(
|
818
|
-
[
|
819
|
-
(
|
820
|
-
f"past_key_values_{i}",
|
821
|
-
local_kvcache_shape if is_sliding(i // 2) else global_kvcache_shape,
|
822
|
-
"float32",
|
823
|
-
)
|
824
|
-
for i in range(num_hidden_layers * 2)
|
754
|
+
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
755
|
+
sliding_window = getattr(model_config, "sliding_window", None)
|
756
|
+
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
757
|
+
if sliding_window_pattern <= model_config.num_hidden_layers:
|
758
|
+
rbln_config.cache_impl = "hybrid"
|
759
|
+
rbln_config.sliding_window = sliding_window
|
760
|
+
rbln_config.sliding_window_layers = [
|
761
|
+
i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
|
825
762
|
]
|
826
|
-
)
|
827
763
|
|
828
|
-
return
|
764
|
+
return rbln_config
|
829
765
|
|
830
766
|
@classmethod
|
831
767
|
def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
@@ -846,102 +782,18 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
846
782
|
model_config: Optional["PretrainedConfig"] = None,
|
847
783
|
rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
|
848
784
|
) -> RBLNGemma3ForCausalLMConfig:
|
849
|
-
|
850
|
-
|
851
|
-
if rbln_config.max_seq_len is None:
|
852
|
-
raise ValueError("`max_seq_len` should be specified.")
|
853
|
-
|
854
|
-
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
855
|
-
attn_impl=rbln_config.attn_impl,
|
856
|
-
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
857
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
858
|
-
max_seq_len=rbln_config.max_seq_len,
|
859
|
-
)
|
860
|
-
|
861
|
-
validate_attention_method(
|
862
|
-
attn_impl=rbln_config.attn_impl,
|
863
|
-
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
864
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
865
|
-
max_seq_len=rbln_config.max_seq_len,
|
866
|
-
)
|
867
|
-
|
868
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
869
|
-
max_num_blocks = required_num_blocks
|
785
|
+
# Update rbln_config with super class
|
786
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
870
787
|
|
871
|
-
|
872
|
-
|
873
|
-
if max_num_blocks < flash_min_blocks:
|
874
|
-
max_num_blocks = flash_min_blocks
|
875
|
-
|
876
|
-
if max_num_blocks < rbln_config.batch_size:
|
877
|
-
raise RuntimeError(
|
878
|
-
f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
|
879
|
-
"Ensure the number of blocks is at least equal to the batch size."
|
880
|
-
)
|
881
|
-
|
882
|
-
if rbln_config.kvcache_num_blocks is None:
|
883
|
-
rbln_config.kvcache_num_blocks = max_num_blocks
|
884
|
-
elif rbln_config.kvcache_num_blocks > max_num_blocks:
|
885
|
-
logger.warning(
|
886
|
-
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
887
|
-
f" than the estimated maximum number of blocks ({max_num_blocks})."
|
888
|
-
"This can cause a failure during model compilation."
|
889
|
-
)
|
890
|
-
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
891
|
-
|
892
|
-
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
893
|
-
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
894
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
895
|
-
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
896
|
-
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
897
|
-
sliding_window = getattr(model_config, "sliding_window", None)
|
898
|
-
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
899
|
-
|
900
|
-
prefill_input_info = cls.get_input_info(
|
901
|
-
batch_size=1,
|
902
|
-
query_length=rbln_config.prefill_chunk_size,
|
903
|
-
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
904
|
-
use_attention_mask=rbln_config.use_attention_mask,
|
905
|
-
use_position_ids=rbln_config.use_position_ids,
|
906
|
-
max_seq_len=rbln_config.max_seq_len,
|
907
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
908
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
909
|
-
num_key_value_heads=num_key_value_heads,
|
910
|
-
num_hidden_layers=num_hidden_layers,
|
911
|
-
hidden_size=hidden_size,
|
912
|
-
head_dim=head_dim,
|
913
|
-
sliding_window=sliding_window,
|
914
|
-
sliding_window_pattern=sliding_window_pattern,
|
915
|
-
dec_batch_size=max(rbln_config.decoder_batch_sizes),
|
916
|
-
)
|
917
|
-
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
788
|
+
# Assume that prefill compile config is at index 0
|
789
|
+
compile_cfgs = rbln_config.compile_cfgs
|
918
790
|
image_prefill_compile_config = RBLNCompileConfig(
|
919
|
-
compiled_model_name="image_prefill", input_info=
|
791
|
+
compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
|
920
792
|
)
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
batch_size=batch_size,
|
926
|
-
query_length=1,
|
927
|
-
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
928
|
-
use_attention_mask=rbln_config.use_attention_mask,
|
929
|
-
use_position_ids=rbln_config.use_position_ids,
|
930
|
-
max_seq_len=rbln_config.max_seq_len,
|
931
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
932
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
933
|
-
num_key_value_heads=num_key_value_heads,
|
934
|
-
num_hidden_layers=num_hidden_layers,
|
935
|
-
hidden_size=hidden_size,
|
936
|
-
head_dim=head_dim,
|
937
|
-
sliding_window=sliding_window,
|
938
|
-
sliding_window_pattern=sliding_window_pattern,
|
939
|
-
dec_batch_size=batch_size,
|
940
|
-
)
|
941
|
-
dec_compile_configs.append(
|
942
|
-
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
943
|
-
)
|
944
|
-
rbln_config.set_compile_cfgs([prefill_compile_config, image_prefill_compile_config, *dec_compile_configs])
|
793
|
+
# Insert image_prefill compile config at index 1
|
794
|
+
image_idx = 1
|
795
|
+
compile_cfgs.insert(image_idx, image_prefill_compile_config)
|
796
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
945
797
|
|
946
798
|
return rbln_config
|
947
799
|
|
@@ -16,4 +16,7 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
|
|
16
16
|
|
17
17
|
|
18
18
|
class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
-
|
19
|
+
"""
|
20
|
+
Configuration class for GPT-2 causal language model.
|
21
|
+
Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
|
22
|
+
"""
|
@@ -45,7 +45,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
45
45
|
)
|
46
46
|
new_layer = GPT2Layer(layer, new_self_attn)
|
47
47
|
new_layers.append(new_layer)
|
48
|
-
new_model = GPT2Model(
|
48
|
+
new_model = GPT2Model(
|
49
|
+
causal_lm.transformer,
|
50
|
+
new_layers,
|
51
|
+
max_seq_len=max_seq_len,
|
52
|
+
sliding_window_layers=self.sliding_window_layers,
|
53
|
+
)
|
49
54
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
50
55
|
return new_causal_lm
|
51
56
|
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional
|
15
|
+
from typing import Any, Dict, Optional
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
|
@@ -22,6 +22,16 @@ class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
|
|
22
22
|
|
23
23
|
|
24
24
|
class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
|
25
|
+
"""
|
26
|
+
Configuration class for RBLNIdefics3ForConditionalGeneration models.
|
27
|
+
|
28
|
+
This class extends `RBLNModelConfig` to include settings specific to the Idefics3 vision-language model optimized for RBLN devices.
|
29
|
+
It allows configuration of the batch size and separate configurations for the vision and text submodules.
|
30
|
+
|
31
|
+
Attributes:
|
32
|
+
submodules (List[str]): List of submodules included in the model. Defaults to `["vision_model", "text_model"]`.
|
33
|
+
"""
|
34
|
+
|
25
35
|
submodules = ["vision_model", "text_model"]
|
26
36
|
|
27
37
|
def __init__(
|
@@ -29,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
|
|
29
39
|
batch_size: Optional[int] = None,
|
30
40
|
vision_model: Optional[RBLNModelConfig] = None,
|
31
41
|
text_model: Optional[RBLNModelConfig] = None,
|
32
|
-
**kwargs,
|
42
|
+
**kwargs: Dict[str, Any],
|
33
43
|
):
|
34
44
|
"""
|
35
45
|
Args:
|
@@ -102,10 +102,9 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
|
|
102
102
|
subfolder: str,
|
103
103
|
rbln_config: RBLNModelConfig,
|
104
104
|
):
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
"""
|
105
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
106
|
+
# store the torch tensor, weight, etc. in this function.
|
107
|
+
|
109
108
|
save_dict = {}
|
110
109
|
save_dict["embeddings"] = model.get_input_embeddings().state_dict()
|
111
110
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
@@ -190,6 +189,44 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
|
|
190
189
|
|
191
190
|
|
192
191
|
class RBLNIdefics3ForConditionalGeneration(RBLNModel):
|
192
|
+
"""
|
193
|
+
RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
194
|
+
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
195
|
+
|
196
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
197
|
+
|
198
|
+
Important Note:
|
199
|
+
This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
|
200
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
201
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNIdefics3ForConditionalGenerationConfig class for details.
|
202
|
+
|
203
|
+
Examples:
|
204
|
+
```python
|
205
|
+
from optimum.rbln import RBLNIdefics3ForConditionalGeneration
|
206
|
+
|
207
|
+
model = RBLNIdefics3ForConditionalGeneration.from_pretrained(
|
208
|
+
"HuggingFaceM4/idefics3-8b",
|
209
|
+
export=True,
|
210
|
+
rbln_config={
|
211
|
+
"vision_model": {
|
212
|
+
"device": 0,
|
213
|
+
},
|
214
|
+
"text_model": {
|
215
|
+
"batch_size": 1,
|
216
|
+
"max_seq_len": 131_072,
|
217
|
+
"tensor_parallel_size": 8,
|
218
|
+
"use_inputs_embeds": True,
|
219
|
+
"attn_impl": "flash_attn",
|
220
|
+
"kvcache_partition_len": 16_384,
|
221
|
+
"device": [0, 1, 2, 3, 4, 5, 6, 7],
|
222
|
+
},
|
223
|
+
},
|
224
|
+
)
|
225
|
+
|
226
|
+
model.save_pretrained("compiled-idefics3-8b")
|
227
|
+
```
|
228
|
+
"""
|
229
|
+
|
193
230
|
auto_model_class = AutoModelForVision2Seq
|
194
231
|
_rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}]
|
195
232
|
_rbln_submodule_prefix = "model"
|
@@ -16,4 +16,27 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
|
|
16
16
|
|
17
17
|
|
18
18
|
class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
-
|
19
|
+
"""
|
20
|
+
Configuration class for RBLN Llama models.
|
21
|
+
|
22
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
23
|
+
|
24
|
+
Example usage:
|
25
|
+
```python
|
26
|
+
from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
|
27
|
+
|
28
|
+
# Create a configuration object
|
29
|
+
config = RBLNLlamaForCausalLMConfig(
|
30
|
+
batch_size=1,
|
31
|
+
max_seq_len=4096,
|
32
|
+
tensor_parallel_size=4
|
33
|
+
)
|
34
|
+
|
35
|
+
# Use the configuration with from_pretrained
|
36
|
+
model = RBLNLlamaForCausalLM.from_pretrained(
|
37
|
+
"meta-llama/Llama-2-7b-hf",
|
38
|
+
export=True,
|
39
|
+
rbln_config=config
|
40
|
+
)
|
41
|
+
```
|
42
|
+
"""
|