dlinfer-ascend 0.1.0__cp38-cp38-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dlinfer/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import dlinfer.vendor as vendor
3
+
4
+ vendor.vendor_torch_init()
@@ -0,0 +1 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
@@ -0,0 +1,2 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import dlinfer.framework.transformers_ext
@@ -0,0 +1,19 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import importlib
3
+ import os, sys
4
+ import typing
5
+ from typing import Any, Dict, List, Optional, Union
6
+ import transformers
7
+ from .patch import apply_model_patches
8
+
9
+
10
+ def patched_get_class_in_module(
11
+ class_name: str, module_path: Union[str, os.PathLike]
12
+ ) -> typing.Type:
13
+ ret_class = transformers_get_class_in_module(class_name, module_path)
14
+ apply_model_patches(importlib.import_module(ret_class.__module__))
15
+ return ret_class
16
+
17
+
18
+ transformers_get_class_in_module = transformers.dynamic_module_utils.get_class_in_module
19
+ transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
@@ -0,0 +1,16 @@
1
+ import torch
2
+ from torch import nn
3
+ import dlinfer.ops as ext_ops
4
+
5
+
6
+ def PatchedAttention_forward(self, x: "tensor(B, L, D)") -> "tensor(B, L, D)":
7
+ B, L, H = x.shape
8
+ qkv = self.query_key_value(x)
9
+ qkv = qkv.reshape(B, L, 3, H).permute(2, 0, 1, 3) # 3, B, L, H
10
+ q, k, v = qkv[0], qkv[1], qkv[2]
11
+
12
+ ext_ops.prefill_attention(
13
+ q, k, v, None, None, L, self.num_heads, self.num_heads, None, attn_output=q
14
+ )
15
+ output = self.dense(q.view(B, L, -1))
16
+ return output
@@ -0,0 +1,241 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ # Copyright 2024 HuggingFace Inc.
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from transformers.cache_utils import Cache
7
+ from typing import Optional, Dict, Any, Tuple
8
+ from dataclasses import dataclass
9
+ import dlinfer.ops as ext_ops
10
+
11
+
12
+ @dataclass
13
+ class TransformerBlockContext: ...
14
+
15
+
16
+ transformer_block_context = TransformerBlockContext()
17
+
18
+
19
+ def modeling_internlm2_InternLM2RMSNorm_forward(self, hidden_states):
20
+ return ext_ops.rms_norm(hidden_states, self.weight, self.variance_epsilon)
21
+
22
+
23
+ def modeling_internlm2_InternLM2Attention_forward(
24
+ self,
25
+ hidden_states: torch.Tensor,
26
+ attention_mask: Optional[torch.Tensor] = None,
27
+ position_ids: Optional[torch.LongTensor] = None,
28
+ past_key_value: Optional[Cache] = None,
29
+ output_attentions: bool = False,
30
+ use_cache: bool = False, # pylint: disable=unused-argument
31
+ cache_position: Optional[torch.LongTensor] = None,
32
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
33
+ bsz, q_len, _ = hidden_states.size()
34
+
35
+ if self.config.pretraining_tp > 1:
36
+ # split qkv_states by tp size
37
+ key_value_slicing = (
38
+ self.num_key_value_heads * self.head_dim
39
+ ) // self.config.pretraining_tp
40
+ qkv_slices = self.wqkv.weight.split(key_value_slicing, dim=0)
41
+ qkv_states = torch.cat(
42
+ [F.linear(hidden_states, qkv_slice) for qkv_slice in qkv_slices],
43
+ dim=-1, # pylint: disable=E1102
44
+ )
45
+ else:
46
+ qkv_states = self.wqkv(hidden_states)
47
+
48
+ qkv_states = rearrange(
49
+ qkv_states,
50
+ "b q (h gs d) -> b q h gs d",
51
+ gs=2 + self.num_key_value_groups,
52
+ d=self.head_dim,
53
+ )
54
+
55
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
56
+ query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
57
+ key_states = qkv_states[..., -2, :]
58
+ value_states = qkv_states[..., -1, :]
59
+
60
+ global transformer_block_context
61
+ if self.layer_idx == 0:
62
+ cos, sin = self.rotary_emb(value_states, position_ids)
63
+ setattr(transformer_block_context, "sin", sin)
64
+ setattr(transformer_block_context, "cos", cos)
65
+ sin = transformer_block_context.sin
66
+ cos = transformer_block_context.cos
67
+ query_states, key_states = ext_ops.apply_rotary_pos_emb(
68
+ query_states, key_states, cos, sin, None, None
69
+ )
70
+
71
+ if past_key_value is not None:
72
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
73
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
74
+ key_states, value_states = past_key_value.update(
75
+ key_states, value_states, self.layer_idx, cache_kwargs
76
+ )
77
+
78
+ attn_output = ext_ops.fused_attention(
79
+ query_states, key_states, value_states, [attention_mask.to(torch.bool)]
80
+ )
81
+
82
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
83
+
84
+ if self.config.pretraining_tp > 1:
85
+ attn_output = attn_output.split(
86
+ self.hidden_size // self.config.pretraining_tp, dim=2
87
+ )
88
+ o_proj_slices = self.wo.weight.split(
89
+ self.hidden_size // self.config.pretraining_tp, dim=1
90
+ )
91
+ attn_output = sum(
92
+ [
93
+ F.linear(attn_output[i], o_proj_slices[i]) # pylint: disable=E1102
94
+ for i in range(self.config.pretraining_tp)
95
+ ]
96
+ )
97
+ else:
98
+ attn_output = self.wo(attn_output)
99
+
100
+ if not output_attentions:
101
+ attn_weights = None
102
+
103
+ return attn_output, attn_weights, past_key_value
104
+
105
+
106
+ def modeling_internlm2_InternLM2ForCausalLM_prepare_inputs_for_generation(
107
+ self,
108
+ input_ids,
109
+ past_key_values=None,
110
+ attention_mask=None,
111
+ inputs_embeds=None,
112
+ cache_position=None,
113
+ use_cache=True,
114
+ **kwargs,
115
+ ):
116
+ past_length = 0
117
+ if past_key_values is not None:
118
+ if isinstance(past_key_values, Cache):
119
+ past_length = (
120
+ cache_position[0]
121
+ if cache_position is not None
122
+ else past_key_values.get_seq_length()
123
+ )
124
+ max_cache_length = (
125
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
126
+ if past_key_values.get_max_length() is not None
127
+ else None
128
+ )
129
+ cache_length = (
130
+ past_length
131
+ if max_cache_length is None
132
+ else torch.min(max_cache_length, past_length)
133
+ )
134
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
135
+ else:
136
+ cache_length = past_length = ext_ops.get_cache_len(past_key_values[0][0])
137
+ max_cache_length = None
138
+
139
+ # Keep only the unprocessed tokens:
140
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
141
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
142
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
143
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
144
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
145
+ # input_ids based on the past_length.
146
+ elif past_length < input_ids.shape[1]:
147
+ input_ids = input_ids[:, past_length:]
148
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
149
+
150
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
151
+ if (
152
+ max_cache_length is not None
153
+ and attention_mask is not None
154
+ and cache_length + input_ids.shape[1] > max_cache_length
155
+ ):
156
+ attention_mask = attention_mask[
157
+ :, -max_cache_length:
158
+ ] # pylint: disable=E1130
159
+
160
+ position_ids = kwargs.get("position_ids", None)
161
+ if attention_mask is not None and position_ids is None:
162
+ # create position_ids on the fly for batch generation
163
+ position_ids = attention_mask.long().cumsum(-1) - 1
164
+ position_ids.masked_fill_(attention_mask == 0, 1)
165
+ if past_key_values:
166
+ position_ids = position_ids[:, -input_ids.shape[1] :]
167
+
168
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
169
+ if inputs_embeds is not None and past_key_values is None:
170
+ model_inputs = {"inputs_embeds": inputs_embeds}
171
+ else:
172
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
173
+ # recompiles graphs as the stride of the inputs is a guard.
174
+ # Ref: https://github.com/huggingface/transformers/pull/29114
175
+ # TODO: use `next_tokens` directly instead.
176
+ model_inputs = {"input_ids": input_ids.contiguous()}
177
+
178
+ input_length = (
179
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
180
+ )
181
+ if cache_position is None:
182
+ cache_position = torch.arange(
183
+ past_length, past_length + input_length, device=input_ids.device
184
+ )
185
+ elif use_cache:
186
+ cache_position = cache_position[-input_length:]
187
+
188
+ model_inputs.update(
189
+ {
190
+ "position_ids": position_ids,
191
+ "cache_position": cache_position,
192
+ "past_key_values": past_key_values,
193
+ "use_cache": use_cache,
194
+ "attention_mask": attention_mask,
195
+ }
196
+ )
197
+ return model_inputs
198
+
199
+
200
+ def transformers_cache_utils_dynamiccache_update(
201
+ self,
202
+ key_states: torch.Tensor,
203
+ value_states: torch.Tensor,
204
+ layer_idx: int,
205
+ cache_kwargs: Optional[Dict[str, Any]] = None,
206
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
207
+ """
208
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
209
+
210
+ Parameters:
211
+ key_states (`torch.Tensor`):
212
+ The new key states to cache.
213
+ value_states (`torch.Tensor`):
214
+ The new value states to cache.
215
+ layer_idx (`int`):
216
+ The index of the layer to cache the states for.
217
+ cache_kwargs (`Dict[str, Any]`, `optional`):
218
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
219
+
220
+ Return:
221
+ A tuple containing the updated key and value states.
222
+ """
223
+ # Update the number of seen tokens
224
+ if layer_idx == 0:
225
+ self._seen_tokens += key_states.shape[-2]
226
+
227
+ # Update the cache
228
+ if len(self.key_cache) <= layer_idx:
229
+ self.key_cache.append(key_states)
230
+ self.value_cache.append(value_states)
231
+ else:
232
+ self.key_cache[layer_idx], self.value_cache[layer_idx] = (
233
+ ext_ops.fill_contiguous_kvcache(
234
+ self.key_cache[layer_idx],
235
+ self.value_cache[layer_idx],
236
+ key_states,
237
+ value_states,
238
+ )
239
+ )
240
+
241
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
@@ -0,0 +1,40 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import torch
3
+ import dlinfer.ops as ext_ops
4
+
5
+
6
+ def InternAttention_naive_attn(self, x):
7
+ B, N, C = x.shape
8
+ qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
9
+ q, k, v = qkv.unbind(0)
10
+ if self.qk_normalization:
11
+ q = self.q_norm(q)
12
+ k = self.k_norm(k)
13
+
14
+ attention_mask = None
15
+ start_loc = torch.tensor(
16
+ [N * i for i in range(B)], device=q.device, dtype=torch.int64
17
+ )
18
+ seq_len = torch.tensor([N for _ in range(B)], device=q.device, dtype=torch.int64)
19
+ ext_ops.prefill_attention(
20
+ q,
21
+ k,
22
+ v,
23
+ start_loc,
24
+ seq_len,
25
+ N,
26
+ self.num_heads,
27
+ self.num_heads,
28
+ attn_mask=attention_mask,
29
+ softmax_scale=None,
30
+ alibi_slopes=None,
31
+ attn_output=x,
32
+ )
33
+
34
+ x = self.proj(x.reshape(B, N, C))
35
+ x = self.proj_drop(x)
36
+ return x
37
+
38
+
39
+ def InternRMSNorm_forward(self, hidden_states):
40
+ return ext_ops.rms_norm(hidden_states, self.weight, self.variance_epsilon)
@@ -0,0 +1,33 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import transformers
3
+ import inspect
4
+
5
+
6
+ def apply_model_patches(module):
7
+ if module.__name__.endswith(".modeling_internlm2"):
8
+ from . import internlm2
9
+
10
+ module.InternLM2RMSNorm.forward = (
11
+ internlm2.modeling_internlm2_InternLM2RMSNorm_forward
12
+ )
13
+ module.InternLM2Attention.forward = (
14
+ internlm2.modeling_internlm2_InternLM2Attention_forward
15
+ )
16
+ module.InternLM2ForCausalLM.prepare_inputs_for_generation = (
17
+ internlm2.modeling_internlm2_InternLM2ForCausalLM_prepare_inputs_for_generation
18
+ )
19
+ transformers.cache_utils.DynamicCache.update = (
20
+ internlm2.transformers_cache_utils_dynamiccache_update
21
+ )
22
+ elif module.__name__.endswith(".modeling_internvl_chat"):
23
+ from . import internvl
24
+
25
+ vit_module = inspect.getmodule(module.InternVisionModel)
26
+ vit_module.InternAttention._naive_attn = internvl.InternAttention_naive_attn
27
+ vit_module.InternRMSNorm.forward = internvl.InternRMSNorm_forward
28
+ elif module.__name__.endswith(".modeling_cogvlm"):
29
+ from . import cogvlm
30
+
31
+ # get parent module from another source code file
32
+ vit_module = inspect.getmodule(module.EVA2CLIPModel)
33
+ vit_module.Attention.forward = cogvlm.PatchedAttention_forward
@@ -0,0 +1,2 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ from .llm import *