optimum-rbln 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. optimum/rbln/__init__.py +7 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  4. optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
  5. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
  8. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  9. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  10. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  12. optimum/rbln/modeling_base.py +172 -100
  13. optimum/rbln/modeling_seq2seq.py +58 -132
  14. optimum/rbln/transformers/__init__.py +2 -0
  15. optimum/rbln/transformers/models/__init__.py +1 -0
  16. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  17. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  18. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
  20. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
  21. optimum/rbln/transformers/models/llama/llama_architecture.py +13 -16
  22. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +41 -36
  23. optimum/rbln/transformers/models/llama/modeling_llama.py +94 -120
  24. optimum/rbln/transformers/models/midm/modeling_midm.py +85 -121
  25. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  26. optimum/rbln/utils/__init__.py +1 -1
  27. optimum/rbln/utils/import_utils.py +46 -0
  28. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -51
  29. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +31 -29
  30. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
  31. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -23,20 +23,16 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
27
 
30
28
  import rebel
31
29
  import torch
32
- from optimum.exporters import TasksManager
33
30
  from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
34
31
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
35
32
 
36
- from ....modeling_base import RBLNBaseModel
33
+ from ....modeling_base import RBLNModel
37
34
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
38
35
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
- from ....utils.save_utils import maybe_save_preprocessors
40
36
  from ...generation.utils import RBLNGenerationMixin
41
37
  from .gpt2_architecture import GPT2LMHeadModelWrapper
42
38
 
@@ -59,12 +55,12 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
59
55
  return Seq2SeqLMOutput(logits=logits)
60
56
 
61
57
 
62
- class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
58
+ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
63
59
  """
64
60
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
65
61
  embeddings).
66
62
 
67
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
63
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
68
64
  library implements for all its model.
69
65
 
70
66
  It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
@@ -89,8 +85,8 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
89
85
  torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
90
86
  )
91
87
 
92
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0])
93
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1])
88
+ self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0])
89
+ self.decoder = RBLNRuntimeDecoder(runtime=self.model[1])
94
90
  self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
95
91
  self.past_cached_length = 0
96
92
 
@@ -117,38 +113,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
117
113
  raise NotImplementedError
118
114
 
119
115
  @classmethod
120
- def _export(
121
- cls,
122
- model_id: str,
123
- config: "PretrainedConfig",
124
- use_auth_token: Optional[Union[bool, str]] = None,
125
- revision: Optional[str] = None,
126
- force_download: bool = False,
127
- cache_dir: Optional[str] = None,
128
- subfolder: str = "",
129
- local_files_only: bool = False,
130
- trust_remote_code: bool = False,
131
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
132
- **kwargs,
133
- ) -> "RBLNGPT2LMHeadModel":
134
- """
135
- Exports a vanilla Transformers model into a rbln-compiled Module.
136
- """
137
- task = kwargs.pop("task", None)
138
- if task is None:
139
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
140
-
141
- if model_save_dir is None:
142
- save_dir = TemporaryDirectory()
143
- save_dir_path = Path(save_dir.name)
144
- else:
145
- save_dir = model_save_dir
146
- if isinstance(save_dir, TemporaryDirectory):
147
- save_dir_path = Path(model_save_dir.name)
148
- else:
149
- save_dir_path = Path(model_save_dir)
150
- save_dir_path.mkdir(exist_ok=True)
151
-
116
+ def update_kwargs(cls, kwargs):
152
117
  kwargs.update(
153
118
  {
154
119
  "torchscript": True,
@@ -156,82 +121,45 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
156
121
  "use_cache": True,
157
122
  }
158
123
  )
124
+ return kwargs
159
125
 
160
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
161
-
162
- model: GPT2LMHeadModel = TasksManager.get_model_from_task(
163
- task=task,
164
- model_name_or_path=model_id,
165
- subfolder=subfolder,
166
- revision=revision,
167
- framework="pt",
168
- cache_dir=cache_dir,
169
- use_auth_token=use_auth_token,
170
- local_files_only=local_files_only,
171
- force_download=force_download,
172
- trust_remote_code=trust_remote_code,
173
- **kwargs,
174
- )
175
-
176
- if config is None:
177
- config = model.config
178
-
179
- config.save_pretrained(save_dir_path)
180
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
181
-
182
- # Get compilation arguments
183
- if rbln_config_kwargs.get("rbln_config", None) is None:
184
- rbln_config = cls.get_rbln_config(
185
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
186
- )
187
-
188
- def compile_gpt2():
189
- wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
190
-
191
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
192
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
126
+ @classmethod
127
+ @torch.inference_mode()
128
+ def get_compiled_model(cls, model: GPT2LMHeadModel, rbln_config: RBLNConfig):
129
+ wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
193
130
 
194
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
195
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
131
+ prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
132
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
196
133
 
197
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
198
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
134
+ prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
135
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
199
136
 
200
- prefill_ir = rebel.torchscript_to_ir(
201
- prefill_scripted_model,
202
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
203
- )
204
- dec_ir = rebel.torchscript_to_ir(
205
- dec_scripted_model,
206
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
207
- )
208
-
209
- connections = [
210
- (prefill_ir.outputs[1], prefill_ir.inputs[1]),
211
- ]
137
+ prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
138
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
212
139
 
213
- compiled_model = rebel.compile(
214
- prefill_ir,
215
- dec_ir,
216
- connections=connections,
217
- fusion=prefill_rbln_runtime_config.fusion,
218
- npu=prefill_rbln_runtime_config.npu,
219
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
220
- use_weight_sharing=True,
221
- )
222
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
140
+ prefill_ir = rebel.torchscript_to_ir(
141
+ prefill_scripted_model,
142
+ input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
143
+ )
144
+ dec_ir = rebel.torchscript_to_ir(
145
+ dec_scripted_model,
146
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
147
+ )
223
148
 
224
- compile_gpt2()
225
- rbln_config.save(save_dir_path)
149
+ connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
226
150
 
227
- return cls._from_pretrained(
228
- model_id=save_dir_path,
229
- config=config,
230
- model_save_dir=save_dir,
231
- **rbln_constructor_kwargs,
232
- **kwargs,
151
+ compiled_model = rebel.compile(
152
+ prefill_ir,
153
+ dec_ir,
154
+ connections=connections,
155
+ fusion=prefill_rbln_runtime_config.fusion,
156
+ npu=prefill_rbln_runtime_config.npu,
157
+ tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
158
+ use_weight_sharing=True,
233
159
  )
234
160
 
161
+ return compiled_model
162
+
235
163
  @classmethod
236
164
  def _get_rbln_config(
237
165
  cls,
@@ -271,24 +199,24 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
271
199
  def get_input_info(query_length):
272
200
  return [
273
201
  ("input_ids", [rbln_batch_size, query_length], "int64"),
202
+ ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
203
+ (
204
+ "cache_position",
205
+ [],
206
+ "int32",
207
+ ),
208
+ ] + [
274
209
  (
275
- "past_key_values",
210
+ f"past_key_values_{i}",
276
211
  [
277
- model_config.n_layer,
278
- 2,
279
212
  rbln_batch_size,
280
213
  model_config.n_head,
281
214
  rbln_max_seq_len,
282
215
  model_config.hidden_size // model_config.n_head,
283
216
  ],
284
217
  "float32",
285
- ),
286
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
287
- (
288
- "cache_position",
289
- [],
290
- "int32",
291
- ),
218
+ )
219
+ for i in range(model_config.n_layer * 2)
292
220
  ]
293
221
 
294
222
  # model input info
@@ -305,11 +233,14 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
305
233
 
306
234
  return rbln_config
307
235
 
308
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
236
+ @classmethod
237
+ def _create_runtimes(
238
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
239
+ ) -> List[rebel.Runtime]:
309
240
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
310
241
  return [
311
- self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
312
- self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
242
+ compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
243
+ compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
313
244
  ]
314
245
 
315
246
  def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
@@ -386,6 +317,3 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
386
317
  output = output.logits
387
318
 
388
319
  return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
389
-
390
- def __repr__(self):
391
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
@@ -107,7 +107,6 @@ class _LlamaAttention(LlamaAttention):
107
107
  use_cache: bool = False,
108
108
  **kwargs,
109
109
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
110
-
111
110
  bsz, q_len, _ = hidden_states.size()
112
111
 
113
112
  if self.config.pretraining_tp > 1:
@@ -227,7 +226,6 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
227
226
  use_cache: Optional[bool] = False,
228
227
  **kwargs,
229
228
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
230
-
231
229
  residual = hidden_states
232
230
 
233
231
  hidden_states = self.input_layernorm(hidden_states)
@@ -414,7 +412,6 @@ class _LlamaForCausalLM(LlamaForCausalLM):
414
412
  output_hidden_states: Optional[bool] = None,
415
413
  return_dict: Optional[bool] = None,
416
414
  ) -> Union[Tuple, CausalLMOutputWithPast]:
417
-
418
415
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
419
416
  output_hidden_states = (
420
417
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -617,23 +614,23 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
617
614
  return q_embed, k_embed
618
615
 
619
616
 
620
- def wrap_llama():
621
- origin_mehtods = {}
622
- origin_mehtods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
623
- origin_mehtods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
624
- origin_mehtods["LlamaModel_forward"] = LlamaModel.forward
625
- origin_mehtods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
617
+ origin_methods = {}
618
+ origin_methods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
619
+ origin_methods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
620
+ origin_methods["LlamaModel_forward"] = LlamaModel.forward
621
+ origin_methods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
626
622
 
623
+
624
+ def wrap_llama():
627
625
  LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
628
626
  LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
629
627
  LlamaModel.forward = _LlamaModel.forward
630
628
  LlamaForCausalLM.forward = _LlamaForCausalLM.forward
631
629
 
632
- return origin_mehtods
633
-
634
630
 
635
- def unwrap_llama(origin_mehtods):
636
- LlamaRotaryEmbedding.__init__ = origin_mehtods["LlamaRotaryEmbedding_INIT"]
637
- LlamaRotaryEmbedding.forward = origin_mehtods["LlamaRotaryEmbedding_forward"]
638
- LlamaModel.forward = origin_mehtods["LlamaModel_forward"]
639
- LlamaForCausalLM.forward = origin_mehtods["LlamaForCausalLM_forward"]
631
+ def unwrap_llama():
632
+ global origin_methods
633
+ LlamaRotaryEmbedding.__init__ = origin_methods["LlamaRotaryEmbedding_INIT"]
634
+ LlamaRotaryEmbedding.forward = origin_methods["LlamaRotaryEmbedding_forward"]
635
+ LlamaModel.forward = origin_methods["LlamaModel_forward"]
636
+ LlamaForCausalLM.forward = origin_methods["LlamaForCausalLM_forward"]
@@ -118,6 +118,9 @@ class _LlamaAttention(LlamaAttention):
118
118
  batch_index: Optional[int] = None,
119
119
  output_attentions: bool = False,
120
120
  use_cache: bool = False,
121
+ cos: Optional[torch.Tensor] = None,
122
+ sin: Optional[torch.Tensor] = None,
123
+ layer_id: int = 0,
121
124
  **kwargs,
122
125
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
123
126
  bsz, q_len, _ = hidden_states.size()
@@ -156,8 +159,11 @@ class _LlamaAttention(LlamaAttention):
156
159
  "with a layer index."
157
160
  )
158
161
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
159
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
160
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
162
+ if layer_id == 0:
163
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
164
+ query_states, key_states, cos, sin = apply_rotary_pos_emb(
165
+ query_states, key_states, cos, sin, position_ids, layer_id
166
+ )
161
167
  if past_key_value is not None:
162
168
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
163
169
  if (batch_index is None or batch_index == -1) and bsz > 1:
@@ -261,7 +267,7 @@ class _LlamaAttention(LlamaAttention):
261
267
  if not output_attentions:
262
268
  attn_weights = None
263
269
 
264
- return attn_output, attn_weights, key_states, value_states
270
+ return attn_output, attn_weights, key_states, value_states, cos, sin
265
271
 
266
272
 
267
273
  class _LlamaDecoderLayer(LlamaDecoderLayer):
@@ -275,6 +281,9 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
275
281
  output_attentions: Optional[bool] = False,
276
282
  use_cache: Optional[bool] = False,
277
283
  batch_ids: Optional[torch.LongTensor] = None,
284
+ cos: Optional[torch.Tensor] = None,
285
+ sin: Optional[torch.Tensor] = None,
286
+ layer_id: int = 0,
278
287
  **kwargs,
279
288
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
280
289
  residual = hidden_states
@@ -282,7 +291,7 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
282
291
  hidden_states = self.input_layernorm(hidden_states)
283
292
  bsz, _, _ = hidden_states.size()
284
293
 
285
- hidden_states, self_attn_weights, k, v = _LlamaAttention.forward(
294
+ hidden_states, self_attn_weights, k, v, cos, sin = _LlamaAttention.forward(
286
295
  self.self_attn,
287
296
  hidden_states=hidden_states,
288
297
  attention_mask=attention_mask,
@@ -291,6 +300,9 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
291
300
  output_attentions=output_attentions,
292
301
  batch_index=batch_ids,
293
302
  use_cache=use_cache,
303
+ cos=cos,
304
+ sin=sin,
305
+ layer_id=layer_id,
294
306
  **kwargs,
295
307
  )
296
308
  past_key_value.assign(k, v, layer_idx)
@@ -313,7 +325,7 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
313
325
  if use_cache:
314
326
  outputs += (present_key_value,)
315
327
 
316
- return outputs
328
+ return outputs, cos, sin
317
329
 
318
330
 
319
331
  class _LlamaModel(LlamaModel):
@@ -415,10 +427,11 @@ class _LlamaModel(LlamaModel):
415
427
  all_self_attns = () if output_attentions else None
416
428
  next_decoder_cache = () if use_cache else None
417
429
 
430
+ cos = None
431
+ sin = None
418
432
  for layer_idx, decoder_layer in enumerate(self.layers):
419
433
  if output_hidden_states:
420
434
  all_hidden_states += (hidden_states,)
421
-
422
435
  layer_outputs = _LlamaDecoderLayer.forward(
423
436
  decoder_layer,
424
437
  hidden_states,
@@ -429,7 +442,13 @@ class _LlamaModel(LlamaModel):
429
442
  output_attentions=output_attentions,
430
443
  use_cache=use_cache,
431
444
  batch_ids=batch_ids,
445
+ cos=cos,
446
+ sin=sin,
447
+ layer_id=layer_idx,
432
448
  )
449
+ cos = layer_outputs[-2]
450
+ sin = layer_outputs[-1]
451
+ layer_outputs = layer_outputs[0]
433
452
 
434
453
  hidden_states = layer_outputs[0]
435
454
 
@@ -697,7 +716,7 @@ def rotate_half(x):
697
716
  return torch.cat((-x2, x1), dim=-1)
698
717
 
699
718
 
700
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
719
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, layer_id, unsqueeze_dim=1):
701
720
  """Applies Rotary Position Embedding to the query and key tensors.
702
721
 
703
722
  Args:
@@ -718,42 +737,28 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
718
737
  Returns:
719
738
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
720
739
  """
721
- if position_ids.shape[0] > 1:
722
- cos_all = []
723
- sin_all = []
724
- for i in range(position_ids.shape[0]):
725
- cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
726
- sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
727
- cos = torch.cat(cos_all, dim=0)
728
- sin = torch.cat(sin_all, dim=0)
729
- else:
730
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
731
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
732
- # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
733
- # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
740
+ if layer_id == 0:
741
+ if position_ids.shape[0] > 1:
742
+ cos_all = []
743
+ sin_all = []
744
+ for i in range(position_ids.shape[0]):
745
+ cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
746
+ sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
747
+ cos = torch.cat(cos_all, dim=0)
748
+ sin = torch.cat(sin_all, dim=0)
749
+ else:
750
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
751
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
752
+ # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
753
+ # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
734
754
 
735
755
  q_embed = (q * cos) + (rotate_half(q) * sin)
736
756
  k_embed = (k * cos) + (rotate_half(k) * sin)
737
- return q_embed, k_embed
757
+ return q_embed, k_embed, cos, sin
738
758
 
739
759
 
740
760
  def wrap_llama():
741
- origin_mehtods = {}
742
- origin_mehtods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
743
- origin_mehtods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
744
- origin_mehtods["LlamaModel_forward"] = LlamaModel.forward
745
- origin_mehtods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
746
-
747
761
  LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
748
762
  LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
749
763
  LlamaModel.forward = _LlamaModel.forward
750
764
  LlamaForCausalLM.forward = _LlamaForCausalLM.forward
751
-
752
- return origin_mehtods
753
-
754
-
755
- def unwrap_llama(origin_mehtods):
756
- LlamaRotaryEmbedding.__init__ = origin_mehtods["LlamaRotaryEmbedding_INIT"]
757
- LlamaRotaryEmbedding.forward = origin_mehtods["LlamaRotaryEmbedding_forward"]
758
- LlamaModel.forward = origin_mehtods["LlamaModel_forward"]
759
- LlamaForCausalLM.forward = origin_mehtods["LlamaForCausalLM_forward"]