optimum-rbln 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +7 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_base.py +172 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
- optimum/rbln/transformers/models/llama/llama_architecture.py +13 -16
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +41 -36
- optimum/rbln/transformers/models/llama/modeling_llama.py +94 -120
- optimum/rbln/transformers/models/midm/modeling_midm.py +85 -121
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -51
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +31 -29
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -23,20 +23,16 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
26
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
29
27
|
|
30
28
|
import rebel
|
31
29
|
import torch
|
32
|
-
from optimum.exporters import TasksManager
|
33
30
|
from transformers import AutoModelForCausalLM, GPT2LMHeadModel, PretrainedConfig
|
34
31
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, Seq2SeqLMOutput
|
35
32
|
|
36
|
-
from ....modeling_base import
|
33
|
+
from ....modeling_base import RBLNModel
|
37
34
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
38
35
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
39
|
-
from ....utils.save_utils import maybe_save_preprocessors
|
40
36
|
from ...generation.utils import RBLNGenerationMixin
|
41
37
|
from .gpt2_architecture import GPT2LMHeadModelWrapper
|
42
38
|
|
@@ -59,12 +55,12 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
59
55
|
return Seq2SeqLMOutput(logits=logits)
|
60
56
|
|
61
57
|
|
62
|
-
class RBLNGPT2LMHeadModel(
|
58
|
+
class RBLNGPT2LMHeadModel(RBLNModel, RBLNGenerationMixin):
|
63
59
|
"""
|
64
60
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
65
61
|
embeddings).
|
66
62
|
|
67
|
-
This model inherits from [`
|
63
|
+
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
|
68
64
|
library implements for all its model.
|
69
65
|
|
70
66
|
It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
|
@@ -89,8 +85,8 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
89
85
|
torch.ones(batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
90
86
|
)
|
91
87
|
|
92
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.
|
93
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.
|
88
|
+
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0])
|
89
|
+
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1])
|
94
90
|
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
95
91
|
self.past_cached_length = 0
|
96
92
|
|
@@ -117,38 +113,7 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
117
113
|
raise NotImplementedError
|
118
114
|
|
119
115
|
@classmethod
|
120
|
-
def
|
121
|
-
cls,
|
122
|
-
model_id: str,
|
123
|
-
config: "PretrainedConfig",
|
124
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
125
|
-
revision: Optional[str] = None,
|
126
|
-
force_download: bool = False,
|
127
|
-
cache_dir: Optional[str] = None,
|
128
|
-
subfolder: str = "",
|
129
|
-
local_files_only: bool = False,
|
130
|
-
trust_remote_code: bool = False,
|
131
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
132
|
-
**kwargs,
|
133
|
-
) -> "RBLNGPT2LMHeadModel":
|
134
|
-
"""
|
135
|
-
Exports a vanilla Transformers model into a rbln-compiled Module.
|
136
|
-
"""
|
137
|
-
task = kwargs.pop("task", None)
|
138
|
-
if task is None:
|
139
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
140
|
-
|
141
|
-
if model_save_dir is None:
|
142
|
-
save_dir = TemporaryDirectory()
|
143
|
-
save_dir_path = Path(save_dir.name)
|
144
|
-
else:
|
145
|
-
save_dir = model_save_dir
|
146
|
-
if isinstance(save_dir, TemporaryDirectory):
|
147
|
-
save_dir_path = Path(model_save_dir.name)
|
148
|
-
else:
|
149
|
-
save_dir_path = Path(model_save_dir)
|
150
|
-
save_dir_path.mkdir(exist_ok=True)
|
151
|
-
|
116
|
+
def update_kwargs(cls, kwargs):
|
152
117
|
kwargs.update(
|
153
118
|
{
|
154
119
|
"torchscript": True,
|
@@ -156,82 +121,45 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
156
121
|
"use_cache": True,
|
157
122
|
}
|
158
123
|
)
|
124
|
+
return kwargs
|
159
125
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
model_name_or_path=model_id,
|
165
|
-
subfolder=subfolder,
|
166
|
-
revision=revision,
|
167
|
-
framework="pt",
|
168
|
-
cache_dir=cache_dir,
|
169
|
-
use_auth_token=use_auth_token,
|
170
|
-
local_files_only=local_files_only,
|
171
|
-
force_download=force_download,
|
172
|
-
trust_remote_code=trust_remote_code,
|
173
|
-
**kwargs,
|
174
|
-
)
|
175
|
-
|
176
|
-
if config is None:
|
177
|
-
config = model.config
|
178
|
-
|
179
|
-
config.save_pretrained(save_dir_path)
|
180
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
181
|
-
|
182
|
-
# Get compilation arguments
|
183
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
184
|
-
rbln_config = cls.get_rbln_config(
|
185
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
186
|
-
)
|
187
|
-
|
188
|
-
def compile_gpt2():
|
189
|
-
wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
|
190
|
-
|
191
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
192
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
126
|
+
@classmethod
|
127
|
+
@torch.inference_mode()
|
128
|
+
def get_compiled_model(cls, model: GPT2LMHeadModel, rbln_config: RBLNConfig):
|
129
|
+
wrapped_decoder = GPT2LMHeadModelWrapper(model).eval()
|
193
130
|
|
194
|
-
|
195
|
-
|
131
|
+
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
132
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
196
133
|
|
197
|
-
|
198
|
-
|
134
|
+
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
135
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
199
136
|
|
200
|
-
|
201
|
-
|
202
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
203
|
-
)
|
204
|
-
dec_ir = rebel.torchscript_to_ir(
|
205
|
-
dec_scripted_model,
|
206
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
207
|
-
)
|
208
|
-
|
209
|
-
connections = [
|
210
|
-
(prefill_ir.outputs[1], prefill_ir.inputs[1]),
|
211
|
-
]
|
137
|
+
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
|
138
|
+
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
212
139
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
)
|
222
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
140
|
+
prefill_ir = rebel.torchscript_to_ir(
|
141
|
+
prefill_scripted_model,
|
142
|
+
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
143
|
+
)
|
144
|
+
dec_ir = rebel.torchscript_to_ir(
|
145
|
+
dec_scripted_model,
|
146
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
147
|
+
)
|
223
148
|
|
224
|
-
|
225
|
-
rbln_config.save(save_dir_path)
|
149
|
+
connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
|
226
150
|
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
151
|
+
compiled_model = rebel.compile(
|
152
|
+
prefill_ir,
|
153
|
+
dec_ir,
|
154
|
+
connections=connections,
|
155
|
+
fusion=prefill_rbln_runtime_config.fusion,
|
156
|
+
npu=prefill_rbln_runtime_config.npu,
|
157
|
+
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
158
|
+
use_weight_sharing=True,
|
233
159
|
)
|
234
160
|
|
161
|
+
return compiled_model
|
162
|
+
|
235
163
|
@classmethod
|
236
164
|
def _get_rbln_config(
|
237
165
|
cls,
|
@@ -271,24 +199,24 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
271
199
|
def get_input_info(query_length):
|
272
200
|
return [
|
273
201
|
("input_ids", [rbln_batch_size, query_length], "int64"),
|
202
|
+
("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
203
|
+
(
|
204
|
+
"cache_position",
|
205
|
+
[],
|
206
|
+
"int32",
|
207
|
+
),
|
208
|
+
] + [
|
274
209
|
(
|
275
|
-
"
|
210
|
+
f"past_key_values_{i}",
|
276
211
|
[
|
277
|
-
model_config.n_layer,
|
278
|
-
2,
|
279
212
|
rbln_batch_size,
|
280
213
|
model_config.n_head,
|
281
214
|
rbln_max_seq_len,
|
282
215
|
model_config.hidden_size // model_config.n_head,
|
283
216
|
],
|
284
217
|
"float32",
|
285
|
-
)
|
286
|
-
|
287
|
-
(
|
288
|
-
"cache_position",
|
289
|
-
[],
|
290
|
-
"int32",
|
291
|
-
),
|
218
|
+
)
|
219
|
+
for i in range(model_config.n_layer * 2)
|
292
220
|
]
|
293
221
|
|
294
222
|
# model input info
|
@@ -305,11 +233,14 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
305
233
|
|
306
234
|
return rbln_config
|
307
235
|
|
308
|
-
|
236
|
+
@classmethod
|
237
|
+
def _create_runtimes(
|
238
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
239
|
+
) -> List[rebel.Runtime]:
|
309
240
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
310
241
|
return [
|
311
|
-
|
312
|
-
|
242
|
+
compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
243
|
+
compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
313
244
|
]
|
314
245
|
|
315
246
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
@@ -386,6 +317,3 @@ class RBLNGPT2LMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
386
317
|
output = output.logits
|
387
318
|
|
388
319
|
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_key_values)
|
389
|
-
|
390
|
-
def __repr__(self):
|
391
|
-
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
@@ -107,7 +107,6 @@ class _LlamaAttention(LlamaAttention):
|
|
107
107
|
use_cache: bool = False,
|
108
108
|
**kwargs,
|
109
109
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
110
|
-
|
111
110
|
bsz, q_len, _ = hidden_states.size()
|
112
111
|
|
113
112
|
if self.config.pretraining_tp > 1:
|
@@ -227,7 +226,6 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
|
|
227
226
|
use_cache: Optional[bool] = False,
|
228
227
|
**kwargs,
|
229
228
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
230
|
-
|
231
229
|
residual = hidden_states
|
232
230
|
|
233
231
|
hidden_states = self.input_layernorm(hidden_states)
|
@@ -414,7 +412,6 @@ class _LlamaForCausalLM(LlamaForCausalLM):
|
|
414
412
|
output_hidden_states: Optional[bool] = None,
|
415
413
|
return_dict: Optional[bool] = None,
|
416
414
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
417
|
-
|
418
415
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
419
416
|
output_hidden_states = (
|
420
417
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -617,23 +614,23 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
617
614
|
return q_embed, k_embed
|
618
615
|
|
619
616
|
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
origin_mehtods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
|
617
|
+
origin_methods = {}
|
618
|
+
origin_methods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
|
619
|
+
origin_methods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
|
620
|
+
origin_methods["LlamaModel_forward"] = LlamaModel.forward
|
621
|
+
origin_methods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
|
626
622
|
|
623
|
+
|
624
|
+
def wrap_llama():
|
627
625
|
LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
|
628
626
|
LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
|
629
627
|
LlamaModel.forward = _LlamaModel.forward
|
630
628
|
LlamaForCausalLM.forward = _LlamaForCausalLM.forward
|
631
629
|
|
632
|
-
return origin_mehtods
|
633
|
-
|
634
630
|
|
635
|
-
def unwrap_llama(
|
636
|
-
|
637
|
-
LlamaRotaryEmbedding.
|
638
|
-
|
639
|
-
|
631
|
+
def unwrap_llama():
|
632
|
+
global origin_methods
|
633
|
+
LlamaRotaryEmbedding.__init__ = origin_methods["LlamaRotaryEmbedding_INIT"]
|
634
|
+
LlamaRotaryEmbedding.forward = origin_methods["LlamaRotaryEmbedding_forward"]
|
635
|
+
LlamaModel.forward = origin_methods["LlamaModel_forward"]
|
636
|
+
LlamaForCausalLM.forward = origin_methods["LlamaForCausalLM_forward"]
|
@@ -118,6 +118,9 @@ class _LlamaAttention(LlamaAttention):
|
|
118
118
|
batch_index: Optional[int] = None,
|
119
119
|
output_attentions: bool = False,
|
120
120
|
use_cache: bool = False,
|
121
|
+
cos: Optional[torch.Tensor] = None,
|
122
|
+
sin: Optional[torch.Tensor] = None,
|
123
|
+
layer_id: int = 0,
|
121
124
|
**kwargs,
|
122
125
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
123
126
|
bsz, q_len, _ = hidden_states.size()
|
@@ -156,8 +159,11 @@ class _LlamaAttention(LlamaAttention):
|
|
156
159
|
"with a layer index."
|
157
160
|
)
|
158
161
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
159
|
-
|
160
|
-
|
162
|
+
if layer_id == 0:
|
163
|
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
164
|
+
query_states, key_states, cos, sin = apply_rotary_pos_emb(
|
165
|
+
query_states, key_states, cos, sin, position_ids, layer_id
|
166
|
+
)
|
161
167
|
if past_key_value is not None:
|
162
168
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
163
169
|
if (batch_index is None or batch_index == -1) and bsz > 1:
|
@@ -261,7 +267,7 @@ class _LlamaAttention(LlamaAttention):
|
|
261
267
|
if not output_attentions:
|
262
268
|
attn_weights = None
|
263
269
|
|
264
|
-
return attn_output, attn_weights, key_states, value_states
|
270
|
+
return attn_output, attn_weights, key_states, value_states, cos, sin
|
265
271
|
|
266
272
|
|
267
273
|
class _LlamaDecoderLayer(LlamaDecoderLayer):
|
@@ -275,6 +281,9 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
|
|
275
281
|
output_attentions: Optional[bool] = False,
|
276
282
|
use_cache: Optional[bool] = False,
|
277
283
|
batch_ids: Optional[torch.LongTensor] = None,
|
284
|
+
cos: Optional[torch.Tensor] = None,
|
285
|
+
sin: Optional[torch.Tensor] = None,
|
286
|
+
layer_id: int = 0,
|
278
287
|
**kwargs,
|
279
288
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
280
289
|
residual = hidden_states
|
@@ -282,7 +291,7 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
|
|
282
291
|
hidden_states = self.input_layernorm(hidden_states)
|
283
292
|
bsz, _, _ = hidden_states.size()
|
284
293
|
|
285
|
-
hidden_states, self_attn_weights, k, v = _LlamaAttention.forward(
|
294
|
+
hidden_states, self_attn_weights, k, v, cos, sin = _LlamaAttention.forward(
|
286
295
|
self.self_attn,
|
287
296
|
hidden_states=hidden_states,
|
288
297
|
attention_mask=attention_mask,
|
@@ -291,6 +300,9 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
|
|
291
300
|
output_attentions=output_attentions,
|
292
301
|
batch_index=batch_ids,
|
293
302
|
use_cache=use_cache,
|
303
|
+
cos=cos,
|
304
|
+
sin=sin,
|
305
|
+
layer_id=layer_id,
|
294
306
|
**kwargs,
|
295
307
|
)
|
296
308
|
past_key_value.assign(k, v, layer_idx)
|
@@ -313,7 +325,7 @@ class _LlamaDecoderLayer(LlamaDecoderLayer):
|
|
313
325
|
if use_cache:
|
314
326
|
outputs += (present_key_value,)
|
315
327
|
|
316
|
-
return outputs
|
328
|
+
return outputs, cos, sin
|
317
329
|
|
318
330
|
|
319
331
|
class _LlamaModel(LlamaModel):
|
@@ -415,10 +427,11 @@ class _LlamaModel(LlamaModel):
|
|
415
427
|
all_self_attns = () if output_attentions else None
|
416
428
|
next_decoder_cache = () if use_cache else None
|
417
429
|
|
430
|
+
cos = None
|
431
|
+
sin = None
|
418
432
|
for layer_idx, decoder_layer in enumerate(self.layers):
|
419
433
|
if output_hidden_states:
|
420
434
|
all_hidden_states += (hidden_states,)
|
421
|
-
|
422
435
|
layer_outputs = _LlamaDecoderLayer.forward(
|
423
436
|
decoder_layer,
|
424
437
|
hidden_states,
|
@@ -429,7 +442,13 @@ class _LlamaModel(LlamaModel):
|
|
429
442
|
output_attentions=output_attentions,
|
430
443
|
use_cache=use_cache,
|
431
444
|
batch_ids=batch_ids,
|
445
|
+
cos=cos,
|
446
|
+
sin=sin,
|
447
|
+
layer_id=layer_idx,
|
432
448
|
)
|
449
|
+
cos = layer_outputs[-2]
|
450
|
+
sin = layer_outputs[-1]
|
451
|
+
layer_outputs = layer_outputs[0]
|
433
452
|
|
434
453
|
hidden_states = layer_outputs[0]
|
435
454
|
|
@@ -697,7 +716,7 @@ def rotate_half(x):
|
|
697
716
|
return torch.cat((-x2, x1), dim=-1)
|
698
717
|
|
699
718
|
|
700
|
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
719
|
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, layer_id, unsqueeze_dim=1):
|
701
720
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
702
721
|
|
703
722
|
Args:
|
@@ -718,42 +737,28 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
718
737
|
Returns:
|
719
738
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
720
739
|
"""
|
721
|
-
if
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
740
|
+
if layer_id == 0:
|
741
|
+
if position_ids.shape[0] > 1:
|
742
|
+
cos_all = []
|
743
|
+
sin_all = []
|
744
|
+
for i in range(position_ids.shape[0]):
|
745
|
+
cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
|
746
|
+
sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
|
747
|
+
cos = torch.cat(cos_all, dim=0)
|
748
|
+
sin = torch.cat(sin_all, dim=0)
|
749
|
+
else:
|
750
|
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
751
|
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
752
|
+
# cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
753
|
+
# sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
734
754
|
|
735
755
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
736
756
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
737
|
-
return q_embed, k_embed
|
757
|
+
return q_embed, k_embed, cos, sin
|
738
758
|
|
739
759
|
|
740
760
|
def wrap_llama():
|
741
|
-
origin_mehtods = {}
|
742
|
-
origin_mehtods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
|
743
|
-
origin_mehtods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
|
744
|
-
origin_mehtods["LlamaModel_forward"] = LlamaModel.forward
|
745
|
-
origin_mehtods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
|
746
|
-
|
747
761
|
LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
|
748
762
|
LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
|
749
763
|
LlamaModel.forward = _LlamaModel.forward
|
750
764
|
LlamaForCausalLM.forward = _LlamaForCausalLM.forward
|
751
|
-
|
752
|
-
return origin_mehtods
|
753
|
-
|
754
|
-
|
755
|
-
def unwrap_llama(origin_mehtods):
|
756
|
-
LlamaRotaryEmbedding.__init__ = origin_mehtods["LlamaRotaryEmbedding_INIT"]
|
757
|
-
LlamaRotaryEmbedding.forward = origin_mehtods["LlamaRotaryEmbedding_forward"]
|
758
|
-
LlamaModel.forward = origin_mehtods["LlamaModel_forward"]
|
759
|
-
LlamaForCausalLM.forward = origin_mehtods["LlamaForCausalLM_forward"]
|