optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +110 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,224 +21,259 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import Dict, Optional, Tuple, Union
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from transformers.modeling_outputs import
|
29
|
-
BaseModelOutputWithPast,
|
30
|
-
BaseModelOutputWithPastAndCrossAttentions,
|
31
|
-
)
|
32
|
-
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
|
28
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
33
29
|
|
30
|
+
from ...cache_utils import RebelDynamicCache_4D
|
34
31
|
|
35
|
-
class _GPT2Attention(GPT2Attention):
|
36
|
-
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
37
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
38
|
-
|
39
|
-
if self.scale_attn_weights:
|
40
|
-
attn_weights = attn_weights / torch.full(
|
41
|
-
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
42
|
-
)
|
43
32
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
33
|
+
class GPT2LMHeadModelWrapper(torch.nn.Module):
|
34
|
+
def __init__(self, model, max_seq_len):
|
35
|
+
super().__init__()
|
36
|
+
self.model = model.transformer
|
37
|
+
self.lm_head = model.lm_head
|
38
|
+
self.config = model.config
|
39
|
+
self.max_seq_len = max_seq_len
|
40
|
+
self.forward_dict = self.get_forward_dict()
|
41
|
+
|
42
|
+
def get_forward_dict(self):
|
43
|
+
forward_dict = {
|
44
|
+
"wrapper": _GPT2Model.forward,
|
45
|
+
"model": _GPT2Block.forward,
|
46
|
+
"decoder_layer": _GPT2Attention.forward,
|
47
|
+
}
|
48
|
+
return forward_dict
|
60
49
|
|
61
|
-
|
62
|
-
|
63
|
-
|
50
|
+
def forward(
|
51
|
+
self,
|
52
|
+
input_ids,
|
53
|
+
attention_mask,
|
54
|
+
cache_position,
|
55
|
+
batch_position,
|
56
|
+
*past_key_values,
|
57
|
+
):
|
58
|
+
if input_ids.shape[1] == 1:
|
59
|
+
rbln_batch_position = None
|
60
|
+
else:
|
61
|
+
rbln_batch_position = batch_position
|
62
|
+
|
63
|
+
# Formatting list of past_kv to DynamicCache class.
|
64
|
+
past_key_value = RebelDynamicCache_4D.from_input_format(
|
65
|
+
cache_position,
|
66
|
+
self.config.n_layer,
|
67
|
+
*past_key_values,
|
68
|
+
)
|
64
69
|
|
65
|
-
|
70
|
+
outputs = self.forward_dict["wrapper"](
|
71
|
+
self.model,
|
72
|
+
input_ids=input_ids,
|
73
|
+
attention_mask=attention_mask,
|
74
|
+
position_ids=cache_position,
|
75
|
+
past_key_value=past_key_value,
|
76
|
+
batch_ids=rbln_batch_position,
|
77
|
+
forward_dict=self.forward_dict,
|
78
|
+
# rotary_emb differenct from_llama
|
79
|
+
)
|
66
80
|
|
67
|
-
|
68
|
-
|
69
|
-
# attn_weights = self.attn_dropout(attn_weights)
|
81
|
+
hidden_states = outputs[0]
|
82
|
+
logits = self.lm_head(hidden_states)
|
70
83
|
|
71
|
-
|
72
|
-
if head_mask is not None:
|
73
|
-
attn_weights = attn_weights * head_mask
|
84
|
+
output = (logits,) + outputs[1:]
|
74
85
|
|
75
|
-
|
86
|
+
return output, batch_position
|
76
87
|
|
77
|
-
return attn_output, attn_weights
|
78
88
|
|
89
|
+
class _GPT2Model:
|
79
90
|
def forward(
|
80
91
|
self,
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
91
|
-
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
92
|
+
input_ids: torch.LongTensor = None,
|
93
|
+
attention_mask: Optional[torch.Tensor] = None,
|
94
|
+
position_ids: Optional[torch.LongTensor] = None,
|
95
|
+
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
96
|
+
batch_ids: Optional[torch.LongTensor] = None,
|
97
|
+
forward_dict: Optional[Dict[str, classmethod]] = None,
|
98
|
+
) -> BaseModelOutputWithPast:
|
99
|
+
b_size, q_len = input_ids.shape
|
100
|
+
inputs_embeds = self.wte(input_ids)
|
92
101
|
|
93
|
-
|
94
|
-
|
95
|
-
|
102
|
+
if position_ids.shape[0] > 1:
|
103
|
+
position_embeds = []
|
104
|
+
for b_idx in range(b_size):
|
105
|
+
position_embed = self.wpe(position_ids[b_idx])
|
106
|
+
# position_embed = position_embed.dtype(inputs_embeds.dtype)
|
107
|
+
position_embeds.append(position_embed)
|
96
108
|
|
97
|
-
|
98
|
-
|
99
|
-
|
109
|
+
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
110
|
+
else:
|
111
|
+
position_embeds = self.wpe(position_ids)
|
100
112
|
|
101
|
-
|
102
|
-
value = past_value.slice_scatter(value, dim=2, start=cache_position, end=cache_position + query_length)
|
113
|
+
hidden_states = inputs_embeds + position_embeds
|
103
114
|
|
104
|
-
|
115
|
+
# GPT2Attention mask.
|
116
|
+
# Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
|
117
|
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
105
118
|
|
106
|
-
|
119
|
+
for layer_idx, block in enumerate(self.h):
|
120
|
+
hidden_states, updated_cache = forward_dict["model"](
|
121
|
+
block,
|
122
|
+
hidden_states,
|
123
|
+
layer_idx,
|
124
|
+
attention_mask=attention_mask,
|
125
|
+
past_key_value=past_key_value,
|
126
|
+
position_ids=position_ids,
|
127
|
+
batch_ids=batch_ids,
|
128
|
+
forward_dict=forward_dict,
|
129
|
+
)
|
107
130
|
|
108
|
-
|
131
|
+
hidden_states = self.ln_f(hidden_states)
|
132
|
+
output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
|
133
|
+
hidden_states = hidden_states.view(output_shape)
|
109
134
|
|
110
|
-
|
111
|
-
|
135
|
+
# convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
|
136
|
+
next_cache = updated_cache.to_legacy_cache()
|
112
137
|
|
113
|
-
return
|
138
|
+
return BaseModelOutputWithPast(
|
139
|
+
last_hidden_state=hidden_states,
|
140
|
+
past_key_values=next_cache,
|
141
|
+
)
|
114
142
|
|
115
143
|
|
116
|
-
class _GPT2Block
|
144
|
+
class _GPT2Block:
|
117
145
|
def forward(
|
118
146
|
self,
|
119
147
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
120
|
-
|
148
|
+
layer_idx: int,
|
121
149
|
attention_mask: Optional[torch.FloatTensor] = None,
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
150
|
+
position_ids: Optional[torch.LongTensor] = None,
|
151
|
+
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
152
|
+
batch_ids: Optional[torch.LongTensor] = None,
|
153
|
+
forward_dict: Optional[Dict[str, classmethod]] = None,
|
154
|
+
**kwargs,
|
155
|
+
) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
|
129
156
|
residual = hidden_states
|
130
157
|
hidden_states = self.ln_1(hidden_states)
|
131
158
|
|
132
|
-
|
159
|
+
hidden_states, k, v = forward_dict["decoder_layer"](
|
133
160
|
self.attn,
|
134
|
-
hidden_states,
|
135
|
-
past_key_values=past_key_values,
|
161
|
+
hidden_states=hidden_states,
|
136
162
|
attention_mask=attention_mask,
|
137
|
-
|
138
|
-
|
163
|
+
position_ids=position_ids,
|
164
|
+
past_key_value=past_key_value,
|
165
|
+
batch_index=batch_ids,
|
139
166
|
)
|
167
|
+
past_key_value.assign(k, v, layer_idx)
|
140
168
|
|
141
169
|
# residual connection
|
142
|
-
hidden_states =
|
170
|
+
hidden_states = residual + hidden_states
|
143
171
|
|
144
172
|
residual = hidden_states
|
145
173
|
hidden_states = self.ln_2(hidden_states)
|
146
|
-
|
147
|
-
|
148
|
-
hidden_states = residual + feed_forward_hidden_states
|
149
|
-
|
150
|
-
return hidden_states
|
151
|
-
|
174
|
+
hidden_states = self.mlp(hidden_states)
|
175
|
+
hidden_states = residual + hidden_states
|
152
176
|
|
153
|
-
|
154
|
-
def forward(
|
155
|
-
self,
|
156
|
-
input_ids: Optional[torch.LongTensor] = None,
|
157
|
-
past_key_values: List[List[torch.Tensor]] = None,
|
158
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
159
|
-
position_ids: Optional[torch.LongTensor] = None,
|
160
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
161
|
-
cache_position: Optional[torch.LongTensor] = None,
|
162
|
-
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
163
|
-
input_shape = input_ids.size()
|
177
|
+
return hidden_states, past_key_value
|
164
178
|
|
165
|
-
if position_ids is None:
|
166
|
-
# force dtype to torch.long -> torch.int32 (to match cache_position)
|
167
|
-
position_ids = torch.arange(0, input_shape[-1], dtype=torch.int32) + cache_position
|
168
|
-
position_ids = position_ids.unsqueeze(0)
|
169
|
-
|
170
|
-
# GPT2Attention mask.
|
171
|
-
# Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
|
172
|
-
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
173
179
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
# head_mask has shape n_layer x batch x n_heads x N x N
|
178
|
-
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
179
|
-
|
180
|
-
inputs_embeds = self.wte(input_ids)
|
181
|
-
position_embeds = self.wpe(position_ids)
|
182
|
-
hidden_states = inputs_embeds + position_embeds
|
183
|
-
|
184
|
-
hidden_states = self.drop(hidden_states)
|
185
|
-
|
186
|
-
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
180
|
+
class _GPT2Attention:
|
181
|
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
182
|
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
187
183
|
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
hidden_states,
|
192
|
-
past_key_values=past_key_values,
|
193
|
-
attention_mask=attention_mask,
|
194
|
-
head_mask=head_mask[i],
|
195
|
-
cache_position=cache_position,
|
184
|
+
if self.scale_attn_weights:
|
185
|
+
attn_weights = attn_weights / torch.full(
|
186
|
+
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
196
187
|
)
|
197
188
|
|
198
|
-
|
199
|
-
|
200
|
-
|
189
|
+
# Layer-wise attention scaling
|
190
|
+
if self.scale_attn_by_inverse_layer_idx:
|
191
|
+
attn_weights = attn_weights / float(self.layer_idx + 1)
|
201
192
|
|
193
|
+
# -------------------
|
194
|
+
# Below are deleted since "where" op does not supported on RBLN graph.
|
195
|
+
# -------------------
|
196
|
+
# if not self.is_cross_attention:
|
197
|
+
# # if only "normal" attention layer implements causal mask
|
198
|
+
# query_length, key_length = query.size(-2), key.size(-2)
|
199
|
+
# causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
200
|
+
# mask_value = torch.finfo(attn_weights.dtype).min
|
201
|
+
# # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
202
|
+
# # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
203
|
+
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
204
|
+
# attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
202
205
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
206
|
+
# Apply the attention mask
|
207
|
+
attn_weights.view(
|
208
|
+
-1,
|
209
|
+
)
|
210
|
+
attn_weights = attn_weights + attention_mask
|
211
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
212
|
+
attn_output = torch.matmul(attn_weights, value)
|
213
|
+
|
214
|
+
return attn_output, attn_weights
|
207
215
|
|
208
216
|
def forward(
|
209
217
|
self,
|
210
|
-
|
211
|
-
attention_mask: torch.
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
transformer_outputs = _GPT2Model.forward(
|
220
|
-
self.model.transformer,
|
221
|
-
input_ids=input_ids,
|
222
|
-
past_key_values=kv_cache,
|
223
|
-
attention_mask=attention_mask,
|
224
|
-
cache_position=cache_position,
|
225
|
-
)
|
226
|
-
|
227
|
-
hidden_states = transformer_outputs[0]
|
218
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
219
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
220
|
+
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
221
|
+
batch_index: Optional[int] = None,
|
222
|
+
**kwargs,
|
223
|
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
224
|
+
bsz, q_len, _ = hidden_states.size()
|
225
|
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
228
226
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
#
|
227
|
+
querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
|
228
|
+
keys = self._split_heads(key, self.num_heads, self.head_dim)
|
229
|
+
values = self._split_heads(value, self.num_heads, self.head_dim)
|
230
|
+
|
231
|
+
# Decoder
|
232
|
+
if (batch_index is None or batch_index == -1) and bsz > 1:
|
233
|
+
all_keys = []
|
234
|
+
all_values = []
|
235
|
+
all_attn_output = []
|
236
|
+
|
237
|
+
for b in range(bsz):
|
238
|
+
query = querys[b].unsqueeze(0)
|
239
|
+
attn_mask = attention_mask[b].unsqueeze(0)
|
240
|
+
key = keys[b].unsqueeze(0)
|
241
|
+
value = values[b].unsqueeze(0)
|
242
|
+
|
243
|
+
key, value = past_key_value.update(
|
244
|
+
key,
|
245
|
+
value,
|
246
|
+
self.layer_idx,
|
247
|
+
b,
|
248
|
+
)
|
249
|
+
|
250
|
+
attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
|
251
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
252
|
+
|
253
|
+
all_keys.append(key)
|
254
|
+
all_values.append(value)
|
255
|
+
all_attn_output.append(attn_output)
|
256
|
+
|
257
|
+
keys = torch.cat(all_keys, dim=0)
|
258
|
+
values = torch.cat(all_values, dim=0)
|
259
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
260
|
+
|
261
|
+
# Prefill
|
262
|
+
else:
|
263
|
+
if batch_index is None or batch_index == -1:
|
264
|
+
batch_index = 0
|
265
|
+
|
266
|
+
keys, values = past_key_value.update(
|
267
|
+
keys,
|
268
|
+
values,
|
269
|
+
self.layer_idx,
|
270
|
+
batch_index,
|
271
|
+
read_first_step=True,
|
272
|
+
)
|
234
273
|
|
235
|
-
|
236
|
-
|
274
|
+
attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
|
275
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
237
276
|
|
238
|
-
|
239
|
-
for i in range(self.model.config.n_layer):
|
240
|
-
past_key_values.append(kv_cache[i][0])
|
241
|
-
past_key_values.append(kv_cache[i][1])
|
277
|
+
attn_output = self.c_proj(attn_output)
|
242
278
|
|
243
|
-
|
244
|
-
return output
|
279
|
+
return attn_output, keys, values
|