sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
46
  )
47
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
49
+ from sglang.srt.utils import add_prefix
49
50
 
50
51
 
51
52
  class Qwen2MoeMLP(nn.Module):
@@ -56,10 +57,15 @@ class Qwen2MoeMLP(nn.Module):
56
57
  hidden_act: str,
57
58
  quant_config: Optional[QuantizationConfig] = None,
58
59
  reduce_results: bool = True,
60
+ prefix: str = "",
59
61
  ) -> None:
60
62
  super().__init__()
61
63
  self.gate_up_proj = MergedColumnParallelLinear(
62
- hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
64
+ hidden_size,
65
+ [intermediate_size] * 2,
66
+ bias=False,
67
+ quant_config=quant_config,
68
+ prefix=add_prefix("gate_up_proj", prefix),
63
69
  )
64
70
  self.down_proj = RowParallelLinear(
65
71
  intermediate_size,
@@ -67,6 +73,7 @@ class Qwen2MoeMLP(nn.Module):
67
73
  bias=False,
68
74
  quant_config=quant_config,
69
75
  reduce_results=reduce_results,
76
+ prefix=add_prefix("down_proj", prefix),
70
77
  )
71
78
  if hidden_act != "silu":
72
79
  raise ValueError(
@@ -87,6 +94,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
87
94
  self,
88
95
  config: PretrainedConfig,
89
96
  quant_config: Optional[QuantizationConfig] = None,
97
+ prefix: str = "",
90
98
  ):
91
99
  super().__init__()
92
100
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -105,10 +113,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
105
113
  reduce_results=False,
106
114
  renormalize=config.norm_topk_prob,
107
115
  quant_config=quant_config,
116
+ prefix=add_prefix("experts", prefix),
108
117
  )
109
118
 
110
119
  self.gate = ReplicatedLinear(
111
- config.hidden_size, config.num_experts, bias=False, quant_config=None
120
+ config.hidden_size,
121
+ config.num_experts,
122
+ bias=False,
123
+ quant_config=None,
124
+ prefix=add_prefix("gate", prefix),
112
125
  )
113
126
  if config.shared_expert_intermediate_size > 0:
114
127
  self.shared_expert = Qwen2MoeMLP(
@@ -117,6 +130,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
117
130
  hidden_act=config.hidden_act,
118
131
  quant_config=quant_config,
119
132
  reduce_results=False,
133
+ prefix=add_prefix("shared_expert", prefix),
120
134
  )
121
135
  else:
122
136
  self.shared_expert = None
@@ -157,6 +171,7 @@ class Qwen2MoeAttention(nn.Module):
157
171
  rope_scaling: Optional[Dict[str, Any]] = None,
158
172
  max_position_embeddings: int = 8192,
159
173
  quant_config: Optional[QuantizationConfig] = None,
174
+ prefix: str = "",
160
175
  ) -> None:
161
176
  super().__init__()
162
177
  self.hidden_size = hidden_size
@@ -188,6 +203,7 @@ class Qwen2MoeAttention(nn.Module):
188
203
  self.total_num_kv_heads,
189
204
  bias=True,
190
205
  quant_config=quant_config,
206
+ prefix=add_prefix("qkv_proj", prefix),
191
207
  )
192
208
 
193
209
  self.o_proj = RowParallelLinear(
@@ -195,6 +211,7 @@ class Qwen2MoeAttention(nn.Module):
195
211
  hidden_size,
196
212
  bias=False,
197
213
  quant_config=quant_config,
214
+ prefix=add_prefix("o_proj", prefix),
198
215
  )
199
216
 
200
217
  self.rotary_emb = get_rope(
@@ -210,6 +227,7 @@ class Qwen2MoeAttention(nn.Module):
210
227
  self.scaling,
211
228
  num_kv_heads=self.num_kv_heads,
212
229
  layer_id=layer_id,
230
+ prefix=add_prefix("attn", prefix),
213
231
  )
214
232
 
215
233
  def forward(
@@ -232,6 +250,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
232
250
  config: PretrainedConfig,
233
251
  layer_id: int,
234
252
  quant_config: Optional[QuantizationConfig] = None,
253
+ prefix: str = "",
235
254
  ) -> None:
236
255
  super().__init__()
237
256
  self.hidden_size = config.hidden_size
@@ -247,6 +266,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
247
266
  rope_scaling=rope_scaling,
248
267
  max_position_embeddings=max_position_embeddings,
249
268
  quant_config=quant_config,
269
+ prefix=add_prefix("self_attn", prefix),
250
270
  )
251
271
 
252
272
  # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
@@ -257,13 +277,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
257
277
  if (layer_id not in mlp_only_layers) and (
258
278
  config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
259
279
  ):
260
- self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config)
280
+ self.mlp = Qwen2MoeSparseMoeBlock(
281
+ config=config,
282
+ quant_config=quant_config,
283
+ prefix=add_prefix("mlp", prefix),
284
+ )
261
285
  else:
262
286
  self.mlp = Qwen2MoeMLP(
263
287
  hidden_size=config.hidden_size,
264
288
  intermediate_size=config.intermediate_size,
265
289
  hidden_act=config.hidden_act,
266
290
  quant_config=quant_config,
291
+ prefix=add_prefix("mlp", prefix),
267
292
  )
268
293
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
269
294
  self.post_attention_layernorm = RMSNorm(
@@ -300,6 +325,7 @@ class Qwen2MoeModel(nn.Module):
300
325
  self,
301
326
  config: PretrainedConfig,
302
327
  quant_config: Optional[QuantizationConfig] = None,
328
+ prefix: str = "",
303
329
  ) -> None:
304
330
  super().__init__()
305
331
  self.padding_idx = config.pad_token_id
@@ -308,10 +334,16 @@ class Qwen2MoeModel(nn.Module):
308
334
  self.embed_tokens = VocabParallelEmbedding(
309
335
  config.vocab_size,
310
336
  config.hidden_size,
337
+ prefix=add_prefix("embed_tokens", prefix),
311
338
  )
312
339
  self.layers = nn.ModuleList(
313
340
  [
314
- Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
341
+ Qwen2MoeDecoderLayer(
342
+ config,
343
+ layer_id,
344
+ quant_config=quant_config,
345
+ prefix=add_prefix(f"layers.{layer_id}", prefix),
346
+ )
315
347
  for layer_id in range(config.num_hidden_layers)
316
348
  ]
317
349
  )
@@ -346,13 +378,19 @@ class Qwen2MoeForCausalLM(nn.Module):
346
378
  self,
347
379
  config: PretrainedConfig,
348
380
  quant_config: Optional[QuantizationConfig] = None,
381
+ prefix: str = "",
349
382
  ) -> None:
350
383
  super().__init__()
351
384
  self.config = config
352
385
  self.quant_config = quant_config
353
- self.model = Qwen2MoeModel(config, quant_config)
386
+ self.model = Qwen2MoeModel(
387
+ config, quant_config, prefix=add_prefix("model", prefix)
388
+ )
354
389
  self.lm_head = ParallelLMHead(
355
- config.vocab_size, config.hidden_size, quant_config=quant_config
390
+ config.vocab_size,
391
+ config.hidden_size,
392
+ quant_config=quant_config,
393
+ prefix=add_prefix("lm_head", prefix),
356
394
  )
357
395
  self.logits_processor = LogitsProcessor(config)
358
396
 
@@ -0,0 +1,78 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from typing import Iterable, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers import Qwen2Config
20
+
21
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
25
+ from sglang.srt.utils import add_prefix
26
+
27
+
28
+ class Qwen2ForRewardModel(nn.Module):
29
+ def __init__(
30
+ self,
31
+ config: Qwen2Config,
32
+ quant_config: Optional[QuantizationConfig] = None,
33
+ prefix: str = "",
34
+ ) -> None:
35
+ super().__init__()
36
+ self.config = config
37
+ self.quant_config = quant_config
38
+ self.num_labels = 1
39
+ self.model = Qwen2Model(
40
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
41
+ )
42
+ self.score = nn.Sequential(
43
+ nn.Linear(config.hidden_size, config.hidden_size),
44
+ nn.ReLU(),
45
+ nn.Linear(config.hidden_size, self.num_labels),
46
+ )
47
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
48
+
49
+ self.eos_token_id = config.eos_token_id
50
+
51
+ @torch.no_grad()
52
+ def forward(
53
+ self,
54
+ input_ids: torch.Tensor,
55
+ positions: torch.Tensor,
56
+ forward_batch: ForwardBatch,
57
+ input_embeds: torch.Tensor = None,
58
+ get_embedding: bool = True,
59
+ ) -> EmbeddingPoolerOutput:
60
+ assert get_embedding, "Qwen2ForRewardModel is only used for embedding"
61
+
62
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
63
+ logits = self.score(hidden_states)
64
+ pooled_logits = self.pooler(logits, forward_batch).embeddings
65
+
66
+ return EmbeddingPoolerOutput(pooled_logits)
67
+
68
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
69
+ # Filter out lm_head weights of Qwen2ForCausalLM
70
+ filtered_weights = [
71
+ (name, w) for name, w in weights if not name.startswith("lm_head")
72
+ ]
73
+ return Qwen2ForCausalLM.load_weights(self, filtered_weights)
74
+
75
+
76
+ EntryClass = [
77
+ Qwen2ForRewardModel,
78
+ ]
@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
46
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
  from sglang.srt.model_loader.weight_utils import default_weight_loader
48
48
  from sglang.srt.models.qwen2 import Qwen2Model
49
+ from sglang.srt.utils import add_prefix
49
50
 
50
51
  logger = logging.getLogger(__name__)
51
52
 
@@ -91,14 +92,21 @@ class Qwen2VisionMLP(nn.Module):
91
92
  hidden_features: int = None,
92
93
  act_layer: Type[nn.Module] = QuickGELU,
93
94
  quant_config: Optional[QuantizationConfig] = None,
95
+ prefix: str = "",
94
96
  ):
95
97
  super().__init__()
96
98
  self.fc1 = ColumnParallelLinear(
97
- in_features, hidden_features, quant_config=quant_config
99
+ in_features,
100
+ hidden_features,
101
+ quant_config=quant_config,
102
+ prefix=add_prefix("fc1", prefix),
98
103
  )
99
104
  self.act = act_layer()
100
105
  self.fc2 = RowParallelLinear(
101
- hidden_features, in_features, quant_config=quant_config
106
+ hidden_features,
107
+ in_features,
108
+ quant_config=quant_config,
109
+ prefix=add_prefix("fc2", prefix),
102
110
  )
103
111
 
104
112
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -119,6 +127,7 @@ class Qwen2VisionBlock(nn.Module):
119
127
  norm_layer: Type[nn.Module] = None,
120
128
  attn_implementation: Optional[str] = "sdpa",
121
129
  quant_config: Optional[QuantizationConfig] = None,
130
+ prefix: str = "",
122
131
  ) -> None:
123
132
  super().__init__()
124
133
  if norm_layer is None:
@@ -145,9 +154,14 @@ class Qwen2VisionBlock(nn.Module):
145
154
  use_full_precision_softmax=use_full_precision_softmax,
146
155
  flatten_batch=True,
147
156
  quant_config=quant_config,
157
+ prefix=add_prefix("attn", prefix),
148
158
  )
149
159
  self.mlp = Qwen2VisionMLP(
150
- dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config
160
+ dim,
161
+ mlp_hidden_dim,
162
+ act_layer=act_layer,
163
+ quant_config=quant_config,
164
+ prefix=add_prefix("mlp", prefix),
151
165
  )
152
166
 
153
167
  def forward(
@@ -199,6 +213,7 @@ class Qwen2VisionPatchMerger(nn.Module):
199
213
  norm_layer: Type[nn.Module] = None,
200
214
  spatial_merge_size: int = 2,
201
215
  quant_config: Optional[QuantizationConfig] = None,
216
+ prefix: str = "",
202
217
  ) -> None:
203
218
  super().__init__()
204
219
  self.hidden_size = context_dim * (spatial_merge_size**2)
@@ -212,10 +227,15 @@ class Qwen2VisionPatchMerger(nn.Module):
212
227
  self.hidden_size,
213
228
  bias=True,
214
229
  quant_config=quant_config,
230
+ prefix=add_prefix("mlp.0", prefix),
215
231
  ),
216
232
  nn.GELU(),
217
233
  RowParallelLinear(
218
- self.hidden_size, d_model, bias=True, quant_config=quant_config
234
+ self.hidden_size,
235
+ d_model,
236
+ bias=True,
237
+ quant_config=quant_config,
238
+ prefix=add_prefix("mlp.2", prefix),
219
239
  ),
220
240
  ]
221
241
  )
@@ -273,6 +293,7 @@ class Qwen2VisionTransformer(nn.Module):
273
293
  vision_config: Qwen2VLVisionConfig,
274
294
  norm_eps: float = 1e-6,
275
295
  quant_config: Optional[QuantizationConfig] = None,
296
+ prefix: str = "",
276
297
  ) -> None:
277
298
  super().__init__()
278
299
 
@@ -307,8 +328,9 @@ class Qwen2VisionTransformer(nn.Module):
307
328
  norm_layer=norm_layer,
308
329
  attn_implementation="sdpa",
309
330
  quant_config=quant_config,
331
+ prefix=add_prefix(f"blocks.{i}", prefix),
310
332
  )
311
- for _ in range(depth)
333
+ for i in range(depth)
312
334
  ]
313
335
  )
314
336
  self.merger = Qwen2VisionPatchMerger(
@@ -316,6 +338,7 @@ class Qwen2VisionTransformer(nn.Module):
316
338
  context_dim=embed_dim,
317
339
  norm_layer=norm_layer,
318
340
  quant_config=quant_config,
341
+ prefix=add_prefix("merger", prefix),
319
342
  )
320
343
 
321
344
  @property
@@ -440,6 +463,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
440
463
  self,
441
464
  config: Qwen2VLConfig,
442
465
  quant_config: Optional[QuantizationConfig] = None,
466
+ prefix: str = "",
443
467
  ) -> None:
444
468
  super().__init__()
445
469
 
@@ -450,15 +474,21 @@ class Qwen2VLForConditionalGeneration(nn.Module):
450
474
  # NOTE: Qwen2-VL vision encoder does not support any
451
475
  # quantization method now.
452
476
  quant_config=None,
477
+ prefix=add_prefix("visual", prefix),
453
478
  )
454
479
 
455
- self.model = Qwen2Model(config, quant_config)
480
+ self.model = Qwen2Model(
481
+ config, quant_config, prefix=add_prefix("model", prefix)
482
+ )
456
483
 
457
484
  if config.tie_word_embeddings:
458
485
  self.lm_head = self.model.embed_tokens
459
486
  else:
460
487
  self.lm_head = ParallelLMHead(
461
- config.vocab_size, config.hidden_size, quant_config=quant_config
488
+ config.vocab_size,
489
+ config.hidden_size,
490
+ quant_config=quant_config,
491
+ prefix=add_prefix("lm_head", prefix),
462
492
  )
463
493
 
464
494
  self.logits_processor = LogitsProcessor(config)
@@ -559,7 +589,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
559
589
  ]
560
590
  image_embeds_offset += num_image_tokens
561
591
 
562
- input_ids = None
563
592
  hidden_states = self.model(
564
593
  input_ids=input_ids,
565
594
  positions=positions,
@@ -587,6 +616,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
587
616
  for name, loaded_weight in weights:
588
617
  if "rotary_emb.inv_freq" in name:
589
618
  continue
619
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
620
+ continue
590
621
 
591
622
  for param_name, weight_name, shard_id in stacked_params_mapping:
592
623
  if weight_name not in name:
@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
42
  )
43
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
44
  from sglang.srt.model_loader.weight_utils import default_weight_loader
45
+ from sglang.srt.utils import add_prefix
45
46
 
46
47
 
47
48
  class StablelmMLP(nn.Module):
@@ -49,6 +50,7 @@ class StablelmMLP(nn.Module):
49
50
  self,
50
51
  config: PretrainedConfig,
51
52
  quant_config: Optional[QuantizationConfig] = None,
53
+ prefix: str = "",
52
54
  ) -> None:
53
55
  super().__init__()
54
56
  self.config = config
@@ -59,12 +61,14 @@ class StablelmMLP(nn.Module):
59
61
  [config.intermediate_size] * 2,
60
62
  bias=False,
61
63
  quant_config=quant_config,
64
+ prefix=add_prefix("gate_up_proj", prefix),
62
65
  )
63
66
  self.down_proj = RowParallelLinear(
64
67
  config.intermediate_size,
65
68
  config.hidden_size,
66
69
  bias=False,
67
70
  quant_config=quant_config,
71
+ prefix=add_prefix("down_proj", prefix),
68
72
  )
69
73
  self.act_fn = SiluAndMul()
70
74
 
@@ -81,6 +85,7 @@ class StablelmAttention(nn.Module):
81
85
  config: PretrainedConfig,
82
86
  layer_id: int = 0,
83
87
  quant_config: Optional[QuantizationConfig] = None,
88
+ prefix: str = "",
84
89
  ) -> None:
85
90
  super().__init__()
86
91
  self.config = config
@@ -122,11 +127,15 @@ class StablelmAttention(nn.Module):
122
127
  self.total_num_heads,
123
128
  self.total_num_key_value_heads,
124
129
  self.qkv_bias,
130
+ quant_config=quant_config,
131
+ prefix=add_prefix("qkv_proj", prefix),
125
132
  )
126
133
  self.o_proj = RowParallelLinear(
127
134
  self.total_num_heads * self.head_dim,
128
135
  self.hidden_size,
129
136
  bias=False,
137
+ quant_config=quant_config,
138
+ prefix=add_prefix("o_proj", prefix),
130
139
  )
131
140
  self.rotary_emb = get_rope(
132
141
  self.head_dim,
@@ -140,6 +149,7 @@ class StablelmAttention(nn.Module):
140
149
  self.scaling,
141
150
  num_kv_heads=self.num_key_value_heads,
142
151
  layer_id=layer_id,
152
+ prefix=add_prefix("attn", prefix),
143
153
  )
144
154
 
145
155
  def forward(
@@ -162,10 +172,15 @@ class StablelmDecoderLayer(nn.Module):
162
172
  config: PretrainedConfig,
163
173
  layer_id: int = 0,
164
174
  quant_config: Optional[QuantizationConfig] = None,
175
+ prefix: str = "",
165
176
  ) -> None:
166
177
  super().__init__()
167
- self.self_attn = StablelmAttention(config, layer_id=layer_id)
168
- self.mlp = StablelmMLP(config, quant_config=quant_config)
178
+ self.self_attn = StablelmAttention(
179
+ config, layer_id=layer_id, prefix=add_prefix("self_attn", prefix)
180
+ )
181
+ self.mlp = StablelmMLP(
182
+ config, quant_config=quant_config, prefix=add_prefix("mlp", prefix)
183
+ )
169
184
  norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
170
185
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
171
186
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
@@ -200,15 +215,22 @@ class StableLMEpochModel(nn.Module):
200
215
  self,
201
216
  config: PretrainedConfig,
202
217
  quant_config: Optional[QuantizationConfig] = None,
218
+ prefix: str = "",
203
219
  ) -> None:
204
220
  super().__init__()
205
221
  self.embed_tokens = VocabParallelEmbedding(
206
222
  config.vocab_size,
207
223
  config.hidden_size,
224
+ prefix=add_prefix("embed_tokens", prefix),
208
225
  )
209
226
  self.layers = nn.ModuleList(
210
227
  [
211
- StablelmDecoderLayer(config, i, quant_config=quant_config)
228
+ StablelmDecoderLayer(
229
+ config,
230
+ i,
231
+ quant_config=quant_config,
232
+ prefix=add_prefix(f"layers.{i}", prefix),
233
+ )
212
234
  for i in range(config.num_hidden_layers)
213
235
  ]
214
236
  )
@@ -242,12 +264,17 @@ class StableLmForCausalLM(nn.Module):
242
264
  self,
243
265
  config: PretrainedConfig,
244
266
  quant_config: Optional[QuantizationConfig] = None,
267
+ prefix: str = "",
245
268
  ) -> None:
246
269
  super().__init__()
247
270
  self.config = config
248
271
  self.quant_config = quant_config
249
- self.model = StableLMEpochModel(config, quant_config=quant_config)
250
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
272
+ self.model = StableLMEpochModel(
273
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
274
+ )
275
+ self.lm_head = ParallelLMHead(
276
+ config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
277
+ )
251
278
  self.logits_processor = LogitsProcessor(config)
252
279
 
253
280
  @torch.no_grad()
@@ -64,6 +64,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
64
64
  )
65
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
66
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
+ from sglang.srt.utils import add_prefix
67
68
 
68
69
  tp_size = get_tensor_model_parallel_world_size()
69
70
  tp_rank = get_tensor_model_parallel_rank()
@@ -294,14 +295,14 @@ class LlamaDecoderLayer(nn.Module):
294
295
  rope_is_neox_style=rope_is_neox_style,
295
296
  max_position_embeddings=max_position_embeddings,
296
297
  quant_config=quant_config,
297
- prefix=f"{prefix}.self_attn",
298
+ prefix=add_prefix("self_attn", prefix),
298
299
  )
299
300
  self.mlp = LlamaMLP(
300
301
  hidden_size=self.hidden_size,
301
302
  intermediate_size=config.intermediate_size,
302
303
  hidden_act=config.hidden_act,
303
304
  quant_config=quant_config,
304
- prefix=f"{prefix}.mlp",
305
+ prefix=add_prefix("mlp", prefix),
305
306
  )
306
307
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
307
308
  self.post_attention_layernorm = RMSNorm(
@@ -486,6 +487,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
486
487
  continue
487
488
  if name.startswith("model.vision_tower") and name not in params_dict:
488
489
  continue
490
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
491
+ continue
489
492
 
490
493
  for param_name, weight_name, shard_id in stacked_params_mapping:
491
494
  if weight_name not in name:
@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  )
41
41
  from sglang.srt.model_executor.model_runner import ForwardBatch
42
42
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
+ from sglang.srt.utils import add_prefix
43
44
 
44
45
 
45
46
  class XverseMLP(nn.Module):
@@ -57,14 +58,14 @@ class XverseMLP(nn.Module):
57
58
  [intermediate_size] * 2,
58
59
  bias=False,
59
60
  quant_config=quant_config,
60
- prefix=f"{prefix}.gate_up_proj",
61
+ prefix=add_prefix("gate_up_proj", prefix),
61
62
  )
62
63
  self.down_proj = RowParallelLinear(
63
64
  intermediate_size,
64
65
  hidden_size,
65
66
  bias=False,
66
67
  quant_config=quant_config,
67
- prefix=f"{prefix}.down_proj",
68
+ prefix=add_prefix("down_proj", prefix),
68
69
  )
69
70
  if hidden_act != "silu":
70
71
  raise ValueError(
@@ -128,14 +129,14 @@ class XverseAttention(nn.Module):
128
129
  self.total_num_kv_heads,
129
130
  bias=False,
130
131
  quant_config=quant_config,
131
- prefix=f"{prefix}.qkv_proj",
132
+ prefix=add_prefix("qkv_proj", prefix),
132
133
  )
133
134
  self.o_proj = RowParallelLinear(
134
135
  self.total_num_heads * self.head_dim,
135
136
  hidden_size,
136
137
  bias=False,
137
138
  quant_config=quant_config,
138
- prefix=f"{prefix}.o_proj",
139
+ prefix=add_prefix("o_proj", prefix),
139
140
  )
140
141
 
141
142
  self.rotary_emb = get_rope(
@@ -152,6 +153,7 @@ class XverseAttention(nn.Module):
152
153
  self.scaling,
153
154
  num_kv_heads=self.num_kv_heads,
154
155
  layer_id=layer_id,
156
+ prefix=add_prefix("attn", prefix),
155
157
  )
156
158
 
157
159
  def forward(
@@ -202,14 +204,14 @@ class XverseDecoderLayer(nn.Module):
202
204
  rope_is_neox_style=rope_is_neox_style,
203
205
  max_position_embeddings=max_position_embeddings,
204
206
  quant_config=quant_config,
205
- prefix=f"{prefix}.self_attn",
207
+ prefix=add_prefix("self_attn", prefix),
206
208
  )
207
209
  self.mlp = XverseMLP(
208
210
  hidden_size=self.hidden_size,
209
211
  intermediate_size=config.intermediate_size,
210
212
  hidden_act=config.hidden_act,
211
213
  quant_config=quant_config,
212
- prefix=f"{prefix}.mlp",
214
+ prefix=add_prefix("mlp", prefix),
213
215
  )
214
216
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
215
217
  self.post_attention_layernorm = RMSNorm(
@@ -246,6 +248,7 @@ class XverseModel(nn.Module):
246
248
  self,
247
249
  config: LlamaConfig,
248
250
  quant_config: Optional[QuantizationConfig] = None,
251
+ prefix: str = "",
249
252
  ) -> None:
250
253
  super().__init__()
251
254
  self.config = config
@@ -254,11 +257,15 @@ class XverseModel(nn.Module):
254
257
  self.embed_tokens = VocabParallelEmbedding(
255
258
  config.vocab_size,
256
259
  config.hidden_size,
260
+ prefix=add_prefix("embed_tokens", prefix),
257
261
  )
258
262
  self.layers = nn.ModuleList(
259
263
  [
260
264
  XverseDecoderLayer(
261
- config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
265
+ config,
266
+ i,
267
+ quant_config=quant_config,
268
+ prefix=add_prefix(f"layers.{i}", prefix),
262
269
  )
263
270
  for i in range(config.num_hidden_layers)
264
271
  ]
@@ -295,12 +302,17 @@ class XverseForCausalLM(nn.Module):
295
302
  self,
296
303
  config: LlamaConfig,
297
304
  quant_config: Optional[QuantizationConfig] = None,
305
+ prefix: str = "",
298
306
  ) -> None:
299
307
  super().__init__()
300
308
  self.config = config
301
309
  self.quant_config = quant_config
302
- self.model = XverseModel(config, quant_config=quant_config)
303
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
310
+ self.model = XverseModel(
311
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
312
+ )
313
+ self.lm_head = ParallelLMHead(
314
+ config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
315
+ )
304
316
  self.logits_processor = LogitsProcessor(config)
305
317
 
306
318
  @torch.no_grad()