optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
|
17
|
+
# additional information regarding copyright ownership.
|
|
18
|
+
|
|
19
|
+
# All other portions of this software, including proprietary code,
|
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
|
21
|
+
# copied, modified, or distributed without prior written permission
|
|
22
|
+
# from Rebellions Inc.
|
|
23
|
+
|
|
24
|
+
import inspect
|
|
25
|
+
import logging
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union
|
|
28
|
+
|
|
29
|
+
import rebel
|
|
30
|
+
import torch
|
|
31
|
+
from rebel.compile_context import CompileContext
|
|
32
|
+
from transformers import PretrainedConfig, TimeSeriesTransformerForPrediction, TimeSeriesTransformerModel
|
|
33
|
+
from transformers.modeling_outputs import SampleTSPredictionOutput, Seq2SeqTSModelOutput
|
|
34
|
+
from transformers.modeling_utils import no_init_weights
|
|
35
|
+
|
|
36
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
37
|
+
from ....modeling import RBLNModel
|
|
38
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
39
|
+
from ...modeling_outputs import RBLNSeq2SeqTSDecoderOutput
|
|
40
|
+
from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
|
|
41
|
+
from .time_series_transformers_architecture import TimeSeriesTransformersWrapper
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
if TYPE_CHECKING:
|
|
47
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
51
|
+
mandatory_members = ["main_input_name"]
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
runtime: rebel.Runtime,
|
|
56
|
+
model: TimeSeriesTransformerModel,
|
|
57
|
+
**kwargs: Any,
|
|
58
|
+
) -> None:
|
|
59
|
+
super().__init__(runtime, **kwargs)
|
|
60
|
+
self._origin_model = model
|
|
61
|
+
|
|
62
|
+
def forward(
|
|
63
|
+
self,
|
|
64
|
+
past_values: torch.Tensor,
|
|
65
|
+
past_time_features: torch.Tensor,
|
|
66
|
+
static_categorical_features: Optional[torch.Tensor] = None,
|
|
67
|
+
static_real_features: Optional[torch.Tensor] = None,
|
|
68
|
+
past_observed_mask: Optional[torch.Tensor] = None,
|
|
69
|
+
future_values: Optional[torch.Tensor] = None,
|
|
70
|
+
future_time_features: Optional[torch.Tensor] = None,
|
|
71
|
+
):
|
|
72
|
+
# preprocess
|
|
73
|
+
transformer_inputs, loc, scale, static_feat = self._origin_model.create_network_inputs(
|
|
74
|
+
past_values=past_values,
|
|
75
|
+
past_time_features=past_time_features,
|
|
76
|
+
past_observed_mask=past_observed_mask,
|
|
77
|
+
static_categorical_features=static_categorical_features,
|
|
78
|
+
static_real_features=static_real_features,
|
|
79
|
+
future_values=future_values,
|
|
80
|
+
future_time_features=future_time_features,
|
|
81
|
+
)
|
|
82
|
+
enc_input = transformer_inputs[:, : self._origin_model.config.context_length, ...]
|
|
83
|
+
|
|
84
|
+
# enc_attn_key_value_caches is updated to device dram in-place
|
|
85
|
+
_ = super().forward(inputs_embeds=enc_input)
|
|
86
|
+
|
|
87
|
+
return Seq2SeqTSModelOutput(
|
|
88
|
+
loc=loc,
|
|
89
|
+
scale=scale,
|
|
90
|
+
static_features=static_feat,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
95
|
+
mandatory_members = ["main_input_name"]
|
|
96
|
+
|
|
97
|
+
def forward(
|
|
98
|
+
self,
|
|
99
|
+
inputs_embeds: torch.Tensor = None,
|
|
100
|
+
attention_mask: torch.Tensor = None,
|
|
101
|
+
cache_position: torch.Tensor = None,
|
|
102
|
+
):
|
|
103
|
+
block_tables = torch.zeros(1, 1, dtype=torch.int16)
|
|
104
|
+
outputs = super().forward(inputs_embeds, attention_mask, cache_position, block_tables)
|
|
105
|
+
|
|
106
|
+
return RBLNSeq2SeqTSDecoderOutput(
|
|
107
|
+
params=outputs[:-1],
|
|
108
|
+
last_hidden_states=outputs[-1],
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
113
|
+
"""
|
|
114
|
+
The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
|
|
115
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
116
|
+
|
|
117
|
+
A class to convert and run pre-trained transformer-based `TimeSeriesTransformerForPrediction` models on RBLN devices.
|
|
118
|
+
It implements the methods to convert a pre-trained transformers `TimeSeriesTransformerForPrediction` model into a RBLN transformer model by:
|
|
119
|
+
|
|
120
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
121
|
+
- compiling the resulting graph using the RBLN Compiler.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
auto_model_class = None
|
|
125
|
+
main_input_name = "inputs_embeds"
|
|
126
|
+
|
|
127
|
+
def __post_init__(self, **kwargs):
|
|
128
|
+
super().__post_init__(**kwargs)
|
|
129
|
+
self.batch_size = self.rbln_config.batch_size
|
|
130
|
+
self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
|
|
131
|
+
self.num_parallel_samples = self.rbln_config.num_parallel_samples
|
|
132
|
+
|
|
133
|
+
with no_init_weights():
|
|
134
|
+
self._origin_model = TimeSeriesTransformerForPrediction._from_config(self.config)
|
|
135
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
136
|
+
self._origin_model.model.embedder.load_state_dict(artifacts["embedder"])
|
|
137
|
+
self.encoder = RBLNRuntimeEncoder(
|
|
138
|
+
runtime=self.model[0],
|
|
139
|
+
main_input_name="inputs_embeds",
|
|
140
|
+
model=self._origin_model.model,
|
|
141
|
+
)
|
|
142
|
+
self.decoder = RBLNRuntimeDecoder(
|
|
143
|
+
runtime=self.model[1],
|
|
144
|
+
main_input_name="inputs_embeds",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def __getattr__(self, __name: str) -> Any:
|
|
148
|
+
def redirect(func):
|
|
149
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
150
|
+
|
|
151
|
+
val = getattr(TimeSeriesTransformerForPrediction, __name)
|
|
152
|
+
if val is not None and isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
153
|
+
return redirect(val)
|
|
154
|
+
|
|
155
|
+
@classmethod
|
|
156
|
+
def _wrap_model_if_needed(
|
|
157
|
+
self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
|
|
158
|
+
):
|
|
159
|
+
return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
|
|
160
|
+
|
|
161
|
+
@classmethod
|
|
162
|
+
@torch.inference_mode()
|
|
163
|
+
def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
|
|
164
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
165
|
+
|
|
166
|
+
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
167
|
+
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
168
|
+
|
|
169
|
+
context = CompileContext(use_weight_sharing=False)
|
|
170
|
+
|
|
171
|
+
enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
|
|
172
|
+
|
|
173
|
+
# Mark encoder's static tensors (cross kv states)
|
|
174
|
+
static_tensors = {}
|
|
175
|
+
for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
|
|
176
|
+
if "key_value_states" in name:
|
|
177
|
+
static_tensors[name] = tensor
|
|
178
|
+
context.mark_static_address(tensor)
|
|
179
|
+
|
|
180
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
181
|
+
|
|
182
|
+
# Mark decoder's static tensors (self kv states)
|
|
183
|
+
for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
|
|
184
|
+
if "key_value_states" in name:
|
|
185
|
+
context.mark_static_address(tensor)
|
|
186
|
+
|
|
187
|
+
compiled_decoder = cls.compile(
|
|
188
|
+
wrapped_model.decoder,
|
|
189
|
+
dec_compile_config,
|
|
190
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
191
|
+
device=rbln_config.device,
|
|
192
|
+
example_inputs=dec_example_inputs,
|
|
193
|
+
compile_context=context,
|
|
194
|
+
)
|
|
195
|
+
compiled_encoder = cls.compile(
|
|
196
|
+
wrapped_model.encoder,
|
|
197
|
+
enc_compile_config,
|
|
198
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
199
|
+
device=rbln_config.device,
|
|
200
|
+
example_inputs=enc_example_inputs,
|
|
201
|
+
compile_context=context,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def save_torch_artifacts(
|
|
208
|
+
cls,
|
|
209
|
+
model: "PreTrainedModel",
|
|
210
|
+
save_dir_path: Path,
|
|
211
|
+
subfolder: str,
|
|
212
|
+
rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
|
|
213
|
+
):
|
|
214
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
215
|
+
# store the torch tensor, weight, etc. in this function.
|
|
216
|
+
|
|
217
|
+
save_dict = {}
|
|
218
|
+
save_dict["embedder"] = model.model.embedder.state_dict()
|
|
219
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
220
|
+
|
|
221
|
+
@classmethod
|
|
222
|
+
def _update_rbln_config(
|
|
223
|
+
cls,
|
|
224
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
225
|
+
model: Optional["PreTrainedModel"] = None,
|
|
226
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
227
|
+
rbln_config: Optional[RBLNTimeSeriesTransformerForPredictionConfig] = None,
|
|
228
|
+
) -> RBLNTimeSeriesTransformerForPredictionConfig:
|
|
229
|
+
rbln_config.num_parallel_samples = rbln_config.num_parallel_samples or model_config.num_parallel_samples
|
|
230
|
+
|
|
231
|
+
if rbln_config.dec_max_seq_len is None:
|
|
232
|
+
predict_length = model_config.prediction_length
|
|
233
|
+
rbln_config.dec_max_seq_len = (
|
|
234
|
+
predict_length if predict_length % 64 == 0 else predict_length + (64 - predict_length % 64)
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# model input info
|
|
238
|
+
enc_input_info = [
|
|
239
|
+
(
|
|
240
|
+
"inputs_embeds",
|
|
241
|
+
[rbln_config.batch_size, model_config.context_length, model_config.feature_size],
|
|
242
|
+
"float32",
|
|
243
|
+
),
|
|
244
|
+
]
|
|
245
|
+
enc_input_info.extend(
|
|
246
|
+
[
|
|
247
|
+
(
|
|
248
|
+
"cross_key_value_states",
|
|
249
|
+
[
|
|
250
|
+
model_config.decoder_layers * 2,
|
|
251
|
+
rbln_config.batch_size,
|
|
252
|
+
model_config.decoder_attention_heads,
|
|
253
|
+
model_config.context_length,
|
|
254
|
+
model_config.d_model // model_config.decoder_attention_heads,
|
|
255
|
+
],
|
|
256
|
+
"float32",
|
|
257
|
+
)
|
|
258
|
+
]
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
dec_input_info = [
|
|
262
|
+
(
|
|
263
|
+
"inputs_embeds",
|
|
264
|
+
[rbln_config.batch_size * rbln_config.num_parallel_samples, 1, model_config.feature_size],
|
|
265
|
+
"float32",
|
|
266
|
+
),
|
|
267
|
+
("attention_mask", [1, rbln_config.dec_max_seq_len], "float32"),
|
|
268
|
+
("cache_position", [], "int32"),
|
|
269
|
+
("block_tables", [1, 1], "int16"),
|
|
270
|
+
]
|
|
271
|
+
dec_input_info.extend(
|
|
272
|
+
[
|
|
273
|
+
(
|
|
274
|
+
"cross_key_value_states",
|
|
275
|
+
[
|
|
276
|
+
model_config.decoder_layers * 2, # 4
|
|
277
|
+
rbln_config.batch_size, # 64
|
|
278
|
+
model_config.decoder_attention_heads, # 2
|
|
279
|
+
model_config.context_length, # 24
|
|
280
|
+
model_config.d_model // model_config.decoder_attention_heads, # 13
|
|
281
|
+
],
|
|
282
|
+
"float32",
|
|
283
|
+
)
|
|
284
|
+
]
|
|
285
|
+
)
|
|
286
|
+
dec_input_info.extend(
|
|
287
|
+
[
|
|
288
|
+
(
|
|
289
|
+
f"self_key_value_states_{i}",
|
|
290
|
+
[
|
|
291
|
+
1,
|
|
292
|
+
model_config.decoder_attention_heads
|
|
293
|
+
* rbln_config.num_parallel_samples
|
|
294
|
+
* rbln_config.batch_size,
|
|
295
|
+
rbln_config.dec_max_seq_len,
|
|
296
|
+
model_config.d_model // model_config.encoder_attention_heads,
|
|
297
|
+
],
|
|
298
|
+
"float32",
|
|
299
|
+
)
|
|
300
|
+
for i in range(model_config.decoder_layers * 2)
|
|
301
|
+
]
|
|
302
|
+
)
|
|
303
|
+
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
|
304
|
+
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
|
305
|
+
|
|
306
|
+
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
|
307
|
+
return rbln_config
|
|
308
|
+
|
|
309
|
+
@classmethod
|
|
310
|
+
def _create_runtimes(
|
|
311
|
+
cls,
|
|
312
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
313
|
+
rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
|
|
314
|
+
) -> List[rebel.Runtime]:
|
|
315
|
+
if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
|
|
316
|
+
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
|
317
|
+
|
|
318
|
+
return [
|
|
319
|
+
rebel.Runtime(
|
|
320
|
+
compiled_models[0],
|
|
321
|
+
tensor_type="pt",
|
|
322
|
+
device=rbln_config.device_map["encoder"],
|
|
323
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
324
|
+
timeout=rbln_config.timeout,
|
|
325
|
+
),
|
|
326
|
+
rebel.Runtime(
|
|
327
|
+
compiled_models[1],
|
|
328
|
+
tensor_type="pt",
|
|
329
|
+
device=rbln_config.device_map["decoder"],
|
|
330
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
331
|
+
timeout=rbln_config.timeout,
|
|
332
|
+
),
|
|
333
|
+
]
|
|
334
|
+
|
|
335
|
+
def validate_batch_size(self, **kwargs):
|
|
336
|
+
for k, v in kwargs.items():
|
|
337
|
+
if v is not None and v.shape[0] != self.batch_size:
|
|
338
|
+
raise RuntimeError(
|
|
339
|
+
f"Batch size mismatch in '{k}': Expected {self.batch_size}, but got {v.shape[0]}. \n"
|
|
340
|
+
f"Tensor shape: {v.shape} \n\n"
|
|
341
|
+
f"Note: `batch_size` is set at compile time. \n"
|
|
342
|
+
f"To change it, pass `export=True` along with `rbln_batch_size` when calling `from_pretrained()` to trigger recompilation."
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
@torch.no_grad()
|
|
346
|
+
def generate(
|
|
347
|
+
self,
|
|
348
|
+
past_values: torch.Tensor,
|
|
349
|
+
past_time_features: torch.Tensor,
|
|
350
|
+
future_time_features: torch.Tensor,
|
|
351
|
+
past_observed_mask: Optional[torch.Tensor] = None,
|
|
352
|
+
static_categorical_features: Optional[torch.Tensor] = None,
|
|
353
|
+
static_real_features: Optional[torch.Tensor] = None,
|
|
354
|
+
**kwargs,
|
|
355
|
+
) -> SampleTSPredictionOutput:
|
|
356
|
+
"""
|
|
357
|
+
Generate pass for the RBLN-optimized Time Series Transformer model for time series forecasting.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
past_values (torch.FloatTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size)): Past values of the time series, that serve as context in order to predict the future.
|
|
361
|
+
past_time_features (torch.FloatTensor of shape (batch_size, sequence_length, num_features)): Required time features, which the model internally will add to past_values.
|
|
362
|
+
future_time_features (torch.FloatTensor of shape (batch_size, prediction_length, num_features)): Required time features for the prediction window, which the model internally will add to future_values.
|
|
363
|
+
past_observed_mask (torch.BoolTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size), optional): Boolean mask to indicate which past_values were observed and which were missing.
|
|
364
|
+
static_categorical_features (torch.LongTensor of shape (batch_size, number of static categorical features), optional): Optional static categorical features for which the model will learn an embedding, which it will add to the values of the time series.
|
|
365
|
+
static_real_features (torch.FloatTensor of shape (batch_size, number of static real features), optional): Optional static real features which the model will add to the values of the time series.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SampleTSPredictionOutput object.
|
|
369
|
+
"""
|
|
370
|
+
self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
|
|
371
|
+
|
|
372
|
+
outputs = self.encoder(
|
|
373
|
+
static_categorical_features=static_categorical_features,
|
|
374
|
+
static_real_features=static_real_features,
|
|
375
|
+
past_time_features=past_time_features,
|
|
376
|
+
past_values=past_values,
|
|
377
|
+
past_observed_mask=past_observed_mask,
|
|
378
|
+
future_time_features=future_time_features,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
loc = outputs.loc
|
|
382
|
+
scale = outputs.scale
|
|
383
|
+
static_feat = outputs.static_features
|
|
384
|
+
|
|
385
|
+
num_parallel_samples = self.num_parallel_samples
|
|
386
|
+
repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0)
|
|
387
|
+
repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0)
|
|
388
|
+
|
|
389
|
+
repeated_past_values = (
|
|
390
|
+
past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc
|
|
391
|
+
) / repeated_scale
|
|
392
|
+
|
|
393
|
+
expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1)
|
|
394
|
+
features = torch.cat((expanded_static_feat, future_time_features), dim=-1)
|
|
395
|
+
repeated_features = features.repeat_interleave(repeats=num_parallel_samples, dim=0)
|
|
396
|
+
|
|
397
|
+
# greedy decoding
|
|
398
|
+
future_samples = []
|
|
399
|
+
dec_attn_mask = torch.zeros(1, self.dec_max_seq_len)
|
|
400
|
+
for k in range(self.config.prediction_length):
|
|
401
|
+
lagged_sequence = self._origin_model.model.get_lagged_subsequences(
|
|
402
|
+
sequence=repeated_past_values,
|
|
403
|
+
subsequences_length=1 + k,
|
|
404
|
+
shift=1,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
lags_shape = lagged_sequence.shape
|
|
408
|
+
reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)
|
|
409
|
+
decoder_input = torch.cat((reshaped_lagged_sequence, repeated_features[:, : k + 1]), dim=-1)
|
|
410
|
+
|
|
411
|
+
dec_attn_mask[:, k] = 1
|
|
412
|
+
dec_inputs_embeds = decoder_input[:, -1:]
|
|
413
|
+
|
|
414
|
+
decoder_out = self.decoder(
|
|
415
|
+
inputs_embeds=dec_inputs_embeds.contiguous(),
|
|
416
|
+
attention_mask=dec_attn_mask,
|
|
417
|
+
cache_position=torch.tensor(k, dtype=torch.int32),
|
|
418
|
+
)
|
|
419
|
+
params = decoder_out.params
|
|
420
|
+
|
|
421
|
+
distr = self._origin_model.output_distribution(params, loc=repeated_loc, scale=repeated_scale)
|
|
422
|
+
next_sample = distr.sample()
|
|
423
|
+
|
|
424
|
+
repeated_past_values = torch.cat(
|
|
425
|
+
(repeated_past_values, (next_sample - repeated_loc) / repeated_scale), dim=1
|
|
426
|
+
)
|
|
427
|
+
future_samples.append(next_sample)
|
|
428
|
+
|
|
429
|
+
concat_future_samples = torch.cat(future_samples, dim=1)
|
|
430
|
+
|
|
431
|
+
return SampleTSPredictionOutput(
|
|
432
|
+
sequences=concat_future_samples.reshape(
|
|
433
|
+
(-1, num_parallel_samples, self.config.prediction_length) + self._origin_model.target_shape,
|
|
434
|
+
)
|
|
435
|
+
)
|