optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +22 -12
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
- optimum/rbln/diffusers/pipelines/__init__.py +22 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +164 -758
- optimum/rbln/modeling_diffusers.py +51 -122
- optimum/rbln/transformers/__init__.py +0 -2
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
- optimum/rbln/utils/decorator_utils.py +10 -6
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +15 -1
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +1 -1
- optimum/rbln/utils/submodule.py +114 -0
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,262 +21,74 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import TYPE_CHECKING, Tuple
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
29
28
|
|
30
|
-
from
|
29
|
+
from ..decoderonly.decoderonly_architecture import (
|
30
|
+
DecoderOnlyAttention,
|
31
|
+
DecoderOnlyForCausalLM,
|
32
|
+
DecoderOnlyLayer,
|
33
|
+
DecoderOnlyModel,
|
34
|
+
DecoderOnlyWrapper,
|
35
|
+
)
|
31
36
|
|
32
37
|
|
33
|
-
|
34
|
-
|
35
|
-
super().__init__()
|
36
|
-
self.model = model.transformer
|
37
|
-
self.lm_head = model.lm_head
|
38
|
-
self.config = model.config
|
39
|
-
self.max_seq_len = max_seq_len
|
40
|
-
self.forward_dict = self.get_forward_dict()
|
38
|
+
if TYPE_CHECKING:
|
39
|
+
from transformers import GPT2LMHeadModel
|
41
40
|
|
42
|
-
def get_forward_dict(self):
|
43
|
-
forward_dict = {
|
44
|
-
"wrapper": _GPT2Model.forward,
|
45
|
-
"model": _GPT2Block.forward,
|
46
|
-
"decoder_layer": _GPT2Attention.forward,
|
47
|
-
}
|
48
|
-
return forward_dict
|
49
41
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
rbln_batch_position = batch_position
|
42
|
+
class GPT2Wrapper(DecoderOnlyWrapper):
|
43
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel"):
|
44
|
+
if self.attn_impl != "eager":
|
45
|
+
raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
|
46
|
+
new_layers = []
|
47
|
+
for layer in causal_lm.transformer.h:
|
48
|
+
new_self_attn = GPT2Attention(layer.attn)
|
49
|
+
new_layer = GPT2Layer(layer, new_self_attn)
|
50
|
+
new_layers.append(new_layer)
|
51
|
+
new_model = GPT2Model(causal_lm.transformer, new_layers)
|
52
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
53
|
+
return new_causal_lm
|
63
54
|
|
64
|
-
# Formatting list of past_kv to DynamicCache class.
|
65
|
-
past_key_value = RebelDynamicCache_4D.from_input_format(
|
66
|
-
cache_position,
|
67
|
-
self.config.n_layer,
|
68
|
-
*past_key_values,
|
69
|
-
)
|
70
|
-
|
71
|
-
outputs = self.forward_dict["wrapper"](
|
72
|
-
self.model,
|
73
|
-
input_ids=input_ids,
|
74
|
-
attention_mask=attention_mask,
|
75
|
-
position_ids=cache_position,
|
76
|
-
past_key_value=past_key_value,
|
77
|
-
batch_ids=rbln_batch_position,
|
78
|
-
forward_dict=self.forward_dict,
|
79
|
-
# rotary_emb differenct from_llama
|
80
|
-
)
|
81
|
-
|
82
|
-
hidden_states = outputs[0]
|
83
|
-
if batch_position >= 0:
|
84
|
-
hidden_states = hidden_states[:, query_idx].unsqueeze(1)
|
85
|
-
logits = self.lm_head(hidden_states)
|
86
|
-
|
87
|
-
output = (logits,) + outputs[1:]
|
88
55
|
|
89
|
-
|
56
|
+
class GPT2Model(DecoderOnlyModel):
|
57
|
+
mask_fmin = torch.finfo(torch.float32).min
|
90
58
|
|
59
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
60
|
+
return self._original_mod.ln_f
|
91
61
|
|
92
|
-
|
93
|
-
|
94
|
-
self,
|
95
|
-
input_ids: torch.LongTensor = None,
|
96
|
-
attention_mask: Optional[torch.Tensor] = None,
|
97
|
-
position_ids: Optional[torch.LongTensor] = None,
|
98
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
99
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
100
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
101
|
-
) -> BaseModelOutputWithPast:
|
102
|
-
b_size, q_len = input_ids.shape
|
103
|
-
inputs_embeds = self.wte(input_ids)
|
62
|
+
def get_embedding(self) -> nn.Embedding:
|
63
|
+
return self._original_mod.wte
|
104
64
|
|
105
|
-
|
106
|
-
|
107
|
-
for b_idx in range(b_size):
|
108
|
-
position_embed = self.wpe(position_ids[b_idx])
|
109
|
-
# position_embed = position_embed.dtype(inputs_embeds.dtype)
|
110
|
-
position_embeds.append(position_embed)
|
65
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
66
|
+
return self._original_mod.wpe
|
111
67
|
|
112
|
-
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
113
|
-
else:
|
114
|
-
position_embeds = self.wpe(position_ids)
|
115
68
|
|
116
|
-
|
69
|
+
class GPT2Layer(DecoderOnlyLayer):
|
70
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
71
|
+
return self._original_mod.ln_1
|
117
72
|
|
118
|
-
|
119
|
-
|
120
|
-
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
73
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
74
|
+
return self._original_mod.ln_2
|
121
75
|
|
122
|
-
for layer_idx, block in enumerate(self.h):
|
123
|
-
hidden_states, updated_cache = forward_dict["model"](
|
124
|
-
block,
|
125
|
-
hidden_states,
|
126
|
-
layer_idx,
|
127
|
-
attention_mask=attention_mask,
|
128
|
-
past_key_value=past_key_value,
|
129
|
-
position_ids=position_ids,
|
130
|
-
batch_ids=batch_ids,
|
131
|
-
forward_dict=forward_dict,
|
132
|
-
)
|
133
|
-
|
134
|
-
hidden_states = self.ln_f(hidden_states)
|
135
|
-
output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
|
136
|
-
hidden_states = hidden_states.view(output_shape)
|
137
|
-
|
138
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
139
|
-
next_cache = updated_cache.to_legacy_cache()
|
140
|
-
|
141
|
-
return BaseModelOutputWithPast(
|
142
|
-
last_hidden_state=hidden_states,
|
143
|
-
past_key_values=next_cache,
|
144
|
-
)
|
145
76
|
|
77
|
+
class GPT2Attention(DecoderOnlyAttention):
|
78
|
+
def __post_init__(self):
|
79
|
+
self.c_attn = self._original_mod.c_attn
|
80
|
+
self.o_proj = self._original_mod.c_proj
|
81
|
+
self.split_size = self._original_mod.split_size
|
82
|
+
self.num_key_value_heads = self._original_mod.num_heads
|
146
83
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
151
|
-
layer_idx: int,
|
152
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
153
|
-
position_ids: Optional[torch.LongTensor] = None,
|
154
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
155
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
156
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
157
|
-
**kwargs,
|
158
|
-
) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
|
159
|
-
residual = hidden_states
|
160
|
-
hidden_states = self.ln_1(hidden_states)
|
84
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
85
|
+
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
86
|
+
return query_states, key_states, value_states
|
161
87
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
batch_index=batch_ids,
|
88
|
+
def rbln_attention(self, *args, **kwargs):
|
89
|
+
return super().rbln_attention(
|
90
|
+
*args,
|
91
|
+
**kwargs,
|
92
|
+
layer_idx=self.layer_idx,
|
93
|
+
scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
|
169
94
|
)
|
170
|
-
past_key_value.assign(k, v, layer_idx)
|
171
|
-
|
172
|
-
# residual connection
|
173
|
-
hidden_states = residual + hidden_states
|
174
|
-
|
175
|
-
residual = hidden_states
|
176
|
-
hidden_states = self.ln_2(hidden_states)
|
177
|
-
hidden_states = self.mlp(hidden_states)
|
178
|
-
hidden_states = residual + hidden_states
|
179
|
-
|
180
|
-
return hidden_states, past_key_value
|
181
|
-
|
182
|
-
|
183
|
-
class _GPT2Attention:
|
184
|
-
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
185
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
186
|
-
|
187
|
-
if self.scale_attn_weights:
|
188
|
-
attn_weights = attn_weights / torch.full(
|
189
|
-
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
190
|
-
)
|
191
|
-
|
192
|
-
# Layer-wise attention scaling
|
193
|
-
if self.scale_attn_by_inverse_layer_idx:
|
194
|
-
attn_weights = attn_weights / float(self.layer_idx + 1)
|
195
|
-
|
196
|
-
# -------------------
|
197
|
-
# Below are deleted since "where" op does not supported on RBLN graph.
|
198
|
-
# -------------------
|
199
|
-
# if not self.is_cross_attention:
|
200
|
-
# # if only "normal" attention layer implements causal mask
|
201
|
-
# query_length, key_length = query.size(-2), key.size(-2)
|
202
|
-
# causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
203
|
-
# mask_value = torch.finfo(attn_weights.dtype).min
|
204
|
-
# # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
205
|
-
# # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
206
|
-
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
207
|
-
# attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
208
|
-
|
209
|
-
# Apply the attention mask
|
210
|
-
attn_weights.view(
|
211
|
-
-1,
|
212
|
-
)
|
213
|
-
attn_weights = attn_weights + attention_mask
|
214
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
215
|
-
attn_output = torch.matmul(attn_weights, value)
|
216
|
-
|
217
|
-
return attn_output, attn_weights
|
218
|
-
|
219
|
-
def forward(
|
220
|
-
self,
|
221
|
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
222
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
223
|
-
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
224
|
-
batch_index: Optional[int] = None,
|
225
|
-
**kwargs,
|
226
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
227
|
-
bsz, q_len, _ = hidden_states.size()
|
228
|
-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
229
|
-
|
230
|
-
querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
|
231
|
-
keys = self._split_heads(key, self.num_heads, self.head_dim)
|
232
|
-
values = self._split_heads(value, self.num_heads, self.head_dim)
|
233
|
-
|
234
|
-
# Decoder
|
235
|
-
if (batch_index is None or batch_index == -1) and bsz > 1:
|
236
|
-
all_keys = []
|
237
|
-
all_values = []
|
238
|
-
all_attn_output = []
|
239
|
-
|
240
|
-
for b in range(bsz):
|
241
|
-
query = querys[b].unsqueeze(0)
|
242
|
-
attn_mask = attention_mask[b].unsqueeze(0)
|
243
|
-
key = keys[b].unsqueeze(0)
|
244
|
-
value = values[b].unsqueeze(0)
|
245
|
-
|
246
|
-
key, value = past_key_value.update(
|
247
|
-
key,
|
248
|
-
value,
|
249
|
-
self.layer_idx,
|
250
|
-
b,
|
251
|
-
)
|
252
|
-
|
253
|
-
attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
|
254
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
255
|
-
|
256
|
-
all_keys.append(key)
|
257
|
-
all_values.append(value)
|
258
|
-
all_attn_output.append(attn_output)
|
259
|
-
|
260
|
-
keys = torch.cat(all_keys, dim=0)
|
261
|
-
values = torch.cat(all_values, dim=0)
|
262
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
263
|
-
|
264
|
-
# Prefill
|
265
|
-
else:
|
266
|
-
if batch_index is None or batch_index == -1:
|
267
|
-
batch_index = 0
|
268
|
-
|
269
|
-
keys, values = past_key_value.update(
|
270
|
-
keys,
|
271
|
-
values,
|
272
|
-
self.layer_idx,
|
273
|
-
batch_index,
|
274
|
-
read_first_step=True,
|
275
|
-
)
|
276
|
-
|
277
|
-
attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
|
278
|
-
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
279
|
-
|
280
|
-
attn_output = self.c_proj(attn_output)
|
281
|
-
|
282
|
-
return attn_output, keys, values
|
@@ -23,7 +23,7 @@
|
|
23
23
|
|
24
24
|
from ....utils import logging
|
25
25
|
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
26
|
-
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
26
|
+
from .gpt2_architecture import GPT2Wrapper # GPT2LMHeadModelWrapper
|
27
27
|
|
28
28
|
|
29
29
|
logger = logging.get_logger(__name__)
|
@@ -43,4 +43,5 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
|
43
43
|
|
44
44
|
"""
|
45
45
|
|
46
|
-
_decoder_wrapper_cls =
|
46
|
+
_decoder_wrapper_cls = GPT2Wrapper
|
47
|
+
_use_rotary_emb = False
|
@@ -23,7 +23,7 @@
|
|
23
23
|
import inspect
|
24
24
|
import logging
|
25
25
|
from pathlib import Path
|
26
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict,
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
27
27
|
|
28
28
|
import numpy as np
|
29
29
|
import torch
|
@@ -36,7 +36,7 @@ from transformers import (
|
|
36
36
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
37
37
|
from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
|
38
38
|
|
39
|
-
from ....
|
39
|
+
from ....modeling import RBLNModel
|
40
40
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
41
41
|
from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
|
42
42
|
|
@@ -166,19 +166,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
166
166
|
self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
167
167
|
return super().__post_init__(**kwargs)
|
168
168
|
|
169
|
-
@classmethod
|
170
|
-
def get_pytorch_model(
|
171
|
-
cls,
|
172
|
-
model_id: str,
|
173
|
-
*args,
|
174
|
-
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
175
|
-
**kwargs,
|
176
|
-
) -> "PreTrainedModel":
|
177
|
-
# Optimum's TasksManager does not handle Llava.
|
178
|
-
kwargs = cls.update_kwargs(kwargs)
|
179
|
-
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, *args, **kwargs)
|
180
|
-
return model
|
181
|
-
|
182
169
|
def get_input_embeddings(self):
|
183
170
|
return self.language_model.get_input_embeddings()
|
184
171
|
|
@@ -422,66 +409,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
|
|
422
409
|
|
423
410
|
return outputs
|
424
411
|
|
425
|
-
def vllm_forward(
|
426
|
-
self,
|
427
|
-
input_ids: torch.LongTensor = None,
|
428
|
-
pixel_values: torch.FloatTensor = None,
|
429
|
-
image_sizes: Optional[torch.LongTensor] = None,
|
430
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
431
|
-
vision_feature_layer: Optional[int] = None,
|
432
|
-
vision_feature_select_strategy: Optional[str] = None,
|
433
|
-
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
434
|
-
batch_idx: Optional[int] = None,
|
435
|
-
**kwargs,
|
436
|
-
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
|
437
|
-
is_prefill = cache_position.shape[-1] > 1
|
438
|
-
|
439
|
-
if inputs_embeds is not None:
|
440
|
-
raise NotImplementedError("Specifying inputs_embeds is not supported.")
|
441
|
-
|
442
|
-
if is_prefill:
|
443
|
-
# Get text_embeds
|
444
|
-
inputs_embeds = self.text_embedding(input_ids)
|
445
|
-
|
446
|
-
# If any images in the prompt, get image_embeds and merge with text
|
447
|
-
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
|
448
|
-
image_features, _ = self.image_embedding(
|
449
|
-
image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
|
450
|
-
)
|
451
|
-
|
452
|
-
def merge_vllm_multimodal_embeddings(
|
453
|
-
input_ids: torch.Tensor,
|
454
|
-
inputs_embeds: torch.Tensor,
|
455
|
-
multimodal_embeddings: torch.Tensor,
|
456
|
-
placeholder_token_id: int,
|
457
|
-
) -> torch.Tensor:
|
458
|
-
mask = input_ids == placeholder_token_id
|
459
|
-
num_expected_tokens = mask.sum().item()
|
460
|
-
|
461
|
-
if multimodal_embeddings.shape[0] != num_expected_tokens:
|
462
|
-
raise ValueError(
|
463
|
-
f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
|
464
|
-
f"multimodal tokens to {num_expected_tokens} placeholders"
|
465
|
-
)
|
466
|
-
|
467
|
-
inputs_embeds[mask] = multimodal_embeddings
|
468
|
-
return inputs_embeds
|
469
|
-
|
470
|
-
inputs_embeds = merge_vllm_multimodal_embeddings(
|
471
|
-
input_ids, inputs_embeds, image_features, self.config.image_token_index
|
472
|
-
)
|
473
|
-
|
474
|
-
else:
|
475
|
-
inputs_embeds = self.text_embedding(input_ids=input_ids)
|
476
|
-
|
477
|
-
outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
|
478
|
-
inputs_embeds=inputs_embeds,
|
479
|
-
batch_idx=batch_idx,
|
480
|
-
cache_position=cache_position,
|
481
|
-
)
|
482
|
-
|
483
|
-
return outputs
|
484
|
-
|
485
412
|
# Almost copied from : https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/llava_next/modeling_llava_next.py
|
486
413
|
def pack_image_features(self, image_features, image_sizes, image_newline=None):
|
487
414
|
"""
|