ipex-llm 2.2.0b20250107__py3-none-manylinux2010_x86_64.whl → 2.2.0b20250108__py3-none-manylinux2010_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ipex_llm/libs/libbloom_amx.so +0 -0
- ipex_llm/libs/libbloom_avx.so +0 -0
- ipex_llm/libs/libbloom_avx2.so +0 -0
- ipex_llm/libs/libbloom_avx512.so +0 -0
- ipex_llm/libs/libbloom_avxvnni.so +0 -0
- ipex_llm/libs/libgptneox_amx.so +0 -0
- ipex_llm/libs/libgptneox_avx.so +0 -0
- ipex_llm/libs/libgptneox_avx2.so +0 -0
- ipex_llm/libs/libgptneox_avx512.so +0 -0
- ipex_llm/libs/libgptneox_avxvnni.so +0 -0
- ipex_llm/libs/libllama_amx.so +0 -0
- ipex_llm/libs/libllama_avx.so +0 -0
- ipex_llm/libs/libllama_avx2.so +0 -0
- ipex_llm/libs/libllama_avx512.so +0 -0
- ipex_llm/libs/libllama_avxvnni.so +0 -0
- ipex_llm/libs/libstarcoder_amx.so +0 -0
- ipex_llm/libs/libstarcoder_avx.so +0 -0
- ipex_llm/libs/libstarcoder_avx2.so +0 -0
- ipex_llm/libs/libstarcoder_avx512.so +0 -0
- ipex_llm/libs/libstarcoder_avxvnni.so +0 -0
- ipex_llm/libs/quantize-bloom +0 -0
- ipex_llm/libs/quantize-gptneox +0 -0
- ipex_llm/libs/quantize-llama +0 -0
- ipex_llm/libs/quantize-starcoder +0 -0
- ipex_llm/transformers/convert.py +15 -37
- ipex_llm/transformers/loader.py +1 -1
- ipex_llm/transformers/low_bit_linear.py +10 -25
- ipex_llm/transformers/model.py +0 -7
- ipex_llm/transformers/models/chatglm4v.py +1 -0
- ipex_llm/transformers/models/glm.py +3 -1
- ipex_llm/transformers/models/llama.py +1 -1
- ipex_llm/transformers/models/minicpm.py +2 -1
- ipex_llm/transformers/models/minicpmv.py +1 -0
- ipex_llm/transformers/models/utils.py +3 -16
- ipex_llm/transformers/speculative.py +2 -14
- ipex_llm/transformers/utils.py +2 -14
- ipex_llm/transformers/xpu_ops.py +25 -19
- {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/METADATA +20 -20
- {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/RECORD +45 -46
- ipex_llm/transformers/models/gptj.py +0 -441
- {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250108.data}/scripts/ipex-llm-init +0 -0
- {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250108.data}/scripts/llm-chat +0 -0
- {ipex_llm-2.2.0b20250107.data → ipex_llm-2.2.0b20250108.data}/scripts/llm-cli +0 -0
- {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/WHEEL +0 -0
- {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/entry_points.txt +0 -0
- {ipex_llm-2.2.0b20250107.dist-info → ipex_llm-2.2.0b20250108.dist-info}/top_level.txt +0 -0
@@ -1,441 +0,0 @@
|
|
1
|
-
#
|
2
|
-
# Copyright 2016 The BigDL Authors.
|
3
|
-
#
|
4
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
-
# you may not use this file except in compliance with the License.
|
6
|
-
# You may obtain a copy of the License at
|
7
|
-
#
|
8
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
-
#
|
10
|
-
# Unless required by applicable law or agreed to in writing, software
|
11
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
-
# See the License for the specific language governing permissions and
|
14
|
-
# limitations under the License.
|
15
|
-
#
|
16
|
-
# This file is adapted from
|
17
|
-
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
|
18
|
-
#
|
19
|
-
|
20
|
-
import torch
|
21
|
-
from typing import Optional, Tuple, Union
|
22
|
-
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
|
23
|
-
apply_rotary_pos_emb, append_kv_cache, apply_ipex_rotate_every_two
|
24
|
-
from transformers.utils.import_utils import is_torch_fx_proxy
|
25
|
-
from transformers.modeling_outputs import BaseModelOutputWithPast
|
26
|
-
from transformers.models.gptj.modeling_gptj import GPTJModel
|
27
|
-
from ipex_llm.utils.common import invalidInputError
|
28
|
-
|
29
|
-
import os
|
30
|
-
|
31
|
-
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
32
|
-
|
33
|
-
|
34
|
-
def _get_embed_positions(self, position_ids):
|
35
|
-
embed_positions = self.embed_positions
|
36
|
-
if embed_positions.device != position_ids.device:
|
37
|
-
embed_positions = embed_positions.to(position_ids.device)
|
38
|
-
self.embed_positions = embed_positions
|
39
|
-
return embed_positions.repeat(position_ids.shape[0], 1, 1)
|
40
|
-
|
41
|
-
|
42
|
-
def _attn(
|
43
|
-
self,
|
44
|
-
query,
|
45
|
-
key,
|
46
|
-
value,
|
47
|
-
attention_mask=None,
|
48
|
-
head_mask=None,
|
49
|
-
):
|
50
|
-
# compute causal mask from causal mask buffer
|
51
|
-
query_length, key_length = query.size(-2), key.size(-2)
|
52
|
-
causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length]
|
53
|
-
|
54
|
-
# Keep the attention weights computation in fp32 to avoid overflow issues
|
55
|
-
query = query.to(torch.float32)
|
56
|
-
key = key.to(torch.float32)
|
57
|
-
|
58
|
-
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
59
|
-
|
60
|
-
mask_value = torch.finfo(attn_weights.dtype).min
|
61
|
-
# Need to be a tensor, otherwise we get error:
|
62
|
-
# `RuntimeError: expected scalar type float but found double`.
|
63
|
-
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
64
|
-
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
65
|
-
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
66
|
-
|
67
|
-
attn_weights = attn_weights / self.scale_attn
|
68
|
-
|
69
|
-
if attention_mask is not None:
|
70
|
-
# Apply the attention mask
|
71
|
-
attn_weights = attn_weights + attention_mask
|
72
|
-
|
73
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
74
|
-
attn_weights = attn_weights.to(value.dtype)
|
75
|
-
attn_weights = self.attn_dropout(attn_weights)
|
76
|
-
|
77
|
-
# Mask heads if we want to
|
78
|
-
if head_mask is not None:
|
79
|
-
attn_weights = attn_weights * head_mask
|
80
|
-
|
81
|
-
attn_output = torch.matmul(attn_weights, value)
|
82
|
-
|
83
|
-
return attn_output, attn_weights
|
84
|
-
|
85
|
-
|
86
|
-
def gptj_attention_forward(
|
87
|
-
self,
|
88
|
-
hidden_states: torch.FloatTensor,
|
89
|
-
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
90
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
91
|
-
position_ids: Optional[torch.LongTensor] = None,
|
92
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
93
|
-
use_cache: Optional[bool] = False,
|
94
|
-
rotary_emb: Optional[Tuple]=None,
|
95
|
-
output_attentions: Optional[bool] = False,
|
96
|
-
) -> Union[
|
97
|
-
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
98
|
-
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
99
|
-
]:
|
100
|
-
query = self.q_proj(hidden_states)
|
101
|
-
key = self.k_proj(hidden_states)
|
102
|
-
value = self.v_proj(hidden_states)
|
103
|
-
|
104
|
-
query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
|
105
|
-
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
|
106
|
-
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
|
107
|
-
|
108
|
-
sin, cos = rotary_emb
|
109
|
-
use_fuse_rope = hidden_states.device.type == "xpu" and not self.training
|
110
|
-
|
111
|
-
if self.rotary_dim is not None:
|
112
|
-
k_rot = key[:, :, :, : self.rotary_dim]
|
113
|
-
q_rot = query[:, :, :, : self.rotary_dim]
|
114
|
-
|
115
|
-
if use_fuse_rope:
|
116
|
-
apply_ipex_rotate_every_two(q_rot, k_rot, cos, sin)
|
117
|
-
else:
|
118
|
-
k_pass = key[:, :, :, self.rotary_dim:]
|
119
|
-
q_pass = query[:, :, :, self.rotary_dim:]
|
120
|
-
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, position_ids, "gptj")
|
121
|
-
key = torch.cat([k_rot, k_pass], dim=-1)
|
122
|
-
query = torch.cat([q_rot, q_pass], dim=-1)
|
123
|
-
else:
|
124
|
-
if use_fuse_rope:
|
125
|
-
apply_ipex_rotate_every_two(query, key, cos, sin)
|
126
|
-
else:
|
127
|
-
query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids, "gptj")
|
128
|
-
|
129
|
-
batch_size, q_len, _ = hidden_states.size()
|
130
|
-
|
131
|
-
key = key.permute(0, 2, 1, 3).contiguous()
|
132
|
-
query = query.permute(0, 2, 1, 3).contiguous()
|
133
|
-
|
134
|
-
kv_seq_len = key.size(-2)
|
135
|
-
device = hidden_states.device
|
136
|
-
|
137
|
-
if layer_past is not None:
|
138
|
-
kv_seq_len += layer_past[0].size(2)
|
139
|
-
|
140
|
-
if layer_past is not None:
|
141
|
-
cache_k = layer_past[0]
|
142
|
-
cache_v = layer_past[1]
|
143
|
-
past_length = cache_k.size(2)
|
144
|
-
if cache_k.stride()[1] < kv_seq_len * cache_k.size(3):
|
145
|
-
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
|
146
|
-
self.num_attention_heads,
|
147
|
-
self.head_dim,
|
148
|
-
past_length,
|
149
|
-
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
150
|
-
dtype=cache_v.dtype,
|
151
|
-
device=device)
|
152
|
-
new_cache_k[:] = cache_k
|
153
|
-
new_cache_v[:] = cache_v
|
154
|
-
cache_k = new_cache_k
|
155
|
-
cache_v = new_cache_v
|
156
|
-
key, value = append_kv_cache(cache_k, cache_v, key, value)
|
157
|
-
|
158
|
-
elif use_cache:
|
159
|
-
key_cache, value_cache = init_kv_cache(batch_size,
|
160
|
-
self.num_attention_heads,
|
161
|
-
self.head_dim,
|
162
|
-
kv_seq_len,
|
163
|
-
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
164
|
-
dtype=value.dtype,
|
165
|
-
device=device)
|
166
|
-
key_cache[:] = key
|
167
|
-
value_cache[:] = value
|
168
|
-
key = key_cache
|
169
|
-
value = value_cache
|
170
|
-
|
171
|
-
if use_cache is True:
|
172
|
-
present = (key, value)
|
173
|
-
else:
|
174
|
-
present = None
|
175
|
-
|
176
|
-
# compute self-attention: V x Softmax(QK^T)
|
177
|
-
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
178
|
-
|
179
|
-
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
|
180
|
-
attn_output = self.out_proj(attn_output)
|
181
|
-
attn_output = self.resid_dropout(attn_output)
|
182
|
-
|
183
|
-
outputs = (attn_output, present)
|
184
|
-
if output_attentions:
|
185
|
-
outputs += (attn_weights,)
|
186
|
-
|
187
|
-
return outputs # a, present, (attentions)
|
188
|
-
|
189
|
-
|
190
|
-
def gptj_block_forward(
|
191
|
-
self,
|
192
|
-
hidden_states: Optional[torch.FloatTensor],
|
193
|
-
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
194
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
195
|
-
position_ids: Optional[torch.LongTensor] = None,
|
196
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
197
|
-
use_cache: Optional[bool] = False,
|
198
|
-
rotary_emb: Optional[Tuple]=None,
|
199
|
-
output_attentions: Optional[bool] = False,
|
200
|
-
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
201
|
-
residual = hidden_states
|
202
|
-
hidden_states = self.ln_1(hidden_states)
|
203
|
-
attn_outputs = self.attn(
|
204
|
-
hidden_states=hidden_states,
|
205
|
-
layer_past=layer_past,
|
206
|
-
attention_mask=attention_mask,
|
207
|
-
position_ids=position_ids,
|
208
|
-
head_mask=head_mask,
|
209
|
-
use_cache=use_cache,
|
210
|
-
rotary_emb=rotary_emb,
|
211
|
-
output_attentions=output_attentions,
|
212
|
-
)
|
213
|
-
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
214
|
-
outputs = attn_outputs[1:]
|
215
|
-
|
216
|
-
feed_forward_hidden_states = self.mlp(hidden_states)
|
217
|
-
hidden_states = attn_output + feed_forward_hidden_states + residual
|
218
|
-
|
219
|
-
if use_cache:
|
220
|
-
outputs = (hidden_states,) + outputs
|
221
|
-
else:
|
222
|
-
outputs = (hidden_states,) + outputs[1:]
|
223
|
-
|
224
|
-
return outputs # hidden_states, present, (attentions)
|
225
|
-
|
226
|
-
|
227
|
-
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
228
|
-
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
|
229
|
-
sinusoid_inp = torch.einsum("i , j -> i j",
|
230
|
-
torch.arange(num_pos, dtype=torch.float), inv_freq).float()
|
231
|
-
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
|
232
|
-
|
233
|
-
|
234
|
-
old_init = GPTJModel.__init__
|
235
|
-
|
236
|
-
|
237
|
-
def gptj_model_new_init(self, config):
|
238
|
-
old_init(self, config)
|
239
|
-
embed_dim = config.hidden_size
|
240
|
-
rotary_dim = config.rotary_dim
|
241
|
-
pos_embd_dim = rotary_dim or embed_dim
|
242
|
-
max_positions = config.max_position_embeddings
|
243
|
-
self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
|
244
|
-
|
245
|
-
|
246
|
-
def get_new_embed_positions(position_ids, prev_embed_positions):
|
247
|
-
embed_positions = prev_embed_positions
|
248
|
-
if embed_positions.device != position_ids.device:
|
249
|
-
embed_positions = embed_positions.to(position_ids.device)
|
250
|
-
prev_embed_positions = embed_positions
|
251
|
-
return embed_positions.repeat(position_ids.shape[0], 1, 1), prev_embed_positions
|
252
|
-
|
253
|
-
|
254
|
-
def gptj_model_forward(
|
255
|
-
self,
|
256
|
-
input_ids: Optional[torch.LongTensor] = None,
|
257
|
-
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
258
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
259
|
-
token_type_ids: Optional[torch.LongTensor] = None,
|
260
|
-
position_ids: Optional[torch.LongTensor] = None,
|
261
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
262
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
263
|
-
use_cache: Optional[bool] = None,
|
264
|
-
output_attentions: Optional[bool] = None,
|
265
|
-
output_hidden_states: Optional[bool] = None,
|
266
|
-
return_dict: Optional[bool] = None,
|
267
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
268
|
-
output_attentions = output_attentions if output_attentions is not None \
|
269
|
-
else self.config.output_attentions
|
270
|
-
output_hidden_states = (
|
271
|
-
output_hidden_states if output_hidden_states is not None
|
272
|
-
else self.config.output_hidden_states
|
273
|
-
)
|
274
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
275
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
276
|
-
|
277
|
-
if input_ids is not None and inputs_embeds is not None:
|
278
|
-
invalidInputError(False,
|
279
|
-
"You cannot specify both input_ids and inputs_embeds at the same time")
|
280
|
-
elif input_ids is not None:
|
281
|
-
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
282
|
-
input_shape = input_ids.size()
|
283
|
-
input_ids = input_ids.view(-1, input_shape[-1])
|
284
|
-
batch_size = input_ids.shape[0]
|
285
|
-
elif inputs_embeds is not None:
|
286
|
-
input_shape = inputs_embeds.size()[:-1]
|
287
|
-
batch_size = inputs_embeds.shape[0]
|
288
|
-
else:
|
289
|
-
invalidInputError(False, "You have to specify either input_ids or inputs_embeds")
|
290
|
-
|
291
|
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
292
|
-
|
293
|
-
if token_type_ids is not None:
|
294
|
-
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
295
|
-
|
296
|
-
if past_key_values is None:
|
297
|
-
past_length = 0
|
298
|
-
past_key_values = tuple([None] * len(self.h))
|
299
|
-
else:
|
300
|
-
past_length = past_key_values[0][0].size(-2)
|
301
|
-
|
302
|
-
if position_ids is None:
|
303
|
-
position_ids = torch.arange(past_length, input_shape[-1] + past_length,
|
304
|
-
dtype=torch.long, device=device)
|
305
|
-
position_ids = position_ids.unsqueeze(0)
|
306
|
-
|
307
|
-
# Attention mask.
|
308
|
-
if attention_mask is not None:
|
309
|
-
if batch_size <= 0:
|
310
|
-
invalidInputError(False, "batch_size has to be defined and > 0")
|
311
|
-
attention_mask = attention_mask.view(batch_size, -1)
|
312
|
-
# We create a 3D attention mask from a 2D tensor mask.
|
313
|
-
# Sizes are [batch_size, 1, 1, to_seq_length]
|
314
|
-
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
315
|
-
# this attention mask is more simple than the triangular masking of causal attention
|
316
|
-
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
317
|
-
attention_mask = attention_mask[:, None, None, :]
|
318
|
-
|
319
|
-
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
320
|
-
# masked positions, this operation will create a tensor which is 0.0 for
|
321
|
-
# positions we want to attend and the dtype's smallest value for masked positions.
|
322
|
-
# Since we are adding it to the raw scores before the softmax, this is
|
323
|
-
# effectively the same as removing these entirely.
|
324
|
-
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
325
|
-
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
326
|
-
|
327
|
-
# Prepare head mask if needed
|
328
|
-
# 1.0 in head_mask indicate we keep the head
|
329
|
-
# attention_probs has shape bsz x num_attention_heads x N x N
|
330
|
-
# head_mask has shape n_layer x batch x num_attention_heads x N x N
|
331
|
-
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
332
|
-
|
333
|
-
if inputs_embeds is None:
|
334
|
-
inputs_embeds = self.wte(input_ids)
|
335
|
-
|
336
|
-
hidden_states = inputs_embeds
|
337
|
-
|
338
|
-
if token_type_ids is not None:
|
339
|
-
token_type_embeds = self.wte(token_type_ids)
|
340
|
-
hidden_states = hidden_states + token_type_embeds
|
341
|
-
|
342
|
-
hidden_states = self.drop(hidden_states)
|
343
|
-
|
344
|
-
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
345
|
-
|
346
|
-
if self.gradient_checkpointing and self.training:
|
347
|
-
if use_cache:
|
348
|
-
logger.warning_once(
|
349
|
-
"`use_cache=True` is incompatible with gradient checkpointing."
|
350
|
-
"Setting `use_cache=False`..."
|
351
|
-
)
|
352
|
-
use_cache = False
|
353
|
-
|
354
|
-
presents = () if use_cache else None
|
355
|
-
all_self_attentions = () if output_attentions else None
|
356
|
-
all_hidden_states = () if output_hidden_states else None
|
357
|
-
|
358
|
-
# Repeat cos sin here, call only once for each token.
|
359
|
-
# If put this to attension forward, it will generate too many times.
|
360
|
-
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
|
361
|
-
# The logic to conditionally copy to GPU could not be traced, so we do this
|
362
|
-
# every time in the torch.fx case
|
363
|
-
embed_positions = get_embed_positions(self.embed_positions, position_ids)
|
364
|
-
else:
|
365
|
-
embed_positions, self.embed_positions = get_new_embed_positions(position_ids,
|
366
|
-
self.embed_positions)
|
367
|
-
|
368
|
-
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
|
369
|
-
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
|
370
|
-
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
|
371
|
-
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
|
372
|
-
cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
|
373
|
-
|
374
|
-
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
375
|
-
# Model parallel
|
376
|
-
if self.model_parallel:
|
377
|
-
torch.cuda.set_device(hidden_states.device)
|
378
|
-
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
379
|
-
if layer_past is not None:
|
380
|
-
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
381
|
-
# Ensure that attention_mask is always on the same device as hidden_states
|
382
|
-
if attention_mask is not None:
|
383
|
-
attention_mask = attention_mask.to(hidden_states.device)
|
384
|
-
if isinstance(head_mask, torch.Tensor):
|
385
|
-
head_mask = head_mask.to(hidden_states.device)
|
386
|
-
if output_hidden_states:
|
387
|
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
388
|
-
|
389
|
-
if self.gradient_checkpointing and self.training:
|
390
|
-
outputs = self._gradient_checkpointing_func(
|
391
|
-
block.__call__,
|
392
|
-
hidden_states,
|
393
|
-
None,
|
394
|
-
attention_mask,
|
395
|
-
position_ids,
|
396
|
-
head_mask[i],
|
397
|
-
use_cache,
|
398
|
-
output_attentions,
|
399
|
-
)
|
400
|
-
else:
|
401
|
-
outputs = block(
|
402
|
-
hidden_states=hidden_states,
|
403
|
-
layer_past=layer_past,
|
404
|
-
attention_mask=attention_mask,
|
405
|
-
position_ids=position_ids,
|
406
|
-
head_mask=head_mask[i],
|
407
|
-
use_cache=use_cache,
|
408
|
-
rotary_emb=(sin, cos),
|
409
|
-
output_attentions=output_attentions,
|
410
|
-
)
|
411
|
-
|
412
|
-
hidden_states = outputs[0]
|
413
|
-
if use_cache is True:
|
414
|
-
presents = presents + (outputs[1],)
|
415
|
-
|
416
|
-
if output_attentions:
|
417
|
-
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
418
|
-
|
419
|
-
# Model Parallel: If it's the last layer for that device, put things on the next device
|
420
|
-
if self.model_parallel:
|
421
|
-
for k, v in self.device_map.items():
|
422
|
-
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
423
|
-
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
424
|
-
|
425
|
-
hidden_states = self.ln_f(hidden_states)
|
426
|
-
|
427
|
-
hidden_states = hidden_states.view(output_shape)
|
428
|
-
# Add last hidden state
|
429
|
-
if output_hidden_states:
|
430
|
-
all_hidden_states = all_hidden_states + (hidden_states,)
|
431
|
-
|
432
|
-
if not return_dict:
|
433
|
-
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
|
434
|
-
if v is not None)
|
435
|
-
|
436
|
-
return BaseModelOutputWithPast(
|
437
|
-
last_hidden_state=hidden_states,
|
438
|
-
past_key_values=presents,
|
439
|
-
hidden_states=all_hidden_states,
|
440
|
-
attentions=all_self_attentions,
|
441
|
-
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|