dlinfer-ascend 0.2.3.post2__cp311-cp311-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dlinfer/__init__.py +5 -0
  2. dlinfer/framework/__init__.py +1 -0
  3. dlinfer/framework/lmdeploy_ext/__init__.py +6 -0
  4. dlinfer/framework/lmdeploy_ext/cudagraph/__init__.py +20 -0
  5. dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +391 -0
  6. dlinfer/framework/lmdeploy_ext/cudagraph/camb_cudagraph.py +133 -0
  7. dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py +128 -0
  8. dlinfer/framework/lmdeploy_ext/cudagraph/ppu_cudagraph.py +131 -0
  9. dlinfer/framework/lmdeploy_ext/device/__init__.py +79 -0
  10. dlinfer/framework/lmdeploy_ext/device/ascend.py +205 -0
  11. dlinfer/framework/lmdeploy_ext/device/camb.py +24 -0
  12. dlinfer/framework/lmdeploy_ext/quants/__init__.py +20 -0
  13. dlinfer/framework/lmdeploy_ext/quants/ascend_awq.py +248 -0
  14. dlinfer/framework/torch_npu_ext/__init__.py +12 -0
  15. dlinfer/framework/torch_npu_ext/aclgraph.py +59 -0
  16. dlinfer/framework/transformers_ext/__init__.py +17 -0
  17. dlinfer/framework/transformers_ext/cogvlm.py +25 -0
  18. dlinfer/framework/transformers_ext/internlm2.py +242 -0
  19. dlinfer/framework/transformers_ext/internvl.py +33 -0
  20. dlinfer/framework/transformers_ext/patch.py +33 -0
  21. dlinfer/graph/__init__.py +5 -0
  22. dlinfer/graph/custom_op.py +147 -0
  23. dlinfer/graph/dicp/__init__.py +0 -0
  24. dlinfer/graph/dicp/dynamo_bridge/__init__.py +0 -0
  25. dlinfer/graph/dicp/dynamo_bridge/compile.py +42 -0
  26. dlinfer/graph/dicp/dynamo_bridge/compile_fx.py +305 -0
  27. dlinfer/graph/dicp/dynamo_bridge/conversion.py +75 -0
  28. dlinfer/graph/dicp/dynamo_bridge/decompositions.py +38 -0
  29. dlinfer/graph/dicp/dynamo_bridge/graph.py +141 -0
  30. dlinfer/graph/dicp/dynamo_bridge/op_transformer.py +293 -0
  31. dlinfer/graph/dicp/dynamo_bridge/operator.py +87 -0
  32. dlinfer/graph/dicp/dynamo_bridge/pt_patch.py +320 -0
  33. dlinfer/graph/dicp/dynamo_bridge/torch_version.py +38 -0
  34. dlinfer/graph/dicp/dynamo_bridge/utils.py +158 -0
  35. dlinfer/graph/dicp/vendor/AtbGraph/__init__.py +13 -0
  36. dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py +853 -0
  37. dlinfer/graph/dicp/vendor/AtbGraph/codegen/__init__.py +0 -0
  38. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb.py +318 -0
  39. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_graph.py +768 -0
  40. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py +763 -0
  41. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py +1279 -0
  42. dlinfer/graph/dicp/vendor/AtbGraph/codegen/libdicp_model.so +0 -0
  43. dlinfer/graph/dicp/vendor/AtbGraph/codegen/load_and_run.py +21 -0
  44. dlinfer/graph/dicp/vendor/AtbGraph/codegen/utils.py +178 -0
  45. dlinfer/graph/dicp/vendor/AtbGraph/compile_job.py +52 -0
  46. dlinfer/graph/dicp/vendor/AtbGraph/config.py +36 -0
  47. dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +908 -0
  48. dlinfer/graph/dicp/vendor/AtbGraph/ext_ops.py +95 -0
  49. dlinfer/graph/dicp/vendor/AtbGraph/infer_res_utils.py +200 -0
  50. dlinfer/graph/dicp/vendor/AtbGraph/opset_convert.py +70 -0
  51. dlinfer/graph/dicp/vendor/AtbGraph/pattern_replacement.py +152 -0
  52. dlinfer/graph/dicp/vendor/__init__.py +0 -0
  53. dlinfer/ops/__init__.py +2 -0
  54. dlinfer/ops/llm.py +879 -0
  55. dlinfer/utils/__init__.py +1 -0
  56. dlinfer/utils/config.py +18 -0
  57. dlinfer/utils/registry.py +8 -0
  58. dlinfer/utils/type_annotation.py +3 -0
  59. dlinfer/vendor/__init__.py +33 -0
  60. dlinfer/vendor/ascend/__init__.py +5 -0
  61. dlinfer/vendor/ascend/pytorch_patch.py +55 -0
  62. dlinfer/vendor/ascend/torch_npu_ops.py +601 -0
  63. dlinfer/vendor/ascend/utils.py +20 -0
  64. dlinfer/vendor/vendor.yaml +2 -0
  65. dlinfer_ascend-0.2.3.post2.dist-info/LICENSE +28 -0
  66. dlinfer_ascend-0.2.3.post2.dist-info/METADATA +213 -0
  67. dlinfer_ascend-0.2.3.post2.dist-info/RECORD +70 -0
  68. dlinfer_ascend-0.2.3.post2.dist-info/WHEEL +5 -0
  69. dlinfer_ascend-0.2.3.post2.dist-info/entry_points.txt +2 -0
  70. dlinfer_ascend-0.2.3.post2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,131 @@
1
+ # Copyright (c) 2025, OpenMMLab and DeepLink. All rights reserved.
2
+ import torch
3
+ from torch import Tensor
4
+
5
+ from typing import Any, Dict, List
6
+
7
+ from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
8
+ from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMixin, next_power_of_2
9
+
10
+ BuffType = Dict[str, Tensor]
11
+
12
+
13
+ def PPUCudaGraphMixin_make_buffers_cudagraph(
14
+ self, graph_meta: CudaGraphMeta, *args, **kwargs
15
+ ) -> BuffType:
16
+ """make cudagraph buffers from forward inputs."""
17
+ max_batches = graph_meta.max_batchs
18
+ max_tokens = graph_meta.max_tokens
19
+ num_blocks = graph_meta.num_blocks
20
+ device = graph_meta.device
21
+ input_buffers: BuffType = dict()
22
+ input_buffers["input_ids"] = torch.zeros(
23
+ 1, max_tokens, dtype=torch.int32, device=device
24
+ )
25
+
26
+ input_buffers["position_ids"] = torch.zeros(
27
+ (1, max_tokens), dtype=torch.int32, device=device
28
+ )
29
+
30
+ input_buffers["block_offsets"] = torch.zeros(
31
+ (max_batches, num_blocks), dtype=torch.int32, device=device
32
+ )
33
+
34
+ input_buffers["q_seqlens"] = torch.zeros(
35
+ max_batches, dtype=torch.int32, device=device
36
+ )
37
+
38
+ input_buffers["kv_seqlens"] = torch.zeros(
39
+ max_batches, dtype=torch.int32, device=device
40
+ )
41
+
42
+ input_buffers["q_start_loc"] = torch.zeros(
43
+ max_batches + 1, dtype=torch.int32, device=device
44
+ )
45
+
46
+ input_buffers["kv_start_indices"] = torch.zeros(
47
+ (max_batches, 1), dtype=torch.int64, device=device
48
+ )
49
+ return input_buffers
50
+
51
+
52
+ def PPUCudaGraphMixin_fill_buffers_cudagraph(
53
+ self,
54
+ graph_meta: CudaGraphMeta,
55
+ input_ids: Tensor,
56
+ position_ids: Tensor,
57
+ past_key_values: List,
58
+ attn_metadata: Any,
59
+ inputs_embeds: Tensor,
60
+ **kwargs
61
+ ) -> Dict[str, Tensor]:
62
+ """fill cudagraph buffers from forward inputs."""
63
+ block_offsets: Tensor = attn_metadata.block_offsets
64
+ q_start_loc: Tensor = attn_metadata.q_start_loc
65
+ q_seqlens: Tensor = attn_metadata.q_seqlens
66
+ kv_seqlens: Tensor = attn_metadata.kv_seqlens
67
+ kv_start_indices: Tensor = attn_metadata.kv_start_indices
68
+
69
+ input_buffers: BuffType = graph_meta.input_buffers
70
+
71
+ batch_size, num_blocks = block_offsets.size()
72
+ num_tokens = input_ids.size(-1)
73
+
74
+ # fill buffer
75
+ input_buffers["input_ids"][:, :num_tokens] = input_ids
76
+ input_buffers["position_ids"][:, :num_tokens] = position_ids
77
+ input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
78
+ input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
79
+ input_buffers["q_seqlens"][:batch_size] = q_seqlens
80
+ input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
81
+ input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
82
+
83
+ if inputs_embeds is not None:
84
+ emb_size = inputs_embeds.size(-1)
85
+ if "inputs_embeds" not in input_buffers:
86
+ max_num_tokens = input_buffers["input_ids"].size(-1)
87
+ input_buffers["inputs_embeds"] = inputs_embeds.new_zeros(
88
+ 1, max_num_tokens, emb_size
89
+ )
90
+ input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
91
+
92
+ # create inputs
93
+ new_batch_size = next_power_of_2(batch_size)
94
+
95
+ attn_metadata.block_offsets = input_buffers["block_offsets"][:new_batch_size]
96
+ attn_metadata.q_start_loc = input_buffers["q_start_loc"][: new_batch_size + 1]
97
+ attn_metadata.q_seqlens = input_buffers["q_seqlens"][:new_batch_size]
98
+ attn_metadata.kv_seqlens = input_buffers["kv_seqlens"][:new_batch_size]
99
+ attn_metadata.kv_start_indices = input_buffers["kv_start_indices"][:new_batch_size]
100
+ attn_metadata.max_q_seq_len = int(attn_metadata.q_seqlens.max().item())
101
+ attn_metadata.max_kv_seq_len = int(attn_metadata.kv_seqlens.max().item())
102
+
103
+ new_inputs = dict(
104
+ past_key_values=past_key_values,
105
+ attn_metadata=attn_metadata,
106
+ )
107
+
108
+ new_inputs["input_ids"] = input_buffers["input_ids"][:, :new_batch_size]
109
+ new_inputs["position_ids"] = input_buffers["position_ids"][:, :new_batch_size]
110
+
111
+ if inputs_embeds is not None:
112
+ new_inputs["inputs_embeds"] = input_buffers["inputs_embeds"][:, :new_batch_size]
113
+
114
+ new_inputs.update(kwargs)
115
+
116
+ return new_inputs
117
+
118
+
119
+ def PPUCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
120
+ """update step context with input buffers."""
121
+ input_buffers = graph_meta.input_buffers
122
+ context.block_offsets = input_buffers["block_offsets"]
123
+ context.q_seqlens = input_buffers["q_seqlens"]
124
+ context.kv_seqlens = input_buffers["kv_seqlens"]
125
+ context.q_start_loc = input_buffers["q_start_loc"]
126
+ context.kv_start_indices = input_buffers["kv_start_indices"]
127
+
128
+
129
+ CudaGraphMixin.make_buffers_cudagraph = PPUCudaGraphMixin_make_buffers_cudagraph
130
+ CudaGraphMixin.fill_buffers_cudagraph = PPUCudaGraphMixin_fill_buffers_cudagraph
131
+ CudaGraphMixin.update_context_cudagraph = PPUCudaGraphMixin_update_context_cudagraph
@@ -0,0 +1,79 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import importlib
3
+ import torch
4
+ from functools import lru_cache
5
+ from dlinfer.vendor import vendor_name
6
+ from lmdeploy.pytorch import models
7
+
8
+
9
+ vendor = ["camb", "ascend"]
10
+
11
+
12
+ def fake_torch_compile(dynamic=False):
13
+ def decorator(func):
14
+ def wrapper(*args, **kwargs):
15
+ return func(*args, **kwargs)
16
+
17
+ return wrapper
18
+
19
+ return decorator
20
+
21
+
22
+ def pre_rms_norm(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
23
+ """Pre rms norm."""
24
+ q = q.to(torch.float32)
25
+ k = k.to(torch.float32)
26
+ variance_q = (q * q).sum(-1, keepdim=True)
27
+ variance_k = (k * k).sum(-1, keepdim=True)
28
+ variance = torch.stack([variance_q, variance_k], dim=0)
29
+ return variance
30
+
31
+
32
+ def post_rms_norm(
33
+ q: torch.Tensor,
34
+ k: torch.Tensor,
35
+ weight_q: torch.Tensor,
36
+ weight_k: torch.Tensor,
37
+ variance: torch.Tensor,
38
+ eps: float,
39
+ embed_dim: int,
40
+ dtype: torch.dtype,
41
+ ):
42
+ """Post rms norm."""
43
+ q = q.to(torch.float32)
44
+ k = k.to(torch.float32)
45
+ variance = variance / embed_dim + eps
46
+ variance_q, variance_k = variance
47
+ q = q * torch.rsqrt(variance_q)
48
+ q = q.to(dtype) * weight_q
49
+ k = k * torch.rsqrt(variance_k)
50
+ k = k.to(dtype) * weight_k
51
+ return q, k
52
+
53
+
54
+ def patch_compiled_func():
55
+ import torch
56
+
57
+ real_torch_compile = torch.compile
58
+ torch.compile = fake_torch_compile
59
+ from lmdeploy.pytorch.models import internvl, internvl3_hf
60
+
61
+ internvl.pre_rms_norm = pre_rms_norm
62
+ internvl.post_rms_norm = post_rms_norm
63
+ internvl3_hf.pre_rms_norm = pre_rms_norm
64
+ internvl3_hf.post_rms_norm = post_rms_norm
65
+ torch.compile = real_torch_compile
66
+
67
+
68
+ @lru_cache(1)
69
+ def import_vendor_module(vendor_name_str):
70
+ if vendor_name_str in vendor:
71
+ importlib.import_module(f".{vendor_name_str}", __package__)
72
+
73
+
74
+ def vendor_device_init():
75
+ import_vendor_module(vendor_name)
76
+ patch_compiled_func()
77
+
78
+
79
+ vendor_device_init()
@@ -0,0 +1,205 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import os
3
+ import torch
4
+
5
+ from lmdeploy.pytorch.backends.dlinfer.moe import DlinferFusedMoEImpl
6
+ from lmdeploy.pytorch.models.chatglm2 import SelfAttention
7
+ from lmdeploy.pytorch.engine import logits_process
8
+
9
+ from dlinfer.vendor.ascend.utils import SocVersion
10
+
11
+
12
+ def rl_update_weights(self, gate_up_weights: torch.Tensor, down_weights: torch.Tensor):
13
+ """Update weights."""
14
+ return gate_up_weights, down_weights
15
+
16
+
17
+ if os.getenv("DLINFER_RESET_MOE_UPDATE_WEIGHTS", "0") == "1":
18
+ DlinferFusedMoEImpl.update_weights = rl_update_weights
19
+
20
+
21
+ @staticmethod
22
+ def ascend_chatglm2_fill_rope(states: torch.Tensor, rope: torch.Tensor):
23
+ """fill rope."""
24
+ rope_part = states.chunk(2, -1)[1]
25
+ rope = rope.unflatten(-1, (2, -1))
26
+ rope = rope.transpose(-2, -1).flatten(-2, -1)
27
+ states = torch.cat([rope_part, rope], dim=-1)
28
+
29
+ return states
30
+
31
+
32
+ SelfAttention._fill_rope = ascend_chatglm2_fill_rope
33
+
34
+
35
+ # modify bad words process for aclgraph performance
36
+ def _process_bad_words_(
37
+ scores: torch.Tensor,
38
+ bad_words: torch.LongTensor,
39
+ mask: torch.BoolTensor,
40
+ filter_value: float = -99999.9999,
41
+ ):
42
+ """Process bad words."""
43
+ filtered_scores = scores.gather(1, bad_words)
44
+ filtered_scores = mask.to(filtered_scores.dtype) * filter_value + filtered_scores
45
+ scores.scatter_(1, bad_words, filtered_scores)
46
+ return scores
47
+
48
+
49
+ logits_process._process_bad_words_ = _process_bad_words_
50
+
51
+ ########## below is for ascend310P ##########
52
+
53
+ if SocVersion.is_Ascend310P():
54
+ # Layz import for Ascend310P
55
+ import torch.distributed as dist
56
+ from lmdeploy.utils import get_logger
57
+ from lmdeploy.pytorch.distributed import get_dist_manager, DistContext
58
+ from lmdeploy.pytorch.engine.model_agent import (
59
+ msg_with_rank,
60
+ BaseModelAgent,
61
+ )
62
+ from lmdeploy.pytorch.engine.cache_engine import CacheEngine
63
+ from lmdeploy.pytorch.models.patch import (
64
+ update_custom_module_map,
65
+ build_patched_model,
66
+ add_adapters,
67
+ )
68
+ from lmdeploy.pytorch.weight_loader.model_weight_loader import ModelWeightLoader
69
+ from lmdeploy.pytorch.disagg.config import EngineRole
70
+
71
+ logger = get_logger("lmdeploy")
72
+
73
+ def _broadcast_next_token_310P(
74
+ self, next_token_ids: torch.Tensor, dist_ctx: DistContext = None
75
+ ):
76
+ # NOTE: Ascend310P does not support broadcast, so we use need to use gloo for broadcast next_token_ids and then transfer it to npu
77
+ # This mock for properly broadcasting next_token_ids on Ascend 310P device.
78
+ if dist_ctx is None:
79
+ dist_ctx = get_dist_manager().current_context()
80
+ if self.cache_config.role == EngineRole.Decode:
81
+ next_token_ids = next_token_ids.cpu()
82
+ tp_cpu_group = dist_ctx.tp_cpu_group
83
+ dist.all_reduce(next_token_ids, op=dist.ReduceOp.SUM, group=tp_cpu_group)
84
+ else:
85
+ # NOTE: Ascend310P does not support broadcast, so we use need to use gloo for broadcast next_token_ids and then transfer it to npu
86
+ tp_cpu_group = dist_ctx.tp_cpu_group
87
+ original_device = next_token_ids.device
88
+ next_token_ids = next_token_ids.cpu()
89
+ dist.broadcast(next_token_ids, src=0, group=tp_cpu_group)
90
+ next_token_ids = next_token_ids.to(original_device)
91
+ return next_token_ids
92
+
93
+ def _allocate_cache_310P(self, num_blocks: int, device: torch.device):
94
+ """
95
+ allocate cache implement.
96
+ # NOTE. Ascend300I duo devices require kv_cache to be acl NZ format.
97
+ """
98
+ key_block_shape = self.get_key_block_shape(local=True)
99
+ value_block_shape = self.get_value_block_shape(local=True)
100
+
101
+ num_layers = self.num_layers
102
+ kv_cache_dtype = self.kv_cache_dtype
103
+
104
+ if device != "cpu":
105
+ import torch_npu
106
+
107
+ key_cache = torch_npu.empty_with_format(
108
+ size=(num_layers, num_blocks, *key_block_shape),
109
+ dtype=kv_cache_dtype,
110
+ device="npu",
111
+ acl_format=29, # 29 for acl NZ format
112
+ )
113
+ value_cache = torch_npu.empty_with_format(
114
+ size=(num_layers, num_blocks, *value_block_shape),
115
+ dtype=kv_cache_dtype,
116
+ device="npu",
117
+ acl_format=29,
118
+ )
119
+ else:
120
+ key_cache = torch.empty(
121
+ size=(num_layers, num_blocks, *key_block_shape),
122
+ dtype=kv_cache_dtype,
123
+ device=device,
124
+ )
125
+ value_cache = torch.empty(
126
+ size=(num_layers, num_blocks, *value_block_shape),
127
+ dtype=kv_cache_dtype,
128
+ device=device,
129
+ )
130
+
131
+ output = (key_cache, value_cache)
132
+
133
+ if self.cache_config.quant_policy in (4, 8):
134
+ dtype = self.model_config.dtype
135
+ key_sz_cache = torch.empty(
136
+ size=(num_layers, num_blocks, *key_block_shape[:-1], 2),
137
+ dtype=dtype,
138
+ device=device,
139
+ )
140
+ val_sz_cache = torch.empty(
141
+ size=(num_layers, num_blocks, *value_block_shape[:-1], 2),
142
+ dtype=dtype,
143
+ device=device,
144
+ )
145
+ output = output + (key_sz_cache, val_sz_cache)
146
+
147
+ return output
148
+
149
+ @torch.inference_mode()
150
+ def load_model_weights_310P(
151
+ model: torch.nn.Module,
152
+ checkpoint_path: str,
153
+ prefix: str = None,
154
+ device: torch.device = None,
155
+ ):
156
+ """Loading model weights."""
157
+ loader = ModelWeightLoader(checkpoint_path, prefix=prefix)
158
+ loader.load_model_weights(model, device=device)
159
+ model.eval()
160
+ # NOTE: Ascend310P convert Linear weight to NZ format defaultly in graph mode.
161
+ # However, vision_model part is not compiled in graph mode, so we skip converting weights of vision_model part.
162
+ # This is a workaround for Ascend310P.
163
+ for name, mod in model.named_modules():
164
+ if (
165
+ not hasattr(mod, "update_weights")
166
+ or name.startswith("vision_model")
167
+ or name.startswith("visual")
168
+ ):
169
+ continue
170
+ mod.update_weights()
171
+
172
+ def _build_model_310P(self):
173
+ """
174
+ Build patched model.
175
+ NOTE: Ascend310P convert Linear weight to NZ format defaultly in graph mode.
176
+ However, vision_model part is not compiled in graph mode, so we skip converting weights of vision_model part.
177
+ """
178
+ model_path = self.model_path
179
+ adapters = self.adapters
180
+ device = self.device
181
+ rank = self.rank
182
+ custom_module_map = self.model_config.custom_module_map
183
+ if custom_module_map is not None:
184
+ update_custom_module_map(custom_module_map)
185
+ logger.debug(msg_with_rank(rank, "build model."))
186
+ patched_model = build_patched_model(
187
+ self.model_config, device=device, model_format=self.misc_config.model_format
188
+ )
189
+ logger.debug(msg_with_rank(rank, "loading weights."))
190
+ if not self.misc_config.empty_init:
191
+ load_model_weights_310P(patched_model, model_path, device=device)
192
+ if adapters is not None:
193
+ logger.debug(msg_with_rank(rank, "loading adapters."))
194
+ add_adapters(
195
+ patched_model, adapters, dtype=self.model_config.dtype, device=device
196
+ )
197
+ self.patched_model = patched_model
198
+
199
+ # Ascend310P dose't support broadcast for now, so we need to use gloo for broadcast next_token_ids and then transfer it to npu
200
+ BaseModelAgent._broadcast_next_token = _broadcast_next_token_310P
201
+ # Ascend310P requires kv_cache to be acl NZ format. So allocate gpu cache in NZ format.
202
+ CacheEngine._allocate_cache = _allocate_cache_310P
203
+ # We convert Linear weight to NZ format on Ascend310P device defaultly in graph mode.
204
+ # However, vision_model part is not compiled in graph mode, so we skip converting weights of vision_model part.
205
+ BaseModelAgent._build_model = _build_model_310P
@@ -0,0 +1,24 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import torch
3
+
4
+ from lmdeploy.pytorch.backends.default.multinomial_sampling import (
5
+ DefaultMultinomialSamplingImpl,
6
+ )
7
+
8
+
9
+ def CambDefaultMultinomialSamplingImpl_forward(
10
+ self,
11
+ scores: torch.Tensor,
12
+ seeds: torch.LongTensor,
13
+ offsets: torch.LongTensor,
14
+ indices: torch.Tensor = None,
15
+ ):
16
+ r"""
17
+ Note.torch_mlu.multinomial dosen't support replacement=True, whereas lmdeploy set replacement=True by default.
18
+ """
19
+ sampled_index = torch.multinomial(scores, num_samples=1, replacement=False)
20
+ outputs = torch.gather(indices, dim=1, index=sampled_index)
21
+ return outputs.view(-1)
22
+
23
+
24
+ DefaultMultinomialSamplingImpl.forward = CambDefaultMultinomialSamplingImpl_forward
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import importlib
3
+ from functools import lru_cache
4
+ from dlinfer.vendor import vendor_name
5
+
6
+
7
+ awq_vendor = ["ascend"]
8
+
9
+
10
+ @lru_cache(1)
11
+ def import_vendor_module(vendor_name_str):
12
+ if vendor_name_str in awq_vendor:
13
+ importlib.import_module(f".{vendor_name_str}_awq", __package__)
14
+
15
+
16
+ def vendor_quant_init():
17
+ import_vendor_module(vendor_name)
18
+
19
+
20
+ vendor_quant_init()