dlinfer-ascend 0.2.3.post2__cp311-cp311-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dlinfer/__init__.py +5 -0
- dlinfer/framework/__init__.py +1 -0
- dlinfer/framework/lmdeploy_ext/__init__.py +6 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/__init__.py +20 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +391 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/camb_cudagraph.py +133 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py +128 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/ppu_cudagraph.py +131 -0
- dlinfer/framework/lmdeploy_ext/device/__init__.py +79 -0
- dlinfer/framework/lmdeploy_ext/device/ascend.py +205 -0
- dlinfer/framework/lmdeploy_ext/device/camb.py +24 -0
- dlinfer/framework/lmdeploy_ext/quants/__init__.py +20 -0
- dlinfer/framework/lmdeploy_ext/quants/ascend_awq.py +248 -0
- dlinfer/framework/torch_npu_ext/__init__.py +12 -0
- dlinfer/framework/torch_npu_ext/aclgraph.py +59 -0
- dlinfer/framework/transformers_ext/__init__.py +17 -0
- dlinfer/framework/transformers_ext/cogvlm.py +25 -0
- dlinfer/framework/transformers_ext/internlm2.py +242 -0
- dlinfer/framework/transformers_ext/internvl.py +33 -0
- dlinfer/framework/transformers_ext/patch.py +33 -0
- dlinfer/graph/__init__.py +5 -0
- dlinfer/graph/custom_op.py +147 -0
- dlinfer/graph/dicp/__init__.py +0 -0
- dlinfer/graph/dicp/dynamo_bridge/__init__.py +0 -0
- dlinfer/graph/dicp/dynamo_bridge/compile.py +42 -0
- dlinfer/graph/dicp/dynamo_bridge/compile_fx.py +305 -0
- dlinfer/graph/dicp/dynamo_bridge/conversion.py +75 -0
- dlinfer/graph/dicp/dynamo_bridge/decompositions.py +38 -0
- dlinfer/graph/dicp/dynamo_bridge/graph.py +141 -0
- dlinfer/graph/dicp/dynamo_bridge/op_transformer.py +293 -0
- dlinfer/graph/dicp/dynamo_bridge/operator.py +87 -0
- dlinfer/graph/dicp/dynamo_bridge/pt_patch.py +320 -0
- dlinfer/graph/dicp/dynamo_bridge/torch_version.py +38 -0
- dlinfer/graph/dicp/dynamo_bridge/utils.py +158 -0
- dlinfer/graph/dicp/vendor/AtbGraph/__init__.py +13 -0
- dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py +853 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/__init__.py +0 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb.py +318 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_graph.py +768 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py +763 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py +1279 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/libdicp_model.so +0 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/load_and_run.py +21 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/utils.py +178 -0
- dlinfer/graph/dicp/vendor/AtbGraph/compile_job.py +52 -0
- dlinfer/graph/dicp/vendor/AtbGraph/config.py +36 -0
- dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +908 -0
- dlinfer/graph/dicp/vendor/AtbGraph/ext_ops.py +95 -0
- dlinfer/graph/dicp/vendor/AtbGraph/infer_res_utils.py +200 -0
- dlinfer/graph/dicp/vendor/AtbGraph/opset_convert.py +70 -0
- dlinfer/graph/dicp/vendor/AtbGraph/pattern_replacement.py +152 -0
- dlinfer/graph/dicp/vendor/__init__.py +0 -0
- dlinfer/ops/__init__.py +2 -0
- dlinfer/ops/llm.py +879 -0
- dlinfer/utils/__init__.py +1 -0
- dlinfer/utils/config.py +18 -0
- dlinfer/utils/registry.py +8 -0
- dlinfer/utils/type_annotation.py +3 -0
- dlinfer/vendor/__init__.py +33 -0
- dlinfer/vendor/ascend/__init__.py +5 -0
- dlinfer/vendor/ascend/pytorch_patch.py +55 -0
- dlinfer/vendor/ascend/torch_npu_ops.py +601 -0
- dlinfer/vendor/ascend/utils.py +20 -0
- dlinfer/vendor/vendor.yaml +2 -0
- dlinfer_ascend-0.2.3.post2.dist-info/LICENSE +28 -0
- dlinfer_ascend-0.2.3.post2.dist-info/METADATA +213 -0
- dlinfer_ascend-0.2.3.post2.dist-info/RECORD +70 -0
- dlinfer_ascend-0.2.3.post2.dist-info/WHEEL +5 -0
- dlinfer_ascend-0.2.3.post2.dist-info/entry_points.txt +2 -0
- dlinfer_ascend-0.2.3.post2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright (c) 2025, OpenMMLab and DeepLink. All rights reserved.
|
|
2
|
+
import torch
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
|
|
8
|
+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMixin, next_power_of_2
|
|
9
|
+
|
|
10
|
+
BuffType = Dict[str, Tensor]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def PPUCudaGraphMixin_make_buffers_cudagraph(
|
|
14
|
+
self, graph_meta: CudaGraphMeta, *args, **kwargs
|
|
15
|
+
) -> BuffType:
|
|
16
|
+
"""make cudagraph buffers from forward inputs."""
|
|
17
|
+
max_batches = graph_meta.max_batchs
|
|
18
|
+
max_tokens = graph_meta.max_tokens
|
|
19
|
+
num_blocks = graph_meta.num_blocks
|
|
20
|
+
device = graph_meta.device
|
|
21
|
+
input_buffers: BuffType = dict()
|
|
22
|
+
input_buffers["input_ids"] = torch.zeros(
|
|
23
|
+
1, max_tokens, dtype=torch.int32, device=device
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
input_buffers["position_ids"] = torch.zeros(
|
|
27
|
+
(1, max_tokens), dtype=torch.int32, device=device
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
input_buffers["block_offsets"] = torch.zeros(
|
|
31
|
+
(max_batches, num_blocks), dtype=torch.int32, device=device
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
input_buffers["q_seqlens"] = torch.zeros(
|
|
35
|
+
max_batches, dtype=torch.int32, device=device
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
input_buffers["kv_seqlens"] = torch.zeros(
|
|
39
|
+
max_batches, dtype=torch.int32, device=device
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
input_buffers["q_start_loc"] = torch.zeros(
|
|
43
|
+
max_batches + 1, dtype=torch.int32, device=device
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
input_buffers["kv_start_indices"] = torch.zeros(
|
|
47
|
+
(max_batches, 1), dtype=torch.int64, device=device
|
|
48
|
+
)
|
|
49
|
+
return input_buffers
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def PPUCudaGraphMixin_fill_buffers_cudagraph(
|
|
53
|
+
self,
|
|
54
|
+
graph_meta: CudaGraphMeta,
|
|
55
|
+
input_ids: Tensor,
|
|
56
|
+
position_ids: Tensor,
|
|
57
|
+
past_key_values: List,
|
|
58
|
+
attn_metadata: Any,
|
|
59
|
+
inputs_embeds: Tensor,
|
|
60
|
+
**kwargs
|
|
61
|
+
) -> Dict[str, Tensor]:
|
|
62
|
+
"""fill cudagraph buffers from forward inputs."""
|
|
63
|
+
block_offsets: Tensor = attn_metadata.block_offsets
|
|
64
|
+
q_start_loc: Tensor = attn_metadata.q_start_loc
|
|
65
|
+
q_seqlens: Tensor = attn_metadata.q_seqlens
|
|
66
|
+
kv_seqlens: Tensor = attn_metadata.kv_seqlens
|
|
67
|
+
kv_start_indices: Tensor = attn_metadata.kv_start_indices
|
|
68
|
+
|
|
69
|
+
input_buffers: BuffType = graph_meta.input_buffers
|
|
70
|
+
|
|
71
|
+
batch_size, num_blocks = block_offsets.size()
|
|
72
|
+
num_tokens = input_ids.size(-1)
|
|
73
|
+
|
|
74
|
+
# fill buffer
|
|
75
|
+
input_buffers["input_ids"][:, :num_tokens] = input_ids
|
|
76
|
+
input_buffers["position_ids"][:, :num_tokens] = position_ids
|
|
77
|
+
input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
|
|
78
|
+
input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
|
|
79
|
+
input_buffers["q_seqlens"][:batch_size] = q_seqlens
|
|
80
|
+
input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
|
|
81
|
+
input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
|
|
82
|
+
|
|
83
|
+
if inputs_embeds is not None:
|
|
84
|
+
emb_size = inputs_embeds.size(-1)
|
|
85
|
+
if "inputs_embeds" not in input_buffers:
|
|
86
|
+
max_num_tokens = input_buffers["input_ids"].size(-1)
|
|
87
|
+
input_buffers["inputs_embeds"] = inputs_embeds.new_zeros(
|
|
88
|
+
1, max_num_tokens, emb_size
|
|
89
|
+
)
|
|
90
|
+
input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
|
|
91
|
+
|
|
92
|
+
# create inputs
|
|
93
|
+
new_batch_size = next_power_of_2(batch_size)
|
|
94
|
+
|
|
95
|
+
attn_metadata.block_offsets = input_buffers["block_offsets"][:new_batch_size]
|
|
96
|
+
attn_metadata.q_start_loc = input_buffers["q_start_loc"][: new_batch_size + 1]
|
|
97
|
+
attn_metadata.q_seqlens = input_buffers["q_seqlens"][:new_batch_size]
|
|
98
|
+
attn_metadata.kv_seqlens = input_buffers["kv_seqlens"][:new_batch_size]
|
|
99
|
+
attn_metadata.kv_start_indices = input_buffers["kv_start_indices"][:new_batch_size]
|
|
100
|
+
attn_metadata.max_q_seq_len = int(attn_metadata.q_seqlens.max().item())
|
|
101
|
+
attn_metadata.max_kv_seq_len = int(attn_metadata.kv_seqlens.max().item())
|
|
102
|
+
|
|
103
|
+
new_inputs = dict(
|
|
104
|
+
past_key_values=past_key_values,
|
|
105
|
+
attn_metadata=attn_metadata,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
new_inputs["input_ids"] = input_buffers["input_ids"][:, :new_batch_size]
|
|
109
|
+
new_inputs["position_ids"] = input_buffers["position_ids"][:, :new_batch_size]
|
|
110
|
+
|
|
111
|
+
if inputs_embeds is not None:
|
|
112
|
+
new_inputs["inputs_embeds"] = input_buffers["inputs_embeds"][:, :new_batch_size]
|
|
113
|
+
|
|
114
|
+
new_inputs.update(kwargs)
|
|
115
|
+
|
|
116
|
+
return new_inputs
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def PPUCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
|
|
120
|
+
"""update step context with input buffers."""
|
|
121
|
+
input_buffers = graph_meta.input_buffers
|
|
122
|
+
context.block_offsets = input_buffers["block_offsets"]
|
|
123
|
+
context.q_seqlens = input_buffers["q_seqlens"]
|
|
124
|
+
context.kv_seqlens = input_buffers["kv_seqlens"]
|
|
125
|
+
context.q_start_loc = input_buffers["q_start_loc"]
|
|
126
|
+
context.kv_start_indices = input_buffers["kv_start_indices"]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
CudaGraphMixin.make_buffers_cudagraph = PPUCudaGraphMixin_make_buffers_cudagraph
|
|
130
|
+
CudaGraphMixin.fill_buffers_cudagraph = PPUCudaGraphMixin_fill_buffers_cudagraph
|
|
131
|
+
CudaGraphMixin.update_context_cudagraph = PPUCudaGraphMixin_update_context_cudagraph
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import importlib
|
|
3
|
+
import torch
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from dlinfer.vendor import vendor_name
|
|
6
|
+
from lmdeploy.pytorch import models
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
vendor = ["camb", "ascend"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def fake_torch_compile(dynamic=False):
|
|
13
|
+
def decorator(func):
|
|
14
|
+
def wrapper(*args, **kwargs):
|
|
15
|
+
return func(*args, **kwargs)
|
|
16
|
+
|
|
17
|
+
return wrapper
|
|
18
|
+
|
|
19
|
+
return decorator
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def pre_rms_norm(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
"""Pre rms norm."""
|
|
24
|
+
q = q.to(torch.float32)
|
|
25
|
+
k = k.to(torch.float32)
|
|
26
|
+
variance_q = (q * q).sum(-1, keepdim=True)
|
|
27
|
+
variance_k = (k * k).sum(-1, keepdim=True)
|
|
28
|
+
variance = torch.stack([variance_q, variance_k], dim=0)
|
|
29
|
+
return variance
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def post_rms_norm(
|
|
33
|
+
q: torch.Tensor,
|
|
34
|
+
k: torch.Tensor,
|
|
35
|
+
weight_q: torch.Tensor,
|
|
36
|
+
weight_k: torch.Tensor,
|
|
37
|
+
variance: torch.Tensor,
|
|
38
|
+
eps: float,
|
|
39
|
+
embed_dim: int,
|
|
40
|
+
dtype: torch.dtype,
|
|
41
|
+
):
|
|
42
|
+
"""Post rms norm."""
|
|
43
|
+
q = q.to(torch.float32)
|
|
44
|
+
k = k.to(torch.float32)
|
|
45
|
+
variance = variance / embed_dim + eps
|
|
46
|
+
variance_q, variance_k = variance
|
|
47
|
+
q = q * torch.rsqrt(variance_q)
|
|
48
|
+
q = q.to(dtype) * weight_q
|
|
49
|
+
k = k * torch.rsqrt(variance_k)
|
|
50
|
+
k = k.to(dtype) * weight_k
|
|
51
|
+
return q, k
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def patch_compiled_func():
|
|
55
|
+
import torch
|
|
56
|
+
|
|
57
|
+
real_torch_compile = torch.compile
|
|
58
|
+
torch.compile = fake_torch_compile
|
|
59
|
+
from lmdeploy.pytorch.models import internvl, internvl3_hf
|
|
60
|
+
|
|
61
|
+
internvl.pre_rms_norm = pre_rms_norm
|
|
62
|
+
internvl.post_rms_norm = post_rms_norm
|
|
63
|
+
internvl3_hf.pre_rms_norm = pre_rms_norm
|
|
64
|
+
internvl3_hf.post_rms_norm = post_rms_norm
|
|
65
|
+
torch.compile = real_torch_compile
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@lru_cache(1)
|
|
69
|
+
def import_vendor_module(vendor_name_str):
|
|
70
|
+
if vendor_name_str in vendor:
|
|
71
|
+
importlib.import_module(f".{vendor_name_str}", __package__)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def vendor_device_init():
|
|
75
|
+
import_vendor_module(vendor_name)
|
|
76
|
+
patch_compiled_func()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
vendor_device_init()
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import os
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from lmdeploy.pytorch.backends.dlinfer.moe import DlinferFusedMoEImpl
|
|
6
|
+
from lmdeploy.pytorch.models.chatglm2 import SelfAttention
|
|
7
|
+
from lmdeploy.pytorch.engine import logits_process
|
|
8
|
+
|
|
9
|
+
from dlinfer.vendor.ascend.utils import SocVersion
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def rl_update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
|
|
13
|
+
"""Update weights."""
|
|
14
|
+
return gate_up_weights, down_weights
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
if os.getenv("DLINFER_RESET_MOE_UPDATE_WEIGHTS", "0") == "1":
|
|
18
|
+
DlinferFusedMoEImpl.update_weights = rl_update_weights
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def ascend_chatglm2_fill_rope(states: torch.Tensor, rope: torch.Tensor):
|
|
23
|
+
"""fill rope."""
|
|
24
|
+
rope_part = states.chunk(2, -1)[1]
|
|
25
|
+
rope = rope.unflatten(-1, (2, -1))
|
|
26
|
+
rope = rope.transpose(-2, -1).flatten(-2, -1)
|
|
27
|
+
states = torch.cat([rope_part, rope], dim=-1)
|
|
28
|
+
|
|
29
|
+
return states
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
SelfAttention._fill_rope = ascend_chatglm2_fill_rope
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# modify bad words process for aclgraph performance
|
|
36
|
+
def _process_bad_words_(
|
|
37
|
+
scores: torch.Tensor,
|
|
38
|
+
bad_words: torch.LongTensor,
|
|
39
|
+
mask: torch.BoolTensor,
|
|
40
|
+
filter_value: float = -99999.9999,
|
|
41
|
+
):
|
|
42
|
+
"""Process bad words."""
|
|
43
|
+
filtered_scores = scores.gather(1, bad_words)
|
|
44
|
+
filtered_scores = mask.to(filtered_scores.dtype) * filter_value + filtered_scores
|
|
45
|
+
scores.scatter_(1, bad_words, filtered_scores)
|
|
46
|
+
return scores
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
logits_process._process_bad_words_ = _process_bad_words_
|
|
50
|
+
|
|
51
|
+
########## below is for ascend310P ##########
|
|
52
|
+
|
|
53
|
+
if SocVersion.is_Ascend310P():
|
|
54
|
+
# Layz import for Ascend310P
|
|
55
|
+
import torch.distributed as dist
|
|
56
|
+
from lmdeploy.utils import get_logger
|
|
57
|
+
from lmdeploy.pytorch.distributed import get_dist_manager, DistContext
|
|
58
|
+
from lmdeploy.pytorch.engine.model_agent import (
|
|
59
|
+
msg_with_rank,
|
|
60
|
+
BaseModelAgent,
|
|
61
|
+
)
|
|
62
|
+
from lmdeploy.pytorch.engine.cache_engine import CacheEngine
|
|
63
|
+
from lmdeploy.pytorch.models.patch import (
|
|
64
|
+
update_custom_module_map,
|
|
65
|
+
build_patched_model,
|
|
66
|
+
add_adapters,
|
|
67
|
+
)
|
|
68
|
+
from lmdeploy.pytorch.weight_loader.model_weight_loader import ModelWeightLoader
|
|
69
|
+
from lmdeploy.pytorch.disagg.config import EngineRole
|
|
70
|
+
|
|
71
|
+
logger = get_logger("lmdeploy")
|
|
72
|
+
|
|
73
|
+
def _broadcast_next_token_310P(
|
|
74
|
+
self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None
|
|
75
|
+
):
|
|
76
|
+
# NOTE: Ascend310P does not support broadcast, so we use need to use gloo for broadcast next_token_ids and then transfer it to npu
|
|
77
|
+
# This mock for properly broadcasting next_token_ids on Ascend 310P device.
|
|
78
|
+
if dist_ctx is None:
|
|
79
|
+
dist_ctx = get_dist_manager().current_context()
|
|
80
|
+
if self.cache_config.role == EngineRole.Decode:
|
|
81
|
+
next_token_ids = next_token_ids.cpu()
|
|
82
|
+
tp_cpu_group = dist_ctx.tp_cpu_group
|
|
83
|
+
dist.all_reduce(next_token_ids, op=dist.ReduceOp.SUM, group=tp_cpu_group)
|
|
84
|
+
else:
|
|
85
|
+
# NOTE: Ascend310P does not support broadcast, so we use need to use gloo for broadcast next_token_ids and then transfer it to npu
|
|
86
|
+
tp_cpu_group = dist_ctx.tp_cpu_group
|
|
87
|
+
original_device = next_token_ids.device
|
|
88
|
+
next_token_ids = next_token_ids.cpu()
|
|
89
|
+
dist.broadcast(next_token_ids, src=0, group=tp_cpu_group)
|
|
90
|
+
next_token_ids = next_token_ids.to(original_device)
|
|
91
|
+
return next_token_ids
|
|
92
|
+
|
|
93
|
+
def _allocate_cache_310P(self, num_blocks: int, device: torch.device):
|
|
94
|
+
"""
|
|
95
|
+
allocate cache implement.
|
|
96
|
+
# NOTE. Ascend300I duo devices require kv_cache to be acl NZ format.
|
|
97
|
+
"""
|
|
98
|
+
key_block_shape = self.get_key_block_shape(local=True)
|
|
99
|
+
value_block_shape = self.get_value_block_shape(local=True)
|
|
100
|
+
|
|
101
|
+
num_layers = self.num_layers
|
|
102
|
+
kv_cache_dtype = self.kv_cache_dtype
|
|
103
|
+
|
|
104
|
+
if device != "cpu":
|
|
105
|
+
import torch_npu
|
|
106
|
+
|
|
107
|
+
key_cache = torch_npu.empty_with_format(
|
|
108
|
+
size=(num_layers, num_blocks, *key_block_shape),
|
|
109
|
+
dtype=kv_cache_dtype,
|
|
110
|
+
device="npu",
|
|
111
|
+
acl_format=29, # 29 for acl NZ format
|
|
112
|
+
)
|
|
113
|
+
value_cache = torch_npu.empty_with_format(
|
|
114
|
+
size=(num_layers, num_blocks, *value_block_shape),
|
|
115
|
+
dtype=kv_cache_dtype,
|
|
116
|
+
device="npu",
|
|
117
|
+
acl_format=29,
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
key_cache = torch.empty(
|
|
121
|
+
size=(num_layers, num_blocks, *key_block_shape),
|
|
122
|
+
dtype=kv_cache_dtype,
|
|
123
|
+
device=device,
|
|
124
|
+
)
|
|
125
|
+
value_cache = torch.empty(
|
|
126
|
+
size=(num_layers, num_blocks, *value_block_shape),
|
|
127
|
+
dtype=kv_cache_dtype,
|
|
128
|
+
device=device,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
output = (key_cache, value_cache)
|
|
132
|
+
|
|
133
|
+
if self.cache_config.quant_policy in (4, 8):
|
|
134
|
+
dtype = self.model_config.dtype
|
|
135
|
+
key_sz_cache = torch.empty(
|
|
136
|
+
size=(num_layers, num_blocks, *key_block_shape[:-1], 2),
|
|
137
|
+
dtype=dtype,
|
|
138
|
+
device=device,
|
|
139
|
+
)
|
|
140
|
+
val_sz_cache = torch.empty(
|
|
141
|
+
size=(num_layers, num_blocks, *value_block_shape[:-1], 2),
|
|
142
|
+
dtype=dtype,
|
|
143
|
+
device=device,
|
|
144
|
+
)
|
|
145
|
+
output = output + (key_sz_cache, val_sz_cache)
|
|
146
|
+
|
|
147
|
+
return output
|
|
148
|
+
|
|
149
|
+
@torch.inference_mode()
|
|
150
|
+
def load_model_weights_310P(
|
|
151
|
+
model: torch.nn.Module,
|
|
152
|
+
checkpoint_path: str,
|
|
153
|
+
prefix: str = None,
|
|
154
|
+
device: torch.device = None,
|
|
155
|
+
):
|
|
156
|
+
"""Loading model weights."""
|
|
157
|
+
loader = ModelWeightLoader(checkpoint_path, prefix=prefix)
|
|
158
|
+
loader.load_model_weights(model, device=device)
|
|
159
|
+
model.eval()
|
|
160
|
+
# NOTE: Ascend310P convert Linear weight to NZ format defaultly in graph mode.
|
|
161
|
+
# However, vision_model part is not compiled in graph mode, so we skip converting weights of vision_model part.
|
|
162
|
+
# This is a workaround for Ascend310P.
|
|
163
|
+
for name, mod in model.named_modules():
|
|
164
|
+
if (
|
|
165
|
+
not hasattr(mod, "update_weights")
|
|
166
|
+
or name.startswith("vision_model")
|
|
167
|
+
or name.startswith("visual")
|
|
168
|
+
):
|
|
169
|
+
continue
|
|
170
|
+
mod.update_weights()
|
|
171
|
+
|
|
172
|
+
def _build_model_310P(self):
|
|
173
|
+
"""
|
|
174
|
+
Build patched model.
|
|
175
|
+
NOTE: Ascend310P convert Linear weight to NZ format defaultly in graph mode.
|
|
176
|
+
However, vision_model part is not compiled in graph mode, so we skip converting weights of vision_model part.
|
|
177
|
+
"""
|
|
178
|
+
model_path = self.model_path
|
|
179
|
+
adapters = self.adapters
|
|
180
|
+
device = self.device
|
|
181
|
+
rank = self.rank
|
|
182
|
+
custom_module_map = self.model_config.custom_module_map
|
|
183
|
+
if custom_module_map is not None:
|
|
184
|
+
update_custom_module_map(custom_module_map)
|
|
185
|
+
logger.debug(msg_with_rank(rank, "build model."))
|
|
186
|
+
patched_model = build_patched_model(
|
|
187
|
+
self.model_config, device=device, model_format=self.misc_config.model_format
|
|
188
|
+
)
|
|
189
|
+
logger.debug(msg_with_rank(rank, "loading weights."))
|
|
190
|
+
if not self.misc_config.empty_init:
|
|
191
|
+
load_model_weights_310P(patched_model, model_path, device=device)
|
|
192
|
+
if adapters is not None:
|
|
193
|
+
logger.debug(msg_with_rank(rank, "loading adapters."))
|
|
194
|
+
add_adapters(
|
|
195
|
+
patched_model, adapters, dtype=self.model_config.dtype, device=device
|
|
196
|
+
)
|
|
197
|
+
self.patched_model = patched_model
|
|
198
|
+
|
|
199
|
+
# Ascend310P dose't support broadcast for now, so we need to use gloo for broadcast next_token_ids and then transfer it to npu
|
|
200
|
+
BaseModelAgent._broadcast_next_token = _broadcast_next_token_310P
|
|
201
|
+
# Ascend310P requires kv_cache to be acl NZ format. So allocate gpu cache in NZ format.
|
|
202
|
+
CacheEngine._allocate_cache = _allocate_cache_310P
|
|
203
|
+
# We convert Linear weight to NZ format on Ascend310P device defaultly in graph mode.
|
|
204
|
+
# However, vision_model part is not compiled in graph mode, so we skip converting weights of vision_model part.
|
|
205
|
+
BaseModelAgent._build_model = _build_model_310P
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from lmdeploy.pytorch.backends.default.multinomial_sampling import (
|
|
5
|
+
DefaultMultinomialSamplingImpl,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def CambDefaultMultinomialSamplingImpl_forward(
|
|
10
|
+
self,
|
|
11
|
+
scores: torch.Tensor,
|
|
12
|
+
seeds: torch.LongTensor,
|
|
13
|
+
offsets: torch.LongTensor,
|
|
14
|
+
indices: torch.Tensor = None,
|
|
15
|
+
):
|
|
16
|
+
r"""
|
|
17
|
+
Note.torch_mlu.multinomial dosen't support replacement=True, whereas lmdeploy set replacement=True by default.
|
|
18
|
+
"""
|
|
19
|
+
sampled_index = torch.multinomial(scores, num_samples=1, replacement=False)
|
|
20
|
+
outputs = torch.gather(indices, dim=1, index=sampled_index)
|
|
21
|
+
return outputs.view(-1)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
DefaultMultinomialSamplingImpl.forward = CambDefaultMultinomialSamplingImpl_forward
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import importlib
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from dlinfer.vendor import vendor_name
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
awq_vendor = ["ascend"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@lru_cache(1)
|
|
11
|
+
def import_vendor_module(vendor_name_str):
|
|
12
|
+
if vendor_name_str in awq_vendor:
|
|
13
|
+
importlib.import_module(f".{vendor_name_str}_awq", __package__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def vendor_quant_init():
|
|
17
|
+
import_vendor_module(vendor_name)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
vendor_quant_init()
|