dlinfer-ascend 0.2.3.post2__cp311-cp311-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dlinfer/__init__.py +5 -0
- dlinfer/framework/__init__.py +1 -0
- dlinfer/framework/lmdeploy_ext/__init__.py +6 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/__init__.py +20 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +391 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/camb_cudagraph.py +133 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py +128 -0
- dlinfer/framework/lmdeploy_ext/cudagraph/ppu_cudagraph.py +131 -0
- dlinfer/framework/lmdeploy_ext/device/__init__.py +79 -0
- dlinfer/framework/lmdeploy_ext/device/ascend.py +205 -0
- dlinfer/framework/lmdeploy_ext/device/camb.py +24 -0
- dlinfer/framework/lmdeploy_ext/quants/__init__.py +20 -0
- dlinfer/framework/lmdeploy_ext/quants/ascend_awq.py +248 -0
- dlinfer/framework/torch_npu_ext/__init__.py +12 -0
- dlinfer/framework/torch_npu_ext/aclgraph.py +59 -0
- dlinfer/framework/transformers_ext/__init__.py +17 -0
- dlinfer/framework/transformers_ext/cogvlm.py +25 -0
- dlinfer/framework/transformers_ext/internlm2.py +242 -0
- dlinfer/framework/transformers_ext/internvl.py +33 -0
- dlinfer/framework/transformers_ext/patch.py +33 -0
- dlinfer/graph/__init__.py +5 -0
- dlinfer/graph/custom_op.py +147 -0
- dlinfer/graph/dicp/__init__.py +0 -0
- dlinfer/graph/dicp/dynamo_bridge/__init__.py +0 -0
- dlinfer/graph/dicp/dynamo_bridge/compile.py +42 -0
- dlinfer/graph/dicp/dynamo_bridge/compile_fx.py +305 -0
- dlinfer/graph/dicp/dynamo_bridge/conversion.py +75 -0
- dlinfer/graph/dicp/dynamo_bridge/decompositions.py +38 -0
- dlinfer/graph/dicp/dynamo_bridge/graph.py +141 -0
- dlinfer/graph/dicp/dynamo_bridge/op_transformer.py +293 -0
- dlinfer/graph/dicp/dynamo_bridge/operator.py +87 -0
- dlinfer/graph/dicp/dynamo_bridge/pt_patch.py +320 -0
- dlinfer/graph/dicp/dynamo_bridge/torch_version.py +38 -0
- dlinfer/graph/dicp/dynamo_bridge/utils.py +158 -0
- dlinfer/graph/dicp/vendor/AtbGraph/__init__.py +13 -0
- dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py +853 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/__init__.py +0 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb.py +318 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_graph.py +768 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py +763 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py +1279 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/libdicp_model.so +0 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/load_and_run.py +21 -0
- dlinfer/graph/dicp/vendor/AtbGraph/codegen/utils.py +178 -0
- dlinfer/graph/dicp/vendor/AtbGraph/compile_job.py +52 -0
- dlinfer/graph/dicp/vendor/AtbGraph/config.py +36 -0
- dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +908 -0
- dlinfer/graph/dicp/vendor/AtbGraph/ext_ops.py +95 -0
- dlinfer/graph/dicp/vendor/AtbGraph/infer_res_utils.py +200 -0
- dlinfer/graph/dicp/vendor/AtbGraph/opset_convert.py +70 -0
- dlinfer/graph/dicp/vendor/AtbGraph/pattern_replacement.py +152 -0
- dlinfer/graph/dicp/vendor/__init__.py +0 -0
- dlinfer/ops/__init__.py +2 -0
- dlinfer/ops/llm.py +879 -0
- dlinfer/utils/__init__.py +1 -0
- dlinfer/utils/config.py +18 -0
- dlinfer/utils/registry.py +8 -0
- dlinfer/utils/type_annotation.py +3 -0
- dlinfer/vendor/__init__.py +33 -0
- dlinfer/vendor/ascend/__init__.py +5 -0
- dlinfer/vendor/ascend/pytorch_patch.py +55 -0
- dlinfer/vendor/ascend/torch_npu_ops.py +601 -0
- dlinfer/vendor/ascend/utils.py +20 -0
- dlinfer/vendor/vendor.yaml +2 -0
- dlinfer_ascend-0.2.3.post2.dist-info/LICENSE +28 -0
- dlinfer_ascend-0.2.3.post2.dist-info/METADATA +213 -0
- dlinfer_ascend-0.2.3.post2.dist-info/RECORD +70 -0
- dlinfer_ascend-0.2.3.post2.dist-info/WHEEL +5 -0
- dlinfer_ascend-0.2.3.post2.dist-info/entry_points.txt +2 -0
- dlinfer_ascend-0.2.3.post2.dist-info/top_level.txt +1 -0
dlinfer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (c) 2024, DeepLink. All rights reserved.
|
|
2
|
+
import importlib
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from dlinfer.vendor import vendor_name
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
graph_vendor = ["ascend", "maca", "camb", "ppu"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@lru_cache(1)
|
|
11
|
+
def import_vendor_module(vendor_name_str):
|
|
12
|
+
if vendor_name_str in graph_vendor:
|
|
13
|
+
importlib.import_module(f".{vendor_name_str}_cudagraph", __package__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def vendor_graph_init():
|
|
17
|
+
import_vendor_module(vendor_name)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
vendor_graph_init()
|
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
# Copyright (c) 2024, OpenMMLab and DeepLink. All rights reserved.
|
|
2
|
+
# this file implements the cudagraph for ascend backend.
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from contextlib import ExitStack
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
from torch.profiler import record_function
|
|
9
|
+
|
|
10
|
+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
|
|
11
|
+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMixin
|
|
12
|
+
from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
|
|
13
|
+
from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager
|
|
14
|
+
from lmdeploy.pytorch.backends.graph_runner import GraphRunner
|
|
15
|
+
from lmdeploy.pytorch.backends.cuda import graph_runner
|
|
16
|
+
|
|
17
|
+
from lmdeploy.utils import get_logger
|
|
18
|
+
|
|
19
|
+
logger = get_logger("dlinfer")
|
|
20
|
+
BuffType = Dict[str, Tensor]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# AscendCudaGraphMixin methods for cudagraph buffer management.
|
|
24
|
+
def AscendCudaGraphMixin_make_buffers_cudagraph(
|
|
25
|
+
self, graph_meta: CudaGraphMeta, *args, **kwargs
|
|
26
|
+
) -> BuffType:
|
|
27
|
+
"""make cudagraph buffers from forward inputs."""
|
|
28
|
+
max_batches = graph_meta.max_batchs
|
|
29
|
+
max_tokens = graph_meta.max_tokens
|
|
30
|
+
num_blocks = graph_meta.num_blocks
|
|
31
|
+
device = graph_meta.device
|
|
32
|
+
input_buffers: BuffType = dict()
|
|
33
|
+
input_buffers["input_ids"] = torch.empty(
|
|
34
|
+
1, max_tokens, dtype=torch.int32, device=device
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
input_buffers["position_ids"] = torch.empty(
|
|
38
|
+
(1, max_tokens), dtype=torch.int32, device=device
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
input_buffers["block_offsets"] = torch.zeros(
|
|
42
|
+
(max_batches, num_blocks), dtype=torch.int32, device=device
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
input_buffers["q_seqlens"] = torch.ones(
|
|
46
|
+
max_batches, dtype=torch.int32, device=device
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
input_buffers["kv_seqlens"] = [0] * max_batches
|
|
50
|
+
|
|
51
|
+
input_buffers["q_start_loc"] = torch.arange(
|
|
52
|
+
max_batches + 1, dtype=torch.int32, device=device
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
input_buffers["kv_start_indices"] = -torch.ones(
|
|
56
|
+
(max_batches), dtype=torch.int64, device=device
|
|
57
|
+
)
|
|
58
|
+
return input_buffers
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def AscendCudaGraphMixin_fill_buffers_cudagraph(
|
|
62
|
+
self,
|
|
63
|
+
graph_meta: CudaGraphMeta,
|
|
64
|
+
input_ids: Tensor,
|
|
65
|
+
position_ids: Tensor,
|
|
66
|
+
past_key_values: List,
|
|
67
|
+
attn_metadata: Any,
|
|
68
|
+
inputs_embeds: Tensor,
|
|
69
|
+
**kwargs,
|
|
70
|
+
) -> Dict[str, Tensor]:
|
|
71
|
+
"""fill cudagraph buffers from forward inputs."""
|
|
72
|
+
block_offsets: Tensor = attn_metadata.block_offsets
|
|
73
|
+
kv_seqlens: List = attn_metadata.kv_seqlens
|
|
74
|
+
kv_start_indices: Tensor = attn_metadata.kv_start_indices
|
|
75
|
+
|
|
76
|
+
input_buffers: BuffType = graph_meta.input_buffers
|
|
77
|
+
|
|
78
|
+
batch_size, num_blocks = block_offsets.size()
|
|
79
|
+
num_tokens = input_ids.size(-1)
|
|
80
|
+
|
|
81
|
+
# fill buffer
|
|
82
|
+
input_buffers["input_ids"][:, :num_tokens] = input_ids
|
|
83
|
+
input_buffers["position_ids"][:, :num_tokens] = position_ids
|
|
84
|
+
input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
|
|
85
|
+
input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
|
|
86
|
+
input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
|
|
87
|
+
|
|
88
|
+
if inputs_embeds is not None:
|
|
89
|
+
emb_size = inputs_embeds.size(-1)
|
|
90
|
+
if "inputs_embeds" not in input_buffers:
|
|
91
|
+
max_num_tokens = input_buffers["input_ids"].size(-1)
|
|
92
|
+
input_buffers["inputs_embeds"] = inputs_embeds.new_zeros(
|
|
93
|
+
1, max_num_tokens, emb_size
|
|
94
|
+
)
|
|
95
|
+
input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
|
|
96
|
+
# create inputs
|
|
97
|
+
new_batch_size = get_ascend_compatible_size(batch_size)
|
|
98
|
+
|
|
99
|
+
attn_metadata.block_offsets = input_buffers["block_offsets"][:new_batch_size]
|
|
100
|
+
attn_metadata.kv_seqlens = input_buffers["kv_seqlens"][:new_batch_size]
|
|
101
|
+
attn_metadata.kv_start_indices = input_buffers["kv_start_indices"][:new_batch_size]
|
|
102
|
+
|
|
103
|
+
new_inputs = dict(
|
|
104
|
+
past_key_values=past_key_values,
|
|
105
|
+
attn_metadata=attn_metadata,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
new_inputs["input_ids"] = input_buffers["input_ids"][:, :new_batch_size]
|
|
109
|
+
new_inputs["position_ids"] = input_buffers["position_ids"][:, :new_batch_size]
|
|
110
|
+
|
|
111
|
+
if inputs_embeds is not None:
|
|
112
|
+
new_inputs["inputs_embeds"] = input_buffers["inputs_embeds"][:, :new_batch_size]
|
|
113
|
+
|
|
114
|
+
new_inputs.update(kwargs)
|
|
115
|
+
|
|
116
|
+
return new_inputs
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
|
|
120
|
+
"""update step context with input buffers."""
|
|
121
|
+
input_buffers = graph_meta.input_buffers
|
|
122
|
+
context.block_offsets = input_buffers["block_offsets"]
|
|
123
|
+
context.q_seqlens = input_buffers["q_seqlens"]
|
|
124
|
+
context.kv_seqlens = input_buffers["kv_seqlens"]
|
|
125
|
+
context.q_start_loc = input_buffers["q_start_loc"]
|
|
126
|
+
context.kv_start_indices = input_buffers["kv_start_indices"]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph
|
|
130
|
+
CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph
|
|
131
|
+
CudaGraphMixin.update_context_cudagraph = AscendCudaGraphMixin_update_context_cudagraph
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def next_power_of_2(n: int):
|
|
135
|
+
"""Return the smallest power of 2 greater than or equal to n."""
|
|
136
|
+
n -= 1
|
|
137
|
+
n |= n >> 1
|
|
138
|
+
n |= n >> 2
|
|
139
|
+
n |= n >> 4
|
|
140
|
+
n |= n >> 8
|
|
141
|
+
n |= n >> 16
|
|
142
|
+
n |= n >> 32
|
|
143
|
+
n += 1
|
|
144
|
+
return n
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def get_ascend_compatible_size(n: int):
|
|
148
|
+
"""Get ascend compatible size."""
|
|
149
|
+
if n <= 16:
|
|
150
|
+
n = next_power_of_2(n)
|
|
151
|
+
elif n <= 256:
|
|
152
|
+
n = (n + 15) & ~0xF
|
|
153
|
+
else:
|
|
154
|
+
n = (((n - 1) >> 8) + 1) << 8
|
|
155
|
+
return n
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@functools.lru_cache
|
|
159
|
+
def _get_capture_batch_size_impl(max_batches: int):
|
|
160
|
+
"""Capture batch size."""
|
|
161
|
+
ret = []
|
|
162
|
+
batch_size = 1
|
|
163
|
+
batch_step_1, batch_step_2 = 16, 256
|
|
164
|
+
# power of 2
|
|
165
|
+
while batch_size <= min(batch_step_1, max_batches):
|
|
166
|
+
ret.append(batch_size)
|
|
167
|
+
batch_size *= 2
|
|
168
|
+
|
|
169
|
+
# step 1
|
|
170
|
+
ret += list(range(batch_size, min(max_batches, batch_step_2) + 1, batch_step_1))
|
|
171
|
+
|
|
172
|
+
# step 2
|
|
173
|
+
ret += list(range(ret[-1] + batch_step_2, max_batches + 1, batch_step_2))
|
|
174
|
+
|
|
175
|
+
# ensure max_batches in ret
|
|
176
|
+
if max_batches != ret[-1]:
|
|
177
|
+
ret.append(max_batches)
|
|
178
|
+
|
|
179
|
+
return ret
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _false(*args, **kwargs):
|
|
183
|
+
"""Default value of not support cuda graph."""
|
|
184
|
+
return False
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class AscendSingleGraphRunner:
|
|
188
|
+
"""Cuda single graph runner."""
|
|
189
|
+
|
|
190
|
+
def __init__(
|
|
191
|
+
self,
|
|
192
|
+
model: torch.nn.Module,
|
|
193
|
+
max_batches: int,
|
|
194
|
+
max_tokens: int,
|
|
195
|
+
num_blocks: int,
|
|
196
|
+
is_decoding: bool,
|
|
197
|
+
pool: Any,
|
|
198
|
+
model_config: ModelConfig,
|
|
199
|
+
device: torch.device,
|
|
200
|
+
):
|
|
201
|
+
self.model = model
|
|
202
|
+
self.ctx_mgr = model.ctx_mgr
|
|
203
|
+
self.model_config = model_config
|
|
204
|
+
|
|
205
|
+
self.meta = CudaGraphMeta(
|
|
206
|
+
max_batchs=max_batches,
|
|
207
|
+
max_tokens=max_tokens,
|
|
208
|
+
num_blocks=num_blocks,
|
|
209
|
+
is_decoding=is_decoding,
|
|
210
|
+
device=device,
|
|
211
|
+
input_buffers=dict(),
|
|
212
|
+
output_buffers=dict(),
|
|
213
|
+
vocab_size=self.model_config.vocab_size,
|
|
214
|
+
)
|
|
215
|
+
self.device = device
|
|
216
|
+
self.max_batches = max_batches
|
|
217
|
+
self.max_tokens = max_tokens
|
|
218
|
+
self.num_blocks = num_blocks
|
|
219
|
+
self.is_decoding = is_decoding
|
|
220
|
+
self.pool = pool
|
|
221
|
+
self._graph: torch.npu.NPUGraph = None
|
|
222
|
+
|
|
223
|
+
@record_function("capture_cudagraph")
|
|
224
|
+
def capture(self, **kwargs):
|
|
225
|
+
"""Capture graph."""
|
|
226
|
+
logger.debug(f"Capturing graph with meta: {self.meta}")
|
|
227
|
+
self.meta.input_buffers = self.model.make_buffers_cudagraph(self.meta, **kwargs)
|
|
228
|
+
padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs)
|
|
229
|
+
context = self.ctx_mgr.current_context()
|
|
230
|
+
self.model.update_context_cudagraph(self.meta, context)
|
|
231
|
+
current_stream = torch.cuda.current_stream()
|
|
232
|
+
|
|
233
|
+
aclgraph = torch.npu.NPUGraph()
|
|
234
|
+
with ExitStack() as stack:
|
|
235
|
+
with torch.npu.graph(
|
|
236
|
+
aclgraph,
|
|
237
|
+
auto_dispatch_capture=True,
|
|
238
|
+
pool=self.pool,
|
|
239
|
+
stream=current_stream,
|
|
240
|
+
):
|
|
241
|
+
output = self.model(**padded_kwargs)
|
|
242
|
+
|
|
243
|
+
output_buffers = dict(logits=output)
|
|
244
|
+
self.meta.output_buffers = output_buffers
|
|
245
|
+
self._graph = aclgraph
|
|
246
|
+
return output
|
|
247
|
+
|
|
248
|
+
@record_function("forward_cudagraph")
|
|
249
|
+
def forward(self, **kwargs):
|
|
250
|
+
"""forward."""
|
|
251
|
+
num_tokens = kwargs["input_ids"].size(-1)
|
|
252
|
+
assert self._graph is not None
|
|
253
|
+
self.model.fill_buffers_cudagraph(self.meta, **kwargs)
|
|
254
|
+
context = self.ctx_mgr.current_context()
|
|
255
|
+
self.model.update_context_cudagraph(self.meta, context)
|
|
256
|
+
torch.npu.synchronize()
|
|
257
|
+
self._graph.replay()
|
|
258
|
+
self._graph.update(
|
|
259
|
+
cpu_update_input=[
|
|
260
|
+
{"actual_seq_lengths_kv": self.meta.input_buffers["kv_seqlens"]}
|
|
261
|
+
]
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
output = self.meta.output_buffers["logits"][:, :num_tokens]
|
|
265
|
+
return output
|
|
266
|
+
|
|
267
|
+
def __del__(self):
|
|
268
|
+
"""del."""
|
|
269
|
+
del self._graph
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class AscendGraphRunner(GraphRunner):
|
|
273
|
+
"""Cuda graph runner."""
|
|
274
|
+
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
model: torch.nn.Module,
|
|
278
|
+
model_config: ModelConfig,
|
|
279
|
+
cache_config: CacheConfig,
|
|
280
|
+
backend_config: BackendConfig,
|
|
281
|
+
device: torch.device,
|
|
282
|
+
):
|
|
283
|
+
super().__init__(model, model_config, cache_config, backend_config, device)
|
|
284
|
+
self.max_batches = cache_config.max_batches
|
|
285
|
+
self.max_tokens = cache_config.max_prefill_token_num
|
|
286
|
+
self.num_blocks = cache_config.num_gpu_blocks
|
|
287
|
+
self.enable_graph = self.check_enable_graph()
|
|
288
|
+
self.graph_pool_handle = torch.cuda.graph_pool_handle()
|
|
289
|
+
self._runner_map: Dict[Any, AscendSingleGraphRunner] = dict()
|
|
290
|
+
self.has_try_compile_model: bool = False
|
|
291
|
+
|
|
292
|
+
def check_enable_graph(self):
|
|
293
|
+
"""Check enable graph."""
|
|
294
|
+
if self.backend_config.eager_mode:
|
|
295
|
+
return _false
|
|
296
|
+
|
|
297
|
+
return getattr(self.model, "support_cuda_graph", _false)
|
|
298
|
+
|
|
299
|
+
def _get_capture_tokens(self, batch_size: int):
|
|
300
|
+
"""Get capture tokens."""
|
|
301
|
+
cap_sizes = self.get_capture_batch_sizes()
|
|
302
|
+
for size in cap_sizes:
|
|
303
|
+
if size >= batch_size:
|
|
304
|
+
return size
|
|
305
|
+
assert False, f"Unsupported batch_size={batch_size}"
|
|
306
|
+
|
|
307
|
+
def get_graph_key(
|
|
308
|
+
self,
|
|
309
|
+
input_ids: torch.Tensor,
|
|
310
|
+
**kwargs,
|
|
311
|
+
):
|
|
312
|
+
"""Get graph key."""
|
|
313
|
+
context = self.ctx_mgr.current_context()
|
|
314
|
+
is_decoding = context.is_decoding
|
|
315
|
+
num_tokens = input_ids.numel()
|
|
316
|
+
meta = self.get_meta()
|
|
317
|
+
enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch
|
|
318
|
+
if meta.padding_batch_size is None:
|
|
319
|
+
new_num_tokens = self._get_capture_tokens(num_tokens)
|
|
320
|
+
else:
|
|
321
|
+
new_num_tokens = self._get_capture_tokens(meta.padding_batch_size)
|
|
322
|
+
return (new_num_tokens, is_decoding, enable_microbatch)
|
|
323
|
+
|
|
324
|
+
def __call__(self, **kwargs):
|
|
325
|
+
"""call."""
|
|
326
|
+
enable_graph = self.enable_graph(**kwargs)
|
|
327
|
+
|
|
328
|
+
if not enable_graph:
|
|
329
|
+
with record_function("forward_eager"):
|
|
330
|
+
ret = self.model(**kwargs)
|
|
331
|
+
return ret
|
|
332
|
+
|
|
333
|
+
graph_key = self.get_graph_key(**kwargs)
|
|
334
|
+
max_tokens = graph_key[0]
|
|
335
|
+
is_decoding = graph_key[1]
|
|
336
|
+
if graph_key not in self._runner_map:
|
|
337
|
+
max_batches = max_tokens if is_decoding else self.max_batches
|
|
338
|
+
runner = AscendSingleGraphRunner(
|
|
339
|
+
self.model,
|
|
340
|
+
max_batches=max_batches,
|
|
341
|
+
max_tokens=max_tokens,
|
|
342
|
+
num_blocks=self.num_blocks,
|
|
343
|
+
is_decoding=is_decoding,
|
|
344
|
+
pool=self.graph_pool_handle,
|
|
345
|
+
model_config=self.model_config,
|
|
346
|
+
device=self.device,
|
|
347
|
+
)
|
|
348
|
+
runner.capture(**kwargs)
|
|
349
|
+
self._runner_map[graph_key] = runner
|
|
350
|
+
else:
|
|
351
|
+
runner = self._runner_map[graph_key]
|
|
352
|
+
output = runner.forward(**kwargs)
|
|
353
|
+
return output
|
|
354
|
+
|
|
355
|
+
@record_function("prepare_inputs_for_generation")
|
|
356
|
+
def prepare_inputs_for_generation(
|
|
357
|
+
self,
|
|
358
|
+
past_key_values: List[List[torch.Tensor]],
|
|
359
|
+
inputs_embeds: torch.Tensor = None,
|
|
360
|
+
context: StepContext = None,
|
|
361
|
+
):
|
|
362
|
+
"""Prepare inputs."""
|
|
363
|
+
return self.model.prepare_inputs_for_generation(
|
|
364
|
+
past_key_values=past_key_values,
|
|
365
|
+
inputs_embeds=inputs_embeds,
|
|
366
|
+
context=context,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def reset(self):
|
|
370
|
+
"""Remove all graphs to prevent hanging on exit."""
|
|
371
|
+
self._runner_map.clear()
|
|
372
|
+
|
|
373
|
+
def update_inputs(self, inputs):
|
|
374
|
+
"""Update inputs."""
|
|
375
|
+
if self.backend_config.eager_mode:
|
|
376
|
+
return inputs
|
|
377
|
+
is_decoding = inputs.is_decoding
|
|
378
|
+
dp_meta = inputs.dp_meta
|
|
379
|
+
if is_decoding and dp_meta is not None:
|
|
380
|
+
meta = self.get_meta()
|
|
381
|
+
padding_batch_size = meta.padding_batch_size
|
|
382
|
+
tp_size = self._get_capture_tokens(padding_batch_size)
|
|
383
|
+
dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes)
|
|
384
|
+
return inputs
|
|
385
|
+
|
|
386
|
+
def get_capture_batch_sizes(self) -> List[int]:
|
|
387
|
+
"""Capture batch sizes."""
|
|
388
|
+
return _get_capture_batch_size_impl(self.cache_config.max_batches)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
graph_runner.CUDAGraphRunner = AscendGraphRunner
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright (c) 2024, OpenMMLab and DeepLink. All rights reserved.
|
|
2
|
+
import torch
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
|
|
8
|
+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMixin, next_power_of_2
|
|
9
|
+
|
|
10
|
+
BuffType = Dict[str, Tensor]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def CambCudaGraphMixin_make_buffers_cudagraph(
|
|
14
|
+
self, graph_meta: CudaGraphMeta, *args, **kwargs
|
|
15
|
+
) -> BuffType:
|
|
16
|
+
"""make cudagraph buffers from forward inputs."""
|
|
17
|
+
max_batches = graph_meta.max_batchs
|
|
18
|
+
max_tokens = graph_meta.max_tokens
|
|
19
|
+
num_blocks = graph_meta.num_blocks
|
|
20
|
+
device = graph_meta.device
|
|
21
|
+
|
|
22
|
+
input_buffers: BuffType = dict()
|
|
23
|
+
input_buffers["input_ids"] = torch.zeros(
|
|
24
|
+
1, max_tokens, dtype=torch.int32, device=device
|
|
25
|
+
)
|
|
26
|
+
input_buffers["position_ids"] = torch.ones(
|
|
27
|
+
(1, max_tokens), dtype=torch.int32, device=device
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
input_buffers["block_offsets"] = torch.zeros(
|
|
31
|
+
(max_batches, num_blocks), dtype=torch.int32, device=device
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# in attn_metadata, we mock q_start_loc as cu_seqlens so need to add 1
|
|
35
|
+
input_buffers["q_start_loc"] = torch.arange(
|
|
36
|
+
max_batches + 1, dtype=torch.int32, device=device
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
input_buffers["q_seqlens"] = torch.ones(
|
|
40
|
+
max_batches, dtype=torch.int32, device=device
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
input_buffers["kv_seqlens"] = torch.ones(
|
|
44
|
+
max_batches, dtype=torch.int32, device=device
|
|
45
|
+
)
|
|
46
|
+
# critical to set negative for kv_start_indices
|
|
47
|
+
# if we don't set it, two batches with same input tokens
|
|
48
|
+
# will result in different answer
|
|
49
|
+
input_buffers["kv_start_indices"] = -torch.ones(
|
|
50
|
+
max_batches, dtype=torch.int32, device=device
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return input_buffers
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def CambCudaGraphMixin_fill_buffers_cudagraph(
|
|
57
|
+
self,
|
|
58
|
+
graph_meta: CudaGraphMeta,
|
|
59
|
+
input_ids: Tensor,
|
|
60
|
+
position_ids: Tensor,
|
|
61
|
+
past_key_values: List,
|
|
62
|
+
attn_metadata: Any,
|
|
63
|
+
inputs_embeds: Tensor,
|
|
64
|
+
**kwargs
|
|
65
|
+
) -> Dict[str, Tensor]:
|
|
66
|
+
"""fill cudagraph buffers from forward inputs."""
|
|
67
|
+
block_offsets: Tensor = attn_metadata.block_offsets
|
|
68
|
+
q_start_loc: Tensor = attn_metadata.q_start_loc
|
|
69
|
+
q_seqlens: Tensor = attn_metadata.q_seqlens
|
|
70
|
+
kv_seqlens: Tensor = attn_metadata.kv_seqlens
|
|
71
|
+
kv_start_indices: Tensor = attn_metadata.kv_start_indices
|
|
72
|
+
|
|
73
|
+
input_buffers: BuffType = graph_meta.input_buffers
|
|
74
|
+
|
|
75
|
+
batch_size, num_blocks = block_offsets.size()
|
|
76
|
+
num_tokens = input_ids.size(-1)
|
|
77
|
+
# fill buffer
|
|
78
|
+
input_buffers["input_ids"][:, :num_tokens] = input_ids
|
|
79
|
+
input_buffers["position_ids"][:, :num_tokens] = position_ids
|
|
80
|
+
input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
|
|
81
|
+
input_buffers["q_seqlens"][:batch_size] = q_seqlens
|
|
82
|
+
input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
|
|
83
|
+
input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
|
|
84
|
+
input_buffers["kv_start_indices"][:num_tokens] = kv_start_indices[:num_tokens]
|
|
85
|
+
|
|
86
|
+
if inputs_embeds is not None:
|
|
87
|
+
emb_size = inputs_embeds.size(-1)
|
|
88
|
+
if "inputs_embeds" not in input_buffers:
|
|
89
|
+
max_num_tokens = input_buffers["input_ids"].size(-1)
|
|
90
|
+
input_buffers["inputs_embeds"] = inputs_embeds.new_zeros(
|
|
91
|
+
1, max_num_tokens, emb_size
|
|
92
|
+
)
|
|
93
|
+
input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
|
|
94
|
+
|
|
95
|
+
# below only used for capture graph
|
|
96
|
+
# create inputs
|
|
97
|
+
new_num_tokens = next_power_of_2(num_tokens)
|
|
98
|
+
new_batch_size = new_num_tokens
|
|
99
|
+
|
|
100
|
+
attn_metadata.block_offsets = input_buffers["block_offsets"][:new_batch_size]
|
|
101
|
+
attn_metadata.q_start_loc = input_buffers["q_start_loc"][: new_batch_size + 1]
|
|
102
|
+
attn_metadata.q_seqlens = input_buffers["q_seqlens"][:new_batch_size]
|
|
103
|
+
attn_metadata.kv_seqlens = input_buffers["kv_seqlens"][:new_batch_size]
|
|
104
|
+
|
|
105
|
+
attn_metadata.kv_start_indices = input_buffers["kv_start_indices"][:new_num_tokens]
|
|
106
|
+
new_inputs = dict(
|
|
107
|
+
past_key_values=past_key_values,
|
|
108
|
+
attn_metadata=attn_metadata,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# is_decoding:
|
|
112
|
+
new_inputs["input_ids"] = input_buffers["input_ids"][:, : new_batch_size + 1]
|
|
113
|
+
new_inputs["position_ids"] = input_buffers["position_ids"][:, :new_batch_size]
|
|
114
|
+
|
|
115
|
+
if inputs_embeds is not None:
|
|
116
|
+
new_inputs["inputs_embeds"] = input_buffers["inputs_embeds"][:, :new_batch_size]
|
|
117
|
+
|
|
118
|
+
new_inputs.update(kwargs)
|
|
119
|
+
|
|
120
|
+
return new_inputs
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def CambCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
|
|
124
|
+
"""update step context with input buffers."""
|
|
125
|
+
input_buffers = graph_meta.input_buffers
|
|
126
|
+
context.q_seqlens = input_buffers["q_seqlens"]
|
|
127
|
+
context.kv_seqlens = input_buffers["kv_seqlens"]
|
|
128
|
+
context.q_start_loc = input_buffers["q_start_loc"]
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
CudaGraphMixin.make_buffers_cudagraph = CambCudaGraphMixin_make_buffers_cudagraph
|
|
132
|
+
CudaGraphMixin.fill_buffers_cudagraph = CambCudaGraphMixin_fill_buffers_cudagraph
|
|
133
|
+
CudaGraphMixin.update_context_cudagraph = CambCudaGraphMixin_update_context_cudagraph
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# Copyright (c) 2024, OpenMMLab and DeepLink. All rights reserved.
|
|
2
|
+
import torch
|
|
3
|
+
from torch import Tensor
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
|
|
8
|
+
from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMixin, next_power_of_2
|
|
9
|
+
|
|
10
|
+
BuffType = Dict[str, Tensor]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def MacaCudaGraphMixin_make_buffers_cudagraph(
|
|
14
|
+
self, graph_meta: CudaGraphMeta, *args, **kwargs
|
|
15
|
+
) -> BuffType:
|
|
16
|
+
"""make cudagraph buffers from forward inputs."""
|
|
17
|
+
max_batches = graph_meta.max_batchs
|
|
18
|
+
max_tokens = graph_meta.max_tokens
|
|
19
|
+
num_blocks = graph_meta.num_blocks
|
|
20
|
+
device = graph_meta.device
|
|
21
|
+
input_buffers: BuffType = dict()
|
|
22
|
+
input_buffers["input_ids"] = torch.empty(
|
|
23
|
+
1, max_tokens, dtype=torch.int32, device=device
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
input_buffers["position_ids"] = torch.empty(
|
|
27
|
+
(1, max_tokens), dtype=torch.int32, device=device
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
input_buffers["block_offsets"] = torch.zeros(
|
|
31
|
+
(max_batches, num_blocks), dtype=torch.int32, device=device
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
input_buffers["q_seqlens"] = torch.ones(
|
|
35
|
+
max_batches, dtype=torch.int32, device=device
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
input_buffers["kv_seqlens"] = torch.zeros(
|
|
39
|
+
max_batches, dtype=torch.int32, device=device
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
input_buffers["q_start_loc"] = torch.arange(
|
|
43
|
+
max_batches + 1, dtype=torch.int32, device=device
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
input_buffers["kv_start_indices"] = -torch.ones(
|
|
47
|
+
(max_batches, 1), dtype=torch.int64, device=device
|
|
48
|
+
)
|
|
49
|
+
return input_buffers
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def MacaCudaGraphMixin_fill_buffers_cudagraph(
|
|
53
|
+
self,
|
|
54
|
+
graph_meta: CudaGraphMeta,
|
|
55
|
+
input_ids: Tensor,
|
|
56
|
+
position_ids: Tensor,
|
|
57
|
+
past_key_values: List,
|
|
58
|
+
attn_metadata: Any,
|
|
59
|
+
inputs_embeds: Tensor,
|
|
60
|
+
**kwargs
|
|
61
|
+
) -> Dict[str, Tensor]:
|
|
62
|
+
"""fill cudagraph buffers from forward inputs."""
|
|
63
|
+
block_offsets: Tensor = attn_metadata.block_offsets
|
|
64
|
+
q_start_loc: Tensor = attn_metadata.q_start_loc
|
|
65
|
+
q_seqlens: Tensor = attn_metadata.q_seqlens
|
|
66
|
+
kv_seqlens: Tensor = attn_metadata.kv_seqlens
|
|
67
|
+
kv_start_indices: Tensor = attn_metadata.kv_start_indices
|
|
68
|
+
|
|
69
|
+
input_buffers: BuffType = graph_meta.input_buffers
|
|
70
|
+
|
|
71
|
+
batch_size, num_blocks = block_offsets.size()
|
|
72
|
+
num_tokens = input_ids.size(-1)
|
|
73
|
+
|
|
74
|
+
# fill buffer
|
|
75
|
+
input_buffers["input_ids"][:, :num_tokens] = input_ids
|
|
76
|
+
input_buffers["position_ids"][:, :num_tokens] = position_ids
|
|
77
|
+
input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
|
|
78
|
+
input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
|
|
79
|
+
input_buffers["q_seqlens"][:batch_size] = q_seqlens
|
|
80
|
+
input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
|
|
81
|
+
input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
|
|
82
|
+
|
|
83
|
+
if inputs_embeds is not None:
|
|
84
|
+
emb_size = inputs_embeds.size(-1)
|
|
85
|
+
if "inputs_embeds" not in input_buffers:
|
|
86
|
+
max_num_tokens = input_buffers["input_ids"].size(-1)
|
|
87
|
+
input_buffers["inputs_embeds"] = inputs_embeds.new_zeros(
|
|
88
|
+
1, max_num_tokens, emb_size
|
|
89
|
+
)
|
|
90
|
+
input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
|
|
91
|
+
# create inputs
|
|
92
|
+
new_batch_size = next_power_of_2(batch_size)
|
|
93
|
+
|
|
94
|
+
attn_metadata.block_offsets = input_buffers["block_offsets"][:new_batch_size]
|
|
95
|
+
attn_metadata.q_start_loc = input_buffers["q_start_loc"][: new_batch_size + 1]
|
|
96
|
+
attn_metadata.q_seqlens = input_buffers["q_seqlens"][:new_batch_size]
|
|
97
|
+
attn_metadata.kv_seqlens = input_buffers["kv_seqlens"][:new_batch_size]
|
|
98
|
+
attn_metadata.kv_start_indices = input_buffers["kv_start_indices"][:new_batch_size]
|
|
99
|
+
|
|
100
|
+
new_inputs = dict(
|
|
101
|
+
past_key_values=past_key_values,
|
|
102
|
+
attn_metadata=attn_metadata,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
new_inputs["input_ids"] = input_buffers["input_ids"][:, :new_batch_size]
|
|
106
|
+
new_inputs["position_ids"] = input_buffers["position_ids"][:, :new_batch_size]
|
|
107
|
+
|
|
108
|
+
if inputs_embeds is not None:
|
|
109
|
+
new_inputs["inputs_embeds"] = input_buffers["inputs_embeds"][:, :new_batch_size]
|
|
110
|
+
|
|
111
|
+
new_inputs.update(kwargs)
|
|
112
|
+
|
|
113
|
+
return new_inputs
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def MacaCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
|
|
117
|
+
"""update step context with input buffers."""
|
|
118
|
+
input_buffers = graph_meta.input_buffers
|
|
119
|
+
context.block_offsets = input_buffers["block_offsets"]
|
|
120
|
+
context.q_seqlens = input_buffers["q_seqlens"]
|
|
121
|
+
context.kv_seqlens = input_buffers["kv_seqlens"]
|
|
122
|
+
context.q_start_loc = input_buffers["q_start_loc"]
|
|
123
|
+
context.kv_start_indices = input_buffers["kv_start_indices"]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
CudaGraphMixin.make_buffers_cudagraph = MacaCudaGraphMixin_make_buffers_cudagraph
|
|
127
|
+
CudaGraphMixin.fill_buffers_cudagraph = MacaCudaGraphMixin_fill_buffers_cudagraph
|
|
128
|
+
CudaGraphMixin.update_context_cudagraph = MacaCudaGraphMixin_update_context_cudagraph
|