optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +22 -12
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +22 -2
- optimum/rbln/diffusers/models/__init__.py +34 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
- optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
- optimum/rbln/diffusers/pipelines/__init__.py +22 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
- optimum/rbln/modeling.py +572 -0
- optimum/rbln/modeling_alias.py +1 -1
- optimum/rbln/modeling_base.py +164 -758
- optimum/rbln/modeling_diffusers.py +51 -122
- optimum/rbln/transformers/__init__.py +0 -2
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
- optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
- optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
- optimum/rbln/utils/decorator_utils.py +10 -6
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +15 -1
- optimum/rbln/utils/model_utils.py +53 -0
- optimum/rbln/utils/runtime_utils.py +1 -1
- optimum/rbln/utils/submodule.py +114 -0
- optimum_rbln-0.1.15.dist-info/METADATA +106 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,302 +21,102 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import
|
25
|
-
from typing import Dict, Optional, Tuple
|
24
|
+
from typing import TYPE_CHECKING, Optional, Tuple
|
26
25
|
|
27
26
|
import torch
|
28
|
-
|
29
|
-
from transformers.modeling_outputs import (
|
30
|
-
BaseModelOutputWithPast,
|
31
|
-
)
|
27
|
+
from transformers import PhiForCausalLM
|
32
28
|
|
33
|
-
from
|
34
|
-
|
29
|
+
from ..decoderonly.decoderonly_architecture import (
|
30
|
+
DecoderOnlyAttention,
|
31
|
+
DecoderOnlyForCausalLM,
|
32
|
+
DecoderOnlyLayer,
|
33
|
+
DecoderOnlyModel,
|
35
34
|
DecoderOnlyWrapper,
|
36
|
-
|
37
|
-
slice_and_unsqueeze_cos_sin,
|
35
|
+
apply_rotary_pos_emb_partial,
|
38
36
|
)
|
39
37
|
|
40
38
|
|
41
|
-
|
42
|
-
|
43
|
-
forward_dict = {}
|
44
|
-
forward_dict.update(
|
45
|
-
{
|
46
|
-
"wrapper": PhiModel.forward,
|
47
|
-
"model": PhiDecoderLayer.forward,
|
48
|
-
"decoder_layer": PhiAttention.forward,
|
49
|
-
}
|
50
|
-
)
|
51
|
-
return forward_dict
|
52
|
-
|
53
|
-
|
54
|
-
class PhiAttention:
|
55
|
-
def _attn(self, query_state, key_state, value_state, attn_mask, past_key_value, batch_idx=0, is_prefill=False):
|
56
|
-
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
57
|
-
key_state = key_state.unsqueeze(2)
|
58
|
-
value_state = value_state.unsqueeze(2)
|
59
|
-
attn_mask = attn_mask.unsqueeze(2)
|
60
|
-
|
61
|
-
query_state = query_state.view(
|
62
|
-
1,
|
63
|
-
self.num_key_value_heads,
|
64
|
-
self.num_heads // self.num_key_value_heads,
|
65
|
-
-1,
|
66
|
-
self.head_dim,
|
67
|
-
)
|
68
|
-
|
69
|
-
key_state, value_state = past_key_value.update(key_state, value_state, self.layer_idx, batch_idx, is_prefill)
|
70
|
-
|
71
|
-
# Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
|
72
|
-
attn_weights = torch.matmul(
|
73
|
-
query_state.to(torch.float32),
|
74
|
-
key_state.to(torch.float32).transpose(3, 4),
|
75
|
-
) / math.sqrt(self.head_dim)
|
76
|
-
attn_weights = attn_weights + attn_mask
|
77
|
-
|
78
|
-
# upcast attention to fp32
|
79
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_state.dtype)
|
80
|
-
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
81
|
-
attn_output = torch.matmul(attn_weights, value_state)
|
82
|
-
|
83
|
-
# reshape for removing repeat_kv
|
84
|
-
attn_output = attn_output.view(1, self.num_heads, -1, self.head_dim)
|
85
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
86
|
-
attn_output = attn_output.reshape(1, -1, self.num_heads * self.head_dim)
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from transformers import PhiForCausalLM
|
87
41
|
|
88
|
-
return attn_output, key_state, value_state
|
89
|
-
|
90
|
-
def forward(
|
91
|
-
self,
|
92
|
-
hidden_states: torch.Tensor,
|
93
|
-
attention_mask: Optional[torch.Tensor] = None,
|
94
|
-
past_key_value: Optional[RebelDynamicCache] = None,
|
95
|
-
batch_index: Optional[int] = None,
|
96
|
-
output_attentions: bool = False,
|
97
|
-
cos: Optional[torch.Tensor] = None,
|
98
|
-
sin: Optional[torch.Tensor] = None,
|
99
|
-
**kwargs,
|
100
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
101
|
-
bsz, q_len, _ = hidden_states.size()
|
102
42
|
|
43
|
+
class PhiWrapper(DecoderOnlyWrapper):
|
44
|
+
def convert_to_rbln_causal_lm(self, causal_lm: "PhiForCausalLM"):
|
45
|
+
new_layers = []
|
46
|
+
for layer in causal_lm.model.layers:
|
47
|
+
if self.attn_impl == "eager":
|
48
|
+
new_self_attn = PhiAttention(layer.self_attn)
|
49
|
+
elif self.attn_impl == "flash_attn":
|
50
|
+
raise NotImplementedError(f"flash attn for {self.__class__} is not implemented yet.")
|
51
|
+
else:
|
52
|
+
raise NotImplementedError(f"Unknwon attn : {self.attn_impl}")
|
53
|
+
new_layer = PhiLayer(layer, new_self_attn)
|
54
|
+
new_layers.append(new_layer)
|
55
|
+
new_model = PhiModel(causal_lm.model, new_layers)
|
56
|
+
new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
|
57
|
+
return new_causal_lm
|
58
|
+
|
59
|
+
|
60
|
+
class PhiAttention(DecoderOnlyAttention):
|
61
|
+
def __post_init__(self):
|
62
|
+
self.q_proj = self._original_mod.q_proj
|
63
|
+
self.k_proj = self._original_mod.k_proj
|
64
|
+
self.v_proj = self._original_mod.v_proj
|
65
|
+
self.o_proj = self._original_mod.dense
|
66
|
+
self.qk_layernorm = self._original_mod.qk_layernorm
|
67
|
+
self.rotary_ndims = self._original_mod.rotary_ndims
|
68
|
+
self.num_key_value_heads = self.num_heads
|
69
|
+
|
70
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
103
71
|
query_states = self.q_proj(hidden_states)
|
104
72
|
key_states = self.k_proj(hidden_states)
|
105
73
|
value_states = self.v_proj(hidden_states)
|
106
74
|
|
107
75
|
if self.qk_layernorm:
|
108
|
-
query_states = self.q_layernorm(query_states)
|
109
|
-
key_states = self.k_layernorm(key_states)
|
110
|
-
|
111
|
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
112
|
-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
113
|
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
114
|
-
|
115
|
-
# Partial rotary embedding
|
116
|
-
query_rot, query_pass = (
|
117
|
-
query_states[..., : self.rotary_ndims],
|
118
|
-
query_states[..., self.rotary_ndims :],
|
119
|
-
)
|
120
|
-
key_rot, key_pass = (
|
121
|
-
key_states[..., : self.rotary_ndims],
|
122
|
-
key_states[..., self.rotary_ndims :],
|
123
|
-
)
|
76
|
+
query_states = self._original_mod.q_layernorm(query_states)
|
77
|
+
key_states = self._original_mod.k_layernorm(key_states)
|
124
78
|
|
125
|
-
|
126
|
-
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
79
|
+
return query_states, key_states, value_states
|
127
80
|
|
128
|
-
|
129
|
-
query_states
|
130
|
-
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
81
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
82
|
+
return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=self.rotary_ndims)
|
131
83
|
|
132
|
-
# Decoder (bsz > 1)
|
133
|
-
if bsz > 1:
|
134
|
-
iterate_results = {"key_states": [], "value_states": [], "attn_output": []}
|
135
|
-
for b in range(bsz):
|
136
|
-
attn_output, key_state, value_state = PhiAttention._attn(
|
137
|
-
self,
|
138
|
-
query_states[b].unsqueeze(0),
|
139
|
-
key_states[b].unsqueeze(0),
|
140
|
-
value_states[b].unsqueeze(0),
|
141
|
-
attention_mask[b].unsqueeze(0),
|
142
|
-
past_key_value,
|
143
|
-
batch_idx=b,
|
144
|
-
is_prefill=False,
|
145
|
-
)
|
146
|
-
iterate_results["key_states"].append(key_state)
|
147
|
-
iterate_results["value_states"].append(value_state)
|
148
|
-
iterate_results["attn_output"].append(attn_output)
|
149
84
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
# Prefill & Decoder (bsz == 1)
|
154
|
-
else:
|
155
|
-
attn_output, key_states, value_states = PhiAttention._attn(
|
156
|
-
self,
|
157
|
-
query_states,
|
158
|
-
key_states,
|
159
|
-
value_states,
|
160
|
-
attention_mask,
|
161
|
-
past_key_value,
|
162
|
-
batch_idx=batch_index,
|
163
|
-
is_prefill=True,
|
164
|
-
)
|
85
|
+
class PhiLayer(DecoderOnlyLayer):
|
86
|
+
def get_post_attention_layernorm(self):
|
87
|
+
raise NotImplementedError
|
165
88
|
|
166
|
-
attn_output = self.dense(attn_output)
|
167
|
-
|
168
|
-
if not output_attentions:
|
169
|
-
attn_weights = None
|
170
|
-
|
171
|
-
return attn_output, attn_weights, key_states, value_states
|
172
|
-
|
173
|
-
|
174
|
-
class PhiDecoderLayer:
|
175
89
|
def forward(
|
176
90
|
self,
|
177
91
|
hidden_states: torch.Tensor,
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
output_attentions: Optional[bool] = None,
|
183
|
-
use_cache: Optional[bool] = None,
|
184
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
92
|
+
attention_mask: torch.Tensor,
|
93
|
+
current_steps: torch.LongTensor,
|
94
|
+
batch_position: torch.Tensor,
|
95
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
185
96
|
cos: Optional[torch.Tensor] = None,
|
186
97
|
sin: Optional[torch.Tensor] = None,
|
187
|
-
|
188
|
-
**kwargs,
|
189
|
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
190
|
-
"""
|
191
|
-
Args:
|
192
|
-
hidden_states (`torch.FloatTensor`):
|
193
|
-
input to the layer of shape `(batch, seq_len, embed_dim)`
|
194
|
-
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
195
|
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
196
|
-
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
197
|
-
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
198
|
-
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
|
199
|
-
output_attentions (`bool`, *optional*):
|
200
|
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
201
|
-
returned tensors for more detail.
|
202
|
-
use_cache (`bool`, *optional*):
|
203
|
-
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
204
|
-
(see `past_key_values`).
|
205
|
-
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
206
|
-
"""
|
207
|
-
|
98
|
+
):
|
208
99
|
residual = hidden_states
|
209
100
|
|
210
|
-
hidden_states = self.
|
101
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
211
102
|
|
212
|
-
|
213
|
-
attn_outputs, self_attn_weights, key_states, value_states = forward_dict["decoder_layer"](
|
214
|
-
self.self_attn,
|
103
|
+
attn_outputs, present_key_values = self.self_attn(
|
215
104
|
hidden_states=hidden_states,
|
216
105
|
attention_mask=attention_mask,
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
batch_index=batch_ids,
|
221
|
-
use_cache=use_cache,
|
106
|
+
current_steps=current_steps,
|
107
|
+
batch_position=batch_position,
|
108
|
+
past_key_values=past_key_values,
|
222
109
|
cos=cos,
|
223
110
|
sin=sin,
|
224
|
-
**kwargs,
|
225
111
|
)
|
226
|
-
past_key_value.assign(key_states, value_states, layer_idx)
|
227
112
|
|
228
|
-
|
113
|
+
feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
|
229
114
|
|
230
|
-
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
|
231
115
|
hidden_states = attn_outputs + feed_forward_hidden_states + residual
|
232
|
-
outputs = (hidden_states,)
|
233
|
-
|
234
|
-
if output_attentions:
|
235
|
-
outputs += (self_attn_weights,)
|
236
|
-
|
237
|
-
if use_cache:
|
238
|
-
outputs += (past_key_value,)
|
239
|
-
|
240
|
-
return outputs
|
241
|
-
|
242
|
-
|
243
|
-
class PhiModel:
|
244
|
-
def forward(
|
245
|
-
self,
|
246
|
-
input_ids: torch.LongTensor = None,
|
247
|
-
attention_mask: Optional[torch.Tensor] = None,
|
248
|
-
position_ids: Optional[torch.LongTensor] = None,
|
249
|
-
past_key_values: Optional[RebelDynamicCache] = None,
|
250
|
-
batch_ids: Optional[torch.LongTensor] = None,
|
251
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
252
|
-
use_cache: Optional[bool] = True,
|
253
|
-
output_attentions: Optional[bool] = False,
|
254
|
-
output_hidden_states: Optional[bool] = False,
|
255
|
-
cache_pos_for_partitions: Optional[torch.Tensor] = None,
|
256
|
-
kvcache_partition_size: Optional[torch.Tensor] = None,
|
257
|
-
forward_dict: Optional[Dict[str, classmethod]] = None,
|
258
|
-
rotary_pos_emb=None,
|
259
|
-
) -> BaseModelOutputWithPast:
|
260
|
-
# retrieve input_ids and inputs_embeds
|
261
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
262
|
-
raise ValueError(
|
263
|
-
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
264
|
-
)
|
265
|
-
|
266
|
-
# embed positions
|
267
|
-
if inputs_embeds is None:
|
268
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
269
|
-
|
270
|
-
hidden_states = inputs_embeds
|
271
|
-
attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
|
272
|
-
|
273
|
-
# get cos,sin vector
|
274
|
-
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
275
|
-
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
276
|
-
|
277
|
-
# decoder layers
|
278
|
-
all_hidden_states = () if output_hidden_states else None
|
279
|
-
all_self_attns = () if output_attentions else None
|
280
|
-
|
281
|
-
for layer_idx, decoder_layer in enumerate(self.layers):
|
282
|
-
if output_hidden_states:
|
283
|
-
all_hidden_states += (hidden_states,)
|
284
|
-
layer_outputs = forward_dict["model"](
|
285
|
-
decoder_layer,
|
286
|
-
hidden_states,
|
287
|
-
layer_idx,
|
288
|
-
attention_mask=attention_mask,
|
289
|
-
position_ids=position_ids,
|
290
|
-
past_key_value=past_key_values,
|
291
|
-
output_attentions=output_attentions,
|
292
|
-
use_cache=use_cache,
|
293
|
-
batch_ids=batch_ids,
|
294
|
-
cos=cos,
|
295
|
-
sin=sin,
|
296
|
-
cache_pos_for_partitions=cache_pos_for_partitions,
|
297
|
-
kvcache_partition_size=kvcache_partition_size,
|
298
|
-
forward_dict=forward_dict,
|
299
|
-
)
|
300
116
|
|
301
|
-
|
117
|
+
return hidden_states, present_key_values
|
302
118
|
|
303
|
-
updated_cache = layer_outputs[2 if output_attentions else 1]
|
304
119
|
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
hidden_states = self.final_layernorm(hidden_states)
|
309
|
-
|
310
|
-
# add hidden states from the last decoder layer
|
311
|
-
if output_hidden_states:
|
312
|
-
all_hidden_states += (hidden_states,)
|
313
|
-
|
314
|
-
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
315
|
-
next_cache = updated_cache.to_legacy_cache()
|
316
|
-
|
317
|
-
return BaseModelOutputWithPast(
|
318
|
-
last_hidden_state=hidden_states,
|
319
|
-
past_key_values=next_cache,
|
320
|
-
hidden_states=all_hidden_states,
|
321
|
-
attentions=all_self_attns,
|
322
|
-
)
|
120
|
+
class PhiModel(DecoderOnlyModel):
|
121
|
+
def get_last_layernorm(self):
|
122
|
+
return self._original_mod.final_layernorm
|
@@ -31,7 +31,7 @@ import torch # noqa: F401
|
|
31
31
|
from transformers import AutoModelForSeq2SeqLM, GenerationConfig, PretrainedConfig, PreTrainedModel
|
32
32
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
33
33
|
|
34
|
-
from ....
|
34
|
+
from ....modeling import RBLNModel
|
35
35
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
36
36
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
37
37
|
|
@@ -346,51 +346,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
|
|
346
346
|
|
347
347
|
return Seq2SeqLMOutput(logits=lm_logits)
|
348
348
|
|
349
|
-
def vllm_forward(
|
350
|
-
self,
|
351
|
-
input_ids: torch.LongTensor = None,
|
352
|
-
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
353
|
-
batch_idx: Optional[torch.LongTensor] = None,
|
354
|
-
enc_lengths: List[int] = None, # vllm return current attention_mask length
|
355
|
-
**kwargs,
|
356
|
-
) -> Tuple[torch.FloatTensor]:
|
357
|
-
# When using vllm, need the output of the encoder (ex. vocab_size + 100) and use that value act as start_token_id in decoder (ex. vocab_size + 99)
|
358
|
-
# encoder
|
359
|
-
if batch_idx is not None:
|
360
|
-
enc_attention_mask = torch.zeros(1, self.rbln_config.model_cfg["enc_max_seq_len"], dtype=torch.float32)
|
361
|
-
enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
|
362
|
-
padding_need = self.rbln_config.model_cfg["enc_max_seq_len"] - input_ids.shape[-1]
|
363
|
-
input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
|
364
|
-
_ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
|
365
|
-
logits = torch.zeros(1, 1, self.config.vocab_size + 100)
|
366
|
-
logits[0][0][-1] = 1
|
367
|
-
# decoder
|
368
|
-
else:
|
369
|
-
input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
|
370
|
-
cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
|
371
|
-
|
372
|
-
enc_attention_mask = torch.zeros(
|
373
|
-
self.rbln_config.model_cfg["batch_size"],
|
374
|
-
self.rbln_config.model_cfg["enc_max_seq_len"],
|
375
|
-
dtype=torch.float32,
|
376
|
-
)
|
377
|
-
dec_attention_mask = torch.zeros(
|
378
|
-
self.rbln_config.model_cfg["batch_size"],
|
379
|
-
self.rbln_config.model_cfg["dec_max_seq_len"],
|
380
|
-
dtype=torch.float32,
|
381
|
-
)
|
382
|
-
for batch_idx in range(self.rbln_config.model_cfg["batch_size"]):
|
383
|
-
enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
|
384
|
-
|
385
|
-
logits = self._forward_decoder(
|
386
|
-
attention_mask=enc_attention_mask,
|
387
|
-
decoder_input_ids=input_ids,
|
388
|
-
decoder_attention_mask=dec_attention_mask,
|
389
|
-
cache_position=cache_position,
|
390
|
-
).logits
|
391
|
-
|
392
|
-
return Seq2SeqLMOutput(logits=logits)
|
393
|
-
|
394
349
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
395
350
|
self,
|
396
351
|
inputs_tensor: torch.Tensor,
|
@@ -22,17 +22,23 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import inspect
|
25
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
25
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
26
26
|
|
27
|
+
import torch
|
28
|
+
import transformers
|
27
29
|
from transformers import (
|
28
30
|
AutoModelForTextEncoding,
|
29
31
|
PretrainedConfig,
|
32
|
+
T5EncoderModel,
|
30
33
|
T5ForConditionalGeneration,
|
31
34
|
)
|
35
|
+
from transformers.modeling_outputs import BaseModelOutput
|
32
36
|
|
33
|
-
from ....
|
37
|
+
from ....modeling import RBLNModel
|
34
38
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
39
|
+
from ....modeling_diffusers import RBLNDiffusionMixin
|
35
40
|
from ....utils.logging import get_logger
|
41
|
+
from ....utils.runtime_utils import RBLNPytorchRuntime
|
36
42
|
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
37
43
|
from .t5_architecture import T5Wrapper
|
38
44
|
|
@@ -43,8 +49,60 @@ if TYPE_CHECKING:
|
|
43
49
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
44
50
|
|
45
51
|
|
52
|
+
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
53
|
+
def forward(
|
54
|
+
self,
|
55
|
+
input_ids: torch.LongTensor,
|
56
|
+
attention_mask: torch.FloatTensor,
|
57
|
+
head_mask: torch.FloatTensor,
|
58
|
+
inputs_embeds: torch.FloatTensor,
|
59
|
+
**kwargs,
|
60
|
+
):
|
61
|
+
return super().forward(
|
62
|
+
input_ids,
|
63
|
+
attention_mask,
|
64
|
+
head_mask,
|
65
|
+
inputs_embeds,
|
66
|
+
**kwargs,
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
class T5EncoderWrapper(torch.nn.Module):
|
71
|
+
def __init__(self, model: "T5EncoderModel") -> None:
|
72
|
+
super().__init__()
|
73
|
+
self.model = model
|
74
|
+
|
75
|
+
def forward(self, *args, **kwargs):
|
76
|
+
kwargs.pop("return_dict", None)
|
77
|
+
return self.model(*args, **kwargs, return_dict=False)
|
78
|
+
|
79
|
+
|
46
80
|
class RBLNT5EncoderModel(RBLNModel):
|
47
81
|
auto_model_class = AutoModelForTextEncoding
|
82
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
83
|
+
|
84
|
+
def __post_init__(self, **kwargs):
|
85
|
+
self.model = RBLNRuntimeModel(runtime=self.model[0])
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
89
|
+
return T5EncoderWrapper(model)
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def update_rbln_config_using_pipe(cls, pipe: RBLNDiffusionMixin, rbln_config: Dict[str, Any]) -> Dict[str, Any]:
|
93
|
+
batch_size = rbln_config.get("batch_size", 1)
|
94
|
+
max_sequence_length = rbln_config.get("max_sequence_length", 256)
|
95
|
+
model_input_names = ["input_ids"]
|
96
|
+
|
97
|
+
rbln_config.update(
|
98
|
+
{
|
99
|
+
"batch_size": batch_size,
|
100
|
+
"max_seq_len": max_sequence_length,
|
101
|
+
"model_input_names": model_input_names,
|
102
|
+
}
|
103
|
+
)
|
104
|
+
|
105
|
+
return rbln_config
|
48
106
|
|
49
107
|
@classmethod
|
50
108
|
def _get_rbln_config(
|
@@ -54,6 +112,7 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
54
112
|
rbln_kwargs: Dict[str, Any] = {},
|
55
113
|
) -> RBLNConfig:
|
56
114
|
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
115
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
57
116
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
58
117
|
|
59
118
|
max_position_embeddings = getattr(model_config, "n_positions", None)
|
@@ -71,12 +130,27 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
71
130
|
if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
|
72
131
|
raise ValueError("`rbln_max_seq_len` should be less or equal than max_position_embeddings!")
|
73
132
|
|
133
|
+
if rbln_model_input_names is None:
|
134
|
+
for tokenizer in preprocessors:
|
135
|
+
if hasattr(tokenizer, "model_input_names"):
|
136
|
+
rbln_model_input_names = tokenizer.model_input_names
|
137
|
+
break
|
138
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
139
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
140
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
141
|
+
original_model_class = getattr(transformers, model_config.architectures[0])
|
142
|
+
input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
|
143
|
+
raise ValueError(
|
144
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
145
|
+
f"and be sure to make the order of the inputs same as T5EncoderModel forward() arguments like ({list(input_names_order)})"
|
146
|
+
)
|
147
|
+
|
74
148
|
if rbln_batch_size is None:
|
75
149
|
rbln_batch_size = 1
|
76
150
|
|
77
151
|
input_info = [
|
78
|
-
(
|
79
|
-
|
152
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
153
|
+
for model_input_name in rbln_model_input_names
|
80
154
|
]
|
81
155
|
|
82
156
|
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
@@ -90,6 +164,30 @@ class RBLNT5EncoderModel(RBLNModel):
|
|
90
164
|
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
91
165
|
return rbln_config
|
92
166
|
|
167
|
+
def forward(
|
168
|
+
self,
|
169
|
+
input_ids: Optional[torch.LongTensor] = None,
|
170
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
171
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
172
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
173
|
+
output_attentions: Optional[bool] = None,
|
174
|
+
output_hidden_states: Optional[bool] = None,
|
175
|
+
return_dict: Optional[bool] = None,
|
176
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
177
|
+
encoder_outputs = self.model(
|
178
|
+
input_ids=input_ids,
|
179
|
+
attention_mask=attention_mask,
|
180
|
+
inputs_embeds=inputs_embeds,
|
181
|
+
head_mask=head_mask,
|
182
|
+
output_attentions=output_attentions,
|
183
|
+
output_hidden_states=output_hidden_states,
|
184
|
+
return_dict=return_dict,
|
185
|
+
)
|
186
|
+
if not return_dict:
|
187
|
+
return (encoder_outputs,)
|
188
|
+
else:
|
189
|
+
return BaseModelOutput(last_hidden_state=encoder_outputs)
|
190
|
+
|
93
191
|
|
94
192
|
class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
95
193
|
@classmethod
|
@@ -28,7 +28,7 @@ import torch
|
|
28
28
|
from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
|
29
29
|
from transformers.modeling_outputs import CausalLMOutput
|
30
30
|
|
31
|
-
from ....
|
31
|
+
from ....modeling import RBLNModel
|
32
32
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
33
33
|
|
34
34
|
|
@@ -36,7 +36,7 @@ from transformers import (
|
|
36
36
|
)
|
37
37
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
38
38
|
|
39
|
-
from ....
|
39
|
+
from ....modeling import RBLNModel
|
40
40
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
41
41
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
42
42
|
from .generation_whisper import RBLNWhisperGenerationMixin
|
@@ -22,12 +22,12 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING,
|
25
|
+
from typing import TYPE_CHECKING, Optional, Union
|
26
26
|
|
27
27
|
import torch
|
28
|
-
from transformers import PretrainedConfig
|
28
|
+
from transformers import PretrainedConfig
|
29
29
|
|
30
|
-
from ....
|
30
|
+
from ....modeling import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|
@@ -38,38 +38,6 @@ if TYPE_CHECKING:
|
|
38
38
|
|
39
39
|
|
40
40
|
class RBLNXLMRobertaModel(RBLNModel):
|
41
|
-
original_model_class = XLMRobertaModel
|
42
|
-
original_config_class = XLMRobertaConfig
|
43
|
-
|
44
|
-
@classmethod
|
45
|
-
def get_pytorch_model(
|
46
|
-
cls,
|
47
|
-
model_id: str,
|
48
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
49
|
-
revision: Optional[str] = None,
|
50
|
-
force_download: bool = False,
|
51
|
-
cache_dir: Optional[str] = None,
|
52
|
-
subfolder: str = "",
|
53
|
-
local_files_only: bool = False,
|
54
|
-
trust_remote_code: bool = False,
|
55
|
-
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
56
|
-
**kwargs,
|
57
|
-
) -> "PreTrainedModel":
|
58
|
-
model: "PreTrainedModel" = super().get_pytorch_model(
|
59
|
-
model_id=model_id,
|
60
|
-
use_auth_token=use_auth_token,
|
61
|
-
revision=revision,
|
62
|
-
force_download=force_download,
|
63
|
-
cache_dir=cache_dir,
|
64
|
-
subfolder=subfolder,
|
65
|
-
local_files_only=local_files_only,
|
66
|
-
trust_remote_code=trust_remote_code,
|
67
|
-
rbln_kwargs=rbln_kwargs,
|
68
|
-
library_name="transformers",
|
69
|
-
)
|
70
|
-
|
71
|
-
return model
|
72
|
-
|
73
41
|
@classmethod
|
74
42
|
def _get_rbln_config(
|
75
43
|
cls,
|