sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -25,10 +25,10 @@ import filelock
25
25
  import gguf
26
26
  import huggingface_hub.constants
27
27
  import numpy as np
28
+ import safetensors.torch
28
29
  import torch
29
30
  from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
30
31
  from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
31
- from safetensors.torch import load_file, safe_open, save_file
32
32
  from tqdm.auto import tqdm
33
33
 
34
34
  from sglang.srt.configs.load_config import LoadConfig
@@ -62,7 +62,6 @@ enable_hf_transfer()
62
62
 
63
63
 
64
64
  class DisabledTqdm(tqdm):
65
-
66
65
  def __init__(self, *args, **kwargs):
67
66
  super().__init__(*args, **kwargs, disable=True)
68
67
 
@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file(
121
120
  )
122
121
 
123
122
  # check if the tensors are the same
124
- reloaded = load_file(sf_filename)
123
+ reloaded = safetensors.torch.load_file(sf_filename)
125
124
  for k in loaded:
126
125
  pt_tensor = loaded[k]
127
126
  sf_tensor = reloaded[k]
@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file(
133
132
  def get_quant_config(
134
133
  model_config: ModelConfig, load_config: LoadConfig
135
134
  ) -> QuantizationConfig:
136
-
137
135
  quant_cls = get_quantization_config(model_config.quantization)
138
136
 
139
137
  # GGUF doesn't have config file
@@ -402,15 +400,34 @@ def np_cache_weights_iterator(
402
400
  yield name, torch.from_numpy(param)
403
401
 
404
402
 
403
+ def decrypt(fn, key):
404
+ raise NotImplementedError()
405
+
406
+
407
+ def safetensors_encrypted_weights_iterator(
408
+ hf_weights_files: List[str],
409
+ is_all_weights_sharded: bool = False,
410
+ decryption_key: Optional[str] = None,
411
+ ):
412
+ raise NotImplementedError()
413
+
414
+
405
415
  def safetensors_weights_iterator(
406
416
  hf_weights_files: List[str],
407
417
  is_all_weights_sharded: bool = False,
418
+ decryption_key: Optional[str] = None,
408
419
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
409
420
  """Iterate over the weights in the model safetensor files.
410
421
 
411
422
  If is_all_weights_sharded is True, it uses more optimize read by reading an
412
423
  entire file instead of reading each tensor one by one.
413
424
  """
425
+ if decryption_key:
426
+ yield from safetensors_encrypted_weights_iterator(
427
+ hf_weights_files, is_all_weights_sharded, decryption_key
428
+ )
429
+ return
430
+
414
431
  enable_tqdm = (
415
432
  not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
416
433
  )
@@ -420,15 +437,9 @@ def safetensors_weights_iterator(
420
437
  disable=not enable_tqdm,
421
438
  bar_format=_BAR_FORMAT,
422
439
  ):
423
- if not is_all_weights_sharded:
424
- with safe_open(st_file, framework="pt") as f:
425
- for name in f.keys(): # noqa: SIM118
426
- param = f.get_tensor(name)
427
- yield name, param
428
- else:
429
- result = load_file(st_file, device="cpu")
430
- for name, param in result.items():
431
- yield name, param
440
+ result = safetensors.torch.load_file(st_file, device="cpu")
441
+ for name, param in result.items():
442
+ yield name, param
432
443
 
433
444
 
434
445
  def pt_weights_iterator(
@@ -644,9 +655,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
644
655
  return remapped_name
645
656
 
646
657
  possible_scale_names = [".k_scale", ".v_scale"]
658
+ modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"]
647
659
  for scale_name in possible_scale_names:
648
660
  if name.endswith(scale_name):
649
- remapped_name = name.replace(scale_name, f".attn{scale_name}")
661
+ # Check and remap the name based on modelopt scale names
662
+ if any(
663
+ modelopt_scale_name in name
664
+ for modelopt_scale_name in modelopt_scale_names
665
+ ):
666
+ remapped_name = name.replace(
667
+ f".self_attn.{scale_name[1]}_proj{scale_name}",
668
+ f".self_attn.attn{scale_name}",
669
+ )
670
+ else:
671
+ remapped_name = name.replace(scale_name, f".attn{scale_name}")
650
672
  if remapped_name not in params_dict:
651
673
  print_warning_once(
652
674
  f"Found {scale_name} in the checkpoint (e.g. {name}), "
@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
46
  )
47
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
49
+ from sglang.srt.utils import add_prefix
49
50
 
50
51
 
51
52
  def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -80,13 +81,22 @@ class BaiChuanMLP(nn.Module):
80
81
  intermediate_size: int,
81
82
  hidden_act: str,
82
83
  quant_config: Optional[QuantizationConfig] = None,
84
+ prefix: str = "",
83
85
  ):
84
86
  super().__init__()
85
87
  self.gate_up_proj = MergedColumnParallelLinear(
86
- hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
88
+ hidden_size,
89
+ [intermediate_size] * 2,
90
+ bias=False,
91
+ quant_config=quant_config,
92
+ prefix=add_prefix("gate_up_proj", prefix),
87
93
  )
88
94
  self.down_proj = RowParallelLinear(
89
- intermediate_size, hidden_size, bias=False, quant_config=quant_config
95
+ intermediate_size,
96
+ hidden_size,
97
+ bias=False,
98
+ quant_config=quant_config,
99
+ prefix=add_prefix("down_proj", prefix),
90
100
  )
91
101
  if hidden_act != "silu":
92
102
  raise ValueError(
@@ -114,6 +124,7 @@ class BaiChuanAttention(nn.Module):
114
124
  max_position_embeddings: int = 8192,
115
125
  quant_config: Optional[QuantizationConfig] = None,
116
126
  layer_id: int = 0,
127
+ prefix: str = "",
117
128
  ):
118
129
  super().__init__()
119
130
  self.hidden_size = hidden_size
@@ -167,6 +178,7 @@ class BaiChuanAttention(nn.Module):
167
178
  scaling,
168
179
  num_kv_heads=self.num_kv_heads,
169
180
  layer_id=layer_id,
181
+ prefix=add_prefix("attn", prefix),
170
182
  )
171
183
  else:
172
184
  self.rotary_emb = get_rope(
@@ -182,6 +194,7 @@ class BaiChuanAttention(nn.Module):
182
194
  self.scaling,
183
195
  num_kv_heads=self.num_kv_heads,
184
196
  layer_id=layer_id,
197
+ prefix=add_prefix("attn", prefix),
185
198
  )
186
199
 
187
200
  def forward(
@@ -207,6 +220,7 @@ class BaiChuanDecoderLayer(nn.Module):
207
220
  position_embedding: str,
208
221
  layer_id: int = 0,
209
222
  quant_config: Optional[QuantizationConfig] = None,
223
+ prefix: str = "",
210
224
  ):
211
225
  super().__init__()
212
226
  self.hidden_size = config.hidden_size
@@ -220,12 +234,14 @@ class BaiChuanDecoderLayer(nn.Module):
220
234
  layer_id=layer_id,
221
235
  max_position_embeddings=max_position_embeddings,
222
236
  quant_config=quant_config,
237
+ prefix=add_prefix("self_attn", prefix),
223
238
  )
224
239
  self.mlp = BaiChuanMLP(
225
240
  hidden_size=self.hidden_size,
226
241
  intermediate_size=config.intermediate_size,
227
242
  hidden_act=config.hidden_act,
228
243
  quant_config=quant_config,
244
+ prefix=add_prefix("mlp", prefix),
229
245
  )
230
246
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
231
247
  self.post_attention_layernorm = RMSNorm(
@@ -264,6 +280,7 @@ class BaiChuanModel(nn.Module):
264
280
  config: PretrainedConfig,
265
281
  position_embedding: str,
266
282
  quant_config: Optional[QuantizationConfig] = None,
283
+ prefix: str = "",
267
284
  ):
268
285
  super().__init__()
269
286
  self.config = config
@@ -281,6 +298,7 @@ class BaiChuanModel(nn.Module):
281
298
  layer_id=i,
282
299
  position_embedding=position_embedding,
283
300
  quant_config=quant_config,
301
+ prefix=add_prefix(f"layers.{i}", prefix),
284
302
  )
285
303
  for i in range(config.num_hidden_layers)
286
304
  ]
@@ -330,18 +348,24 @@ class BaiChuanBaseForCausalLM(nn.Module):
330
348
  config: PretrainedConfig,
331
349
  position_embedding: str,
332
350
  quant_config: Optional[QuantizationConfig] = None,
351
+ prefix: str = "",
333
352
  ):
334
353
  super().__init__()
335
354
 
336
355
  self.config = config
337
356
 
338
357
  self.quant_config = quant_config
339
- self.model = BaiChuanModel(config, position_embedding, quant_config)
358
+ self.model = BaiChuanModel(
359
+ config, position_embedding, quant_config, prefix=add_prefix("model", prefix)
360
+ )
340
361
  if self.config.tie_word_embeddings:
341
362
  self.lm_head = self.model.embed_tokens
342
363
  else:
343
364
  self.lm_head = ParallelLMHead(
344
- config.vocab_size, config.hidden_size, quant_config=quant_config
365
+ config.vocab_size,
366
+ config.hidden_size,
367
+ quant_config=quant_config,
368
+ prefix=add_prefix("lm_head", prefix),
345
369
  )
346
370
  self.logits_processor = LogitsProcessor(config)
347
371
 
@@ -404,11 +428,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
404
428
  self,
405
429
  config,
406
430
  quant_config: Optional[QuantizationConfig] = None,
431
+ prefix: str = "",
407
432
  ):
408
433
  if config.hidden_size == 4096: # baichuan2 7b
409
- super().__init__(config, "ROPE", quant_config)
434
+ super().__init__(config, "ROPE", quant_config, prefix=prefix)
410
435
  else: # baichuan 13b, baichuan2 13b
411
- super().__init__(config, "ALIBI", quant_config)
436
+ super().__init__(config, "ALIBI", quant_config, prefix=prefix)
412
437
 
413
438
 
414
439
  EntryClass = [BaichuanForCausalLM]
@@ -41,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
41
41
  )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
  from sglang.srt.model_loader.weight_utils import default_weight_loader
44
+ from sglang.srt.utils import add_prefix
44
45
 
45
46
  LoraConfig = None
46
47
 
@@ -51,6 +52,7 @@ class GLMAttention(nn.Module):
51
52
  config,
52
53
  layer_id: int = 0,
53
54
  quant_config: Optional[QuantizationConfig] = None,
55
+ prefix: str = "",
54
56
  ):
55
57
  super().__init__()
56
58
  self.hidden_size = config.hidden_size
@@ -85,12 +87,14 @@ class GLMAttention(nn.Module):
85
87
  self.total_num_kv_heads,
86
88
  bias=config.add_bias_linear or config.add_qkv_bias,
87
89
  quant_config=quant_config,
90
+ prefix=add_prefix("query_key_value", prefix),
88
91
  )
89
92
  self.dense = RowParallelLinear(
90
93
  self.total_num_heads * self.head_dim,
91
94
  config.hidden_size,
92
95
  bias=config.add_bias_linear,
93
96
  quant_config=quant_config,
97
+ prefix=add_prefix("dense", prefix),
94
98
  )
95
99
 
96
100
  # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
@@ -109,6 +113,7 @@ class GLMAttention(nn.Module):
109
113
  self.scaling,
110
114
  num_kv_heads=self.num_kv_heads,
111
115
  layer_id=layer_id,
116
+ prefix=add_prefix("attn", prefix),
112
117
  )
113
118
 
114
119
  def forward(
@@ -142,6 +147,7 @@ class GLMMLP(nn.Module):
142
147
  self,
143
148
  config,
144
149
  quant_config: Optional[QuantizationConfig] = None,
150
+ prefix: str = "",
145
151
  ):
146
152
  super().__init__()
147
153
 
@@ -153,6 +159,7 @@ class GLMMLP(nn.Module):
153
159
  [config.ffn_hidden_size] * 2,
154
160
  bias=config.add_bias_linear,
155
161
  quant_config=quant_config,
162
+ prefix=add_prefix("dense_h_to_4h", prefix),
156
163
  )
157
164
 
158
165
  self.activation_func = SiluAndMul()
@@ -163,6 +170,7 @@ class GLMMLP(nn.Module):
163
170
  config.hidden_size,
164
171
  bias=config.add_bias_linear,
165
172
  quant_config=quant_config,
173
+ prefix=add_prefix("dense_4h_to_h", prefix),
166
174
  )
167
175
 
168
176
  def forward(self, hidden_states):
@@ -186,6 +194,7 @@ class GLMBlock(nn.Module):
186
194
  config,
187
195
  layer_id: int,
188
196
  quant_config: Optional[QuantizationConfig] = None,
197
+ prefix: str = "",
189
198
  ):
190
199
  super().__init__()
191
200
  self.apply_residual_connection_post_layernorm = (
@@ -201,7 +210,9 @@ class GLMBlock(nn.Module):
201
210
  )
202
211
 
203
212
  # Self attention.
204
- self.self_attention = GLMAttention(config, layer_id, quant_config)
213
+ self.self_attention = GLMAttention(
214
+ config, layer_id, quant_config, prefix=add_prefix("self_attention", prefix)
215
+ )
205
216
  self.hidden_dropout = config.hidden_dropout
206
217
 
207
218
  # Layernorm on the attention output
@@ -210,7 +221,7 @@ class GLMBlock(nn.Module):
210
221
  )
211
222
 
212
223
  # MLP
213
- self.mlp = GLMMLP(config, quant_config)
224
+ self.mlp = GLMMLP(config, quant_config, prefix=add_prefix("mlp", prefix))
214
225
 
215
226
  def forward(
216
227
  self,
@@ -257,6 +268,7 @@ class GLMTransformer(nn.Module):
257
268
  self,
258
269
  config,
259
270
  quant_config: Optional[QuantizationConfig] = None,
271
+ prefix: str = "",
260
272
  ):
261
273
  super().__init__()
262
274
  self.post_layer_norm = config.post_layer_norm
@@ -266,7 +278,15 @@ class GLMTransformer(nn.Module):
266
278
 
267
279
  # Transformer layers.
268
280
  self.layers = nn.ModuleList(
269
- [GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
281
+ [
282
+ GLMBlock(
283
+ config,
284
+ i,
285
+ quant_config,
286
+ prefix=add_prefix(f"layers.{i}", prefix),
287
+ )
288
+ for i in range(self.num_layers)
289
+ ]
270
290
  )
271
291
 
272
292
  if self.post_layer_norm:
@@ -301,19 +321,28 @@ class ChatGLMM(nn.Module):
301
321
  self,
302
322
  config,
303
323
  quant_config: Optional[QuantizationConfig] = None,
324
+ prefix: str = "",
304
325
  ):
305
326
  super().__init__()
306
327
 
307
328
  self.embedding = VocabParallelEmbedding(
308
- config.padded_vocab_size, config.hidden_size
329
+ config.padded_vocab_size,
330
+ config.hidden_size,
331
+ prefix=add_prefix("embedding", prefix),
309
332
  )
310
333
 
311
334
  self.num_layers = config.num_layers
312
335
  self.multi_query_group_num = config.multi_query_group_num
313
336
  self.kv_channels = config.kv_channels
314
- self.encoder = GLMTransformer(config, quant_config)
337
+ self.encoder = GLMTransformer(
338
+ config, quant_config, add_prefix("encoder", prefix)
339
+ )
315
340
 
316
- self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
341
+ self.output_layer = ParallelLMHead(
342
+ config.padded_vocab_size,
343
+ config.hidden_size,
344
+ prefix=add_prefix("output_layer", prefix),
345
+ )
317
346
 
318
347
  def forward(
319
348
  self,
@@ -351,12 +380,15 @@ class ChatGLMForCausalLM(nn.Module):
351
380
  self,
352
381
  config: ChatGLMConfig,
353
382
  quant_config: Optional[QuantizationConfig] = None,
383
+ prefix: str = "",
354
384
  ):
355
385
  super().__init__()
356
386
  self.config: ChatGLMConfig = config
357
387
  self.quant_config = quant_config
358
388
  self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
359
- self.transformer = ChatGLMM(config, quant_config)
389
+ self.transformer = ChatGLMM(
390
+ config, quant_config, prefix=add_prefix("transformer", prefix)
391
+ )
360
392
  self.lm_head = self.transformer.output_layer
361
393
  self.logits_processor = LogitsProcessor(config)
362
394
 
@@ -65,7 +65,7 @@ from sglang.srt.model_loader.weight_utils import (
65
65
  default_weight_loader,
66
66
  maybe_remap_kv_scale_name,
67
67
  )
68
- from sglang.srt.utils import get_compiler_backend, set_weight_attrs
68
+ from sglang.srt.utils import add_prefix, get_compiler_backend, set_weight_attrs
69
69
 
70
70
 
71
71
  @torch.compile(backend=get_compiler_backend())
@@ -110,6 +110,7 @@ class CohereMLP(nn.Module):
110
110
  self,
111
111
  config,
112
112
  quant_config: Optional[QuantizationConfig] = None,
113
+ prefix: str = "",
113
114
  ):
114
115
  super().__init__()
115
116
  self.config = config
@@ -120,12 +121,14 @@ class CohereMLP(nn.Module):
120
121
  [self.intermediate_size] * 2,
121
122
  bias=False,
122
123
  quant_config=quant_config,
124
+ prefix=add_prefix("gate_up_proj", prefix),
123
125
  )
124
126
  self.down_proj = RowParallelLinear(
125
127
  self.intermediate_size,
126
128
  self.hidden_size,
127
129
  bias=False,
128
130
  quant_config=quant_config,
131
+ prefix=add_prefix("down_proj", prefix),
129
132
  )
130
133
  self.act_fn = SiluAndMul()
131
134
 
@@ -142,6 +145,7 @@ class CohereAttention(nn.Module):
142
145
  config: PretrainedConfig,
143
146
  layer_id: int = 0,
144
147
  quant_config: Optional[QuantizationConfig] = None,
148
+ prefix: str = "",
145
149
  ):
146
150
  super().__init__()
147
151
  tp_size = get_tensor_model_parallel_world_size()
@@ -177,12 +181,14 @@ class CohereAttention(nn.Module):
177
181
  self.total_num_kv_heads,
178
182
  bias=False,
179
183
  quant_config=quant_config,
184
+ prefix=add_prefix("qkv_proj", prefix),
180
185
  )
181
186
  self.o_proj = RowParallelLinear(
182
187
  self.total_num_heads * self.head_dim,
183
188
  self.hidden_size,
184
189
  bias=False,
185
190
  quant_config=quant_config,
191
+ prefix=add_prefix("o_proj", prefix),
186
192
  )
187
193
  self.rotary_emb = get_rope(
188
194
  self.head_dim,
@@ -198,6 +204,7 @@ class CohereAttention(nn.Module):
198
204
  self.scaling,
199
205
  num_kv_heads=self.num_kv_heads,
200
206
  layer_id=layer_id,
207
+ prefix=add_prefix("attn", prefix),
201
208
  )
202
209
  if self.use_qk_norm:
203
210
  self.q_norm = LayerNorm(
@@ -239,15 +246,23 @@ class CohereDecoderLayer(nn.Module):
239
246
  config: PretrainedConfig,
240
247
  layer_id: int = 0,
241
248
  quant_config: Optional[QuantizationConfig] = None,
249
+ prefix: str = "",
242
250
  ):
243
251
  super().__init__()
244
252
  self.hidden_size = config.hidden_size
245
253
 
246
254
  self.self_attn = CohereAttention(
247
- config, layer_id=layer_id, quant_config=quant_config
255
+ config,
256
+ layer_id=layer_id,
257
+ quant_config=quant_config,
258
+ prefix=add_prefix("self_attn", prefix),
248
259
  )
249
260
 
250
- self.mlp = CohereMLP(config, quant_config=quant_config)
261
+ self.mlp = CohereMLP(
262
+ config,
263
+ quant_config=quant_config,
264
+ prefix=add_prefix("mlp", prefix),
265
+ )
251
266
  self.input_layernorm = LayerNorm(
252
267
  param_shape=(config.hidden_size), eps=config.layer_norm_eps
253
268
  )
@@ -279,6 +294,7 @@ class CohereModel(nn.Module):
279
294
  self,
280
295
  config: PretrainedConfig,
281
296
  quant_config: Optional[QuantizationConfig] = None,
297
+ prefix: str = "",
282
298
  ):
283
299
  super().__init__()
284
300
  self.config = config
@@ -288,7 +304,12 @@ class CohereModel(nn.Module):
288
304
  )
289
305
  self.layers = nn.ModuleList(
290
306
  [
291
- CohereDecoderLayer(config, i, quant_config=quant_config)
307
+ CohereDecoderLayer(
308
+ config,
309
+ i,
310
+ quant_config=quant_config,
311
+ prefix=add_prefix(f"layers.{i}", prefix),
312
+ )
292
313
  for i in range(config.num_hidden_layers)
293
314
  ]
294
315
  )
@@ -321,12 +342,15 @@ class CohereForCausalLM(nn.Module):
321
342
  self,
322
343
  config: PretrainedConfig,
323
344
  quant_config: Optional[QuantizationConfig] = None,
345
+ prefix: str = "",
324
346
  ) -> None:
325
347
  super().__init__()
326
348
  self.config = config
327
349
  self.quant_config = quant_config
328
350
  self.logits_processor = LogitsProcessor(config)
329
- self.model = CohereModel(config, quant_config)
351
+ self.model = CohereModel(
352
+ config, quant_config, prefix=add_prefix("model", prefix)
353
+ )
330
354
 
331
355
  @torch.no_grad()
332
356
  def forward(
sglang/srt/models/dbrx.py CHANGED
@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
46
46
  default_weight_loader,
47
47
  maybe_remap_kv_scale_name,
48
48
  )
49
- from sglang.srt.utils import set_weight_attrs
49
+ from sglang.srt.utils import add_prefix, set_weight_attrs
50
50
 
51
51
 
52
52
  class DbrxRouter(nn.Module):
@@ -58,6 +58,7 @@ class DbrxRouter(nn.Module):
58
58
  self,
59
59
  config: DbrxConfig,
60
60
  params_dtype: Optional[torch.dtype] = None,
61
+ prefix: str = "",
61
62
  ):
62
63
  super().__init__()
63
64
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -89,6 +90,7 @@ class DbrxExperts(nn.Module):
89
90
  config: DbrxConfig,
90
91
  quant_config: Optional[QuantizationConfig] = None,
91
92
  params_dtype: Optional[torch.dtype] = None,
93
+ prefix: str = "",
92
94
  ):
93
95
  super().__init__()
94
96
  self.tp_size = get_tensor_model_parallel_world_size()
@@ -189,6 +191,7 @@ class DbrxAttention(nn.Module):
189
191
  config: DbrxConfig,
190
192
  layer_id: int = 0,
191
193
  quant_config: Optional[QuantizationConfig] = None,
194
+ prefix: str = "",
192
195
  ):
193
196
  super().__init__()
194
197
  self.d_model = config.d_model
@@ -207,12 +210,14 @@ class DbrxAttention(nn.Module):
207
210
  self.total_num_kv_heads,
208
211
  bias=False,
209
212
  quant_config=quant_config,
213
+ prefix=add_prefix("Wqkv", prefix),
210
214
  )
211
215
  self.out_proj = RowParallelLinear(
212
216
  self.d_model,
213
217
  self.d_model,
214
218
  bias=False,
215
219
  quant_config=quant_config,
220
+ prefix=add_prefix("out_proj", prefix),
216
221
  )
217
222
  self.rotary_emb = get_rope(
218
223
  self.head_dim,
@@ -244,6 +249,7 @@ class DbrxAttention(nn.Module):
244
249
  self.scaling,
245
250
  num_kv_heads=self.num_kv_heads,
246
251
  layer_id=layer_id,
252
+ prefix=add_prefix("attn", prefix),
247
253
  )
248
254
 
249
255
  def forward(
@@ -268,10 +274,16 @@ class DbrxFusedNormAttention(nn.Module):
268
274
  config: DbrxConfig,
269
275
  layer_id: int = 0,
270
276
  quant_config: Optional[QuantizationConfig] = None,
277
+ prefix: str = "",
271
278
  ):
272
279
  super().__init__()
273
280
  self.d_model = config.d_model
274
- self.attn = DbrxAttention(config, layer_id, quant_config=quant_config)
281
+ self.attn = DbrxAttention(
282
+ config,
283
+ layer_id,
284
+ quant_config=quant_config,
285
+ prefix=add_prefix("attn", prefix),
286
+ )
275
287
  self.norm_1 = nn.LayerNorm(self.d_model)
276
288
  self.norm_2 = nn.LayerNorm(self.d_model)
277
289
 
@@ -300,10 +312,14 @@ class DbrxBlock(nn.Module):
300
312
  config: DbrxConfig,
301
313
  layer_id: int = 0,
302
314
  quant_config: Optional[QuantizationConfig] = None,
315
+ prefix: str = "",
303
316
  ):
304
317
  super().__init__()
305
318
  self.norm_attn_norm = DbrxFusedNormAttention(
306
- config, layer_id, quant_config=quant_config
319
+ config,
320
+ layer_id,
321
+ quant_config=quant_config,
322
+ prefix=add_prefix("norm_attn_norm", prefix),
307
323
  )
308
324
  self.ffn = DbrxExperts(config, quant_config=quant_config)
309
325
 
@@ -328,6 +344,7 @@ class DbrxModel(nn.Module):
328
344
  self,
329
345
  config: DbrxConfig,
330
346
  quant_config: Optional[QuantizationConfig] = None,
347
+ prefix: str = "",
331
348
  ):
332
349
  super().__init__()
333
350
  self.wte = VocabParallelEmbedding(
@@ -336,7 +353,12 @@ class DbrxModel(nn.Module):
336
353
  )
337
354
  self.blocks = nn.ModuleList(
338
355
  [
339
- DbrxBlock(config, i, quant_config=quant_config)
356
+ DbrxBlock(
357
+ config,
358
+ i,
359
+ quant_config=quant_config,
360
+ prefix=add_prefix(f"blocks.{i}", prefix),
361
+ )
340
362
  for i in range(config.n_layers)
341
363
  ]
342
364
  )
@@ -369,17 +391,21 @@ class DbrxForCausalLM(nn.Module):
369
391
  self,
370
392
  config: DbrxConfig,
371
393
  quant_config: Optional[QuantizationConfig] = None,
394
+ prefix: str = "",
372
395
  ):
373
396
  super().__init__()
374
397
  self.config = config
375
398
  self.quant_config = quant_config
376
399
  self.unpadded_vocab_size = config.vocab_size
377
- self.transformer = DbrxModel(config, quant_config=quant_config)
400
+ self.transformer = DbrxModel(
401
+ config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
402
+ )
378
403
  self.lm_head = ParallelLMHead(
379
404
  config.vocab_size,
380
405
  config.d_model,
381
406
  org_num_embeddings=config.vocab_size,
382
407
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
408
+ prefix=add_prefix("lm_head", prefix),
383
409
  )
384
410
  self.logits_processor = LogitsProcessor(config)
385
411