optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +2 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +45 -33
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
- optimum/rbln/diffusers/modeling_diffusers.py +16 -26
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
- optimum/rbln/diffusers/models/controlnet.py +13 -7
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
- optimum/rbln/modeling.py +33 -35
- optimum/rbln/modeling_base.py +45 -107
- optimum/rbln/transformers/__init__.py +39 -47
- optimum/rbln/transformers/configuration_generic.py +16 -13
- optimum/rbln/transformers/modeling_generic.py +18 -19
- optimum/rbln/transformers/modeling_rope_utils.py +1 -1
- optimum/rbln/transformers/models/__init__.py +46 -4
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +30 -12
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +231 -175
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +51 -5
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +87 -236
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
- optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +33 -4
- optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +51 -5
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
- optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
- optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +46 -25
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -2
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
- optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
- optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
- optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
- optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
- optimum/rbln/utils/model_utils.py +20 -0
- optimum/rbln/utils/submodule.py +6 -8
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/RECORD +127 -114
- /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
- /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -32,10 +32,6 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
|
|
32
32
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
33
33
|
from ....modeling import RBLNModel
|
34
34
|
from ....utils.logging import get_logger
|
35
|
-
from ..decoderonly.decoderonly_architecture import (
|
36
|
-
set_default_values,
|
37
|
-
validate_attention_method,
|
38
|
-
)
|
39
35
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
|
40
36
|
from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
|
41
37
|
from .gemma3_architecture import Gemma3ForCausalLMWrapper
|
@@ -215,15 +211,16 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
215
211
|
|
216
212
|
return model_kwargs
|
217
213
|
|
218
|
-
def get_image_features(self, pixel_values: torch.Tensor):
|
214
|
+
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
219
215
|
"""
|
220
216
|
Projects the last hidden state from the vision model into language model space.
|
221
217
|
|
222
218
|
Args:
|
223
|
-
pixel_values (`torch.FloatTensor
|
224
|
-
|
219
|
+
pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
|
220
|
+
The tensors corresponding to the input images.
|
221
|
+
|
225
222
|
Returns:
|
226
|
-
|
223
|
+
Image feature tensor of shape `(num_images, image_length, embed_dim)`.
|
227
224
|
"""
|
228
225
|
vision_outputs = self.vision_tower(pixel_values).last_hidden_state
|
229
226
|
image_features = self.multi_modal_projector(vision_outputs)
|
@@ -272,7 +269,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
|
|
272
269
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
273
270
|
position_ids: Optional[torch.Tensor] = None,
|
274
271
|
token_type_ids: Optional[torch.Tensor] = None,
|
275
|
-
**lm_kwargs,
|
272
|
+
**lm_kwargs: Dict[str, Any],
|
276
273
|
) -> Union[Tuple, RBLNDecoderOnlyOutput]:
|
277
274
|
# prefill
|
278
275
|
if cache_position is None:
|
@@ -352,16 +349,17 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
352
349
|
# Find image start positions
|
353
350
|
image_starts = [
|
354
351
|
s
|
355
|
-
for s in range(seq_len - self.prefill_chunk_size + 1)
|
356
|
-
if torch.all(token_type_ids[:, s : s + self.prefill_chunk_size] == 1)
|
352
|
+
for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
|
353
|
+
if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
|
357
354
|
]
|
358
355
|
|
359
356
|
# Initialize padded tensors
|
360
357
|
padded_input_len = seq_len
|
361
358
|
for image_start in image_starts:
|
362
359
|
pad_needed = (
|
363
|
-
self.
|
364
|
-
|
360
|
+
self.rbln_config.prefill_chunk_size
|
361
|
+
- (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
|
362
|
+
) % self.rbln_config.prefill_chunk_size
|
365
363
|
padded_input_len += pad_needed
|
366
364
|
total_padding = padded_input_len - seq_len
|
367
365
|
|
@@ -390,7 +388,9 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
390
388
|
src_pos = image_start
|
391
389
|
|
392
390
|
# Padding
|
393
|
-
pad_needed = (
|
391
|
+
pad_needed = (
|
392
|
+
self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
|
393
|
+
) % self.rbln_config.prefill_chunk_size
|
394
394
|
if pad_needed and dest_pos < padded_input_len:
|
395
395
|
position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
|
396
396
|
last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
|
@@ -399,21 +399,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
399
399
|
|
400
400
|
# Image segment
|
401
401
|
if src_pos < seq_len and src_pos == image_start:
|
402
|
-
inputs_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = inputs[
|
403
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
402
|
+
inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
|
403
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
404
404
|
]
|
405
|
-
attention_mask_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = attention_mask[
|
406
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
405
|
+
attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
|
406
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
407
407
|
]
|
408
|
-
position_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = position_ids[
|
409
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
408
|
+
position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
|
409
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
410
410
|
]
|
411
|
-
token_type_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = token_type_ids[
|
412
|
-
:, src_pos : src_pos + self.prefill_chunk_size
|
411
|
+
token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
|
412
|
+
:, src_pos : src_pos + self.rbln_config.prefill_chunk_size
|
413
413
|
]
|
414
|
-
dest_pos += self.prefill_chunk_size
|
415
|
-
src_pos += self.prefill_chunk_size
|
416
|
-
last_pos_id = position_ids[0, image_start + self.prefill_chunk_size - 1].item()
|
414
|
+
dest_pos += self.rbln_config.prefill_chunk_size
|
415
|
+
src_pos += self.rbln_config.prefill_chunk_size
|
416
|
+
last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
|
417
417
|
|
418
418
|
return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
|
419
419
|
|
@@ -444,11 +444,13 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
444
444
|
|
445
445
|
seq_len = inputs.shape[1]
|
446
446
|
# Initialize attention mask for chunked processing
|
447
|
-
if self.use_attention_mask:
|
447
|
+
if self.rbln_config.use_attention_mask:
|
448
448
|
chunked_attention_mask = (
|
449
449
|
torch.ones(1, seq_len, dtype=torch.float32)
|
450
|
-
if self.use_position_ids
|
451
|
-
else torch.zeros(
|
450
|
+
if self.rbln_config.use_position_ids
|
451
|
+
else torch.zeros(
|
452
|
+
1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
|
453
|
+
)
|
452
454
|
)
|
453
455
|
else:
|
454
456
|
chunked_attention_mask = None
|
@@ -467,21 +469,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
467
469
|
)
|
468
470
|
|
469
471
|
query_length = inputs.shape[1]
|
470
|
-
if query_length > self.max_seq_len:
|
472
|
+
if query_length > self.rbln_config.max_seq_len:
|
471
473
|
raise ValueError(
|
472
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
474
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
473
475
|
)
|
474
476
|
|
475
477
|
# Align attention_mask to compiled shape
|
476
|
-
if self.use_position_ids:
|
478
|
+
if self.rbln_config.use_position_ids:
|
477
479
|
chunked_attention_mask = torch.nn.functional.pad(
|
478
|
-
chunked_attention_mask, (0, self.max_seq_len - query_length)
|
480
|
+
chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
|
479
481
|
)
|
480
482
|
|
481
483
|
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
482
484
|
padding_size = 0
|
483
|
-
if query_length % self.prefill_chunk_size != 0:
|
484
|
-
padding_size = (self.prefill_chunk_size - query_length) % self.prefill_chunk_size
|
485
|
+
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
486
|
+
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
485
487
|
# inputs_embeds
|
486
488
|
if inputs.dim() == 3:
|
487
489
|
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
@@ -549,65 +551,71 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
549
551
|
) = self._prepare_prefill_inputs(
|
550
552
|
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
551
553
|
)
|
552
|
-
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
|
553
554
|
if not is_external_block_tables:
|
554
555
|
local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
|
556
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
|
555
557
|
|
556
|
-
if self.use_attention_mask and self.use_position_ids:
|
557
|
-
chunked_attention_mask = torch.zeros(1, self.max_seq_len, dtype=torch.float32)
|
558
|
+
if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
|
559
|
+
chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
|
558
560
|
|
559
561
|
# Process input in chunks of size `prefill_chunk_size`
|
560
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
562
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
561
563
|
# Extract the current chunk of inputs and cache positions
|
562
|
-
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
563
|
-
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
564
|
+
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
565
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
564
566
|
position_ids_chunk = (
|
565
|
-
position_ids[:, step : step + self.prefill_chunk_size]
|
567
|
+
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
568
|
+
if position_ids is not None
|
569
|
+
else None
|
566
570
|
)
|
567
571
|
|
568
572
|
# Not used in Gemma3 yet.
|
569
|
-
if self.use_attention_mask:
|
570
|
-
if self.use_position_ids:
|
571
|
-
chunked_attention_mask[0, step : step + self.prefill_chunk_size] = self.dec_attn_mask[
|
572
|
-
batch_idx, step : step + self.prefill_chunk_size
|
573
|
+
if self.rbln_config.use_attention_mask:
|
574
|
+
if self.rbln_config.use_position_ids:
|
575
|
+
chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = self.dec_attn_mask[
|
576
|
+
batch_idx, step : step + self.rbln_config.prefill_chunk_size
|
573
577
|
]
|
574
578
|
else:
|
575
579
|
# Update attention mask to ensure proper causal behavior
|
576
|
-
if step >= self.prefill_chunk_size:
|
577
|
-
chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
578
|
-
chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] =
|
580
|
+
if step >= self.rbln_config.prefill_chunk_size:
|
581
|
+
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
582
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = (
|
583
|
+
self.causal_mask
|
584
|
+
)
|
579
585
|
|
580
586
|
# Define query position
|
581
587
|
query_position = (
|
582
588
|
torch.sum(
|
583
|
-
chunked_attention_mask[0][step : step + self.prefill_chunk_size],
|
589
|
+
chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
|
590
|
+
dim=-1,
|
591
|
+
dtype=torch.int16,
|
584
592
|
).squeeze(0)
|
585
593
|
- 1
|
586
594
|
)
|
587
595
|
if token_type_ids_padded[:, step] == 1:
|
588
|
-
if torch.any(token_type_ids_padded[:, step : step + self.prefill_chunk_size] == 0):
|
596
|
+
if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
|
589
597
|
raise ValueError("All tokens of image_prefill should be the same image.")
|
590
598
|
else:
|
591
599
|
logits = self.image_prefill(
|
592
600
|
input_chunk,
|
593
|
-
chunked_attention_mask,
|
594
601
|
cache_pos_chunk,
|
595
|
-
position_ids_chunk,
|
596
|
-
query_position,
|
597
602
|
block_tables,
|
598
603
|
local_block_tables,
|
604
|
+
query_position,
|
605
|
+
chunked_attention_mask,
|
606
|
+
position_ids_chunk,
|
599
607
|
out=out_buffers,
|
600
608
|
)
|
601
609
|
else:
|
602
610
|
# Forward pass for the current chunk
|
603
611
|
logits = self.prefill(
|
604
612
|
input_chunk,
|
605
|
-
chunked_attention_mask,
|
606
613
|
cache_pos_chunk,
|
607
|
-
position_ids_chunk,
|
608
|
-
query_position,
|
609
614
|
block_tables,
|
610
615
|
local_block_tables,
|
616
|
+
query_position,
|
617
|
+
chunked_attention_mask,
|
618
|
+
position_ids_chunk,
|
611
619
|
out=out_buffers,
|
612
620
|
)
|
613
621
|
|
@@ -647,7 +655,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
647
655
|
if local_block_tables is not None
|
648
656
|
else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
649
657
|
)
|
650
|
-
if self.use_attention_mask and attention_mask is None:
|
658
|
+
if self.rbln_config.use_attention_mask and attention_mask is None:
|
651
659
|
for b_idx in range(batch_size):
|
652
660
|
decoding_step = cache_position[b_idx].item()
|
653
661
|
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
@@ -664,7 +672,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
664
672
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
665
673
|
attention_mask = attention_mask[: self.batch_size]
|
666
674
|
|
667
|
-
logits = self.decode(inputs,
|
675
|
+
logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
|
668
676
|
|
669
677
|
return RBLNDecoderOnlyOutput(logits=logits)
|
670
678
|
|
@@ -701,7 +709,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
701
709
|
dtype=torch.int16,
|
702
710
|
).fill_(-1)
|
703
711
|
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
704
|
-
|
705
712
|
self.prefill_decoder = RBLNGemma3RuntimeModel(
|
706
713
|
runtime=self.model[0],
|
707
714
|
image_prefill=self.model[1],
|
@@ -711,14 +718,9 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
711
718
|
batch_size=self.rbln_config.batch_size,
|
712
719
|
dec_attn_mask=dec_attn_mask,
|
713
720
|
block_tables=block_tables,
|
714
|
-
free_block_pool=free_block_pool,
|
715
|
-
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
716
721
|
vocab_size=self.config.vocab_size,
|
717
|
-
|
718
|
-
|
719
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
720
|
-
attn_impl=self.rbln_config.attn_impl,
|
721
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
722
|
+
free_block_pool=free_block_pool,
|
723
|
+
rbln_config=self.rbln_config,
|
722
724
|
)
|
723
725
|
|
724
726
|
self.decoders = {}
|
@@ -732,10 +734,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
732
734
|
dec_attn_mask=dec_attn_mask,
|
733
735
|
block_tables=block_tables,
|
734
736
|
free_block_pool=free_block_pool,
|
735
|
-
|
736
|
-
use_attention_mask=self.rbln_config.use_attention_mask,
|
737
|
-
attn_impl=self.rbln_config.attn_impl,
|
738
|
-
use_position_ids=self.rbln_config.use_position_ids,
|
737
|
+
rbln_config=self.rbln_config,
|
739
738
|
)
|
740
739
|
|
741
740
|
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
@@ -752,81 +751,17 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
752
751
|
return embed_tokens
|
753
752
|
|
754
753
|
@classmethod
|
755
|
-
def
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
kvcache_block_size: int,
|
764
|
-
kvcache_num_blocks: int,
|
765
|
-
num_key_value_heads: int,
|
766
|
-
num_hidden_layers: int,
|
767
|
-
hidden_size: int,
|
768
|
-
head_dim: int,
|
769
|
-
sliding_window: int,
|
770
|
-
sliding_window_pattern: int,
|
771
|
-
dec_batch_size: int,
|
772
|
-
):
|
773
|
-
if use_inputs_embeds:
|
774
|
-
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
775
|
-
else:
|
776
|
-
main_input = ("input_ids", [batch_size, query_length], "int64")
|
777
|
-
|
778
|
-
input_info = [
|
779
|
-
main_input,
|
780
|
-
(
|
781
|
-
"attention_mask",
|
782
|
-
[batch_size, 1, query_length, max_seq_len] if not use_position_ids else [batch_size, max_seq_len],
|
783
|
-
"float32",
|
784
|
-
),
|
785
|
-
(
|
786
|
-
"cache_position",
|
787
|
-
[batch_size, query_length],
|
788
|
-
"int32",
|
789
|
-
),
|
790
|
-
(
|
791
|
-
"position_ids",
|
792
|
-
[batch_size, query_length],
|
793
|
-
"int32",
|
794
|
-
),
|
795
|
-
]
|
796
|
-
|
797
|
-
if query_length > 1:
|
798
|
-
input_info.extend(
|
799
|
-
[
|
800
|
-
("query_position", [], "int16"),
|
801
|
-
]
|
802
|
-
)
|
803
|
-
|
804
|
-
max_block_cnt = max_seq_len // kvcache_block_size
|
805
|
-
|
806
|
-
if query_length > 1:
|
807
|
-
input_info.extend([("global_block_tables", [max_block_cnt], "int16")])
|
808
|
-
input_info.extend([("local_block_tables", [1], "int16")])
|
809
|
-
else:
|
810
|
-
input_info.extend([("global_block_tables", [batch_size, max_block_cnt], "int16")])
|
811
|
-
input_info.extend([("local_block_tables", [batch_size, 1], "int16")])
|
812
|
-
|
813
|
-
def is_sliding(layer_idx: int) -> bool:
|
814
|
-
return bool((layer_idx + 1) % sliding_window_pattern)
|
815
|
-
|
816
|
-
local_kvcache_shape = [dec_batch_size, num_key_value_heads, sliding_window, head_dim]
|
817
|
-
global_kvcache_shape = [kvcache_num_blocks, num_key_value_heads, kvcache_block_size, head_dim]
|
818
|
-
input_info.extend(
|
819
|
-
[
|
820
|
-
(
|
821
|
-
f"past_key_values_{i}",
|
822
|
-
local_kvcache_shape if is_sliding(i // 2) else global_kvcache_shape,
|
823
|
-
"float32",
|
824
|
-
)
|
825
|
-
for i in range(num_hidden_layers * 2)
|
754
|
+
def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
|
755
|
+
sliding_window = getattr(model_config, "sliding_window", None)
|
756
|
+
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
757
|
+
if sliding_window_pattern <= model_config.num_hidden_layers:
|
758
|
+
rbln_config.cache_impl = "hybrid"
|
759
|
+
rbln_config.sliding_window = sliding_window
|
760
|
+
rbln_config.sliding_window_layers = [
|
761
|
+
i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
|
826
762
|
]
|
827
|
-
)
|
828
763
|
|
829
|
-
return
|
764
|
+
return rbln_config
|
830
765
|
|
831
766
|
@classmethod
|
832
767
|
def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
@@ -847,102 +782,18 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
|
|
847
782
|
model_config: Optional["PretrainedConfig"] = None,
|
848
783
|
rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
|
849
784
|
) -> RBLNGemma3ForCausalLMConfig:
|
850
|
-
|
851
|
-
|
852
|
-
if rbln_config.max_seq_len is None:
|
853
|
-
raise ValueError("`max_seq_len` should be specified.")
|
854
|
-
|
855
|
-
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
856
|
-
attn_impl=rbln_config.attn_impl,
|
857
|
-
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
858
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
859
|
-
max_seq_len=rbln_config.max_seq_len,
|
860
|
-
)
|
861
|
-
|
862
|
-
validate_attention_method(
|
863
|
-
attn_impl=rbln_config.attn_impl,
|
864
|
-
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
865
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
866
|
-
max_seq_len=rbln_config.max_seq_len,
|
867
|
-
)
|
868
|
-
|
869
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
870
|
-
max_num_blocks = required_num_blocks
|
785
|
+
# Update rbln_config with super class
|
786
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
871
787
|
|
872
|
-
|
873
|
-
|
874
|
-
if max_num_blocks < flash_min_blocks:
|
875
|
-
max_num_blocks = flash_min_blocks
|
876
|
-
|
877
|
-
if max_num_blocks < rbln_config.batch_size:
|
878
|
-
raise RuntimeError(
|
879
|
-
f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
|
880
|
-
"Ensure the number of blocks is at least equal to the batch size."
|
881
|
-
)
|
882
|
-
|
883
|
-
if rbln_config.kvcache_num_blocks is None:
|
884
|
-
rbln_config.kvcache_num_blocks = max_num_blocks
|
885
|
-
elif rbln_config.kvcache_num_blocks > max_num_blocks:
|
886
|
-
logger.warning(
|
887
|
-
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
888
|
-
f" than the estimated maximum number of blocks ({max_num_blocks})."
|
889
|
-
"This can cause a failure during model compilation."
|
890
|
-
)
|
891
|
-
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
892
|
-
|
893
|
-
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
894
|
-
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
895
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
896
|
-
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
897
|
-
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
898
|
-
sliding_window = getattr(model_config, "sliding_window", None)
|
899
|
-
sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
|
900
|
-
|
901
|
-
prefill_input_info = cls.get_input_info(
|
902
|
-
batch_size=1,
|
903
|
-
query_length=rbln_config.prefill_chunk_size,
|
904
|
-
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
905
|
-
use_attention_mask=rbln_config.use_attention_mask,
|
906
|
-
use_position_ids=rbln_config.use_position_ids,
|
907
|
-
max_seq_len=rbln_config.max_seq_len,
|
908
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
909
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
910
|
-
num_key_value_heads=num_key_value_heads,
|
911
|
-
num_hidden_layers=num_hidden_layers,
|
912
|
-
hidden_size=hidden_size,
|
913
|
-
head_dim=head_dim,
|
914
|
-
sliding_window=sliding_window,
|
915
|
-
sliding_window_pattern=sliding_window_pattern,
|
916
|
-
dec_batch_size=max(rbln_config.decoder_batch_sizes),
|
917
|
-
)
|
918
|
-
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
788
|
+
# Assume that prefill compile config is at index 0
|
789
|
+
compile_cfgs = rbln_config.compile_cfgs
|
919
790
|
image_prefill_compile_config = RBLNCompileConfig(
|
920
|
-
compiled_model_name="image_prefill", input_info=
|
791
|
+
compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
|
921
792
|
)
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
batch_size=batch_size,
|
927
|
-
query_length=1,
|
928
|
-
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
929
|
-
use_attention_mask=rbln_config.use_attention_mask,
|
930
|
-
use_position_ids=rbln_config.use_position_ids,
|
931
|
-
max_seq_len=rbln_config.max_seq_len,
|
932
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
933
|
-
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
934
|
-
num_key_value_heads=num_key_value_heads,
|
935
|
-
num_hidden_layers=num_hidden_layers,
|
936
|
-
hidden_size=hidden_size,
|
937
|
-
head_dim=head_dim,
|
938
|
-
sliding_window=sliding_window,
|
939
|
-
sliding_window_pattern=sliding_window_pattern,
|
940
|
-
dec_batch_size=batch_size,
|
941
|
-
)
|
942
|
-
dec_compile_configs.append(
|
943
|
-
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
944
|
-
)
|
945
|
-
rbln_config.set_compile_cfgs([prefill_compile_config, image_prefill_compile_config, *dec_compile_configs])
|
793
|
+
# Insert image_prefill compile config at index 1
|
794
|
+
image_idx = 1
|
795
|
+
compile_cfgs.insert(image_idx, image_prefill_compile_config)
|
796
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
946
797
|
|
947
798
|
return rbln_config
|
948
799
|
|
@@ -16,4 +16,7 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
|
|
16
16
|
|
17
17
|
|
18
18
|
class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
-
|
19
|
+
"""
|
20
|
+
Configuration class for GPT-2 causal language model.
|
21
|
+
Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
|
22
|
+
"""
|
@@ -45,7 +45,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
|
|
45
45
|
)
|
46
46
|
new_layer = GPT2Layer(layer, new_self_attn)
|
47
47
|
new_layers.append(new_layer)
|
48
|
-
new_model = GPT2Model(
|
48
|
+
new_model = GPT2Model(
|
49
|
+
causal_lm.transformer,
|
50
|
+
new_layers,
|
51
|
+
max_seq_len=max_seq_len,
|
52
|
+
sliding_window_layers=self.sliding_window_layers,
|
53
|
+
)
|
49
54
|
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
50
55
|
return new_causal_lm
|
51
56
|
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional
|
15
|
+
from typing import Any, Dict, Optional
|
16
16
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
18
18
|
|
@@ -22,6 +22,16 @@ class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
|
|
22
22
|
|
23
23
|
|
24
24
|
class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
|
25
|
+
"""
|
26
|
+
Configuration class for RBLNIdefics3ForConditionalGeneration models.
|
27
|
+
|
28
|
+
This class extends `RBLNModelConfig` to include settings specific to the Idefics3 vision-language model optimized for RBLN devices.
|
29
|
+
It allows configuration of the batch size and separate configurations for the vision and text submodules.
|
30
|
+
|
31
|
+
Attributes:
|
32
|
+
submodules (List[str]): List of submodules included in the model. Defaults to `["vision_model", "text_model"]`.
|
33
|
+
"""
|
34
|
+
|
25
35
|
submodules = ["vision_model", "text_model"]
|
26
36
|
|
27
37
|
def __init__(
|
@@ -29,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
|
|
29
39
|
batch_size: Optional[int] = None,
|
30
40
|
vision_model: Optional[RBLNModelConfig] = None,
|
31
41
|
text_model: Optional[RBLNModelConfig] = None,
|
32
|
-
**kwargs,
|
42
|
+
**kwargs: Dict[str, Any],
|
33
43
|
):
|
34
44
|
"""
|
35
45
|
Args:
|
@@ -102,10 +102,9 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
|
|
102
102
|
subfolder: str,
|
103
103
|
rbln_config: RBLNModelConfig,
|
104
104
|
):
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
"""
|
105
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
106
|
+
# store the torch tensor, weight, etc. in this function.
|
107
|
+
|
109
108
|
save_dict = {}
|
110
109
|
save_dict["embeddings"] = model.get_input_embeddings().state_dict()
|
111
110
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
@@ -190,6 +189,44 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
|
|
190
189
|
|
191
190
|
|
192
191
|
class RBLNIdefics3ForConditionalGeneration(RBLNModel):
|
192
|
+
"""
|
193
|
+
RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
194
|
+
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
195
|
+
|
196
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
197
|
+
|
198
|
+
Important Note:
|
199
|
+
This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
|
200
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
201
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNIdefics3ForConditionalGenerationConfig class for details.
|
202
|
+
|
203
|
+
Examples:
|
204
|
+
```python
|
205
|
+
from optimum.rbln import RBLNIdefics3ForConditionalGeneration
|
206
|
+
|
207
|
+
model = RBLNIdefics3ForConditionalGeneration.from_pretrained(
|
208
|
+
"HuggingFaceM4/idefics3-8b",
|
209
|
+
export=True,
|
210
|
+
rbln_config={
|
211
|
+
"vision_model": {
|
212
|
+
"device": 0,
|
213
|
+
},
|
214
|
+
"text_model": {
|
215
|
+
"batch_size": 1,
|
216
|
+
"max_seq_len": 131_072,
|
217
|
+
"tensor_parallel_size": 8,
|
218
|
+
"use_inputs_embeds": True,
|
219
|
+
"attn_impl": "flash_attn",
|
220
|
+
"kvcache_partition_len": 16_384,
|
221
|
+
"device": [0, 1, 2, 3, 4, 5, 6, 7],
|
222
|
+
},
|
223
|
+
},
|
224
|
+
)
|
225
|
+
|
226
|
+
model.save_pretrained("compiled-idefics3-8b")
|
227
|
+
```
|
228
|
+
"""
|
229
|
+
|
193
230
|
auto_model_class = AutoModelForVision2Seq
|
194
231
|
_rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}]
|
195
232
|
_rbln_submodule_prefix = "model"
|
@@ -16,4 +16,27 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
|
|
16
16
|
|
17
17
|
|
18
18
|
class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
|
19
|
-
|
19
|
+
"""
|
20
|
+
Configuration class for RBLN Llama models.
|
21
|
+
|
22
|
+
This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
|
23
|
+
|
24
|
+
Example usage:
|
25
|
+
```python
|
26
|
+
from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
|
27
|
+
|
28
|
+
# Create a configuration object
|
29
|
+
config = RBLNLlamaForCausalLMConfig(
|
30
|
+
batch_size=1,
|
31
|
+
max_seq_len=4096,
|
32
|
+
tensor_parallel_size=4
|
33
|
+
)
|
34
|
+
|
35
|
+
# Use the configuration with from_pretrained
|
36
|
+
model = RBLNLlamaForCausalLM.from_pretrained(
|
37
|
+
"meta-llama/Llama-2-7b-hf",
|
38
|
+
export=True,
|
39
|
+
rbln_config=config
|
40
|
+
)
|
41
|
+
```
|
42
|
+
"""
|