sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
39
39
  )
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
41
  from sglang.srt.model_loader.weight_utils import default_weight_loader
42
+ from sglang.srt.utils import add_prefix
42
43
 
43
44
 
44
45
  class ExaoneGatedMLP(nn.Module):
@@ -56,14 +57,14 @@ class ExaoneGatedMLP(nn.Module):
56
57
  [intermediate_size] * 2,
57
58
  bias=False,
58
59
  quant_config=quant_config,
59
- prefix=f"{prefix}.gate_up_proj",
60
+ prefix=add_prefix("gate_up_proj", prefix),
60
61
  )
61
62
  self.c_proj = RowParallelLinear(
62
63
  intermediate_size,
63
64
  hidden_size,
64
65
  bias=False,
65
66
  quant_config=quant_config,
66
- prefix=f"{prefix}.c_proj",
67
+ prefix=add_prefix("c_proj", prefix),
67
68
  )
68
69
  if hidden_act != "silu":
69
70
  raise ValueError(
@@ -130,14 +131,14 @@ class ExaoneAttention(nn.Module):
130
131
  self.total_num_kv_heads,
131
132
  bias=False,
132
133
  quant_config=quant_config,
133
- prefix=f"{prefix}.qkv_proj",
134
+ prefix=add_prefix("qkv_proj", prefix),
134
135
  )
135
136
  self.out_proj = RowParallelLinear(
136
137
  self.total_num_heads * self.head_dim,
137
138
  hidden_size,
138
139
  bias=False,
139
140
  quant_config=quant_config,
140
- prefix=f"{prefix}.out_proj",
141
+ prefix=add_prefix("out_proj", prefix),
141
142
  )
142
143
 
143
144
  self.rotary_emb = get_rope(
@@ -201,14 +202,14 @@ class ExaoneDecoderLayer(nn.Module):
201
202
  rope_is_neox_style=rope_is_neox_style,
202
203
  max_position_embeddings=max_position_embeddings,
203
204
  quant_config=quant_config,
204
- prefix=f"{prefix}.self_attn",
205
+ prefix=add_prefix("self_attn", prefix),
205
206
  )
206
207
  self.mlp = ExaoneGatedMLP(
207
208
  hidden_size=self.hidden_size,
208
209
  intermediate_size=config.intermediate_size,
209
210
  hidden_act=config.activation_function,
210
211
  quant_config=quant_config,
211
- prefix=f"{prefix}.mlp",
212
+ prefix=add_prefix("mlp", prefix),
212
213
  )
213
214
  rms_norm_eps = config.layer_norm_epsilon
214
215
  self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps)
@@ -244,6 +245,7 @@ class ExaoneModel(nn.Module):
244
245
  self,
245
246
  config,
246
247
  quant_config: Optional[QuantizationConfig] = None,
248
+ prefix: str = "",
247
249
  ) -> None:
248
250
  super().__init__()
249
251
  self.config = config
@@ -256,7 +258,10 @@ class ExaoneModel(nn.Module):
256
258
  self.h = nn.ModuleList(
257
259
  [
258
260
  ExaoneDecoderLayer(
259
- config, i, quant_config=quant_config, prefix=f"model.h.{i}"
261
+ config,
262
+ i,
263
+ quant_config=quant_config,
264
+ prefix=add_prefix(f"h.{i}", prefix),
260
265
  )
261
266
  for i in range(config.num_hidden_layers)
262
267
  ]
@@ -293,12 +298,17 @@ class ExaoneForCausalLM(nn.Module):
293
298
  self,
294
299
  config,
295
300
  quant_config: Optional[QuantizationConfig] = None,
301
+ prefix: str = "",
296
302
  ) -> None:
297
303
  super().__init__()
298
304
  self.config = config
299
305
  self.quant_config = quant_config
300
- self.transformer = ExaoneModel(config, quant_config=quant_config)
301
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
306
+ self.transformer = ExaoneModel(
307
+ config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
308
+ )
309
+ self.lm_head = ParallelLMHead(
310
+ config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
311
+ )
302
312
  self.logits_processor = LogitsProcessor(config)
303
313
 
304
314
  @torch.no_grad()
@@ -37,6 +37,7 @@ from sglang.srt.layers.rotary_embedding import get_rope
37
37
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
38
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
39
  from sglang.srt.model_loader.weight_utils import default_weight_loader
40
+ from sglang.srt.utils import add_prefix
40
41
 
41
42
 
42
43
  class GemmaMLP(nn.Module):
@@ -45,6 +46,7 @@ class GemmaMLP(nn.Module):
45
46
  hidden_size: int,
46
47
  intermediate_size: int,
47
48
  quant_config: Optional[QuantizationConfig] = None,
49
+ prefix: str = "",
48
50
  ) -> None:
49
51
  super().__init__()
50
52
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -52,12 +54,14 @@ class GemmaMLP(nn.Module):
52
54
  [intermediate_size] * 2,
53
55
  bias=False,
54
56
  quant_config=quant_config,
57
+ prefix=add_prefix("gate_up_proj", prefix),
55
58
  )
56
59
  self.down_proj = RowParallelLinear(
57
60
  intermediate_size,
58
61
  hidden_size,
59
62
  bias=False,
60
63
  quant_config=quant_config,
64
+ prefix=add_prefix("down_proj", prefix),
61
65
  )
62
66
  self.act_fn = GeluAndMul("none")
63
67
 
@@ -79,6 +83,7 @@ class GemmaAttention(nn.Module):
79
83
  max_position_embeddings: int = 8192,
80
84
  rope_theta: float = 10000,
81
85
  quant_config: Optional[QuantizationConfig] = None,
86
+ prefix: str = "",
82
87
  ) -> None:
83
88
  super().__init__()
84
89
  self.hidden_size = hidden_size
@@ -109,12 +114,14 @@ class GemmaAttention(nn.Module):
109
114
  self.total_num_kv_heads,
110
115
  bias=False,
111
116
  quant_config=quant_config,
117
+ prefix=add_prefix("qkv_proj", prefix),
112
118
  )
113
119
  self.o_proj = RowParallelLinear(
114
120
  self.total_num_heads * self.head_dim,
115
121
  hidden_size,
116
122
  bias=False,
117
123
  quant_config=quant_config,
124
+ prefix=add_prefix("o_proj", prefix),
118
125
  )
119
126
 
120
127
  self.rotary_emb = get_rope(
@@ -130,6 +137,7 @@ class GemmaAttention(nn.Module):
130
137
  self.scaling,
131
138
  num_kv_heads=self.num_kv_heads,
132
139
  layer_id=layer_id,
140
+ prefix=add_prefix("attn", prefix),
133
141
  )
134
142
 
135
143
  def forward(
@@ -152,6 +160,7 @@ class GemmaDecoderLayer(nn.Module):
152
160
  config: PretrainedConfig,
153
161
  layer_id: int = 0,
154
162
  quant_config: Optional[QuantizationConfig] = None,
163
+ prefix: str = "",
155
164
  ) -> None:
156
165
  super().__init__()
157
166
  self.hidden_size = config.hidden_size
@@ -164,11 +173,13 @@ class GemmaDecoderLayer(nn.Module):
164
173
  max_position_embeddings=config.max_position_embeddings,
165
174
  rope_theta=config.rope_theta,
166
175
  quant_config=quant_config,
176
+ prefix=add_prefix("self_attn", prefix),
167
177
  )
168
178
  self.mlp = GemmaMLP(
169
179
  hidden_size=self.hidden_size,
170
180
  intermediate_size=config.intermediate_size,
171
181
  quant_config=quant_config,
182
+ prefix=add_prefix("mlp", prefix),
172
183
  )
173
184
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
185
  self.post_attention_layernorm = RMSNorm(
@@ -205,6 +216,7 @@ class GemmaModel(nn.Module):
205
216
  self,
206
217
  config: PretrainedConfig,
207
218
  quant_config: Optional[QuantizationConfig] = None,
219
+ prefix: str = "",
208
220
  ) -> None:
209
221
  super().__init__()
210
222
  self.config = config
@@ -215,7 +227,12 @@ class GemmaModel(nn.Module):
215
227
  )
216
228
  self.layers = nn.ModuleList(
217
229
  [
218
- GemmaDecoderLayer(config, i, quant_config=quant_config)
230
+ GemmaDecoderLayer(
231
+ config,
232
+ i,
233
+ quant_config=quant_config,
234
+ prefix=add_prefix(f"layers.{i}", prefix),
235
+ )
219
236
  for i in range(config.num_hidden_layers)
220
237
  ]
221
238
  )
@@ -277,11 +294,14 @@ class GemmaForCausalLM(nn.Module):
277
294
  self,
278
295
  config: PretrainedConfig,
279
296
  quant_config: Optional[QuantizationConfig] = None,
297
+ prefix: str = "",
280
298
  ) -> None:
281
299
  super().__init__()
282
300
  self.config = config
283
301
  self.quant_config = quant_config
284
- self.model = GemmaModel(config, quant_config=quant_config)
302
+ self.model = GemmaModel(
303
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
304
+ )
285
305
  self.logits_processor = LogitsProcessor(config)
286
306
 
287
307
  @torch.no_grad()
@@ -336,12 +356,6 @@ class GemmaForCausalLM(nn.Module):
336
356
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
337
357
  weight_loader(param, loaded_weight)
338
358
  loaded_params.add(name)
339
- unloaded_params = params_dict.keys() - loaded_params
340
- if unloaded_params:
341
- raise RuntimeError(
342
- "Some weights are not initialized from checkpoints: "
343
- f"{unloaded_params}"
344
- )
345
359
 
346
360
 
347
361
  EntryClass = GemmaForCausalLM
@@ -39,7 +39,7 @@ from sglang.srt.model_loader.weight_utils import (
39
39
  default_weight_loader,
40
40
  maybe_remap_kv_scale_name,
41
41
  )
42
- from sglang.srt.utils import make_layers
42
+ from sglang.srt.utils import add_prefix, make_layers
43
43
 
44
44
 
45
45
  # Aligned with HF's implementation, using sliding window inclusive with the last token
@@ -56,13 +56,22 @@ class Gemma2MLP(nn.Module):
56
56
  hidden_act: str,
57
57
  hidden_activation: str,
58
58
  quant_config: Optional[QuantizationConfig] = None,
59
+ prefix: str = "",
59
60
  ) -> None:
60
61
  super().__init__()
61
62
  self.gate_up_proj = MergedColumnParallelLinear(
62
- hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
63
+ hidden_size,
64
+ [intermediate_size] * 2,
65
+ bias=False,
66
+ quant_config=quant_config,
67
+ prefix=add_prefix("gate_up_proj", prefix),
63
68
  )
64
69
  self.down_proj = RowParallelLinear(
65
- intermediate_size, hidden_size, bias=False, quant_config=quant_config
70
+ intermediate_size,
71
+ hidden_size,
72
+ bias=False,
73
+ quant_config=quant_config,
74
+ prefix=add_prefix("down_proj", prefix),
66
75
  )
67
76
  if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
68
77
  raise ValueError(
@@ -91,6 +100,7 @@ class Gemma2Attention(nn.Module):
91
100
  max_position_embeddings: int,
92
101
  rope_theta: float,
93
102
  quant_config: Optional[QuantizationConfig] = None,
103
+ prefix: str = "",
94
104
  ) -> None:
95
105
  super().__init__()
96
106
  self.layer_id = layer_id
@@ -123,12 +133,14 @@ class Gemma2Attention(nn.Module):
123
133
  self.total_num_kv_heads,
124
134
  bias=config.attention_bias,
125
135
  quant_config=quant_config,
136
+ prefix=add_prefix("qkv_proj", prefix),
126
137
  )
127
138
  self.o_proj = RowParallelLinear(
128
139
  self.total_num_heads * self.head_dim,
129
140
  hidden_size,
130
141
  bias=config.attention_bias,
131
142
  quant_config=quant_config,
143
+ prefix=add_prefix("o_proj", prefix),
132
144
  )
133
145
  self.rotary_emb = get_rope(
134
146
  self.head_dim,
@@ -151,6 +163,7 @@ class Gemma2Attention(nn.Module):
151
163
  if use_sliding_window
152
164
  else None
153
165
  ),
166
+ prefix=add_prefix("attn", prefix),
154
167
  )
155
168
 
156
169
  def forward(
@@ -173,6 +186,7 @@ class Gemma2DecoderLayer(nn.Module):
173
186
  layer_id: int,
174
187
  config: PretrainedConfig,
175
188
  quant_config: Optional[QuantizationConfig] = None,
189
+ prefix: str = "",
176
190
  ) -> None:
177
191
  super().__init__()
178
192
  self.hidden_size = config.hidden_size
@@ -186,6 +200,7 @@ class Gemma2DecoderLayer(nn.Module):
186
200
  max_position_embeddings=config.max_position_embeddings,
187
201
  rope_theta=config.rope_theta,
188
202
  quant_config=quant_config,
203
+ prefix=add_prefix("self_attn", prefix),
189
204
  )
190
205
  self.hidden_size = config.hidden_size
191
206
  self.mlp = Gemma2MLP(
@@ -194,6 +209,7 @@ class Gemma2DecoderLayer(nn.Module):
194
209
  hidden_act=config.hidden_act,
195
210
  hidden_activation=config.hidden_activation,
196
211
  quant_config=quant_config,
212
+ prefix=add_prefix("mlp", prefix),
197
213
  )
198
214
  self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
199
215
  self.post_attention_layernorm = GemmaRMSNorm(
@@ -238,6 +254,7 @@ class Gemma2Model(nn.Module):
238
254
  self,
239
255
  config: PretrainedConfig,
240
256
  quant_config: Optional[QuantizationConfig] = None,
257
+ prefix: str = "",
241
258
  ) -> None:
242
259
  super().__init__()
243
260
  self.config = config
@@ -253,7 +270,7 @@ class Gemma2Model(nn.Module):
253
270
  config=config,
254
271
  quant_config=quant_config,
255
272
  ),
256
- prefix="",
273
+ prefix=add_prefix("layers", prefix),
257
274
  )
258
275
  self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
259
276
 
@@ -339,11 +356,14 @@ class Gemma2ForCausalLM(nn.Module):
339
356
  self,
340
357
  config: PretrainedConfig,
341
358
  quant_config: Optional[QuantizationConfig] = None,
359
+ prefix: str = "",
342
360
  ) -> None:
343
361
  super().__init__()
344
362
  self.config = config
345
363
  self.quant_config = quant_config
346
- self.model = Gemma2Model(config, quant_config)
364
+ self.model = Gemma2Model(
365
+ config, quant_config, prefix=add_prefix("model", prefix)
366
+ )
347
367
  self.logits_processor = LogitsProcessor(config)
348
368
 
349
369
  @torch.no_grad()
@@ -437,12 +457,5 @@ class Gemma2ForCausalLM(nn.Module):
437
457
  weight_loader(param, loaded_weight)
438
458
  loaded_params.add(name)
439
459
 
440
- unloaded_params = params_dict.keys() - loaded_params
441
- if unloaded_params:
442
- raise RuntimeError(
443
- "Some weights are not initialized from checkpoints: "
444
- f"{unloaded_params}"
445
- )
446
-
447
460
 
448
461
  EntryClass = Gemma2ForCausalLM
@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
24
  from sglang.srt.models.gemma2 import Gemma2ForCausalLM, Gemma2Model
25
+ from sglang.srt.utils import add_prefix
25
26
 
26
27
 
27
28
  class Gemma2ForSequenceClassification(nn.Module):
@@ -29,12 +30,15 @@ class Gemma2ForSequenceClassification(nn.Module):
29
30
  self,
30
31
  config: Gemma2Config,
31
32
  quant_config: Optional[QuantizationConfig] = None,
33
+ prefix: str = "",
32
34
  ) -> None:
33
35
  super().__init__()
34
36
  self.config = config
35
37
  self.quant_config = quant_config
36
38
  self.num_labels = config.num_labels
37
- self.model = Gemma2Model(config, quant_config=quant_config)
39
+ self.model = Gemma2Model(
40
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
41
+ )
38
42
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
39
43
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
40
44
 
sglang/srt/models/gpt2.py CHANGED
@@ -17,14 +17,14 @@
17
17
  # See the License for the specific language governing permissions and
18
18
  # limitations under the License.
19
19
  """Inference-only GPT-2 model compatible with HuggingFace weights."""
20
- from typing import Iterable, Optional, Tuple
20
+ from typing import Iterable, Optional, Tuple, Type
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import GPT2Config
25
25
 
26
26
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
27
- from sglang.srt.layers.activation import get_act_fn
27
+ from sglang.srt.layers.activation import NewGELU
28
28
  from sglang.srt.layers.linear import (
29
29
  ColumnParallelLinear,
30
30
  QKVParallelLinear,
@@ -36,6 +36,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
36
36
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
37
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
38
  from sglang.srt.model_loader.weight_utils import default_weight_loader
39
+ from sglang.srt.utils import add_prefix
39
40
 
40
41
 
41
42
  class GPT2Attention(nn.Module):
@@ -62,14 +63,14 @@ class GPT2Attention(nn.Module):
62
63
  total_num_heads,
63
64
  bias=True,
64
65
  quant_config=quant_config,
65
- prefix=f"{prefix}.c_attn",
66
+ prefix=add_prefix("c_attn", prefix),
66
67
  )
67
68
  self.c_proj = RowParallelLinear(
68
69
  self.hidden_size,
69
70
  self.hidden_size,
70
71
  bias=True,
71
72
  quant_config=quant_config,
72
- prefix=f"{prefix}.c_proj",
73
+ prefix=add_prefix("c_proj", prefix),
73
74
  )
74
75
  self.attn = RadixAttention(
75
76
  self.num_heads,
@@ -97,6 +98,7 @@ class GPT2MLP(nn.Module):
97
98
  self,
98
99
  intermediate_size: int,
99
100
  config: GPT2Config,
101
+ act_layer: Type[nn.Module] = NewGELU,
100
102
  quant_config: Optional[QuantizationConfig] = None,
101
103
  prefix: str = "",
102
104
  ):
@@ -107,18 +109,16 @@ class GPT2MLP(nn.Module):
107
109
  intermediate_size,
108
110
  bias=True,
109
111
  quant_config=quant_config,
110
- prefix=f"{prefix}.c_fc",
112
+ prefix=add_prefix("c_fc", prefix),
111
113
  )
112
114
  self.c_proj = RowParallelLinear(
113
115
  intermediate_size,
114
116
  hidden_size,
115
117
  bias=True,
116
118
  quant_config=quant_config,
117
- prefix=f"{prefix}.c_proj",
118
- )
119
- self.act = get_act_fn(
120
- config.activation_function, quant_config, intermediate_size
119
+ prefix=add_prefix("c_proj", prefix),
121
120
  )
121
+ self.act = act_layer()
122
122
 
123
123
  def forward(
124
124
  self,
@@ -136,6 +136,7 @@ class GPT2Block(nn.Module):
136
136
  self,
137
137
  layer_id: int,
138
138
  config: GPT2Config,
139
+ act_layer: Type[nn.Module] = NewGELU,
139
140
  quant_config: Optional[QuantizationConfig] = None,
140
141
  prefix: str = "",
141
142
  ):
@@ -145,10 +146,16 @@ class GPT2Block(nn.Module):
145
146
 
146
147
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
147
148
  self.attn = GPT2Attention(
148
- layer_id, config, quant_config, prefix=f"{prefix}.attn"
149
+ layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
149
150
  )
150
151
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
151
- self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
152
+ self.mlp = GPT2MLP(
153
+ inner_dim,
154
+ config,
155
+ act_layer=act_layer,
156
+ quant_config=quant_config,
157
+ prefix=add_prefix("mlp", prefix),
158
+ )
152
159
 
153
160
  def forward(
154
161
  self,
@@ -190,7 +197,12 @@ class GPT2Model(nn.Module):
190
197
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
191
198
  self.h = nn.ModuleList(
192
199
  [
193
- GPT2Block(i, config, quant_config)
200
+ GPT2Block(
201
+ i,
202
+ config,
203
+ quant_config=quant_config,
204
+ prefix=add_prefix(f"h.{i}", prefix),
205
+ )
194
206
  for i in range(config.num_hidden_layers)
195
207
  ]
196
208
  )
@@ -221,11 +233,14 @@ class GPT2LMHeadModel(nn.Module):
221
233
  self,
222
234
  config: GPT2Config,
223
235
  quant_config: Optional[QuantizationConfig] = None,
236
+ prefix: str = "",
224
237
  ):
225
238
  super().__init__()
226
239
  self.config = config
227
240
  self.quant_config = quant_config
228
- self.transformer = GPT2Model(config, quant_config, prefix="transformer")
241
+ self.transformer = GPT2Model(
242
+ config, quant_config, prefix=add_prefix("transformer", prefix)
243
+ )
229
244
  self.lm_head = self.transformer.wte
230
245
 
231
246
  self.logits_processor = LogitsProcessor(config)
@@ -35,6 +35,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
35
35
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
37
  from sglang.srt.model_loader.weight_utils import default_weight_loader
38
+ from sglang.srt.utils import add_prefix
38
39
 
39
40
 
40
41
  class GPTBigCodeAttention(nn.Module):
@@ -44,6 +45,7 @@ class GPTBigCodeAttention(nn.Module):
44
45
  layer_id: int,
45
46
  config: GPTBigCodeConfig,
46
47
  quant_config: Optional[QuantizationConfig] = None,
48
+ prefix: str = "",
47
49
  ):
48
50
  super().__init__()
49
51
  self.hidden_size = config.hidden_size
@@ -69,6 +71,7 @@ class GPTBigCodeAttention(nn.Module):
69
71
  total_num_kv_heads,
70
72
  bias=True,
71
73
  quant_config=quant_config,
74
+ prefix=add_prefix("c_attn", prefix),
72
75
  )
73
76
 
74
77
  self.c_proj = RowParallelLinear(
@@ -76,6 +79,7 @@ class GPTBigCodeAttention(nn.Module):
76
79
  self.hidden_size,
77
80
  bias=True,
78
81
  quant_config=quant_config,
82
+ prefix=add_prefix("c_proj", prefix),
79
83
  )
80
84
  self.attn = RadixAttention(
81
85
  self.num_heads,
@@ -83,6 +87,7 @@ class GPTBigCodeAttention(nn.Module):
83
87
  scaling=self.scale,
84
88
  num_kv_heads=self.num_kv_heads,
85
89
  layer_id=layer_id,
90
+ prefix=add_prefix("attn", prefix),
86
91
  )
87
92
 
88
93
  def forward(
@@ -111,6 +116,7 @@ class GPTBigMLP(nn.Module):
111
116
  intermediate_size: int,
112
117
  config: GPTBigCodeConfig,
113
118
  quant_config: Optional[QuantizationConfig] = None,
119
+ prefix: str = "",
114
120
  ):
115
121
  super().__init__()
116
122
  hidden_size = config.hidden_size
@@ -119,12 +125,14 @@ class GPTBigMLP(nn.Module):
119
125
  intermediate_size,
120
126
  bias=True,
121
127
  quant_config=quant_config,
128
+ prefix=add_prefix("c_fc", prefix),
122
129
  )
123
130
  self.c_proj = RowParallelLinear(
124
131
  intermediate_size,
125
132
  hidden_size,
126
133
  bias=True,
127
134
  quant_config=quant_config,
135
+ prefix=add_prefix("c_proj", prefix),
128
136
  )
129
137
  self.act = get_act_fn(
130
138
  config.activation_function, quant_config, intermediate_size
@@ -144,15 +152,20 @@ class GPTBigCodeBlock(nn.Module):
144
152
  layer_id: int,
145
153
  config: GPTBigCodeConfig,
146
154
  quant_config: Optional[QuantizationConfig] = None,
155
+ prefix: str = "",
147
156
  ):
148
157
  super().__init__()
149
158
  hidden_size = config.hidden_size
150
159
  inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
151
160
 
152
161
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
153
- self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
162
+ self.attn = GPTBigCodeAttention(
163
+ layer_id, config, quant_config, prefix=add_prefix("attn", prefix)
164
+ )
154
165
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
155
- self.mlp = GPTBigMLP(inner_dim, config, quant_config)
166
+ self.mlp = GPTBigMLP(
167
+ inner_dim, config, quant_config, prefix=add_prefix("mlp", prefix)
168
+ )
156
169
 
157
170
  def forward(
158
171
  self,
@@ -181,6 +194,7 @@ class GPTBigCodeModel(nn.Module):
181
194
  self,
182
195
  config: GPTBigCodeConfig,
183
196
  quant_config: Optional[QuantizationConfig] = None,
197
+ prefix: str = "",
184
198
  ):
185
199
  super().__init__()
186
200
  self.config = config
@@ -190,12 +204,17 @@ class GPTBigCodeModel(nn.Module):
190
204
  lora_vocab = 0
191
205
  self.vocab_size = config.vocab_size + lora_vocab
192
206
  self.wte = VocabParallelEmbedding(
193
- self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
207
+ self.vocab_size,
208
+ self.embed_dim,
209
+ org_num_embeddings=config.vocab_size,
210
+ prefix=add_prefix("wte", prefix),
194
211
  )
195
212
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
196
213
  self.h = nn.ModuleList(
197
214
  [
198
- GPTBigCodeBlock(i, config, quant_config)
215
+ GPTBigCodeBlock(
216
+ i, config, quant_config, prefix=add_prefix(f"h.{i}", prefix)
217
+ )
199
218
  for i in range(config.num_hidden_layers)
200
219
  ]
201
220
  )
@@ -235,13 +254,16 @@ class GPTBigCodeForCausalLM(nn.Module):
235
254
  self,
236
255
  config: GPTBigCodeConfig,
237
256
  quant_config: Optional[QuantizationConfig] = None,
257
+ prefix: str = "",
238
258
  ):
239
259
  super().__init__()
240
260
 
241
261
  self.config = config
242
262
 
243
263
  self.quant_config = quant_config
244
- self.transformer = GPTBigCodeModel(config, quant_config)
264
+ self.transformer = GPTBigCodeModel(
265
+ config, quant_config, prefix=add_prefix("transformer", prefix)
266
+ )
245
267
  self.lm_head = self.transformer.wte
246
268
  self.unpadded_vocab_size = config.vocab_size
247
269
  self.logits_processor = LogitsProcessor(config)