optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from abc import ABC
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
18
|
+
|
|
19
|
+
import rebel
|
|
20
|
+
import torch
|
|
21
|
+
from rebel.compile_context import CompileContext
|
|
22
|
+
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
|
23
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
24
|
+
from transformers.generation.utils import GenerationMixin
|
|
25
|
+
from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
|
|
26
|
+
|
|
27
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
28
|
+
from ....modeling import RBLNModel
|
|
29
|
+
from ....utils.logging import get_logger
|
|
30
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
31
|
+
from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
logger = get_logger(__name__)
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
41
|
+
mandatory_members = ["main_input_name"]
|
|
42
|
+
|
|
43
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
|
|
44
|
+
output = super().forward(*args, **kwargs)
|
|
45
|
+
return BaseModelOutput(last_hidden_state=output)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
49
|
+
mandatory_members = ["main_input_name"]
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
runtime: rebel.Runtime,
|
|
54
|
+
batch_size: int,
|
|
55
|
+
dec_max_seq_len: int,
|
|
56
|
+
use_attention_mask: Optional[bool] = None,
|
|
57
|
+
**kwargs: Any,
|
|
58
|
+
) -> None:
|
|
59
|
+
super().__init__(runtime, **kwargs)
|
|
60
|
+
self.batch_size = batch_size
|
|
61
|
+
self.dec_max_seq_len = dec_max_seq_len
|
|
62
|
+
self.use_attention_mask = use_attention_mask
|
|
63
|
+
self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
|
|
64
|
+
|
|
65
|
+
def forward(
|
|
66
|
+
self,
|
|
67
|
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
68
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
|
69
|
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
|
70
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
71
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
72
|
+
**kwargs,
|
|
73
|
+
) -> Tuple[torch.FloatTensor]:
|
|
74
|
+
batch_size = decoder_input_ids.shape[0]
|
|
75
|
+
if batch_size != self.batch_size:
|
|
76
|
+
raise RuntimeError(
|
|
77
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if batch_size != cache_position.shape[0]:
|
|
81
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
82
|
+
|
|
83
|
+
if self.use_attention_mask:
|
|
84
|
+
for b_idx in range(self.batch_size):
|
|
85
|
+
decoding_step = cache_position[b_idx].item()
|
|
86
|
+
if not (0 <= decoding_step < self.dec_max_seq_len):
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"Decoding step {decoding_step} out of bounds for decoder_max_seq_len ({self.dec_max_seq_len})."
|
|
89
|
+
)
|
|
90
|
+
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
|
91
|
+
|
|
92
|
+
if block_tables is None:
|
|
93
|
+
block_tables = self.default_block_tables
|
|
94
|
+
|
|
95
|
+
lm_logits = super().forward(
|
|
96
|
+
decoder_input_ids,
|
|
97
|
+
decoder_attention_mask if self.use_attention_mask else None,
|
|
98
|
+
attention_mask,
|
|
99
|
+
cache_position,
|
|
100
|
+
block_tables=block_tables,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return Seq2SeqLMOutput(logits=lm_logits)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
|
|
107
|
+
"""
|
|
108
|
+
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
|
109
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
110
|
+
|
|
111
|
+
A class to convert and run pre-trained transformers based Seq2SeqLM models on RBLN devices.
|
|
112
|
+
It implements the methods to convert a pre-trained transformers Seq2SeqLM model into a RBLN transformer model by:
|
|
113
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
114
|
+
- compiling the resulting graph using the RBLN compiler.
|
|
115
|
+
|
|
116
|
+
Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
main_input_name = "input_ids"
|
|
120
|
+
auto_model_class = AutoModelForSeq2SeqLM
|
|
121
|
+
support_causal_attn = None
|
|
122
|
+
_is_stateful = False
|
|
123
|
+
|
|
124
|
+
def __post_init__(self, **kwargs):
|
|
125
|
+
batch_size = self.rbln_config.batch_size
|
|
126
|
+
dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
|
127
|
+
self.use_attention_mask = self.rbln_config.use_attention_mask
|
|
128
|
+
|
|
129
|
+
self.encoder = RBLNRuntimeEncoder(
|
|
130
|
+
runtime=self.model[0],
|
|
131
|
+
main_input_name="input_ids",
|
|
132
|
+
)
|
|
133
|
+
self.decoder = RBLNRuntimeDecoder(
|
|
134
|
+
runtime=self.model[1],
|
|
135
|
+
main_input_name="input_ids",
|
|
136
|
+
batch_size=batch_size,
|
|
137
|
+
dec_max_seq_len=dec_max_seq_len,
|
|
138
|
+
use_attention_mask=self.use_attention_mask,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
@classmethod
|
|
142
|
+
@torch.inference_mode()
|
|
143
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
|
144
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
145
|
+
|
|
146
|
+
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
147
|
+
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
148
|
+
|
|
149
|
+
context = CompileContext(use_weight_sharing=False)
|
|
150
|
+
|
|
151
|
+
enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
|
|
152
|
+
|
|
153
|
+
# Mark encoder's static tensors (cross kv states)
|
|
154
|
+
static_tensors = {}
|
|
155
|
+
for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
|
|
156
|
+
if "key_value_states" in name:
|
|
157
|
+
static_tensors[name] = tensor
|
|
158
|
+
context.mark_static_address(tensor)
|
|
159
|
+
|
|
160
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
161
|
+
|
|
162
|
+
# Mark decoder's static tensors (self kv states)
|
|
163
|
+
for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
|
|
164
|
+
if "key_value_states" in name:
|
|
165
|
+
context.mark_static_address(tensor)
|
|
166
|
+
|
|
167
|
+
compiled_encoder = cls.compile(
|
|
168
|
+
wrapped_model.encoder,
|
|
169
|
+
enc_compile_config,
|
|
170
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
171
|
+
device=rbln_config.device,
|
|
172
|
+
example_inputs=enc_example_inputs,
|
|
173
|
+
compile_context=context,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
compiled_decoder = cls.compile(
|
|
177
|
+
wrapped_model.decoder,
|
|
178
|
+
dec_compile_config,
|
|
179
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
180
|
+
device=rbln_config.device,
|
|
181
|
+
example_inputs=dec_example_inputs,
|
|
182
|
+
compile_context=context,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
|
189
|
+
rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
|
|
190
|
+
rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
|
|
191
|
+
|
|
192
|
+
if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
|
|
193
|
+
raise NotImplementedError(
|
|
194
|
+
f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
|
|
198
|
+
raise NotImplementedError(
|
|
199
|
+
f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
@classmethod
|
|
203
|
+
def _update_rbln_config(
|
|
204
|
+
cls,
|
|
205
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
206
|
+
model: Optional["PreTrainedModel"] = None,
|
|
207
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
208
|
+
rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
|
|
209
|
+
) -> RBLNModelForSeq2SeqLMConfig:
|
|
210
|
+
if not cls.support_causal_attn:
|
|
211
|
+
rbln_config.use_attention_mask = True
|
|
212
|
+
|
|
213
|
+
n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
|
|
214
|
+
n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
|
|
215
|
+
d_kv = (
|
|
216
|
+
model_config.d_kv
|
|
217
|
+
if hasattr(model_config, "d_kv")
|
|
218
|
+
else model_config.d_model // model_config.encoder_attention_heads
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
|
222
|
+
model_config, "max_position_embeddings", None
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
if rbln_config.enc_max_seq_len is None:
|
|
226
|
+
enc_max_seq_len = max_position_embeddings
|
|
227
|
+
for tokenizer in preprocessors:
|
|
228
|
+
if hasattr(tokenizer, "model_max_length"):
|
|
229
|
+
enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
|
|
230
|
+
break
|
|
231
|
+
|
|
232
|
+
if enc_max_seq_len is None:
|
|
233
|
+
raise ValueError("`enc_max_seq_len` should be specified!")
|
|
234
|
+
rbln_config.enc_max_seq_len = enc_max_seq_len
|
|
235
|
+
|
|
236
|
+
if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
|
|
237
|
+
raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
|
238
|
+
|
|
239
|
+
if rbln_config.dec_max_seq_len is None:
|
|
240
|
+
dec_max_seq_len = max_position_embeddings
|
|
241
|
+
for tokenizer in preprocessors:
|
|
242
|
+
if hasattr(tokenizer, "model_max_length"):
|
|
243
|
+
dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
|
|
244
|
+
break
|
|
245
|
+
|
|
246
|
+
if dec_max_seq_len is None:
|
|
247
|
+
raise ValueError("`dec_max_seq_len` should be specified!")
|
|
248
|
+
rbln_config.dec_max_seq_len = dec_max_seq_len
|
|
249
|
+
|
|
250
|
+
if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
|
|
251
|
+
raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
|
252
|
+
|
|
253
|
+
if rbln_config.support_paged_attention:
|
|
254
|
+
cls._update_paged_attention_config(model_config, rbln_config)
|
|
255
|
+
|
|
256
|
+
# model input info
|
|
257
|
+
enc_input_info = [
|
|
258
|
+
("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
|
|
259
|
+
("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
|
|
260
|
+
("block_tables", [1], "int16"),
|
|
261
|
+
]
|
|
262
|
+
enc_input_info.extend(
|
|
263
|
+
[
|
|
264
|
+
(
|
|
265
|
+
f"cross_key_value_states_{i}",
|
|
266
|
+
[
|
|
267
|
+
rbln_config.batch_size,
|
|
268
|
+
n_head,
|
|
269
|
+
rbln_config.enc_max_seq_len,
|
|
270
|
+
d_kv,
|
|
271
|
+
],
|
|
272
|
+
"float32",
|
|
273
|
+
)
|
|
274
|
+
for i in range(n_layer * 2)
|
|
275
|
+
]
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
dec_input_info = [
|
|
279
|
+
("input_ids", [rbln_config.batch_size, 1], "int64"),
|
|
280
|
+
("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
|
|
281
|
+
(
|
|
282
|
+
"cache_position",
|
|
283
|
+
[rbln_config.batch_size, 1],
|
|
284
|
+
"int32",
|
|
285
|
+
),
|
|
286
|
+
("block_tables", [rbln_config.batch_size, 1], "int16"),
|
|
287
|
+
]
|
|
288
|
+
dec_input_info.extend(
|
|
289
|
+
[
|
|
290
|
+
(
|
|
291
|
+
f"cross_key_value_states_{i}",
|
|
292
|
+
[
|
|
293
|
+
rbln_config.batch_size,
|
|
294
|
+
n_head,
|
|
295
|
+
rbln_config.enc_max_seq_len,
|
|
296
|
+
d_kv,
|
|
297
|
+
],
|
|
298
|
+
"float32",
|
|
299
|
+
)
|
|
300
|
+
for i in range(n_layer * 2)
|
|
301
|
+
]
|
|
302
|
+
)
|
|
303
|
+
dec_input_info.extend(
|
|
304
|
+
[
|
|
305
|
+
(
|
|
306
|
+
f"self_key_value_states_{i}",
|
|
307
|
+
[
|
|
308
|
+
rbln_config.batch_size,
|
|
309
|
+
n_head,
|
|
310
|
+
rbln_config.dec_max_seq_len,
|
|
311
|
+
d_kv,
|
|
312
|
+
],
|
|
313
|
+
"float32",
|
|
314
|
+
)
|
|
315
|
+
for i in range(n_layer * 2)
|
|
316
|
+
]
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if rbln_config.use_attention_mask:
|
|
320
|
+
dec_input_info.insert(
|
|
321
|
+
1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
|
325
|
+
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
|
326
|
+
|
|
327
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
|
328
|
+
|
|
329
|
+
return rbln_config
|
|
330
|
+
|
|
331
|
+
@classmethod
|
|
332
|
+
def _create_runtimes(
|
|
333
|
+
cls,
|
|
334
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
335
|
+
rbln_config: RBLNModelForSeq2SeqLMConfig,
|
|
336
|
+
) -> List[rebel.Runtime]:
|
|
337
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
|
338
|
+
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
|
339
|
+
|
|
340
|
+
return [
|
|
341
|
+
rebel.Runtime(
|
|
342
|
+
compiled_models[0],
|
|
343
|
+
tensor_type="pt",
|
|
344
|
+
device=rbln_config.device_map["encoder"],
|
|
345
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
346
|
+
timeout=rbln_config.timeout,
|
|
347
|
+
),
|
|
348
|
+
rebel.Runtime(
|
|
349
|
+
compiled_models[1],
|
|
350
|
+
tensor_type="pt",
|
|
351
|
+
device=rbln_config.device_map["decoder"],
|
|
352
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
353
|
+
timeout=rbln_config.timeout,
|
|
354
|
+
),
|
|
355
|
+
]
|
|
356
|
+
|
|
357
|
+
def can_generate(self):
|
|
358
|
+
return True
|
|
359
|
+
|
|
360
|
+
def get_encoder(self):
|
|
361
|
+
return self.encoder
|
|
362
|
+
|
|
363
|
+
def get_decoder(self):
|
|
364
|
+
return self.decoder
|
|
365
|
+
|
|
366
|
+
def prepare_inputs_for_generation(
|
|
367
|
+
self,
|
|
368
|
+
input_ids,
|
|
369
|
+
attention_mask=None,
|
|
370
|
+
decoder_attention_mask=None,
|
|
371
|
+
**kwargs,
|
|
372
|
+
):
|
|
373
|
+
cur_seq_len = input_ids.shape[-1]
|
|
374
|
+
cache_position = cur_seq_len - 1
|
|
375
|
+
max_seq_len = self.rbln_config.dec_max_seq_len
|
|
376
|
+
decoder_batch_size = input_ids.shape[0]
|
|
377
|
+
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
|
378
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
|
|
379
|
+
decoder_attention_mask[:, :cur_seq_len] = 1
|
|
380
|
+
|
|
381
|
+
return {
|
|
382
|
+
"decoder_input_ids": input_ids,
|
|
383
|
+
"attention_mask": attention_mask.to(torch.float32),
|
|
384
|
+
"decoder_attention_mask": decoder_attention_mask,
|
|
385
|
+
"cache_position": cache_position,
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
def forward(
|
|
389
|
+
self,
|
|
390
|
+
decoder_input_ids: torch.LongTensor = None,
|
|
391
|
+
cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
|
|
392
|
+
**kwargs,
|
|
393
|
+
) -> Tuple[torch.FloatTensor]:
|
|
394
|
+
# common decoder
|
|
395
|
+
cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
|
|
396
|
+
logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
|
|
397
|
+
|
|
398
|
+
return Seq2SeqLMOutput(
|
|
399
|
+
logits=logits,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
def _prepare_encoder_decoder_kwargs_for_generation(
|
|
403
|
+
self,
|
|
404
|
+
inputs_tensor: torch.Tensor,
|
|
405
|
+
model_kwargs,
|
|
406
|
+
model_input_name: Optional[str] = None,
|
|
407
|
+
generation_config: Optional["GenerationConfig"] = None,
|
|
408
|
+
) -> Dict[str, Any]:
|
|
409
|
+
# 1. get encoder
|
|
410
|
+
encoder = self.get_encoder()
|
|
411
|
+
|
|
412
|
+
# 2. Prepare encoder args and encoder kwargs from model kwargs.
|
|
413
|
+
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
|
414
|
+
encoder_kwargs = {
|
|
415
|
+
argument: value
|
|
416
|
+
for argument, value in model_kwargs.items()
|
|
417
|
+
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
|
418
|
+
}
|
|
419
|
+
encoder_signature = set(inspect.signature(encoder.forward).parameters)
|
|
420
|
+
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
|
|
421
|
+
if not encoder_accepts_wildcard:
|
|
422
|
+
encoder_kwargs = {
|
|
423
|
+
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
batch_size, input_len = inputs_tensor.shape
|
|
427
|
+
inputs_tensor = torch.nn.functional.pad(
|
|
428
|
+
inputs_tensor,
|
|
429
|
+
(0, self.rbln_config.enc_max_seq_len - input_len),
|
|
430
|
+
value=self.config.pad_token_id,
|
|
431
|
+
)
|
|
432
|
+
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
|
433
|
+
model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# 3. make sure that encoder returns `ModelOutput`
|
|
437
|
+
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
|
438
|
+
encoder_kwargs["return_dict"] = True
|
|
439
|
+
encoder_kwargs["output_hidden_states"] = False
|
|
440
|
+
encoder_kwargs["output_attentions"] = False
|
|
441
|
+
|
|
442
|
+
for b in range(batch_size):
|
|
443
|
+
block_tables = torch.tensor([b], dtype=torch.int16)
|
|
444
|
+
encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
|
|
445
|
+
encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
|
446
|
+
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
|
|
447
|
+
|
|
448
|
+
return model_kwargs
|
|
449
|
+
|
|
450
|
+
def generate(
|
|
451
|
+
self,
|
|
452
|
+
input_ids: torch.LongTensor,
|
|
453
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
454
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
455
|
+
**kwargs,
|
|
456
|
+
) -> Union[ModelOutput, torch.LongTensor]:
|
|
457
|
+
"""
|
|
458
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
459
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
input_ids (torch.LongTensor): The input ids to the model.
|
|
463
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
464
|
+
generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
465
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
466
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
467
|
+
kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Generates sequences of token ids for models with a language modeling head.
|
|
471
|
+
"""
|
|
472
|
+
if generation_config is not None:
|
|
473
|
+
kwargs["generation_config"] = generation_config
|
|
474
|
+
if attention_mask is not None:
|
|
475
|
+
kwargs["attention_mask"] = attention_mask
|
|
476
|
+
|
|
477
|
+
return super().generate(input_ids, **kwargs)
|