optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -53,7 +53,7 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
53
53
 
54
54
  @classmethod
55
55
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
- rbln_max_seq_len = rbln_config.meta["rbln_max_seq_len"]
56
+ rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
57
57
  return MistralForCausalLMWrapper(model, rbln_max_seq_len).eval()
58
58
 
59
59
  def __getattr__(self, __name: str) -> Any:
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from .modeling_phi import RBLNPhiForCausalLM
@@ -0,0 +1,69 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Callable
27
+
28
+ from transformers import PhiForCausalLM
29
+
30
+ from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
31
+ from .phi_architecture import PhiWrapper
32
+
33
+
34
+ if TYPE_CHECKING:
35
+ from transformers import PreTrainedModel
36
+
37
+ from ....modeling_config import RBLNConfig
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
+ """
44
+ The Phi Model transformer with a language modeling head (linear layer) on top.
45
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
+
47
+ A class to convert and run pre-trained transformers based PhiForCausalLM model on RBLN devices.
48
+ It implements the methods to convert a pre-trained transformers PhiForCausalLM model into a RBLN transformer model by:
49
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
50
+ - compiling the resulting graph using the RBLN compiler.
51
+ """
52
+
53
+ @classmethod
54
+ def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
+ rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
+ return PhiWrapper(model, rbln_max_seq_len).eval()
57
+
58
+ def __getattr__(self, __name: str) -> Any:
59
+ def redirect(func):
60
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
+
62
+ val = getattr(PhiForCausalLM, __name)
63
+
64
+ if isinstance(val, Callable) and "self" in set(
65
+ inspect.signature(val).parameters
66
+ ):
67
+ return redirect(val)
68
+
69
+ return val
@@ -0,0 +1,406 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import math
25
+ from typing import Dict, Optional, Tuple
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPast,
31
+ )
32
+
33
+ from ...cache_utils import RebelDynamicCache
34
+ from ..decoderonly import (
35
+ DecoderOnlyWrapper,
36
+ DynamicNTKScalingRotaryEmbedding,
37
+ LinearScalingRotaryEmbedding,
38
+ RotaryEmbedding,
39
+ apply_rotary_pos_emb,
40
+ slice_and_unsqueeze_cos_sin,
41
+ )
42
+
43
+
44
+ class PhiWrapper(DecoderOnlyWrapper):
45
+ def _init_rope(self):
46
+ if self.rope_scaling is None:
47
+ rotary_emb = RotaryEmbedding(
48
+ int(self.config.partial_rotary_factor * self.head_dim),
49
+ max_position_embeddings=self.max_position_embeddings,
50
+ base=self.config.rope_theta,
51
+ )
52
+ else:
53
+ scaling_type = self.rope_scaling["type"]
54
+ scaling_factor = self.rope_scaling["factor"]
55
+ if scaling_type == "linear":
56
+ rotary_emb = LinearScalingRotaryEmbedding(
57
+ int(self.config.partial_rotary_factor * self.head_dim),
58
+ max_position_embeddings=self.max_position_embeddings,
59
+ scaling_factor=scaling_factor,
60
+ base=self.config.rope_theta,
61
+ max_seq_len=self.max_seq_len,
62
+ )
63
+ elif scaling_type == "dynamic":
64
+ rotary_emb = DynamicNTKScalingRotaryEmbedding(
65
+ int(self.config.partial_rotary_factor * self.head_dim),
66
+ max_position_embeddings=self.max_position_embeddings,
67
+ scaling_factor=scaling_factor,
68
+ base=self.config.rope_theta,
69
+ max_seq_len=self.max_seq_len,
70
+ )
71
+ else:
72
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
73
+
74
+ return rotary_emb
75
+
76
+ def get_forward_dict(self):
77
+ forward_dict = {}
78
+ forward_dict.update(
79
+ {
80
+ "wrapper": PhiModel.forward,
81
+ "model": PhiDecoderLayer.forward,
82
+ "decoder_layer": PhiAttention.forward,
83
+ }
84
+ )
85
+ return forward_dict
86
+
87
+
88
+ class PhiAttention:
89
+ def forward(
90
+ self,
91
+ hidden_states: torch.Tensor,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ past_key_value: Optional[RebelDynamicCache] = None,
94
+ batch_index: Optional[int] = None,
95
+ output_attentions: bool = False,
96
+ cos: Optional[torch.Tensor] = None,
97
+ sin: Optional[torch.Tensor] = None,
98
+ rotary_pos_emb=None,
99
+ **kwargs,
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
101
+ bsz, q_len, _ = hidden_states.size()
102
+
103
+ query_states = self.q_proj(hidden_states)
104
+ key_states = self.k_proj(hidden_states)
105
+ value_states = self.v_proj(hidden_states)
106
+
107
+ if self.qk_layernorm:
108
+ query_states = self.q_layernorm(query_states)
109
+ key_states = self.k_layernorm(key_states)
110
+
111
+ query_states = query_states.view(
112
+ bsz, q_len, self.num_heads, self.head_dim
113
+ ).transpose(1, 2)
114
+ key_states = key_states.view(
115
+ bsz, q_len, self.num_key_value_heads, self.head_dim
116
+ ).transpose(1, 2)
117
+ value_states = value_states.view(
118
+ bsz, q_len, self.num_key_value_heads, self.head_dim
119
+ ).transpose(1, 2)
120
+
121
+ # Partial rotary embedding
122
+ query_rot, query_pass = (
123
+ query_states[..., : rotary_pos_emb.dim],
124
+ query_states[..., rotary_pos_emb.dim :],
125
+ )
126
+ key_rot, key_pass = (
127
+ key_states[..., : rotary_pos_emb.dim],
128
+ key_states[..., rotary_pos_emb.dim :],
129
+ )
130
+
131
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
132
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
133
+
134
+ # [batch_size, seq_length, num_heads, head_dim]
135
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
136
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
137
+
138
+ # Decoder
139
+ if (batch_index is None or batch_index == -1) and bsz > 1:
140
+ all_key_states = []
141
+ all_value_states = []
142
+ all_attn_output = []
143
+
144
+ for b in range(bsz):
145
+ query_state = query_states[b].unsqueeze(0)
146
+ attn_mask = attention_mask[b].unsqueeze(0)
147
+ key_state = key_states[b].unsqueeze(0)
148
+ value_state = value_states[b].unsqueeze(0)
149
+
150
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
151
+ key_state = key_state.unsqueeze(2)
152
+ value_state = value_state.unsqueeze(2)
153
+ attn_mask = attn_mask.unsqueeze(2)
154
+
155
+ query_state = query_state.view(
156
+ 1,
157
+ self.num_key_value_heads,
158
+ self.num_heads // self.num_key_value_heads,
159
+ q_len,
160
+ self.head_dim,
161
+ )
162
+
163
+ key_state, value_state = past_key_value.update(
164
+ key_state,
165
+ value_state,
166
+ self.layer_idx,
167
+ b,
168
+ )
169
+
170
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
171
+ attn_weights = torch.matmul(
172
+ query_state.to(torch.float32),
173
+ key_state.to(torch.float32).transpose(3, 4),
174
+ ) / math.sqrt(self.head_dim)
175
+ attn_weights = attn_weights + attn_mask
176
+
177
+ # upcast attention to fp32
178
+ attn_weights = nn.functional.softmax(
179
+ attn_weights, dim=-1, dtype=torch.float32
180
+ ).to(query_states.dtype)
181
+ attn_weights = nn.functional.dropout(
182
+ attn_weights, p=self.attention_dropout, training=self.training
183
+ )
184
+ attn_output = torch.matmul(attn_weights, value_state)
185
+
186
+ # reshape for removing repeat_kv
187
+ attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
188
+ attn_output = attn_output.transpose(1, 2).contiguous()
189
+ attn_output = attn_output.reshape(
190
+ 1, q_len, self.num_heads * self.head_dim
191
+ )
192
+
193
+ all_key_states.append(key_state)
194
+ all_value_states.append(value_state)
195
+ all_attn_output.append(attn_output)
196
+
197
+ key_states = torch.cat(all_key_states, dim=0)
198
+ value_states = torch.cat(all_value_states, dim=0)
199
+ attn_output = torch.cat(all_attn_output, dim=0)
200
+ else:
201
+ if batch_index is None or batch_index == -1:
202
+ batch_index = 0
203
+
204
+ # reshape for removing repeat_kv
205
+ key_states = key_states.unsqueeze(2)
206
+ value_states = value_states.unsqueeze(2)
207
+ attention_mask = attention_mask.unsqueeze(2)
208
+ query_states = query_states.view(
209
+ 1,
210
+ self.num_key_value_heads,
211
+ self.num_heads // self.num_key_value_heads,
212
+ q_len,
213
+ self.head_dim,
214
+ )
215
+
216
+ key_states, value_states = past_key_value.update(
217
+ key_states,
218
+ value_states,
219
+ self.layer_idx,
220
+ batch_index,
221
+ read_first_step=True,
222
+ )
223
+
224
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
225
+ attn_weights = torch.matmul(
226
+ query_states.to(torch.float32),
227
+ key_states.to(torch.float32).transpose(3, 4),
228
+ ) / math.sqrt(self.head_dim)
229
+ attn_weights = attn_weights + attention_mask
230
+
231
+ # upcast attention to fp32
232
+ attn_weights = torch.nn.functional.softmax(
233
+ attn_weights, dim=-1, dtype=torch.float32
234
+ ).to(value_states.dtype)
235
+ attn_weights = torch.nn.functional.dropout(
236
+ attn_weights, p=self.attention_dropout, training=self.training
237
+ )
238
+ attn_output = torch.matmul(attn_weights, value_states)
239
+
240
+ # reshape for removing repeat_kv
241
+ attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
242
+ attn_output = attn_output.transpose(1, 2).contiguous()
243
+ attn_output = attn_output.reshape(
244
+ bsz, q_len, self.num_heads * self.head_dim
245
+ )
246
+
247
+ attn_output = self.dense(attn_output)
248
+
249
+ if not output_attentions:
250
+ attn_weights = None
251
+
252
+ return attn_output, attn_weights, key_states, value_states
253
+
254
+
255
+ class PhiDecoderLayer:
256
+ def forward(
257
+ self,
258
+ hidden_states: torch.Tensor,
259
+ layer_idx: int,
260
+ attention_mask: Optional[torch.Tensor] = None,
261
+ position_ids: Optional[torch.LongTensor] = None,
262
+ past_key_value: Optional[RebelDynamicCache] = None,
263
+ output_attentions: Optional[bool] = None,
264
+ use_cache: Optional[bool] = None,
265
+ batch_ids: Optional[torch.LongTensor] = None,
266
+ cos: Optional[torch.Tensor] = None,
267
+ sin: Optional[torch.Tensor] = None,
268
+ rotary_pos_emb=None,
269
+ forward_dict: Optional[Dict[str, classmethod]] = None,
270
+ **kwargs,
271
+ ) -> Tuple[
272
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
273
+ ]:
274
+ """
275
+ Args:
276
+ hidden_states (`torch.FloatTensor`):
277
+ input to the layer of shape `(batch, seq_len, embed_dim)`
278
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
279
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
280
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
281
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
282
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
283
+ output_attentions (`bool`, *optional*):
284
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
285
+ returned tensors for more detail.
286
+ use_cache (`bool`, *optional*):
287
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
288
+ (see `past_key_values`).
289
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
290
+ """
291
+
292
+ residual = hidden_states
293
+
294
+ hidden_states = self.input_layernorm(hidden_states)
295
+
296
+ # Self Attention
297
+ attn_outputs, self_attn_weights, key_states, value_states = forward_dict[
298
+ "decoder_layer"
299
+ ](
300
+ self.self_attn,
301
+ hidden_states=hidden_states,
302
+ attention_mask=attention_mask,
303
+ position_ids=position_ids,
304
+ past_key_value=past_key_value,
305
+ output_attentions=output_attentions,
306
+ batch_index=batch_ids,
307
+ use_cache=use_cache,
308
+ cos=cos,
309
+ sin=sin,
310
+ rotary_pos_emb=rotary_pos_emb,
311
+ **kwargs,
312
+ )
313
+ past_key_value.assign(key_states, value_states, layer_idx)
314
+
315
+ attn_outputs = self.resid_dropout(attn_outputs)
316
+
317
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
318
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
319
+ outputs = (hidden_states,)
320
+
321
+ if output_attentions:
322
+ outputs += (self_attn_weights,)
323
+
324
+ if use_cache:
325
+ outputs += (past_key_value,)
326
+
327
+ return outputs
328
+
329
+
330
+ class PhiModel:
331
+ def forward(
332
+ self,
333
+ input_ids: torch.LongTensor = None,
334
+ attention_mask: Optional[torch.Tensor] = None,
335
+ position_ids: Optional[torch.LongTensor] = None,
336
+ past_key_values: Optional[RebelDynamicCache] = None,
337
+ batch_ids: Optional[torch.LongTensor] = None,
338
+ inputs_embeds: Optional[torch.FloatTensor] = None,
339
+ use_cache: Optional[bool] = True,
340
+ output_attentions: Optional[bool] = False,
341
+ output_hidden_states: Optional[bool] = False,
342
+ forward_dict: Optional[Dict[str, classmethod]] = None,
343
+ rotary_pos_emb=None,
344
+ ) -> BaseModelOutputWithPast:
345
+ # retrieve input_ids and inputs_embeds
346
+ if (input_ids is None) ^ (inputs_embeds is not None):
347
+ raise ValueError(
348
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
349
+ )
350
+
351
+ # embed positions
352
+ if inputs_embeds is None:
353
+ inputs_embeds = self.embed_tokens(input_ids)
354
+
355
+ hidden_states = inputs_embeds
356
+ attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
357
+
358
+ # get cos,sin vector
359
+ cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
360
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
361
+
362
+ # decoder layers
363
+ all_hidden_states = () if output_hidden_states else None
364
+ all_self_attns = () if output_attentions else None
365
+
366
+ for layer_idx, decoder_layer in enumerate(self.layers):
367
+ if output_hidden_states:
368
+ all_hidden_states += (hidden_states,)
369
+ layer_outputs = forward_dict["model"](
370
+ decoder_layer,
371
+ hidden_states,
372
+ layer_idx,
373
+ attention_mask=attention_mask,
374
+ position_ids=position_ids,
375
+ past_key_value=past_key_values,
376
+ output_attentions=output_attentions,
377
+ use_cache=use_cache,
378
+ batch_ids=batch_ids,
379
+ cos=cos,
380
+ sin=sin,
381
+ rotary_pos_emb=rotary_pos_emb,
382
+ forward_dict=forward_dict,
383
+ )
384
+
385
+ hidden_states = layer_outputs[0]
386
+
387
+ updated_cache = layer_outputs[2 if output_attentions else 1]
388
+
389
+ if output_attentions:
390
+ all_self_attns += (layer_outputs[1],)
391
+
392
+ hidden_states = self.final_layernorm(hidden_states)
393
+
394
+ # add hidden states from the last decoder layer
395
+ if output_hidden_states:
396
+ all_hidden_states += (hidden_states,)
397
+
398
+ # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
399
+ next_cache = updated_cache.to_legacy_cache()
400
+
401
+ return BaseModelOutputWithPast(
402
+ last_hidden_state=hidden_states,
403
+ past_key_values=next_cache,
404
+ hidden_states=all_hidden_states,
405
+ attentions=all_self_attns,
406
+ )