sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,637 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """Inference-only OPT model compatible with HuggingFace weights."""
16
+ from collections.abc import Iterable
17
+ from typing import Optional, Union
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+ from transformers import OPTConfig
23
+
24
+ from sglang.srt.distributed import (
25
+ get_pp_group,
26
+ get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ )
29
+ from sglang.srt.layers.activation import get_act_fn
30
+ from sglang.srt.layers.linear import (
31
+ ColumnParallelLinear,
32
+ MergedColumnParallelLinear,
33
+ QKVParallelLinear,
34
+ ReplicatedLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
38
+ from sglang.srt.layers.pooler import Pooler, PoolingType
39
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
+ from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ ParallelLMHead,
44
+ VocabParallelEmbedding,
45
+ )
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
+ from sglang.srt.model_loader.weight_utils import (
48
+ default_weight_loader,
49
+ kv_cache_scales_loader,
50
+ maybe_remap_kv_scale_name,
51
+ )
52
+ from sglang.srt.utils import add_prefix, make_layers
53
+
54
+
55
+ def get_activation(name="relu"):
56
+ """Select an activation function by name
57
+
58
+ Args:
59
+ name: str
60
+ activation function name,
61
+ one of ["relu", "gelu", "swish", "sigmoid"],
62
+ default "relu".
63
+ """
64
+ name = name.lower()
65
+ if name == "relu":
66
+ return nn.ReLU()
67
+ if name == "gelu":
68
+ return nn.GELU()
69
+ if name == "sigmoid":
70
+ return torch.nn.Sigmoid()
71
+ return nn.Identity()
72
+
73
+
74
+ class OPTLearnedPositionalEmbedding(nn.Embedding):
75
+
76
+ def __init__(self, num_embeddings: int, embedding_dim: int):
77
+ # OPT is set up so that if padding_idx is specified then offset the
78
+ # embedding ids by 2 and adjust num_embeddings appropriately. Other
79
+ # models don't have this hack
80
+ self.offset = 2
81
+ super().__init__(num_embeddings + self.offset, embedding_dim)
82
+
83
+ def forward(self, positions: torch.Tensor):
84
+ return super().forward(positions + self.offset)
85
+
86
+
87
+ class OPTAttention(nn.Module):
88
+
89
+ def __init__(
90
+ self,
91
+ embed_dim: int,
92
+ num_heads: int,
93
+ layer_id: int = 0,
94
+ bias: bool = True,
95
+ quant_config: Optional[QuantizationConfig] = None,
96
+ prefix: str = "",
97
+ ) -> None:
98
+ super().__init__()
99
+ self.embed_dim = embed_dim
100
+ tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
101
+ total_num_heads = num_heads
102
+ assert num_heads % tensor_model_parallel_world_size == 0
103
+ self.num_heads = total_num_heads // tensor_model_parallel_world_size
104
+ self.head_dim = embed_dim // total_num_heads
105
+ self.scaling = self.head_dim**-0.5
106
+
107
+ self.qkv_proj = QKVParallelLinear(
108
+ embed_dim,
109
+ self.head_dim,
110
+ total_num_heads,
111
+ bias=bias,
112
+ quant_config=quant_config,
113
+ prefix=add_prefix("qkv_proj", prefix),
114
+ )
115
+ self.out_proj = RowParallelLinear(
116
+ embed_dim,
117
+ embed_dim,
118
+ bias=bias,
119
+ quant_config=quant_config,
120
+ prefix=add_prefix("o_proj", prefix),
121
+ )
122
+
123
+ self.attn = RadixAttention(
124
+ self.num_heads,
125
+ self.head_dim,
126
+ self.scaling,
127
+ num_kv_heads=self.num_heads,
128
+ layer_id=layer_id,
129
+ quant_config=quant_config,
130
+ prefix=add_prefix("attn", prefix),
131
+ )
132
+
133
+ def forward(
134
+ self,
135
+ hidden_states: torch.Tensor,
136
+ forward_batch: ForwardBatch,
137
+ ) -> torch.Tensor:
138
+ qkv, _ = self.qkv_proj(hidden_states)
139
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
140
+ attn_output = self.attn(q, k, v, forward_batch)
141
+ output, _ = self.out_proj(attn_output)
142
+ return output
143
+
144
+
145
+ class OPTDecoderLayer(nn.Module):
146
+
147
+ def __init__(
148
+ self,
149
+ config: OPTConfig,
150
+ layer_id: int = 0,
151
+ quant_config: Optional[QuantizationConfig] = None,
152
+ prefix: str = "",
153
+ ):
154
+ super().__init__()
155
+ self.config = config
156
+ self.embed_dim = config.hidden_size
157
+ self.self_attn = OPTAttention(
158
+ embed_dim=self.embed_dim,
159
+ num_heads=config.num_attention_heads,
160
+ layer_id=layer_id,
161
+ bias=config.enable_bias,
162
+ quant_config=quant_config,
163
+ prefix=add_prefix("self_attn", prefix),
164
+ )
165
+ self.do_layer_norm_before = config.do_layer_norm_before
166
+
167
+ self.self_attn_layer_norm = nn.LayerNorm(
168
+ self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
169
+ )
170
+ self.fc1 = ColumnParallelLinear(
171
+ self.embed_dim,
172
+ config.ffn_dim,
173
+ bias=config.enable_bias,
174
+ quant_config=quant_config,
175
+ prefix=add_prefix("fc1", prefix),
176
+ )
177
+ self.activation_fn = get_activation(config.activation_function)
178
+ self.fc2 = RowParallelLinear(
179
+ config.ffn_dim,
180
+ self.embed_dim,
181
+ bias=config.enable_bias,
182
+ quant_config=quant_config,
183
+ prefix=add_prefix("fc2", prefix),
184
+ )
185
+ self.final_layer_norm = nn.LayerNorm(
186
+ self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
187
+ )
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states: torch.Tensor,
192
+ forward_batch: ForwardBatch,
193
+ ) -> torch.Tensor:
194
+ # Self Attention
195
+ residual = hidden_states
196
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
197
+ if self.do_layer_norm_before:
198
+ hidden_states = self.self_attn_layer_norm(hidden_states)
199
+ hidden_states = self.self_attn(
200
+ hidden_states=hidden_states, forward_batch=forward_batch
201
+ )
202
+ hidden_states = residual + hidden_states
203
+ # 350m applies layer norm AFTER attention
204
+ if not self.do_layer_norm_before:
205
+ hidden_states = self.self_attn_layer_norm(hidden_states)
206
+
207
+ # Fully Connected
208
+ residual = hidden_states
209
+ # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
210
+ if self.do_layer_norm_before:
211
+ hidden_states = self.final_layer_norm(hidden_states)
212
+ hidden_states, _ = self.fc1(hidden_states)
213
+ hidden_states = self.activation_fn(hidden_states)
214
+ hidden_states, _ = self.fc2(hidden_states)
215
+ hidden_states = residual + hidden_states
216
+ # 350m applies layer norm AFTER attention
217
+ if not self.do_layer_norm_before:
218
+ hidden_states = self.final_layer_norm(hidden_states)
219
+ return hidden_states
220
+
221
+
222
+ class OPTDecoder(nn.Module):
223
+
224
+ def __init__(
225
+ self,
226
+ config: OPTConfig,
227
+ layer_id: int = 0,
228
+ quant_config: Optional[QuantizationConfig] = None,
229
+ prefix: str = "",
230
+ ):
231
+ super().__init__()
232
+ self.config = config
233
+ self.max_target_positions = config.max_position_embeddings
234
+ self.vocab_size = config.vocab_size
235
+
236
+ self.pp_group = get_pp_group()
237
+
238
+ self.embed_tokens = VocabParallelEmbedding(
239
+ config.vocab_size,
240
+ config.word_embed_proj_dim,
241
+ prefix=add_prefix("embed_tokens", prefix),
242
+ )
243
+ # Positional embeddings are replicated (not sharded).
244
+ self.embed_positions = OPTLearnedPositionalEmbedding(
245
+ config.max_position_embeddings, config.hidden_size
246
+ )
247
+
248
+ # Project out & in will be replicated if they exist.
249
+ if config.word_embed_proj_dim != config.hidden_size:
250
+ self.project_out = ReplicatedLinear(
251
+ config.hidden_size,
252
+ config.word_embed_proj_dim,
253
+ bias=False,
254
+ quant_config=quant_config,
255
+ prefix=add_prefix("project_out", prefix),
256
+ )
257
+ else:
258
+ self.project_out = None
259
+
260
+ if config.word_embed_proj_dim != config.hidden_size:
261
+ self.project_in = ReplicatedLinear(
262
+ config.word_embed_proj_dim,
263
+ config.hidden_size,
264
+ bias=False,
265
+ quant_config=quant_config,
266
+ prefix=add_prefix("project_in", prefix),
267
+ )
268
+ else:
269
+ self.project_in = None
270
+
271
+ # Note that the only purpose of `config._remove_final_layer_norm` is to
272
+ # keep backward compatibility with checkpoints that have been fine-tuned
273
+ # before transformers v4.20.1
274
+ # see https://github.com/facebookresearch/metaseq/pull/164
275
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
276
+ self.final_layer_norm = nn.LayerNorm(
277
+ config.hidden_size,
278
+ elementwise_affine=config.layer_norm_elementwise_affine,
279
+ )
280
+ else:
281
+ self.final_layer_norm = None
282
+
283
+ self.layers, self.start_layer, self.end_layer = make_layers(
284
+ config.num_hidden_layers,
285
+ lambda idx, prefix: OPTDecoderLayer(
286
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
287
+ ),
288
+ pp_rank=self.pp_group.rank_in_group,
289
+ pp_size=self.pp_group.world_size,
290
+ prefix="model.layers",
291
+ )
292
+
293
+ def forward(
294
+ self,
295
+ input_ids: torch.Tensor,
296
+ positions: torch.Tensor,
297
+ forward_batch: ForwardBatch,
298
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
299
+ input_embeds: Optional[torch.Tensor] = None,
300
+ ) -> Union[torch.Tensor, PPProxyTensors]:
301
+ if self.pp_group.is_first_rank:
302
+ if input_embeds is None:
303
+ input_embeds = self.embed_tokens(input_ids)
304
+ pos_embeds = self.embed_positions(positions)
305
+ if self.project_in is not None:
306
+ input_embeds, _ = self.project_in(input_embeds)
307
+ hidden_states = input_embeds + pos_embeds
308
+ else:
309
+ assert pp_proxy_tensors is not None
310
+ hidden_states = pp_proxy_tensors["hidden_states"]
311
+
312
+ for layer in self.layers[self.start_layer : self.end_layer]:
313
+ hidden_states = layer(
314
+ hidden_states=hidden_states, forward_batch=forward_batch
315
+ )
316
+ if not self.pp_group.is_last_rank:
317
+ return PPProxyTensors({"hidden_states": hidden_states})
318
+ if self.final_layer_norm is not None:
319
+ hidden_states = self.final_layer_norm(hidden_states)
320
+ # 没有经过这里
321
+ if self.project_out is not None:
322
+ hidden_states, _ = self.project_out(hidden_states)
323
+ return hidden_states
324
+
325
+
326
+ class OPTModel(nn.Module):
327
+
328
+ def __init__(
329
+ self,
330
+ config: OPTConfig,
331
+ quant_config: Optional[QuantizationConfig] = None,
332
+ prefix: str = "",
333
+ ) -> None:
334
+ super().__init__()
335
+
336
+ # config = vllm_config.model_config.hf_config
337
+ # quant_config = vllm_config.quant_config
338
+ self.config = config
339
+ self.padding_idx = config.pad_token_id
340
+ self.vocab_size = config.vocab_size
341
+ self.pp_group = get_pp_group()
342
+
343
+ self.decoder = OPTDecoder(
344
+ config=config,
345
+ quant_config=quant_config,
346
+ prefix=add_prefix("decoder", prefix),
347
+ )
348
+
349
+ def forward(
350
+ self,
351
+ input_ids: torch.Tensor,
352
+ positions: torch.Tensor,
353
+ forward_batch: ForwardBatch,
354
+ pp_proxy_tensors: Optional[PPProxyTensors],
355
+ input_embeds: Optional[torch.Tensor] = None,
356
+ ) -> Union[torch.Tensor, PPProxyTensors]:
357
+ return self.decoder(
358
+ input_ids,
359
+ positions,
360
+ pp_proxy_tensors=pp_proxy_tensors,
361
+ input_embeds=input_embeds,
362
+ forward_batch=forward_batch,
363
+ )
364
+
365
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
366
+ tp_size = get_tensor_model_parallel_world_size()
367
+ tp_rank = get_tensor_model_parallel_rank()
368
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
369
+ quantization_param_path,
370
+ tp_rank,
371
+ tp_size,
372
+ self.config.num_hidden_layers,
373
+ self.config.__class__.model_type,
374
+ ):
375
+ if not isinstance(self.decoder.layers[layer_idx], nn.Identity):
376
+ layer_self_attn = self.decoder.layers[layer_idx].self_attn
377
+
378
+ if hasattr(layer_self_attn.attn, "k_scale"):
379
+ layer_self_attn.attn.k_scale = scaling_factor
380
+ layer_self_attn.attn.v_scale = scaling_factor
381
+ else:
382
+ raise RuntimeError(
383
+ "Self attention has no KV cache scaling " "factor attribute!"
384
+ )
385
+
386
+
387
+ class OPTForCausalLM(nn.Module):
388
+ # BitandBytes specific attributes
389
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
390
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
391
+
392
+ def __init__(
393
+ self,
394
+ config: OPTConfig,
395
+ quant_config: Optional[QuantizationConfig] = None,
396
+ prefix: str = "",
397
+ ):
398
+ super().__init__()
399
+ self.config = config
400
+ self.quant_config = quant_config
401
+
402
+ self.model = OPTModel(
403
+ config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
404
+ )
405
+ if self.config.tie_word_embeddings:
406
+ self.lm_head = self.model.decoder.embed_tokens
407
+ else:
408
+ self.lm_head = ParallelLMHead(
409
+ config.vocab_size,
410
+ config.word_embed_proj_dim,
411
+ prefix=add_prefix("lm_head", prefix),
412
+ )
413
+ self.logits_processor = LogitsProcessor(config)
414
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
415
+ self.capture_aux_hidden_states = False
416
+ self.pp_group = get_pp_group()
417
+ self.stacked_params_mapping = [
418
+ # (param_name, shard_name, shard_id)
419
+ (".qkv_proj", ".q_proj", "q"),
420
+ (".qkv_proj", ".k_proj", "k"),
421
+ (".qkv_proj", ".v_proj", "v"),
422
+ ]
423
+
424
+ def forward(
425
+ self,
426
+ input_ids: torch.Tensor,
427
+ positions: torch.Tensor,
428
+ forward_batch: ForwardBatch,
429
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
430
+ input_embeds: Optional[torch.Tensor] = None,
431
+ get_embedding: bool = False,
432
+ ) -> LogitsProcessorOutput:
433
+ hidden_states = self.model(
434
+ input_ids=input_ids,
435
+ positions=positions,
436
+ forward_batch=forward_batch,
437
+ input_embeds=input_embeds,
438
+ pp_proxy_tensors=pp_proxy_tensors,
439
+ )
440
+ aux_hidden_states = None
441
+ if self.capture_aux_hidden_states:
442
+ hidden_states, aux_hidden_states = hidden_states
443
+
444
+ if self.pp_group.is_last_rank:
445
+ if not get_embedding:
446
+ return self.logits_processor(
447
+ input_ids,
448
+ hidden_states,
449
+ self.lm_head,
450
+ forward_batch,
451
+ aux_hidden_states=aux_hidden_states,
452
+ )
453
+ else:
454
+ return self.pooler(hidden_states, forward_batch)
455
+ else:
456
+ return hidden_states
457
+
458
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
459
+ stacked_params_mapping = [
460
+ # (param_name, shard_name, shard_id)
461
+ ("qkv_proj", "q_proj", "q"),
462
+ ("qkv_proj", "k_proj", "k"),
463
+ ("qkv_proj", "v_proj", "v"),
464
+ ]
465
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
466
+
467
+ for name, loaded_weight in weights:
468
+ if name.startswith("decoder"):
469
+ name = name.replace("decoder.", "model.decoder.")
470
+ layer_id = get_layer_id(name)
471
+ if (
472
+ layer_id is not None
473
+ and hasattr(self.model, "start_layer")
474
+ and (
475
+ layer_id < self.model.start_layer
476
+ or layer_id >= self.model.end_layer
477
+ )
478
+ ):
479
+ continue
480
+ for param_name, weight_name, shard_id in stacked_params_mapping:
481
+ if weight_name not in name:
482
+ continue
483
+ name = name.replace(weight_name, param_name)
484
+ # Skip loading extra bias for GPTQ models.
485
+ if name.endswith(".bias") and name not in params_dict:
486
+ continue
487
+ # if is_pp_missing_parameter(name, self):
488
+ # continue
489
+ param = params_dict[name]
490
+ weight_loader = param.weight_loader
491
+ weight_loader(param, loaded_weight, shard_id)
492
+ break
493
+ else:
494
+ # Skip loading extra bias for GPTQ models.
495
+ if name.endswith(".bias") and name not in params_dict:
496
+ continue
497
+ # if is_pp_missing_parameter(name, self):
498
+ # continue
499
+ if name not in params_dict:
500
+ continue
501
+ if name in params_dict.keys():
502
+ param = params_dict[name]
503
+ weight_loader = getattr(
504
+ param, "weight_loader", default_weight_loader
505
+ )
506
+ weight_loader(param, loaded_weight)
507
+ else:
508
+ logger.warning(f"Parameter {name} not found in params_dict")
509
+
510
+ @property
511
+ def start_layer(self):
512
+ return self.model.start_layer
513
+
514
+ @property
515
+ def end_layer(self):
516
+ return self.model.end_layer
517
+
518
+ def get_input_embeddings(self) -> nn.Embedding:
519
+ return self.model.embed_tokens
520
+
521
+ def get_module_name_from_weight_name(self, name):
522
+ for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
523
+ if weight_name in name:
524
+ return (
525
+ name.replace(weight_name, param_name)[: -len(".weight")],
526
+ num_shard,
527
+ )
528
+ return name[: -len(".weight")], 1
529
+
530
+ def get_num_params(self):
531
+ params_dict = dict(self.named_parameters())
532
+ return len(params_dict)
533
+
534
+ def get_weights_by_name(
535
+ self, name: str, truncate_size: int = 100, tp_size: int = 1
536
+ ) -> Optional[torch.Tensor]:
537
+ """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
538
+
539
+ Only used for unit test with an unoptimized performance.
540
+ For optimized performance, please use torch.save and torch.load.
541
+ """
542
+ try:
543
+ if name == "lm_head.weight" and self.config.tie_word_embeddings:
544
+ logger.info(
545
+ "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
546
+ )
547
+ return (
548
+ self.model.embed_tokens.weight.cpu()
549
+ .to(torch.float32)
550
+ .numpy()
551
+ .tolist()[:truncate_size]
552
+ )
553
+
554
+ mapped_name = name
555
+ mapped_shard_id = None
556
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
557
+ if weight_name in name:
558
+ mapped_name = name.replace(weight_name, param_name)
559
+ mapped_shard_id = shard_id
560
+ break
561
+ params_dict = dict(self.named_parameters())
562
+ param = params_dict[mapped_name]
563
+ if mapped_shard_id is not None:
564
+ if mapped_shard_id in ["q", "k", "v"]:
565
+ num_heads = self.config.num_attention_heads // tp_size
566
+ num_kv_heads = self.config.num_attention_heads // tp_size
567
+ head_dim = (
568
+ self.config.hidden_size // self.config.num_attention_heads
569
+ )
570
+ if mapped_shard_id == "q":
571
+ offset = 0
572
+ size = num_heads * head_dim
573
+ elif mapped_shard_id == "k":
574
+ offset = num_heads * head_dim
575
+ size = num_kv_heads * head_dim
576
+ elif mapped_shard_id == "v":
577
+ offset = (num_heads + num_kv_heads) * head_dim
578
+ size = num_kv_heads * head_dim
579
+ weight = param.data.narrow(0, offset, size)
580
+ elif mapped_shard_id in [0, 1]:
581
+ intermediate_size = self.config.ffn_dim
582
+ slice_size = intermediate_size // tp_size
583
+ if mapped_shard_id == 0: # gate_proj
584
+ offset = 0
585
+ size = slice_size
586
+ elif mapped_shard_id == 1: # up_proj
587
+ offset = slice_size
588
+ size = slice_size
589
+
590
+ weight = param.data.narrow(0, offset, size)
591
+ else:
592
+ weight = param.data
593
+ else:
594
+ weight = param.data
595
+ if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
596
+ gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
597
+ torch.distributed.all_gather(gathered_weights, weight)
598
+ weight = torch.cat(gathered_weights, dim=1)
599
+ return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
600
+
601
+ except Exception:
602
+ logger.error(
603
+ f"Error getting weights by name {name} in OPTForCausalLM: {get_exception_traceback()}"
604
+ )
605
+ return None
606
+
607
+ def get_embed_and_head(self):
608
+ return self.model.embed_tokens.weight, self.lm_head.weight
609
+
610
+ def set_embed_and_head(self, embed, head):
611
+ del self.model.embed_tokens.weight
612
+ del self.lm_head.weight
613
+ self.model.embed_tokens.weight = embed
614
+ self.lm_head.weight = head
615
+ torch.cuda.empty_cache()
616
+ torch.cuda.synchronize()
617
+
618
+ def get_embed(self):
619
+ return self.model.embed_tokens.weight
620
+
621
+ def set_embed(self, embed):
622
+ # NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
623
+ if (
624
+ hasattr(self.config, "target_hidden_size")
625
+ and self.config.target_hidden_size != self.config.hidden_size
626
+ ):
627
+ return
628
+ del self.model.embed_tokens.weight
629
+ self.model.embed_tokens.weight = embed
630
+ torch.cuda.empty_cache()
631
+ torch.cuda.synchronize()
632
+
633
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
634
+ self.model.load_kv_cache_scales(quantization_param_path)
635
+
636
+
637
+ EntryClass = [OPTForCausalLM]
@@ -16,7 +16,7 @@
16
16
  # Modify details for the adaptation of Qwen2 model.
17
17
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
18
  import logging
19
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
19
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
20
 
21
21
  import torch
22
22
  from torch import nn
@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
431
431
  quant_config=quant_config,
432
432
  prefix=add_prefix("lm_head", prefix),
433
433
  )
434
-
435
434
  else:
436
435
  # ranks other than the last rank will have a placeholder layer
437
436
  self.lm_head = PPMissingLayer()
@@ -452,6 +451,11 @@ class Qwen2ForCausalLM(nn.Module):
452
451
 
453
452
  self.logits_processor = LogitsProcessor(config)
454
453
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
454
+ # For EAGLE3 support
455
+ self.capture_aux_hidden_states = False
456
+
457
+ # For EAGLE3 support
458
+ self.capture_aux_hidden_states = False
455
459
 
456
460
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
457
461
  return self.model.get_input_embedding(input_ids)
@@ -476,11 +480,22 @@ class Qwen2ForCausalLM(nn.Module):
476
480
  input_embeds,
477
481
  pp_proxy_tensors=pp_proxy_tensors,
478
482
  )
483
+ aux_hidden_states = None
484
+ if self.capture_aux_hidden_states:
485
+ hidden_states, aux_hidden_states = hidden_states
486
+
487
+ aux_hidden_states = None
488
+ if self.capture_aux_hidden_states:
489
+ hidden_states, aux_hidden_states = hidden_states
479
490
 
480
491
  if self.pp_group.is_last_rank:
481
492
  if not get_embedding:
482
493
  return self.logits_processor(
483
- input_ids, hidden_states, self.lm_head, forward_batch
494
+ input_ids,
495
+ hidden_states,
496
+ self.lm_head,
497
+ forward_batch,
498
+ aux_hidden_states,
484
499
  )
485
500
  else:
486
501
  return self.pooler(hidden_states, forward_batch)
@@ -619,5 +634,20 @@ class Qwen2ForCausalLM(nn.Module):
619
634
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
620
635
  self.model.load_kv_cache_scales(quantization_param_path)
621
636
 
637
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
638
+ if not self.pp_group.is_last_rank:
639
+ return
640
+
641
+ self.capture_aux_hidden_states = True
642
+ if layer_ids is None:
643
+ num_layers = self.config.num_hidden_layers
644
+ self.model.layers_to_capture = [
645
+ 2,
646
+ num_layers // 2,
647
+ num_layers - 3,
648
+ ] # Specific layers for EAGLE3 support
649
+ else:
650
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
651
+
622
652
 
623
653
  EntryClass = Qwen2ForCausalLM