optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.4a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +12 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +16 -6
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +12 -8
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +242 -109
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +1 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +6 -45
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +0 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +10 -1
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +92 -43
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +207 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +140 -46
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +7 -1
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +1 -1
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -25
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +8 -9
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -7
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +1 -1
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +17 -1
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +1 -1
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +9 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +7 -1
- optimum/rbln/utils/runtime_utils.py +32 -0
- optimum/rbln/utils/submodule.py +3 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/METADATA +2 -2
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/RECORD +106 -99
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/WHEEL +1 -1
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.4a2.dist-info}/licenses/LICENSE +0 -0
|
@@ -88,8 +88,12 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
88
88
|
def setup_runtime(self):
|
|
89
89
|
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
90
90
|
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
91
|
-
|
|
92
|
-
|
|
91
|
+
if self.rbln_config.use_position_ids:
|
|
92
|
+
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=self.dtype)
|
|
93
|
+
else:
|
|
94
|
+
dec_attn_mask = torch.zeros(
|
|
95
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype
|
|
96
|
+
)
|
|
93
97
|
|
|
94
98
|
common_kwargs = {
|
|
95
99
|
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
@@ -97,12 +101,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
97
101
|
"dec_attn_mask": dec_attn_mask,
|
|
98
102
|
"page_table_manager": page_table_manager,
|
|
99
103
|
"rbln_config": self.rbln_config,
|
|
104
|
+
"config": self.config,
|
|
100
105
|
}
|
|
101
106
|
self.prefill_decoder = RBLNRuntimeModel(
|
|
102
107
|
runtime=self.model[0],
|
|
103
108
|
phase="prefill",
|
|
104
109
|
batch_size=self.rbln_config.batch_size,
|
|
105
|
-
|
|
110
|
+
logits_last_dim=self.logits_last_dim,
|
|
106
111
|
**common_kwargs,
|
|
107
112
|
)
|
|
108
113
|
if self.can_generate():
|
|
@@ -119,12 +124,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
119
124
|
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
120
125
|
|
|
121
126
|
@property
|
|
122
|
-
def
|
|
123
|
-
return
|
|
124
|
-
1,
|
|
125
|
-
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
126
|
-
self.config.hidden_size,
|
|
127
|
-
)
|
|
127
|
+
def logits_last_dim(self):
|
|
128
|
+
return self.config.hidden_size
|
|
128
129
|
|
|
129
130
|
@classmethod
|
|
130
131
|
def get_quantized_model(
|
|
@@ -216,7 +217,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
216
217
|
return self.rbln_config.kvcache_num_blocks
|
|
217
218
|
|
|
218
219
|
@classmethod
|
|
219
|
-
def
|
|
220
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
220
221
|
return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
|
|
221
222
|
|
|
222
223
|
@classmethod
|
|
@@ -272,7 +273,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
272
273
|
@classmethod
|
|
273
274
|
@torch.inference_mode()
|
|
274
275
|
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
275
|
-
wrapped_model = cls.
|
|
276
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
276
277
|
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
277
278
|
|
|
278
279
|
# Here we use meta tensor, for the memory efficiency.
|
|
@@ -340,10 +341,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
340
341
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
341
342
|
model_config: PretrainedConfig,
|
|
342
343
|
):
|
|
343
|
-
num_attention_heads = getattr(model_config, "n_head", None) or
|
|
344
|
+
num_attention_heads = getattr(model_config, "n_head", None) or model_config.num_attention_heads
|
|
344
345
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
345
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or
|
|
346
|
-
hidden_size = getattr(model_config, "n_embd", None) or
|
|
346
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
|
|
347
|
+
hidden_size = getattr(model_config, "n_embd", None) or model_config.hidden_size
|
|
347
348
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
348
349
|
is_prefill = query_length > 1
|
|
349
350
|
|
|
@@ -439,10 +440,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
439
440
|
# Returns:
|
|
440
441
|
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
441
442
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
443
|
+
rbln_config.sliding_window = model_config.sliding_window
|
|
444
|
+
sliding_window_layers = []
|
|
445
|
+
|
|
446
|
+
for i in range(model_config.num_hidden_layers):
|
|
447
|
+
if hasattr(model_config, "layer_types"):
|
|
448
|
+
if model_config.layer_types[i] == "sliding_attention":
|
|
449
|
+
sliding_window_layers.append(i)
|
|
450
|
+
else:
|
|
451
|
+
sliding_window_layers.append(i)
|
|
452
|
+
|
|
453
|
+
rbln_config.sliding_window_layers = sliding_window_layers
|
|
454
|
+
|
|
455
|
+
rbln_config.cache_impl = (
|
|
456
|
+
"sliding_window" if len(sliding_window_layers) == model_config.num_hidden_layers else "hybrid"
|
|
445
457
|
)
|
|
458
|
+
return rbln_config
|
|
446
459
|
|
|
447
460
|
@classmethod
|
|
448
461
|
def _update_attention_config(
|
|
@@ -466,13 +479,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
466
479
|
|
|
467
480
|
# Update kvcache_num_blocks based on the attention implementation.
|
|
468
481
|
if rbln_config.attn_impl == "flash_attn":
|
|
469
|
-
estimated_max_num_blocks = cls.
|
|
470
|
-
|
|
471
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
|
|
472
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
473
|
-
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
474
|
-
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
475
|
-
num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
|
|
482
|
+
estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
|
|
483
|
+
model=model, model_config=model_config, rbln_config=rbln_config
|
|
476
484
|
)
|
|
477
485
|
|
|
478
486
|
if rbln_config.kvcache_num_blocks is None:
|
|
@@ -511,7 +519,6 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
511
519
|
f" than the required number of blocks ({num_full_blocks})."
|
|
512
520
|
"This can cause a failure during model compilation."
|
|
513
521
|
)
|
|
514
|
-
|
|
515
522
|
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
516
523
|
|
|
517
524
|
return rbln_config
|
|
@@ -531,8 +538,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
531
538
|
if rbln_config.max_seq_len is None:
|
|
532
539
|
raise ValueError("`max_seq_len` should be specified.")
|
|
533
540
|
|
|
534
|
-
|
|
535
|
-
|
|
541
|
+
layer_types = getattr(model_config, "layer_types", None)
|
|
542
|
+
all_full_attention = layer_types is not None and all(t == "full_attention" for t in layer_types)
|
|
543
|
+
|
|
544
|
+
if (
|
|
545
|
+
getattr(model_config, "sliding_window", None) is not None
|
|
546
|
+
and getattr(model_config, "use_sliding_window", True)
|
|
547
|
+
and not all_full_attention
|
|
536
548
|
):
|
|
537
549
|
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
538
550
|
if rbln_config.sliding_window is not None:
|
|
@@ -608,34 +620,74 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
608
620
|
input_ids: Optional[torch.LongTensor] = None,
|
|
609
621
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
610
622
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
623
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
611
624
|
position_embed: Optional[torch.Tensor] = None,
|
|
625
|
+
output_hidden_states: Optional[bool] = None,
|
|
612
626
|
**kwargs,
|
|
613
|
-
) ->
|
|
627
|
+
) -> BaseModelOutputWithPast:
|
|
628
|
+
"""
|
|
629
|
+
Args:
|
|
630
|
+
input_ids (torch.LongTensor, optional): The input IDs to the model.
|
|
631
|
+
inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
|
|
632
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
633
|
+
kwargs (dict[str, Any], optional): Additional keyword arguments.
|
|
634
|
+
|
|
635
|
+
Returns:
|
|
636
|
+
Dataclass containing the last hidden states of the model.
|
|
637
|
+
"""
|
|
614
638
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
615
639
|
batch_size = inputs.shape[0]
|
|
640
|
+
position_embed = kwargs.get("position_embed", None)
|
|
616
641
|
|
|
617
642
|
if batch_size != self.rbln_config.batch_size:
|
|
618
643
|
raise ValueError(
|
|
619
644
|
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
620
645
|
)
|
|
621
646
|
|
|
647
|
+
output_hidden_states = (
|
|
648
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
649
|
+
)
|
|
650
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
651
|
+
raise ValueError(
|
|
652
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
653
|
+
f"Please compile again with the correct argument."
|
|
654
|
+
)
|
|
655
|
+
|
|
622
656
|
all_last_hidden_states = []
|
|
657
|
+
all_hidden_states = (
|
|
658
|
+
tuple(
|
|
659
|
+
torch.zeros(
|
|
660
|
+
self.rbln_config.batch_size,
|
|
661
|
+
inputs.shape[1],
|
|
662
|
+
self.config.hidden_size,
|
|
663
|
+
dtype=self.rbln_config.torch_dtype,
|
|
664
|
+
)
|
|
665
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
666
|
+
)
|
|
667
|
+
if output_hidden_states
|
|
668
|
+
else None
|
|
669
|
+
)
|
|
623
670
|
for b_idx in range(self.rbln_config.batch_size):
|
|
624
671
|
query_length = (
|
|
625
672
|
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
|
|
626
673
|
)
|
|
627
674
|
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
628
|
-
|
|
629
|
-
inputs[b_idx : b_idx + 1],
|
|
675
|
+
outputs = self.prefill_decoder(
|
|
676
|
+
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
|
677
|
+
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
|
630
678
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
679
|
+
position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
|
|
631
680
|
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
632
681
|
cache_position=cache_position,
|
|
633
682
|
batch_idx=b_idx,
|
|
634
|
-
)
|
|
635
|
-
all_last_hidden_states.append(
|
|
683
|
+
)
|
|
684
|
+
all_last_hidden_states.append(outputs.logits)
|
|
685
|
+
if self.rbln_config.output_hidden_states:
|
|
686
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
687
|
+
all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
|
|
636
688
|
|
|
637
689
|
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
638
|
-
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
690
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states)
|
|
639
691
|
|
|
640
692
|
|
|
641
693
|
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
|
|
@@ -661,12 +713,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
661
713
|
auto_model_class = AutoModelForCausalLM
|
|
662
714
|
|
|
663
715
|
@property
|
|
664
|
-
def
|
|
665
|
-
return
|
|
666
|
-
1,
|
|
667
|
-
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
668
|
-
self.config.vocab_size,
|
|
669
|
-
)
|
|
716
|
+
def logits_last_dim(self):
|
|
717
|
+
return self.config.vocab_size
|
|
670
718
|
|
|
671
719
|
@classmethod
|
|
672
720
|
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
@@ -731,6 +779,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
731
779
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
732
780
|
lora_int_ids: Optional[torch.Tensor] = None,
|
|
733
781
|
return_dict: Optional[torch.Tensor] = None,
|
|
782
|
+
output_hidden_states: Optional[bool] = None,
|
|
734
783
|
**kwargs,
|
|
735
784
|
) -> Tuple[torch.FloatTensor]:
|
|
736
785
|
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
|
|
@@ -754,24 +803,55 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
754
803
|
)
|
|
755
804
|
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
756
805
|
|
|
806
|
+
output_hidden_states = (
|
|
807
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
808
|
+
)
|
|
809
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
810
|
+
raise ValueError(
|
|
811
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
812
|
+
f"Please compile again with the correct argument."
|
|
813
|
+
)
|
|
814
|
+
|
|
757
815
|
# Prefill
|
|
758
816
|
if cache_position is None:
|
|
759
817
|
logits = []
|
|
760
818
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
761
819
|
batch_size = inputs.shape[0]
|
|
820
|
+
input_len = inputs.shape[1]
|
|
821
|
+
if batch_size > self.rbln_config.batch_size:
|
|
822
|
+
raise ValueError(
|
|
823
|
+
f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
|
|
824
|
+
)
|
|
825
|
+
if input_len > self.rbln_config.max_seq_len:
|
|
826
|
+
raise ValueError(
|
|
827
|
+
f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
all_hidden_states = (
|
|
831
|
+
tuple(
|
|
832
|
+
torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.torch_dtype)
|
|
833
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
834
|
+
)
|
|
835
|
+
if self.rbln_config.output_hidden_states
|
|
836
|
+
else None
|
|
837
|
+
)
|
|
762
838
|
for b_idx in range(batch_size):
|
|
763
839
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
764
|
-
|
|
840
|
+
outputs = self.prefill_decoder(
|
|
765
841
|
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
|
766
842
|
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
|
767
843
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
844
|
+
position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
|
|
768
845
|
cache_position=cache_position,
|
|
769
846
|
batch_idx=b_idx,
|
|
770
847
|
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
771
848
|
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
772
849
|
)
|
|
773
|
-
padded_cache_lengths[b_idx] +=
|
|
774
|
-
logits.append(
|
|
850
|
+
padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
|
|
851
|
+
logits.append(outputs.logits)
|
|
852
|
+
if self.rbln_config.output_hidden_states:
|
|
853
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
854
|
+
all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
|
|
775
855
|
logits = torch.cat(logits, dim=0)
|
|
776
856
|
# Decoder
|
|
777
857
|
else:
|
|
@@ -783,17 +863,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
783
863
|
f"Available batch sizes are: {list(self.decoders.keys())}. "
|
|
784
864
|
f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
|
|
785
865
|
)
|
|
786
|
-
|
|
866
|
+
if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
|
|
867
|
+
raise ValueError(
|
|
868
|
+
f"Cache position exceeds the maximum sequence length.\n"
|
|
869
|
+
f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
|
|
870
|
+
f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
|
|
871
|
+
f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
|
|
872
|
+
f"or `max_length` in the generation config."
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
outputs = self.decoders[batch_size](
|
|
787
876
|
input_ids=input_ids,
|
|
788
877
|
inputs_embeds=inputs_embeds,
|
|
789
878
|
cache_position=cache_position,
|
|
790
879
|
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
791
880
|
lora_int_ids=lora_int_ids,
|
|
792
|
-
)
|
|
881
|
+
)
|
|
882
|
+
logits = outputs.logits
|
|
883
|
+
all_hidden_states = outputs.hidden_states
|
|
793
884
|
|
|
794
885
|
if not return_dict:
|
|
795
|
-
return logits, generate_idx, padded_cache_lengths
|
|
886
|
+
return logits, generate_idx, padded_cache_lengths, all_hidden_states
|
|
796
887
|
else:
|
|
797
888
|
return RBLNDecoderOnlyOutput(
|
|
798
|
-
logits=logits,
|
|
889
|
+
logits=logits,
|
|
890
|
+
generate_idx=generate_idx,
|
|
891
|
+
padded_cache_lengths=padded_cache_lengths,
|
|
892
|
+
hidden_states=all_hidden_states,
|
|
799
893
|
)
|
|
@@ -13,6 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import DepthEstimatorOutput
|
|
20
|
+
|
|
16
21
|
from ...modeling_generic import RBLNModelForDepthEstimation
|
|
17
22
|
|
|
18
23
|
|
|
@@ -23,3 +28,15 @@ class RBLNDepthAnythingForDepthEstimation(RBLNModelForDepthEstimation):
|
|
|
23
28
|
This class provides hardware-accelerated inference for Depth Anything V2
|
|
24
29
|
models on RBLN devices, providing the most capable monocular depth estimation (MDE) model.
|
|
25
30
|
"""
|
|
31
|
+
|
|
32
|
+
def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[Tuple, DepthEstimatorOutput]:
|
|
33
|
+
"""
|
|
34
|
+
Forward pass for the RBLN-optimized DepthAnythingForDepthEstimation model.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a DepthEstimatorOutput object.
|
|
41
|
+
"""
|
|
42
|
+
return super().forward(pixel_values, **kwargs)
|
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
|
19
|
+
|
|
15
20
|
from ...modeling_generic import RBLNModelForQuestionAnswering
|
|
16
21
|
|
|
17
22
|
|
|
@@ -25,3 +30,22 @@ class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
|
|
25
30
|
"""
|
|
26
31
|
|
|
27
32
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
33
|
+
|
|
34
|
+
def forward(
|
|
35
|
+
self,
|
|
36
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
37
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
38
|
+
**kwargs,
|
|
39
|
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
|
40
|
+
"""
|
|
41
|
+
Forward pass for the RBLN-optimized DistilBERT model for question answering tasks.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
45
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
@@ -13,6 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import DepthEstimatorOutput
|
|
20
|
+
|
|
16
21
|
from ...modeling_generic import RBLNModelForDepthEstimation
|
|
17
22
|
|
|
18
23
|
|
|
@@ -23,3 +28,15 @@ class RBLNDPTForDepthEstimation(RBLNModelForDepthEstimation):
|
|
|
23
28
|
This class provides hardware-accelerated inference for DPT (Dense Prediction Transformer)
|
|
24
29
|
models on RBLN devices, supporting monocular depth estimation from single images.
|
|
25
30
|
"""
|
|
31
|
+
|
|
32
|
+
def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[Tuple, DepthEstimatorOutput]:
|
|
33
|
+
"""
|
|
34
|
+
Forward pass for the RBLN-optimized DPT model.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a DepthEstimatorOutput object.
|
|
41
|
+
"""
|
|
42
|
+
return super().forward(pixel_values, **kwargs)
|
|
@@ -64,6 +64,7 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
64
64
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
65
65
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
66
66
|
lora_int_id: Optional[torch.Tensor] = None,
|
|
67
|
+
output_hidden_states: Optional[bool] = None,
|
|
67
68
|
):
|
|
68
69
|
# retrieve input_ids and inputs_embeds
|
|
69
70
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -96,7 +97,10 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
96
97
|
|
|
97
98
|
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
|
|
98
99
|
|
|
100
|
+
all_hidden_states = () if output_hidden_states else None
|
|
99
101
|
for layer_idx, layer in enumerate(self.layers):
|
|
102
|
+
if output_hidden_states:
|
|
103
|
+
all_hidden_states += (hidden_states,)
|
|
100
104
|
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
101
105
|
hidden_states = layer(
|
|
102
106
|
hidden_states=hidden_states,
|
|
@@ -110,7 +114,9 @@ class Gemma3TextModel(DecoderOnlyModel):
|
|
|
110
114
|
)
|
|
111
115
|
|
|
112
116
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
113
|
-
|
|
117
|
+
if output_hidden_states:
|
|
118
|
+
all_hidden_states += (hidden_states,)
|
|
119
|
+
return hidden_states, all_hidden_states
|
|
114
120
|
|
|
115
121
|
|
|
116
122
|
class Gemma3DecoderLayer(DecoderOnlyLayer):
|
|
@@ -16,7 +16,7 @@ from typing import Optional
|
|
|
16
16
|
import rebel
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from ...modeling_outputs import
|
|
19
|
+
from ...modeling_outputs import RBLNGemma3ForCausalLMOutput
|
|
20
20
|
from ..decoderonly.decoderonly_runtime_utils import RBLNPytorchRuntime
|
|
21
21
|
from ..decoderonly.modeling_decoderonly import RBLNRuntimeModel
|
|
22
22
|
|
|
@@ -26,7 +26,6 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
26
26
|
super().__init__(*args, **kwargs)
|
|
27
27
|
self.image_prefill = RBLNPytorchRuntime(image_prefill) # FIXME(taehoon)
|
|
28
28
|
self.prefill = RBLNPytorchRuntime(self.runtime) if self.phase == "prefill" else None # FIXME
|
|
29
|
-
self.decode = RBLNPytorchRuntime(self.runtime) if self.phase == "decode" else None
|
|
30
29
|
|
|
31
30
|
def _prepare_prefill_inputs(self, *args, **kwargs):
|
|
32
31
|
(
|
|
@@ -106,6 +105,8 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
106
105
|
)
|
|
107
106
|
|
|
108
107
|
step = 0
|
|
108
|
+
output_logits = []
|
|
109
|
+
all_hidden_states = [] if self.rbln_config.output_hidden_states else None
|
|
109
110
|
while step < query_length:
|
|
110
111
|
if self.rbln_config.use_image_prefill:
|
|
111
112
|
# Check if the prefill chunk is an image prefill
|
|
@@ -146,7 +147,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
146
147
|
query_position = torch.tensor(num_processed_tokens - 1, dtype=torch.int16)
|
|
147
148
|
|
|
148
149
|
if is_image_prefill:
|
|
149
|
-
|
|
150
|
+
outputs = self.image_prefill(
|
|
150
151
|
input_chunk,
|
|
151
152
|
cache_pos_chunk,
|
|
152
153
|
block_tables,
|
|
@@ -157,7 +158,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
157
158
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
158
159
|
)
|
|
159
160
|
else:
|
|
160
|
-
|
|
161
|
+
outputs = self.prefill(
|
|
161
162
|
input_chunk,
|
|
162
163
|
cache_pos_chunk,
|
|
163
164
|
block_tables,
|
|
@@ -168,78 +169,49 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
|
|
|
168
169
|
lora_int_ids if self.rbln_config.use_lora else None,
|
|
169
170
|
)
|
|
170
171
|
|
|
172
|
+
if self.rbln_config.output_hidden_states:
|
|
173
|
+
output_logits.append(outputs[0])
|
|
174
|
+
all_hidden_states.append(tuple(outputs[1:]))
|
|
175
|
+
else:
|
|
176
|
+
output_logits.append(outputs)
|
|
177
|
+
|
|
171
178
|
padded_cache_lengths += current_padded_cache_lengths
|
|
172
179
|
step += num_processed_tokens
|
|
173
180
|
|
|
174
|
-
if
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
def decode_forward(
|
|
182
|
-
self,
|
|
183
|
-
inputs: torch.Tensor,
|
|
184
|
-
cache_position: torch.Tensor = None,
|
|
185
|
-
block_tables: torch.Tensor = None,
|
|
186
|
-
is_external_block_tables: bool = None,
|
|
187
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
188
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
189
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
190
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
191
|
-
lora_int_ids: Optional[torch.Tensor] = None,
|
|
192
|
-
) -> torch.FloatTensor:
|
|
193
|
-
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
194
|
-
if self.lora_int_ids is None:
|
|
195
|
-
raise ValueError(
|
|
196
|
-
"lora_int_id is required when using LoRA. "
|
|
197
|
-
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
lora_int_ids = self.lora_int_ids
|
|
201
|
-
|
|
202
|
-
if lora_int_ids is not None and lora_int_ids.shape[0] != self.batch_size:
|
|
203
|
-
raise ValueError(f"lora_int_ids size mismatch: got {lora_int_ids.shape[0]}, expected {self.batch_size}.")
|
|
204
|
-
|
|
205
|
-
batch_size = inputs.shape[0]
|
|
206
|
-
if batch_size != self.batch_size:
|
|
207
|
-
raise RuntimeError(
|
|
208
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
209
|
-
)
|
|
181
|
+
if self.rbln_config.output_hidden_states:
|
|
182
|
+
num_hidden_layers = len(all_hidden_states[0]) - 1
|
|
183
|
+
concatenated_hidden_states = ()
|
|
184
|
+
for l_idx in range(num_hidden_layers + 1):
|
|
185
|
+
l_hidden_states = torch.cat([hidden_states[l_idx] for hidden_states in all_hidden_states], dim=1)
|
|
186
|
+
l_hidden_states = l_hidden_states[:, :query_length, :]
|
|
187
|
+
concatenated_hidden_states += (l_hidden_states,)
|
|
210
188
|
|
|
211
|
-
|
|
212
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
189
|
+
all_hidden_states = concatenated_hidden_states
|
|
213
190
|
|
|
214
|
-
#
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
if local_block_tables is None:
|
|
219
|
-
raise ValueError("local_block_tables should be provided with external block tables.")
|
|
191
|
+
# Aggregate output_logits
|
|
192
|
+
output_logits = torch.concat(output_logits, dim=-2)
|
|
193
|
+
if self.rbln_config.logits_to_keep > 0:
|
|
194
|
+
output_logits = output_logits[:, -self.rbln_config.logits_to_keep :, :]
|
|
220
195
|
else:
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
232
|
-
)
|
|
233
|
-
self.dec_attn_mask[b_idx, decoding_step] = 1
|
|
234
|
-
|
|
235
|
-
attention_mask = self.dec_attn_mask
|
|
236
|
-
|
|
237
|
-
if self.batch_size < block_tables.shape[0]:
|
|
238
|
-
block_tables = block_tables[: self.batch_size]
|
|
196
|
+
output_logits = output_logits[:, :query_length, :]
|
|
197
|
+
# index copy for masked output_logits
|
|
198
|
+
if attention_mask is not None:
|
|
199
|
+
new_output_logits = torch.full(
|
|
200
|
+
(1, attention_mask.shape[-1], output_logits.shape[-1]),
|
|
201
|
+
fill_value=1e-10,
|
|
202
|
+
dtype=output_logits.dtype,
|
|
203
|
+
)
|
|
204
|
+
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
205
|
+
new_output_logits.index_copy_(dim=-2, index=mask_indices, source=output_logits)
|
|
239
206
|
|
|
240
|
-
|
|
241
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
207
|
+
output_logits = new_output_logits
|
|
242
208
|
|
|
243
|
-
|
|
209
|
+
if not is_external_block_tables:
|
|
210
|
+
self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask
|
|
244
211
|
|
|
245
|
-
return
|
|
212
|
+
return RBLNGemma3ForCausalLMOutput(
|
|
213
|
+
logits=output_logits,
|
|
214
|
+
padded_cache_lengths=padded_cache_lengths,
|
|
215
|
+
attention_mask=chunked_attention_mask,
|
|
216
|
+
hidden_states=all_hidden_states,
|
|
217
|
+
)
|