optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,155 +21,108 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
-
25
- from typing import Dict, Optional, Tuple, Union
24
+ from typing import Optional, Tuple, Union
26
25
 
27
26
  import torch
28
27
  import torch.nn as nn
29
- from transformers.cache_utils import Cache, DynamicCache
30
- from transformers.modeling_outputs import (
31
- BaseModelOutputWithPastAndCrossAttentions,
32
- )
28
+ from transformers.modeling_outputs import BaseModelOutputWithPast
33
29
 
34
- from .hf_hub_cached.modeling_midm import (
35
- MidmAttention,
36
- MidmBlock,
37
- MidmModel,
30
+ from ....transformers.models.decoderonly.decoderonly_architecture import (
31
+ RotaryEmbedding,
32
+ rotate_half,
33
+ slice_and_unsqueeze_cos_sin,
38
34
  )
35
+ from ...cache_utils import RebelDynamicCache_4D
39
36
 
40
37
 
41
- class _MidmRotaryEmbedding(nn.Module):
42
- """
43
- Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
44
- """
45
-
46
- def __init__(
47
- self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
48
- ):
49
- """
50
- Args:
51
-
52
- dim (int): rotary embedding dimension
53
- seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
54
- by this factor via the trick in https://arxiv.org/abs/2306.15595.
55
- pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
56
- """
57
- super().__init__()
58
- self.seq_len_interpolation_factor = seq_len_interpolation_factor
59
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
60
- self.register_buffer("inv_freq", inv_freq, persistent=False)
61
- self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
62
-
63
- seq_len = pretrained_max_position_embeddings
64
- device = self.inv_freq.device
65
- dtype = torch.get_default_dtype()
66
- self.max_seq_len_cached = seq_len
67
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
68
-
69
- freqs = torch.outer(t, self.inv_freq)
70
-
71
- emb = torch.cat((freqs, freqs), dim=-1)
72
- self.register_buffer("emb_cached", emb.to(dtype), persistent=False)
73
-
74
- def forward(self, max_seq_len, offset=0):
75
-
76
- if max_seq_len > self.max_seq_len_cached:
77
- self._set_emb_cache(seq_len=max_seq_len)
78
-
79
- return self.emb_cached[:max_seq_len]
38
+ def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
39
+ """Applies rotary position embedding to the specified dimension of the tensor."""
40
+ tensor_, tensor_pass = tensor[..., :rot_dim], tensor[..., rot_dim:]
41
+ tensor_embed = (tensor_ * cos) + (rotate_half(tensor_) * sin)
42
+ return torch.cat((tensor_embed, tensor_pass), dim=-1)
80
43
 
81
44
 
82
- def _rotate_half(x):
83
- """
84
- change sign so the last dimension
85
- [A, B, C, D] -> [-C, -D, A, B]
86
- """
87
- x1 = x[..., : x.shape[-1] // 2]
88
- x2 = x[..., x.shape[-1] // 2 :]
89
- return torch.cat((-x2, x1), dim=-1)
90
-
91
-
92
- def apply_rotary_pos_emb(t: torch.Tensor, cache_kwargs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
93
- """
94
- input tensor t is of shape [seq_length, ..., dim]
95
- rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
96
- check https://kexue.fm/archives/8265 for detailed formulas
97
- """
98
-
99
- freqs = cache_kwargs["rotary_pos_emb"]
100
- position_ids = cache_kwargs["position_ids"]
101
- unsqueeze_dim = 1
102
-
103
- rot_dim = freqs.shape[-1]
104
-
105
- t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
106
- cos = freqs.cos()[position_ids].unsqueeze(unsqueeze_dim)
107
- sin = freqs.sin()[position_ids].unsqueeze(unsqueeze_dim)
108
-
109
- embed = (t * cos) + (_rotate_half(t) * sin)
110
- embed = torch.cat((embed, t_pass), dim=-1)
111
-
112
- return embed
45
+ def apply_rotary_pos_emb(q, k, cos, sin):
46
+ """Applies Rotary Position Embedding to the query and key tensors."""
47
+ rot_dim = cos.shape[-1]
48
+ q_embed = apply_rotary_to_tensor(q, cos, sin, rot_dim)
49
+ k_embed = apply_rotary_to_tensor(k, cos, sin, rot_dim)
50
+ return q_embed, k_embed
113
51
 
114
52
 
115
53
  class MidmLMHeadModelWrapper(torch.nn.Module):
116
- def __init__(self, model):
54
+ """A wrapper class for the Midm model with a language modeling head."""
55
+
56
+ def __init__(self, model, max_seq_len):
117
57
  super().__init__()
118
- self.model = model
119
- self.confg = model.config
120
-
121
- self.use_rotary_position_embedding = model.config.use_rotary_position_embedding
122
- if self.use_rotary_position_embedding:
123
- rotary_dim = model.config.hidden_size // model.config.num_attention_heads
124
- assert 0 < model.config.rotary_percentage <= 1
125
- if model.config.rotary_percentage < 1:
126
- rotary_dim = int(rotary_dim * model.config.rotary_percentage)
127
- self._rotary_pos_emb = _MidmRotaryEmbedding(
128
- rotary_dim,
129
- seq_len_interpolation_factor=None,
130
- pretrained_max_position_embeddings=model.config.max_position_embeddings,
131
- )
58
+ self.model = model.transformer
59
+ self.lm_head = model.lm_head
60
+ self.config = model.config
61
+ self.head_dim = self.config.n_embd // self.config.n_head
62
+ self.max_position_embeddings = (
63
+ self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
64
+ )
65
+ self.max_seq_len = max_seq_len
66
+ self.rotary_dim = int(
67
+ model.config.hidden_size // model.config.num_attention_heads * model.config.rotary_percentage
68
+ )
69
+ self.rotary_emb = self._init_rope()
70
+
71
+ def _init_rope(self):
72
+ """Initializes the Rotary Position Embeddings."""
73
+ rotary_emb = RotaryEmbedding(
74
+ self.rotary_dim,
75
+ max_position_embeddings=self.max_position_embeddings,
76
+ )
77
+ return rotary_emb
132
78
 
133
79
  def forward(
134
80
  self,
135
81
  input_ids: torch.Tensor,
136
82
  attention_mask: torch.Tensor,
137
83
  cache_position: torch.LongTensor,
84
+ batch_position: int,
138
85
  *past_key_values,
139
86
  ):
140
- past_kv_list = []
141
- for i in range(self.model.config.n_layer):
142
- cur_kv_layer = []
143
- for j in range(2):
144
- cur_kv_layer.append(past_key_values[2 * i + j])
145
- past_kv_list.append(cur_kv_layer)
146
-
147
- transformer_outputs = _MidmModel.forward(
148
- self.model.transformer,
87
+ """Defines the forward pass for the wrapper model."""
88
+ if input_ids.shape[1] == 1:
89
+ rbln_batch_position = None
90
+ else:
91
+ rbln_batch_position = batch_position
92
+
93
+ past_key_values = RebelDynamicCache_4D.from_input_format(
94
+ cache_position,
95
+ self.config.num_hidden_layers,
96
+ *past_key_values,
97
+ )
98
+
99
+ outputs = _MidmModel.forward(
100
+ self.model,
149
101
  input_ids=input_ids,
150
- past_key_values=past_kv_list,
102
+ past_key_values=past_key_values,
151
103
  attention_mask=attention_mask,
152
104
  position_ids=cache_position,
153
- rotary_pos_emb=self._rotary_pos_emb,
105
+ rotary_pos_emb=self.rotary_emb,
106
+ batch_ids=rbln_batch_position,
154
107
  )
155
108
 
156
- hidden_states = transformer_outputs[0]
109
+ hidden_states = outputs[0]
110
+ logits = self.lm_head(hidden_states)
111
+ output = (logits,) + outputs[1:]
157
112
 
158
- # For the input_ids, we assume right-alignment.
159
- # This assumption allows us to bypass dynamic indexing.
160
- hidden_states = hidden_states[:, -1:]
161
- lm_logits = self.model.lm_head(hidden_states)
162
- kv_cache = transformer_outputs[1]
163
-
164
- return lm_logits, kv_cache
113
+ return output, batch_position
165
114
 
166
115
 
167
116
  def layernorm1p(module, input):
117
+ """Applies Layer Normalization with a slight modification on the weights."""
168
118
  return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
169
119
 
170
120
 
171
- class _MidmAttention(MidmAttention):
121
+ class _MidmAttention:
122
+ """Custom implementation of the MidmAttention class with specific modifications."""
123
+
172
124
  def _attn(self, query, key, value, attention_mask=None, head_mask=None):
125
+ """Computes the attention weights and output."""
173
126
  attn_weights = torch.matmul(query, key.transpose(-1, -2))
174
127
 
175
128
  if self.scale_attn_weights:
@@ -187,320 +140,170 @@ class _MidmAttention(MidmAttention):
187
140
  attn_weights = attn_weights * float(self.layer_idx + 1)
188
141
 
189
142
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
190
-
191
143
  attn_weights = attn_weights.type(value.dtype)
192
144
 
193
145
  if head_mask is not None:
194
146
  attn_weights = attn_weights * head_mask
195
147
 
196
148
  attn_output = torch.matmul(attn_weights, value)
197
-
198
149
  return attn_output, attn_weights
199
150
 
200
151
  def forward(
201
152
  self,
202
153
  hidden_states: Optional[Tuple[torch.FloatTensor]],
203
154
  attention_mask: Optional[torch.FloatTensor] = None,
204
- position_ids: Optional[torch.LongTensor] = None,
205
- past_key_value: Optional[Cache] = None,
206
- head_mask: Optional[torch.FloatTensor] = None,
207
- use_cache: Optional[bool] = False,
208
- rotary_pos_emb=None,
155
+ past_key_value: Optional[RebelDynamicCache_4D] = None,
156
+ batch_index: Optional[int] = None,
157
+ cos: Optional[torch.Tensor] = None,
158
+ sin: Optional[torch.Tensor] = None,
209
159
  ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
210
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
211
- query = self._split_heads(query, self.num_heads, self.head_dim)
212
- key = self._split_heads(key, self.num_heads, self.head_dim)
213
- value = self._split_heads(value, self.num_heads, self.head_dim)
214
-
215
- kv_seq_len = key.shape[-2]
216
- if past_key_value is not None:
217
- if self.layer_idx is None:
218
- raise ValueError(
219
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
220
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
221
- "with a layer index."
160
+ """Defines the forward pass for the attention mechanism."""
161
+ bsz, q_len, _ = hidden_states.size()
162
+
163
+ querys, keys, values = self.c_attn(hidden_states).split(self.split_size, dim=2)
164
+
165
+ querys = self._split_heads(querys, self.num_heads, self.head_dim).contiguous()
166
+ keys = self._split_heads(keys, self.num_heads, self.head_dim).contiguous()
167
+ values = self._split_heads(values, self.num_heads, self.head_dim).contiguous()
168
+
169
+ querys, keys = apply_rotary_pos_emb(querys, keys, cos, sin)
170
+
171
+ # Decoder
172
+ if (batch_index is None or batch_index == -1) and bsz > 1:
173
+ all_key_states = []
174
+ all_value_states = []
175
+ all_attn_output = []
176
+
177
+ for b in range(bsz):
178
+ query = querys[b].unsqueeze(0)
179
+ attn_mask = attention_mask[b].unsqueeze(0)
180
+ key = keys[b].unsqueeze(0)
181
+ value = values[b].unsqueeze(0)
182
+
183
+ key, value = past_key_value.update(
184
+ key,
185
+ value,
186
+ self.layer_idx,
187
+ b,
222
188
  )
223
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
224
189
 
225
- if use_cache is True:
226
- present = (key, value)
227
- else:
228
- present = None
190
+ attn_output, _ = _MidmAttention._attn(self, query, key, value, attn_mask)
191
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
229
192
 
230
- if rotary_pos_emb is not None:
231
- query = apply_rotary_pos_emb(query, {"rotary_pos_emb": rotary_pos_emb, "position_ids": position_ids})
232
- key = apply_rotary_pos_emb(key, {"rotary_pos_emb": rotary_pos_emb, "position_ids": position_ids})
193
+ all_key_states.append(key)
194
+ all_value_states.append(value)
195
+ all_attn_output.append(attn_output)
233
196
 
234
- if past_key_value is not None:
235
- key, value = past_key_value.update(key, value, self.layer_idx)
197
+ keys = torch.cat(all_key_states, dim=0)
198
+ values = torch.cat(all_value_states, dim=0)
199
+ attn_output = torch.cat(all_attn_output, dim=0)
236
200
 
237
- attn_output, _ = _MidmAttention._attn(self, query, key, value, attention_mask, head_mask)
201
+ else:
202
+ if batch_index is None or batch_index == -1:
203
+ batch_index = 0
204
+
205
+ keys, values = past_key_value.update(
206
+ keys,
207
+ values,
208
+ self.layer_idx,
209
+ batch_index,
210
+ read_first_step=True,
211
+ )
238
212
 
239
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
213
+ attn_output, _ = _MidmAttention._attn(self, querys, keys, values, attention_mask)
214
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
240
215
 
241
216
  attn_output = self.c_proj(attn_output)
242
- attn_output = self.resid_dropout(attn_output)
243
-
244
- outputs = (attn_output, present)
217
+ return attn_output, keys, values
245
218
 
246
- return outputs, past_key_value
247
219
 
220
+ class _MidmBlock:
221
+ """Custom implementation of the MidmBlock class with specific modifications."""
248
222
 
249
- class _MidmBlock(MidmBlock):
250
223
  def forward(
251
224
  self,
252
225
  hidden_states: Optional[Tuple[torch.FloatTensor]],
226
+ layer_idx: int,
253
227
  attention_mask: Optional[torch.FloatTensor] = None,
254
- position_ids: Optional[torch.LongTensor] = None,
255
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
256
- head_mask: Optional[torch.FloatTensor] = None,
257
- use_cache: Optional[bool] = False,
258
- rotary_pos_emb=None,
228
+ past_key_value: Optional[RebelDynamicCache_4D] = None,
229
+ batch_ids: Optional[torch.LongTensor] = None,
230
+ cos: Optional[torch.Tensor] = None,
231
+ sin: Optional[torch.Tensor] = None,
259
232
  ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
233
+ """Defines the forward pass for the block."""
260
234
  residual = hidden_states
261
235
  if self.use_layernorm1p:
262
236
  hidden_states = layernorm1p(self.ln_1, hidden_states)
263
237
  else:
264
238
  hidden_states = self.ln_1(hidden_states)
265
239
 
266
- attn_outputs, present_key_value = _MidmAttention.forward(
240
+ hidden_states, k, v = _MidmAttention.forward(
267
241
  self.attn,
268
242
  hidden_states,
269
243
  attention_mask=attention_mask,
270
- position_ids=position_ids,
271
244
  past_key_value=past_key_value,
272
- head_mask=head_mask,
273
- rotary_pos_emb=rotary_pos_emb,
274
- use_cache=use_cache,
245
+ cos=cos,
246
+ sin=sin,
247
+ batch_index=batch_ids,
275
248
  )
249
+ past_key_value.assign(k, v, layer_idx)
276
250
 
277
- attn_output = attn_outputs[0]
278
- outputs = attn_outputs[1:]
279
-
280
- hidden_states = attn_output + residual
251
+ hidden_states = hidden_states + residual
281
252
 
282
253
  residual = hidden_states
283
254
  if self.use_layernorm1p:
284
255
  hidden_states = layernorm1p(self.ln_2, hidden_states)
285
256
  else:
286
257
  hidden_states = self.ln_2(hidden_states)
287
- feed_forward_hidden_states = self.mlp(hidden_states)
288
258
 
259
+ feed_forward_hidden_states = self.mlp(hidden_states)
289
260
  hidden_states = residual + feed_forward_hidden_states
290
261
 
291
- if use_cache:
292
- outputs = (hidden_states,) + outputs
293
- else:
294
- outputs = (hidden_states,) + outputs[1:]
295
-
296
- if use_cache:
297
- outputs += (present_key_value,)
262
+ return hidden_states, past_key_value
298
263
 
299
- return outputs
300
264
 
265
+ class _MidmModel:
266
+ """Custom implementation of the MidmModel class with specific modifications."""
301
267
 
302
- class _MidmModel(MidmModel):
303
268
  def forward(
304
269
  self,
305
270
  input_ids: Optional[torch.LongTensor] = None,
306
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
271
+ past_key_values: Optional[RebelDynamicCache_4D] = None,
307
272
  attention_mask: Optional[torch.FloatTensor] = None,
308
- token_type_ids=None,
309
273
  position_ids: Optional[torch.LongTensor] = None,
310
274
  rotary_pos_emb=None,
311
- head_mask: Optional[torch.FloatTensor] = None,
312
- inputs_embeds=None,
313
- use_cache: Optional[bool] = None,
314
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
275
+ batch_ids: Optional[torch.LongTensor] = None,
276
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
277
+ """Defines the forward pass for the model."""
315
278
  input_shape = input_ids.size()
316
279
 
317
- use_cache = use_cache if use_cache is not None else self.config.use_cache
318
-
319
- current_step = position_ids
320
-
321
- if input_ids is not None and inputs_embeds is not None:
322
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
323
- elif input_ids is not None:
324
- batch_size, seq_length = input_ids.shape[:2]
325
- elif inputs_embeds is not None:
326
- batch_size, seq_length = inputs_embeds.shape[:2]
327
- else:
328
- raise ValueError("You have to specify either input_ids or inputs_embeds")
329
-
330
- if token_type_ids is not None:
331
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
332
-
333
- past_key_values_length = 0
334
-
335
- if use_cache:
336
- use_legacy_cache = not isinstance(past_key_values, Cache)
337
- if use_legacy_cache:
338
- past_key_values = RebelDynamicCache.from_legacy_cache(
339
- current_step=current_step,
340
- max_length=self.config.max_position_embeddings,
341
- past_key_values=past_key_values,
342
- )
343
-
344
- position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
345
-
346
- if position_ids is None:
347
- device = input_ids.device if input_ids is not None else inputs_embeds.device
348
- position_ids = torch.arange(
349
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
350
- )
351
- position_ids = position_ids.unsqueeze(0)
352
-
353
280
  attention_mask = (1.0 - attention_mask) * -10000.0
354
281
 
355
- if self.use_rotary_position_embedding:
356
- rotary_pos_emb = rotary_pos_emb(self.config.max_position_embeddings)
282
+ inputs_embeds = self.wte(input_ids)
357
283
 
358
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
359
-
360
- if inputs_embeds is None:
361
- inputs_embeds = self.wte(input_ids)
284
+ cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
285
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
362
286
  hidden_states = inputs_embeds
363
287
 
364
- if token_type_ids is not None:
365
- token_type_embeds = self.wte(token_type_ids)
366
- hidden_states = hidden_states + token_type_embeds
367
-
368
- hidden_states = self.drop(hidden_states)
369
-
370
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
371
-
372
- next_decoder_cache = () if use_cache else None
373
-
374
- for i, (block, _) in enumerate(zip(self.h, past_key_values)):
375
- outputs = _MidmBlock.forward(
288
+ for layer_idx, (block, _) in enumerate(zip(self.h, past_key_values)):
289
+ hidden_states, updated_cache = _MidmBlock.forward(
376
290
  block,
377
291
  hidden_states,
292
+ layer_idx,
378
293
  attention_mask=attention_mask,
379
- position_ids=position_ids,
380
294
  past_key_value=past_key_values,
381
- head_mask=head_mask[i],
382
- rotary_pos_emb=rotary_pos_emb,
383
- use_cache=use_cache,
295
+ batch_ids=batch_ids,
296
+ cos=cos,
297
+ sin=sin,
384
298
  )
385
- hidden_states = outputs[0]
386
-
387
- if use_cache:
388
- next_decoder_cache = outputs[2]
389
299
 
390
- if self.use_layernorm1p:
391
- hidden_states = layernorm1p(self.ln_f, hidden_states)
392
- else:
393
- hidden_states = self.ln_f(hidden_states)
300
+ hidden_states = layernorm1p(self.ln_f, hidden_states)
301
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
394
302
  hidden_states = hidden_states.view(output_shape)
395
303
 
396
- next_cache = None
397
- if use_cache:
398
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
399
-
400
- # return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
401
- return hidden_states, next_cache
402
-
403
-
404
- class RebelDynamicCache(DynamicCache):
405
- """
406
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
304
+ next_cache = updated_cache.to_legacy_cache()
407
305
 
408
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
409
- `[batch_size, num_heads, seq_len, head_dim]`.
410
- """
411
-
412
- def __init__(self, current_step, max_length) -> None:
413
- super().__init__()
414
- self.current_step = current_step
415
- self.max_length = max_length
416
-
417
- def copy(
418
- self,
419
- key_states: torch.Tensor,
420
- value_states: torch.Tensor,
421
- layer_idx: int,
422
- ) -> Tuple[torch.Tensor, torch.Tensor]:
423
- """
424
- Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
425
- just for from_legacy_cache function
426
-
427
- Parameters:
428
- key_states (`torch.Tensor`):
429
- The new key states to cache.
430
- value_states (`torch.Tensor`):
431
- The new value states to cache.
432
- layer_idx (`int`):
433
- The index of the layer to cache the states for.
434
- cache_kwargs (`Dict[str, Any]`, `optional`):
435
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
436
-
437
- Return:
438
- A tuple containing the updated key and value states.
439
- """
440
-
441
- if len(self.key_cache) <= layer_idx:
442
- self.key_cache.append(key_states)
443
- self.value_cache.append(value_states)
444
- else:
445
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
446
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
447
-
448
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
449
-
450
- def update(
451
- self,
452
- key_states: torch.Tensor,
453
- value_states: torch.Tensor,
454
- layer_idx: int,
455
- ) -> Tuple[torch.Tensor, torch.Tensor]:
456
- """
457
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
458
- based on self.current_step,
459
-
460
- Parameters:
461
- key_states (`torch.Tensor`):
462
- The new key states to cache.
463
- value_states (`torch.Tensor`):
464
- The new value states to cache.
465
- layer_idx (`int`):
466
- The index of the layer to cache the states for.
467
- cache_kwargs (`Dict[str, Any]`, `optional`):
468
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
469
-
470
- Return:
471
- A tuple containing the updated key and value states.
472
- """
473
-
474
- if len(self.key_cache) <= layer_idx:
475
- self.key_cache.append(key_states)
476
- self.value_cache.append(value_states)
477
- else:
478
- self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
479
- key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
480
- )
481
- self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
482
- value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
483
- )
484
-
485
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
486
-
487
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
488
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
489
- if len(self.key_cache) <= layer_idx:
490
- return 0
491
- return self.key_cache[layer_idx].shape[-2]
492
-
493
- def get_max_length(self) -> Optional[int]:
494
- return self.max_length
495
-
496
- @classmethod
497
- def from_legacy_cache(
498
- cls, current_step, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
499
- ) -> "DynamicCache":
500
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
501
- cache = cls(current_step, max_length)
502
- if past_key_values is not None:
503
- for layer_idx in range(len(past_key_values)):
504
- key_states, value_states = past_key_values[layer_idx]
505
- cache.copy(key_states, value_states, layer_idx)
506
- return cache
306
+ return BaseModelOutputWithPast(
307
+ last_hidden_state=hidden_states,
308
+ past_key_values=next_cache,
309
+ )