optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,6 +21,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
|
21
21
|
from ....configuration_utils import RBLNCompileConfig
|
|
22
22
|
from ....modeling import RBLNModel
|
|
23
23
|
from ....utils.logging import get_logger
|
|
24
|
+
from ...modeling_outputs import _validate_output_attentions, _validate_output_hidden_states
|
|
24
25
|
from .configuration_siglip import RBLNSiglipVisionModelConfig
|
|
25
26
|
|
|
26
27
|
|
|
@@ -52,7 +53,7 @@ class _SiglipVisionModel(torch.nn.Module):
|
|
|
52
53
|
interpolate_pos_encoding=self.interpolate_pos_encoding,
|
|
53
54
|
output_attentions=self.output_attentions,
|
|
54
55
|
)
|
|
55
|
-
return
|
|
56
|
+
return enc_out
|
|
56
57
|
|
|
57
58
|
|
|
58
59
|
class RBLNSiglipVisionModel(RBLNModel):
|
|
@@ -66,7 +67,9 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
66
67
|
_tp_support = False
|
|
67
68
|
|
|
68
69
|
@classmethod
|
|
69
|
-
def
|
|
70
|
+
def _wrap_model_if_needed(
|
|
71
|
+
cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig
|
|
72
|
+
) -> torch.nn.Module:
|
|
70
73
|
wrapper_cfg = {
|
|
71
74
|
"interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
|
|
72
75
|
"output_hidden_states": rbln_config.output_hidden_states,
|
|
@@ -122,23 +125,22 @@ class RBLNSiglipVisionModel(RBLNModel):
|
|
|
122
125
|
interpolate_pos_encoding: bool = False,
|
|
123
126
|
**kwargs: Any,
|
|
124
127
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
128
|
+
"""
|
|
129
|
+
Forward pass for the RBLN-optimized SigLIP vision model.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
|
|
133
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
134
|
+
output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
|
|
135
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
136
|
+
interpolate_pos_encoding (bool, defaults to False): Whether to interpolate the pre-trained position encodings.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
output_attentions = _validate_output_attentions(output_attentions, self.rbln_config)
|
|
143
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
142
144
|
if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
|
|
143
145
|
raise ValueError(
|
|
144
146
|
f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
|
|
@@ -32,11 +32,6 @@ class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
|
|
|
32
32
|
Raises:
|
|
33
33
|
ValueError: If batch_size is not a positive integer.
|
|
34
34
|
"""
|
|
35
|
-
super().__init__(**kwargs)
|
|
36
|
-
self.batch_size = batch_size or 1
|
|
37
|
-
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
38
|
-
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
39
|
-
|
|
40
|
-
self.image_size = image_size
|
|
35
|
+
super().__init__(batch_size=batch_size, image_size=image_size, **kwargs)
|
|
41
36
|
self.output_hidden_states = output_hidden_states
|
|
42
37
|
self.output_attentions = output_attentions
|
|
@@ -203,7 +203,7 @@ class _SwinBackbone(torch.nn.Module):
|
|
|
203
203
|
|
|
204
204
|
class RBLNSwinBackbone(RBLNModel):
|
|
205
205
|
@classmethod
|
|
206
|
-
def
|
|
206
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSwinBackboneConfig) -> torch.nn.Module:
|
|
207
207
|
for layer in model.encoder.layers:
|
|
208
208
|
for block in layer.blocks:
|
|
209
209
|
block.get_attn_mask = types.MethodType(get_attn_mask, block)
|
|
@@ -278,6 +278,19 @@ class RBLNSwinBackbone(RBLNModel):
|
|
|
278
278
|
output_hidden_states: bool = None,
|
|
279
279
|
**kwargs,
|
|
280
280
|
) -> Union[Tuple, BackboneOutput]:
|
|
281
|
+
"""
|
|
282
|
+
Forward pass for the RBLN-optimized Swin backbone model.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
|
|
286
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
287
|
+
output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
|
|
288
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BackboneOutput object.
|
|
292
|
+
"""
|
|
293
|
+
|
|
281
294
|
if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
|
|
282
295
|
logger.warning(
|
|
283
296
|
f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
|
|
@@ -314,19 +327,19 @@ class RBLNSwinBackbone(RBLNModel):
|
|
|
314
327
|
output = self.model[0](padded_pixel_values)
|
|
315
328
|
|
|
316
329
|
feature_maps = ()
|
|
317
|
-
for
|
|
330
|
+
for _ in range(len(self.config.out_features)):
|
|
318
331
|
feature_maps += (output.pop(0),)
|
|
319
332
|
|
|
320
333
|
if self.rbln_config.output_hidden_states:
|
|
321
334
|
hidden_states = ()
|
|
322
|
-
for
|
|
335
|
+
for _ in range(len(self.config.stage_names)):
|
|
323
336
|
hidden_states += (output.pop(0),)
|
|
324
337
|
else:
|
|
325
338
|
hidden_states = None
|
|
326
339
|
|
|
327
340
|
if self.rbln_config.output_attentions:
|
|
328
341
|
attentions = ()
|
|
329
|
-
for
|
|
342
|
+
for _ in range(len(self.config.depths)):
|
|
330
343
|
attentions += (output.pop(0),)
|
|
331
344
|
else:
|
|
332
345
|
attentions = None
|
|
@@ -68,7 +68,7 @@ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
|
68
68
|
output_class = BaseModelOutputWithPastAndCrossAttentions
|
|
69
69
|
|
|
70
70
|
@classmethod
|
|
71
|
-
def
|
|
71
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
|
|
72
72
|
return T5EncoderWrapper(model)
|
|
73
73
|
|
|
74
74
|
@classmethod
|
|
@@ -113,7 +113,7 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
|
113
113
|
support_causal_attn = False
|
|
114
114
|
|
|
115
115
|
@classmethod
|
|
116
|
-
def
|
|
116
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
|
|
117
117
|
return T5Wrapper(
|
|
118
118
|
model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
|
|
119
119
|
)
|
|
@@ -39,7 +39,7 @@ class T5Wrapper:
|
|
|
39
39
|
|
|
40
40
|
class T5EncoderWrapper(Seq2SeqEncoderWrapper):
|
|
41
41
|
def __post_init__(self, model: nn.Module):
|
|
42
|
-
self.n_layer =
|
|
42
|
+
self.n_layer = self.config.num_layers
|
|
43
43
|
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().block)
|
|
44
44
|
self.num_heads = self.config.num_heads
|
|
45
45
|
self.d_kv = self.config.d_kv
|
|
@@ -111,9 +111,9 @@ class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
|
|
111
111
|
class T5Decoder(Seq2SeqDecoder):
|
|
112
112
|
has_pos_emb = False
|
|
113
113
|
|
|
114
|
-
def __post_init__(self, dec_max_seq_len: int = None):
|
|
115
|
-
self.invert_attention_mask =
|
|
116
|
-
self._dec_position_bias = self.precompute_dec_position_bias(
|
|
114
|
+
def __post_init__(self, model: nn.Module, dec_max_seq_len: int = None):
|
|
115
|
+
self.invert_attention_mask = model.invert_attention_mask
|
|
116
|
+
self._dec_position_bias = self.precompute_dec_position_bias(model, dec_max_seq_len)
|
|
117
117
|
|
|
118
118
|
def precompute_dec_position_bias(self, model, dec_max_length):
|
|
119
119
|
attn_layer = model.block[0].layer[0].SelfAttention
|
|
@@ -145,13 +145,12 @@ class T5Decoder(Seq2SeqDecoder):
|
|
|
145
145
|
class T5Block(Seq2SeqDecoderLayer):
|
|
146
146
|
def __init__(self, decoder_layer, self_attn):
|
|
147
147
|
super().__init__(decoder_layer, self_attn, cross_attn=None)
|
|
148
|
-
self.__post_init__()
|
|
149
148
|
|
|
150
|
-
def __post_init__(self):
|
|
151
|
-
self.self_attn_layer_norm =
|
|
152
|
-
self.encoder_attn_layer_norm =
|
|
153
|
-
self.cross_attn = T5CrossAttention(
|
|
154
|
-
self.ff_layer =
|
|
149
|
+
def __post_init__(self, decoder_layer: nn.Module):
|
|
150
|
+
self.self_attn_layer_norm = decoder_layer.layer[0].layer_norm
|
|
151
|
+
self.encoder_attn_layer_norm = decoder_layer.layer[1].layer_norm
|
|
152
|
+
self.cross_attn = T5CrossAttention(decoder_layer.layer[1].EncDecAttention)
|
|
153
|
+
self.ff_layer = decoder_layer.layer[2]
|
|
155
154
|
|
|
156
155
|
def pre_self_attn_layer_norm(self, hidden_states):
|
|
157
156
|
return self.self_attn_layer_norm(hidden_states)
|
|
@@ -167,13 +166,13 @@ class T5Block(Seq2SeqDecoderLayer):
|
|
|
167
166
|
|
|
168
167
|
|
|
169
168
|
class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
|
170
|
-
def __post_init__(self):
|
|
171
|
-
self.q_proj =
|
|
172
|
-
self.k_proj =
|
|
173
|
-
self.v_proj =
|
|
174
|
-
self.out_proj =
|
|
175
|
-
self.num_heads =
|
|
176
|
-
self.head_dim =
|
|
169
|
+
def __post_init__(self, attn: nn.Module):
|
|
170
|
+
self.q_proj = attn.q
|
|
171
|
+
self.k_proj = attn.k
|
|
172
|
+
self.v_proj = attn.v
|
|
173
|
+
self.out_proj = attn.o
|
|
174
|
+
self.num_heads = attn.n_heads
|
|
175
|
+
self.head_dim = attn.key_value_proj_dim
|
|
177
176
|
self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
|
|
178
177
|
|
|
179
178
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py
CHANGED
|
@@ -153,7 +153,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
|
153
153
|
return redirect(val)
|
|
154
154
|
|
|
155
155
|
@classmethod
|
|
156
|
-
def
|
|
156
|
+
def _wrap_model_if_needed(
|
|
157
157
|
self, model: "PreTrainedModel", rbln_config: RBLNTimeSeriesTransformerForPredictionConfig
|
|
158
158
|
):
|
|
159
159
|
return TimeSeriesTransformersWrapper(model, rbln_config.num_parallel_samples)
|
|
@@ -161,7 +161,7 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
|
161
161
|
@classmethod
|
|
162
162
|
@torch.inference_mode()
|
|
163
163
|
def get_compiled_model(cls, model, rbln_config: RBLNTimeSeriesTransformerForPredictionConfig):
|
|
164
|
-
wrapped_model = cls.
|
|
164
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
165
165
|
|
|
166
166
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
167
167
|
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
@@ -184,14 +184,6 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
|
184
184
|
if "key_value_states" in name:
|
|
185
185
|
context.mark_static_address(tensor)
|
|
186
186
|
|
|
187
|
-
compiled_decoder = cls.compile(
|
|
188
|
-
wrapped_model.decoder,
|
|
189
|
-
dec_compile_config,
|
|
190
|
-
create_runtimes=rbln_config.create_runtimes,
|
|
191
|
-
device=rbln_config.device,
|
|
192
|
-
example_inputs=dec_example_inputs,
|
|
193
|
-
compile_context=context,
|
|
194
|
-
)
|
|
195
187
|
compiled_encoder = cls.compile(
|
|
196
188
|
wrapped_model.encoder,
|
|
197
189
|
enc_compile_config,
|
|
@@ -201,6 +193,15 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
|
201
193
|
compile_context=context,
|
|
202
194
|
)
|
|
203
195
|
|
|
196
|
+
compiled_decoder = cls.compile(
|
|
197
|
+
wrapped_model.decoder,
|
|
198
|
+
dec_compile_config,
|
|
199
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
200
|
+
device=rbln_config.device,
|
|
201
|
+
example_inputs=dec_example_inputs,
|
|
202
|
+
compile_context=context,
|
|
203
|
+
)
|
|
204
|
+
|
|
204
205
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
205
206
|
|
|
206
207
|
@classmethod
|
|
@@ -353,6 +354,20 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
|
|
|
353
354
|
static_real_features: Optional[torch.Tensor] = None,
|
|
354
355
|
**kwargs,
|
|
355
356
|
) -> SampleTSPredictionOutput:
|
|
357
|
+
"""
|
|
358
|
+
Generate pass for the RBLN-optimized Time Series Transformer model for time series forecasting.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
past_values (torch.FloatTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size)): Past values of the time series, that serve as context in order to predict the future.
|
|
362
|
+
past_time_features (torch.FloatTensor of shape (batch_size, sequence_length, num_features)): Required time features, which the model internally will add to past_values.
|
|
363
|
+
future_time_features (torch.FloatTensor of shape (batch_size, prediction_length, num_features)): Required time features for the prediction window, which the model internally will add to future_values.
|
|
364
|
+
past_observed_mask (torch.BoolTensor of shape (batch_size, sequence_length) or (batch_size, sequence_length, input_size), optional): Boolean mask to indicate which past_values were observed and which were missing.
|
|
365
|
+
static_categorical_features (torch.LongTensor of shape (batch_size, number of static categorical features), optional): Optional static categorical features for which the model will learn an embedding, which it will add to the values of the time series.
|
|
366
|
+
static_real_features (torch.FloatTensor of shape (batch_size, number of static real features), optional): Optional static real features which the model will add to the values of the time series.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SampleTSPredictionOutput object.
|
|
370
|
+
"""
|
|
356
371
|
self.validate_batch_size(**{k: v for k, v in locals().items() if isinstance(v, torch.Tensor)})
|
|
357
372
|
|
|
358
373
|
outputs = self.encoder(
|
optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py
CHANGED
|
@@ -140,7 +140,6 @@ class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
|
|
|
140
140
|
class TimeSeriesTransformersDecoder(nn.Module):
|
|
141
141
|
def __init__(self, model, layers, **kwargs):
|
|
142
142
|
super().__init__()
|
|
143
|
-
self._original_mod = model
|
|
144
143
|
self.config = model.config
|
|
145
144
|
self.layers = nn.ModuleList(layers)
|
|
146
145
|
self.value_embedding = model.value_embedding
|
|
@@ -190,7 +189,6 @@ class TimeSeriesTransformersDecoder(nn.Module):
|
|
|
190
189
|
class TimeSeriesTransformersDecoderLayer(nn.Module):
|
|
191
190
|
def __init__(self, decoder_layer, self_attn, cross_attn):
|
|
192
191
|
super().__init__()
|
|
193
|
-
self._original_mod = decoder_layer
|
|
194
192
|
self.self_attn = self_attn
|
|
195
193
|
self.encoder_attn = cross_attn
|
|
196
194
|
self.embed_dim = decoder_layer.embed_dim
|
|
@@ -245,7 +243,6 @@ class TimeSeriesTransformersDecoderLayer(nn.Module):
|
|
|
245
243
|
class TimeSeriesTransformersAttention(nn.Module):
|
|
246
244
|
def __init__(self, attn, num_parallel_samples):
|
|
247
245
|
super().__init__()
|
|
248
|
-
self._original_mod = attn
|
|
249
246
|
self.q_proj = attn.q_proj
|
|
250
247
|
self.k_proj = attn.k_proj
|
|
251
248
|
self.v_proj = attn.v_proj
|
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import ImageClassifierOutput
|
|
19
|
+
|
|
15
20
|
from ...modeling_generic import RBLNModelForImageClassification
|
|
16
21
|
|
|
17
22
|
|
|
@@ -23,3 +28,17 @@ class RBLNViTForImageClassification(RBLNModelForImageClassification):
|
|
|
23
28
|
on RBLN devices, supporting image classification with transformer-based architectures
|
|
24
29
|
that process images as sequences of patches.
|
|
25
30
|
"""
|
|
31
|
+
|
|
32
|
+
def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[ImageClassifierOutput, Tuple]:
|
|
33
|
+
"""
|
|
34
|
+
Forward pass for the RBLN-optimized Vision Transformer model for image classification.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)):
|
|
38
|
+
The tensors corresponding to the input images.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns an ImageClassifierOutput object.
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
return super().forward(pixel_values, **kwargs)
|
|
@@ -12,10 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
17
18
|
|
|
18
|
-
|
|
19
|
+
|
|
20
|
+
class RBLNWav2Vec2ForCTCConfig(RBLNModelConfig):
|
|
19
21
|
"""
|
|
20
22
|
Configuration class for RBLNWav2Vec2ForCTC.
|
|
21
23
|
|
|
@@ -23,4 +25,14 @@ class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
|
|
|
23
25
|
RBLN-optimized Wav2Vec2 models for Connectionist Temporal Classification (CTC) tasks.
|
|
24
26
|
"""
|
|
25
27
|
|
|
26
|
-
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
max_seq_len: Optional[int] = None,
|
|
31
|
+
batch_size: Optional[int] = None,
|
|
32
|
+
**kwargs: Any,
|
|
33
|
+
):
|
|
34
|
+
super().__init__(**kwargs)
|
|
35
|
+
self.max_seq_len = max_seq_len
|
|
36
|
+
self.batch_size = batch_size or 1
|
|
37
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
38
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
@@ -13,13 +13,21 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
17
|
+
|
|
16
18
|
import torch
|
|
17
|
-
from transformers import
|
|
19
|
+
from transformers import AutoModelForCTC, Wav2Vec2Config, Wav2Vec2ForCTC
|
|
20
|
+
from transformers.modeling_outputs import CausalLMOutput
|
|
18
21
|
|
|
19
|
-
from
|
|
22
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
23
|
+
from ....modeling import RBLNModel
|
|
20
24
|
from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
|
|
21
25
|
|
|
22
26
|
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
29
|
+
|
|
30
|
+
|
|
23
31
|
class _Wav2Vec2(torch.nn.Module):
|
|
24
32
|
def __init__(self, model: "Wav2Vec2ForCTC"):
|
|
25
33
|
super().__init__()
|
|
@@ -30,13 +38,10 @@ class _Wav2Vec2(torch.nn.Module):
|
|
|
30
38
|
return self.model.lm_head(output[0])
|
|
31
39
|
|
|
32
40
|
|
|
33
|
-
class RBLNWav2Vec2ForCTC(
|
|
41
|
+
class RBLNWav2Vec2ForCTC(RBLNModel):
|
|
34
42
|
"""
|
|
35
43
|
Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
|
|
36
44
|
|
|
37
|
-
This model inherits from [`RBLNModelForMaskedLM`]. Check the superclass documentation for the generic methods the
|
|
38
|
-
library implements for all its model.
|
|
39
|
-
|
|
40
45
|
It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
|
|
41
46
|
|
|
42
47
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
@@ -44,9 +49,56 @@ class RBLNWav2Vec2ForCTC(RBLNModelForMaskedLM):
|
|
|
44
49
|
"""
|
|
45
50
|
|
|
46
51
|
main_input_name = "input_values"
|
|
47
|
-
auto_model_class =
|
|
52
|
+
auto_model_class = AutoModelForCTC
|
|
48
53
|
rbln_dtype = "float32"
|
|
49
54
|
|
|
50
55
|
@classmethod
|
|
51
|
-
def
|
|
56
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
|
|
52
57
|
return _Wav2Vec2(model).eval()
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _update_rbln_config(
|
|
61
|
+
cls,
|
|
62
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
|
63
|
+
model: Optional["PreTrainedModel"] = None,
|
|
64
|
+
model_config: "Wav2Vec2Config" = None,
|
|
65
|
+
rbln_config: Optional[RBLNWav2Vec2ForCTCConfig] = None,
|
|
66
|
+
) -> RBLNWav2Vec2ForCTCConfig:
|
|
67
|
+
if rbln_config.max_seq_len is None:
|
|
68
|
+
for tokenizer in preprocessors:
|
|
69
|
+
if hasattr(tokenizer, "model_max_length"):
|
|
70
|
+
rbln_config.max_seq_len = tokenizer.model_max_length
|
|
71
|
+
break
|
|
72
|
+
if rbln_config.max_seq_len is None:
|
|
73
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
|
74
|
+
|
|
75
|
+
rbln_compile_config = RBLNCompileConfig(
|
|
76
|
+
input_info=[
|
|
77
|
+
(
|
|
78
|
+
"input_values",
|
|
79
|
+
[
|
|
80
|
+
rbln_config.batch_size,
|
|
81
|
+
rbln_config.max_seq_len,
|
|
82
|
+
],
|
|
83
|
+
"float32",
|
|
84
|
+
)
|
|
85
|
+
]
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
89
|
+
return rbln_config
|
|
90
|
+
|
|
91
|
+
def forward(
|
|
92
|
+
self, input_values: torch.Tensor, return_dict: Optional[bool] = None, **kwargs
|
|
93
|
+
) -> Union[CausalLMOutput, tuple]:
|
|
94
|
+
"""
|
|
95
|
+
Forward pass for the RBLN-optimized Wav2Vec2 model for Connectionist Temporal Classification (CTC).
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
input_values (torch.FloatTensor of shape (batch_size, sequence_length)): Float values of input raw speech waveform. Values can be obtained by loading a .flac or .wav audio file into an array of type List[float] or a numpy.ndarray, e.g. via the soundfile library (pip install soundfile). To prepare the array into input_values, the AutoProcessor should be used for padding and conversion into a tensor of type torch.FloatTensor.
|
|
99
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CausalLMOutput object.
|
|
103
|
+
"""
|
|
104
|
+
return super().forward(input_values=input_values, return_dict=return_dict, **kwargs)
|
|
@@ -31,29 +31,63 @@ Generation utilities for Whisper.
|
|
|
31
31
|
Modified from `transformers.models.whisper.generation_whisper.py`
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
|
+
from typing import Any, Dict, Optional, Union
|
|
35
|
+
|
|
34
36
|
import torch
|
|
35
37
|
import transformers
|
|
36
38
|
from packaging import version
|
|
37
39
|
from transformers import GenerationMixin
|
|
40
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
41
|
+
from transformers.modeling_outputs import ModelOutput
|
|
38
42
|
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
|
|
39
43
|
|
|
40
44
|
|
|
41
45
|
class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
|
|
42
|
-
def generate(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
46
|
+
def generate(
|
|
47
|
+
self,
|
|
48
|
+
input_features: Optional[torch.Tensor] = None,
|
|
49
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
50
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
51
|
+
return_segments: Optional[bool] = None,
|
|
52
|
+
return_timestamps: Optional[bool] = None,
|
|
53
|
+
return_token_timestamps: Optional[bool] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
|
|
56
|
+
"""
|
|
57
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
58
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
input_features(torch.Tensor, optional): The input features to the model.
|
|
62
|
+
attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
|
|
63
|
+
generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
64
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
65
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
66
|
+
return_segments(bool, optional): Whether to return segments.
|
|
67
|
+
return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
|
|
68
|
+
return_token_timestamps(bool, optional): Whether to return token timestamps.
|
|
69
|
+
kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
|
|
73
|
+
"""
|
|
74
|
+
if kwargs.get("num_beams", None) is not None:
|
|
75
|
+
if kwargs.get("num_beams") != 1:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"Beam search is not supported in RBLNWhisperGenerationMixin. "
|
|
78
|
+
"Received num_beams={num_beams}, but only num_beams=1 is allowed. "
|
|
79
|
+
"Please set num_beams=1 for greedy search or adjust your configuration."
|
|
80
|
+
)
|
|
55
81
|
|
|
56
|
-
return super().generate(
|
|
82
|
+
return super().generate(
|
|
83
|
+
input_features,
|
|
84
|
+
attention_mask=attention_mask,
|
|
85
|
+
generation_config=generation_config,
|
|
86
|
+
return_segments=return_segments,
|
|
87
|
+
return_timestamps=return_timestamps,
|
|
88
|
+
return_token_timestamps=return_token_timestamps,
|
|
89
|
+
**kwargs,
|
|
90
|
+
)
|
|
57
91
|
|
|
58
92
|
def _postprocess_outputs(
|
|
59
93
|
self,
|
|
@@ -203,7 +203,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
203
203
|
raise NotImplementedError
|
|
204
204
|
|
|
205
205
|
@classmethod
|
|
206
|
-
def
|
|
206
|
+
def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
|
207
207
|
return WhisperWrapper(
|
|
208
208
|
model,
|
|
209
209
|
use_attention_mask=rbln_config.use_attention_mask,
|
|
@@ -213,7 +213,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin)
|
|
|
213
213
|
@classmethod
|
|
214
214
|
@torch.inference_mode()
|
|
215
215
|
def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
|
|
216
|
-
wrapped_model = cls.
|
|
216
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
217
217
|
|
|
218
218
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
219
219
|
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
@@ -154,7 +154,6 @@ class WhisperDecoderWrapper(torch.nn.Module):
|
|
|
154
154
|
class WhisperDecoder(nn.Module):
|
|
155
155
|
def __init__(self, model, layers, **kwargs):
|
|
156
156
|
super().__init__()
|
|
157
|
-
self._original_mod = model
|
|
158
157
|
self.layers = nn.ModuleList(layers)
|
|
159
158
|
self.embed_tokens = model.embed_tokens
|
|
160
159
|
self.layer_norm = model.layer_norm
|
|
@@ -210,7 +209,6 @@ class WhisperDecoder(nn.Module):
|
|
|
210
209
|
class WhisperDecoderLayer(nn.Module):
|
|
211
210
|
def __init__(self, decoder_layer, self_attn, cross_attn):
|
|
212
211
|
super().__init__()
|
|
213
|
-
self._original_mod = decoder_layer
|
|
214
212
|
self.self_attn = self_attn
|
|
215
213
|
self.encoder_attn = cross_attn
|
|
216
214
|
self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
|
|
@@ -263,7 +261,6 @@ class WhisperDecoderLayer(nn.Module):
|
|
|
263
261
|
class WhisperAttention(nn.Module):
|
|
264
262
|
def __init__(self, attn):
|
|
265
263
|
super().__init__()
|
|
266
|
-
self._original_mod = attn
|
|
267
264
|
self.q_proj = attn.q_proj
|
|
268
265
|
self.k_proj = attn.k_proj
|
|
269
266
|
self.v_proj = attn.v_proj
|