optimum-rbln 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. optimum/rbln/__init__.py +9 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  4. optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
  5. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
  8. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  9. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  10. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  12. optimum/rbln/modeling_base.py +175 -103
  13. optimum/rbln/modeling_seq2seq.py +58 -132
  14. optimum/rbln/transformers/__init__.py +4 -0
  15. optimum/rbln/transformers/models/__init__.py +2 -0
  16. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  17. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  18. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
  20. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
  21. optimum/rbln/transformers/models/llama/llama_architecture.py +62 -33
  22. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +764 -0
  23. optimum/rbln/transformers/models/llama/modeling_llama.py +208 -140
  24. optimum/rbln/transformers/models/midm/__init__.py +32 -0
  25. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
  26. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
  27. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
  29. optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
  30. optimum/rbln/transformers/models/midm/modeling_midm.py +390 -0
  31. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  32. optimum/rbln/utils/__init__.py +1 -1
  33. optimum/rbln/utils/import_utils.py +46 -0
  34. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -50
  35. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +37 -27
  36. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
  37. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,98 @@
1
+ # coding=utf-8
2
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import einsum, nn
19
+
20
+
21
+ __all__ = ["RotaryEmbedding", "apply_rotary_pos_emb"]
22
+
23
+
24
+ class RotaryEmbedding(nn.Module):
25
+ """
26
+ Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
27
+ """
28
+
29
+ def __init__(
30
+ self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
31
+ ):
32
+ """
33
+ Args:
34
+
35
+ dim (int): rotary embedding dimension
36
+ seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
37
+ by this factor via the trick in https://arxiv.org/abs/2306.15595.
38
+ pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
39
+ """
40
+ super().__init__()
41
+ self.seq_len_interpolation_factor = seq_len_interpolation_factor
42
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
43
+ self.register_buffer("inv_freq", inv_freq)
44
+ self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
45
+
46
+ def forward(self, max_seq_len, offset=0):
47
+ seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
48
+ seq = seq.type_as(self.inv_freq)
49
+
50
+ if self.pretrained_max_position_embeddings is not None and self.seq_len_interpolation_factor is not None:
51
+ if max_seq_len > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor:
52
+ # dynamic linear scaling (length > position we have learned)
53
+ seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
54
+ else:
55
+ # fixed linear scaling
56
+ seq *= 1 / self.seq_len_interpolation_factor
57
+
58
+ freqs = einsum("i , j -> i j", seq, self.inv_freq)
59
+ # first part even vector components, second part odd vector components,
60
+ # 2 * dim in dimension size
61
+ emb = torch.cat((freqs, freqs), dim=-1)
62
+ # emb [seq_length, .., dim]
63
+ return rearrange(emb, "n d -> n 1 1 d")
64
+
65
+
66
+ def _rotate_half(x):
67
+ """
68
+ change sign so the last dimension
69
+ [A, B, C, D] -> [-C, -D, A, B]
70
+ """
71
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
72
+ x1, x2 = x.unbind(dim=-2)
73
+ return torch.cat((-x2, x1), dim=-1)
74
+
75
+
76
+ def apply_rotary_pos_emb(t, freqs):
77
+ """
78
+ input tensor t is of shape [seq_length, ..., dim]
79
+ rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
80
+ check https://kexue.fm/archives/8265 for detailed formulas
81
+ """
82
+ # Changes from the original RoPE implementation
83
+ # 1. The original NeMo implementation assumes the input tensor of shape
84
+ # [seq_length, ..., dim], but the HF layout is [..., seq_length, dim].
85
+ # Thus freqs needs to be viewed as [..., seq_length, dim].
86
+ freqs = freqs.permute(1, 2, 0, 3)
87
+ # 2. Support for queries which past tokens are truncated
88
+ assert freqs.shape[-2] >= t.shape[-2]
89
+ if freqs.shape[-2] != t.shape[-2]:
90
+ freqs = freqs[:, :, -t.shape[-2] :, :]
91
+
92
+ rot_dim = freqs.shape[-1]
93
+ # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
94
+ t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
95
+ # first part is cosine component
96
+ # second part is sine component, need to change signs with _rotate_half method
97
+ t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
98
+ return torch.cat((t, t_pass), dim=-1)
@@ -0,0 +1,506 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+
25
+ from typing import Dict, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from transformers.cache_utils import Cache, DynamicCache
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPastAndCrossAttentions,
32
+ )
33
+
34
+ from .hf_hub_cached.modeling_midm import (
35
+ MidmAttention,
36
+ MidmBlock,
37
+ MidmModel,
38
+ )
39
+
40
+
41
+ class _MidmRotaryEmbedding(nn.Module):
42
+ """
43
+ Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
44
+ """
45
+
46
+ def __init__(
47
+ self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
48
+ ):
49
+ """
50
+ Args:
51
+
52
+ dim (int): rotary embedding dimension
53
+ seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
54
+ by this factor via the trick in https://arxiv.org/abs/2306.15595.
55
+ pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
56
+ """
57
+ super().__init__()
58
+ self.seq_len_interpolation_factor = seq_len_interpolation_factor
59
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
60
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
61
+ self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
62
+
63
+ seq_len = pretrained_max_position_embeddings
64
+ device = self.inv_freq.device
65
+ dtype = torch.get_default_dtype()
66
+ self.max_seq_len_cached = seq_len
67
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
68
+
69
+ freqs = torch.outer(t, self.inv_freq)
70
+
71
+ emb = torch.cat((freqs, freqs), dim=-1)
72
+ self.register_buffer("emb_cached", emb.to(dtype), persistent=False)
73
+
74
+ def forward(self, max_seq_len, offset=0):
75
+
76
+ if max_seq_len > self.max_seq_len_cached:
77
+ self._set_emb_cache(seq_len=max_seq_len)
78
+
79
+ return self.emb_cached[:max_seq_len]
80
+
81
+
82
+ def _rotate_half(x):
83
+ """
84
+ change sign so the last dimension
85
+ [A, B, C, D] -> [-C, -D, A, B]
86
+ """
87
+ x1 = x[..., : x.shape[-1] // 2]
88
+ x2 = x[..., x.shape[-1] // 2 :]
89
+ return torch.cat((-x2, x1), dim=-1)
90
+
91
+
92
+ def apply_rotary_pos_emb(t: torch.Tensor, cache_kwargs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
93
+ """
94
+ input tensor t is of shape [seq_length, ..., dim]
95
+ rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
96
+ check https://kexue.fm/archives/8265 for detailed formulas
97
+ """
98
+
99
+ freqs = cache_kwargs["rotary_pos_emb"]
100
+ position_ids = cache_kwargs["position_ids"]
101
+ unsqueeze_dim = 1
102
+
103
+ rot_dim = freqs.shape[-1]
104
+
105
+ t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
106
+ cos = freqs.cos()[position_ids].unsqueeze(unsqueeze_dim)
107
+ sin = freqs.sin()[position_ids].unsqueeze(unsqueeze_dim)
108
+
109
+ embed = (t * cos) + (_rotate_half(t) * sin)
110
+ embed = torch.cat((embed, t_pass), dim=-1)
111
+
112
+ return embed
113
+
114
+
115
+ class MidmLMHeadModelWrapper(torch.nn.Module):
116
+ def __init__(self, model):
117
+ super().__init__()
118
+ self.model = model
119
+ self.confg = model.config
120
+
121
+ self.use_rotary_position_embedding = model.config.use_rotary_position_embedding
122
+ if self.use_rotary_position_embedding:
123
+ rotary_dim = model.config.hidden_size // model.config.num_attention_heads
124
+ assert 0 < model.config.rotary_percentage <= 1
125
+ if model.config.rotary_percentage < 1:
126
+ rotary_dim = int(rotary_dim * model.config.rotary_percentage)
127
+ self._rotary_pos_emb = _MidmRotaryEmbedding(
128
+ rotary_dim,
129
+ seq_len_interpolation_factor=None,
130
+ pretrained_max_position_embeddings=model.config.max_position_embeddings,
131
+ )
132
+
133
+ def forward(
134
+ self,
135
+ input_ids: torch.Tensor,
136
+ attention_mask: torch.Tensor,
137
+ cache_position: torch.LongTensor,
138
+ *past_key_values,
139
+ ):
140
+ past_kv_list = []
141
+ for i in range(self.model.config.n_layer):
142
+ cur_kv_layer = []
143
+ for j in range(2):
144
+ cur_kv_layer.append(past_key_values[2 * i + j])
145
+ past_kv_list.append(cur_kv_layer)
146
+
147
+ transformer_outputs = _MidmModel.forward(
148
+ self.model.transformer,
149
+ input_ids=input_ids,
150
+ past_key_values=past_kv_list,
151
+ attention_mask=attention_mask,
152
+ position_ids=cache_position,
153
+ rotary_pos_emb=self._rotary_pos_emb,
154
+ )
155
+
156
+ hidden_states = transformer_outputs[0]
157
+
158
+ # For the input_ids, we assume right-alignment.
159
+ # This assumption allows us to bypass dynamic indexing.
160
+ hidden_states = hidden_states[:, -1:]
161
+ lm_logits = self.model.lm_head(hidden_states)
162
+ kv_cache = transformer_outputs[1]
163
+
164
+ return lm_logits, kv_cache
165
+
166
+
167
+ def layernorm1p(module, input):
168
+ return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
169
+
170
+
171
+ class _MidmAttention(MidmAttention):
172
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
173
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
174
+
175
+ if self.scale_attn_weights:
176
+ attn_weights = attn_weights / torch.full(
177
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
178
+ )
179
+
180
+ if self.scale_attn_by_inverse_layer_idx or self.scale_qk_by_inverse_layer_idx:
181
+ attn_weights = attn_weights / float(self.layer_idx + 1)
182
+
183
+ if attention_mask is not None:
184
+ attn_weights = attn_weights + attention_mask
185
+
186
+ if self.scale_qk_by_inverse_layer_idx:
187
+ attn_weights = attn_weights * float(self.layer_idx + 1)
188
+
189
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
190
+
191
+ attn_weights = attn_weights.type(value.dtype)
192
+
193
+ if head_mask is not None:
194
+ attn_weights = attn_weights * head_mask
195
+
196
+ attn_output = torch.matmul(attn_weights, value)
197
+
198
+ return attn_output, attn_weights
199
+
200
+ def forward(
201
+ self,
202
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
203
+ attention_mask: Optional[torch.FloatTensor] = None,
204
+ position_ids: Optional[torch.LongTensor] = None,
205
+ past_key_value: Optional[Cache] = None,
206
+ head_mask: Optional[torch.FloatTensor] = None,
207
+ use_cache: Optional[bool] = False,
208
+ rotary_pos_emb=None,
209
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
210
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
211
+ query = self._split_heads(query, self.num_heads, self.head_dim)
212
+ key = self._split_heads(key, self.num_heads, self.head_dim)
213
+ value = self._split_heads(value, self.num_heads, self.head_dim)
214
+
215
+ kv_seq_len = key.shape[-2]
216
+ if past_key_value is not None:
217
+ if self.layer_idx is None:
218
+ raise ValueError(
219
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
220
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
221
+ "with a layer index."
222
+ )
223
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
224
+
225
+ if use_cache is True:
226
+ present = (key, value)
227
+ else:
228
+ present = None
229
+
230
+ if rotary_pos_emb is not None:
231
+ query = apply_rotary_pos_emb(query, {"rotary_pos_emb": rotary_pos_emb, "position_ids": position_ids})
232
+ key = apply_rotary_pos_emb(key, {"rotary_pos_emb": rotary_pos_emb, "position_ids": position_ids})
233
+
234
+ if past_key_value is not None:
235
+ key, value = past_key_value.update(key, value, self.layer_idx)
236
+
237
+ attn_output, _ = _MidmAttention._attn(self, query, key, value, attention_mask, head_mask)
238
+
239
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
240
+
241
+ attn_output = self.c_proj(attn_output)
242
+ attn_output = self.resid_dropout(attn_output)
243
+
244
+ outputs = (attn_output, present)
245
+
246
+ return outputs, past_key_value
247
+
248
+
249
+ class _MidmBlock(MidmBlock):
250
+ def forward(
251
+ self,
252
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
253
+ attention_mask: Optional[torch.FloatTensor] = None,
254
+ position_ids: Optional[torch.LongTensor] = None,
255
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
256
+ head_mask: Optional[torch.FloatTensor] = None,
257
+ use_cache: Optional[bool] = False,
258
+ rotary_pos_emb=None,
259
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
260
+ residual = hidden_states
261
+ if self.use_layernorm1p:
262
+ hidden_states = layernorm1p(self.ln_1, hidden_states)
263
+ else:
264
+ hidden_states = self.ln_1(hidden_states)
265
+
266
+ attn_outputs, present_key_value = _MidmAttention.forward(
267
+ self.attn,
268
+ hidden_states,
269
+ attention_mask=attention_mask,
270
+ position_ids=position_ids,
271
+ past_key_value=past_key_value,
272
+ head_mask=head_mask,
273
+ rotary_pos_emb=rotary_pos_emb,
274
+ use_cache=use_cache,
275
+ )
276
+
277
+ attn_output = attn_outputs[0]
278
+ outputs = attn_outputs[1:]
279
+
280
+ hidden_states = attn_output + residual
281
+
282
+ residual = hidden_states
283
+ if self.use_layernorm1p:
284
+ hidden_states = layernorm1p(self.ln_2, hidden_states)
285
+ else:
286
+ hidden_states = self.ln_2(hidden_states)
287
+ feed_forward_hidden_states = self.mlp(hidden_states)
288
+
289
+ hidden_states = residual + feed_forward_hidden_states
290
+
291
+ if use_cache:
292
+ outputs = (hidden_states,) + outputs
293
+ else:
294
+ outputs = (hidden_states,) + outputs[1:]
295
+
296
+ if use_cache:
297
+ outputs += (present_key_value,)
298
+
299
+ return outputs
300
+
301
+
302
+ class _MidmModel(MidmModel):
303
+ def forward(
304
+ self,
305
+ input_ids: Optional[torch.LongTensor] = None,
306
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
307
+ attention_mask: Optional[torch.FloatTensor] = None,
308
+ token_type_ids=None,
309
+ position_ids: Optional[torch.LongTensor] = None,
310
+ rotary_pos_emb=None,
311
+ head_mask: Optional[torch.FloatTensor] = None,
312
+ inputs_embeds=None,
313
+ use_cache: Optional[bool] = None,
314
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
315
+ input_shape = input_ids.size()
316
+
317
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
318
+
319
+ current_step = position_ids
320
+
321
+ if input_ids is not None and inputs_embeds is not None:
322
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
323
+ elif input_ids is not None:
324
+ batch_size, seq_length = input_ids.shape[:2]
325
+ elif inputs_embeds is not None:
326
+ batch_size, seq_length = inputs_embeds.shape[:2]
327
+ else:
328
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
329
+
330
+ if token_type_ids is not None:
331
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
332
+
333
+ past_key_values_length = 0
334
+
335
+ if use_cache:
336
+ use_legacy_cache = not isinstance(past_key_values, Cache)
337
+ if use_legacy_cache:
338
+ past_key_values = RebelDynamicCache.from_legacy_cache(
339
+ current_step=current_step,
340
+ max_length=self.config.max_position_embeddings,
341
+ past_key_values=past_key_values,
342
+ )
343
+
344
+ position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
345
+
346
+ if position_ids is None:
347
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
348
+ position_ids = torch.arange(
349
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
350
+ )
351
+ position_ids = position_ids.unsqueeze(0)
352
+
353
+ attention_mask = (1.0 - attention_mask) * -10000.0
354
+
355
+ if self.use_rotary_position_embedding:
356
+ rotary_pos_emb = rotary_pos_emb(self.config.max_position_embeddings)
357
+
358
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
359
+
360
+ if inputs_embeds is None:
361
+ inputs_embeds = self.wte(input_ids)
362
+ hidden_states = inputs_embeds
363
+
364
+ if token_type_ids is not None:
365
+ token_type_embeds = self.wte(token_type_ids)
366
+ hidden_states = hidden_states + token_type_embeds
367
+
368
+ hidden_states = self.drop(hidden_states)
369
+
370
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
371
+
372
+ next_decoder_cache = () if use_cache else None
373
+
374
+ for i, (block, _) in enumerate(zip(self.h, past_key_values)):
375
+ outputs = _MidmBlock.forward(
376
+ block,
377
+ hidden_states,
378
+ attention_mask=attention_mask,
379
+ position_ids=position_ids,
380
+ past_key_value=past_key_values,
381
+ head_mask=head_mask[i],
382
+ rotary_pos_emb=rotary_pos_emb,
383
+ use_cache=use_cache,
384
+ )
385
+ hidden_states = outputs[0]
386
+
387
+ if use_cache:
388
+ next_decoder_cache = outputs[2]
389
+
390
+ if self.use_layernorm1p:
391
+ hidden_states = layernorm1p(self.ln_f, hidden_states)
392
+ else:
393
+ hidden_states = self.ln_f(hidden_states)
394
+ hidden_states = hidden_states.view(output_shape)
395
+
396
+ next_cache = None
397
+ if use_cache:
398
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
399
+
400
+ # return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
401
+ return hidden_states, next_cache
402
+
403
+
404
+ class RebelDynamicCache(DynamicCache):
405
+ """
406
+ A cache that grows dynamically as more tokens are generated. This is the default for generative models.
407
+
408
+ It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
409
+ `[batch_size, num_heads, seq_len, head_dim]`.
410
+ """
411
+
412
+ def __init__(self, current_step, max_length) -> None:
413
+ super().__init__()
414
+ self.current_step = current_step
415
+ self.max_length = max_length
416
+
417
+ def copy(
418
+ self,
419
+ key_states: torch.Tensor,
420
+ value_states: torch.Tensor,
421
+ layer_idx: int,
422
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
423
+ """
424
+ Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
425
+ just for from_legacy_cache function
426
+
427
+ Parameters:
428
+ key_states (`torch.Tensor`):
429
+ The new key states to cache.
430
+ value_states (`torch.Tensor`):
431
+ The new value states to cache.
432
+ layer_idx (`int`):
433
+ The index of the layer to cache the states for.
434
+ cache_kwargs (`Dict[str, Any]`, `optional`):
435
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
436
+
437
+ Return:
438
+ A tuple containing the updated key and value states.
439
+ """
440
+
441
+ if len(self.key_cache) <= layer_idx:
442
+ self.key_cache.append(key_states)
443
+ self.value_cache.append(value_states)
444
+ else:
445
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
446
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
447
+
448
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
449
+
450
+ def update(
451
+ self,
452
+ key_states: torch.Tensor,
453
+ value_states: torch.Tensor,
454
+ layer_idx: int,
455
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
456
+ """
457
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
458
+ based on self.current_step,
459
+
460
+ Parameters:
461
+ key_states (`torch.Tensor`):
462
+ The new key states to cache.
463
+ value_states (`torch.Tensor`):
464
+ The new value states to cache.
465
+ layer_idx (`int`):
466
+ The index of the layer to cache the states for.
467
+ cache_kwargs (`Dict[str, Any]`, `optional`):
468
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
469
+
470
+ Return:
471
+ A tuple containing the updated key and value states.
472
+ """
473
+
474
+ if len(self.key_cache) <= layer_idx:
475
+ self.key_cache.append(key_states)
476
+ self.value_cache.append(value_states)
477
+ else:
478
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
479
+ key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
480
+ )
481
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
482
+ value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
483
+ )
484
+
485
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
486
+
487
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
488
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
489
+ if len(self.key_cache) <= layer_idx:
490
+ return 0
491
+ return self.key_cache[layer_idx].shape[-2]
492
+
493
+ def get_max_length(self) -> Optional[int]:
494
+ return self.max_length
495
+
496
+ @classmethod
497
+ def from_legacy_cache(
498
+ cls, current_step, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
499
+ ) -> "DynamicCache":
500
+ """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
501
+ cache = cls(current_step, max_length)
502
+ if past_key_values is not None:
503
+ for layer_idx in range(len(past_key_values)):
504
+ key_states, value_states = past_key_values[layer_idx]
505
+ cache.copy(key_states, value_states, layer_idx)
506
+ return cache