optimum-rbln 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. optimum/rbln/__init__.py +9 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  4. optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
  5. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
  8. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  9. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  10. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  12. optimum/rbln/modeling_base.py +175 -103
  13. optimum/rbln/modeling_seq2seq.py +58 -132
  14. optimum/rbln/transformers/__init__.py +4 -0
  15. optimum/rbln/transformers/models/__init__.py +2 -0
  16. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  17. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  18. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
  20. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
  21. optimum/rbln/transformers/models/llama/llama_architecture.py +62 -33
  22. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +764 -0
  23. optimum/rbln/transformers/models/llama/modeling_llama.py +208 -140
  24. optimum/rbln/transformers/models/midm/__init__.py +32 -0
  25. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
  26. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
  27. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
  29. optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
  30. optimum/rbln/transformers/models/midm/modeling_midm.py +390 -0
  31. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  32. optimum/rbln/utils/__init__.py +1 -1
  33. optimum/rbln/utils/import_utils.py +46 -0
  34. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -50
  35. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +37 -27
  36. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
  37. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -23,20 +23,16 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
29
27
 
30
28
  import rebel
31
29
  import torch
32
- from optimum.exporters import TasksManager
33
30
  from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
34
31
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
35
32
 
36
- from ....modeling_base import RBLNBaseModel
33
+ from ....modeling_base import RBLNModel
37
34
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
38
35
  from ....utils.runtime_utils import RBLNPytorchRuntime
39
- from ....utils.save_utils import maybe_save_preprocessors
40
36
  from ...generation.utils import RBLNGenerationMixin
41
37
  from .gpt2_architecture import GPT2LMHeadModelWrapper
42
38
 
@@ -59,12 +55,12 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
59
55
  return Seq2SeqLMOutput(logits=logits)
60
56
 
61
57
 
62
- class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
58
+ class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
63
59
  """
64
60
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
65
61
  embeddings).
66
62
 
67
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
63
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
68
64
  library implements for all its model.
69
65
 
70
66
  It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
@@ -89,8 +85,8 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
89
85
  torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
90
86
  )
91
87
 
92
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0])
93
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1])
88
+ self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0])
89
+ self.decoder = RBLNRuntimeDecoder(runtime=self.model[1])
94
90
  self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
95
91
  self.past_cached_length = 0
96
92
 
@@ -117,38 +113,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
117
113
  raise NotImplementedError
118
114
 
119
115
  @classmethod
120
- def _export(
121
- cls,
122
- model_id: str,
123
- config: "PretrainedConfig",
124
- use_auth_token: Optional[Union[bool, str]] = None,
125
- revision: Optional[str] = None,
126
- force_download: bool = False,
127
- cache_dir: Optional[str] = None,
128
- subfolder: str = "",
129
- local_files_only: bool = False,
130
- trust_remote_code: bool = False,
131
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
132
- **kwargs,
133
- ) -> "RBLNGPT2LMHeadModel":
134
- """
135
- Exports a vanilla Transformers model into a rbln-compiled Module.
136
- """
137
- task = kwargs.pop("task", None)
138
- if task is None:
139
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
140
-
141
- if model_save_dir is None:
142
- save_dir = TemporaryDirectory()
143
- save_dir_path = Path(save_dir.name)
144
- else:
145
- save_dir = model_save_dir
146
- if isinstance(save_dir, TemporaryDirectory):
147
- save_dir_path = Path(model_save_dir.name)
148
- else:
149
- save_dir_path = Path(model_save_dir)
150
- save_dir_path.mkdir(exist_ok=True)
151
-
116
+ def update_kwargs(cls, kwargs):
152
117
  kwargs.update(
153
118
  {
154
119
  "torchscript": True,
@@ -156,82 +121,45 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
156
121
  "use_cache": True,
157
122
  }
158
123
  )
124
+ return kwargs
159
125
 
160
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
161
-
162
- model: GPT2LMHeadModel = TasksManager.get_model_from_task(
163
- task=task,
164
- model_name_or_path=model_id,
165
- subfolder=subfolder,
166
- revision=revision,
167
- framework="pt",
168
- cache_dir=cache_dir,
169
- use_auth_token=use_auth_token,
170
- local_files_only=local_files_only,
171
- force_download=force_download,
172
- trust_remote_code=trust_remote_code,
173
- **kwargs,
174
- )
175
-
176
- if config is None:
177
- config = model.config
178
-
179
- config.save_pretrained(save_dir_path)
180
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
181
-
182
- # Get compilation arguments
183
- if rbln_config_kwargs.get("rbln_config", None) is None:
184
- rbln_config = cls.get_rbln_config(
185
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
186
- )
187
-
188
- def compile_gpt2():
189
- wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
190
-
191
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
192
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
126
+ @classmethod
127
+ @torch.inference_mode()
128
+ def get_compiled_model(cls, model: GPT2LMHeadModel, rbln_config: RBLNConfig):
129
+ wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
193
130
 
194
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
195
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
131
+ prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
132
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
196
133
 
197
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
198
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
134
+ prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
135
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
199
136
 
200
- prefill_ir = rebel.torchscript_to_ir(
201
- prefill_scripted_model,
202
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
203
- )
204
- dec_ir = rebel.torchscript_to_ir(
205
- dec_scripted_model,
206
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
207
- )
208
-
209
- connections = [
210
- (prefill_ir.outputs[1], prefill_ir.inputs[1]),
211
- ]
137
+ prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
138
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
212
139
 
213
- compiled_model = rebel.compile(
214
- prefill_ir,
215
- dec_ir,
216
- connections=connections,
217
- fusion=prefill_rbln_runtime_config.fusion,
218
- npu=prefill_rbln_runtime_config.npu,
219
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
220
- use_weight_sharing=True,
221
- )
222
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
140
+ prefill_ir = rebel.torchscript_to_ir(
141
+ prefill_scripted_model,
142
+ input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
143
+ )
144
+ dec_ir = rebel.torchscript_to_ir(
145
+ dec_scripted_model,
146
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
147
+ )
223
148
 
224
- compile_gpt2()
225
- rbln_config.save(save_dir_path)
149
+ connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
226
150
 
227
- return cls._from_pretrained(
228
- model_id=save_dir_path,
229
- config=config,
230
- model_save_dir=save_dir,
231
- **rbln_constructor_kwargs,
232
- **kwargs,
151
+ compiled_model = rebel.compile(
152
+ prefill_ir,
153
+ dec_ir,
154
+ connections=connections,
155
+ fusion=prefill_rbln_runtime_config.fusion,
156
+ npu=prefill_rbln_runtime_config.npu,
157
+ tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
158
+ use_weight_sharing=True,
233
159
  )
234
160
 
161
+ return compiled_model
162
+
235
163
  @classmethod
236
164
  def _get_rbln_config(
237
165
  cls,
@@ -271,24 +199,24 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
271
199
  def get_input_info(query_length):
272
200
  return [
273
201
  ("input_ids", [rbln_batch_size, query_length], "int64"),
202
+ ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
203
+ (
204
+ "cache_position",
205
+ [],
206
+ "int32",
207
+ ),
208
+ ] + [
274
209
  (
275
- "past_key_values",
210
+ f"past_key_values_{i}",
276
211
  [
277
- model_config.n_layer,
278
- 2,
279
212
  rbln_batch_size,
280
213
  model_config.n_head,
281
214
  rbln_max_seq_len,
282
215
  model_config.hidden_size // model_config.n_head,
283
216
  ],
284
217
  "float32",
285
- ),
286
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
287
- (
288
- "cache_position",
289
- [],
290
- "int32",
291
- ),
218
+ )
219
+ for i in range(model_config.n_layer * 2)
292
220
  ]
293
221
 
294
222
  # model input info
@@ -305,11 +233,14 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
305
233
 
306
234
  return rbln_config
307
235
 
308
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
236
+ @classmethod
237
+ def _create_runtimes(
238
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
239
+ ) -> List[rebel.Runtime]:
309
240
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
310
241
  return [
311
- self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
312
- self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
242
+ compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
243
+ compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
313
244
  ]
314
245
 
315
246
  def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
@@ -386,6 +317,3 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
386
317
  output = output.logits
387
318
 
388
319
  return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
389
-
390
- def __repr__(self):
391
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
@@ -36,7 +36,6 @@ from transformers.models.llama.modeling_llama import (
36
36
  LlamaForCausalLM,
37
37
  LlamaModel,
38
38
  LlamaRotaryEmbedding,
39
- repeat_kv,
40
39
  )
41
40
 
42
41
 
@@ -108,7 +107,6 @@ class _LlamaAttention(LlamaAttention):
108
107
  use_cache: bool = False,
109
108
  **kwargs,
110
109
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
111
-
112
110
  bsz, q_len, _ = hidden_states.size()
113
111
 
114
112
  if self.config.pretraining_tp > 1:
@@ -149,26 +147,41 @@ class _LlamaAttention(LlamaAttention):
149
147
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
150
148
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
151
149
 
150
+ # change to remove repeat
151
+ key_states = key_states.unsqueeze(2)
152
+ value_states = value_states.unsqueeze(2)
153
+ query_states = query_states.view(
154
+ bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
155
+ )
156
+
152
157
  if past_key_value is not None:
153
158
  cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
154
159
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
155
160
 
156
- key_states = repeat_kv(key_states, self.num_key_value_groups)
157
- value_states = repeat_kv(value_states, self.num_key_value_groups)
161
+ # change to remove repeat
162
+ # key_states = repeat_kv(key_states, self.num_key_value_groups)
163
+ # value_states = repeat_kv(value_states, self.num_key_value_groups)
158
164
 
159
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
165
+ # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
160
166
 
161
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
162
- raise ValueError(
163
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
164
- f" {attn_weights.size()}"
165
- )
167
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
168
+
169
+ # change to remove repeat
170
+ # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
171
+ # raise ValueError(
172
+ # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
173
+ # f" {attn_weights.size()}"
174
+ # )
166
175
 
167
176
  if attention_mask is not None:
168
177
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
169
178
  raise ValueError(
170
179
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
171
180
  )
181
+ else:
182
+ # change to remove repeat
183
+ attention_mask = attention_mask.unsqueeze(2)
184
+
172
185
  attn_weights = attn_weights + attention_mask
173
186
 
174
187
  # upcast attention to fp32
@@ -176,6 +189,9 @@ class _LlamaAttention(LlamaAttention):
176
189
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
177
190
  attn_output = torch.matmul(attn_weights, value_states)
178
191
 
192
+ # change to remove repeat
193
+ attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
194
+
179
195
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
180
196
  raise ValueError(
181
197
  f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
@@ -210,7 +226,6 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
210
226
  use_cache: Optional[bool] = False,
211
227
  **kwargs,
212
228
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
213
-
214
229
  residual = hidden_states
215
230
 
216
231
  hidden_states = self.input_layernorm(hidden_states)
@@ -397,7 +412,6 @@ class _LlamaForCausalLM(LlamaForCausalLM):
397
412
  output_hidden_states: Optional[bool] = None,
398
413
  return_dict: Optional[bool] = None,
399
414
  ) -> Union[Tuple, CausalLMOutputWithPast]:
400
-
401
415
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
402
416
  output_hidden_states = (
403
417
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -516,17 +530,32 @@ class RebelDynamicCache(DynamicCache):
516
530
  if len(self.key_cache) <= layer_idx:
517
531
  self.key_cache.append(key_states)
518
532
  self.value_cache.append(value_states)
533
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
519
534
  else:
520
- self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
521
- key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
535
+ # change to remove repeat
536
+ # self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
537
+ # key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
538
+ # )
539
+ # self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
540
+ # value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
541
+ # )
542
+ updated_key = (
543
+ self.key_cache[layer_idx]
544
+ .unsqueeze(2)
545
+ .slice_scatter(
546
+ key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
547
+ )
522
548
  )
523
- self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
524
- value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
549
+ updated_value = (
550
+ self.value_cache[layer_idx]
551
+ .unsqueeze(2)
552
+ .slice_scatter(
553
+ value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
554
+ )
525
555
  )
526
- # self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
527
- # self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
528
-
529
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
556
+ self.key_cache[layer_idx] = updated_key.squeeze(2)
557
+ self.value_cache[layer_idx] = updated_value.squeeze(2)
558
+ return updated_key, updated_value
530
559
 
531
560
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
532
561
  """Returns the sequence length of the cached states. A layer index can be optionally passed."""
@@ -585,23 +614,23 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
585
614
  return q_embed, k_embed
586
615
 
587
616
 
588
- def wrap_llama():
589
- origin_mehtods = {}
590
- origin_mehtods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
591
- origin_mehtods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
592
- origin_mehtods["LlamaModel_forward"] = LlamaModel.forward
593
- origin_mehtods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
617
+ origin_methods = {}
618
+ origin_methods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
619
+ origin_methods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
620
+ origin_methods["LlamaModel_forward"] = LlamaModel.forward
621
+ origin_methods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
622
+
594
623
 
624
+ def wrap_llama():
595
625
  LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
596
626
  LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
597
627
  LlamaModel.forward = _LlamaModel.forward
598
628
  LlamaForCausalLM.forward = _LlamaForCausalLM.forward
599
629
 
600
- return origin_mehtods
601
-
602
630
 
603
- def unwrap_llama(origin_mehtods):
604
- LlamaRotaryEmbedding.__init__ = origin_mehtods["LlamaRotaryEmbedding_INIT"]
605
- LlamaRotaryEmbedding.forward = origin_mehtods["LlamaRotaryEmbedding_forward"]
606
- LlamaModel.forward = origin_mehtods["LlamaModel_forward"]
607
- LlamaForCausalLM.forward = origin_mehtods["LlamaForCausalLM_forward"]
631
+ def unwrap_llama():
632
+ global origin_methods
633
+ LlamaRotaryEmbedding.__init__ = origin_methods["LlamaRotaryEmbedding_INIT"]
634
+ LlamaRotaryEmbedding.forward = origin_methods["LlamaRotaryEmbedding_forward"]
635
+ LlamaModel.forward = origin_methods["LlamaModel_forward"]
636
+ LlamaForCausalLM.forward = origin_methods["LlamaForCausalLM_forward"]