optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
|
|
20
|
+
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
|
21
|
+
|
|
22
|
+
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
|
23
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
|
24
|
+
from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
|
|
25
|
+
from .t5_architecture import T5Wrapper
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from transformers import PreTrainedModel
|
|
30
|
+
|
|
31
|
+
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class T5EncoderWrapper(torch.nn.Module):
|
|
35
|
+
def __init__(self, model: "T5EncoderModel") -> None:
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.model = model
|
|
38
|
+
|
|
39
|
+
def forward(self, *args, **kwargs):
|
|
40
|
+
kwargs.pop("return_dict", None)
|
|
41
|
+
return self.model(*args, **kwargs, return_dict=False)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
45
|
+
"""
|
|
46
|
+
The T5 Model transformer with an encoder-only architecture for feature extraction.
|
|
47
|
+
This model inherits from [`RBLNTransformerEncoderForFeatureExtraction`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
48
|
+
|
|
49
|
+
Important Note:
|
|
50
|
+
This model supports various sizes of the T5EncoderModel. For optimal performance, it is highly recommended to adjust the tensor parallelism setting
|
|
51
|
+
based on the model size. Please refer to the [Optimum RBLN Overview](../../../optimum_rbln.md) for guidance on choosing the appropriate tensor parallelism size for your model.
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
```python
|
|
55
|
+
from optimum.rbln import RBLNT5EncoderModel
|
|
56
|
+
|
|
57
|
+
model = RBLNT5EncoderModel.from_pretrained(
|
|
58
|
+
"sentence-transformers/sentence-t5-xxl",
|
|
59
|
+
export=True,
|
|
60
|
+
rbln_tensor_parallel_size=4,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
model.save_pretrained("compiled-sentence-t5-xxl")
|
|
64
|
+
```
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
auto_model_class = AutoModelForTextEncoding
|
|
68
|
+
output_class = BaseModelOutputWithPastAndCrossAttentions
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
|
|
72
|
+
return T5EncoderWrapper(model)
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def update_rbln_config_using_pipe(
|
|
76
|
+
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
77
|
+
) -> "RBLNDiffusionMixinConfig":
|
|
78
|
+
return rbln_config
|
|
79
|
+
|
|
80
|
+
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
|
81
|
+
input_dict = {"input_ids": input_ids.long()}
|
|
82
|
+
if attention_mask is not None:
|
|
83
|
+
input_dict["attention_mask"] = attention_mask.long()
|
|
84
|
+
|
|
85
|
+
output = super().forward(**input_dict, **kwargs)
|
|
86
|
+
return output
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
90
|
+
"""
|
|
91
|
+
The T5 Model transformer with a language modeling head for conditional generation.
|
|
92
|
+
This model inherits from [`RBLNModelForSeq2SeqLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
93
|
+
|
|
94
|
+
Important Note:
|
|
95
|
+
This model supports various sizes of the T5ForConditionalGeneration. For optimal performance, it is highly recommended to adjust the tensor parallelism setting
|
|
96
|
+
based on the model size. Please refer to the [Optimum RBLN Overview](../../../optimum_rbln.md) for guidance on choosing the appropriate tensor parallelism size for your model.
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
Examples:
|
|
100
|
+
```python
|
|
101
|
+
from optimum.rbln import RBLNT5ForConditionalGeneration
|
|
102
|
+
|
|
103
|
+
model = RBLNT5ForConditionalGeneration.from_pretrained(
|
|
104
|
+
"google-t5/t5-11b",
|
|
105
|
+
export=True,
|
|
106
|
+
rbln_tensor_parallel_size=4,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
model.save_pretrained("compiled-sentence-t5-xxl")
|
|
110
|
+
```
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
support_causal_attn = False
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
|
|
117
|
+
return T5Wrapper(
|
|
118
|
+
model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def __getattr__(self, __name: str) -> Any:
|
|
122
|
+
def redirect(func):
|
|
123
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
124
|
+
|
|
125
|
+
val = getattr(T5ForConditionalGeneration, __name)
|
|
126
|
+
|
|
127
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
128
|
+
return redirect(val)
|
|
129
|
+
|
|
130
|
+
return val
|
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
from transformers.utils import logging
|
|
20
|
+
|
|
21
|
+
from ..seq2seq.seq2seq_architecture import (
|
|
22
|
+
Seq2SeqDecoder,
|
|
23
|
+
Seq2SeqDecoderLayer,
|
|
24
|
+
Seq2SeqDecoderWrapper,
|
|
25
|
+
Seq2SeqEncoderWrapper,
|
|
26
|
+
Seq2SeqForConditionalGeneration,
|
|
27
|
+
Seq2SeqSelfAttention,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
logger = logging.get_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class T5Wrapper:
|
|
35
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, dec_max_seq_len: int = None):
|
|
36
|
+
self.encoder = T5EncoderWrapper(model, enc_max_seq_len)
|
|
37
|
+
self.decoder = T5DecoderWrapper(model, dec_max_seq_len=dec_max_seq_len)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class T5EncoderWrapper(Seq2SeqEncoderWrapper):
|
|
41
|
+
def __post_init__(self, model: nn.Module):
|
|
42
|
+
self.n_layer = getattr(self.config, "num_layers")
|
|
43
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().block)
|
|
44
|
+
self.num_heads = self.config.num_heads
|
|
45
|
+
self.d_kv = self.config.d_kv
|
|
46
|
+
|
|
47
|
+
def _extract_cross_kv_projects(self, t5_block: nn.Module):
|
|
48
|
+
return (
|
|
49
|
+
# different from bart
|
|
50
|
+
nn.ModuleList(t5_block[i].layer[1].EncDecAttention.k for i in range(self.n_layer)),
|
|
51
|
+
nn.ModuleList(t5_block[i].layer[1].EncDecAttention.v for i in range(self.n_layer)),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
|
56
|
+
def __post_init__(self, model, dec_max_seq_len: int = None):
|
|
57
|
+
self.num_layers = self.config.num_layers
|
|
58
|
+
self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
|
|
59
|
+
|
|
60
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module, dec_max_seq_len: int):
|
|
61
|
+
new_blocks = []
|
|
62
|
+
for block in model.get_decoder().block:
|
|
63
|
+
self_attn = T5LayerSelfAttention(block.layer[0].SelfAttention)
|
|
64
|
+
block = T5Block(block, self_attn)
|
|
65
|
+
new_blocks.append(block)
|
|
66
|
+
|
|
67
|
+
decoder_model = T5Decoder(model.get_decoder(), new_blocks, dec_max_seq_len=dec_max_seq_len)
|
|
68
|
+
new_model = T5ForConditionalGeneration(model, decoder_model)
|
|
69
|
+
|
|
70
|
+
return new_model
|
|
71
|
+
|
|
72
|
+
def forward(
|
|
73
|
+
self,
|
|
74
|
+
input_ids,
|
|
75
|
+
attention_mask,
|
|
76
|
+
encoder_attention_mask,
|
|
77
|
+
cache_position,
|
|
78
|
+
block_tables,
|
|
79
|
+
*kv_cache,
|
|
80
|
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
|
81
|
+
self_past_key_values = ()
|
|
82
|
+
cross_past_key_values = ()
|
|
83
|
+
self_kv_cache = kv_cache[self.num_layers * 2 :]
|
|
84
|
+
cross_kv_cache = kv_cache[: self.num_layers * 2]
|
|
85
|
+
|
|
86
|
+
for i in range(0, self.num_layers * 2, 2):
|
|
87
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
|
88
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
|
89
|
+
|
|
90
|
+
# decode
|
|
91
|
+
lm_logits = self.conditional_generation(
|
|
92
|
+
input_ids=input_ids,
|
|
93
|
+
attention_mask=attention_mask,
|
|
94
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
95
|
+
self_past_key_values=self_past_key_values,
|
|
96
|
+
cross_past_key_values=cross_past_key_values,
|
|
97
|
+
cache_position=cache_position,
|
|
98
|
+
block_tables=block_tables,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return lm_logits
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
|
105
|
+
has_rescaling = True
|
|
106
|
+
|
|
107
|
+
def __post_init__(self):
|
|
108
|
+
self.scaling = self.config.d_model**-0.5
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class T5Decoder(Seq2SeqDecoder):
|
|
112
|
+
has_pos_emb = False
|
|
113
|
+
|
|
114
|
+
def __post_init__(self, dec_max_seq_len: int = None):
|
|
115
|
+
self.invert_attention_mask = self._original_mod.invert_attention_mask
|
|
116
|
+
self._dec_position_bias = self.precompute_dec_position_bias(self._original_mod, dec_max_seq_len)
|
|
117
|
+
|
|
118
|
+
def precompute_dec_position_bias(self, model, dec_max_length):
|
|
119
|
+
attn_layer = model.block[0].layer[0].SelfAttention
|
|
120
|
+
return attn_layer.compute_bias(dec_max_length, dec_max_length)
|
|
121
|
+
|
|
122
|
+
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, cache_position):
|
|
123
|
+
attention_mask = self.invert_attention_mask(attention_mask)
|
|
124
|
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
|
125
|
+
|
|
126
|
+
b_size = attention_mask.shape[0]
|
|
127
|
+
batch_decoder_position_bias = []
|
|
128
|
+
for i in range(b_size):
|
|
129
|
+
if torch.compiler.is_exporting():
|
|
130
|
+
cache_pos = cache_position[i][0].item()
|
|
131
|
+
torch._check_is_size(cache_pos)
|
|
132
|
+
torch._check(cache_pos >= 0)
|
|
133
|
+
torch._check(cache_pos < self._dec_position_bias.shape[2])
|
|
134
|
+
else:
|
|
135
|
+
cache_pos = cache_position[i][0]
|
|
136
|
+
batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
|
|
137
|
+
batch_decoder_position_bias.append(batch_position_bias)
|
|
138
|
+
position_bias = torch.cat(batch_decoder_position_bias, dim=0)
|
|
139
|
+
|
|
140
|
+
attention_mask = position_bias + attention_mask
|
|
141
|
+
|
|
142
|
+
return attention_mask, encoder_attention_mask
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class T5Block(Seq2SeqDecoderLayer):
|
|
146
|
+
def __init__(self, decoder_layer, self_attn):
|
|
147
|
+
super().__init__(decoder_layer, self_attn, cross_attn=None)
|
|
148
|
+
self.__post_init__()
|
|
149
|
+
|
|
150
|
+
def __post_init__(self):
|
|
151
|
+
self.self_attn_layer_norm = self._original_mod.layer[0].layer_norm
|
|
152
|
+
self.encoder_attn_layer_norm = self._original_mod.layer[1].layer_norm
|
|
153
|
+
self.cross_attn = T5CrossAttention(self._original_mod.layer[1].EncDecAttention)
|
|
154
|
+
self.ff_layer = self._original_mod.layer[2]
|
|
155
|
+
|
|
156
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
|
157
|
+
return self.self_attn_layer_norm(hidden_states)
|
|
158
|
+
|
|
159
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
|
160
|
+
return hidden_states
|
|
161
|
+
|
|
162
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
|
163
|
+
return self.encoder_attn_layer_norm(hidden_states)
|
|
164
|
+
|
|
165
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
|
166
|
+
return hidden_states
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
170
|
+
def __post_init__(self):
|
|
171
|
+
self.q_proj = self._original_mod.q
|
|
172
|
+
self.k_proj = self._original_mod.k
|
|
173
|
+
self.v_proj = self._original_mod.v
|
|
174
|
+
self.out_proj = self._original_mod.o
|
|
175
|
+
self.num_heads = self._original_mod.n_heads
|
|
176
|
+
self.head_dim = self._original_mod.key_value_proj_dim
|
|
177
|
+
self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
|
|
178
|
+
|
|
179
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
180
|
+
query_states = self.q_proj(hidden_states)
|
|
181
|
+
key_states = self.k_proj(hidden_states)
|
|
182
|
+
value_states = self.v_proj(hidden_states)
|
|
183
|
+
return query_states, key_states, value_states
|
|
184
|
+
|
|
185
|
+
def forward(
|
|
186
|
+
self,
|
|
187
|
+
hidden_states: torch.Tensor,
|
|
188
|
+
past_key_value: Tuple[torch.Tensor],
|
|
189
|
+
attention_mask: torch.Tensor,
|
|
190
|
+
cache_position: torch.Tensor,
|
|
191
|
+
block_tables: torch.Tensor,
|
|
192
|
+
**kwargs,
|
|
193
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
|
194
|
+
bsz, tgt_len, _ = hidden_states.size()
|
|
195
|
+
|
|
196
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
197
|
+
query_states = self._shape(query_states, tgt_len, bsz)
|
|
198
|
+
key_states = self._shape(key_states, -1, bsz)
|
|
199
|
+
value_states = self._shape(value_states, -1, bsz)
|
|
200
|
+
|
|
201
|
+
block_size = past_key_value[0].shape[-2]
|
|
202
|
+
attn_output = self.attn_decode(
|
|
203
|
+
query_states,
|
|
204
|
+
key_states,
|
|
205
|
+
value_states,
|
|
206
|
+
attention_mask.unsqueeze(
|
|
207
|
+
2
|
|
208
|
+
), # Unsqueeze group axis since CustomKernel expects it for group query attention
|
|
209
|
+
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
|
210
|
+
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
|
211
|
+
cache_position,
|
|
212
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
|
213
|
+
block_tables,
|
|
214
|
+
block_size,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
|
218
|
+
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
|
219
|
+
|
|
220
|
+
attn_output = self.out_proj(attn_output)
|
|
221
|
+
return attn_output
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class T5CrossAttention(nn.Module):
|
|
225
|
+
def __init__(self, attn):
|
|
226
|
+
super().__init__()
|
|
227
|
+
self.attn = attn
|
|
228
|
+
self.q = attn.q
|
|
229
|
+
self.o = attn.o
|
|
230
|
+
self.n_heads = attn.n_heads
|
|
231
|
+
self.key_value_proj_dim = attn.key_value_proj_dim
|
|
232
|
+
self.inner_dim = attn.inner_dim
|
|
233
|
+
|
|
234
|
+
def forward(
|
|
235
|
+
self,
|
|
236
|
+
hidden_states: torch.Tensor = None,
|
|
237
|
+
past_key_value: torch.Tensor = None,
|
|
238
|
+
attention_mask: torch.Tensor = None,
|
|
239
|
+
key_value_states: torch.Tensor = None,
|
|
240
|
+
):
|
|
241
|
+
batch_size = hidden_states.shape[0]
|
|
242
|
+
|
|
243
|
+
query_states = self.q(hidden_states)
|
|
244
|
+
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
|
245
|
+
|
|
246
|
+
# reuse k,v, cross_attentions
|
|
247
|
+
key_states = past_key_value[0]
|
|
248
|
+
value_states = past_key_value[1]
|
|
249
|
+
|
|
250
|
+
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
|
251
|
+
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
|
252
|
+
scores += attention_mask
|
|
253
|
+
|
|
254
|
+
# (batch_size, n_heads, seq_length, key_length)
|
|
255
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
|
256
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
|
257
|
+
|
|
258
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
259
|
+
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
|
|
260
|
+
attn_output = self.o(attn_output)
|
|
261
|
+
|
|
262
|
+
outputs = (attn_output, past_key_value)
|
|
263
|
+
|
|
264
|
+
return outputs
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
|
17
|
+
# additional information regarding copyright ownership.
|
|
18
|
+
|
|
19
|
+
# All other portions of this software, including proprietary code,
|
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
|
21
|
+
# copied, modified, or distributed without prior written permission
|
|
22
|
+
# from Rebellions Inc.
|
|
23
|
+
|
|
24
|
+
from ....ops import paged_add_softmax_attn_decode, rbln_cache_update
|
|
25
|
+
from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
|
|
26
|
+
from .modeling_time_series_transformer import RBLNTimeSeriesTransformerForPrediction
|
optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from ....configuration_utils import RBLNModelConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
|
|
7
|
+
"""
|
|
8
|
+
Configuration class for RBLNTimeSeriesTransformerForPrediction.
|
|
9
|
+
|
|
10
|
+
This configuration class stores the configuration parameters specific to
|
|
11
|
+
RBLN-optimized Time Series Transformer models for time series forecasting tasks.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
batch_size: Optional[int] = None,
|
|
17
|
+
enc_max_seq_len: Optional[int] = None,
|
|
18
|
+
dec_max_seq_len: Optional[int] = None,
|
|
19
|
+
num_parallel_samples: Optional[int] = None,
|
|
20
|
+
**kwargs: Any,
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
25
|
+
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
|
26
|
+
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
|
27
|
+
num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
|
|
28
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ValueError: If batch_size is not a positive integer.
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(**kwargs)
|
|
34
|
+
|
|
35
|
+
self.batch_size = batch_size or 1
|
|
36
|
+
if not isinstance(self.batch_size, int) or self.batch_size <= 0:
|
|
37
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
38
|
+
|
|
39
|
+
self.enc_max_seq_len = enc_max_seq_len
|
|
40
|
+
self.dec_max_seq_len = dec_max_seq_len
|
|
41
|
+
self.num_parallel_samples = num_parallel_samples
|