vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,203 @@
1
+ import copy
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+
6
+ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
7
+ SequenceGroupMetadata)
8
+ from vllm.spec_decode.interfaces import SpeculativeProposals
9
+ from vllm.spec_decode.top1_proposer import Top1Proposer
10
+ from vllm.worker.worker import Worker
11
+
12
+
13
+ class MultiStepWorker(Worker):
14
+ """The MultiStepWorker is equivalent to a Worker except that it allows
15
+ multiple forward passes in a single call, assuming the scheduler has
16
+ allocated enough space to store the additional KV. This reduces overhead
17
+ by invoking the scheduler less.
18
+
19
+ The MultiStepWorker does not support cache swap operations, or beam search.
20
+ Cache swap operations do not require large modifications. On the other hand,
21
+ beam search requires memory allocations during sequence forks and thus
22
+ requires more thought for MultiStepWorker support.
23
+ """
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+
28
+ # Lazy initialization list.
29
+ self._proposer: Top1Proposer
30
+
31
+ def init_device(self):
32
+ super().init_device()
33
+
34
+ self._proposer = Top1Proposer(
35
+ self,
36
+ self.device,
37
+ self.vocab_size,
38
+ max_proposal_len=self.max_model_len,
39
+ )
40
+
41
+ def set_include_gpu_probs_tensor(self):
42
+ # Need include_gpu_probs_tensor for multi_step_worker
43
+ self.model_runner.model.sampler.include_gpu_probs_tensor = True
44
+
45
+ @torch.inference_mode()
46
+ def sampler_output(
47
+ self,
48
+ execute_model_req: ExecuteModelRequest,
49
+ sample_len: int,
50
+ ) -> Tuple[List[SamplerOutput], bool]:
51
+ """Run the model forward pass sample_len times. Returns the list of
52
+ sampler output, one per model forward pass, along with indicator of
53
+ whether torch tensor in sampler output need to be transposed in latter
54
+ sampler_output_to_torch logic.
55
+
56
+ For multi step worker, this indicator shall be True.
57
+ """
58
+ self._raise_if_unsupported(execute_model_req)
59
+
60
+ # Shallow copy input data so modifications (such as appending tokens)
61
+ # do not cause side-effects.
62
+ copied_seq_group_metadata_list = self._shallow_copy_inputs(
63
+ execute_model_req.seq_group_metadata_list)
64
+ copied_execute_model_req = execute_model_req.clone(
65
+ copied_seq_group_metadata_list)
66
+
67
+ # Assert enough KV space for sample_len tokens per sequence.
68
+ self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
69
+ sample_len)
70
+
71
+ # Run model sample_len times.
72
+ model_outputs = []
73
+ for _ in range(sample_len):
74
+ model_output = super().execute_model(
75
+ execute_model_req=copied_execute_model_req)
76
+ assert (len(model_output) == 1
77
+ ), "composing multistep workers not supported"
78
+ model_output = model_output[0]
79
+
80
+ self._append_new_tokens(model_output,
81
+ copied_seq_group_metadata_list)
82
+ model_outputs.append(model_output)
83
+
84
+ return model_outputs, True
85
+
86
+ def get_spec_proposals(
87
+ self,
88
+ execute_model_req: ExecuteModelRequest,
89
+ ) -> SpeculativeProposals:
90
+ """Produce speculations given an input batch of sequences. The number of
91
+ speculative tokens per sequence is determined by max_proposal_len.
92
+ """
93
+
94
+ return self._proposer.get_proposals(execute_model_req)
95
+
96
+ def _append_new_tokens(
97
+ self, model_output: SamplerOutput,
98
+ seq_group_metadata_list: SequenceGroupMetadata) -> None:
99
+ """Given model output from a single run, append the tokens to the
100
+ sequences. This is normally done outside of the worker, but it is
101
+ required if the worker is to perform multiple forward passes.
102
+ """
103
+ for seq_group_metadata, sequence_group_outputs in zip(
104
+ seq_group_metadata_list, model_output):
105
+ seq_group_metadata.is_prompt = False
106
+
107
+ for seq_output in sequence_group_outputs.samples:
108
+ # NOTE: Beam search is not supported, so we can assume that
109
+ # parent_seq_id == seq_id.
110
+ seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
111
+
112
+ token_id = seq_output.output_token
113
+ token_logprob = seq_output.logprobs[token_id]
114
+
115
+ seq.append_token_id(token_id, token_logprob.logprob)
116
+
117
+ def _shallow_copy_inputs(
118
+ self, seq_group_metadata_list: List[SequenceGroupMetadata]
119
+ ) -> List[SequenceGroupMetadata]:
120
+ """Copy input data structures to remove side-effects when input data
121
+ structures are shared with other modules.
122
+
123
+ Helpful when the vLLM scheduler runs in the same process as the worker.
124
+ The alternative is deep-copying (or other form of deep copy); this has
125
+ performance downsides.
126
+ """
127
+
128
+ # Shallow-copy the list of SequenceGroupMetadata. This allows us to
129
+ # append tokens and change is_prompt without external side-effects.
130
+ new_seq_group_metadata_list = []
131
+
132
+ for old_seq_group_metadata in seq_group_metadata_list:
133
+ # We must shallow-copy seq_group_metadata as is_prompt could change.
134
+ seq_group_metadata = copy.copy(old_seq_group_metadata)
135
+ new_seq_group_metadata_list.append(seq_group_metadata)
136
+
137
+ # We must shallow-copy seq_data as we will append token ids
138
+ new_seq_data = {}
139
+ for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
140
+ new_seq_data[seq_id] = copy.copy(old_seq_data)
141
+ new_seq_data[
142
+ seq_id].output_token_ids = old_seq_data.output_token_ids[:]
143
+
144
+ seq_group_metadata.seq_data = new_seq_data
145
+
146
+ return new_seq_group_metadata_list
147
+
148
+ def _assert_enough_kv_space(
149
+ self, seq_group_metadata_list: List[SequenceGroupMetadata],
150
+ num_steps: int) -> None:
151
+ """Assert there are enough physical blocks per sequence to store the
152
+ current KV plus additional KV from num_steps tokens.
153
+ """
154
+ assert self.model_runner.block_size is not None
155
+ for seq_group_metadata in seq_group_metadata_list:
156
+ # Only one seq_id is guaranteed because there is no beam search.
157
+ seq_id = list(seq_group_metadata.seq_data.keys())[0]
158
+ seq = seq_group_metadata.seq_data[seq_id]
159
+
160
+ # After num_steps, the seq len will be the current seq len
161
+ # plus one token per step.
162
+ final_seq_len = seq.get_len() + num_steps
163
+
164
+ # We will have final_seq_len - 1 KV because vLLM saves KV for a
165
+ # token in the iteration after the token was generated.
166
+ required_num_kv_slots = final_seq_len - 1
167
+
168
+ # The allocated number of kv slots is the number of allocated blocks
169
+ # times the number of slots of block.
170
+ number_physical_blocks = len(
171
+ seq_group_metadata.block_tables[seq_id])
172
+ allocated_kv_slots = (number_physical_blocks *
173
+ self.model_runner.block_size)
174
+
175
+ if required_num_kv_slots > allocated_kv_slots:
176
+ request_id = seq_group_metadata.request_id
177
+ raise ValueError(
178
+ "The worker attempted to run "
179
+ f"{num_steps} times but found insufficient KV space for "
180
+ f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
181
+ f"{required_num_kv_slots=}).")
182
+
183
+ def _raise_if_unsupported(
184
+ self,
185
+ execute_model_req: ExecuteModelRequest,
186
+ ) -> None:
187
+ """MultiStepWorker does not yet implement support for cache swap
188
+ operations or beam search.
189
+ """
190
+ if any([
191
+ execute_model_req.blocks_to_swap_in,
192
+ execute_model_req.blocks_to_swap_out,
193
+ execute_model_req.blocks_to_copy
194
+ ]):
195
+ raise NotImplementedError(
196
+ "MultiStepWorker does not support cache operations")
197
+
198
+ if any(
199
+ len(seq_group_metadata.seq_data.keys()) != 1
200
+ for seq_group_metadata in
201
+ execute_model_req.seq_group_metadata_list):
202
+ raise NotImplementedError(
203
+ "MultiStepWorker does not support beam search.")
@@ -0,0 +1,176 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from vllm.sequence import ExecuteModelRequest, SamplerOutput
6
+ from vllm.spec_decode.interfaces import SpeculativeProposals
7
+ from vllm.spec_decode.top1_proposer import Top1Proposer
8
+ from vllm.worker.worker_base import LoraNotSupportedWorkerBase
9
+
10
+
11
+ class NGramWorker(LoraNotSupportedWorkerBase):
12
+ """NGramWorker provides a light drafter without need for model.
13
+
14
+ Current NGramWorker only implement prompt lookup decoding,
15
+ and in future we may also do RAG type drafter and other scenerios
16
+ which don't rely on LLM model to give proposals.
17
+ """
18
+
19
+ def __init__(self, *args, **kwargs):
20
+ # Get local_rank/vocab_size from kwargs attribute
21
+ self.local_rank = kwargs["local_rank"]
22
+ self.vocab_size = kwargs["model_config"].get_vocab_size()
23
+
24
+ # Lazy initialization list.
25
+ self._proposer: Top1Proposer
26
+
27
+ def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
28
+ ngram_prompt_lookup_max: int):
29
+ # Search valid candidate window between
30
+ # ngram_prompt_lookup_min/ngram_prompt_lookup_max
31
+ self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
32
+ self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
33
+
34
+ def init_device(self):
35
+ self.device = torch.device(f"cuda:{self.local_rank}")
36
+ self.load_model = lambda *args, **kwargs: None
37
+
38
+ # Current only support Top1Proposer
39
+ self._proposer = Top1Proposer(
40
+ self,
41
+ device=self.device,
42
+ vocab_size=self.vocab_size,
43
+ )
44
+
45
+ def set_include_gpu_probs_tensor(self):
46
+ # NGram don't need gpu sampler
47
+ pass
48
+
49
+ def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
50
+ """NGram doesn't depend on model execution, just pass this function"""
51
+ pass
52
+
53
+ def determine_num_available_blocks(self) -> None:
54
+ """NGram doesn't depend on model execution, no need to check blocks"""
55
+ pass
56
+
57
+ def initialize_cache(self, num_gpu_blocks: int,
58
+ num_cpu_blocks: int) -> None:
59
+ """As there is no cache need to handle, just pass this function"""
60
+ pass
61
+
62
+ def get_cache_block_size_bytes(self):
63
+ """Return the size of a cache block in bytes."""
64
+ return 0
65
+
66
+ def sampler_output(
67
+ self,
68
+ execute_model_req: ExecuteModelRequest,
69
+ sample_len: int,
70
+ ) -> Tuple[Optional[List[SamplerOutput]], bool]:
71
+ """NGram match algo to pick proposal candidate. Returns the list of
72
+ sampler output, one per SequenceGroupMetadata.
73
+
74
+ For ngram worker, we already done needed transposed internal, so the
75
+ indicator pass to sampler_output_to_torch shall be False.
76
+ """
77
+ self._raise_if_unsupported(execute_model_req)
78
+
79
+ arr = []
80
+ has_spec_out = False
81
+ for seq_group_metadata in execute_model_req.seq_group_metadata_list:
82
+ seq_data = next(iter(seq_group_metadata.seq_data.values()))
83
+
84
+ input_ids = torch.as_tensor(seq_data.get_token_ids(),
85
+ dtype=torch.long,
86
+ device=self.device)
87
+ input_length = seq_data.get_len()
88
+
89
+ for ngram_size in range(
90
+ min(self.ngram_prompt_lookup_max, input_length - 1),
91
+ self.ngram_prompt_lookup_min,
92
+ -1,
93
+ ):
94
+ ngram_tensor = input_ids[-1 * ngram_size:]
95
+ windows = input_ids.unfold(dimension=0,
96
+ size=ngram_size,
97
+ step=1)
98
+ matches = (windows == ngram_tensor).all(dim=1)
99
+ match_indices = matches.nonzero(as_tuple=True)[0]
100
+ if match_indices.size()[0] > 1:
101
+ has_spec_out = True
102
+ res = seq_data.get_token_ids()
103
+ res = res[match_indices[0] + ngram_size:match_indices[0] +
104
+ ngram_size + sample_len]
105
+ res_len = len(res)
106
+ # pad 0 towards output as sample_len tokens required
107
+ res += [0] * (sample_len - res_len)
108
+
109
+ break
110
+ else:
111
+ # if no candidate found, fill with 0
112
+ res = [0] * sample_len
113
+
114
+ arr.append(res)
115
+
116
+ if not has_spec_out:
117
+ return None, False
118
+
119
+ outputs = []
120
+ token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device)
121
+ indices = token_ids.unsqueeze(2)
122
+
123
+ token_probs = torch.zeros(
124
+ (len(execute_model_req.seq_group_metadata_list), sample_len,
125
+ self.vocab_size),
126
+ dtype=torch.float32,
127
+ device=self.device,
128
+ )
129
+ token_probs.scatter_(2, indices, 1)
130
+ token_logprobs = torch.zeros(
131
+ (len(execute_model_req.seq_group_metadata_list), sample_len,
132
+ self.vocab_size),
133
+ dtype=torch.float32,
134
+ device=self.device,
135
+ )
136
+ for i in range(len(execute_model_req.seq_group_metadata_list)):
137
+ outputs.append(
138
+ SamplerOutput(
139
+ outputs=None,
140
+ sampled_token_probs=token_probs[i],
141
+ logprobs=token_logprobs,
142
+ sampled_token_ids=token_ids[i],
143
+ ))
144
+ return outputs, False
145
+
146
+ def get_spec_proposals(
147
+ self,
148
+ execute_model_req: ExecuteModelRequest,
149
+ ) -> SpeculativeProposals:
150
+ """Produce speculations given an input batch of sequences. The number of
151
+ speculative tokens per sequence is determined by max_proposal_len.
152
+ """
153
+
154
+ return self._proposer.get_proposals(execute_model_req)
155
+
156
+ def _raise_if_unsupported(
157
+ self,
158
+ execute_model_req: ExecuteModelRequest,
159
+ ) -> None:
160
+ """NGramWorker does not yet implement support for cache swap
161
+ operations or beam search.
162
+ """
163
+ if any([
164
+ execute_model_req.blocks_to_swap_in,
165
+ execute_model_req.blocks_to_swap_out,
166
+ execute_model_req.blocks_to_copy
167
+ ]):
168
+ raise NotImplementedError(
169
+ "NGramWorker does not support cache operations")
170
+
171
+ if any(
172
+ len(seq_group_metadata.seq_data.keys()) != 1
173
+ for seq_group_metadata in
174
+ execute_model_req.seq_group_metadata_list):
175
+ raise NotImplementedError(
176
+ "NGramWorker does not support beam search.")