optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -21,302 +21,101 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
from typing import Dict, Optional, Tuple
|
24
|
+
from typing import TYPE_CHECKING, Optional, Tuple
|
26
25
|
|
27
26
|
import torch
|
28
|
-
|
29
|
-
from transformers.modeling_outputs import (
|
30
|
-
BaseModelOutputWithPast,
|
31
|
-
)
|
27
|
+
from transformers import PhiForCausalLM
|
32
28
|
|
33
|
-
from
|
34
|
-
|
29
|
+
from ..decoderonly.decoderonly_architecture import (
|
30
|
+
DecoderOnlyAttention,
|
31
|
+
DecoderOnlyForCausalLM,
|
32
|
+
DecoderOnlyLayer,
|
33
|
+
DecoderOnlyModel,
|
35
34
|
DecoderOnlyWrapper,
|
36
|
-
|
37
|
-
slice_and_unsqueeze_cos_sin,
|
35
|
+
apply_rotary_pos_emb_partial,
|
38
36
|
)
|
39
37
|
|
40
38
|
|
41
|
-
|
42
|
-
|
43
|
-
forward_dict = {}
|
44
|
-
forward_dict.update(
|
45
|
-
{
|
46
|
-
"wrapper": PhiModel.forward,
|
47
|
-
"model": PhiDecoderLayer.forward,
|
48
|
-
"decoder_layer": PhiAttention.forward,
|
49
|
-
}
|
50
|
-
)
|
51
|
-
return forward_dict
|
52
|
-
|
53
|
-
|
54
|
-
class PhiAttention:
|
55
|
-
def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
|
56
|
-
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
57
|
-
key_state = key_state.unsqueeze(2)
|
58
|
-
value_state = value_state.unsqueeze(2)
|
59
|
-
attn_mask = attn_mask.unsqueeze(2)
|
60
|
-
|
61
|
-
query_state = query_state.view(
|
62
|
-
1,
|
63
|
-
self.num_key_value_heads,
|
64
|
-
self.num_heads // self.num_key_value_heads,
|
65
|
-
-1,
|
66
|
-
self.head_dim,
|
67
|
-
)
|
68
|
-
|
69
|
-
key_state, value_state = past_key_value.update(key_state, value_state, self.layer_idx, batch_idx, is_prefill)
|
70
|
-
|
71
|
-
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
72
|
-
attn_weights = torch.matmul(
|
73
|
-
query_state.to(torch.float32),
|
74
|
-
key_state.to(torch.float32).transpose(3, 4),
|
75
|
-
) / math.sqrt(self.head_dim)
|
76
|
-
attn_weights = attn_weights + attn_mask
|
77
|
-
|
78
|
-
# upcast attention to fp32
|
79
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_state.dtype)
|
80
|
-
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
81
|
-
attn_output = torch.matmul(attn_weights, value_state)
|
82
|
-
|
83
|
-
# reshape for removing repeat_kv
|
84
|
-
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
85
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
86
|
-
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from transformers import PhiForCausalLM
|
87
41
|
|
88
|
-
return attn_output, key_state, value_state
|
89
|
-
|
90
|
-
def forward(
|
91
|
-
self,
|
92
|
-
hidden_states: torch.Tensor,
|
93
|
-
attention_mask: Optional[torch.Tensor] = None,
|
94
|
-
past_key_value: Optional[RebelDynamicCache] = None,
|
95
|
-
batch_index: Optional[int] = None,
|
96
|
-
output_attentions: bool = False,
|
97
|
-
cos: Optional[torch.Tensor] = None,
|
98
|
-
sin: Optional[torch.Tensor] = None,
|
99
|
-
**kwargs,
|
100
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
101
|
-
bsz, q_len, _ = hidden_states.size()
|
102
42
|
|
43
|
+
class PhiWrapper(DecoderOnlyWrapper):
|
44
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM"):
|
45
|
+
new_layers = []
|
46
|
+
for layer in causal_lm.model.layers:
|
47
|
+
if self.attn_impl == "eager":
|
48
|
+
new_self_attn = PhiAttention(layer.self_attn)
|
49
|
+
elif self.attn_impl == "flash_attn":
|
50
|
+
raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
|
51
|
+
else:
|
52
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
53
|
+
new_layer = PhiLayer(layer, new_self_attn)
|
54
|
+
new_layers.append(new_layer)
|
55
|
+
new_model = PhiModel(causal_lm.model, new_layers)
|
56
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
57
|
+
return new_causal_lm
|
58
|
+
|
59
|
+
|
60
|
+
class PhiAttention(DecoderOnlyAttention):
|
61
|
+
def __post_init__(self):
|
62
|
+
self.q_proj = self._original_mod.q_proj
|
63
|
+
self.k_proj = self._original_mod.k_proj
|
64
|
+
self.v_proj = self._original_mod.v_proj
|
65
|
+
self.o_proj = self._original_mod.dense
|
66
|
+
self.qk_layernorm = self._original_mod.qk_layernorm
|
67
|
+
self.rotary_ndims = self._original_mod.rotary_ndims
|
68
|
+
|
69
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
103
70
|
query_states = self.q_proj(hidden_states)
|
104
71
|
key_states = self.k_proj(hidden_states)
|
105
72
|
value_states = self.v_proj(hidden_states)
|
106
73
|
|
107
74
|
if self.qk_layernorm:
|
108
|
-
query_states = self.q_layernorm(query_states)
|
109
|
-
key_states = self.k_layernorm(key_states)
|
110
|
-
|
111
|
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
112
|
-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
113
|
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
114
|
-
|
115
|
-
# Partial rotary embedding
|
116
|
-
query_rot, query_pass = (
|
117
|
-
query_states[..., : self.rotary_ndims],
|
118
|
-
query_states[..., self.rotary_ndims :],
|
119
|
-
)
|
120
|
-
key_rot, key_pass = (
|
121
|
-
key_states[..., : self.rotary_ndims],
|
122
|
-
key_states[..., self.rotary_ndims :],
|
123
|
-
)
|
75
|
+
query_states = self._original_mod.q_layernorm(query_states)
|
76
|
+
key_states = self._original_mod.k_layernorm(key_states)
|
124
77
|
|
125
|
-
|
126
|
-
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
78
|
+
return query_states, key_states, value_states
|
127
79
|
|
128
|
-
|
129
|
-
query_states
|
130
|
-
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
80
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
81
|
+
return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=self.rotary_ndims)
|
131
82
|
|
132
|
-
# Decoder (bsz > 1)
|
133
|
-
if bsz > 1:
|
134
|
-
iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
|
135
|
-
for b in range(bsz):
|
136
|
-
attn_output, key_state, value_state = PhiAttention._attn(
|
137
|
-
self,
|
138
|
-
query_states[b].unsqueeze(0),
|
139
|
-
key_states[b].unsqueeze(0),
|
140
|
-
value_states[b].unsqueeze(0),
|
141
|
-
attention_mask[b].unsqueeze(0),
|
142
|
-
past_key_value,
|
143
|
-
batch_idx=b,
|
144
|
-
is_prefill=False,
|
145
|
-
)
|
146
|
-
iterate_results["key_states"].append(key_state)
|
147
|
-
iterate_results["value_states"].append(value_state)
|
148
|
-
iterate_results["attn_output"].append(attn_output)
|
149
83
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
# Prefill & Decoder (bsz == 1)
|
154
|
-
else:
|
155
|
-
attn_output, key_states, value_states = PhiAttention._attn(
|
156
|
-
self,
|
157
|
-
query_states,
|
158
|
-
key_states,
|
159
|
-
value_states,
|
160
|
-
attention_mask,
|
161
|
-
past_key_value,
|
162
|
-
batch_idx=batch_index,
|
163
|
-
is_prefill=True,
|
164
|
-
)
|
84
|
+
class PhiLayer(DecoderOnlyLayer):
|
85
|
+
def get_post_attention_layernorm(self):
|
86
|
+
raise NotImplementedError
|
165
87
|
|
166
|
-
attn_output = self.dense(attn_output)
|
167
|
-
|
168
|
-
if not output_attentions:
|
169
|
-
attn_weights = None
|
170
|
-
|
171
|
-
return attn_output, attn_weights, key_states, value_states
|
172
|
-
|
173
|
-
|
174
|
-
class PhiDecoderLayer:
|
175
88
|
def forward(
|
176
89
|
self,
|
177
90
|
hidden_states: torch.Tensor,
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
output_attentions: Optional[bool] = None,
|
183
|
-
use_cache: Optional[bool] = None,
|
184
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
91
|
+
attention_mask: torch.Tensor,
|
92
|
+
seq_positions: torch.LongTensor,
|
93
|
+
batch_position: torch.Tensor,
|
94
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
185
95
|
cos: Optional[torch.Tensor] = None,
|
186
96
|
sin: Optional[torch.Tensor] = None,
|
187
|
-
|
188
|
-
**kwargs,
|
189
|
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
190
|
-
"""
|
191
|
-
Args:
|
192
|
-
hidden_states (`torch.FloatTensor`):
|
193
|
-
input to the layer of shape `(batch, seq_len, embed_dim)`
|
194
|
-
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
195
|
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
196
|
-
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
197
|
-
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
198
|
-
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
199
|
-
output_attentions (`bool`, *optional*):
|
200
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
201
|
-
returned tensors for more detail.
|
202
|
-
use_cache (`bool`, *optional*):
|
203
|
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
204
|
-
(see `past_key_values`).
|
205
|
-
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
206
|
-
"""
|
207
|
-
|
97
|
+
):
|
208
98
|
residual = hidden_states
|
209
99
|
|
210
|
-
hidden_states = self.
|
100
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
211
101
|
|
212
|
-
|
213
|
-
attn_outputs, self_attn_weights, key_states, value_states = forward_dict["decoder_layer"](
|
214
|
-
self.self_attn,
|
102
|
+
attn_outputs, present_key_values = self.self_attn(
|
215
103
|
hidden_states=hidden_states,
|
216
104
|
attention_mask=attention_mask,
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
batch_index=batch_ids,
|
221
|
-
use_cache=use_cache,
|
105
|
+
seq_positions=seq_positions,
|
106
|
+
batch_position=batch_position,
|
107
|
+
past_key_values=past_key_values,
|
222
108
|
cos=cos,
|
223
109
|
sin=sin,
|
224
|
-
**kwargs,
|
225
110
|
)
|
226
|
-
past_key_value.assign(key_states, value_states, layer_idx)
|
227
111
|
|
228
|
-
|
112
|
+
feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
|
229
113
|
|
230
|
-
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
231
114
|
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
232
|
-
outputs = (hidden_states,)
|
233
|
-
|
234
|
-
if output_attentions:
|
235
|
-
outputs += (self_attn_weights,)
|
236
|
-
|
237
|
-
if use_cache:
|
238
|
-
outputs += (past_key_value,)
|
239
|
-
|
240
|
-
return outputs
|
241
|
-
|
242
|
-
|
243
|
-
class PhiModel:
|
244
|
-
def forward(
|
245
|
-
self,
|
246
|
-
input_ids: torch.LongTensor = None,
|
247
|
-
attention_mask: Optional[torch.Tensor] = None,
|
248
|
-
position_ids: Optional[torch.LongTensor] = None,
|
249
|
-
past_key_values: Optional[RebelDynamicCache] = None,
|
250
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
251
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
252
|
-
use_cache: Optional[bool] = True,
|
253
|
-
output_attentions: Optional[bool] = False,
|
254
|
-
output_hidden_states: Optional[bool] = False,
|
255
|
-
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
256
|
-
kvcache_partition_size: Optional[torch.Tensor] = None,
|
257
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
258
|
-
rotary_pos_emb=None,
|
259
|
-
) -> BaseModelOutputWithPast:
|
260
|
-
# retrieve input_ids and inputs_embeds
|
261
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
262
|
-
raise ValueError(
|
263
|
-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
264
|
-
)
|
265
|
-
|
266
|
-
# embed positions
|
267
|
-
if inputs_embeds is None:
|
268
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
269
|
-
|
270
|
-
hidden_states = inputs_embeds
|
271
|
-
attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
|
272
|
-
|
273
|
-
# get cos,sin vector
|
274
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
275
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
276
|
-
|
277
|
-
# decoder layers
|
278
|
-
all_hidden_states = () if output_hidden_states else None
|
279
|
-
all_self_attns = () if output_attentions else None
|
280
|
-
|
281
|
-
for layer_idx, decoder_layer in enumerate(self.layers):
|
282
|
-
if output_hidden_states:
|
283
|
-
all_hidden_states += (hidden_states,)
|
284
|
-
layer_outputs = forward_dict["model"](
|
285
|
-
decoder_layer,
|
286
|
-
hidden_states,
|
287
|
-
layer_idx,
|
288
|
-
attention_mask=attention_mask,
|
289
|
-
position_ids=position_ids,
|
290
|
-
past_key_value=past_key_values,
|
291
|
-
output_attentions=output_attentions,
|
292
|
-
use_cache=use_cache,
|
293
|
-
batch_ids=batch_ids,
|
294
|
-
cos=cos,
|
295
|
-
sin=sin,
|
296
|
-
cache_pos_for_partitions=cache_pos_for_partitions,
|
297
|
-
kvcache_partition_size=kvcache_partition_size,
|
298
|
-
forward_dict=forward_dict,
|
299
|
-
)
|
300
115
|
|
301
|
-
|
116
|
+
return hidden_states, present_key_values
|
302
117
|
|
303
|
-
updated_cache = layer_outputs[2 if output_attentions else 1]
|
304
118
|
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
hidden_states = self.final_layernorm(hidden_states)
|
309
|
-
|
310
|
-
# add hidden states from the last decoder layer
|
311
|
-
if output_hidden_states:
|
312
|
-
all_hidden_states += (hidden_states,)
|
313
|
-
|
314
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
315
|
-
next_cache = updated_cache.to_legacy_cache()
|
316
|
-
|
317
|
-
return BaseModelOutputWithPast(
|
318
|
-
last_hidden_state=hidden_states,
|
319
|
-
past_key_values=next_cache,
|
320
|
-
hidden_states=all_hidden_states,
|
321
|
-
attentions=all_self_attns,
|
322
|
-
)
|
119
|
+
class PhiModel(DecoderOnlyModel):
|
120
|
+
def get_last_layernorm(self):
|
121
|
+
return self._original_mod.final_layernorm
|
@@ -26,13 +26,14 @@ import logging
|
|
26
26
|
from abc import ABC
|
27
27
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
28
28
|
|
29
|
-
import rebel
|
30
|
-
import torch
|
29
|
+
import rebel
|
30
|
+
import torch
|
31
|
+
from rebel.compile_context import CompileContext
|
31
32
|
from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
|
32
33
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
33
34
|
|
34
|
-
from ....
|
35
|
-
from ....modeling_config import
|
35
|
+
from ....modeling import RBLNModel
|
36
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
36
37
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
37
38
|
|
38
39
|
|
@@ -66,7 +67,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
66
67
|
class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
67
68
|
"""
|
68
69
|
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
69
|
-
This model inherits from [`
|
70
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
70
71
|
|
71
72
|
A class to convert and run pre-trained transformers based Seq2SeqLM models on RBLN devices.
|
72
73
|
It implements the methods to convert a pre-trained transformers Seq2SeqLM model into a RBLN transformer model by:
|
@@ -88,49 +89,42 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
88
89
|
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNConfig):
|
89
90
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
90
91
|
|
91
|
-
|
92
|
-
|
92
|
+
enc_compile_config = rbln_config.compile_cfgs[0]
|
93
|
+
dec_compile_config = rbln_config.compile_cfgs[1]
|
93
94
|
|
94
|
-
|
95
|
-
wrapped_model.decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
|
95
|
+
context = CompileContext(use_weight_sharing=False)
|
96
96
|
|
97
|
-
|
98
|
-
dec_rbln_compile_config = rbln_config.compile_cfgs[1]
|
97
|
+
enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
|
99
98
|
|
100
|
-
|
101
|
-
|
99
|
+
# Mark encoder's static tensors (cross kv states)
|
100
|
+
static_tensors = {}
|
101
|
+
for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
|
102
|
+
if "key_value_states" in name:
|
103
|
+
static_tensors[name] = tensor
|
104
|
+
context.mark_static_address(tensor)
|
102
105
|
|
103
|
-
|
104
|
-
dec_example_inputs[4].fill_(-1)
|
106
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
105
107
|
|
106
|
-
|
107
|
-
|
108
|
+
# Mark decoder's static tensors (self kv states)
|
109
|
+
for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
|
110
|
+
if "key_value_states" in name:
|
111
|
+
context.mark_static_address(tensor)
|
108
112
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
+
compiled_encoder = super().compile(
|
114
|
+
wrapped_model.encoder,
|
115
|
+
enc_compile_config,
|
116
|
+
example_inputs=enc_example_inputs,
|
117
|
+
compile_context=context,
|
113
118
|
)
|
114
|
-
dec_ir = rebel.torchscript_to_ir(
|
115
|
-
dec_scripted_model,
|
116
|
-
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
117
|
-
name=dec_rbln_compile_config.mod_name,
|
118
|
-
)
|
119
|
-
|
120
|
-
connections = [
|
121
|
-
(enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
|
122
|
-
(dec_ir.outputs[1], dec_ir.inputs[5]),
|
123
|
-
]
|
124
119
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
npu=enc_rbln_compile_config.npu,
|
131
|
-
tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
|
120
|
+
compiled_decoder = super().compile(
|
121
|
+
wrapped_model.decoder,
|
122
|
+
dec_compile_config,
|
123
|
+
example_inputs=dec_example_inputs,
|
124
|
+
compile_context=context,
|
132
125
|
)
|
133
|
-
|
126
|
+
|
127
|
+
return {"encoder": compiled_encoder, "decoder": compiled_decoder}
|
134
128
|
|
135
129
|
@classmethod
|
136
130
|
def _get_rbln_config(
|
@@ -204,7 +198,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
204
198
|
],
|
205
199
|
"float32",
|
206
200
|
),
|
207
|
-
("
|
201
|
+
("batch_position", [], "int16"),
|
208
202
|
]
|
209
203
|
|
210
204
|
dec_input_info = [
|
@@ -216,17 +210,16 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
216
210
|
[rbln_batch_size, 1],
|
217
211
|
"int32",
|
218
212
|
),
|
219
|
-
("batch_position", [], "int32"),
|
220
213
|
]
|
221
214
|
dec_input_info.extend(
|
222
215
|
[
|
223
216
|
(
|
224
|
-
"
|
217
|
+
"cross_key_value_states",
|
225
218
|
[
|
226
219
|
n_layer * 2,
|
227
220
|
rbln_batch_size,
|
228
221
|
n_head,
|
229
|
-
|
222
|
+
rbln_enc_max_seq_len,
|
230
223
|
d_kv,
|
231
224
|
],
|
232
225
|
"float32",
|
@@ -236,24 +229,24 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
236
229
|
dec_input_info.extend(
|
237
230
|
[
|
238
231
|
(
|
239
|
-
"
|
232
|
+
f"self_key_value_states_{i}",
|
240
233
|
[
|
241
|
-
n_layer * 2,
|
242
234
|
rbln_batch_size,
|
243
235
|
n_head,
|
244
|
-
|
236
|
+
rbln_dec_max_seq_len,
|
245
237
|
d_kv,
|
246
238
|
],
|
247
239
|
"float32",
|
248
240
|
)
|
241
|
+
for i in range(n_layer * 2)
|
249
242
|
]
|
250
243
|
)
|
251
|
-
|
252
|
-
|
244
|
+
enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
|
245
|
+
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
253
246
|
|
254
247
|
rbln_config = RBLNConfig(
|
255
248
|
rbln_cls=cls.__name__,
|
256
|
-
compile_cfgs=[
|
249
|
+
compile_cfgs=[enc_compile_config, dec_compile_config],
|
257
250
|
rbln_kwargs=rbln_kwargs,
|
258
251
|
)
|
259
252
|
|
@@ -270,12 +263,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
270
263
|
|
271
264
|
@classmethod
|
272
265
|
def _create_runtimes(
|
273
|
-
cls,
|
266
|
+
cls,
|
267
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
268
|
+
rbln_device_map: Dict[str, int],
|
269
|
+
activate_profiler: Optional[bool] = None,
|
274
270
|
) -> List[rebel.Runtime]:
|
275
|
-
|
271
|
+
if any(model_name not in rbln_device_map for model_name in ["encoder", "decoder"]):
|
272
|
+
cls._raise_missing_compiled_file_error(["encoder", "decoder"])
|
273
|
+
|
276
274
|
return [
|
277
|
-
compiled_models[0].create_runtime(
|
278
|
-
|
275
|
+
compiled_models[0].create_runtime(
|
276
|
+
tensor_type="pt", device=rbln_device_map["encoder"], activate_profiler=activate_profiler
|
277
|
+
),
|
278
|
+
compiled_models[1].create_runtime(
|
279
|
+
tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
|
280
|
+
),
|
279
281
|
]
|
280
282
|
|
281
283
|
def can_generate(self):
|
@@ -340,57 +342,11 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
340
342
|
attention_mask=dec_attention_mask,
|
341
343
|
encoder_attention_mask=attention_mask,
|
342
344
|
cache_position=cache_position,
|
343
|
-
batch_position=torch.tensor(0, dtype=torch.int32),
|
344
345
|
)
|
345
|
-
lm_logits = decoder_output.logits
|
346
|
+
lm_logits = decoder_output.logits
|
346
347
|
|
347
348
|
return Seq2SeqLMOutput(logits=lm_logits)
|
348
349
|
|
349
|
-
def vllm_forward(
|
350
|
-
self,
|
351
|
-
input_ids: torch.LongTensor = None,
|
352
|
-
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
353
|
-
batch_idx: Optional[torch.LongTensor] = None,
|
354
|
-
enc_lengths: List[int] = None, # vllm return current attention_mask length
|
355
|
-
**kwargs,
|
356
|
-
) -> Tuple[torch.FloatTensor]:
|
357
|
-
# When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
|
358
|
-
# encoder
|
359
|
-
if batch_idx is not None:
|
360
|
-
enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
|
361
|
-
enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
|
362
|
-
padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
|
363
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
|
364
|
-
_ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
|
365
|
-
logits = torch.zeros(1, 1, self.config.vocab_size + 100)
|
366
|
-
logits[0][0][-1] = 1
|
367
|
-
# decoder
|
368
|
-
else:
|
369
|
-
input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
|
370
|
-
cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
|
371
|
-
|
372
|
-
enc_attention_mask = torch.zeros(
|
373
|
-
self.rbln_config.model_cfg["batch_size"],
|
374
|
-
self.rbln_config.model_cfg["enc_max_seq_len"],
|
375
|
-
dtype=torch.float32,
|
376
|
-
)
|
377
|
-
dec_attention_mask = torch.zeros(
|
378
|
-
self.rbln_config.model_cfg["batch_size"],
|
379
|
-
self.rbln_config.model_cfg["dec_max_seq_len"],
|
380
|
-
dtype=torch.float32,
|
381
|
-
)
|
382
|
-
for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
|
383
|
-
enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
|
384
|
-
|
385
|
-
logits = self._forward_decoder(
|
386
|
-
attention_mask=enc_attention_mask,
|
387
|
-
decoder_input_ids=input_ids,
|
388
|
-
decoder_attention_mask=dec_attention_mask,
|
389
|
-
cache_position=cache_position,
|
390
|
-
).logits
|
391
|
-
|
392
|
-
return Seq2SeqLMOutput(logits=logits)
|
393
|
-
|
394
350
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
395
351
|
self,
|
396
352
|
inputs_tensor: torch.Tensor,
|
@@ -426,15 +382,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
426
382
|
)
|
427
383
|
|
428
384
|
# 3. make sure that encoder returns `ModelOutput`
|
429
|
-
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
430
385
|
encoder_kwargs["return_dict"] = True
|
431
386
|
encoder_kwargs["output_hidden_states"] = False
|
432
387
|
encoder_kwargs["output_attentions"] = False
|
433
388
|
|
434
389
|
for b in range(batch_size):
|
435
|
-
|
390
|
+
batch_position = torch.tensor(b, dtype=torch.int16)
|
436
391
|
encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
|
437
392
|
encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
438
|
-
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs,
|
393
|
+
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, batch_position=batch_position)
|
439
394
|
|
440
395
|
return model_kwargs
|