optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,7 +20,9 @@ import rebel
|
|
|
20
20
|
import torch
|
|
21
21
|
from rebel.compile_context import CompileContext
|
|
22
22
|
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
|
23
|
-
from transformers.
|
|
23
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
24
|
+
from transformers.generation.utils import GenerationMixin
|
|
25
|
+
from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
|
|
24
26
|
|
|
25
27
|
from ....configuration_utils import RBLNCompileConfig
|
|
26
28
|
from ....modeling import RBLNModel
|
|
@@ -32,13 +34,13 @@ from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
|
|
|
32
34
|
logger = get_logger(__name__)
|
|
33
35
|
|
|
34
36
|
if TYPE_CHECKING:
|
|
35
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer,
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
39
41
|
mandatory_members = ["main_input_name"]
|
|
40
42
|
|
|
41
|
-
def forward(self, *args: List[torch.Tensor], **kwargs:
|
|
43
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
|
|
42
44
|
output = super().forward(*args, **kwargs)
|
|
43
45
|
return BaseModelOutput(last_hidden_state=output)
|
|
44
46
|
|
|
@@ -83,7 +85,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
83
85
|
decoding_step = cache_position[b_idx].item()
|
|
84
86
|
if not (0 <= decoding_step < self.dec_max_seq_len):
|
|
85
87
|
raise ValueError(
|
|
86
|
-
f"Decoding step {decoding_step} out of bounds for
|
|
88
|
+
f"Decoding step {decoding_step} out of bounds for decoder_max_seq_len ({self.dec_max_seq_len})."
|
|
87
89
|
)
|
|
88
90
|
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
|
89
91
|
|
|
@@ -101,7 +103,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
101
103
|
return Seq2SeqLMOutput(logits=lm_logits)
|
|
102
104
|
|
|
103
105
|
|
|
104
|
-
class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
106
|
+
class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
|
|
105
107
|
"""
|
|
106
108
|
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
|
107
109
|
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
@@ -117,6 +119,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
117
119
|
main_input_name = "input_ids"
|
|
118
120
|
auto_model_class = AutoModelForSeq2SeqLM
|
|
119
121
|
support_causal_attn = None
|
|
122
|
+
_is_stateful = False
|
|
120
123
|
|
|
121
124
|
def __post_init__(self, **kwargs):
|
|
122
125
|
batch_size = self.rbln_config.batch_size
|
|
@@ -138,7 +141,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
138
141
|
@classmethod
|
|
139
142
|
@torch.inference_mode()
|
|
140
143
|
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
|
141
|
-
wrapped_model = cls.
|
|
144
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
142
145
|
|
|
143
146
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
144
147
|
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
@@ -181,6 +184,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
181
184
|
|
|
182
185
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
183
186
|
|
|
187
|
+
@classmethod
|
|
188
|
+
def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
|
189
|
+
rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
|
|
190
|
+
rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
|
|
191
|
+
|
|
192
|
+
if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
|
|
193
|
+
raise NotImplementedError(
|
|
194
|
+
f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
|
|
198
|
+
raise NotImplementedError(
|
|
199
|
+
f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
|
|
200
|
+
)
|
|
201
|
+
|
|
184
202
|
@classmethod
|
|
185
203
|
def _update_rbln_config(
|
|
186
204
|
cls,
|
|
@@ -204,12 +222,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
204
222
|
model_config, "max_position_embeddings", None
|
|
205
223
|
)
|
|
206
224
|
|
|
207
|
-
pad_token_id = getattr(model_config, "pad_token_id", None)
|
|
208
|
-
pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
|
|
209
|
-
pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
|
|
210
|
-
pad_token_id = pad_token_id or -1
|
|
211
|
-
rbln_config.pad_token_id = pad_token_id
|
|
212
|
-
|
|
213
225
|
if rbln_config.enc_max_seq_len is None:
|
|
214
226
|
enc_max_seq_len = max_position_embeddings
|
|
215
227
|
for tokenizer in preprocessors:
|
|
@@ -238,6 +250,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
238
250
|
if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
|
|
239
251
|
raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
|
240
252
|
|
|
253
|
+
if rbln_config.support_paged_attention:
|
|
254
|
+
cls._update_paged_attention_config(model_config, rbln_config)
|
|
255
|
+
|
|
241
256
|
# model input info
|
|
242
257
|
enc_input_info = [
|
|
243
258
|
("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
|
|
@@ -310,6 +325,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
310
325
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
|
311
326
|
|
|
312
327
|
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
|
328
|
+
|
|
313
329
|
return rbln_config
|
|
314
330
|
|
|
315
331
|
@classmethod
|
|
@@ -327,12 +343,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
327
343
|
tensor_type="pt",
|
|
328
344
|
device=rbln_config.device_map["encoder"],
|
|
329
345
|
activate_profiler=rbln_config.activate_profiler,
|
|
346
|
+
timeout=rbln_config.timeout,
|
|
330
347
|
),
|
|
331
348
|
rebel.Runtime(
|
|
332
349
|
compiled_models[1],
|
|
333
350
|
tensor_type="pt",
|
|
334
351
|
device=rbln_config.device_map["decoder"],
|
|
335
352
|
activate_profiler=rbln_config.activate_profiler,
|
|
353
|
+
timeout=rbln_config.timeout,
|
|
336
354
|
),
|
|
337
355
|
]
|
|
338
356
|
|
|
@@ -409,7 +427,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
409
427
|
inputs_tensor = torch.nn.functional.pad(
|
|
410
428
|
inputs_tensor,
|
|
411
429
|
(0, self.rbln_config.enc_max_seq_len - input_len),
|
|
412
|
-
value=self.
|
|
430
|
+
value=self.config.pad_token_id,
|
|
413
431
|
)
|
|
414
432
|
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
|
415
433
|
model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
|
|
@@ -428,3 +446,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
428
446
|
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
|
|
429
447
|
|
|
430
448
|
return model_kwargs
|
|
449
|
+
|
|
450
|
+
def generate(
|
|
451
|
+
self,
|
|
452
|
+
input_ids: torch.LongTensor,
|
|
453
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
454
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
455
|
+
**kwargs,
|
|
456
|
+
) -> Union[ModelOutput, torch.LongTensor]:
|
|
457
|
+
"""
|
|
458
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
459
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
input_ids (torch.LongTensor): The input ids to the model.
|
|
463
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
464
|
+
generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
465
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
466
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
467
|
+
kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Generates sequences of token ids for models with a language modeling head.
|
|
471
|
+
"""
|
|
472
|
+
if generation_config is not None:
|
|
473
|
+
kwargs["generation_config"] = generation_config
|
|
474
|
+
if attention_mask is not None:
|
|
475
|
+
kwargs["attention_mask"] = attention_mask
|
|
476
|
+
|
|
477
|
+
return super().generate(input_ids, **kwargs)
|
|
@@ -31,7 +31,7 @@ class Seq2SeqWrapper:
|
|
|
31
31
|
Args:
|
|
32
32
|
model (nn.Module): The Seq2Seq model to wrap.
|
|
33
33
|
enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
|
|
34
|
-
|
|
34
|
+
kwargs: Additional arguments to pass to the decoder wrapper.
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
37
|
def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
|
|
@@ -125,7 +125,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
|
125
125
|
|
|
126
126
|
Args:
|
|
127
127
|
model (nn.Module): The Seq2Seq model containing the decoder.
|
|
128
|
-
|
|
128
|
+
kwargs: Additional arguments for decoder configuration.
|
|
129
129
|
"""
|
|
130
130
|
|
|
131
131
|
def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
|
|
@@ -12,9 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_siglip import
|
|
16
|
-
|
|
17
|
-
)
|
|
18
|
-
from .modeling_siglip import (
|
|
19
|
-
RBLNSiglipVisionModel,
|
|
20
|
-
)
|
|
15
|
+
from .configuration_siglip import RBLNSiglipVisionModelConfig
|
|
16
|
+
from .modeling_siglip import RBLNSiglipVisionModel
|
|
@@ -42,7 +42,7 @@ class RBLNSiglipVisionModelConfig(RBLNModelConfig):
|
|
|
42
42
|
interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
|
|
43
43
|
output_hidden_states: (Optional[bool]): Whether to return hidden states.
|
|
44
44
|
output_attentions: (Optional[bool]): Whether to return attentions.
|
|
45
|
-
|
|
45
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
46
46
|
|
|
47
47
|
Raises:
|
|
48
48
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import TYPE_CHECKING, Any,
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from transformers import SiglipVisionConfig, SiglipVisionModel
|
|
@@ -29,8 +29,6 @@ logger = get_logger(__name__)
|
|
|
29
29
|
if TYPE_CHECKING:
|
|
30
30
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
31
31
|
|
|
32
|
-
from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
|
|
33
|
-
|
|
34
32
|
|
|
35
33
|
class _SiglipVisionModel(torch.nn.Module):
|
|
36
34
|
def __init__(
|
|
@@ -65,8 +63,12 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
65
63
|
on RBLN devices, supporting image encoding for multimodal vision-language tasks.
|
|
66
64
|
"""
|
|
67
65
|
|
|
66
|
+
_tp_support = False
|
|
67
|
+
|
|
68
68
|
@classmethod
|
|
69
|
-
def
|
|
69
|
+
def _wrap_model_if_needed(
|
|
70
|
+
cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig
|
|
71
|
+
) -> torch.nn.Module:
|
|
70
72
|
wrapper_cfg = {
|
|
71
73
|
"interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
|
|
72
74
|
"output_hidden_states": rbln_config.output_hidden_states,
|
|
@@ -74,12 +76,6 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
74
76
|
}
|
|
75
77
|
return _SiglipVisionModel(model, **wrapper_cfg).eval()
|
|
76
78
|
|
|
77
|
-
@classmethod
|
|
78
|
-
def update_rbln_config_using_pipe(
|
|
79
|
-
cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
|
|
80
|
-
) -> "RBLNDiffusionMixinConfig":
|
|
81
|
-
return rbln_config
|
|
82
|
-
|
|
83
79
|
@classmethod
|
|
84
80
|
def _update_rbln_config(
|
|
85
81
|
cls,
|
|
@@ -126,12 +122,21 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
126
122
|
output_attentions: bool = None,
|
|
127
123
|
output_hidden_states: bool = None,
|
|
128
124
|
interpolate_pos_encoding: bool = False,
|
|
129
|
-
**kwargs:
|
|
125
|
+
**kwargs: Any,
|
|
130
126
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
127
|
+
"""
|
|
128
|
+
Forward pass for the RBLN-optimized SigLIP vision model.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
|
|
132
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
133
|
+
output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
|
|
134
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
135
|
+
interpolate_pos_encoding (bool, defaults to False): Whether to interpolate the pre-trained position encodings.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
139
|
+
"""
|
|
135
140
|
|
|
136
141
|
output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
|
|
137
142
|
output_hidden_states = (
|
|
@@ -156,7 +161,7 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
156
161
|
f"Please compile again with the correct argument."
|
|
157
162
|
)
|
|
158
163
|
|
|
159
|
-
output = super().forward(pixel_values, return_dict=return_dict)
|
|
164
|
+
output = super().forward(pixel_values, return_dict=return_dict, **kwargs)
|
|
160
165
|
return output
|
|
161
166
|
|
|
162
167
|
def _prepare_output(self, output, return_dict):
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_swin import RBLNSwinBackboneConfig
|
|
16
|
+
from .modeling_swin import RBLNSwinBackbone
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at:
|
|
4
|
+
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
from typing import Any, Optional, Tuple, Union
|
|
14
|
+
|
|
15
|
+
from ...configuration_generic import RBLNModelForImageClassificationConfig
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
22
|
+
batch_size: Optional[int] = None,
|
|
23
|
+
output_hidden_states: Optional[bool] = None,
|
|
24
|
+
output_attentions: Optional[bool] = None,
|
|
25
|
+
**kwargs: Any,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Args:
|
|
29
|
+
batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
|
|
30
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError: If batch_size is not a positive integer.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(**kwargs)
|
|
36
|
+
self.batch_size = batch_size or 1
|
|
37
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
38
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
39
|
+
|
|
40
|
+
self.image_size = image_size
|
|
41
|
+
self.output_hidden_states = output_hidden_states
|
|
42
|
+
self.output_attentions = output_attentions
|