optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,224 +21,259 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import List, Optional, Tuple, Union
24
+ from typing import Dict, Optional, Tuple, Union
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from transformers.modeling_outputs import (
29
- BaseModelOutputWithPast,
30
- BaseModelOutputWithPastAndCrossAttentions,
31
- )
32
- from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast
33
29
 
30
+ from ...cache_utils import RebelDynamicCache_4D
34
31
 
35
- class _GPT2Attention(GPT2Attention):
36
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
37
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
38
-
39
- if self.scale_attn_weights:
40
- attn_weights = attn_weights / torch.full(
41
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
42
- )
43
32
 
44
- # Layer-wise attention scaling
45
- if self.scale_attn_by_inverse_layer_idx:
46
- attn_weights = attn_weights / float(self.layer_idx + 1)
47
-
48
- # -------------------
49
- # Below are deleted since "where" op does not supported on RBLN graph.
50
- # -------------------
51
- # if not self.is_cross_attention:
52
- # # if only "normal" attention layer implements causal mask
53
- # query_length, key_length = query.size(-2), key.size(-2)
54
- # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
55
- # mask_value = torch.finfo(attn_weights.dtype).min
56
- # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
57
- # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
58
- # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
59
- # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
33
+ class GPT2LMHeadModelWrapper(torch.nn.Module):
34
+ def __init__(self, model, max_seq_len):
35
+ super().__init__()
36
+ self.model = model.transformer
37
+ self.lm_head = model.lm_head
38
+ self.config = model.config
39
+ self.max_seq_len = max_seq_len
40
+ self.forward_dict = self.get_forward_dict()
41
+
42
+ def get_forward_dict(self):
43
+ forward_dict = {
44
+ "wrapper": _GPT2Model.forward,
45
+ "model": _GPT2Block.forward,
46
+ "decoder_layer": _GPT2Attention.forward,
47
+ }
48
+ return forward_dict
60
49
 
61
- if attention_mask is not None:
62
- # Apply the attention mask
63
- attn_weights = attn_weights + attention_mask
50
+ def forward(
51
+ self,
52
+ input_ids,
53
+ attention_mask,
54
+ cache_position,
55
+ batch_position,
56
+ *past_key_values,
57
+ ):
58
+ if input_ids.shape[1] == 1:
59
+ rbln_batch_position = None
60
+ else:
61
+ rbln_batch_position = batch_position
62
+
63
+ # Formatting list of past_kv to DynamicCache class.
64
+ past_key_value = RebelDynamicCache_4D.from_input_format(
65
+ cache_position,
66
+ self.config.n_layer,
67
+ *past_key_values,
68
+ )
64
69
 
65
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
70
+ outputs = self.forward_dict["wrapper"](
71
+ self.model,
72
+ input_ids=input_ids,
73
+ attention_mask=attention_mask,
74
+ position_ids=cache_position,
75
+ past_key_value=past_key_value,
76
+ batch_ids=rbln_batch_position,
77
+ forward_dict=self.forward_dict,
78
+ # rotary_emb differenct from_llama
79
+ )
66
80
 
67
- # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
68
- attn_weights = attn_weights.type(value.dtype)
69
- # attn_weights = self.attn_dropout(attn_weights)
81
+ hidden_states = outputs[0]
82
+ logits = self.lm_head(hidden_states)
70
83
 
71
- # Mask heads if we want to
72
- if head_mask is not None:
73
- attn_weights = attn_weights * head_mask
84
+ output = (logits,) + outputs[1:]
74
85
 
75
- attn_output = torch.matmul(attn_weights, value)
86
+ return output, batch_position
76
87
 
77
- return attn_output, attn_weights
78
88
 
89
+ class _GPT2Model:
79
90
  def forward(
80
91
  self,
81
- hidden_states: Optional[Tuple[torch.FloatTensor]],
82
- past_key_values: List[List[torch.Tensor]] = None,
83
- attention_mask: Optional[torch.FloatTensor] = None,
84
- head_mask: Optional[torch.FloatTensor] = None,
85
- encoder_hidden_states: Optional[torch.Tensor] = None,
86
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
87
- use_cache: Optional[bool] = False,
88
- output_attentions: Optional[bool] = False,
89
- cache_position: Optional[torch.LongTensor] = None,
90
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
91
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
92
+ input_ids: torch.LongTensor = None,
93
+ attention_mask: Optional[torch.Tensor] = None,
94
+ position_ids: Optional[torch.LongTensor] = None,
95
+ past_key_value: Optional[RebelDynamicCache_4D] = None,
96
+ batch_ids: Optional[torch.LongTensor] = None,
97
+ forward_dict: Optional[Dict[str, classmethod]] = None,
98
+ ) -> BaseModelOutputWithPast:
99
+ b_size, q_len = input_ids.shape
100
+ inputs_embeds = self.wte(input_ids)
92
101
 
93
- query = self._split_heads(query, self.num_heads, self.head_dim)
94
- key = self._split_heads(key, self.num_heads, self.head_dim)
95
- value = self._split_heads(value, self.num_heads, self.head_dim)
102
+ if position_ids.shape[0] > 1:
103
+ position_embeds = []
104
+ for b_idx in range(b_size):
105
+ position_embed = self.wpe(position_ids[b_idx])
106
+ # position_embed = position_embed.dtype(inputs_embeds.dtype)
107
+ position_embeds.append(position_embed)
96
108
 
97
- if past_key_values is not None:
98
- past_key, past_value = past_key_values[self.layer_idx]
99
- query_length = query.shape[-2]
109
+ position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
110
+ else:
111
+ position_embeds = self.wpe(position_ids)
100
112
 
101
- key = past_key.slice_scatter(key, dim=2, start=cache_position, end=cache_position + query_length)
102
- value = past_value.slice_scatter(value, dim=2, start=cache_position, end=cache_position + query_length)
113
+ hidden_states = inputs_embeds + position_embeds
103
114
 
104
- past_key_values[self.layer_idx] = [key, value]
115
+ # GPT2Attention mask.
116
+ # Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
117
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
105
118
 
106
- attn_output, _ = _GPT2Attention._attn(self, query, key, value, attention_mask, head_mask)
119
+ for layer_idx, block in enumerate(self.h):
120
+ hidden_states, updated_cache = forward_dict["model"](
121
+ block,
122
+ hidden_states,
123
+ layer_idx,
124
+ attention_mask=attention_mask,
125
+ past_key_value=past_key_value,
126
+ position_ids=position_ids,
127
+ batch_ids=batch_ids,
128
+ forward_dict=forward_dict,
129
+ )
107
130
 
108
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
131
+ hidden_states = self.ln_f(hidden_states)
132
+ output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
133
+ hidden_states = hidden_states.view(output_shape)
109
134
 
110
- attn_output = self.c_proj(attn_output)
111
- attn_output = self.resid_dropout(attn_output)
135
+ # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
136
+ next_cache = updated_cache.to_legacy_cache()
112
137
 
113
- return attn_output
138
+ return BaseModelOutputWithPast(
139
+ last_hidden_state=hidden_states,
140
+ past_key_values=next_cache,
141
+ )
114
142
 
115
143
 
116
- class _GPT2Block(GPT2Block):
144
+ class _GPT2Block:
117
145
  def forward(
118
146
  self,
119
147
  hidden_states: Optional[Tuple[torch.FloatTensor]],
120
- past_key_values: List[List[torch.Tensor]] = None,
148
+ layer_idx: int,
121
149
  attention_mask: Optional[torch.FloatTensor] = None,
122
- head_mask: Optional[torch.FloatTensor] = None,
123
- encoder_hidden_states: Optional[torch.Tensor] = None,
124
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
125
- use_cache: Optional[bool] = False,
126
- output_attentions: Optional[bool] = False,
127
- cache_position: Optional[torch.LongTensor] = None,
128
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
150
+ position_ids: Optional[torch.LongTensor] = None,
151
+ past_key_value: Optional[RebelDynamicCache_4D] = None,
152
+ batch_ids: Optional[torch.LongTensor] = None,
153
+ forward_dict: Optional[Dict[str, classmethod]] = None,
154
+ **kwargs,
155
+ ) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
129
156
  residual = hidden_states
130
157
  hidden_states = self.ln_1(hidden_states)
131
158
 
132
- attn_output = _GPT2Attention.forward(
159
+ hidden_states, k, v = forward_dict["decoder_layer"](
133
160
  self.attn,
134
- hidden_states,
135
- past_key_values=past_key_values,
161
+ hidden_states=hidden_states,
136
162
  attention_mask=attention_mask,
137
- head_mask=head_mask,
138
- cache_position=cache_position,
163
+ position_ids=position_ids,
164
+ past_key_value=past_key_value,
165
+ batch_index=batch_ids,
139
166
  )
167
+ past_key_value.assign(k, v, layer_idx)
140
168
 
141
169
  # residual connection
142
- hidden_states = attn_output + residual
170
+ hidden_states = residual + hidden_states
143
171
 
144
172
  residual = hidden_states
145
173
  hidden_states = self.ln_2(hidden_states)
146
- feed_forward_hidden_states = self.mlp(hidden_states)
147
- # residual connection
148
- hidden_states = residual + feed_forward_hidden_states
149
-
150
- return hidden_states
151
-
174
+ hidden_states = self.mlp(hidden_states)
175
+ hidden_states = residual + hidden_states
152
176
 
153
- class _GPT2Model(GPT2Model):
154
- def forward(
155
- self,
156
- input_ids: Optional[torch.LongTensor] = None,
157
- past_key_values: List[List[torch.Tensor]] = None,
158
- attention_mask: Optional[torch.FloatTensor] = None,
159
- position_ids: Optional[torch.LongTensor] = None,
160
- head_mask: Optional[torch.FloatTensor] = None,
161
- cache_position: Optional[torch.LongTensor] = None,
162
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
163
- input_shape = input_ids.size()
177
+ return hidden_states, past_key_value
164
178
 
165
- if position_ids is None:
166
- # force dtype to torch.long -> torch.int32 (to match cache_position)
167
- position_ids = torch.arange(0, input_shape[-1], dtype=torch.int32) + cache_position
168
- position_ids = position_ids.unsqueeze(0)
169
-
170
- # GPT2Attention mask.
171
- # Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
172
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
173
179
 
174
- # Prepare head mask if needed
175
- # 1.0 in head_mask indicate we keep the head
176
- # attention_probs has shape bsz x n_heads x N x N
177
- # head_mask has shape n_layer x batch x n_heads x N x N
178
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
179
-
180
- inputs_embeds = self.wte(input_ids)
181
- position_embeds = self.wpe(position_ids)
182
- hidden_states = inputs_embeds + position_embeds
183
-
184
- hidden_states = self.drop(hidden_states)
185
-
186
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
180
+ class _GPT2Attention:
181
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
182
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
187
183
 
188
- for i, block in enumerate(self.h):
189
- hidden_states = _GPT2Block.forward(
190
- block,
191
- hidden_states,
192
- past_key_values=past_key_values,
193
- attention_mask=attention_mask,
194
- head_mask=head_mask[i],
195
- cache_position=cache_position,
184
+ if self.scale_attn_weights:
185
+ attn_weights = attn_weights / torch.full(
186
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
196
187
  )
197
188
 
198
- hidden_states = self.ln_f(hidden_states)
199
- hidden_states = hidden_states.view(output_shape)
200
- return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=past_key_values)
189
+ # Layer-wise attention scaling
190
+ if self.scale_attn_by_inverse_layer_idx:
191
+ attn_weights = attn_weights / float(self.layer_idx + 1)
201
192
 
193
+ # -------------------
194
+ # Below are deleted since "where" op does not supported on RBLN graph.
195
+ # -------------------
196
+ # if not self.is_cross_attention:
197
+ # # if only "normal" attention layer implements causal mask
198
+ # query_length, key_length = query.size(-2), key.size(-2)
199
+ # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
200
+ # mask_value = torch.finfo(attn_weights.dtype).min
201
+ # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
202
+ # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
203
+ # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
204
+ # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
202
205
 
203
- class GPT2LMHeadModelWrapper(torch.nn.Module):
204
- def __init__(self, gpt):
205
- super().__init__()
206
- self.model = gpt
206
+ # Apply the attention mask
207
+ attn_weights.view(
208
+ -1,
209
+ )
210
+ attn_weights = attn_weights + attention_mask
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
212
+ attn_output = torch.matmul(attn_weights, value)
213
+
214
+ return attn_output, attn_weights
207
215
 
208
216
  def forward(
209
217
  self,
210
- input_ids: torch.Tensor,
211
- attention_mask: torch.Tensor,
212
- cache_position: torch.LongTensor,
213
- *past_key_values: torch.Tensor,
214
- ):
215
- kv_cache = []
216
- for i in range(self.model.config.n_layer):
217
- kv_cache.append((past_key_values[2 * i], past_key_values[2 * i + 1]))
218
-
219
- transformer_outputs = _GPT2Model.forward(
220
- self.model.transformer,
221
- input_ids=input_ids,
222
- past_key_values=kv_cache,
223
- attention_mask=attention_mask,
224
- cache_position=cache_position,
225
- )
226
-
227
- hidden_states = transformer_outputs[0]
218
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
219
+ attention_mask: Optional[torch.FloatTensor] = None,
220
+ past_key_value: Optional[RebelDynamicCache_4D] = None,
221
+ batch_index: Optional[int] = None,
222
+ **kwargs,
223
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
224
+ bsz, q_len, _ = hidden_states.size()
225
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
228
226
 
229
- # TODO : Use query_length here to pick last logit
230
- # batch_size, sequence_length = hidden_states.shape[:2]
231
- # hidden_states = hidden_states.view(batch_size * sequence_length, -1)
232
- # hidden_states = torch.nn.functional.embedding(query_length, hidden_states)
233
- # hidden_states = hidden_states.view(batch_size, 1, -1)
227
+ querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
228
+ keys = self._split_heads(key, self.num_heads, self.head_dim)
229
+ values = self._split_heads(value, self.num_heads, self.head_dim)
230
+
231
+ # Decoder
232
+ if (batch_index is None or batch_index == -1) and bsz > 1:
233
+ all_keys = []
234
+ all_values = []
235
+ all_attn_output = []
236
+
237
+ for b in range(bsz):
238
+ query = querys[b].unsqueeze(0)
239
+ attn_mask = attention_mask[b].unsqueeze(0)
240
+ key = keys[b].unsqueeze(0)
241
+ value = values[b].unsqueeze(0)
242
+
243
+ key, value = past_key_value.update(
244
+ key,
245
+ value,
246
+ self.layer_idx,
247
+ b,
248
+ )
249
+
250
+ attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
251
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
252
+
253
+ all_keys.append(key)
254
+ all_values.append(value)
255
+ all_attn_output.append(attn_output)
256
+
257
+ keys = torch.cat(all_keys, dim=0)
258
+ values = torch.cat(all_values, dim=0)
259
+ attn_output = torch.cat(all_attn_output, dim=0)
260
+
261
+ # Prefill
262
+ else:
263
+ if batch_index is None or batch_index == -1:
264
+ batch_index = 0
265
+
266
+ keys, values = past_key_value.update(
267
+ keys,
268
+ values,
269
+ self.layer_idx,
270
+ batch_index,
271
+ read_first_step=True,
272
+ )
234
273
 
235
- lm_logits = self.model.lm_head(hidden_states)
236
- kv_cache = transformer_outputs[1]
274
+ attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
275
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
237
276
 
238
- past_key_values = []
239
- for i in range(self.model.config.n_layer):
240
- past_key_values.append(kv_cache[i][0])
241
- past_key_values.append(kv_cache[i][1])
277
+ attn_output = self.c_proj(attn_output)
242
278
 
243
- output = (lm_logits,) + tuple(past_key_values)
244
- return output
279
+ return attn_output, keys, values