optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +21 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +282 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
- optimum_rbln-0.1.8.dist-info/RECORD +73 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
- optimum_rbln-0.1.4.dist-info/RECORD +0 -63
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,233 +21,259 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import Optional, Tuple, Union
|
24
|
+
from typing import Dict, Optional, Tuple, Union
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from transformers.modeling_outputs import
|
29
|
-
BaseModelOutputWithPast,
|
30
|
-
BaseModelOutputWithPastAndCrossAttentions,
|
31
|
-
)
|
32
|
-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
28
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
33
29
|
|
30
|
+
from ...cache_utils import RebelDynamicCache_4D
|
34
31
|
|
35
|
-
class _GPT2Attention(GPT2Attention):
|
36
|
-
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
37
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
38
|
-
|
39
|
-
if self.scale_attn_weights:
|
40
|
-
attn_weights = attn_weights / torch.full(
|
41
|
-
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
42
|
-
)
|
43
32
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
33
|
+
class GPT2LMHeadModelWrapper(torch.nn.Module):
|
34
|
+
def __init__(self, model, max_seq_len):
|
35
|
+
super().__init__()
|
36
|
+
self.model = model.transformer
|
37
|
+
self.lm_head = model.lm_head
|
38
|
+
self.config = model.config
|
39
|
+
self.max_seq_len = max_seq_len
|
40
|
+
self.forward_dict = self.get_forward_dict()
|
41
|
+
|
42
|
+
def get_forward_dict(self):
|
43
|
+
forward_dict = {
|
44
|
+
"wrapper": _GPT2Model.forward,
|
45
|
+
"model": _GPT2Block.forward,
|
46
|
+
"decoder_layer": _GPT2Attention.forward,
|
47
|
+
}
|
48
|
+
return forward_dict
|
60
49
|
|
61
|
-
|
62
|
-
|
63
|
-
|
50
|
+
def forward(
|
51
|
+
self,
|
52
|
+
input_ids,
|
53
|
+
attention_mask,
|
54
|
+
cache_position,
|
55
|
+
batch_position,
|
56
|
+
*past_key_values,
|
57
|
+
):
|
58
|
+
if input_ids.shape[1] == 1:
|
59
|
+
rbln_batch_position = None
|
60
|
+
else:
|
61
|
+
rbln_batch_position = batch_position
|
62
|
+
|
63
|
+
# Formatting list of past_kv to DynamicCache class.
|
64
|
+
past_key_value = RebelDynamicCache_4D.from_input_format(
|
65
|
+
cache_position,
|
66
|
+
self.config.n_layer,
|
67
|
+
*past_key_values,
|
68
|
+
)
|
64
69
|
|
65
|
-
|
70
|
+
outputs = self.forward_dict["wrapper"](
|
71
|
+
self.model,
|
72
|
+
input_ids=input_ids,
|
73
|
+
attention_mask=attention_mask,
|
74
|
+
position_ids=cache_position,
|
75
|
+
past_key_value=past_key_value,
|
76
|
+
batch_ids=rbln_batch_position,
|
77
|
+
forward_dict=self.forward_dict,
|
78
|
+
# rotary_emb differenct from_llama
|
79
|
+
)
|
66
80
|
|
67
|
-
|
68
|
-
|
69
|
-
# attn_weights = self.attn_dropout(attn_weights)
|
81
|
+
hidden_states = outputs[0]
|
82
|
+
logits = self.lm_head(hidden_states)
|
70
83
|
|
71
|
-
|
72
|
-
if head_mask is not None:
|
73
|
-
attn_weights = attn_weights * head_mask
|
84
|
+
output = (logits,) + outputs[1:]
|
74
85
|
|
75
|
-
|
86
|
+
return output, batch_position
|
76
87
|
|
77
|
-
return attn_output, attn_weights
|
78
88
|
|
89
|
+
class _GPT2Model:
|
79
90
|
def forward(
|
80
91
|
self,
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
91
|
-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
92
|
+
input_ids: torch.LongTensor = None,
|
93
|
+
attention_mask: Optional[torch.Tensor] = None,
|
94
|
+
position_ids: Optional[torch.LongTensor] = None,
|
95
|
+
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
96
|
+
batch_ids: Optional[torch.LongTensor] = None,
|
97
|
+
forward_dict: Optional[Dict[str, classmethod]] = None,
|
98
|
+
) -> BaseModelOutputWithPast:
|
99
|
+
b_size, q_len = input_ids.shape
|
100
|
+
inputs_embeds = self.wte(input_ids)
|
92
101
|
|
93
|
-
|
94
|
-
|
95
|
-
|
102
|
+
if position_ids.shape[0] > 1:
|
103
|
+
position_embeds = []
|
104
|
+
for b_idx in range(b_size):
|
105
|
+
position_embed = self.wpe(position_ids[b_idx])
|
106
|
+
# position_embed = position_embed.dtype(inputs_embeds.dtype)
|
107
|
+
position_embeds.append(position_embed)
|
96
108
|
|
97
|
-
|
98
|
-
|
99
|
-
|
109
|
+
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
110
|
+
else:
|
111
|
+
position_embeds = self.wpe(position_ids)
|
100
112
|
|
101
|
-
|
102
|
-
value = torch.slice_scatter(
|
103
|
-
past_value, value, dim=2, start=cache_position, end=cache_position + query_length
|
104
|
-
)
|
113
|
+
hidden_states = inputs_embeds + position_embeds
|
105
114
|
|
106
|
-
|
107
|
-
|
115
|
+
# GPT2Attention mask.
|
116
|
+
# Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
|
117
|
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
108
118
|
|
109
|
-
|
119
|
+
for layer_idx, block in enumerate(self.h):
|
120
|
+
hidden_states, updated_cache = forward_dict["model"](
|
121
|
+
block,
|
122
|
+
hidden_states,
|
123
|
+
layer_idx,
|
124
|
+
attention_mask=attention_mask,
|
125
|
+
past_key_value=past_key_value,
|
126
|
+
position_ids=position_ids,
|
127
|
+
batch_ids=batch_ids,
|
128
|
+
forward_dict=forward_dict,
|
129
|
+
)
|
110
130
|
|
111
|
-
|
112
|
-
|
131
|
+
hidden_states = self.ln_f(hidden_states)
|
132
|
+
output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
|
133
|
+
hidden_states = hidden_states.view(output_shape)
|
113
134
|
|
114
|
-
|
135
|
+
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
136
|
+
next_cache = updated_cache.to_legacy_cache()
|
115
137
|
|
116
|
-
return
|
138
|
+
return BaseModelOutputWithPast(
|
139
|
+
last_hidden_state=hidden_states,
|
140
|
+
past_key_values=next_cache,
|
141
|
+
)
|
117
142
|
|
118
143
|
|
119
|
-
class _GPT2Block
|
144
|
+
class _GPT2Block:
|
120
145
|
def forward(
|
121
146
|
self,
|
122
147
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
123
|
-
|
148
|
+
layer_idx: int,
|
124
149
|
attention_mask: Optional[torch.FloatTensor] = None,
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
150
|
+
position_ids: Optional[torch.LongTensor] = None,
|
151
|
+
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
152
|
+
batch_ids: Optional[torch.LongTensor] = None,
|
153
|
+
forward_dict: Optional[Dict[str, classmethod]] = None,
|
154
|
+
**kwargs,
|
155
|
+
) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
|
132
156
|
residual = hidden_states
|
133
157
|
hidden_states = self.ln_1(hidden_states)
|
134
158
|
|
135
|
-
|
159
|
+
hidden_states, k, v = forward_dict["decoder_layer"](
|
136
160
|
self.attn,
|
137
|
-
hidden_states,
|
138
|
-
layer_past=layer_past,
|
161
|
+
hidden_states=hidden_states,
|
139
162
|
attention_mask=attention_mask,
|
140
|
-
|
141
|
-
|
163
|
+
position_ids=position_ids,
|
164
|
+
past_key_value=past_key_value,
|
165
|
+
batch_index=batch_ids,
|
142
166
|
)
|
167
|
+
past_key_value.assign(k, v, layer_idx)
|
143
168
|
|
144
|
-
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
145
|
-
outputs = attn_outputs[1:]
|
146
169
|
# residual connection
|
147
|
-
hidden_states =
|
170
|
+
hidden_states = residual + hidden_states
|
148
171
|
|
149
172
|
residual = hidden_states
|
150
173
|
hidden_states = self.ln_2(hidden_states)
|
151
|
-
|
152
|
-
|
153
|
-
hidden_states = residual + feed_forward_hidden_states
|
154
|
-
|
155
|
-
outputs = (hidden_states,) + outputs
|
156
|
-
return outputs # hidden_states, present, (attentions, cross_attentions)
|
174
|
+
hidden_states = self.mlp(hidden_states)
|
175
|
+
hidden_states = residual + hidden_states
|
157
176
|
|
177
|
+
return hidden_states, past_key_value
|
158
178
|
|
159
|
-
class _GPT2Model(GPT2Model):
|
160
|
-
def forward(
|
161
|
-
self,
|
162
|
-
input_ids: Optional[torch.LongTensor] = None,
|
163
|
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
164
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
165
|
-
position_ids: Optional[torch.LongTensor] = None,
|
166
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
167
|
-
cache_position: Optional[torch.LongTensor] = None,
|
168
|
-
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
169
|
-
input_shape = input_ids.size()
|
170
179
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
position_ids = position_ids.unsqueeze(0)
|
175
|
-
|
176
|
-
# GPT2Attention mask.
|
177
|
-
# Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
|
178
|
-
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
179
|
-
|
180
|
-
# Prepare head mask if needed
|
181
|
-
# 1.0 in head_mask indicate we keep the head
|
182
|
-
# attention_probs has shape bsz x n_heads x N x N
|
183
|
-
# head_mask has shape n_layer x batch x n_heads x N x N
|
184
|
-
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
185
|
-
|
186
|
-
inputs_embeds = self.wte(input_ids)
|
187
|
-
position_embeds = self.wpe(position_ids)
|
188
|
-
hidden_states = inputs_embeds + position_embeds
|
189
|
-
|
190
|
-
hidden_states = self.drop(hidden_states)
|
191
|
-
|
192
|
-
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
180
|
+
class _GPT2Attention:
|
181
|
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
182
|
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
193
183
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
block,
|
198
|
-
hidden_states,
|
199
|
-
layer_past=layer_past,
|
200
|
-
attention_mask=attention_mask,
|
201
|
-
head_mask=head_mask[i],
|
202
|
-
cache_position=cache_position,
|
184
|
+
if self.scale_attn_weights:
|
185
|
+
attn_weights = attn_weights / torch.full(
|
186
|
+
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
203
187
|
)
|
204
|
-
hidden_states = outputs[0]
|
205
188
|
|
206
|
-
|
189
|
+
# Layer-wise attention scaling
|
190
|
+
if self.scale_attn_by_inverse_layer_idx:
|
191
|
+
attn_weights = attn_weights / float(self.layer_idx + 1)
|
207
192
|
|
208
|
-
|
209
|
-
|
210
|
-
|
193
|
+
# -------------------
|
194
|
+
# Below are deleted since "where" op does not supported on RBLN graph.
|
195
|
+
# -------------------
|
196
|
+
# if not self.is_cross_attention:
|
197
|
+
# # if only "normal" attention layer implements causal mask
|
198
|
+
# query_length, key_length = query.size(-2), key.size(-2)
|
199
|
+
# causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
200
|
+
# mask_value = torch.finfo(attn_weights.dtype).min
|
201
|
+
# # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
202
|
+
# # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
203
|
+
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
204
|
+
# attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
211
205
|
|
206
|
+
# Apply the attention mask
|
207
|
+
attn_weights.view(
|
208
|
+
-1,
|
209
|
+
)
|
210
|
+
attn_weights = attn_weights + attention_mask
|
211
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
212
|
+
attn_output = torch.matmul(attn_weights, value)
|
212
213
|
|
213
|
-
|
214
|
-
def __init__(self, gpt):
|
215
|
-
super().__init__()
|
216
|
-
self.model = gpt
|
214
|
+
return attn_output, attn_weights
|
217
215
|
|
218
216
|
def forward(
|
219
217
|
self,
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
transformer_outputs = _GPT2Model.forward(
|
230
|
-
self.model.transformer,
|
231
|
-
input_ids=input_ids,
|
232
|
-
past_key_values=kv_cache,
|
233
|
-
attention_mask=attention_mask,
|
234
|
-
cache_position=cache_position,
|
235
|
-
)
|
236
|
-
|
237
|
-
hidden_states = transformer_outputs[0]
|
218
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
219
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
220
|
+
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
221
|
+
batch_index: Optional[int] = None,
|
222
|
+
**kwargs,
|
223
|
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
224
|
+
bsz, q_len, _ = hidden_states.size()
|
225
|
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
238
226
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
#
|
227
|
+
querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
|
228
|
+
keys = self._split_heads(key, self.num_heads, self.head_dim)
|
229
|
+
values = self._split_heads(value, self.num_heads, self.head_dim)
|
230
|
+
|
231
|
+
# Decoder
|
232
|
+
if (batch_index is None or batch_index == -1) and bsz > 1:
|
233
|
+
all_keys = []
|
234
|
+
all_values = []
|
235
|
+
all_attn_output = []
|
236
|
+
|
237
|
+
for b in range(bsz):
|
238
|
+
query = querys[b].unsqueeze(0)
|
239
|
+
attn_mask = attention_mask[b].unsqueeze(0)
|
240
|
+
key = keys[b].unsqueeze(0)
|
241
|
+
value = values[b].unsqueeze(0)
|
242
|
+
|
243
|
+
key, value = past_key_value.update(
|
244
|
+
key,
|
245
|
+
value,
|
246
|
+
self.layer_idx,
|
247
|
+
b,
|
248
|
+
)
|
249
|
+
|
250
|
+
attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
|
251
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
252
|
+
|
253
|
+
all_keys.append(key)
|
254
|
+
all_values.append(value)
|
255
|
+
all_attn_output.append(attn_output)
|
256
|
+
|
257
|
+
keys = torch.cat(all_keys, dim=0)
|
258
|
+
values = torch.cat(all_values, dim=0)
|
259
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
260
|
+
|
261
|
+
# Prefill
|
262
|
+
else:
|
263
|
+
if batch_index is None or batch_index == -1:
|
264
|
+
batch_index = 0
|
265
|
+
|
266
|
+
keys, values = past_key_value.update(
|
267
|
+
keys,
|
268
|
+
values,
|
269
|
+
self.layer_idx,
|
270
|
+
batch_index,
|
271
|
+
read_first_step=True,
|
272
|
+
)
|
244
273
|
|
245
|
-
|
246
|
-
|
274
|
+
attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
|
275
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
247
276
|
|
248
|
-
|
249
|
-
for i in range(self.model.config.n_layer):
|
250
|
-
past_key_values.append(torch.stack(kv_cache[i], dim=0))
|
251
|
-
past_key_values = torch.stack(past_key_values, dim=0)
|
277
|
+
attn_output = self.c_proj(attn_output)
|
252
278
|
|
253
|
-
return
|
279
|
+
return attn_output, keys, values
|