dlinfer-ascend 0.1.0__cp38-cp38-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dlinfer/__init__.py +4 -0
- dlinfer/framework/__init__.py +1 -0
- dlinfer/framework/lmdeploy_ext/__init__.py +2 -0
- dlinfer/framework/transformers_ext/__init__.py +19 -0
- dlinfer/framework/transformers_ext/cogvlm.py +16 -0
- dlinfer/framework/transformers_ext/internlm2.py +241 -0
- dlinfer/framework/transformers_ext/internvl.py +40 -0
- dlinfer/framework/transformers_ext/patch.py +33 -0
- dlinfer/ops/__init__.py +2 -0
- dlinfer/ops/llm.py +401 -0
- dlinfer/utils/__init__.py +1 -0
- dlinfer/utils/graph/__init__.py +1 -0
- dlinfer/utils/graph/custom_op.py +69 -0
- dlinfer/utils/registry.py +9 -0
- dlinfer/utils/type_annotation.py +3 -0
- dlinfer/vendor/__init__.py +26 -0
- dlinfer/vendor/ascend/__init__.py +7 -0
- dlinfer/vendor/ascend/ascend_extension.so +0 -0
- dlinfer/vendor/ascend/pytorch_patch.py +8 -0
- dlinfer/vendor/ascend/torch_npu_ops.py +354 -0
- dlinfer/vendor/vendor.yaml +2 -0
- dlinfer_ascend-0.1.0.dist-info/LICENSE +28 -0
- dlinfer_ascend-0.1.0.dist-info/METADATA +125 -0
- dlinfer_ascend-0.1.0.dist-info/RECORD +26 -0
- dlinfer_ascend-0.1.0.dist-info/WHEEL +5 -0
- dlinfer_ascend-0.1.0.dist-info/top_level.txt +1 -0
dlinfer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import importlib
|
|
3
|
+
import os, sys
|
|
4
|
+
import typing
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union
|
|
6
|
+
import transformers
|
|
7
|
+
from .patch import apply_model_patches
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def patched_get_class_in_module(
|
|
11
|
+
class_name: str, module_path: Union[str, os.PathLike]
|
|
12
|
+
) -> typing.Type:
|
|
13
|
+
ret_class = transformers_get_class_in_module(class_name, module_path)
|
|
14
|
+
apply_model_patches(importlib.import_module(ret_class.__module__))
|
|
15
|
+
return ret_class
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
transformers_get_class_in_module = transformers.dynamic_module_utils.get_class_in_module
|
|
19
|
+
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import dlinfer.ops as ext_ops
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def PatchedAttention_forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
|
|
7
|
+
B, L, H = x.shape
|
|
8
|
+
qkv = self.query_key_value(x)
|
|
9
|
+
qkv = qkv.reshape(B, L, 3, H).permute(2, 0, 1, 3) # 3, B, L, H
|
|
10
|
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
11
|
+
|
|
12
|
+
ext_ops.prefill_attention(
|
|
13
|
+
q, k, v, None, None, L, self.num_heads, self.num_heads, None, attn_output=q
|
|
14
|
+
)
|
|
15
|
+
output = self.dense(q.view(B, L, -1))
|
|
16
|
+
return output
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
# Copyright 2024 HuggingFace Inc.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from transformers.cache_utils import Cache
|
|
7
|
+
from typing import Optional, Dict, Any, Tuple
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
import dlinfer.ops as ext_ops
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class TransformerBlockContext: ...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
transformer_block_context = TransformerBlockContext()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def modeling_internlm2_InternLM2RMSNorm_forward(self, hidden_states):
|
|
20
|
+
return ext_ops.rms_norm(hidden_states, self.weight, self.variance_epsilon)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def modeling_internlm2_InternLM2Attention_forward(
|
|
24
|
+
self,
|
|
25
|
+
hidden_states: torch.Tensor,
|
|
26
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
27
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
28
|
+
past_key_value: Optional[Cache] = None,
|
|
29
|
+
output_attentions: bool = False,
|
|
30
|
+
use_cache: bool = False, # pylint: disable=unused-argument
|
|
31
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
32
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
33
|
+
bsz, q_len, _ = hidden_states.size()
|
|
34
|
+
|
|
35
|
+
if self.config.pretraining_tp > 1:
|
|
36
|
+
# split qkv_states by tp size
|
|
37
|
+
key_value_slicing = (
|
|
38
|
+
self.num_key_value_heads * self.head_dim
|
|
39
|
+
) // self.config.pretraining_tp
|
|
40
|
+
qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
|
|
41
|
+
qkv_states = torch.cat(
|
|
42
|
+
[F.linear(hidden_states, qkv_slice) for qkv_slice in qkv_slices],
|
|
43
|
+
dim=-1, # pylint: disable=E1102
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
qkv_states = self.wqkv(hidden_states)
|
|
47
|
+
|
|
48
|
+
qkv_states = rearrange(
|
|
49
|
+
qkv_states,
|
|
50
|
+
"b q (h gs d) -> b q h gs d",
|
|
51
|
+
gs=2 + self.num_key_value_groups,
|
|
52
|
+
d=self.head_dim,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
query_states = qkv_states[..., : self.num_key_value_groups, :]
|
|
56
|
+
query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
|
|
57
|
+
key_states = qkv_states[..., -2, :]
|
|
58
|
+
value_states = qkv_states[..., -1, :]
|
|
59
|
+
|
|
60
|
+
global transformer_block_context
|
|
61
|
+
if self.layer_idx == 0:
|
|
62
|
+
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
63
|
+
setattr(transformer_block_context, "sin", sin)
|
|
64
|
+
setattr(transformer_block_context, "cos", cos)
|
|
65
|
+
sin = transformer_block_context.sin
|
|
66
|
+
cos = transformer_block_context.cos
|
|
67
|
+
query_states, key_states = ext_ops.apply_rotary_pos_emb(
|
|
68
|
+
query_states, key_states, cos, sin, None, None
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if past_key_value is not None:
|
|
72
|
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
73
|
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
74
|
+
key_states, value_states = past_key_value.update(
|
|
75
|
+
key_states, value_states, self.layer_idx, cache_kwargs
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
attn_output = ext_ops.fused_attention(
|
|
79
|
+
query_states, key_states, value_states, [attention_mask.to(torch.bool)]
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
83
|
+
|
|
84
|
+
if self.config.pretraining_tp > 1:
|
|
85
|
+
attn_output = attn_output.split(
|
|
86
|
+
self.hidden_size // self.config.pretraining_tp, dim=2
|
|
87
|
+
)
|
|
88
|
+
o_proj_slices = self.wo.weight.split(
|
|
89
|
+
self.hidden_size // self.config.pretraining_tp, dim=1
|
|
90
|
+
)
|
|
91
|
+
attn_output = sum(
|
|
92
|
+
[
|
|
93
|
+
F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102
|
|
94
|
+
for i in range(self.config.pretraining_tp)
|
|
95
|
+
]
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
attn_output = self.wo(attn_output)
|
|
99
|
+
|
|
100
|
+
if not output_attentions:
|
|
101
|
+
attn_weights = None
|
|
102
|
+
|
|
103
|
+
return attn_output, attn_weights, past_key_value
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def modeling_internlm2_InternLM2ForCausalLM_prepare_inputs_for_generation(
|
|
107
|
+
self,
|
|
108
|
+
input_ids,
|
|
109
|
+
past_key_values=None,
|
|
110
|
+
attention_mask=None,
|
|
111
|
+
inputs_embeds=None,
|
|
112
|
+
cache_position=None,
|
|
113
|
+
use_cache=True,
|
|
114
|
+
**kwargs,
|
|
115
|
+
):
|
|
116
|
+
past_length = 0
|
|
117
|
+
if past_key_values is not None:
|
|
118
|
+
if isinstance(past_key_values, Cache):
|
|
119
|
+
past_length = (
|
|
120
|
+
cache_position[0]
|
|
121
|
+
if cache_position is not None
|
|
122
|
+
else past_key_values.get_seq_length()
|
|
123
|
+
)
|
|
124
|
+
max_cache_length = (
|
|
125
|
+
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
|
126
|
+
if past_key_values.get_max_length() is not None
|
|
127
|
+
else None
|
|
128
|
+
)
|
|
129
|
+
cache_length = (
|
|
130
|
+
past_length
|
|
131
|
+
if max_cache_length is None
|
|
132
|
+
else torch.min(max_cache_length, past_length)
|
|
133
|
+
)
|
|
134
|
+
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
|
135
|
+
else:
|
|
136
|
+
cache_length = past_length = ext_ops.get_cache_len(past_key_values[0][0])
|
|
137
|
+
max_cache_length = None
|
|
138
|
+
|
|
139
|
+
# Keep only the unprocessed tokens:
|
|
140
|
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
|
141
|
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
|
142
|
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
|
143
|
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
|
144
|
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
|
145
|
+
# input_ids based on the past_length.
|
|
146
|
+
elif past_length < input_ids.shape[1]:
|
|
147
|
+
input_ids = input_ids[:, past_length:]
|
|
148
|
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
149
|
+
|
|
150
|
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
|
151
|
+
if (
|
|
152
|
+
max_cache_length is not None
|
|
153
|
+
and attention_mask is not None
|
|
154
|
+
and cache_length + input_ids.shape[1] > max_cache_length
|
|
155
|
+
):
|
|
156
|
+
attention_mask = attention_mask[
|
|
157
|
+
:, -max_cache_length:
|
|
158
|
+
] # pylint: disable=E1130
|
|
159
|
+
|
|
160
|
+
position_ids = kwargs.get("position_ids", None)
|
|
161
|
+
if attention_mask is not None and position_ids is None:
|
|
162
|
+
# create position_ids on the fly for batch generation
|
|
163
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
164
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
165
|
+
if past_key_values:
|
|
166
|
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
167
|
+
|
|
168
|
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
169
|
+
if inputs_embeds is not None and past_key_values is None:
|
|
170
|
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
171
|
+
else:
|
|
172
|
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
|
173
|
+
# recompiles graphs as the stride of the inputs is a guard.
|
|
174
|
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
|
175
|
+
# TODO: use `next_tokens` directly instead.
|
|
176
|
+
model_inputs = {"input_ids": input_ids.contiguous()}
|
|
177
|
+
|
|
178
|
+
input_length = (
|
|
179
|
+
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
|
180
|
+
)
|
|
181
|
+
if cache_position is None:
|
|
182
|
+
cache_position = torch.arange(
|
|
183
|
+
past_length, past_length + input_length, device=input_ids.device
|
|
184
|
+
)
|
|
185
|
+
elif use_cache:
|
|
186
|
+
cache_position = cache_position[-input_length:]
|
|
187
|
+
|
|
188
|
+
model_inputs.update(
|
|
189
|
+
{
|
|
190
|
+
"position_ids": position_ids,
|
|
191
|
+
"cache_position": cache_position,
|
|
192
|
+
"past_key_values": past_key_values,
|
|
193
|
+
"use_cache": use_cache,
|
|
194
|
+
"attention_mask": attention_mask,
|
|
195
|
+
}
|
|
196
|
+
)
|
|
197
|
+
return model_inputs
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def transformers_cache_utils_dynamiccache_update(
|
|
201
|
+
self,
|
|
202
|
+
key_states: torch.Tensor,
|
|
203
|
+
value_states: torch.Tensor,
|
|
204
|
+
layer_idx: int,
|
|
205
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
206
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
207
|
+
"""
|
|
208
|
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
|
209
|
+
|
|
210
|
+
Parameters:
|
|
211
|
+
key_states (`torch.Tensor`):
|
|
212
|
+
The new key states to cache.
|
|
213
|
+
value_states (`torch.Tensor`):
|
|
214
|
+
The new value states to cache.
|
|
215
|
+
layer_idx (`int`):
|
|
216
|
+
The index of the layer to cache the states for.
|
|
217
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
218
|
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
|
219
|
+
|
|
220
|
+
Return:
|
|
221
|
+
A tuple containing the updated key and value states.
|
|
222
|
+
"""
|
|
223
|
+
# Update the number of seen tokens
|
|
224
|
+
if layer_idx == 0:
|
|
225
|
+
self._seen_tokens += key_states.shape[-2]
|
|
226
|
+
|
|
227
|
+
# Update the cache
|
|
228
|
+
if len(self.key_cache) <= layer_idx:
|
|
229
|
+
self.key_cache.append(key_states)
|
|
230
|
+
self.value_cache.append(value_states)
|
|
231
|
+
else:
|
|
232
|
+
self.key_cache[layer_idx], self.value_cache[layer_idx] = (
|
|
233
|
+
ext_ops.fill_contiguous_kvcache(
|
|
234
|
+
self.key_cache[layer_idx],
|
|
235
|
+
self.value_cache[layer_idx],
|
|
236
|
+
key_states,
|
|
237
|
+
value_states,
|
|
238
|
+
)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import torch
|
|
3
|
+
import dlinfer.ops as ext_ops
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def InternAttention_naive_attn(self, x):
|
|
7
|
+
B, N, C = x.shape
|
|
8
|
+
qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
|
|
9
|
+
q, k, v = qkv.unbind(0)
|
|
10
|
+
if self.qk_normalization:
|
|
11
|
+
q = self.q_norm(q)
|
|
12
|
+
k = self.k_norm(k)
|
|
13
|
+
|
|
14
|
+
attention_mask = None
|
|
15
|
+
start_loc = torch.tensor(
|
|
16
|
+
[N * i for i in range(B)], device=q.device, dtype=torch.int64
|
|
17
|
+
)
|
|
18
|
+
seq_len = torch.tensor([N for _ in range(B)], device=q.device, dtype=torch.int64)
|
|
19
|
+
ext_ops.prefill_attention(
|
|
20
|
+
q,
|
|
21
|
+
k,
|
|
22
|
+
v,
|
|
23
|
+
start_loc,
|
|
24
|
+
seq_len,
|
|
25
|
+
N,
|
|
26
|
+
self.num_heads,
|
|
27
|
+
self.num_heads,
|
|
28
|
+
attn_mask=attention_mask,
|
|
29
|
+
softmax_scale=None,
|
|
30
|
+
alibi_slopes=None,
|
|
31
|
+
attn_output=x,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
x = self.proj(x.reshape(B, N, C))
|
|
35
|
+
x = self.proj_drop(x)
|
|
36
|
+
return x
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def InternRMSNorm_forward(self, hidden_states):
|
|
40
|
+
return ext_ops.rms_norm(hidden_states, self.weight, self.variance_epsilon)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import transformers
|
|
3
|
+
import inspect
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def apply_model_patches(module):
|
|
7
|
+
if module.__name__.endswith(".modeling_internlm2"):
|
|
8
|
+
from . import internlm2
|
|
9
|
+
|
|
10
|
+
module.InternLM2RMSNorm.forward = (
|
|
11
|
+
internlm2.modeling_internlm2_InternLM2RMSNorm_forward
|
|
12
|
+
)
|
|
13
|
+
module.InternLM2Attention.forward = (
|
|
14
|
+
internlm2.modeling_internlm2_InternLM2Attention_forward
|
|
15
|
+
)
|
|
16
|
+
module.InternLM2ForCausalLM.prepare_inputs_for_generation = (
|
|
17
|
+
internlm2.modeling_internlm2_InternLM2ForCausalLM_prepare_inputs_for_generation
|
|
18
|
+
)
|
|
19
|
+
transformers.cache_utils.DynamicCache.update = (
|
|
20
|
+
internlm2.transformers_cache_utils_dynamiccache_update
|
|
21
|
+
)
|
|
22
|
+
elif module.__name__.endswith(".modeling_internvl_chat"):
|
|
23
|
+
from . import internvl
|
|
24
|
+
|
|
25
|
+
vit_module = inspect.getmodule(module.InternVisionModel)
|
|
26
|
+
vit_module.InternAttention._naive_attn = internvl.InternAttention_naive_attn
|
|
27
|
+
vit_module.InternRMSNorm.forward = internvl.InternRMSNorm_forward
|
|
28
|
+
elif module.__name__.endswith(".modeling_cogvlm"):
|
|
29
|
+
from . import cogvlm
|
|
30
|
+
|
|
31
|
+
# get parent module from another source code file
|
|
32
|
+
vit_module = inspect.getmodule(module.EVA2CLIPModel)
|
|
33
|
+
vit_module.Attention.forward = cogvlm.PatchedAttention_forward
|
dlinfer/ops/__init__.py
ADDED