dlinfer-ascend 0.2.3.post2__cp311-cp311-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dlinfer/__init__.py +5 -0
  2. dlinfer/framework/__init__.py +1 -0
  3. dlinfer/framework/lmdeploy_ext/__init__.py +6 -0
  4. dlinfer/framework/lmdeploy_ext/cudagraph/__init__.py +20 -0
  5. dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +391 -0
  6. dlinfer/framework/lmdeploy_ext/cudagraph/camb_cudagraph.py +133 -0
  7. dlinfer/framework/lmdeploy_ext/cudagraph/maca_cudagraph.py +128 -0
  8. dlinfer/framework/lmdeploy_ext/cudagraph/ppu_cudagraph.py +131 -0
  9. dlinfer/framework/lmdeploy_ext/device/__init__.py +79 -0
  10. dlinfer/framework/lmdeploy_ext/device/ascend.py +205 -0
  11. dlinfer/framework/lmdeploy_ext/device/camb.py +24 -0
  12. dlinfer/framework/lmdeploy_ext/quants/__init__.py +20 -0
  13. dlinfer/framework/lmdeploy_ext/quants/ascend_awq.py +248 -0
  14. dlinfer/framework/torch_npu_ext/__init__.py +12 -0
  15. dlinfer/framework/torch_npu_ext/aclgraph.py +59 -0
  16. dlinfer/framework/transformers_ext/__init__.py +17 -0
  17. dlinfer/framework/transformers_ext/cogvlm.py +25 -0
  18. dlinfer/framework/transformers_ext/internlm2.py +242 -0
  19. dlinfer/framework/transformers_ext/internvl.py +33 -0
  20. dlinfer/framework/transformers_ext/patch.py +33 -0
  21. dlinfer/graph/__init__.py +5 -0
  22. dlinfer/graph/custom_op.py +147 -0
  23. dlinfer/graph/dicp/__init__.py +0 -0
  24. dlinfer/graph/dicp/dynamo_bridge/__init__.py +0 -0
  25. dlinfer/graph/dicp/dynamo_bridge/compile.py +42 -0
  26. dlinfer/graph/dicp/dynamo_bridge/compile_fx.py +305 -0
  27. dlinfer/graph/dicp/dynamo_bridge/conversion.py +75 -0
  28. dlinfer/graph/dicp/dynamo_bridge/decompositions.py +38 -0
  29. dlinfer/graph/dicp/dynamo_bridge/graph.py +141 -0
  30. dlinfer/graph/dicp/dynamo_bridge/op_transformer.py +293 -0
  31. dlinfer/graph/dicp/dynamo_bridge/operator.py +87 -0
  32. dlinfer/graph/dicp/dynamo_bridge/pt_patch.py +320 -0
  33. dlinfer/graph/dicp/dynamo_bridge/torch_version.py +38 -0
  34. dlinfer/graph/dicp/dynamo_bridge/utils.py +158 -0
  35. dlinfer/graph/dicp/vendor/AtbGraph/__init__.py +13 -0
  36. dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py +853 -0
  37. dlinfer/graph/dicp/vendor/AtbGraph/codegen/__init__.py +0 -0
  38. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb.py +318 -0
  39. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_graph.py +768 -0
  40. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py +763 -0
  41. dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py +1279 -0
  42. dlinfer/graph/dicp/vendor/AtbGraph/codegen/libdicp_model.so +0 -0
  43. dlinfer/graph/dicp/vendor/AtbGraph/codegen/load_and_run.py +21 -0
  44. dlinfer/graph/dicp/vendor/AtbGraph/codegen/utils.py +178 -0
  45. dlinfer/graph/dicp/vendor/AtbGraph/compile_job.py +52 -0
  46. dlinfer/graph/dicp/vendor/AtbGraph/config.py +36 -0
  47. dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +908 -0
  48. dlinfer/graph/dicp/vendor/AtbGraph/ext_ops.py +95 -0
  49. dlinfer/graph/dicp/vendor/AtbGraph/infer_res_utils.py +200 -0
  50. dlinfer/graph/dicp/vendor/AtbGraph/opset_convert.py +70 -0
  51. dlinfer/graph/dicp/vendor/AtbGraph/pattern_replacement.py +152 -0
  52. dlinfer/graph/dicp/vendor/__init__.py +0 -0
  53. dlinfer/ops/__init__.py +2 -0
  54. dlinfer/ops/llm.py +879 -0
  55. dlinfer/utils/__init__.py +1 -0
  56. dlinfer/utils/config.py +18 -0
  57. dlinfer/utils/registry.py +8 -0
  58. dlinfer/utils/type_annotation.py +3 -0
  59. dlinfer/vendor/__init__.py +33 -0
  60. dlinfer/vendor/ascend/__init__.py +5 -0
  61. dlinfer/vendor/ascend/pytorch_patch.py +55 -0
  62. dlinfer/vendor/ascend/torch_npu_ops.py +601 -0
  63. dlinfer/vendor/ascend/utils.py +20 -0
  64. dlinfer/vendor/vendor.yaml +2 -0
  65. dlinfer_ascend-0.2.3.post2.dist-info/LICENSE +28 -0
  66. dlinfer_ascend-0.2.3.post2.dist-info/METADATA +213 -0
  67. dlinfer_ascend-0.2.3.post2.dist-info/RECORD +70 -0
  68. dlinfer_ascend-0.2.3.post2.dist-info/WHEEL +5 -0
  69. dlinfer_ascend-0.2.3.post2.dist-info/entry_points.txt +2 -0
  70. dlinfer_ascend-0.2.3.post2.dist-info/top_level.txt +1 -0
dlinfer/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import dlinfer.vendor as vendor
3
+
4
+ vendor.vendor_torch_init()
5
+ __version__ = "0.2.3.post2"
@@ -0,0 +1 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
@@ -0,0 +1,6 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import dlinfer.framework.transformers_ext
3
+ import dlinfer.framework.torch_npu_ext
4
+ from . import quants
5
+ from . import cudagraph
6
+ from . import device
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2024, DeepLink. All rights reserved.
2
+ import importlib
3
+ from functools import lru_cache
4
+ from dlinfer.vendor import vendor_name
5
+
6
+
7
+ graph_vendor = ["ascend", "maca", "camb", "ppu"]
8
+
9
+
10
+ @lru_cache(1)
11
+ def import_vendor_module(vendor_name_str):
12
+ if vendor_name_str in graph_vendor:
13
+ importlib.import_module(f".{vendor_name_str}_cudagraph", __package__)
14
+
15
+
16
+ def vendor_graph_init():
17
+ import_vendor_module(vendor_name)
18
+
19
+
20
+ vendor_graph_init()
@@ -0,0 +1,391 @@
1
+ # Copyright (c) 2024, OpenMMLab and DeepLink. All rights reserved.
2
+ # this file implements the cudagraph for ascend backend.
3
+ import functools
4
+ from typing import Any, Dict, List
5
+ from contextlib import ExitStack
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.profiler import record_function
9
+
10
+ from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
11
+ from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMixin
12
+ from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
13
+ from lmdeploy.pytorch.model_inputs import StepContext, get_step_ctx_manager
14
+ from lmdeploy.pytorch.backends.graph_runner import GraphRunner
15
+ from lmdeploy.pytorch.backends.cuda import graph_runner
16
+
17
+ from lmdeploy.utils import get_logger
18
+
19
+ logger = get_logger("dlinfer")
20
+ BuffType = Dict[str, Tensor]
21
+
22
+
23
+ # AscendCudaGraphMixin methods for cudagraph buffer management.
24
+ def AscendCudaGraphMixin_make_buffers_cudagraph(
25
+ self, graph_meta: CudaGraphMeta, *args, **kwargs
26
+ ) -> BuffType:
27
+ """make cudagraph buffers from forward inputs."""
28
+ max_batches = graph_meta.max_batchs
29
+ max_tokens = graph_meta.max_tokens
30
+ num_blocks = graph_meta.num_blocks
31
+ device = graph_meta.device
32
+ input_buffers: BuffType = dict()
33
+ input_buffers["input_ids"] = torch.empty(
34
+ 1, max_tokens, dtype=torch.int32, device=device
35
+ )
36
+
37
+ input_buffers["position_ids"] = torch.empty(
38
+ (1, max_tokens), dtype=torch.int32, device=device
39
+ )
40
+
41
+ input_buffers["block_offsets"] = torch.zeros(
42
+ (max_batches, num_blocks), dtype=torch.int32, device=device
43
+ )
44
+
45
+ input_buffers["q_seqlens"] = torch.ones(
46
+ max_batches, dtype=torch.int32, device=device
47
+ )
48
+
49
+ input_buffers["kv_seqlens"] = [0] * max_batches
50
+
51
+ input_buffers["q_start_loc"] = torch.arange(
52
+ max_batches + 1, dtype=torch.int32, device=device
53
+ )
54
+
55
+ input_buffers["kv_start_indices"] = -torch.ones(
56
+ (max_batches), dtype=torch.int64, device=device
57
+ )
58
+ return input_buffers
59
+
60
+
61
+ def AscendCudaGraphMixin_fill_buffers_cudagraph(
62
+ self,
63
+ graph_meta: CudaGraphMeta,
64
+ input_ids: Tensor,
65
+ position_ids: Tensor,
66
+ past_key_values: List,
67
+ attn_metadata: Any,
68
+ inputs_embeds: Tensor,
69
+ **kwargs,
70
+ ) -> Dict[str, Tensor]:
71
+ """fill cudagraph buffers from forward inputs."""
72
+ block_offsets: Tensor = attn_metadata.block_offsets
73
+ kv_seqlens: List = attn_metadata.kv_seqlens
74
+ kv_start_indices: Tensor = attn_metadata.kv_start_indices
75
+
76
+ input_buffers: BuffType = graph_meta.input_buffers
77
+
78
+ batch_size, num_blocks = block_offsets.size()
79
+ num_tokens = input_ids.size(-1)
80
+
81
+ # fill buffer
82
+ input_buffers["input_ids"][:, :num_tokens] = input_ids
83
+ input_buffers["position_ids"][:, :num_tokens] = position_ids
84
+ input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
85
+ input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
86
+ input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
87
+
88
+ if inputs_embeds is not None:
89
+ emb_size = inputs_embeds.size(-1)
90
+ if "inputs_embeds" not in input_buffers:
91
+ max_num_tokens = input_buffers["input_ids"].size(-1)
92
+ input_buffers["inputs_embeds"] = inputs_embeds.new_zeros(
93
+ 1, max_num_tokens, emb_size
94
+ )
95
+ input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
96
+ # create inputs
97
+ new_batch_size = get_ascend_compatible_size(batch_size)
98
+
99
+ attn_metadata.block_offsets = input_buffers["block_offsets"][:new_batch_size]
100
+ attn_metadata.kv_seqlens = input_buffers["kv_seqlens"][:new_batch_size]
101
+ attn_metadata.kv_start_indices = input_buffers["kv_start_indices"][:new_batch_size]
102
+
103
+ new_inputs = dict(
104
+ past_key_values=past_key_values,
105
+ attn_metadata=attn_metadata,
106
+ )
107
+
108
+ new_inputs["input_ids"] = input_buffers["input_ids"][:, :new_batch_size]
109
+ new_inputs["position_ids"] = input_buffers["position_ids"][:, :new_batch_size]
110
+
111
+ if inputs_embeds is not None:
112
+ new_inputs["inputs_embeds"] = input_buffers["inputs_embeds"][:, :new_batch_size]
113
+
114
+ new_inputs.update(kwargs)
115
+
116
+ return new_inputs
117
+
118
+
119
+ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
120
+ """update step context with input buffers."""
121
+ input_buffers = graph_meta.input_buffers
122
+ context.block_offsets = input_buffers["block_offsets"]
123
+ context.q_seqlens = input_buffers["q_seqlens"]
124
+ context.kv_seqlens = input_buffers["kv_seqlens"]
125
+ context.q_start_loc = input_buffers["q_start_loc"]
126
+ context.kv_start_indices = input_buffers["kv_start_indices"]
127
+
128
+
129
+ CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph
130
+ CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph
131
+ CudaGraphMixin.update_context_cudagraph = AscendCudaGraphMixin_update_context_cudagraph
132
+
133
+
134
+ def next_power_of_2(n: int):
135
+ """Return the smallest power of 2 greater than or equal to n."""
136
+ n -= 1
137
+ n |= n >> 1
138
+ n |= n >> 2
139
+ n |= n >> 4
140
+ n |= n >> 8
141
+ n |= n >> 16
142
+ n |= n >> 32
143
+ n += 1
144
+ return n
145
+
146
+
147
+ def get_ascend_compatible_size(n: int):
148
+ """Get ascend compatible size."""
149
+ if n <= 16:
150
+ n = next_power_of_2(n)
151
+ elif n <= 256:
152
+ n = (n + 15) & ~0xF
153
+ else:
154
+ n = (((n - 1) >> 8) + 1) << 8
155
+ return n
156
+
157
+
158
+ @functools.lru_cache
159
+ def _get_capture_batch_size_impl(max_batches: int):
160
+ """Capture batch size."""
161
+ ret = []
162
+ batch_size = 1
163
+ batch_step_1, batch_step_2 = 16, 256
164
+ # power of 2
165
+ while batch_size <= min(batch_step_1, max_batches):
166
+ ret.append(batch_size)
167
+ batch_size *= 2
168
+
169
+ # step 1
170
+ ret += list(range(batch_size, min(max_batches, batch_step_2) + 1, batch_step_1))
171
+
172
+ # step 2
173
+ ret += list(range(ret[-1] + batch_step_2, max_batches + 1, batch_step_2))
174
+
175
+ # ensure max_batches in ret
176
+ if max_batches != ret[-1]:
177
+ ret.append(max_batches)
178
+
179
+ return ret
180
+
181
+
182
+ def _false(*args, **kwargs):
183
+ """Default value of not support cuda graph."""
184
+ return False
185
+
186
+
187
+ class AscendSingleGraphRunner:
188
+ """Cuda single graph runner."""
189
+
190
+ def __init__(
191
+ self,
192
+ model: torch.nn.Module,
193
+ max_batches: int,
194
+ max_tokens: int,
195
+ num_blocks: int,
196
+ is_decoding: bool,
197
+ pool: Any,
198
+ model_config: ModelConfig,
199
+ device: torch.device,
200
+ ):
201
+ self.model = model
202
+ self.ctx_mgr = model.ctx_mgr
203
+ self.model_config = model_config
204
+
205
+ self.meta = CudaGraphMeta(
206
+ max_batchs=max_batches,
207
+ max_tokens=max_tokens,
208
+ num_blocks=num_blocks,
209
+ is_decoding=is_decoding,
210
+ device=device,
211
+ input_buffers=dict(),
212
+ output_buffers=dict(),
213
+ vocab_size=self.model_config.vocab_size,
214
+ )
215
+ self.device = device
216
+ self.max_batches = max_batches
217
+ self.max_tokens = max_tokens
218
+ self.num_blocks = num_blocks
219
+ self.is_decoding = is_decoding
220
+ self.pool = pool
221
+ self._graph: torch.npu.NPUGraph = None
222
+
223
+ @record_function("capture_cudagraph")
224
+ def capture(self, **kwargs):
225
+ """Capture graph."""
226
+ logger.debug(f"Capturing graph with meta: {self.meta}")
227
+ self.meta.input_buffers = self.model.make_buffers_cudagraph(self.meta, **kwargs)
228
+ padded_kwargs = self.model.fill_buffers_cudagraph(self.meta, **kwargs)
229
+ context = self.ctx_mgr.current_context()
230
+ self.model.update_context_cudagraph(self.meta, context)
231
+ current_stream = torch.cuda.current_stream()
232
+
233
+ aclgraph = torch.npu.NPUGraph()
234
+ with ExitStack() as stack:
235
+ with torch.npu.graph(
236
+ aclgraph,
237
+ auto_dispatch_capture=True,
238
+ pool=self.pool,
239
+ stream=current_stream,
240
+ ):
241
+ output = self.model(**padded_kwargs)
242
+
243
+ output_buffers = dict(logits=output)
244
+ self.meta.output_buffers = output_buffers
245
+ self._graph = aclgraph
246
+ return output
247
+
248
+ @record_function("forward_cudagraph")
249
+ def forward(self, **kwargs):
250
+ """forward."""
251
+ num_tokens = kwargs["input_ids"].size(-1)
252
+ assert self._graph is not None
253
+ self.model.fill_buffers_cudagraph(self.meta, **kwargs)
254
+ context = self.ctx_mgr.current_context()
255
+ self.model.update_context_cudagraph(self.meta, context)
256
+ torch.npu.synchronize()
257
+ self._graph.replay()
258
+ self._graph.update(
259
+ cpu_update_input=[
260
+ {"actual_seq_lengths_kv": self.meta.input_buffers["kv_seqlens"]}
261
+ ]
262
+ )
263
+
264
+ output = self.meta.output_buffers["logits"][:, :num_tokens]
265
+ return output
266
+
267
+ def __del__(self):
268
+ """del."""
269
+ del self._graph
270
+
271
+
272
+ class AscendGraphRunner(GraphRunner):
273
+ """Cuda graph runner."""
274
+
275
+ def __init__(
276
+ self,
277
+ model: torch.nn.Module,
278
+ model_config: ModelConfig,
279
+ cache_config: CacheConfig,
280
+ backend_config: BackendConfig,
281
+ device: torch.device,
282
+ ):
283
+ super().__init__(model, model_config, cache_config, backend_config, device)
284
+ self.max_batches = cache_config.max_batches
285
+ self.max_tokens = cache_config.max_prefill_token_num
286
+ self.num_blocks = cache_config.num_gpu_blocks
287
+ self.enable_graph = self.check_enable_graph()
288
+ self.graph_pool_handle = torch.cuda.graph_pool_handle()
289
+ self._runner_map: Dict[Any, AscendSingleGraphRunner] = dict()
290
+ self.has_try_compile_model: bool = False
291
+
292
+ def check_enable_graph(self):
293
+ """Check enable graph."""
294
+ if self.backend_config.eager_mode:
295
+ return _false
296
+
297
+ return getattr(self.model, "support_cuda_graph", _false)
298
+
299
+ def _get_capture_tokens(self, batch_size: int):
300
+ """Get capture tokens."""
301
+ cap_sizes = self.get_capture_batch_sizes()
302
+ for size in cap_sizes:
303
+ if size >= batch_size:
304
+ return size
305
+ assert False, f"Unsupported batch_size={batch_size}"
306
+
307
+ def get_graph_key(
308
+ self,
309
+ input_ids: torch.Tensor,
310
+ **kwargs,
311
+ ):
312
+ """Get graph key."""
313
+ context = self.ctx_mgr.current_context()
314
+ is_decoding = context.is_decoding
315
+ num_tokens = input_ids.numel()
316
+ meta = self.get_meta()
317
+ enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch
318
+ if meta.padding_batch_size is None:
319
+ new_num_tokens = self._get_capture_tokens(num_tokens)
320
+ else:
321
+ new_num_tokens = self._get_capture_tokens(meta.padding_batch_size)
322
+ return (new_num_tokens, is_decoding, enable_microbatch)
323
+
324
+ def __call__(self, **kwargs):
325
+ """call."""
326
+ enable_graph = self.enable_graph(**kwargs)
327
+
328
+ if not enable_graph:
329
+ with record_function("forward_eager"):
330
+ ret = self.model(**kwargs)
331
+ return ret
332
+
333
+ graph_key = self.get_graph_key(**kwargs)
334
+ max_tokens = graph_key[0]
335
+ is_decoding = graph_key[1]
336
+ if graph_key not in self._runner_map:
337
+ max_batches = max_tokens if is_decoding else self.max_batches
338
+ runner = AscendSingleGraphRunner(
339
+ self.model,
340
+ max_batches=max_batches,
341
+ max_tokens=max_tokens,
342
+ num_blocks=self.num_blocks,
343
+ is_decoding=is_decoding,
344
+ pool=self.graph_pool_handle,
345
+ model_config=self.model_config,
346
+ device=self.device,
347
+ )
348
+ runner.capture(**kwargs)
349
+ self._runner_map[graph_key] = runner
350
+ else:
351
+ runner = self._runner_map[graph_key]
352
+ output = runner.forward(**kwargs)
353
+ return output
354
+
355
+ @record_function("prepare_inputs_for_generation")
356
+ def prepare_inputs_for_generation(
357
+ self,
358
+ past_key_values: List[List[torch.Tensor]],
359
+ inputs_embeds: torch.Tensor = None,
360
+ context: StepContext = None,
361
+ ):
362
+ """Prepare inputs."""
363
+ return self.model.prepare_inputs_for_generation(
364
+ past_key_values=past_key_values,
365
+ inputs_embeds=inputs_embeds,
366
+ context=context,
367
+ )
368
+
369
+ def reset(self):
370
+ """Remove all graphs to prevent hanging on exit."""
371
+ self._runner_map.clear()
372
+
373
+ def update_inputs(self, inputs):
374
+ """Update inputs."""
375
+ if self.backend_config.eager_mode:
376
+ return inputs
377
+ is_decoding = inputs.is_decoding
378
+ dp_meta = inputs.dp_meta
379
+ if is_decoding and dp_meta is not None:
380
+ meta = self.get_meta()
381
+ padding_batch_size = meta.padding_batch_size
382
+ tp_size = self._get_capture_tokens(padding_batch_size)
383
+ dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes)
384
+ return inputs
385
+
386
+ def get_capture_batch_sizes(self) -> List[int]:
387
+ """Capture batch sizes."""
388
+ return _get_capture_batch_size_impl(self.cache_config.max_batches)
389
+
390
+
391
+ graph_runner.CUDAGraphRunner = AscendGraphRunner
@@ -0,0 +1,133 @@
1
+ # Copyright (c) 2024, OpenMMLab and DeepLink. All rights reserved.
2
+ import torch
3
+ from torch import Tensor
4
+
5
+ from typing import Any, Dict, List
6
+
7
+ from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
8
+ from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMixin, next_power_of_2
9
+
10
+ BuffType = Dict[str, Tensor]
11
+
12
+
13
+ def CambCudaGraphMixin_make_buffers_cudagraph(
14
+ self, graph_meta: CudaGraphMeta, *args, **kwargs
15
+ ) -> BuffType:
16
+ """make cudagraph buffers from forward inputs."""
17
+ max_batches = graph_meta.max_batchs
18
+ max_tokens = graph_meta.max_tokens
19
+ num_blocks = graph_meta.num_blocks
20
+ device = graph_meta.device
21
+
22
+ input_buffers: BuffType = dict()
23
+ input_buffers["input_ids"] = torch.zeros(
24
+ 1, max_tokens, dtype=torch.int32, device=device
25
+ )
26
+ input_buffers["position_ids"] = torch.ones(
27
+ (1, max_tokens), dtype=torch.int32, device=device
28
+ )
29
+
30
+ input_buffers["block_offsets"] = torch.zeros(
31
+ (max_batches, num_blocks), dtype=torch.int32, device=device
32
+ )
33
+
34
+ # in attn_metadata, we mock q_start_loc as cu_seqlens so need to add 1
35
+ input_buffers["q_start_loc"] = torch.arange(
36
+ max_batches + 1, dtype=torch.int32, device=device
37
+ )
38
+
39
+ input_buffers["q_seqlens"] = torch.ones(
40
+ max_batches, dtype=torch.int32, device=device
41
+ )
42
+
43
+ input_buffers["kv_seqlens"] = torch.ones(
44
+ max_batches, dtype=torch.int32, device=device
45
+ )
46
+ # critical to set negative for kv_start_indices
47
+ # if we don't set it, two batches with same input tokens
48
+ # will result in different answer
49
+ input_buffers["kv_start_indices"] = -torch.ones(
50
+ max_batches, dtype=torch.int32, device=device
51
+ )
52
+
53
+ return input_buffers
54
+
55
+
56
+ def CambCudaGraphMixin_fill_buffers_cudagraph(
57
+ self,
58
+ graph_meta: CudaGraphMeta,
59
+ input_ids: Tensor,
60
+ position_ids: Tensor,
61
+ past_key_values: List,
62
+ attn_metadata: Any,
63
+ inputs_embeds: Tensor,
64
+ **kwargs
65
+ ) -> Dict[str, Tensor]:
66
+ """fill cudagraph buffers from forward inputs."""
67
+ block_offsets: Tensor = attn_metadata.block_offsets
68
+ q_start_loc: Tensor = attn_metadata.q_start_loc
69
+ q_seqlens: Tensor = attn_metadata.q_seqlens
70
+ kv_seqlens: Tensor = attn_metadata.kv_seqlens
71
+ kv_start_indices: Tensor = attn_metadata.kv_start_indices
72
+
73
+ input_buffers: BuffType = graph_meta.input_buffers
74
+
75
+ batch_size, num_blocks = block_offsets.size()
76
+ num_tokens = input_ids.size(-1)
77
+ # fill buffer
78
+ input_buffers["input_ids"][:, :num_tokens] = input_ids
79
+ input_buffers["position_ids"][:, :num_tokens] = position_ids
80
+ input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
81
+ input_buffers["q_seqlens"][:batch_size] = q_seqlens
82
+ input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
83
+ input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
84
+ input_buffers["kv_start_indices"][:num_tokens] = kv_start_indices[:num_tokens]
85
+
86
+ if inputs_embeds is not None:
87
+ emb_size = inputs_embeds.size(-1)
88
+ if "inputs_embeds" not in input_buffers:
89
+ max_num_tokens = input_buffers["input_ids"].size(-1)
90
+ input_buffers["inputs_embeds"] = inputs_embeds.new_zeros(
91
+ 1, max_num_tokens, emb_size
92
+ )
93
+ input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
94
+
95
+ # below only used for capture graph
96
+ # create inputs
97
+ new_num_tokens = next_power_of_2(num_tokens)
98
+ new_batch_size = new_num_tokens
99
+
100
+ attn_metadata.block_offsets = input_buffers["block_offsets"][:new_batch_size]
101
+ attn_metadata.q_start_loc = input_buffers["q_start_loc"][: new_batch_size + 1]
102
+ attn_metadata.q_seqlens = input_buffers["q_seqlens"][:new_batch_size]
103
+ attn_metadata.kv_seqlens = input_buffers["kv_seqlens"][:new_batch_size]
104
+
105
+ attn_metadata.kv_start_indices = input_buffers["kv_start_indices"][:new_num_tokens]
106
+ new_inputs = dict(
107
+ past_key_values=past_key_values,
108
+ attn_metadata=attn_metadata,
109
+ )
110
+
111
+ # is_decoding:
112
+ new_inputs["input_ids"] = input_buffers["input_ids"][:, : new_batch_size + 1]
113
+ new_inputs["position_ids"] = input_buffers["position_ids"][:, :new_batch_size]
114
+
115
+ if inputs_embeds is not None:
116
+ new_inputs["inputs_embeds"] = input_buffers["inputs_embeds"][:, :new_batch_size]
117
+
118
+ new_inputs.update(kwargs)
119
+
120
+ return new_inputs
121
+
122
+
123
+ def CambCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
124
+ """update step context with input buffers."""
125
+ input_buffers = graph_meta.input_buffers
126
+ context.q_seqlens = input_buffers["q_seqlens"]
127
+ context.kv_seqlens = input_buffers["kv_seqlens"]
128
+ context.q_start_loc = input_buffers["q_start_loc"]
129
+
130
+
131
+ CudaGraphMixin.make_buffers_cudagraph = CambCudaGraphMixin_make_buffers_cudagraph
132
+ CudaGraphMixin.fill_buffers_cudagraph = CambCudaGraphMixin_fill_buffers_cudagraph
133
+ CudaGraphMixin.update_context_cudagraph = CambCudaGraphMixin_update_context_cudagraph
@@ -0,0 +1,128 @@
1
+ # Copyright (c) 2024, OpenMMLab and DeepLink. All rights reserved.
2
+ import torch
3
+ from torch import Tensor
4
+
5
+ from typing import Any, Dict, List
6
+
7
+ from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMeta
8
+ from lmdeploy.pytorch.models.utils.cudagraph import CudaGraphMixin, next_power_of_2
9
+
10
+ BuffType = Dict[str, Tensor]
11
+
12
+
13
+ def MacaCudaGraphMixin_make_buffers_cudagraph(
14
+ self, graph_meta: CudaGraphMeta, *args, **kwargs
15
+ ) -> BuffType:
16
+ """make cudagraph buffers from forward inputs."""
17
+ max_batches = graph_meta.max_batchs
18
+ max_tokens = graph_meta.max_tokens
19
+ num_blocks = graph_meta.num_blocks
20
+ device = graph_meta.device
21
+ input_buffers: BuffType = dict()
22
+ input_buffers["input_ids"] = torch.empty(
23
+ 1, max_tokens, dtype=torch.int32, device=device
24
+ )
25
+
26
+ input_buffers["position_ids"] = torch.empty(
27
+ (1, max_tokens), dtype=torch.int32, device=device
28
+ )
29
+
30
+ input_buffers["block_offsets"] = torch.zeros(
31
+ (max_batches, num_blocks), dtype=torch.int32, device=device
32
+ )
33
+
34
+ input_buffers["q_seqlens"] = torch.ones(
35
+ max_batches, dtype=torch.int32, device=device
36
+ )
37
+
38
+ input_buffers["kv_seqlens"] = torch.zeros(
39
+ max_batches, dtype=torch.int32, device=device
40
+ )
41
+
42
+ input_buffers["q_start_loc"] = torch.arange(
43
+ max_batches + 1, dtype=torch.int32, device=device
44
+ )
45
+
46
+ input_buffers["kv_start_indices"] = -torch.ones(
47
+ (max_batches, 1), dtype=torch.int64, device=device
48
+ )
49
+ return input_buffers
50
+
51
+
52
+ def MacaCudaGraphMixin_fill_buffers_cudagraph(
53
+ self,
54
+ graph_meta: CudaGraphMeta,
55
+ input_ids: Tensor,
56
+ position_ids: Tensor,
57
+ past_key_values: List,
58
+ attn_metadata: Any,
59
+ inputs_embeds: Tensor,
60
+ **kwargs
61
+ ) -> Dict[str, Tensor]:
62
+ """fill cudagraph buffers from forward inputs."""
63
+ block_offsets: Tensor = attn_metadata.block_offsets
64
+ q_start_loc: Tensor = attn_metadata.q_start_loc
65
+ q_seqlens: Tensor = attn_metadata.q_seqlens
66
+ kv_seqlens: Tensor = attn_metadata.kv_seqlens
67
+ kv_start_indices: Tensor = attn_metadata.kv_start_indices
68
+
69
+ input_buffers: BuffType = graph_meta.input_buffers
70
+
71
+ batch_size, num_blocks = block_offsets.size()
72
+ num_tokens = input_ids.size(-1)
73
+
74
+ # fill buffer
75
+ input_buffers["input_ids"][:, :num_tokens] = input_ids
76
+ input_buffers["position_ids"][:, :num_tokens] = position_ids
77
+ input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets
78
+ input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
79
+ input_buffers["q_seqlens"][:batch_size] = q_seqlens
80
+ input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
81
+ input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
82
+
83
+ if inputs_embeds is not None:
84
+ emb_size = inputs_embeds.size(-1)
85
+ if "inputs_embeds" not in input_buffers:
86
+ max_num_tokens = input_buffers["input_ids"].size(-1)
87
+ input_buffers["inputs_embeds"] = inputs_embeds.new_zeros(
88
+ 1, max_num_tokens, emb_size
89
+ )
90
+ input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
91
+ # create inputs
92
+ new_batch_size = next_power_of_2(batch_size)
93
+
94
+ attn_metadata.block_offsets = input_buffers["block_offsets"][:new_batch_size]
95
+ attn_metadata.q_start_loc = input_buffers["q_start_loc"][: new_batch_size + 1]
96
+ attn_metadata.q_seqlens = input_buffers["q_seqlens"][:new_batch_size]
97
+ attn_metadata.kv_seqlens = input_buffers["kv_seqlens"][:new_batch_size]
98
+ attn_metadata.kv_start_indices = input_buffers["kv_start_indices"][:new_batch_size]
99
+
100
+ new_inputs = dict(
101
+ past_key_values=past_key_values,
102
+ attn_metadata=attn_metadata,
103
+ )
104
+
105
+ new_inputs["input_ids"] = input_buffers["input_ids"][:, :new_batch_size]
106
+ new_inputs["position_ids"] = input_buffers["position_ids"][:, :new_batch_size]
107
+
108
+ if inputs_embeds is not None:
109
+ new_inputs["inputs_embeds"] = input_buffers["inputs_embeds"][:, :new_batch_size]
110
+
111
+ new_inputs.update(kwargs)
112
+
113
+ return new_inputs
114
+
115
+
116
+ def MacaCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
117
+ """update step context with input buffers."""
118
+ input_buffers = graph_meta.input_buffers
119
+ context.block_offsets = input_buffers["block_offsets"]
120
+ context.q_seqlens = input_buffers["q_seqlens"]
121
+ context.kv_seqlens = input_buffers["kv_seqlens"]
122
+ context.q_start_loc = input_buffers["q_start_loc"]
123
+ context.kv_start_indices = input_buffers["kv_start_indices"]
124
+
125
+
126
+ CudaGraphMixin.make_buffers_cudagraph = MacaCudaGraphMixin_make_buffers_cudagraph
127
+ CudaGraphMixin.fill_buffers_cudagraph = MacaCudaGraphMixin_fill_buffers_cudagraph
128
+ CudaGraphMixin.update_context_cudagraph = MacaCudaGraphMixin_update_context_cudagraph