vllm-npu 0.4.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. vllm/__init__.py +23 -0
  2. vllm/_custom_ops.py +251 -0
  3. vllm/attention/__init__.py +13 -0
  4. vllm/attention/backends/__init__.py +0 -0
  5. vllm/attention/backends/abstract.py +127 -0
  6. vllm/attention/backends/flash_attn.py +271 -0
  7. vllm/attention/backends/flashinfer.py +220 -0
  8. vllm/attention/backends/rocm_flash_attn.py +374 -0
  9. vllm/attention/backends/torch_sdpa.py +250 -0
  10. vllm/attention/backends/xformers.py +393 -0
  11. vllm/attention/layer.py +56 -0
  12. vllm/attention/ops/__init__.py +0 -0
  13. vllm/attention/ops/paged_attn.py +216 -0
  14. vllm/attention/ops/prefix_prefill.py +792 -0
  15. vllm/attention/ops/triton_flash_attention.py +810 -0
  16. vllm/attention/selector.py +91 -0
  17. vllm/block.py +84 -0
  18. vllm/config.py +1225 -0
  19. vllm/core/__init__.py +0 -0
  20. vllm/core/block/__init__.py +0 -0
  21. vllm/core/block/block_table.py +295 -0
  22. vllm/core/block/common.py +199 -0
  23. vllm/core/block/cpu_gpu_block_allocator.py +228 -0
  24. vllm/core/block/interfaces.py +205 -0
  25. vllm/core/block/naive_block.py +318 -0
  26. vllm/core/block/prefix_caching_block.py +606 -0
  27. vllm/core/block_manager_v1.py +625 -0
  28. vllm/core/block_manager_v2.py +258 -0
  29. vllm/core/evictor_v1.py +105 -0
  30. vllm/core/evictor_v2.py +127 -0
  31. vllm/core/interfaces.py +113 -0
  32. vllm/core/policy.py +45 -0
  33. vllm/core/scheduler.py +1163 -0
  34. vllm/distributed/__init__.py +3 -0
  35. vllm/distributed/communication_op.py +237 -0
  36. vllm/distributed/device_communicators/__init__.py +0 -0
  37. vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
  38. vllm/distributed/device_communicators/pynccl.py +287 -0
  39. vllm/distributed/device_communicators/pynccl_utils.py +66 -0
  40. vllm/distributed/parallel_state.py +339 -0
  41. vllm/distributed/utils.py +136 -0
  42. vllm/engine/__init__.py +0 -0
  43. vllm/engine/arg_utils.py +649 -0
  44. vllm/engine/async_llm_engine.py +737 -0
  45. vllm/engine/llm_engine.py +784 -0
  46. vllm/engine/metrics.py +368 -0
  47. vllm/engine/output_processor/__init__.py +0 -0
  48. vllm/engine/output_processor/interfaces.py +76 -0
  49. vllm/engine/output_processor/multi_step.py +142 -0
  50. vllm/engine/output_processor/single_step.py +284 -0
  51. vllm/engine/output_processor/stop_checker.py +101 -0
  52. vllm/engine/output_processor/util.py +19 -0
  53. vllm/entrypoints/__init__.py +0 -0
  54. vllm/entrypoints/api_server.py +119 -0
  55. vllm/entrypoints/llm.py +259 -0
  56. vllm/entrypoints/openai/__init__.py +0 -0
  57. vllm/entrypoints/openai/api_server.py +186 -0
  58. vllm/entrypoints/openai/cli_args.py +115 -0
  59. vllm/entrypoints/openai/protocol.py +460 -0
  60. vllm/entrypoints/openai/serving_chat.py +392 -0
  61. vllm/entrypoints/openai/serving_completion.py +347 -0
  62. vllm/entrypoints/openai/serving_engine.py +234 -0
  63. vllm/envs.py +217 -0
  64. vllm/executor/__init__.py +0 -0
  65. vllm/executor/cpu_executor.py +152 -0
  66. vllm/executor/distributed_gpu_executor.py +115 -0
  67. vllm/executor/executor_base.py +115 -0
  68. vllm/executor/gpu_executor.py +150 -0
  69. vllm/executor/multiproc_worker_utils.py +263 -0
  70. vllm/executor/neuron_executor.py +91 -0
  71. vllm/executor/ray_gpu_executor.py +327 -0
  72. vllm/executor/ray_utils.py +119 -0
  73. vllm/logger.py +153 -0
  74. vllm/logging/__init__.py +5 -0
  75. vllm/logging/formatter.py +15 -0
  76. vllm/lora/__init__.py +0 -0
  77. vllm/lora/fully_sharded_layers.py +262 -0
  78. vllm/lora/layers.py +1181 -0
  79. vllm/lora/lora.py +167 -0
  80. vllm/lora/models.py +645 -0
  81. vllm/lora/punica.py +213 -0
  82. vllm/lora/request.py +32 -0
  83. vllm/lora/utils.py +98 -0
  84. vllm/lora/worker_manager.py +251 -0
  85. vllm/model_executor/__init__.py +7 -0
  86. vllm/model_executor/guided_decoding/__init__.py +25 -0
  87. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
  88. vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
  89. vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
  90. vllm/model_executor/layers/__init__.py +0 -0
  91. vllm/model_executor/layers/activation.py +173 -0
  92. vllm/model_executor/layers/fused_moe/__init__.py +7 -0
  93. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  94. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  97. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  98. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  99. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  100. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  101. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  102. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  103. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  104. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  105. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
  106. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  107. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  108. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  109. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  110. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  111. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  112. vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
  113. vllm/model_executor/layers/layernorm.py +71 -0
  114. vllm/model_executor/layers/linear.py +709 -0
  115. vllm/model_executor/layers/logits_processor.py +115 -0
  116. vllm/model_executor/layers/ops/__init__.py +0 -0
  117. vllm/model_executor/layers/ops/rand.py +157 -0
  118. vllm/model_executor/layers/ops/sample.py +406 -0
  119. vllm/model_executor/layers/quantization/__init__.py +35 -0
  120. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  121. vllm/model_executor/layers/quantization/awq.py +175 -0
  122. vllm/model_executor/layers/quantization/base_config.py +97 -0
  123. vllm/model_executor/layers/quantization/fp8.py +265 -0
  124. vllm/model_executor/layers/quantization/gptq.py +224 -0
  125. vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
  126. vllm/model_executor/layers/quantization/marlin.py +227 -0
  127. vllm/model_executor/layers/quantization/schema.py +84 -0
  128. vllm/model_executor/layers/quantization/squeezellm.py +137 -0
  129. vllm/model_executor/layers/rejection_sampler.py +405 -0
  130. vllm/model_executor/layers/rotary_embedding.py +525 -0
  131. vllm/model_executor/layers/sampler.py +1051 -0
  132. vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
  133. vllm/model_executor/model_loader/__init__.py +30 -0
  134. vllm/model_executor/model_loader/loader.py +362 -0
  135. vllm/model_executor/model_loader/neuron.py +136 -0
  136. vllm/model_executor/model_loader/tensorizer.py +368 -0
  137. vllm/model_executor/model_loader/utils.py +41 -0
  138. vllm/model_executor/model_loader/weight_utils.py +372 -0
  139. vllm/model_executor/models/__init__.py +119 -0
  140. vllm/model_executor/models/baichuan.py +410 -0
  141. vllm/model_executor/models/bloom.py +327 -0
  142. vllm/model_executor/models/chatglm.py +386 -0
  143. vllm/model_executor/models/commandr.py +373 -0
  144. vllm/model_executor/models/dbrx.py +413 -0
  145. vllm/model_executor/models/decilm.py +122 -0
  146. vllm/model_executor/models/deepseek.py +438 -0
  147. vllm/model_executor/models/falcon.py +444 -0
  148. vllm/model_executor/models/gemma.py +393 -0
  149. vllm/model_executor/models/gpt2.py +266 -0
  150. vllm/model_executor/models/gpt_bigcode.py +274 -0
  151. vllm/model_executor/models/gpt_j.py +281 -0
  152. vllm/model_executor/models/gpt_neox.py +295 -0
  153. vllm/model_executor/models/internlm2.py +323 -0
  154. vllm/model_executor/models/jais.py +333 -0
  155. vllm/model_executor/models/llama.py +442 -0
  156. vllm/model_executor/models/llava.py +239 -0
  157. vllm/model_executor/models/minicpm.py +531 -0
  158. vllm/model_executor/models/mixtral.py +583 -0
  159. vllm/model_executor/models/mixtral_quant.py +404 -0
  160. vllm/model_executor/models/mpt.py +295 -0
  161. vllm/model_executor/models/olmo.py +356 -0
  162. vllm/model_executor/models/opt.py +349 -0
  163. vllm/model_executor/models/orion.py +319 -0
  164. vllm/model_executor/models/phi.py +300 -0
  165. vllm/model_executor/models/qwen.py +284 -0
  166. vllm/model_executor/models/qwen2.py +367 -0
  167. vllm/model_executor/models/qwen2_moe.py +447 -0
  168. vllm/model_executor/models/stablelm.py +301 -0
  169. vllm/model_executor/models/starcoder2.py +302 -0
  170. vllm/model_executor/models/xverse.py +366 -0
  171. vllm/model_executor/sampling_metadata.py +588 -0
  172. vllm/model_executor/utils.py +35 -0
  173. vllm/outputs.py +150 -0
  174. vllm/py.typed +2 -0
  175. vllm/sampling_params.py +340 -0
  176. vllm/sequence.py +766 -0
  177. vllm/spec_decode/__init__.py +0 -0
  178. vllm/spec_decode/batch_expansion.py +397 -0
  179. vllm/spec_decode/interfaces.py +73 -0
  180. vllm/spec_decode/metrics.py +191 -0
  181. vllm/spec_decode/multi_step_worker.py +203 -0
  182. vllm/spec_decode/ngram_worker.py +176 -0
  183. vllm/spec_decode/spec_decode_worker.py +472 -0
  184. vllm/spec_decode/top1_proposer.py +200 -0
  185. vllm/spec_decode/util.py +228 -0
  186. vllm/test_utils.py +41 -0
  187. vllm/transformers_utils/__init__.py +0 -0
  188. vllm/transformers_utils/config.py +58 -0
  189. vllm/transformers_utils/configs/__init__.py +16 -0
  190. vllm/transformers_utils/configs/chatglm.py +68 -0
  191. vllm/transformers_utils/configs/dbrx.py +278 -0
  192. vllm/transformers_utils/configs/falcon.py +87 -0
  193. vllm/transformers_utils/configs/jais.py +236 -0
  194. vllm/transformers_utils/configs/mpt.py +178 -0
  195. vllm/transformers_utils/detokenizer.py +313 -0
  196. vllm/transformers_utils/tokenizer.py +149 -0
  197. vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
  198. vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
  199. vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
  200. vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
  201. vllm/transformers_utils/tokenizers/__init__.py +5 -0
  202. vllm/transformers_utils/tokenizers/baichuan.py +255 -0
  203. vllm/usage/__init__.py +0 -0
  204. vllm/usage/usage_lib.py +209 -0
  205. vllm/utils.py +677 -0
  206. vllm/worker/__init__.py +0 -0
  207. vllm/worker/cache_engine.py +105 -0
  208. vllm/worker/cpu_model_runner.py +346 -0
  209. vllm/worker/cpu_worker.py +321 -0
  210. vllm/worker/model_runner.py +1168 -0
  211. vllm/worker/neuron_model_runner.py +196 -0
  212. vllm/worker/neuron_worker.py +98 -0
  213. vllm/worker/worker.py +345 -0
  214. vllm/worker/worker_base.py +146 -0
  215. vllm_npu-0.4.2.dist-info/LICENSE +201 -0
  216. vllm_npu-0.4.2.dist-info/METADATA +173 -0
  217. vllm_npu-0.4.2.dist-info/RECORD +219 -0
  218. vllm_npu-0.4.2.dist-info/WHEEL +5 -0
  219. vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,84 @@
1
+ """
2
+ This file contains the Pydantic schemas for various quantization-related
3
+ parameters. When a relevant quantization technique is specified, these
4
+ parameters are loaded in the form of a JSON alongside the model weights
5
+ and augment the model with additional information needed for use of that
6
+ technique. The format of this JSON should be specified by one or more
7
+ schemas contained here.
8
+
9
+ For example, when the KV cache is quantized to FP8-E4M3 (currently only
10
+ possible on ROCm), the model can be optionally augmented with KV cache
11
+ scaling factors.
12
+ """
13
+
14
+ from typing import Dict, Optional
15
+
16
+ from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
17
+
18
+
19
+ class KVCacheQuantSchema(BaseModel):
20
+ dtype: str
21
+ # Each key is a TP rank. Each value is a dictionary mapping a TP rank's
22
+ # layer indices to their per-tensor KV cache scaling factor.
23
+ # TODO: Consider pulling this and its validation methods out into its
24
+ # own schema class (tricky as its members are variable)
25
+ scaling_factor: Dict[int, Dict[int, float]]
26
+
27
+ @model_validator(mode="after")
28
+ def check_is_fp8(self) -> "KVCacheQuantSchema":
29
+ assert self.dtype == "float8_e4m3fn", (
30
+ "Loaded scaling factors intended for KV cache dtype = "
31
+ f"{self.dtype} rather than float8_e4m3fn!")
32
+ return self
33
+
34
+ @model_validator(mode="after")
35
+ def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
36
+ context = info.context
37
+ if context:
38
+ tp_size = context["tp_size"]
39
+ num_hidden_layers = context["num_hidden_layers"]
40
+ assert len(self.scaling_factor) == tp_size, (
41
+ f"Loaded dictionary has TP size {len(self.scaling_factor)} "
42
+ f"but LLM engine is currently running with TP size {tp_size}.")
43
+ for tp_rank, layer_maps in self.scaling_factor.items():
44
+ assert len(layer_maps) == num_hidden_layers, (
45
+ f"KV cache scales map for TP rank {tp_rank} is malformed. "
46
+ f"Expected {num_hidden_layers} layers, got "
47
+ f"{len(layer_maps)}.")
48
+ for i in range(tp_size):
49
+ assert i in self.scaling_factor, (
50
+ f"KV cache scales map for TP rank {i} not found.")
51
+ return self
52
+
53
+ @model_validator(mode="after")
54
+ def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
55
+ context = info.context
56
+ if context:
57
+ tp_rank = context["tp_rank"]
58
+ num_hidden_layers = context["num_hidden_layers"]
59
+ layer_scales_map = self.scaling_factor[tp_rank]
60
+ for i in range(num_hidden_layers):
61
+ assert i in layer_scales_map, (
62
+ f"Could not find KV cache scales for layer {i} in "
63
+ f"TP rank {tp_rank}.")
64
+ return self
65
+
66
+
67
+ class QuantParamSchema(BaseModel):
68
+ # TODO: Generalize and extend with more fields
69
+ # (e.g. weights/activations params) once functionality is enabled
70
+ model_config = ConfigDict(protected_namespaces=())
71
+ model_type: Optional[str]
72
+ kv_cache: KVCacheQuantSchema
73
+
74
+ @model_validator(mode="after")
75
+ def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
76
+ context = info.context
77
+ if context:
78
+ model_type = context.get("model_type", None)
79
+ if model_type is not None:
80
+ assert model_type == self.model_type, (
81
+ f"Model type is {model_type} but loaded "
82
+ f"scaling factors belonging to different "
83
+ f"model type {self.model_type}!")
84
+ return self
@@ -0,0 +1,137 @@
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ import torch
4
+ from torch.nn.parameter import Parameter
5
+
6
+ from vllm import _custom_ops as ops
7
+ from vllm.model_executor.layers.linear import LinearBase
8
+ from vllm.model_executor.layers.quantization.base_config import (
9
+ QuantizationConfig, QuantizeMethodBase)
10
+ from vllm.model_executor.utils import set_weight_attrs
11
+ from vllm.utils import is_hip
12
+
13
+
14
+ class SqueezeLLMConfig(QuantizationConfig):
15
+ """Config class for SqueezeLLM.
16
+
17
+ Reference: https://arxiv.org/pdf/2306.07629
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ weight_bits: int,
23
+ ) -> None:
24
+ self.weight_bits = weight_bits
25
+
26
+ if self.weight_bits != 4:
27
+ raise ValueError(
28
+ "Currently, only 4-bit weight quantization is supported for "
29
+ f"SqueezeLLM, but got {self.weight_bits} bits.")
30
+
31
+ self.pack_factor = 32 // self.weight_bits
32
+
33
+ def __repr__(self) -> str:
34
+ return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
35
+
36
+ def get_name(self) -> str:
37
+ return "squeezellm"
38
+
39
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
40
+ return [torch.half]
41
+
42
+ def get_min_capability(self) -> int:
43
+ return 70
44
+
45
+ @staticmethod
46
+ def get_config_filenames() -> List[str]:
47
+ return ["quant_config.json"]
48
+
49
+ @classmethod
50
+ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
51
+ weight_bits = cls.get_from_keys(config, ["wbits"])
52
+ return cls(weight_bits)
53
+
54
+ def get_quant_method(
55
+ self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
56
+ if isinstance(layer, LinearBase):
57
+ return SqueezeLLMLinearMethod(self)
58
+ return None
59
+
60
+ def get_scaled_act_names(self) -> List[str]:
61
+ return []
62
+
63
+
64
+ class SqueezeLLMLinearMethod(QuantizeMethodBase):
65
+ """Linear method for SqueezeLLM.
66
+
67
+ Args:
68
+ quant_config: The SqueezeLLM quantization config.
69
+ """
70
+
71
+ def __init__(self, quant_config: SqueezeLLMConfig):
72
+ self.quant_config = quant_config
73
+
74
+ def create_weights(self, layer: torch.nn.Module,
75
+ input_size_per_partition: int,
76
+ output_partition_sizes: List[int], input_size: int,
77
+ output_size: int, params_dtype: torch.dtype,
78
+ **extra_weight_attrs):
79
+ if input_size_per_partition % self.quant_config.pack_factor != 0:
80
+ raise ValueError(
81
+ "The input size is not aligned with the quantized "
82
+ "weight shape. This can be caused by too large "
83
+ "tensor parallel size.")
84
+
85
+ output_size_per_partition = sum(output_partition_sizes)
86
+ qweight = Parameter(
87
+ torch.empty(
88
+ input_size_per_partition // self.quant_config.pack_factor,
89
+ output_size_per_partition,
90
+ dtype=torch.int32,
91
+ ),
92
+ requires_grad=False,
93
+ )
94
+ set_weight_attrs(
95
+ qweight, {
96
+ "input_dim": 0,
97
+ "output_dim": 1,
98
+ "packed_dim": 0,
99
+ "pack_factor": self.quant_config.pack_factor,
100
+ })
101
+ lookup_table = Parameter(
102
+ torch.empty(
103
+ output_size,
104
+ self.quant_config.weight_bits**2,
105
+ dtype=params_dtype,
106
+ ),
107
+ requires_grad=False,
108
+ )
109
+ set_weight_attrs(lookup_table, {
110
+ "output_dim": 0,
111
+ })
112
+
113
+ layer.register_parameter("qweight", qweight)
114
+ set_weight_attrs(qweight, extra_weight_attrs)
115
+ layer.register_parameter("lookup_table", lookup_table)
116
+ set_weight_attrs(lookup_table, extra_weight_attrs)
117
+
118
+ def apply(self,
119
+ layer: torch.nn.Module,
120
+ x: torch.Tensor,
121
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
122
+ qweight = layer.qweight
123
+ lookup_table = layer.lookup_table
124
+ out_shape = x.shape[:-1] + (qweight.shape[-1], )
125
+ reshaped_x = x.reshape(-1, x.shape[-1])
126
+ if is_hip():
127
+ out_f = torch.zeros(out_shape, dtype=torch.float)
128
+ ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
129
+ out = out_f.to(dtype=torch.float16)
130
+ else:
131
+ # NOTE: The output tensor should be zero-initialized.
132
+ out = torch.zeros(out_shape, dtype=torch.float16)
133
+ ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
134
+
135
+ if bias is not None:
136
+ out.add_(bias)
137
+ return out.reshape(out_shape)
@@ -0,0 +1,405 @@
1
+ from functools import cached_property
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.jit
6
+ import torch.nn as nn
7
+
8
+
9
+ class RejectionSampler(nn.Module):
10
+ """Apply modified rejection sampling as described in "Accelerating Large
11
+ Language Model Decoding with Speculative Sampling"
12
+ https://arxiv.org/pdf/2302.01318.pdf.
13
+ """
14
+
15
+ def __init__(self, strict_mode: bool = False):
16
+ """Create a rejection sampler.
17
+
18
+ Args:
19
+ strict_mode: Whether or not to perform shape/device/dtype checks
20
+ during sampling. This catches correctness issues but adds
21
+ nontrivial latency.
22
+ """
23
+ super().__init__()
24
+ self._strict_mode = strict_mode
25
+
26
+ # NOTE: A "bonus token" is accepted iff all proposal tokens are
27
+ # accepted. There is always only one possible bonus token. We store this
28
+ # value in a variable for readability.
29
+ self._num_bonus_tokens = 1
30
+
31
+ self.num_accepted_tokens: Optional[torch.Tensor] = None
32
+ self.num_emitted_tokens: Optional[torch.Tensor] = None
33
+ self.num_draft_tokens: int = 0
34
+
35
+ def init_gpu_tensors(self, rank: int) -> None:
36
+ assert self.num_accepted_tokens is None
37
+ device = f"cuda:{rank}"
38
+ self.num_accepted_tokens = torch.tensor(0,
39
+ dtype=torch.long,
40
+ device=device)
41
+ self.num_emitted_tokens = torch.tensor(0,
42
+ dtype=torch.long,
43
+ device=device)
44
+
45
+ @property
46
+ def probs_dtype(self):
47
+ return torch.float32
48
+
49
+ @property
50
+ def token_id_dtype(self):
51
+ return torch.int64
52
+
53
+ def forward(
54
+ self,
55
+ target_probs: torch.Tensor,
56
+ bonus_token_ids: torch.Tensor,
57
+ draft_probs: torch.Tensor,
58
+ draft_token_ids: torch.Tensor,
59
+ ) -> torch.Tensor:
60
+ """Sample token ids using rejection sampling. This accepts or rejects
61
+ tokens proposed by the draft model using the probability of each token
62
+ according to the draft and target models.
63
+
64
+ In the worst case where all draft tokens are rejected, it is guaranteed
65
+ one correct token will be emitted.
66
+
67
+ In the case where all draft tokens are accepted, a bonus token will be
68
+ accepted as its cheap to have the target model score this speculative
69
+ sequence.
70
+
71
+ Args:
72
+ target_probs: The probability distribution over token ids given
73
+ context according to the target model.
74
+ shape = [batch_size, num_speculative_tokens, vocab_size]
75
+
76
+ bonus_token_ids: The "bonus" token ids that are accepted iff all
77
+ speculative tokens in a sequence are accepted.
78
+ shape = [batch_size, num_bonus_tokens]
79
+
80
+ draft_probs: The probability distribution over token ids given
81
+ context according to the draft model.
82
+ shape = [batch_size, num_speculative_tokens, vocab_size]
83
+
84
+ draft_token_ids: The token ids that were sampled from the draft
85
+ probabilities.
86
+ shape = [batch_size, num_speculative_tokens]
87
+
88
+ Returns:
89
+ output_token_ids: The token ids sampled via rejection sampling,
90
+ or -1 if unable to sample a token because the previous token
91
+ was rejected.
92
+ shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
93
+ """
94
+ # Only perform shape/dtype/device checking in strict mode, as it adds
95
+ # overhead.
96
+ if self._strict_mode:
97
+ self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
98
+ draft_probs, draft_token_ids)
99
+ self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
100
+ draft_probs, draft_token_ids)
101
+ self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
102
+ draft_probs, draft_token_ids)
103
+ self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
104
+ bonus_token_ids,
105
+ draft_token_ids)
106
+
107
+ accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
108
+ target_probs,
109
+ draft_probs,
110
+ draft_token_ids,
111
+ )
112
+
113
+ output_token_ids = self._create_output(
114
+ accepted,
115
+ recovered_token_ids,
116
+ draft_token_ids,
117
+ bonus_token_ids,
118
+ )
119
+ return output_token_ids
120
+
121
+ def _batch_modified_rejection_sampling(
122
+ self,
123
+ target_probs: torch.Tensor, # [batch_size, k, vocab_size]
124
+ draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
125
+ draft_token_ids: torch.Tensor, # [batch_size, k]
126
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ """Perform modified rejection sampling on each sequence.
128
+
129
+ Returns:
130
+ A tuple of two tensors:
131
+ 0: A bool tensor of which tokens in each sequence is accepted.
132
+ shape = [batch_size, k]
133
+ 1: Token ids sampled from a recovered distribution, to be used
134
+ when a token is rejected.
135
+ shape = [batch_size, k]
136
+ """
137
+
138
+ batch_size, k, vocab_size = draft_probs.shape
139
+
140
+ # shape [batch_size, k]
141
+ accepted = self._get_accepted(target_probs, draft_probs,
142
+ draft_token_ids)
143
+
144
+ recovered_probs = self._get_recovered_probs(
145
+ target_probs, draft_probs).reshape(batch_size * k, vocab_size)
146
+
147
+ # NOTE: the recovered_probs are overwritten by this method.
148
+ recovered_token_ids = _multinomial(recovered_probs,
149
+ num_samples=1).reshape(
150
+ batch_size, k)
151
+ return accepted, recovered_token_ids
152
+
153
+ def _get_accepted(
154
+ self,
155
+ target_probs: torch.Tensor, # [batch_size, k, vocab_size]
156
+ draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
157
+ draft_token_ids: torch.Tensor, # [batch_size, k]
158
+ ) -> torch.Tensor:
159
+ r"""Create bool matrix over the proposed draft tokens. If
160
+ True, then a token can be accepted, else it should be
161
+ rejected.
162
+
163
+ Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
164
+ :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
165
+ to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
166
+ same conditional probability according to the draft model, the token
167
+ is accepted with probability:
168
+
169
+ .. math::
170
+ \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
171
+ {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
172
+
173
+ This implementation does not apply causality. When using the output,
174
+ if a token is rejected, subsequent tokens should not be used.
175
+
176
+ Returns a bool tensor of shape [batch_size, k] specifying which tokens
177
+ are accepted.
178
+ """
179
+ batch_size, k, _ = draft_probs.shape
180
+ batch_indices = torch.arange(batch_size,
181
+ device=target_probs.device)[:, None]
182
+ probs_indicies = torch.arange(k, device=target_probs.device)
183
+
184
+ # shape [batch_size, k]
185
+ selected_draft_probs = draft_probs[batch_indices, probs_indicies,
186
+ draft_token_ids]
187
+
188
+ # shape [batch_size, k]
189
+ selected_target_probs = target_probs[batch_indices, probs_indicies,
190
+ draft_token_ids]
191
+
192
+ uniform_rand = torch.rand(batch_size,
193
+ k,
194
+ dtype=self.probs_dtype,
195
+ device=target_probs.device)
196
+ capped_ratio = torch.minimum(
197
+ selected_target_probs / selected_draft_probs,
198
+ torch.full((1, ), 1, device=target_probs.device))
199
+ accepted = uniform_rand < capped_ratio
200
+
201
+ return accepted
202
+
203
+ def _get_recovered_probs(
204
+ self,
205
+ target_probs: torch.Tensor, # [k, vocab_size]
206
+ draft_probs: torch.Tensor, # [k, vocab_size]
207
+ ) -> torch.Tensor:
208
+ r"""Create a probability distribution for each proposed token which can
209
+ be sampled if the proposed token is rejected.
210
+
211
+ When this routine is applied sequentially, the true distribution of the
212
+ target model is recovered (within hardware numerics).
213
+
214
+ The probability distribution used in this rejection case is constructed
215
+ as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
216
+ :math:`x` given context :math:`x_1, \dots, x_n` according to the target
217
+ model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
218
+ according to the draft model:
219
+
220
+ .. math::
221
+ x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
222
+
223
+ where :math:`(f(x))_+` is defined as:
224
+
225
+ .. math::
226
+ (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
227
+
228
+ See https://github.com/vllm-project/vllm/pull/2336 for a visualization
229
+ of the draft, target, and recovered probability distributions.
230
+
231
+ Returns a tensor of shape [batch_size, k, vocab_size].
232
+
233
+ Note: This batches operations on GPU and thus constructs the recovered
234
+ distribution for all tokens, even if they are accepted. This causes
235
+ division-by-zero errors, so we use self._smallest_positive_value to
236
+ avoid that. This introduces some drift to the distribution.
237
+ """
238
+ _, k, _ = draft_probs.shape
239
+
240
+ # shape [batch_size, k, vocab_size]
241
+ difference = target_probs - draft_probs
242
+
243
+ # TODO(cade): Can we use logprobs instead of probs, and avoid the
244
+ # division-by-zero errors without introducing distribution drift?
245
+
246
+ # shape [batch_size, k, vocab_size]
247
+ f = torch.clamp(difference, min=self._smallest_positive_value)
248
+
249
+ # shape [batch_size, k, vocab_size]
250
+ recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
251
+
252
+ return recovered_probs
253
+
254
+ @cached_property
255
+ def _smallest_positive_value(self) -> float:
256
+ """Return the smallest positive value representable by the probs dtype.
257
+ This value is used when constructing a distribution from which to sample
258
+ recovered tokens in the first rejection case.
259
+
260
+ See _get_recovered_probs for more details
261
+
262
+ Note that this isn't actually the smallest positive value representable
263
+ by float32, but the smallest positive normal value.
264
+ See https://en.wikipedia.org/wiki/Subnormal_number for more information.
265
+ """
266
+ return torch.finfo(self.probs_dtype).tiny
267
+
268
+ def _create_output(
269
+ self,
270
+ accepted: torch.Tensor, # [batch_size, k]
271
+ recovered_token_ids: torch.Tensor, # [batch_size, k]
272
+ draft_token_ids: torch.Tensor, # [batch_size, k]
273
+ bonus_token_ids: torch.Tensor, # [batch_size]
274
+ ) -> torch.Tensor:
275
+ """Format output. Returns a matrix of token ids. When
276
+ a token is rejected via rejection sampling, all subsequent
277
+ token ids are set to -1 for the sequence.
278
+
279
+ shape = [batch_size, k + num_bonus_tokens]
280
+ """
281
+ bonus_token_ids = bonus_token_ids.squeeze()
282
+ batch_size, k = recovered_token_ids.shape
283
+
284
+ # Determine the index of the first False value for each row.
285
+ limits = (accepted == 0).max(1).indices
286
+ limits[~(accepted == 0).any(1)] = k
287
+
288
+ # Create masks using the indices.
289
+ indices = torch.arange(k, device=accepted.device).unsqueeze(0)
290
+ accepted_mask = indices < limits.unsqueeze(1)
291
+ after_false_mask = indices == limits.unsqueeze(1)
292
+
293
+ # Create an extended output tensor
294
+ output_with_bonus_tokens = -torch.ones(
295
+ (batch_size, k + self._num_bonus_tokens),
296
+ dtype=self.token_id_dtype,
297
+ device=accepted.device)
298
+ output = output_with_bonus_tokens[:, :k]
299
+
300
+ # Fill in the first k columns of the output tensor using masks and data
301
+ # tensors.
302
+ output[:, :k] = torch.where(accepted_mask, draft_token_ids,
303
+ -torch.ones_like(draft_token_ids))
304
+
305
+ # Fill the last column.
306
+ # We check output directly as accepted may have True values inconsistent
307
+ # with causal acceptance.
308
+ output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
309
+ bonus_token_ids, -1)
310
+
311
+ # We disable bonus tokens because it causes corrupt KV cache for
312
+ # proposal methods that require KV cache. We can fix it by "prefilling"
313
+ # the bonus token in the proposer. The following issue tracks the fix.
314
+ # https://github.com/vllm-project/vllm/issues/4212
315
+ output_with_bonus_tokens[:, -1] = -1
316
+
317
+ # Fill the recovered token ids.
318
+ output.mul_(~after_false_mask).add_(
319
+ recovered_token_ids.mul(after_false_mask))
320
+
321
+ self.num_accepted_tokens += accepted.sum()
322
+ self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
323
+ self.num_draft_tokens += batch_size * k
324
+
325
+ return output_with_bonus_tokens
326
+
327
+ def _raise_if_incorrect_shape(
328
+ self,
329
+ target_probs: torch.Tensor,
330
+ bonus_token_ids: torch.Tensor,
331
+ draft_probs: torch.Tensor,
332
+ draft_token_ids: torch.Tensor,
333
+ ) -> None:
334
+ (target_batch_size, num_target_probs,
335
+ target_vocab_size) = target_probs.shape
336
+ bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
337
+ draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
338
+ draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
339
+
340
+ assert draft_batch_size == target_batch_size
341
+ assert num_draft_probs == num_target_probs
342
+ assert (draft_vocab_size == target_vocab_size
343
+ ), f"{draft_vocab_size=} {target_vocab_size=}"
344
+
345
+ assert draft_token_ids_batch_size == draft_batch_size
346
+ assert num_draft_token_ids == num_draft_probs
347
+
348
+ assert bonus_batch_size == target_batch_size
349
+ assert num_bonus_tokens == self._num_bonus_tokens
350
+
351
+ def _raise_if_incorrect_dtype(
352
+ self,
353
+ target_probs: torch.Tensor,
354
+ bonus_token_ids: torch.Tensor,
355
+ draft_probs: torch.Tensor,
356
+ draft_token_ids: torch.Tensor,
357
+ ) -> None:
358
+ assert all(probs.dtype == self.probs_dtype
359
+ for probs in [target_probs, draft_probs])
360
+ assert all(token_ids.dtype == self.token_id_dtype
361
+ for token_ids in [bonus_token_ids, draft_token_ids])
362
+
363
+ def _raise_if_inconsistent_device(
364
+ self,
365
+ target_probs: torch.Tensor,
366
+ bonus_token_ids: torch.Tensor,
367
+ draft_probs: torch.Tensor,
368
+ draft_token_ids: torch.Tensor,
369
+ ) -> None:
370
+ devices = [
371
+ t.device for t in
372
+ [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
373
+ ]
374
+ assert all([devices[0] == device for device in devices])
375
+
376
+ def _raise_if_out_of_bounds_vocab(
377
+ self,
378
+ vocab_size: int,
379
+ bonus_token_ids: torch.Tensor,
380
+ draft_token_ids: torch.Tensor,
381
+ ) -> None:
382
+ assert torch.all(bonus_token_ids < vocab_size)
383
+ assert torch.all(bonus_token_ids >= 0)
384
+ assert torch.all(draft_token_ids < vocab_size)
385
+ assert torch.all(draft_token_ids >= 0)
386
+
387
+
388
+ # torch.multinomial forces a GPU<->CPU sync.
389
+ # Therefore, we use an optimized implementation instead that skips the sync.
390
+ # Note that we always sample with replacement.
391
+ # probs will be modified in place, but this is fine, as we pass
392
+ # in a copy already.
393
+ @torch.jit.script
394
+ def _multinomial(
395
+ probs: torch.Tensor,
396
+ num_samples: int,
397
+ ) -> torch.Tensor:
398
+ if num_samples > 1:
399
+ # This is equivalent to torch.repeat_interleaved (which also
400
+ # forces a GPU<->CPU sync).
401
+ probs = probs[:, None, :].expand(probs.shape[0], num_samples,
402
+ probs.shape[1]).contiguous().view(
403
+ -1, probs.shape[1])
404
+ q = torch.empty_like(probs).exponential_(1.0)
405
+ return probs.div_(q).argmax(dim=1).view(-1, num_samples)