optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +108 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +156 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +48 -21
- optimum/rbln/modeling_base.py +99 -22
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/runtime_utils.py +60 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,16 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import torch
|
|
16
|
-
import torch.nn as nn
|
|
17
|
-
from transformers import PreTrainedModel
|
|
18
15
|
|
|
19
|
-
from ..decoderonly.decoderonly_architecture import
|
|
20
|
-
DecoderOnlyAttention,
|
|
21
|
-
DecoderOnlyLayer,
|
|
22
|
-
DecoderOnlyWrapper,
|
|
23
|
-
RotaryEmbedding,
|
|
24
|
-
)
|
|
16
|
+
from ..decoderonly.decoderonly_architecture import DecoderOnlyAttention, DecoderOnlyWrapper
|
|
25
17
|
|
|
26
18
|
|
|
27
19
|
class Qwen3Wrapper(DecoderOnlyWrapper):
|
|
@@ -37,239 +29,3 @@ class Qwen3Attention(DecoderOnlyAttention):
|
|
|
37
29
|
self.o_proj = self._original_mod.o_proj
|
|
38
30
|
self.q_norm = self._original_mod.q_norm
|
|
39
31
|
self.k_norm = self._original_mod.k_norm
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class Qwen3ModelWrapper(nn.Module):
|
|
43
|
-
def __init__(
|
|
44
|
-
self,
|
|
45
|
-
model,
|
|
46
|
-
attn_impl=None,
|
|
47
|
-
use_inputs_embeds=None,
|
|
48
|
-
use_attention_mask=None,
|
|
49
|
-
use_rotary_emb=None,
|
|
50
|
-
cache_impl=None,
|
|
51
|
-
kvcache_partition_len=None,
|
|
52
|
-
max_seq_len=None,
|
|
53
|
-
kvcache_block_size=None,
|
|
54
|
-
sliding_window=None,
|
|
55
|
-
sliding_window_layers=None,
|
|
56
|
-
):
|
|
57
|
-
super().__init__()
|
|
58
|
-
self.config = model.config
|
|
59
|
-
|
|
60
|
-
if use_rotary_emb:
|
|
61
|
-
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
|
62
|
-
if isinstance(rotary_embs, tuple):
|
|
63
|
-
self.rotary_emb_global, self.rotary_emb_local = rotary_embs
|
|
64
|
-
else:
|
|
65
|
-
self.rotary_emb = rotary_embs
|
|
66
|
-
else:
|
|
67
|
-
self.rotary_emb = None
|
|
68
|
-
|
|
69
|
-
self._original_mod = model
|
|
70
|
-
self.use_inputs_embeds = use_inputs_embeds
|
|
71
|
-
self.attn_impl = attn_impl
|
|
72
|
-
self.cache_impl = cache_impl
|
|
73
|
-
self.use_attention_mask = use_attention_mask
|
|
74
|
-
self.kvcache_partition_len = kvcache_partition_len
|
|
75
|
-
self.kvcache_block_size = kvcache_block_size
|
|
76
|
-
self.max_seq_len = max_seq_len
|
|
77
|
-
self.sliding_window = sliding_window
|
|
78
|
-
self.sliding_window_layers = sliding_window_layers
|
|
79
|
-
self.model = self.convert_to_rbln_model(model)
|
|
80
|
-
|
|
81
|
-
def get_rotary_emb(self, max_seq_len):
|
|
82
|
-
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
|
83
|
-
|
|
84
|
-
def convert_to_rbln_model(self, base_model: PreTrainedModel):
|
|
85
|
-
for layer_idx, layer in enumerate(base_model.layers):
|
|
86
|
-
is_sliding = layer_idx in self.sliding_window_layers
|
|
87
|
-
new_self_attn = Qwen3Attention(
|
|
88
|
-
layer.self_attn,
|
|
89
|
-
self.use_attention_mask if not is_sliding else True,
|
|
90
|
-
use_position_ids=None,
|
|
91
|
-
kvcache_block_size=self.sliding_window
|
|
92
|
-
if layer_idx in self.sliding_window_layers
|
|
93
|
-
else self.kvcache_block_size,
|
|
94
|
-
is_sliding=is_sliding,
|
|
95
|
-
attn_impl=self.attn_impl if not is_sliding else "eager",
|
|
96
|
-
kvcache_partition_len=self.kvcache_partition_len,
|
|
97
|
-
)
|
|
98
|
-
base_model.layers[layer_idx] = DecoderOnlyLayer(layer, new_self_attn)
|
|
99
|
-
|
|
100
|
-
return base_model
|
|
101
|
-
|
|
102
|
-
@property
|
|
103
|
-
def hidden_multiplier(self):
|
|
104
|
-
return 1
|
|
105
|
-
|
|
106
|
-
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
107
|
-
return self._original_mod.norm
|
|
108
|
-
|
|
109
|
-
def get_embedding(self) -> nn.Embedding:
|
|
110
|
-
return self._original_mod.embed_tokens
|
|
111
|
-
|
|
112
|
-
def get_pos_embedding(self) -> nn.Embedding:
|
|
113
|
-
raise NotImplementedError(
|
|
114
|
-
"The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
|
|
118
|
-
if self.attn_impl not in ["flash_attn"]:
|
|
119
|
-
raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
|
|
120
|
-
partition_len = self.kvcache_partition_len
|
|
121
|
-
num_partition = max_seq_len // partition_len
|
|
122
|
-
|
|
123
|
-
cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
|
|
124
|
-
pidx = torch.arange(num_partition)
|
|
125
|
-
cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
|
|
126
|
-
return cache_pos_for_partitions
|
|
127
|
-
|
|
128
|
-
def get_local_cache_positions(self, position_ids, query_position):
|
|
129
|
-
max_cache_len = self.model.config.sliding_window
|
|
130
|
-
valid_input_len = 1 if query_position is None else query_position + 1
|
|
131
|
-
cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
|
|
132
|
-
cache_offset = (
|
|
133
|
-
torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
|
|
134
|
-
) # cache offset for next steps
|
|
135
|
-
|
|
136
|
-
return cache_seq_len, cache_offset
|
|
137
|
-
|
|
138
|
-
def prepare_forward_args(self, *args):
|
|
139
|
-
args = list(args)
|
|
140
|
-
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
|
141
|
-
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
|
142
|
-
cache_position = args.pop(0)
|
|
143
|
-
global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
|
|
144
|
-
local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
|
|
145
|
-
query_position = args.pop(0) if self.sliding_window else None
|
|
146
|
-
attention_mask = args.pop(0) if self.use_attention_mask else None
|
|
147
|
-
position_ids = None
|
|
148
|
-
past_key_values = args
|
|
149
|
-
|
|
150
|
-
if len(past_key_values) != 2 * self.config.num_hidden_layers:
|
|
151
|
-
raise ValueError(
|
|
152
|
-
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.config.num_hidden_layers}"
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
|
156
|
-
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
|
157
|
-
_past_key_values = []
|
|
158
|
-
for i in range(self.config.num_hidden_layers):
|
|
159
|
-
key_states = past_key_values[i * 2]
|
|
160
|
-
value_states = past_key_values[i * 2 + 1]
|
|
161
|
-
past_key_value = [key_states, value_states]
|
|
162
|
-
_past_key_values.append(past_key_value)
|
|
163
|
-
past_key_values = _past_key_values
|
|
164
|
-
|
|
165
|
-
if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
|
|
166
|
-
rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
|
|
167
|
-
else:
|
|
168
|
-
rotary_emb = self.rotary_emb
|
|
169
|
-
|
|
170
|
-
return (
|
|
171
|
-
input_ids,
|
|
172
|
-
inputs_embeds,
|
|
173
|
-
cache_position,
|
|
174
|
-
global_block_tables,
|
|
175
|
-
local_block_tables,
|
|
176
|
-
attention_mask,
|
|
177
|
-
position_ids,
|
|
178
|
-
query_position,
|
|
179
|
-
past_key_values,
|
|
180
|
-
rotary_emb,
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
def forward(self, *args):
|
|
184
|
-
(
|
|
185
|
-
input_ids,
|
|
186
|
-
inputs_embeds,
|
|
187
|
-
cache_position,
|
|
188
|
-
global_block_tables,
|
|
189
|
-
local_block_tables,
|
|
190
|
-
attention_mask,
|
|
191
|
-
position_ids,
|
|
192
|
-
query_position,
|
|
193
|
-
past_key_values,
|
|
194
|
-
rotary_emb,
|
|
195
|
-
) = self.prepare_forward_args(*args)
|
|
196
|
-
|
|
197
|
-
# retrieve input_ids and inputs_embeds
|
|
198
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
199
|
-
raise ValueError(
|
|
200
|
-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
# embed positions
|
|
204
|
-
if inputs_embeds is None:
|
|
205
|
-
inputs_embeds = self.get_embedding()(input_ids)
|
|
206
|
-
|
|
207
|
-
hidden_states = inputs_embeds * self.hidden_multiplier
|
|
208
|
-
|
|
209
|
-
# get cos,sin vector if needed
|
|
210
|
-
position_ids = position_ids if position_ids is not None else cache_position
|
|
211
|
-
if rotary_emb is not None:
|
|
212
|
-
if isinstance(rotary_emb, torch.Tensor):
|
|
213
|
-
cos = rotary_emb[0]
|
|
214
|
-
sin = rotary_emb[1]
|
|
215
|
-
else:
|
|
216
|
-
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
|
217
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
|
218
|
-
else:
|
|
219
|
-
batch_size = inputs_embeds.shape[0]
|
|
220
|
-
if position_ids.shape[0] > 1:
|
|
221
|
-
position_embeds = []
|
|
222
|
-
for b_idx in range(batch_size):
|
|
223
|
-
position_embed = self.get_pos_embedding()(position_ids[b_idx])
|
|
224
|
-
position_embeds.append(position_embed)
|
|
225
|
-
|
|
226
|
-
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
|
227
|
-
else:
|
|
228
|
-
position_embeds = self.get_pos_embedding()(position_ids)
|
|
229
|
-
hidden_states = hidden_states + position_embeds
|
|
230
|
-
cos, sin = None, None
|
|
231
|
-
|
|
232
|
-
# Get sequence positions for flash attention
|
|
233
|
-
if self.attn_impl == "flash_attn":
|
|
234
|
-
seq_positions = cache_position[:, 0]
|
|
235
|
-
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
|
236
|
-
seq_positions=seq_positions, max_seq_len=self.max_seq_len
|
|
237
|
-
)
|
|
238
|
-
else:
|
|
239
|
-
seq_positions = cache_position[:, :1]
|
|
240
|
-
|
|
241
|
-
# Get local cache positions for sliding window layers
|
|
242
|
-
if len(self.sliding_window_layers) > 0:
|
|
243
|
-
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
|
|
244
|
-
|
|
245
|
-
for layer_idx, layer in enumerate(self.model.layers):
|
|
246
|
-
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
247
|
-
hidden_states = layer(
|
|
248
|
-
hidden_states=hidden_states,
|
|
249
|
-
attention_mask=attention_mask,
|
|
250
|
-
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
251
|
-
past_key_values=past_key_values,
|
|
252
|
-
cos=cos,
|
|
253
|
-
sin=sin,
|
|
254
|
-
block_tables=local_block_tables if is_sliding else global_block_tables,
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
258
|
-
return hidden_states
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
262
|
-
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
263
|
-
if cache_position.shape[0] > 1:
|
|
264
|
-
cos_all = []
|
|
265
|
-
sin_all = []
|
|
266
|
-
for i in range(cache_position.shape[0]):
|
|
267
|
-
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
268
|
-
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
269
|
-
cos = torch.cat(cos_all, dim=0)
|
|
270
|
-
sin = torch.cat(sin_all, dim=0)
|
|
271
|
-
else:
|
|
272
|
-
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
273
|
-
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
274
|
-
|
|
275
|
-
return cos, sin
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
16
18
|
from ...configuration_generic import RBLNModelForImageClassificationConfig
|
|
17
19
|
|
|
18
20
|
|
|
@@ -23,3 +25,18 @@ class RBLNResNetForImageClassificationConfig(RBLNModelForImageClassificationConf
|
|
|
23
25
|
This configuration class stores the configuration parameters specific to
|
|
24
26
|
RBLN-optimized ResNet models for image classification tasks.
|
|
25
27
|
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, output_hidden_states: Optional[bool] = None, **kwargs):
|
|
30
|
+
"""
|
|
31
|
+
Args:
|
|
32
|
+
image_size (Optional[Union[int, Tuple[int, int]]]): The size of input images.
|
|
33
|
+
Can be an integer for square images or a tuple (height, width).
|
|
34
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
35
|
+
output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers.
|
|
36
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If batch_size is not a positive integer.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__(**kwargs)
|
|
42
|
+
self.output_hidden_states = output_hidden_states
|
|
@@ -13,7 +13,17 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention
|
|
20
|
+
|
|
16
21
|
from ...modeling_generic import RBLNModelForImageClassification
|
|
22
|
+
from .configuration_resnet import RBLNResNetForImageClassificationConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
|
|
17
27
|
|
|
18
28
|
|
|
19
29
|
class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
|
@@ -24,3 +34,66 @@ class RBLNResNetForImageClassification(RBLNModelForImageClassification):
|
|
|
24
34
|
on RBLN devices, supporting image classification with convolutional neural networks
|
|
25
35
|
designed for computer vision tasks.
|
|
26
36
|
"""
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def _update_rbln_config(
|
|
40
|
+
cls,
|
|
41
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
42
|
+
model: Optional["PreTrainedModel"] = None,
|
|
43
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
44
|
+
rbln_config: Optional["RBLNResNetForImageClassificationConfig"] = None,
|
|
45
|
+
) -> "RBLNResNetForImageClassificationConfig":
|
|
46
|
+
if rbln_config.output_hidden_states is None:
|
|
47
|
+
rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
|
|
48
|
+
|
|
49
|
+
rbln_config = super()._update_rbln_config(
|
|
50
|
+
preprocessors=preprocessors,
|
|
51
|
+
model=model,
|
|
52
|
+
model_config=model_config,
|
|
53
|
+
rbln_config=rbln_config,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
return rbln_config
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def _wrap_model_if_needed(
|
|
60
|
+
cls, model: torch.nn.Module, rbln_config: "RBLNResNetForImageClassificationConfig"
|
|
61
|
+
) -> torch.nn.Module:
|
|
62
|
+
class _ResNetForImageClassification(torch.nn.Module):
|
|
63
|
+
def __init__(self, model: torch.nn.Module, output_hidden_states: bool):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.model = model
|
|
66
|
+
self.output_hidden_states = output_hidden_states
|
|
67
|
+
|
|
68
|
+
def forward(self, *args, **kwargs):
|
|
69
|
+
output = self.model(*args, output_hidden_states=self.output_hidden_states, **kwargs)
|
|
70
|
+
return output
|
|
71
|
+
|
|
72
|
+
return _ResNetForImageClassification(model, rbln_config.output_hidden_states)
|
|
73
|
+
|
|
74
|
+
def forward(
|
|
75
|
+
self, pixel_values: torch.Tensor, output_hidden_states: bool = None, return_dict: bool = None, **kwargs
|
|
76
|
+
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
|
|
77
|
+
"""
|
|
78
|
+
Foward pass for the RBLN-optimized ResNet model for image classification.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)): The tensors corresponding to the input images.
|
|
82
|
+
output_hidden_states (bool, *optional*, defaults to False): Whether or not to return the hidden states of all layers.
|
|
83
|
+
See hidden_states under returned tensors for more details.
|
|
84
|
+
return_dict (bool, *optional*, defaults to True): Whether to return a dictionary of outputs.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a ImageClassifierOutputWithNoAttention object.
|
|
88
|
+
"""
|
|
89
|
+
output_hidden_states = (
|
|
90
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
96
|
+
f"Please compile again with the correct argument."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
return super().forward(pixel_values=pixel_values, return_dict=return_dict, **kwargs)
|
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
|
|
19
|
+
|
|
15
20
|
from ...modeling_generic import RBLNModelForMaskedLM, RBLNModelForSequenceClassification
|
|
16
21
|
|
|
17
22
|
|
|
@@ -26,6 +31,19 @@ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
|
|
|
26
31
|
|
|
27
32
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
28
33
|
|
|
34
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Union[Tuple, MaskedLMOutput]:
|
|
35
|
+
"""
|
|
36
|
+
Forward pass for the RBLN-optimized RoBERTa model for masked language modeling tasks.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
40
|
+
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
|
|
44
|
+
"""
|
|
45
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
46
|
+
|
|
29
47
|
|
|
30
48
|
class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
|
31
49
|
"""
|
|
@@ -37,3 +55,18 @@ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
|
|
|
37
55
|
"""
|
|
38
56
|
|
|
39
57
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
58
|
+
|
|
59
|
+
def forward(
|
|
60
|
+
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
|
|
61
|
+
) -> Union[Tuple, SequenceClassifierOutput]:
|
|
62
|
+
"""
|
|
63
|
+
Forward pass for the RBLN-optimized RoBERTa model for sequence classification tasks.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
67
|
+
attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
|
|
71
|
+
"""
|
|
72
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
@@ -12,11 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
16
|
-
|
|
17
|
-
import rebel
|
|
15
|
+
from typing import Any, Optional
|
|
18
16
|
|
|
19
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....utils.deprecation import deprecate_kwarg
|
|
20
19
|
from ....utils.logging import get_logger
|
|
21
20
|
|
|
22
21
|
|
|
@@ -24,14 +23,18 @@ logger = get_logger()
|
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
26
|
+
support_paged_attention = None
|
|
27
|
+
|
|
28
|
+
@deprecate_kwarg(old_name="pad_token_id", version="0.10.0")
|
|
27
29
|
def __init__(
|
|
28
30
|
self,
|
|
29
31
|
batch_size: Optional[int] = None,
|
|
30
32
|
enc_max_seq_len: Optional[int] = None,
|
|
31
33
|
dec_max_seq_len: Optional[int] = None,
|
|
32
34
|
use_attention_mask: Optional[bool] = None,
|
|
33
|
-
|
|
34
|
-
|
|
35
|
+
kvcache_num_blocks: Optional[int] = None,
|
|
36
|
+
kvcache_block_size: Optional[int] = None,
|
|
37
|
+
**kwargs: Any,
|
|
35
38
|
):
|
|
36
39
|
"""
|
|
37
40
|
Args:
|
|
@@ -39,9 +42,11 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
|
39
42
|
enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
|
|
40
43
|
dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
|
|
41
44
|
use_attention_mask (Optional[bool]): Whether to use attention masks during inference.
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
+
kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
|
|
46
|
+
PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
|
|
47
|
+
kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
|
|
48
|
+
in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
|
|
49
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
45
50
|
|
|
46
51
|
Raises:
|
|
47
52
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -55,12 +60,12 @@ class RBLNModelForSeq2SeqLMConfig(RBLNModelConfig):
|
|
|
55
60
|
self.dec_max_seq_len = dec_max_seq_len
|
|
56
61
|
|
|
57
62
|
self.use_attention_mask = use_attention_mask
|
|
58
|
-
npu = self.npu or rebel.get_npu_name()
|
|
59
|
-
if npu == "RBLN-CA02":
|
|
60
|
-
if self.use_attention_mask is False:
|
|
61
|
-
logger.warning("Attention mask should be used with RBLN-CA02. Setting use_attention_mask to True.")
|
|
62
|
-
self.use_attention_mask = True
|
|
63
|
-
else:
|
|
64
|
-
self.use_attention_mask = self.use_attention_mask or False
|
|
65
63
|
|
|
66
|
-
self.
|
|
64
|
+
if self.support_paged_attention:
|
|
65
|
+
self.kvcache_num_blocks = kvcache_num_blocks
|
|
66
|
+
self.kvcache_block_size = kvcache_block_size
|
|
67
|
+
else:
|
|
68
|
+
if kvcache_num_blocks is not None or kvcache_block_size is not None:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
"You cannot set kvcache_num_blocks or kvcache_block_size as paged attention is not supported for the model."
|
|
71
|
+
)
|
|
@@ -20,7 +20,9 @@ import rebel
|
|
|
20
20
|
import torch
|
|
21
21
|
from rebel.compile_context import CompileContext
|
|
22
22
|
from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
|
|
23
|
-
from transformers.
|
|
23
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
24
|
+
from transformers.generation.utils import GenerationMixin
|
|
25
|
+
from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
|
|
24
26
|
|
|
25
27
|
from ....configuration_utils import RBLNCompileConfig
|
|
26
28
|
from ....modeling import RBLNModel
|
|
@@ -32,13 +34,13 @@ from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
|
|
|
32
34
|
logger = get_logger(__name__)
|
|
33
35
|
|
|
34
36
|
if TYPE_CHECKING:
|
|
35
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer,
|
|
37
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
|
39
41
|
mandatory_members = ["main_input_name"]
|
|
40
42
|
|
|
41
|
-
def forward(self, *args: List[torch.Tensor], **kwargs:
|
|
43
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
|
|
42
44
|
output = super().forward(*args, **kwargs)
|
|
43
45
|
return BaseModelOutput(last_hidden_state=output)
|
|
44
46
|
|
|
@@ -83,7 +85,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
83
85
|
decoding_step = cache_position[b_idx].item()
|
|
84
86
|
if not (0 <= decoding_step < self.dec_max_seq_len):
|
|
85
87
|
raise ValueError(
|
|
86
|
-
f"Decoding step {decoding_step} out of bounds for
|
|
88
|
+
f"Decoding step {decoding_step} out of bounds for decoder_max_seq_len ({self.dec_max_seq_len})."
|
|
87
89
|
)
|
|
88
90
|
decoder_attention_mask[b_idx, : decoding_step + 1] = 1
|
|
89
91
|
|
|
@@ -101,7 +103,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
|
101
103
|
return Seq2SeqLMOutput(logits=lm_logits)
|
|
102
104
|
|
|
103
105
|
|
|
104
|
-
class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
106
|
+
class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
|
|
105
107
|
"""
|
|
106
108
|
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
|
107
109
|
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
@@ -117,6 +119,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
117
119
|
main_input_name = "input_ids"
|
|
118
120
|
auto_model_class = AutoModelForSeq2SeqLM
|
|
119
121
|
support_causal_attn = None
|
|
122
|
+
_is_stateful = False
|
|
120
123
|
|
|
121
124
|
def __post_init__(self, **kwargs):
|
|
122
125
|
batch_size = self.rbln_config.batch_size
|
|
@@ -138,7 +141,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
138
141
|
@classmethod
|
|
139
142
|
@torch.inference_mode()
|
|
140
143
|
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
|
141
|
-
wrapped_model = cls.
|
|
144
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
142
145
|
|
|
143
146
|
enc_compile_config = rbln_config.compile_cfgs[0]
|
|
144
147
|
dec_compile_config = rbln_config.compile_cfgs[1]
|
|
@@ -181,6 +184,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
181
184
|
|
|
182
185
|
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
|
183
186
|
|
|
187
|
+
@classmethod
|
|
188
|
+
def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
|
|
189
|
+
rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
|
|
190
|
+
rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
|
|
191
|
+
|
|
192
|
+
if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
|
|
193
|
+
raise NotImplementedError(
|
|
194
|
+
f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
|
|
198
|
+
raise NotImplementedError(
|
|
199
|
+
f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
|
|
200
|
+
)
|
|
201
|
+
|
|
184
202
|
@classmethod
|
|
185
203
|
def _update_rbln_config(
|
|
186
204
|
cls,
|
|
@@ -204,12 +222,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
204
222
|
model_config, "max_position_embeddings", None
|
|
205
223
|
)
|
|
206
224
|
|
|
207
|
-
pad_token_id = getattr(model_config, "pad_token_id", None)
|
|
208
|
-
pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
|
|
209
|
-
pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
|
|
210
|
-
pad_token_id = pad_token_id or -1
|
|
211
|
-
rbln_config.pad_token_id = pad_token_id
|
|
212
|
-
|
|
213
225
|
if rbln_config.enc_max_seq_len is None:
|
|
214
226
|
enc_max_seq_len = max_position_embeddings
|
|
215
227
|
for tokenizer in preprocessors:
|
|
@@ -238,6 +250,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
238
250
|
if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
|
|
239
251
|
raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
|
240
252
|
|
|
253
|
+
if rbln_config.support_paged_attention:
|
|
254
|
+
cls._update_paged_attention_config(model_config, rbln_config)
|
|
255
|
+
|
|
241
256
|
# model input info
|
|
242
257
|
enc_input_info = [
|
|
243
258
|
("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
|
|
@@ -310,6 +325,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
310
325
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
|
311
326
|
|
|
312
327
|
rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
|
|
328
|
+
|
|
313
329
|
return rbln_config
|
|
314
330
|
|
|
315
331
|
@classmethod
|
|
@@ -411,7 +427,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
411
427
|
inputs_tensor = torch.nn.functional.pad(
|
|
412
428
|
inputs_tensor,
|
|
413
429
|
(0, self.rbln_config.enc_max_seq_len - input_len),
|
|
414
|
-
value=self.
|
|
430
|
+
value=self.config.pad_token_id,
|
|
415
431
|
)
|
|
416
432
|
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
|
417
433
|
model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
|
|
@@ -430,3 +446,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
|
430
446
|
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
|
|
431
447
|
|
|
432
448
|
return model_kwargs
|
|
449
|
+
|
|
450
|
+
def generate(
|
|
451
|
+
self,
|
|
452
|
+
input_ids: torch.LongTensor,
|
|
453
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
454
|
+
generation_config: Optional[GenerationConfig] = None,
|
|
455
|
+
**kwargs,
|
|
456
|
+
) -> Union[ModelOutput, torch.LongTensor]:
|
|
457
|
+
"""
|
|
458
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
459
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
input_ids (torch.LongTensor): The input ids to the model.
|
|
463
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
464
|
+
generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
|
|
465
|
+
If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
|
|
466
|
+
Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
|
|
467
|
+
kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Generates sequences of token ids for models with a language modeling head.
|
|
471
|
+
"""
|
|
472
|
+
if generation_config is not None:
|
|
473
|
+
kwargs["generation_config"] = generation_config
|
|
474
|
+
if attention_mask is not None:
|
|
475
|
+
kwargs["attention_mask"] = attention_mask
|
|
476
|
+
|
|
477
|
+
return super().generate(input_ids, **kwargs)
|
|
@@ -31,7 +31,7 @@ class Seq2SeqWrapper:
|
|
|
31
31
|
Args:
|
|
32
32
|
model (nn.Module): The Seq2Seq model to wrap.
|
|
33
33
|
enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
|
|
34
|
-
|
|
34
|
+
kwargs: Additional arguments to pass to the decoder wrapper.
|
|
35
35
|
"""
|
|
36
36
|
|
|
37
37
|
def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
|
|
@@ -125,7 +125,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
|
|
|
125
125
|
|
|
126
126
|
Args:
|
|
127
127
|
model (nn.Module): The Seq2Seq model containing the decoder.
|
|
128
|
-
|
|
128
|
+
kwargs: Additional arguments for decoder configuration.
|
|
129
129
|
"""
|
|
130
130
|
|
|
131
131
|
def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
|
|
@@ -12,9 +12,5 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from .configuration_siglip import
|
|
16
|
-
|
|
17
|
-
)
|
|
18
|
-
from .modeling_siglip import (
|
|
19
|
-
RBLNSiglipVisionModel,
|
|
20
|
-
)
|
|
15
|
+
from .configuration_siglip import RBLNSiglipVisionModelConfig
|
|
16
|
+
from .modeling_siglip import RBLNSiglipVisionModel
|