sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
42
  )
43
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
44
  from sglang.srt.model_loader.weight_utils import default_weight_loader
45
+ from sglang.srt.utils import add_prefix
45
46
  from sglang.utils import get_exception_traceback
46
47
 
47
48
  logger = logging.getLogger(__name__)
@@ -62,14 +63,14 @@ class GraniteMLP(nn.Module):
62
63
  [intermediate_size] * 2,
63
64
  bias=False,
64
65
  quant_config=quant_config,
65
- prefix=f"{prefix}.gate_up_proj",
66
+ prefix=add_prefix("gate_up_proj", prefix),
66
67
  )
67
68
  self.down_proj = RowParallelLinear(
68
69
  intermediate_size,
69
70
  hidden_size,
70
71
  bias=False,
71
72
  quant_config=quant_config,
72
- prefix=f"{prefix}.down_proj",
73
+ prefix=add_prefix("down_proj", prefix),
73
74
  )
74
75
  if hidden_act != "silu":
75
76
  raise ValueError(
@@ -133,14 +134,14 @@ class GraniteAttention(nn.Module):
133
134
  self.total_num_kv_heads,
134
135
  bias=False,
135
136
  quant_config=quant_config,
136
- prefix=f"{prefix}.qkv_proj",
137
+ prefix=add_prefix("qkv_proj", prefix),
137
138
  )
138
139
  self.o_proj = RowParallelLinear(
139
140
  self.total_num_heads * self.head_dim,
140
141
  hidden_size,
141
142
  bias=False,
142
143
  quant_config=quant_config,
143
- prefix=f"{prefix}.o_proj",
144
+ prefix=add_prefix("o_proj", prefix),
144
145
  )
145
146
 
146
147
  self.rotary_emb = get_rope(
@@ -157,6 +158,7 @@ class GraniteAttention(nn.Module):
157
158
  self.scaling,
158
159
  num_kv_heads=self.num_kv_heads,
159
160
  layer_id=layer_id,
161
+ prefix=add_prefix("attn", prefix),
160
162
  )
161
163
 
162
164
  def forward(
@@ -205,14 +207,14 @@ class GraniteDecoderLayer(nn.Module):
205
207
  rope_is_neox_style=rope_is_neox_style,
206
208
  max_position_embeddings=max_position_embeddings,
207
209
  quant_config=quant_config,
208
- prefix=f"{prefix}.self_attn",
210
+ prefix=add_prefix("self_attn", prefix),
209
211
  )
210
212
  self.mlp = GraniteMLP(
211
213
  hidden_size=self.hidden_size,
212
214
  intermediate_size=config.intermediate_size,
213
215
  hidden_act=config.hidden_act,
214
216
  quant_config=quant_config,
215
- prefix=f"{prefix}.mlp",
217
+ prefix=add_prefix("mlp", prefix),
216
218
  )
217
219
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
218
220
  self.post_attention_layernorm = RMSNorm(
@@ -252,6 +254,7 @@ class GraniteModel(nn.Module):
252
254
  self,
253
255
  config: GraniteConfig,
254
256
  quant_config: Optional[QuantizationConfig] = None,
257
+ prefix: str = "",
255
258
  ) -> None:
256
259
  super().__init__()
257
260
  self.config = config
@@ -263,7 +266,10 @@ class GraniteModel(nn.Module):
263
266
  self.layers = nn.ModuleList(
264
267
  [
265
268
  GraniteDecoderLayer(
266
- config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
269
+ config,
270
+ i,
271
+ quant_config=quant_config,
272
+ prefix=add_prefix(f"layers.{i}", prefix),
267
273
  )
268
274
  for i in range(config.num_hidden_layers)
269
275
  ]
@@ -300,17 +306,23 @@ class GraniteForCausalLM(nn.Module):
300
306
  self,
301
307
  config: GraniteConfig,
302
308
  quant_config: Optional[QuantizationConfig] = None,
309
+ prefix: str = "",
303
310
  ) -> None:
304
311
  super().__init__()
305
312
  self.config = config
306
313
  self.quant_config = quant_config
307
- self.model = GraniteModel(config, quant_config=quant_config)
314
+ self.model = GraniteModel(
315
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
316
+ )
308
317
  # If tie_word_embeddings == True, then input and output embeddings are
309
318
  # the same tensor. Enforce during object creation so that weights will
310
319
  # load correctly even if the LM head weights don't have a separate entry
311
320
  # in the state dict.
312
321
  self.lm_head = ParallelLMHead(
313
- config.vocab_size, config.hidden_size, quant_config=quant_config
322
+ config.vocab_size,
323
+ config.hidden_size,
324
+ quant_config=quant_config,
325
+ prefix=add_prefix("lm_head", prefix),
314
326
  )
315
327
  if self.config.tie_word_embeddings:
316
328
  self.lm_head.tie_weights(self.model.embed_tokens)
sglang/srt/models/grok.py CHANGED
@@ -47,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
47
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
48
  from sglang.srt.model_loader.loader import DefaultModelLoader
49
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
50
+ from sglang.srt.utils import add_prefix
50
51
 
51
52
 
52
53
  class Grok1MLP(nn.Module):
@@ -65,7 +66,7 @@ class Grok1MLP(nn.Module):
65
66
  [intermediate_size] * 2,
66
67
  bias=False,
67
68
  quant_config=quant_config,
68
- prefix=f"{prefix}.gate_up_proj",
69
+ prefix=add_prefix("gate_up_proj", prefix),
69
70
  use_presharded_weights=use_presharded_weights,
70
71
  )
71
72
  self.down_proj = RowParallelLinear(
@@ -73,7 +74,7 @@ class Grok1MLP(nn.Module):
73
74
  hidden_size,
74
75
  bias=False,
75
76
  quant_config=quant_config,
76
- prefix=f"{prefix}.down_proj",
77
+ prefix=add_prefix("down_proj", prefix),
77
78
  reduce_results=reduce_results,
78
79
  use_presharded_weights=use_presharded_weights,
79
80
  )
@@ -107,6 +108,7 @@ class Grok1MoE(nn.Module):
107
108
  tp_size: Optional[int] = None,
108
109
  reduce_results=True,
109
110
  use_presharded_weights: bool = False,
111
+ prefix: str = "",
110
112
  ):
111
113
  super().__init__()
112
114
  self.hidden_size = hidden_size
@@ -118,6 +120,7 @@ class Grok1MoE(nn.Module):
118
120
  bias=False,
119
121
  params_dtype=params_dtype,
120
122
  quant_config=None,
123
+ prefix=add_prefix("gate", prefix),
121
124
  )
122
125
 
123
126
  self.router_logit_softcapping = getattr(
@@ -135,6 +138,7 @@ class Grok1MoE(nn.Module):
135
138
  tp_size=tp_size,
136
139
  activation="gelu",
137
140
  use_presharded_weights=use_presharded_weights,
141
+ prefix=add_prefix("experts", prefix),
138
142
  )
139
143
 
140
144
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -163,6 +167,7 @@ class Grok1Attention(nn.Module):
163
167
  rope_theta: float = 10000,
164
168
  quant_config: Optional[QuantizationConfig] = None,
165
169
  reduce_results: bool = True,
170
+ prefix: str = "",
166
171
  ) -> None:
167
172
  super().__init__()
168
173
  self.config = config
@@ -195,6 +200,7 @@ class Grok1Attention(nn.Module):
195
200
  self.total_num_kv_heads,
196
201
  bias=False,
197
202
  quant_config=quant_config,
203
+ prefix=add_prefix("qkv_proj", prefix),
198
204
  )
199
205
  self.o_proj = RowParallelLinear(
200
206
  self.total_num_heads * self.head_dim,
@@ -202,6 +208,7 @@ class Grok1Attention(nn.Module):
202
208
  bias=False,
203
209
  quant_config=quant_config,
204
210
  reduce_results=reduce_results,
211
+ prefix=add_prefix("o_proj", prefix),
205
212
  )
206
213
  self.rotary_emb = get_rope(
207
214
  self.head_dim,
@@ -220,6 +227,7 @@ class Grok1Attention(nn.Module):
220
227
  num_kv_heads=self.num_kv_heads,
221
228
  layer_id=layer_id,
222
229
  logit_cap=logit_cap,
230
+ prefix=add_prefix("attn", prefix),
223
231
  )
224
232
 
225
233
  def forward(
@@ -243,6 +251,7 @@ class Grok1DecoderLayer(nn.Module):
243
251
  layer_id: int = 0,
244
252
  quant_config: Optional[QuantizationConfig] = None,
245
253
  use_presharded_weights: bool = False,
254
+ prefix: str = "",
246
255
  ) -> None:
247
256
  super().__init__()
248
257
  self.num_experts = config.num_local_experts
@@ -259,6 +268,7 @@ class Grok1DecoderLayer(nn.Module):
259
268
  layer_id=layer_id,
260
269
  rope_theta=rope_theta,
261
270
  quant_config=quant_config,
271
+ prefix=add_prefix("attn", prefix),
262
272
  )
263
273
  self.block_sparse_moe = Grok1MoE(
264
274
  config=config,
@@ -273,6 +283,7 @@ class Grok1DecoderLayer(nn.Module):
273
283
  quant_config=quant_config,
274
284
  reduce_results=True,
275
285
  use_presharded_weights=use_presharded_weights,
286
+ prefix=add_prefix("block_sparse_moe", prefix),
276
287
  )
277
288
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
278
289
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -311,6 +322,7 @@ class Grok1Model(nn.Module):
311
322
  config: PretrainedConfig,
312
323
  quant_config: Optional[QuantizationConfig] = None,
313
324
  use_presharded_weights: bool = False,
325
+ prefix: str = "",
314
326
  ) -> None:
315
327
  super().__init__()
316
328
  self.config = config
@@ -320,6 +332,7 @@ class Grok1Model(nn.Module):
320
332
  self.embed_tokens = VocabParallelEmbedding(
321
333
  config.vocab_size,
322
334
  config.hidden_size,
335
+ prefix=add_prefix("embed_tokens", prefix),
323
336
  )
324
337
  self.layers = nn.ModuleList(
325
338
  [
@@ -328,6 +341,7 @@ class Grok1Model(nn.Module):
328
341
  i,
329
342
  quant_config=quant_config,
330
343
  use_presharded_weights=use_presharded_weights,
344
+ prefix=add_prefix(f"layers.{i}", prefix),
331
345
  )
332
346
  for i in range(config.num_hidden_layers)
333
347
  ]
@@ -359,7 +373,7 @@ class Grok1ForCausalLM(nn.Module):
359
373
  self,
360
374
  config: PretrainedConfig,
361
375
  quant_config: Optional[QuantizationConfig] = None,
362
- cache_config=None,
376
+ prefix: str = "",
363
377
  ) -> None:
364
378
  super().__init__()
365
379
  self.config = config
@@ -378,8 +392,11 @@ class Grok1ForCausalLM(nn.Module):
378
392
  config,
379
393
  quant_config=quant_config,
380
394
  use_presharded_weights=self.use_presharded_weights,
395
+ prefix=add_prefix("model", prefix),
396
+ )
397
+ self.lm_head = ParallelLMHead(
398
+ config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
381
399
  )
382
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
383
400
  self.logits_processor = LogitsProcessor(config)
384
401
 
385
402
  def forward(
@@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
38
  )
39
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
40
  from sglang.srt.model_loader.weight_utils import default_weight_loader
41
+ from sglang.srt.utils import add_prefix
41
42
 
42
43
 
43
44
  class InternLM2MLP(nn.Module):
@@ -47,13 +48,22 @@ class InternLM2MLP(nn.Module):
47
48
  intermediate_size: int,
48
49
  hidden_act: str,
49
50
  quant_config: Optional[QuantizationConfig] = None,
51
+ prefix: str = "",
50
52
  ) -> None:
51
53
  super().__init__()
52
54
  self.gate_up_proj = MergedColumnParallelLinear(
53
- hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
55
+ hidden_size,
56
+ [intermediate_size] * 2,
57
+ bias=False,
58
+ quant_config=quant_config,
59
+ prefix=add_prefix("gate_up_proj", prefix),
54
60
  )
55
61
  self.w2 = RowParallelLinear(
56
- intermediate_size, hidden_size, bias=False, quant_config=quant_config
62
+ intermediate_size,
63
+ hidden_size,
64
+ bias=False,
65
+ quant_config=quant_config,
66
+ prefix=add_prefix("w2", prefix),
57
67
  )
58
68
  if hidden_act != "silu":
59
69
  raise ValueError(
@@ -80,6 +90,7 @@ class InternLM2Attention(nn.Module):
80
90
  max_position_embeddings: int = 8192,
81
91
  layer_id: int = 0,
82
92
  quant_config: Optional[QuantizationConfig] = None,
93
+ prefix: str = "",
83
94
  ) -> None:
84
95
  super().__init__()
85
96
  self.hidden_size = hidden_size
@@ -111,12 +122,14 @@ class InternLM2Attention(nn.Module):
111
122
  self.total_num_kv_heads,
112
123
  bias=False,
113
124
  quant_config=quant_config,
125
+ prefix=add_prefix("wqkv", prefix),
114
126
  )
115
127
  self.wo = RowParallelLinear(
116
128
  self.total_num_heads * self.head_dim,
117
129
  hidden_size,
118
130
  bias=False,
119
131
  quant_config=quant_config,
132
+ prefix=add_prefix("wo", prefix),
120
133
  )
121
134
 
122
135
  self.rotary_emb = get_rope(
@@ -127,7 +140,12 @@ class InternLM2Attention(nn.Module):
127
140
  rope_scaling=rope_scaling,
128
141
  )
129
142
  self.attn = RadixAttention(
130
- self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, layer_id
143
+ self.num_heads,
144
+ self.head_dim,
145
+ self.scaling,
146
+ self.num_kv_heads,
147
+ layer_id,
148
+ prefix=add_prefix("attn", prefix),
131
149
  )
132
150
 
133
151
  def forward(
@@ -150,6 +168,7 @@ class InternLMDecoderLayer(nn.Module):
150
168
  config: PretrainedConfig,
151
169
  layer_id: int = 0,
152
170
  quant_config: Optional[QuantizationConfig] = None,
171
+ prefix: str = "",
153
172
  ) -> None:
154
173
  super().__init__()
155
174
  self.hidden_size = config.hidden_size
@@ -165,12 +184,14 @@ class InternLMDecoderLayer(nn.Module):
165
184
  max_position_embeddings=max_position_embeddings,
166
185
  layer_id=layer_id,
167
186
  quant_config=quant_config,
187
+ prefix=add_prefix("attention", prefix),
168
188
  )
169
189
  self.feed_forward = InternLM2MLP(
170
190
  hidden_size=self.hidden_size,
171
191
  intermediate_size=config.intermediate_size,
172
192
  hidden_act=config.hidden_act,
173
193
  quant_config=quant_config,
194
+ prefix=add_prefix("feed_forward", prefix),
174
195
  )
175
196
  self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
176
197
  self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -205,6 +226,7 @@ class InternLM2Model(nn.Module):
205
226
  self,
206
227
  config: PretrainedConfig,
207
228
  quant_config: Optional[QuantizationConfig] = None,
229
+ prefix: str = "",
208
230
  ) -> None:
209
231
  super().__init__()
210
232
  self.config = config
@@ -213,10 +235,13 @@ class InternLM2Model(nn.Module):
213
235
  self.tok_embeddings = VocabParallelEmbedding(
214
236
  config.vocab_size,
215
237
  config.hidden_size,
238
+ prefix=add_prefix("tok_embeddings", prefix),
216
239
  )
217
240
  self.layers = nn.ModuleList(
218
241
  [
219
- InternLMDecoderLayer(config, i, quant_config)
242
+ InternLMDecoderLayer(
243
+ config, i, quant_config, prefix=add_prefix(f"layers.{i}", prefix)
244
+ )
220
245
  for i in range(config.num_hidden_layers)
221
246
  ]
222
247
  )
@@ -251,12 +276,17 @@ class InternLM2ForCausalLM(nn.Module):
251
276
  self,
252
277
  config: PretrainedConfig,
253
278
  quant_config: Optional[QuantizationConfig] = None,
279
+ prefix: str = "",
254
280
  ) -> None:
255
281
  super().__init__()
256
282
  self.config = config
257
283
  self.quant_config = quant_config
258
- self.model = InternLM2Model(config, quant_config)
259
- self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
284
+ self.model = InternLM2Model(
285
+ config, quant_config, prefix=add_prefix("model", prefix)
286
+ )
287
+ self.output = ParallelLMHead(
288
+ config.vocab_size, config.hidden_size, prefix=add_prefix("output", prefix)
289
+ )
260
290
  self.logits_processor = LogitsProcessor(config)
261
291
 
262
292
  @torch.no_grad()
@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
24
  from sglang.srt.models.internlm2 import InternLM2ForCausalLM, InternLM2Model
25
+ from sglang.srt.utils import add_prefix
25
26
 
26
27
 
27
28
  class InternLM2ForRewardModel(nn.Module):
@@ -29,12 +30,15 @@ class InternLM2ForRewardModel(nn.Module):
29
30
  self,
30
31
  config: PretrainedConfig,
31
32
  quant_config: Optional[QuantizationConfig] = None,
33
+ prefix: str = "",
32
34
  ) -> None:
33
35
  super().__init__()
34
36
  self.config = config
35
37
  self.quant_config = quant_config
36
38
  self.vocab_size = config.vocab_size
37
- self.model = InternLM2Model(config, quant_config)
39
+ self.model = InternLM2Model(
40
+ config, quant_config, prefix=add_prefix("model", prefix)
41
+ )
38
42
  self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
39
43
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
40
44
 
@@ -47,8 +47,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
  from sglang.srt.model_loader.weight_utils import (
48
48
  default_weight_loader,
49
49
  kv_cache_scales_loader,
50
+ maybe_remap_kv_scale_name,
50
51
  )
51
- from sglang.srt.utils import make_layers
52
+ from sglang.srt.utils import add_prefix, make_layers
52
53
  from sglang.utils import get_exception_traceback
53
54
 
54
55
  logger = logging.getLogger(__name__)
@@ -69,14 +70,14 @@ class LlamaMLP(nn.Module):
69
70
  [intermediate_size] * 2,
70
71
  bias=False,
71
72
  quant_config=quant_config,
72
- prefix=f"{prefix}.gate_up_proj",
73
+ prefix=add_prefix("gate_up_proj", prefix),
73
74
  )
74
75
  self.down_proj = RowParallelLinear(
75
76
  intermediate_size,
76
77
  hidden_size,
77
78
  bias=False,
78
79
  quant_config=quant_config,
79
- prefix=f"{prefix}.down_proj",
80
+ prefix=add_prefix("down_proj", prefix),
80
81
  )
81
82
  if hidden_act != "silu":
82
83
  raise ValueError(
@@ -141,14 +142,14 @@ class LlamaAttention(nn.Module):
141
142
  self.total_num_kv_heads,
142
143
  bias=bias,
143
144
  quant_config=quant_config,
144
- prefix=f"{prefix}.qkv_proj",
145
+ prefix=add_prefix("qkv_proj", prefix),
145
146
  )
146
147
  self.o_proj = RowParallelLinear(
147
148
  self.total_num_heads * self.head_dim,
148
149
  hidden_size,
149
150
  bias=bias,
150
151
  quant_config=quant_config,
151
- prefix=f"{prefix}.o_proj",
152
+ prefix=add_prefix("o_proj", prefix),
152
153
  )
153
154
 
154
155
  self.rotary_emb = get_rope(
@@ -165,6 +166,7 @@ class LlamaAttention(nn.Module):
165
166
  self.scaling,
166
167
  num_kv_heads=self.num_kv_heads,
167
168
  layer_id=layer_id,
169
+ prefix=add_prefix("attn", prefix),
168
170
  )
169
171
 
170
172
  def forward(
@@ -217,7 +219,7 @@ class LlamaDecoderLayer(nn.Module):
217
219
  rope_is_neox_style=rope_is_neox_style,
218
220
  max_position_embeddings=max_position_embeddings,
219
221
  quant_config=quant_config,
220
- prefix=f"{prefix}.self_attn",
222
+ prefix=add_prefix("self_attn", prefix),
221
223
  bias=attention_bias,
222
224
  )
223
225
  self.mlp = LlamaMLP(
@@ -225,7 +227,7 @@ class LlamaDecoderLayer(nn.Module):
225
227
  intermediate_size=config.intermediate_size,
226
228
  hidden_act=config.hidden_act,
227
229
  quant_config=quant_config,
228
- prefix=f"{prefix}.mlp",
230
+ prefix=add_prefix("mlp", prefix),
229
231
  )
230
232
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
231
233
  self.post_attention_layernorm = RMSNorm(
@@ -262,6 +264,7 @@ class LlamaModel(nn.Module):
262
264
  self,
263
265
  config: LlamaConfig,
264
266
  quant_config: Optional[QuantizationConfig] = None,
267
+ prefix: str = "",
265
268
  ) -> None:
266
269
  super().__init__()
267
270
  self.config = config
@@ -271,6 +274,7 @@ class LlamaModel(nn.Module):
271
274
  config.vocab_size,
272
275
  config.hidden_size,
273
276
  quant_config=quant_config,
277
+ prefix=add_prefix("embed_tokens", prefix),
274
278
  )
275
279
  self.layers = make_layers(
276
280
  config.num_hidden_layers,
@@ -357,18 +361,24 @@ class LlamaForCausalLM(nn.Module):
357
361
  self,
358
362
  config: LlamaConfig,
359
363
  quant_config: Optional[QuantizationConfig] = None,
364
+ prefix: str = "",
360
365
  ) -> None:
361
366
  super().__init__()
362
367
  self.config = config
363
368
  self.quant_config = quant_config
364
- self.model = LlamaModel(config, quant_config=quant_config)
369
+ self.model = LlamaModel(
370
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
371
+ )
365
372
  # Llama 3.2 1B Instruct set tie_word_embeddings to True
366
373
  # Llama 3.1 8B Instruct set tie_word_embeddings to False
367
374
  if self.config.tie_word_embeddings:
368
375
  self.lm_head = self.model.embed_tokens
369
376
  else:
370
377
  self.lm_head = ParallelLMHead(
371
- config.vocab_size, config.hidden_size, quant_config=quant_config
378
+ config.vocab_size,
379
+ config.hidden_size,
380
+ quant_config=quant_config,
381
+ prefix=add_prefix("lm_head", prefix),
372
382
  )
373
383
  self.logits_processor = LogitsProcessor(config)
374
384
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -457,6 +467,13 @@ class LlamaForCausalLM(nn.Module):
457
467
  continue
458
468
  if name.startswith("model.vision_tower") and name not in params_dict:
459
469
  continue
470
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
471
+ continue
472
+ # Handle FP8 kv-scale remapping
473
+ if "scale" in name:
474
+ name = maybe_remap_kv_scale_name(name, params_dict)
475
+ if name is None:
476
+ continue
460
477
 
461
478
  for param_name, weight_name, shard_id in stacked_params_mapping:
462
479
  if weight_name not in name:
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
24
  from sglang.srt.model_loader.weight_utils import default_weight_loader
25
25
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
26
+ from sglang.srt.utils import add_prefix
26
27
 
27
28
 
28
29
  class LlamaForClassification(nn.Module):
@@ -30,11 +31,14 @@ class LlamaForClassification(nn.Module):
30
31
  self,
31
32
  config: LlamaConfig,
32
33
  quant_config: Optional[QuantizationConfig] = None,
34
+ prefix: str = "",
33
35
  ) -> None:
34
36
  super().__init__()
35
37
  self.config = config
36
38
  self.quant_config = quant_config
37
- self.model = LlamaModel(config, quant_config=quant_config)
39
+ self.model = LlamaModel(
40
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
41
+ )
38
42
 
39
43
  self.classification_head = nn.Linear(
40
44
  config.hidden_size, config.classification_out_size, bias=False
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ from sglang.srt.utils import add_prefix
17
+
16
18
  # Adapted from
17
19
  # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
18
20
  """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
@@ -55,6 +57,7 @@ class LlamaModel(nn.Module):
55
57
  self,
56
58
  config: LlamaConfig,
57
59
  quant_config: Optional[QuantizationConfig] = None,
60
+ prefix: str = "",
58
61
  ) -> None:
59
62
  super().__init__()
60
63
  self.config = config
@@ -62,11 +65,15 @@ class LlamaModel(nn.Module):
62
65
  self.embed_tokens = VocabParallelEmbedding(
63
66
  config.vocab_size,
64
67
  config.hidden_size,
68
+ prefix=add_prefix("embed_tokens", prefix),
65
69
  )
66
70
  self.layers = nn.ModuleList(
67
71
  [
68
72
  LlamaDecoderLayer(
69
- config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
73
+ config,
74
+ i,
75
+ quant_config=quant_config,
76
+ prefix=add_prefix(f"layers.{i}", prefix),
70
77
  )
71
78
  for i in range(config.num_hidden_layers)
72
79
  ]
@@ -106,20 +113,26 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
106
113
  self,
107
114
  config: LlamaConfig,
108
115
  quant_config: Optional[QuantizationConfig] = None,
109
- cache_config=None,
116
+ prefix: str = "",
110
117
  ) -> None:
111
118
  nn.Module.__init__(self)
112
119
  self.config = config
113
120
  self.quant_config = quant_config
114
- self.model = LlamaModel(config, quant_config=quant_config)
121
+ self.model = LlamaModel(
122
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
123
+ )
115
124
  # Llama 3.2 1B Instruct set tie_word_embeddings to True
116
125
  # Llama 3.1 8B Instruct set tie_word_embeddings to False
117
126
  if self.config.tie_word_embeddings:
118
127
  self.lm_head = self.model.embed_tokens
119
128
  else:
120
129
  self.lm_head = ParallelLMHead(
121
- config.vocab_size, config.hidden_size, quant_config=quant_config
130
+ getattr(config, "hot_vocab_size", config.vocab_size),
131
+ config.hidden_size,
132
+ quant_config=quant_config,
133
+ prefix=add_prefix("lm_head", prefix),
122
134
  )
135
+
123
136
  self.logits_processor = LogitsProcessor(config)
124
137
 
125
138
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -8,6 +8,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
8
8
  from sglang.srt.model_executor.model_runner import ForwardBatch
9
9
  from sglang.srt.model_loader.weight_utils import default_weight_loader
10
10
  from sglang.srt.models.llama import LlamaModel
11
+ from sglang.srt.utils import add_prefix
11
12
 
12
13
 
13
14
  class LlamaEmbeddingModel(nn.Module):
@@ -15,9 +16,12 @@ class LlamaEmbeddingModel(nn.Module):
15
16
  self,
16
17
  config: LlamaConfig,
17
18
  quant_config=None,
19
+ prefix: str = "",
18
20
  ) -> None:
19
21
  super().__init__()
20
- self.model = LlamaModel(config, quant_config=quant_config)
22
+ self.model = LlamaModel(
23
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
24
+ )
21
25
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
22
26
 
23
27
  @torch.no_grad()
@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
24
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
25
+ from sglang.srt.utils import add_prefix
25
26
 
26
27
 
27
28
  class LlamaForSequenceClassification(nn.Module):
@@ -29,12 +30,15 @@ class LlamaForSequenceClassification(nn.Module):
29
30
  self,
30
31
  config: LlamaConfig,
31
32
  quant_config: Optional[QuantizationConfig] = None,
33
+ prefix: str = "",
32
34
  ) -> None:
33
35
  super().__init__()
34
36
  self.config = config
35
37
  self.quant_config = quant_config
36
38
  self.num_labels = config.num_labels
37
- self.model = LlamaModel(config, quant_config=quant_config)
39
+ self.model = LlamaModel(
40
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
41
+ )
38
42
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
39
43
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
40
44
 
@@ -82,8 +86,9 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
82
86
  self,
83
87
  config: LlamaConfig,
84
88
  quant_config: Optional[QuantizationConfig] = None,
89
+ prefix: str = "",
85
90
  ) -> None:
86
- super().__init__(config, quant_config)
91
+ super().__init__(config, quant_config, prefix=prefix)
87
92
  self.weights = self.Weights(config.hidden_size, self.num_labels)
88
93
 
89
94
  @torch.no_grad()