optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -26,15 +26,16 @@ from transformers.modeling_utils import no_init_weights
|
|
|
26
26
|
from ....configuration_utils import RBLNCompileConfig
|
|
27
27
|
from ....modeling import RBLNModel
|
|
28
28
|
from ....utils.logging import get_logger
|
|
29
|
+
from ....utils.runtime_utils import is_compiler_supports_buffer_resize
|
|
29
30
|
from ...modeling_attention_utils import (
|
|
30
31
|
RBLNDecoderOnlyFlashAttentionMixin,
|
|
31
32
|
set_default_values,
|
|
32
33
|
validate_attention_method,
|
|
33
34
|
validate_sliding_window,
|
|
34
35
|
)
|
|
35
|
-
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
36
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput, _validate_output_hidden_states
|
|
36
37
|
from ...utils.rbln_quantization import get_quantized_model
|
|
37
|
-
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
|
+
from .configuration_decoderonly import KVCacheMeta, RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
39
|
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
39
40
|
from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
|
|
40
41
|
from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
@@ -88,8 +89,12 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
88
89
|
def setup_runtime(self):
|
|
89
90
|
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
90
91
|
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
91
|
-
|
|
92
|
-
|
|
92
|
+
if self.rbln_config.use_position_ids:
|
|
93
|
+
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, self.rbln_config.max_seq_len, dtype=self.dtype)
|
|
94
|
+
else:
|
|
95
|
+
dec_attn_mask = torch.zeros(
|
|
96
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype
|
|
97
|
+
)
|
|
93
98
|
|
|
94
99
|
common_kwargs = {
|
|
95
100
|
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
@@ -97,12 +102,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
97
102
|
"dec_attn_mask": dec_attn_mask,
|
|
98
103
|
"page_table_manager": page_table_manager,
|
|
99
104
|
"rbln_config": self.rbln_config,
|
|
105
|
+
"config": self.config,
|
|
100
106
|
}
|
|
101
107
|
self.prefill_decoder = RBLNRuntimeModel(
|
|
102
108
|
runtime=self.model[0],
|
|
103
109
|
phase="prefill",
|
|
104
110
|
batch_size=self.rbln_config.batch_size,
|
|
105
|
-
|
|
111
|
+
logits_last_dim=self.logits_last_dim,
|
|
106
112
|
**common_kwargs,
|
|
107
113
|
)
|
|
108
114
|
if self.can_generate():
|
|
@@ -119,12 +125,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
119
125
|
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
120
126
|
|
|
121
127
|
@property
|
|
122
|
-
def
|
|
123
|
-
return
|
|
124
|
-
1,
|
|
125
|
-
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
126
|
-
self.config.hidden_size,
|
|
127
|
-
)
|
|
128
|
+
def logits_last_dim(self):
|
|
129
|
+
return self.config.hidden_size
|
|
128
130
|
|
|
129
131
|
@classmethod
|
|
130
132
|
def get_quantized_model(
|
|
@@ -216,7 +218,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
216
218
|
return self.rbln_config.kvcache_num_blocks
|
|
217
219
|
|
|
218
220
|
@classmethod
|
|
219
|
-
def
|
|
221
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
220
222
|
return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
|
|
221
223
|
|
|
222
224
|
@classmethod
|
|
@@ -229,7 +231,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
229
231
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
230
232
|
quantization=None,
|
|
231
233
|
phase: str = "prefill",
|
|
232
|
-
):
|
|
234
|
+
) -> rebel.RBLNCompiledModel:
|
|
233
235
|
try:
|
|
234
236
|
wrapped_model.phase = phase
|
|
235
237
|
if quantization:
|
|
@@ -251,28 +253,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
251
253
|
quantization.maybe_reset_quantization_env()
|
|
252
254
|
|
|
253
255
|
@classmethod
|
|
254
|
-
def _get_compile_context(
|
|
255
|
-
cls,
|
|
256
|
-
compile_config: RBLNCompileConfig,
|
|
257
|
-
example_inputs: List[torch.Tensor],
|
|
258
|
-
):
|
|
256
|
+
def _get_compile_context(cls, compile_config: RBLNCompileConfig, example_inputs: List[torch.Tensor]):
|
|
259
257
|
context = CompileContext(use_weight_sharing=True)
|
|
260
258
|
|
|
261
259
|
# Mark static tensors (self kv states)
|
|
262
260
|
static_tensors = {}
|
|
263
|
-
idx = 0
|
|
264
261
|
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
265
262
|
if "past_key_values" in name:
|
|
266
263
|
static_tensors[name] = tensor
|
|
267
|
-
context.mark_static_address(tensor,
|
|
268
|
-
idx += 1
|
|
264
|
+
context.mark_static_address(tensor, name)
|
|
269
265
|
|
|
270
266
|
return context, static_tensors
|
|
271
267
|
|
|
272
268
|
@classmethod
|
|
273
269
|
@torch.inference_mode()
|
|
274
270
|
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
275
|
-
wrapped_model = cls.
|
|
271
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
276
272
|
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
277
273
|
|
|
278
274
|
# Here we use meta tensor, for the memory efficiency.
|
|
@@ -280,7 +276,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
280
276
|
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
281
277
|
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
282
278
|
|
|
283
|
-
compiled_models = {}
|
|
279
|
+
compiled_models: dict[str, rebel.RBLNCompiledModel] = {}
|
|
284
280
|
compiled_models["prefill"] = cls._compile_model(
|
|
285
281
|
wrapped_model,
|
|
286
282
|
prefill_compile_config,
|
|
@@ -306,14 +302,10 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
306
302
|
)
|
|
307
303
|
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
308
304
|
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
compiled_models=compiled_models,
|
|
314
|
-
model_config=model.config,
|
|
315
|
-
rbln_config=rbln_config,
|
|
316
|
-
)
|
|
305
|
+
if rbln_config.is_auto_num_blocks:
|
|
306
|
+
if not is_compiler_supports_buffer_resize():
|
|
307
|
+
raise RuntimeError("`kvcache_num_blocks` must be set.")
|
|
308
|
+
cls.set_kvcache_num_blocks_after_compilation(compiled_models, rbln_config)
|
|
317
309
|
|
|
318
310
|
return compiled_models
|
|
319
311
|
|
|
@@ -329,8 +321,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
329
321
|
return model
|
|
330
322
|
|
|
331
323
|
@classmethod
|
|
332
|
-
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
333
|
-
return use_local_attention
|
|
324
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True, logits_to_keep: int = None):
|
|
325
|
+
return is_prefill and (use_local_attention or logits_to_keep == 1)
|
|
334
326
|
|
|
335
327
|
@classmethod
|
|
336
328
|
def get_input_info(
|
|
@@ -340,16 +332,16 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
340
332
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
341
333
|
model_config: PretrainedConfig,
|
|
342
334
|
):
|
|
343
|
-
num_attention_heads = getattr(model_config, "n_head", None) or
|
|
335
|
+
num_attention_heads = getattr(model_config, "n_head", None) or model_config.num_attention_heads
|
|
344
336
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
345
|
-
num_hidden_layers = getattr(model_config, "n_layer", None) or
|
|
346
|
-
hidden_size = getattr(model_config, "n_embd", None) or
|
|
337
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or model_config.num_hidden_layers
|
|
338
|
+
hidden_size = getattr(model_config, "n_embd", None) or model_config.hidden_size
|
|
347
339
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
348
340
|
is_prefill = query_length > 1
|
|
349
341
|
|
|
350
342
|
input_info = []
|
|
351
343
|
if rbln_config.use_inputs_embeds:
|
|
352
|
-
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.
|
|
344
|
+
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.dtype))
|
|
353
345
|
else:
|
|
354
346
|
input_info.append(("input_ids", [batch_size, query_length], "int64"))
|
|
355
347
|
|
|
@@ -363,15 +355,15 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
363
355
|
if rbln_config.use_local_attention:
|
|
364
356
|
input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
|
|
365
357
|
|
|
366
|
-
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
358
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill, rbln_config.logits_to_keep):
|
|
367
359
|
input_info.append(("query_position", [], "int16"))
|
|
368
360
|
|
|
369
361
|
if rbln_config.use_attention_mask:
|
|
370
362
|
if rbln_config.use_position_ids:
|
|
371
|
-
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.
|
|
363
|
+
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.dtype))
|
|
372
364
|
else:
|
|
373
365
|
input_info.append(
|
|
374
|
-
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.
|
|
366
|
+
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.dtype)
|
|
375
367
|
)
|
|
376
368
|
|
|
377
369
|
if rbln_config.use_position_ids:
|
|
@@ -380,29 +372,36 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
380
372
|
if rbln_config.use_lora:
|
|
381
373
|
input_info.append(("lora_int_ids", [batch_size], "int32"))
|
|
382
374
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
375
|
+
if len(rbln_config.kvcache_metas) > 0:
|
|
376
|
+
# Meta is already set, use it
|
|
377
|
+
input_info.extend(
|
|
378
|
+
[
|
|
379
|
+
(kvcache_meta.name, kvcache_meta.compile_shape, kvcache_meta.dtype)
|
|
380
|
+
for kvcache_meta in rbln_config.kvcache_metas
|
|
381
|
+
]
|
|
382
|
+
)
|
|
386
383
|
|
|
387
|
-
|
|
388
|
-
rbln_config.
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
(
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
kvcache_dtype,
|
|
384
|
+
else:
|
|
385
|
+
kvcache_dtype = rbln_config.dtype
|
|
386
|
+
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
387
|
+
kvcache_dtype = "float8_e4m3fn"
|
|
388
|
+
|
|
389
|
+
kvcache_metas = []
|
|
390
|
+
for i in range(num_hidden_layers * 2):
|
|
391
|
+
layer_idx = i // 2
|
|
392
|
+
name = f"past_key_values_{i}"
|
|
393
|
+
kvcache_meta = KVCacheMeta.make(
|
|
394
|
+
name,
|
|
395
|
+
layer_idx,
|
|
396
|
+
num_key_value_heads,
|
|
397
|
+
head_dim,
|
|
398
|
+
RBLNCompileConfig.normalize_dtype(kvcache_dtype),
|
|
399
|
+
rbln_config,
|
|
402
400
|
)
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
401
|
+
kvcache_metas.append(kvcache_meta)
|
|
402
|
+
input_info.append((name, kvcache_meta.compile_shape, kvcache_meta.dtype))
|
|
403
|
+
|
|
404
|
+
rbln_config.kvcache_metas.extend(kvcache_metas)
|
|
406
405
|
|
|
407
406
|
return input_info
|
|
408
407
|
|
|
@@ -439,10 +438,22 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
439
438
|
# Returns:
|
|
440
439
|
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
441
440
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
441
|
+
rbln_config.sliding_window = model_config.sliding_window
|
|
442
|
+
sliding_window_layers = []
|
|
443
|
+
|
|
444
|
+
for i in range(model_config.num_hidden_layers):
|
|
445
|
+
if hasattr(model_config, "layer_types"):
|
|
446
|
+
if model_config.layer_types[i] == "sliding_attention":
|
|
447
|
+
sliding_window_layers.append(i)
|
|
448
|
+
else:
|
|
449
|
+
sliding_window_layers.append(i)
|
|
450
|
+
|
|
451
|
+
rbln_config.sliding_window_layers = sliding_window_layers
|
|
452
|
+
|
|
453
|
+
rbln_config.cache_impl = (
|
|
454
|
+
"sliding_window" if len(sliding_window_layers) == model_config.num_hidden_layers else "hybrid"
|
|
445
455
|
)
|
|
456
|
+
return rbln_config
|
|
446
457
|
|
|
447
458
|
@classmethod
|
|
448
459
|
def _update_attention_config(
|
|
@@ -462,58 +473,40 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
462
473
|
max_seq_len=rbln_config.max_seq_len,
|
|
463
474
|
)
|
|
464
475
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
#
|
|
476
|
+
# Validate kvcache_num_blocks based on the number of full blocks required.
|
|
477
|
+
# Eager mode restriction:
|
|
478
|
+
# - num_blocks must be at least equal to the batch size
|
|
479
|
+
# Flash attention restriction:
|
|
480
|
+
# - num_blocks must be at least equal to (max_seq_len // kvcache_block_size) + 1
|
|
481
|
+
# - num_blocks must be no greater than the number of full blocks.
|
|
468
482
|
if rbln_config.attn_impl == "flash_attn":
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
473
|
-
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
474
|
-
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
475
|
-
num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
|
|
476
|
-
)
|
|
483
|
+
if rbln_config.is_auto_num_blocks:
|
|
484
|
+
# Do nothing
|
|
485
|
+
pass
|
|
477
486
|
|
|
478
|
-
|
|
479
|
-
if
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
487
|
+
else:
|
|
488
|
+
if rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
|
|
489
|
+
logger.warning(
|
|
490
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
491
|
+
f" than the required number of blocks ({rbln_config.num_full_blocks})."
|
|
492
|
+
"This can cause a failure during model compilation."
|
|
493
|
+
)
|
|
494
|
+
elif rbln_config.kvcache_num_blocks < rbln_config.num_min_blocks:
|
|
495
|
+
raise ValueError(
|
|
496
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is less"
|
|
497
|
+
f" than the minimum number of blocks ({rbln_config.num_min_blocks})."
|
|
483
498
|
)
|
|
484
|
-
if min_blocks_for_flash > estimated_max_num_blocks:
|
|
485
|
-
# NOTE: Just try to compile with lower bound of blocks for flash attention.
|
|
486
|
-
# Even if it's larger than the estimated maximum number of blocks.
|
|
487
|
-
rbln_config.kvcache_num_blocks = min_blocks_for_flash
|
|
488
|
-
else:
|
|
489
|
-
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
490
|
-
rbln_config.kvcache_num_blocks = estimated_max_num_blocks
|
|
491
|
-
|
|
492
|
-
if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
|
|
493
|
-
raise RuntimeError(
|
|
494
|
-
f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
|
|
495
|
-
"Ensure the number of blocks is at least equal to the batch size."
|
|
496
|
-
)
|
|
497
|
-
else:
|
|
498
|
-
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
499
|
-
elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
|
|
500
|
-
logger.warning(
|
|
501
|
-
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
502
|
-
f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
|
|
503
|
-
"This can cause a failure during model compilation."
|
|
504
|
-
)
|
|
505
499
|
else:
|
|
506
|
-
if rbln_config.
|
|
507
|
-
|
|
508
|
-
|
|
500
|
+
if rbln_config.is_auto_num_blocks:
|
|
501
|
+
# Eager attention should use fixed number of blocks.
|
|
502
|
+
rbln_config.kvcache_num_blocks = rbln_config.num_full_blocks
|
|
503
|
+
elif rbln_config.kvcache_num_blocks > rbln_config.num_full_blocks:
|
|
509
504
|
logger.warning(
|
|
510
505
|
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
511
|
-
f" than the required number of blocks ({num_full_blocks})."
|
|
506
|
+
f" than the required number of blocks ({rbln_config.num_full_blocks})."
|
|
512
507
|
"This can cause a failure during model compilation."
|
|
513
508
|
)
|
|
514
509
|
|
|
515
|
-
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
516
|
-
|
|
517
510
|
return rbln_config
|
|
518
511
|
|
|
519
512
|
@classmethod
|
|
@@ -531,8 +524,13 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
531
524
|
if rbln_config.max_seq_len is None:
|
|
532
525
|
raise ValueError("`max_seq_len` should be specified.")
|
|
533
526
|
|
|
534
|
-
|
|
535
|
-
|
|
527
|
+
layer_types = getattr(model_config, "layer_types", None)
|
|
528
|
+
all_full_attention = layer_types is not None and all(t == "full_attention" for t in layer_types)
|
|
529
|
+
|
|
530
|
+
if (
|
|
531
|
+
getattr(model_config, "sliding_window", None) is not None
|
|
532
|
+
and getattr(model_config, "use_sliding_window", True)
|
|
533
|
+
and not all_full_attention
|
|
536
534
|
):
|
|
537
535
|
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
538
536
|
if rbln_config.sliding_window is not None:
|
|
@@ -608,34 +606,66 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
608
606
|
input_ids: Optional[torch.LongTensor] = None,
|
|
609
607
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
610
608
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
609
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
611
610
|
position_embed: Optional[torch.Tensor] = None,
|
|
611
|
+
output_hidden_states: Optional[bool] = None,
|
|
612
612
|
**kwargs,
|
|
613
|
-
) ->
|
|
613
|
+
) -> BaseModelOutputWithPast:
|
|
614
|
+
"""
|
|
615
|
+
Args:
|
|
616
|
+
input_ids (torch.LongTensor, optional): The input IDs to the model.
|
|
617
|
+
inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
|
|
618
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
619
|
+
kwargs (dict[str, Any], optional): Additional keyword arguments.
|
|
620
|
+
|
|
621
|
+
Returns:
|
|
622
|
+
Dataclass containing the last hidden states of the model.
|
|
623
|
+
"""
|
|
614
624
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
615
625
|
batch_size = inputs.shape[0]
|
|
626
|
+
position_embed = kwargs.get("position_embed", None)
|
|
616
627
|
|
|
617
628
|
if batch_size != self.rbln_config.batch_size:
|
|
618
629
|
raise ValueError(
|
|
619
630
|
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
620
631
|
)
|
|
632
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
621
633
|
|
|
622
634
|
all_last_hidden_states = []
|
|
635
|
+
all_hidden_states = (
|
|
636
|
+
tuple(
|
|
637
|
+
torch.zeros(
|
|
638
|
+
self.rbln_config.batch_size,
|
|
639
|
+
inputs.shape[1],
|
|
640
|
+
self.config.hidden_size,
|
|
641
|
+
dtype=self.rbln_config.dtype,
|
|
642
|
+
)
|
|
643
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
644
|
+
)
|
|
645
|
+
if output_hidden_states
|
|
646
|
+
else None
|
|
647
|
+
)
|
|
623
648
|
for b_idx in range(self.rbln_config.batch_size):
|
|
624
649
|
query_length = (
|
|
625
650
|
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
|
|
626
651
|
)
|
|
627
652
|
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
628
|
-
|
|
629
|
-
inputs[b_idx : b_idx + 1],
|
|
653
|
+
outputs = self.prefill_decoder(
|
|
654
|
+
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
|
655
|
+
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
|
630
656
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
657
|
+
position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
|
|
631
658
|
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
632
659
|
cache_position=cache_position,
|
|
633
660
|
batch_idx=b_idx,
|
|
634
|
-
)
|
|
635
|
-
all_last_hidden_states.append(
|
|
661
|
+
)
|
|
662
|
+
all_last_hidden_states.append(outputs.logits)
|
|
663
|
+
if self.rbln_config.output_hidden_states:
|
|
664
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
665
|
+
all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
|
|
636
666
|
|
|
637
667
|
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
638
|
-
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
668
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states, hidden_states=all_hidden_states)
|
|
639
669
|
|
|
640
670
|
|
|
641
671
|
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
|
|
@@ -648,6 +678,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
648
678
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
649
679
|
2. Handling the compilation process for RBLN devices
|
|
650
680
|
3. Managing inference operations for causal language modeling
|
|
681
|
+
|
|
651
682
|
This class inherits from RBLNModel and implements specific methods required for
|
|
652
683
|
decoder-only architectures and causal language modeling tasks.
|
|
653
684
|
|
|
@@ -661,16 +692,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
661
692
|
auto_model_class = AutoModelForCausalLM
|
|
662
693
|
|
|
663
694
|
@property
|
|
664
|
-
def
|
|
665
|
-
return
|
|
666
|
-
1,
|
|
667
|
-
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
668
|
-
self.config.vocab_size,
|
|
669
|
-
)
|
|
670
|
-
|
|
671
|
-
@classmethod
|
|
672
|
-
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
673
|
-
return is_prefill
|
|
695
|
+
def logits_last_dim(self):
|
|
696
|
+
return self.config.vocab_size
|
|
674
697
|
|
|
675
698
|
def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
|
|
676
699
|
if isinstance(lora_int_ids, int):
|
|
@@ -731,6 +754,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
731
754
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
732
755
|
lora_int_ids: Optional[torch.Tensor] = None,
|
|
733
756
|
return_dict: Optional[torch.Tensor] = None,
|
|
757
|
+
output_hidden_states: Optional[bool] = None,
|
|
734
758
|
**kwargs,
|
|
735
759
|
) -> Tuple[torch.FloatTensor]:
|
|
736
760
|
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
|
|
@@ -754,24 +778,48 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
754
778
|
)
|
|
755
779
|
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
756
780
|
|
|
781
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
782
|
+
|
|
757
783
|
# Prefill
|
|
758
784
|
if cache_position is None:
|
|
759
785
|
logits = []
|
|
760
786
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
761
787
|
batch_size = inputs.shape[0]
|
|
788
|
+
input_len = inputs.shape[1]
|
|
789
|
+
if batch_size > self.rbln_config.batch_size:
|
|
790
|
+
raise ValueError(
|
|
791
|
+
f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
|
|
792
|
+
)
|
|
793
|
+
if input_len > self.rbln_config.max_seq_len:
|
|
794
|
+
raise ValueError(
|
|
795
|
+
f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
all_hidden_states = (
|
|
799
|
+
tuple(
|
|
800
|
+
torch.zeros(batch_size, input_len, self.config.hidden_size, dtype=self.rbln_config.dtype)
|
|
801
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
802
|
+
)
|
|
803
|
+
if self.rbln_config.output_hidden_states
|
|
804
|
+
else None
|
|
805
|
+
)
|
|
762
806
|
for b_idx in range(batch_size):
|
|
763
807
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
764
|
-
|
|
808
|
+
outputs = self.prefill_decoder(
|
|
765
809
|
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
|
766
810
|
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
|
767
811
|
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
812
|
+
position_ids=position_ids[b_idx : b_idx + 1] if position_ids is not None else None,
|
|
768
813
|
cache_position=cache_position,
|
|
769
814
|
batch_idx=b_idx,
|
|
770
815
|
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
771
816
|
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
772
817
|
)
|
|
773
|
-
padded_cache_lengths[b_idx] +=
|
|
774
|
-
logits.append(
|
|
818
|
+
padded_cache_lengths[b_idx] += outputs.padded_cache_lengths
|
|
819
|
+
logits.append(outputs.logits)
|
|
820
|
+
if self.rbln_config.output_hidden_states:
|
|
821
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
822
|
+
all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
|
|
775
823
|
logits = torch.cat(logits, dim=0)
|
|
776
824
|
# Decoder
|
|
777
825
|
else:
|
|
@@ -783,17 +831,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
|
|
|
783
831
|
f"Available batch sizes are: {list(self.decoders.keys())}. "
|
|
784
832
|
f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
|
|
785
833
|
)
|
|
786
|
-
|
|
834
|
+
if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
|
|
835
|
+
raise ValueError(
|
|
836
|
+
f"Cache position exceeds the maximum sequence length.\n"
|
|
837
|
+
f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
|
|
838
|
+
f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
|
|
839
|
+
f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
|
|
840
|
+
f"or `max_length` in the generation config."
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
outputs = self.decoders[batch_size](
|
|
787
844
|
input_ids=input_ids,
|
|
788
845
|
inputs_embeds=inputs_embeds,
|
|
789
846
|
cache_position=cache_position,
|
|
790
847
|
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
791
848
|
lora_int_ids=lora_int_ids,
|
|
792
|
-
)
|
|
849
|
+
)
|
|
850
|
+
logits = outputs.logits
|
|
851
|
+
all_hidden_states = outputs.hidden_states
|
|
793
852
|
|
|
794
853
|
if not return_dict:
|
|
795
|
-
return logits, generate_idx, padded_cache_lengths
|
|
854
|
+
return logits, generate_idx, padded_cache_lengths, all_hidden_states
|
|
796
855
|
else:
|
|
797
856
|
return RBLNDecoderOnlyOutput(
|
|
798
|
-
logits=logits,
|
|
857
|
+
logits=logits,
|
|
858
|
+
generate_idx=generate_idx,
|
|
859
|
+
padded_cache_lengths=padded_cache_lengths,
|
|
860
|
+
hidden_states=all_hidden_states,
|
|
799
861
|
)
|
|
@@ -13,6 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import DepthEstimatorOutput
|
|
20
|
+
|
|
16
21
|
from ...modeling_generic import RBLNModelForDepthEstimation
|
|
17
22
|
|
|
18
23
|
|
|
@@ -23,3 +28,15 @@ class RBLNDepthAnythingForDepthEstimation(RBLNModelForDepthEstimation):
|
|
|
23
28
|
This class provides hardware-accelerated inference for Depth Anything V2
|
|
24
29
|
models on RBLN devices, providing the most capable monocular depth estimation (MDE) model.
|
|
25
30
|
"""
|
|
31
|
+
|
|
32
|
+
def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[Tuple, DepthEstimatorOutput]:
|
|
33
|
+
"""
|
|
34
|
+
Forward pass for the RBLN-optimized DepthAnythingForDepthEstimation model.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a DepthEstimatorOutput object.
|
|
41
|
+
"""
|
|
42
|
+
return super().forward(pixel_values, **kwargs)
|
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
|
19
|
+
|
|
15
20
|
from ...modeling_generic import RBLNModelForQuestionAnswering
|
|
16
21
|
|
|
17
22
|
|
|
@@ -25,3 +30,22 @@ class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
|
|
25
30
|
"""
|
|
26
31
|
|
|
27
32
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
33
|
+
|
|
34
|
+
def forward(
|
|
35
|
+
self,
|
|
36
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
37
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
38
|
+
**kwargs,
|
|
39
|
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
|
40
|
+
"""
|
|
41
|
+
Forward pass for the RBLN-optimized DistilBERT model for question answering tasks.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
45
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
@@ -13,6 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import DepthEstimatorOutput
|
|
20
|
+
|
|
16
21
|
from ...modeling_generic import RBLNModelForDepthEstimation
|
|
17
22
|
|
|
18
23
|
|
|
@@ -23,3 +28,15 @@ class RBLNDPTForDepthEstimation(RBLNModelForDepthEstimation):
|
|
|
23
28
|
This class provides hardware-accelerated inference for DPT (Dense Prediction Transformer)
|
|
24
29
|
models on RBLN devices, supporting monocular depth estimation from single images.
|
|
25
30
|
"""
|
|
31
|
+
|
|
32
|
+
def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[Tuple, DepthEstimatorOutput]:
|
|
33
|
+
"""
|
|
34
|
+
Forward pass for the RBLN-optimized DPT model.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a DepthEstimatorOutput object.
|
|
41
|
+
"""
|
|
42
|
+
return super().forward(pixel_values, **kwargs)
|