optimum-rbln 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +9 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_base.py +175 -103
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +4 -0
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
- optimum/rbln/transformers/models/llama/llama_architecture.py +62 -33
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +764 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +208 -140
- optimum/rbln/transformers/models/midm/__init__.py +32 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +390 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -50
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +37 -27
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -23,20 +23,16 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
26
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
29
27
|
|
30
28
|
import rebel
|
31
29
|
import torch
|
32
|
-
from optimum.exporters import TasksManager
|
33
30
|
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
|
34
31
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
|
35
32
|
|
36
|
-
from ....modeling_base import
|
33
|
+
from ....modeling_base import RBLNModel
|
37
34
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
38
35
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
39
|
-
from ....utils.save_utils import maybe_save_preprocessors
|
40
36
|
from ...generation.utils import RBLNGenerationMixin
|
41
37
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
42
38
|
|
@@ -59,12 +55,12 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
59
55
|
return Seq2SeqLMOutput(logits=logits)
|
60
56
|
|
61
57
|
|
62
|
-
class RBLNGPT2LMHeadModel(
|
58
|
+
class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
|
63
59
|
"""
|
64
60
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
65
61
|
embeddings).
|
66
62
|
|
67
|
-
This model inherits from [`
|
63
|
+
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
|
68
64
|
library implements for all its model.
|
69
65
|
|
70
66
|
It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
|
@@ -89,8 +85,8 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
89
85
|
torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
90
86
|
)
|
91
87
|
|
92
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.
|
93
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.
|
88
|
+
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0])
|
89
|
+
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1])
|
94
90
|
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
95
91
|
self.past_cached_length = 0
|
96
92
|
|
@@ -117,38 +113,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
117
113
|
raise NotImplementedError
|
118
114
|
|
119
115
|
@classmethod
|
120
|
-
def
|
121
|
-
cls,
|
122
|
-
model_id: str,
|
123
|
-
config: "PretrainedConfig",
|
124
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
125
|
-
revision: Optional[str] = None,
|
126
|
-
force_download: bool = False,
|
127
|
-
cache_dir: Optional[str] = None,
|
128
|
-
subfolder: str = "",
|
129
|
-
local_files_only: bool = False,
|
130
|
-
trust_remote_code: bool = False,
|
131
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
132
|
-
**kwargs,
|
133
|
-
) -> "RBLNGPT2LMHeadModel":
|
134
|
-
"""
|
135
|
-
Exports a vanilla Transformers model into a rbln-compiled Module.
|
136
|
-
"""
|
137
|
-
task = kwargs.pop("task", None)
|
138
|
-
if task is None:
|
139
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
140
|
-
|
141
|
-
if model_save_dir is None:
|
142
|
-
save_dir = TemporaryDirectory()
|
143
|
-
save_dir_path = Path(save_dir.name)
|
144
|
-
else:
|
145
|
-
save_dir = model_save_dir
|
146
|
-
if isinstance(save_dir, TemporaryDirectory):
|
147
|
-
save_dir_path = Path(model_save_dir.name)
|
148
|
-
else:
|
149
|
-
save_dir_path = Path(model_save_dir)
|
150
|
-
save_dir_path.mkdir(exist_ok=True)
|
151
|
-
|
116
|
+
def update_kwargs(cls, kwargs):
|
152
117
|
kwargs.update(
|
153
118
|
{
|
154
119
|
"torchscript": True,
|
@@ -156,82 +121,45 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
156
121
|
"use_cache": True,
|
157
122
|
}
|
158
123
|
)
|
124
|
+
return kwargs
|
159
125
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
model_name_or_path=model_id,
|
165
|
-
subfolder=subfolder,
|
166
|
-
revision=revision,
|
167
|
-
framework="pt",
|
168
|
-
cache_dir=cache_dir,
|
169
|
-
use_auth_token=use_auth_token,
|
170
|
-
local_files_only=local_files_only,
|
171
|
-
force_download=force_download,
|
172
|
-
trust_remote_code=trust_remote_code,
|
173
|
-
**kwargs,
|
174
|
-
)
|
175
|
-
|
176
|
-
if config is None:
|
177
|
-
config = model.config
|
178
|
-
|
179
|
-
config.save_pretrained(save_dir_path)
|
180
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
181
|
-
|
182
|
-
# Get compilation arguments
|
183
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
184
|
-
rbln_config = cls.get_rbln_config(
|
185
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
186
|
-
)
|
187
|
-
|
188
|
-
def compile_gpt2():
|
189
|
-
wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
|
190
|
-
|
191
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
192
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
126
|
+
@classmethod
|
127
|
+
@torch.inference_mode()
|
128
|
+
def get_compiled_model(cls, model: GPT2LMHeadModel, rbln_config: RBLNConfig):
|
129
|
+
wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
|
193
130
|
|
194
|
-
|
195
|
-
|
131
|
+
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
132
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
196
133
|
|
197
|
-
|
198
|
-
|
134
|
+
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
135
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
199
136
|
|
200
|
-
|
201
|
-
|
202
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
203
|
-
)
|
204
|
-
dec_ir = rebel.torchscript_to_ir(
|
205
|
-
dec_scripted_model,
|
206
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
207
|
-
)
|
208
|
-
|
209
|
-
connections = [
|
210
|
-
(prefill_ir.outputs[1], prefill_ir.inputs[1]),
|
211
|
-
]
|
137
|
+
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
|
138
|
+
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
212
139
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
)
|
222
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
140
|
+
prefill_ir = rebel.torchscript_to_ir(
|
141
|
+
prefill_scripted_model,
|
142
|
+
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
143
|
+
)
|
144
|
+
dec_ir = rebel.torchscript_to_ir(
|
145
|
+
dec_scripted_model,
|
146
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
147
|
+
)
|
223
148
|
|
224
|
-
|
225
|
-
rbln_config.save(save_dir_path)
|
149
|
+
connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
|
226
150
|
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
151
|
+
compiled_model = rebel.compile(
|
152
|
+
prefill_ir,
|
153
|
+
dec_ir,
|
154
|
+
connections=connections,
|
155
|
+
fusion=prefill_rbln_runtime_config.fusion,
|
156
|
+
npu=prefill_rbln_runtime_config.npu,
|
157
|
+
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
158
|
+
use_weight_sharing=True,
|
233
159
|
)
|
234
160
|
|
161
|
+
return compiled_model
|
162
|
+
|
235
163
|
@classmethod
|
236
164
|
def _get_rbln_config(
|
237
165
|
cls,
|
@@ -271,24 +199,24 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
271
199
|
def get_input_info(query_length):
|
272
200
|
return [
|
273
201
|
("input_ids", [rbln_batch_size, query_length], "int64"),
|
202
|
+
("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
203
|
+
(
|
204
|
+
"cache_position",
|
205
|
+
[],
|
206
|
+
"int32",
|
207
|
+
),
|
208
|
+
] + [
|
274
209
|
(
|
275
|
-
"
|
210
|
+
f"past_key_values_{i}",
|
276
211
|
[
|
277
|
-
model_config.n_layer,
|
278
|
-
2,
|
279
212
|
rbln_batch_size,
|
280
213
|
model_config.n_head,
|
281
214
|
rbln_max_seq_len,
|
282
215
|
model_config.hidden_size // model_config.n_head,
|
283
216
|
],
|
284
217
|
"float32",
|
285
|
-
)
|
286
|
-
|
287
|
-
(
|
288
|
-
"cache_position",
|
289
|
-
[],
|
290
|
-
"int32",
|
291
|
-
),
|
218
|
+
)
|
219
|
+
for i in range(model_config.n_layer * 2)
|
292
220
|
]
|
293
221
|
|
294
222
|
# model input info
|
@@ -305,11 +233,14 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
305
233
|
|
306
234
|
return rbln_config
|
307
235
|
|
308
|
-
|
236
|
+
@classmethod
|
237
|
+
def _create_runtimes(
|
238
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
239
|
+
) -> List[rebel.Runtime]:
|
309
240
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
310
241
|
return [
|
311
|
-
|
312
|
-
|
242
|
+
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
243
|
+
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
313
244
|
]
|
314
245
|
|
315
246
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
@@ -386,6 +317,3 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
386
317
|
output = output.logits
|
387
318
|
|
388
319
|
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
|
389
|
-
|
390
|
-
def __repr__(self):
|
391
|
-
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
@@ -36,7 +36,6 @@ from transformers.models.llama.modeling_llama import (
|
|
36
36
|
LlamaForCausalLM,
|
37
37
|
LlamaModel,
|
38
38
|
LlamaRotaryEmbedding,
|
39
|
-
repeat_kv,
|
40
39
|
)
|
41
40
|
|
42
41
|
|
@@ -108,7 +107,6 @@ class _LlamaAttention(LlamaAttention):
|
|
108
107
|
use_cache: bool = False,
|
109
108
|
**kwargs,
|
110
109
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
111
|
-
|
112
110
|
bsz, q_len, _ = hidden_states.size()
|
113
111
|
|
114
112
|
if self.config.pretraining_tp > 1:
|
@@ -149,26 +147,41 @@ class _LlamaAttention(LlamaAttention):
|
|
149
147
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
150
148
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
151
149
|
|
150
|
+
# change to remove repeat
|
151
|
+
key_states = key_states.unsqueeze(2)
|
152
|
+
value_states = value_states.unsqueeze(2)
|
153
|
+
query_states = query_states.view(
|
154
|
+
bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
|
155
|
+
)
|
156
|
+
|
152
157
|
if past_key_value is not None:
|
153
158
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
154
159
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
155
160
|
|
156
|
-
|
157
|
-
|
161
|
+
# change to remove repeat
|
162
|
+
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
163
|
+
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
158
164
|
|
159
|
-
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
165
|
+
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
160
166
|
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
167
|
+
attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
|
168
|
+
|
169
|
+
# change to remove repeat
|
170
|
+
# if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
171
|
+
# raise ValueError(
|
172
|
+
# f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
173
|
+
# f" {attn_weights.size()}"
|
174
|
+
# )
|
166
175
|
|
167
176
|
if attention_mask is not None:
|
168
177
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
169
178
|
raise ValueError(
|
170
179
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
171
180
|
)
|
181
|
+
else:
|
182
|
+
# change to remove repeat
|
183
|
+
attention_mask = attention_mask.unsqueeze(2)
|
184
|
+
|
172
185
|
attn_weights = attn_weights + attention_mask
|
173
186
|
|
174
187
|
# upcast attention to fp32
|
@@ -176,6 +189,9 @@ class _LlamaAttention(LlamaAttention):
|
|
176
189
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
177
190
|
attn_output = torch.matmul(attn_weights, value_states)
|
178
191
|
|
192
|
+
# change to remove repeat
|
193
|
+
attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
|
194
|
+
|
179
195
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
180
196
|
raise ValueError(
|
181
197
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
@@ -210,7 +226,6 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
|
|
210
226
|
use_cache: Optional[bool] = False,
|
211
227
|
**kwargs,
|
212
228
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
213
|
-
|
214
229
|
residual = hidden_states
|
215
230
|
|
216
231
|
hidden_states = self.input_layernorm(hidden_states)
|
@@ -397,7 +412,6 @@ class _LlamaForCausalLM(LlamaForCausalLM):
|
|
397
412
|
output_hidden_states: Optional[bool] = None,
|
398
413
|
return_dict: Optional[bool] = None,
|
399
414
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
400
|
-
|
401
415
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
402
416
|
output_hidden_states = (
|
403
417
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -516,17 +530,32 @@ class RebelDynamicCache(DynamicCache):
|
|
516
530
|
if len(self.key_cache) <= layer_idx:
|
517
531
|
self.key_cache.append(key_states)
|
518
532
|
self.value_cache.append(value_states)
|
533
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
519
534
|
else:
|
520
|
-
|
521
|
-
|
535
|
+
# change to remove repeat
|
536
|
+
# self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
|
537
|
+
# key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
|
538
|
+
# )
|
539
|
+
# self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
|
540
|
+
# value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
|
541
|
+
# )
|
542
|
+
updated_key = (
|
543
|
+
self.key_cache[layer_idx]
|
544
|
+
.unsqueeze(2)
|
545
|
+
.slice_scatter(
|
546
|
+
key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
|
547
|
+
)
|
522
548
|
)
|
523
|
-
|
524
|
-
|
549
|
+
updated_value = (
|
550
|
+
self.value_cache[layer_idx]
|
551
|
+
.unsqueeze(2)
|
552
|
+
.slice_scatter(
|
553
|
+
value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
|
554
|
+
)
|
525
555
|
)
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
556
|
+
self.key_cache[layer_idx] = updated_key.squeeze(2)
|
557
|
+
self.value_cache[layer_idx] = updated_value.squeeze(2)
|
558
|
+
return updated_key, updated_value
|
530
559
|
|
531
560
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
532
561
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
@@ -585,23 +614,23 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
585
614
|
return q_embed, k_embed
|
586
615
|
|
587
616
|
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
617
|
+
origin_methods = {}
|
618
|
+
origin_methods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
|
619
|
+
origin_methods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
|
620
|
+
origin_methods["LlamaModel_forward"] = LlamaModel.forward
|
621
|
+
origin_methods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
|
622
|
+
|
594
623
|
|
624
|
+
def wrap_llama():
|
595
625
|
LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
|
596
626
|
LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
|
597
627
|
LlamaModel.forward = _LlamaModel.forward
|
598
628
|
LlamaForCausalLM.forward = _LlamaForCausalLM.forward
|
599
629
|
|
600
|
-
return origin_mehtods
|
601
|
-
|
602
630
|
|
603
|
-
def unwrap_llama(
|
604
|
-
|
605
|
-
LlamaRotaryEmbedding.
|
606
|
-
|
607
|
-
|
631
|
+
def unwrap_llama():
|
632
|
+
global origin_methods
|
633
|
+
LlamaRotaryEmbedding.__init__ = origin_methods["LlamaRotaryEmbedding_INIT"]
|
634
|
+
LlamaRotaryEmbedding.forward = origin_methods["LlamaRotaryEmbedding_forward"]
|
635
|
+
LlamaModel.forward = origin_methods["LlamaModel_forward"]
|
636
|
+
LlamaForCausalLM.forward = origin_methods["LlamaForCausalLM_forward"]
|